mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-03-01 17:00:48 +01:00
Compare commits
47 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
31678935aa | ||
|
|
823a0b3e36 | ||
|
1591cd3d64
|
|||
|
|
6714ceeb92 | ||
|
|
73fae04333 | ||
|
|
32ed36e102 | ||
|
|
48567310bc | ||
|
|
de51ed4675 | ||
|
|
794767edbb | ||
|
|
9c136f05bb | ||
|
3299a439fe
|
|||
|
|
d5b22a72fd | ||
|
|
c32f2e18be | ||
|
d971261f98
|
|||
|
|
74a54b7396 | ||
|
|
19805ab376 | ||
|
|
d4498e2063 | ||
| f59c1a17e2 | |||
|
|
8982ba18e3 | ||
|
|
71fe6f478f | ||
|
|
1cfbf14986 | ||
|
|
e3ff535b7e | ||
|
|
8825c772ce | ||
|
|
c8c263ca8f | ||
|
2020fa2f92
|
|||
|
|
1ea316bef4 | ||
|
|
ced1a655f2 | ||
|
|
290b2a06ec | ||
|
|
baa9711665 | ||
|
d526969d0e
|
|||
|
|
e24153053e | ||
|
348ed4c148
|
|||
|
bd6e90de1b
|
|||
|
|
4404fb3df9 | ||
|
|
f68793fbdb | ||
|
|
3a69c3c788 | ||
|
e861a0a49a
|
|||
|
|
cb2cf572e0 | ||
|
494869a172
|
|||
|
|
e0bc93096d | ||
|
1ff94eb9d3
|
|||
|
|
97ab10edcd | ||
|
|
3ff7ff18bb | ||
|
0f50c8a0f0
|
|||
|
|
691fb78fda | ||
|
|
34ef4da317 | ||
|
|
8c287b3ce7 |
2
.github/workflows/build-release.yml
vendored
2
.github/workflows/build-release.yml
vendored
@@ -20,7 +20,7 @@ jobs:
|
|||||||
run: uv python install 3.14
|
run: uv python install 3.14
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: uv sync
|
run: uv sync --group dev
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
run: uv build
|
run: uv build
|
||||||
|
|||||||
16
.github/workflows/ci.yml
vendored
16
.github/workflows/ci.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
|||||||
run: uv python install 3.13
|
run: uv python install 3.13
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: uv sync --extra dev
|
run: uv sync --group dev
|
||||||
|
|
||||||
- name: Run Ruff linter
|
- name: Run Ruff linter
|
||||||
run: uv run ruff check .
|
run: uv run ruff check .
|
||||||
@@ -45,7 +45,7 @@ jobs:
|
|||||||
run: uv python install 3.13
|
run: uv python install 3.13
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: uv sync --extra dev
|
run: uv sync --group dev
|
||||||
|
|
||||||
- name: Run ty
|
- name: Run ty
|
||||||
run: uv run ty check
|
run: uv run ty check
|
||||||
@@ -83,18 +83,26 @@ jobs:
|
|||||||
run: uv python install ${{ matrix.python-version }}
|
run: uv python install ${{ matrix.python-version }}
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: uv sync --extra dev
|
run: uv sync --group dev
|
||||||
|
|
||||||
- name: Run tests with coverage
|
- name: Run tests with coverage
|
||||||
env:
|
env:
|
||||||
DATABASE_URL: postgresql+asyncpg://postgres:postgres@localhost:5432/test_db
|
DATABASE_URL: postgresql+asyncpg://postgres:postgres@localhost:5432/test_db
|
||||||
run: |
|
run: |
|
||||||
uv run pytest --cov --cov-report=xml --cov-report=term-missing
|
uv run pytest --cov --cov-report=xml --cov-report=term-missing --junitxml=junit.xml -o junit_family=legacy
|
||||||
|
|
||||||
- name: Upload coverage to Codecov
|
- name: Upload coverage to Codecov
|
||||||
if: matrix.python-version == '3.14'
|
if: matrix.python-version == '3.14'
|
||||||
uses: codecov/codecov-action@v5
|
uses: codecov/codecov-action@v5
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
|
report_type: coverage
|
||||||
files: ./coverage.xml
|
files: ./coverage.xml
|
||||||
fail_ci_if_error: false
|
fail_ci_if_error: false
|
||||||
|
|
||||||
|
- name: Upload test results to Codecov
|
||||||
|
if: matrix.python-version == '3.14'
|
||||||
|
uses: codecov/codecov-action@v5
|
||||||
|
with:
|
||||||
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
|
report_type: test_results
|
||||||
|
|||||||
38
.github/workflows/docs.yml
vendored
Normal file
38
.github/workflows/docs.yml
vendored
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
name: Documentation
|
||||||
|
|
||||||
|
on:
|
||||||
|
release:
|
||||||
|
types: [published]
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
pages: write
|
||||||
|
id-token: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
deploy:
|
||||||
|
environment:
|
||||||
|
name: github-pages
|
||||||
|
url: ${{ steps.deployment.outputs.page_url }}
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/configure-pages@v5
|
||||||
|
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v7
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
run: uv python install 3.13
|
||||||
|
|
||||||
|
- run: uv sync --group dev
|
||||||
|
|
||||||
|
- run: uv run zensical build --clean
|
||||||
|
|
||||||
|
- uses: actions/upload-pages-artifact@v4
|
||||||
|
with:
|
||||||
|
path: site
|
||||||
|
|
||||||
|
- uses: actions/deploy-pages@v4
|
||||||
|
id: deployment
|
||||||
36
README.md
36
README.md
@@ -1,6 +1,6 @@
|
|||||||
# FastAPI Toolsets
|
# FastAPI Toolsets
|
||||||
|
|
||||||
FastAPI Toolsets provides production-ready utilities for FastAPI applications built with async SQLAlchemy and PostgreSQL. It includes generic CRUD operations, a fixture system with dependency resolution, a Django-like CLI, standardized API responses, and structured exception handling with automatic OpenAPI documentation.
|
A modular collection of production-ready utilities for FastAPI. Install only what you need — from async CRUD and database helpers to CLI tooling, Prometheus metrics, and pytest fixtures. Each module is independently installable via optional extras, keeping your dependency footprint minimal.
|
||||||
|
|
||||||
[](https://github.com/d3vyce/fastapi-toolsets/actions/workflows/ci.yml)
|
[](https://github.com/d3vyce/fastapi-toolsets/actions/workflows/ci.yml)
|
||||||
[](https://codecov.io/gh/d3vyce/fastapi-toolsets)
|
[](https://codecov.io/gh/d3vyce/fastapi-toolsets)
|
||||||
@@ -20,17 +20,43 @@ FastAPI Toolsets provides production-ready utilities for FastAPI applications bu
|
|||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
|
The base package includes the core modules (CRUD, database, schemas, exceptions, fixtures, dependencies, logging):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv add fastapi-toolsets
|
uv add fastapi-toolsets
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Install only the extras you need:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv add "fastapi-toolsets[cli]" # CLI (typer)
|
||||||
|
uv add "fastapi-toolsets[metrics]" # Prometheus metrics (prometheus_client)
|
||||||
|
uv add "fastapi-toolsets[pytest]" # Pytest helpers (httpx, pytest-xdist)
|
||||||
|
```
|
||||||
|
|
||||||
|
Or install everything:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv add "fastapi-toolsets[all]"
|
||||||
|
```
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
- **CRUD**: Generic async CRUD operations with `CrudFactory`
|
### Core
|
||||||
- **Fixtures**: Fixture system with dependency management, context support and pytest integration
|
|
||||||
- **CLI**: Django-like command-line interface for fixtures and custom commands
|
- **CRUD**: Generic async CRUD operations with `CrudFactory`, built-in search with relationship traversal
|
||||||
- **Standardized API Responses**: Consistent response format across your API
|
- **Database**: Session management, transaction helpers, table locking, and polling-based row change detection
|
||||||
|
- **Dependencies**: FastAPI dependency factories (`PathDependency`, `BodyDependency`) for automatic DB lookups from path or body parameters
|
||||||
|
- **Fixtures**: Fixture system with dependency management, context support, and pytest integration
|
||||||
|
- **Standardized API Responses**: Consistent response format with `Response`, `PaginatedResponse`, and `PydanticBase`
|
||||||
- **Exception Handling**: Structured error responses with automatic OpenAPI documentation
|
- **Exception Handling**: Structured error responses with automatic OpenAPI documentation
|
||||||
|
- **Logging**: Logging configuration with uvicorn integration via `configure_logging` and `get_logger`
|
||||||
|
|
||||||
|
### Optional
|
||||||
|
|
||||||
|
- **CLI**: Django-like command-line interface with fixture management and custom commands support
|
||||||
|
- **Metrics**: Prometheus metrics endpoint with provider/collector registry
|
||||||
|
- **Pytest Helpers**: Async test client, database session management, `pytest-xdist` support, and table cleanup utilities
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
|
|||||||
67
docs/index.md
Normal file
67
docs/index.md
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
# FastAPI Toolsets
|
||||||
|
|
||||||
|
A modular collection of production-ready utilities for FastAPI. Install only what you need — from async CRUD and database helpers to CLI tooling, Prometheus metrics, and pytest fixtures. Each module is independently installable via optional extras, keeping your dependency footprint minimal.
|
||||||
|
|
||||||
|
[](https://github.com/d3vyce/fastapi-toolsets/actions/workflows/ci.yml)
|
||||||
|
[](https://codecov.io/gh/d3vyce/fastapi-toolsets)
|
||||||
|
[](https://github.com/astral-sh/ty)
|
||||||
|
[](https://github.com/astral-sh/uv)
|
||||||
|
[](https://github.com/astral-sh/ruff)
|
||||||
|
[](https://www.python.org/downloads/)
|
||||||
|
[](https://opensource.org/licenses/MIT)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Documentation**: [https://fastapi-toolsets.d3vyce.fr](https://fastapi-toolsets.d3vyce.fr)
|
||||||
|
|
||||||
|
**Source Code**: [https://github.com/d3vyce/fastapi-toolsets](https://github.com/d3vyce/fastapi-toolsets)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
The base package includes the core modules (CRUD, database, schemas, exceptions, fixtures, dependencies, logging):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv add fastapi-toolsets
|
||||||
|
```
|
||||||
|
|
||||||
|
Install only the extras you need:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv add "fastapi-toolsets[cli]" # CLI (typer)
|
||||||
|
uv add "fastapi-toolsets[metrics]" # Prometheus metrics (prometheus_client)
|
||||||
|
uv add "fastapi-toolsets[pytest]" # Pytest helpers (httpx, pytest-xdist)
|
||||||
|
```
|
||||||
|
|
||||||
|
Or install everything:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv add "fastapi-toolsets[all]"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
### Core
|
||||||
|
|
||||||
|
- **CRUD**: Generic async CRUD operations with `CrudFactory`, built-in search with relationship traversal
|
||||||
|
- **Database**: Session management, transaction helpers, table locking, and polling-based row change detection
|
||||||
|
- **Dependencies**: FastAPI dependency factories (`PathDependency`, `BodyDependency`) for automatic DB lookups from path or body parameters
|
||||||
|
- **Fixtures**: Fixture system with dependency management, context support, and pytest integration
|
||||||
|
- **Standardized API Responses**: Consistent response format with `Response`, `PaginatedResponse`, and `PydanticBase`
|
||||||
|
- **Exception Handling**: Structured error responses with automatic OpenAPI documentation
|
||||||
|
- **Logging**: Logging configuration with uvicorn integration via `configure_logging` and `get_logger`
|
||||||
|
|
||||||
|
### Optional
|
||||||
|
|
||||||
|
- **CLI**: Django-like command-line interface with fixture management and custom commands support
|
||||||
|
- **Metrics**: Prometheus metrics endpoint with provider/collector registry
|
||||||
|
- **Pytest Helpers**: Async test client, database session management, `pytest-xdist` support, and table cleanup utilities
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
MIT License - see [LICENSE](LICENSE) for details.
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
Contributions are welcome! Please feel free to submit issues and pull requests.
|
||||||
93
docs/module/cli.md
Normal file
93
docs/module/cli.md
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
# CLI
|
||||||
|
|
||||||
|
Typer-based command-line interface for managing your FastAPI application, with built-in fixture commands integration.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
=== "uv"
|
||||||
|
``` bash
|
||||||
|
uv add "fastapi-toolsets[cli]"
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "pip"
|
||||||
|
``` bash
|
||||||
|
pip install "fastapi-toolsets[cli]"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `cli` module provides a `manager` entry point built with [Typer](https://typer.tiangolo.com/). It allow custom commands to be added in addition of the fixture commands when a [`FixtureRegistry`](../reference/fixtures.md#fastapi_toolsets.fixtures.registry.FixtureRegistry) and a database context are configured.
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
Configure the CLI in your `pyproject.toml`:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[tool.fastapi-toolsets]
|
||||||
|
cli = "myapp.cli:cli" # Custom Typer app
|
||||||
|
fixtures = "myapp.fixtures:registry" # FixtureRegistry instance
|
||||||
|
db_context = "myapp.db:db_context" # Async context manager for sessions
|
||||||
|
```
|
||||||
|
|
||||||
|
All fields are optional. Without configuration, the `manager` command still works but no command are available.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Manager commands
|
||||||
|
manager --help
|
||||||
|
|
||||||
|
Usage: manager [OPTIONS] COMMAND [ARGS]...
|
||||||
|
|
||||||
|
FastAPI utilities CLI.
|
||||||
|
|
||||||
|
╭─ Options ────────────────────────────────────────────────────────────────────────╮
|
||||||
|
│ --install-completion Install completion for the current shell. │
|
||||||
|
│ --show-completion Show completion for the current shell, to copy it │
|
||||||
|
│ or customize the installation. │
|
||||||
|
│ --help Show this message and exit. │
|
||||||
|
╰──────────────────────────────────────────────────────────────────────────────────╯
|
||||||
|
╭─ Commands ───────────────────────────────────────────────────────────────────────╮
|
||||||
|
│ check-db │
|
||||||
|
│ fixtures Manage database fixtures. │
|
||||||
|
╰──────────────────────────────────────────────────────────────────────────────────╯
|
||||||
|
|
||||||
|
# Fixtures commands
|
||||||
|
manager fixtures --help
|
||||||
|
|
||||||
|
Usage: manager fixtures [OPTIONS] COMMAND [ARGS]...
|
||||||
|
|
||||||
|
Manage database fixtures.
|
||||||
|
|
||||||
|
╭─ Options ────────────────────────────────────────────────────────────────────────╮
|
||||||
|
│ --help Show this message and exit. │
|
||||||
|
╰──────────────────────────────────────────────────────────────────────────────────╯
|
||||||
|
╭─ Commands ───────────────────────────────────────────────────────────────────────╮
|
||||||
|
│ list List all registered fixtures. │
|
||||||
|
│ load Load fixtures into the database. │
|
||||||
|
╰──────────────────────────────────────────────────────────────────────────────────╯
|
||||||
|
```
|
||||||
|
|
||||||
|
## Custom CLI
|
||||||
|
|
||||||
|
You can extend the CLI by providing your own Typer app. The `manager` entry point will merge your app's commands with the built-in ones:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# myapp/cli.py
|
||||||
|
import typer
|
||||||
|
|
||||||
|
cli = typer.Typer()
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
def hello():
|
||||||
|
print("Hello from my app!")
|
||||||
|
```
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[tool.fastapi-toolsets]
|
||||||
|
cli = "myapp.cli:cli"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/cli.md)
|
||||||
152
docs/module/crud.md
Normal file
152
docs/module/crud.md
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
# CRUD
|
||||||
|
|
||||||
|
Generic async CRUD operations for SQLAlchemy models with search, pagination, and many-to-many support. This module has features that are only compatible with Postgres.
|
||||||
|
|
||||||
|
!!! info
|
||||||
|
This module has been coded and tested to be compatible with PostgreSQL only.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `crud` module provides [`AsyncCrud`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud), an abstract base class with a full suite of async database operations, and [`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory), a convenience function to instantiate it for a given model.
|
||||||
|
|
||||||
|
## Creating a CRUD class
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.crud import CrudFactory
|
||||||
|
from myapp.models import User
|
||||||
|
|
||||||
|
UserCrud = CrudFactory(model=User)
|
||||||
|
```
|
||||||
|
|
||||||
|
[`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory) dynamically creates a class named `AsyncUserCrud` with `User` as its model.
|
||||||
|
|
||||||
|
## Basic operations
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Create
|
||||||
|
user = await UserCrud.create(session=session, obj=UserCreateSchema(username="alice"))
|
||||||
|
|
||||||
|
# Get one (raises NotFoundError if not found)
|
||||||
|
user = await UserCrud.get(session=session, filters=[User.id == user_id])
|
||||||
|
|
||||||
|
# Get first or None
|
||||||
|
user = await UserCrud.first(session=session, filters=[User.email == email])
|
||||||
|
|
||||||
|
# Get multiple
|
||||||
|
users = await UserCrud.get_multi(session=session, filters=[User.is_active == True])
|
||||||
|
|
||||||
|
# Update
|
||||||
|
user = await UserCrud.update(session=session, obj=UserUpdateSchema(username="bob"), filters=[User.id == user_id])
|
||||||
|
|
||||||
|
# Delete
|
||||||
|
await UserCrud.delete(session=session, filters=[User.id == user_id])
|
||||||
|
|
||||||
|
# Count / exists
|
||||||
|
count = await UserCrud.count(session=session, filters=[User.is_active == True])
|
||||||
|
exists = await UserCrud.exists(session=session, filters=[User.email == email])
|
||||||
|
```
|
||||||
|
|
||||||
|
## Pagination
|
||||||
|
|
||||||
|
```python
|
||||||
|
@router.get(
|
||||||
|
"",
|
||||||
|
response_model=PaginatedResponse[User],
|
||||||
|
)
|
||||||
|
async def get_users(
|
||||||
|
session: SessionDep,
|
||||||
|
items_per_page: int = 50,
|
||||||
|
page: int = 1,
|
||||||
|
):
|
||||||
|
return await crud.UserCrud.paginate(
|
||||||
|
session=session,
|
||||||
|
items_per_page=items_per_page,
|
||||||
|
page=page,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
The [`paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.paginate) function will return a [`PaginatedResponse`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse).
|
||||||
|
|
||||||
|
## Search
|
||||||
|
|
||||||
|
Declare searchable fields on the CRUD class. Relationship traversal is supported via tuples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
PostCrud = CrudFactory(
|
||||||
|
model=Post,
|
||||||
|
searchable_fields=[
|
||||||
|
Post.title,
|
||||||
|
Post.content,
|
||||||
|
(Post.author, User.username), # search across relationship
|
||||||
|
],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
This allow to do a search with the [`paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.paginate) function:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@router.get(
|
||||||
|
"",
|
||||||
|
response_model=PaginatedResponse[User],
|
||||||
|
)
|
||||||
|
async def get_users(
|
||||||
|
session: SessionDep,
|
||||||
|
items_per_page: int = 50,
|
||||||
|
page: int = 1,
|
||||||
|
search: str | None = None,
|
||||||
|
):
|
||||||
|
return await crud.UserCrud.paginate(
|
||||||
|
session=session,
|
||||||
|
items_per_page=items_per_page,
|
||||||
|
page=page,
|
||||||
|
search=search,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Many-to-many relationships
|
||||||
|
|
||||||
|
Use `m2m_fields` to map schema fields containing lists of IDs to SQLAlchemy relationships. The CRUD class resolves and validates all IDs before persisting:
|
||||||
|
|
||||||
|
```python
|
||||||
|
PostCrud = CrudFactory(
|
||||||
|
model=Post,
|
||||||
|
m2m_fields={"tag_ids": Post.tags},
|
||||||
|
)
|
||||||
|
|
||||||
|
post = await PostCrud.create(session=session, obj=PostCreateSchema(title="Hello", tag_ids=[1, 2, 3]))
|
||||||
|
```
|
||||||
|
|
||||||
|
## Upsert
|
||||||
|
|
||||||
|
Atomic `INSERT ... ON CONFLICT DO UPDATE` using PostgreSQL:
|
||||||
|
|
||||||
|
```python
|
||||||
|
await UserCrud.upsert(
|
||||||
|
session=session,
|
||||||
|
obj=UserCreateSchema(email="alice@example.com", username="alice"),
|
||||||
|
index_elements=[User.email],
|
||||||
|
set_={"username"},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## `as_response`
|
||||||
|
|
||||||
|
Pass `as_response=True` to any write operation to get a [`Response[ModelType]`](../reference/schemas.md#fastapi_toolsets.schemas.Response) back directly for API usage:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@router.get(
|
||||||
|
"/{uuid}",
|
||||||
|
response_model=Response[User],
|
||||||
|
responses=generate_error_responses(NotFoundError),
|
||||||
|
)
|
||||||
|
async def get_user(session: SessionDep, uuid: UUID):
|
||||||
|
return await crud.UserCrud.get(
|
||||||
|
session=session,
|
||||||
|
filters=[User.id == uuid],
|
||||||
|
as_response=True,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/crud.md)
|
||||||
92
docs/module/db.md
Normal file
92
docs/module/db.md
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
# DB
|
||||||
|
|
||||||
|
SQLAlchemy async session management with transactions, table locking, and row-change polling.
|
||||||
|
|
||||||
|
!!! info
|
||||||
|
This module has been coded and tested to be compatible with PostgreSQL only.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `db` module provides helpers to create FastAPI dependencies and context managers for `AsyncSession`, along with utilities for nested transactions, table lock and polling for row changes.
|
||||||
|
|
||||||
|
## Session dependency
|
||||||
|
|
||||||
|
Use [`create_db_dependency`](../reference/db.md#fastapi_toolsets.db.create_db_dependency) to create a FastAPI dependency that yields a session and auto-commits on success:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||||
|
from fastapi_toolsets.db import create_db_dependency
|
||||||
|
|
||||||
|
engine = create_async_engine(url="postgresql+asyncpg://...", future=True)
|
||||||
|
session_maker = async_sessionmaker(bind=engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
get_db = create_db_dependency(session_maker=session_maker)
|
||||||
|
|
||||||
|
@router.get("/users")
|
||||||
|
async def list_users(session: AsyncSession = Depends(get_db)):
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Session context manager
|
||||||
|
|
||||||
|
Use [`create_db_context`](../reference/db.md#fastapi_toolsets.db.create_db_context) for sessions outside request handlers (e.g. background tasks, CLI commands):
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.db import create_db_context
|
||||||
|
|
||||||
|
db_context = create_db_context(session_maker=session_maker)
|
||||||
|
|
||||||
|
async def seed():
|
||||||
|
async with db_context() as session:
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Nested transactions
|
||||||
|
|
||||||
|
[`get_transaction`](../reference/db.md#fastapi_toolsets.db.get_transaction) handles savepoints automatically, allowing safe nesting:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.db import get_transaction
|
||||||
|
|
||||||
|
async def create_user_with_role(session=session):
|
||||||
|
async with get_transaction(session=session):
|
||||||
|
...
|
||||||
|
async with get_transaction(session=session): # uses savepoint
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Table locking
|
||||||
|
|
||||||
|
[`lock_tables`](../reference/db.md#fastapi_toolsets.db.lock_tables) acquires PostgreSQL table-level locks before executing critical sections:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.db import lock_tables
|
||||||
|
|
||||||
|
async with lock_tables(session=session, tables=[User], mode="EXCLUSIVE"):
|
||||||
|
# No other transaction can modify User until this block exits
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
Available lock modes are defined in [`LockMode`](../reference/db.md#fastapi_toolsets.db.LockMode): `ACCESS_SHARE`, `ROW_SHARE`, `ROW_EXCLUSIVE`, `SHARE_UPDATE_EXCLUSIVE`, `SHARE`, `SHARE_ROW_EXCLUSIVE`, `EXCLUSIVE`, `ACCESS_EXCLUSIVE`.
|
||||||
|
|
||||||
|
## Row-change polling
|
||||||
|
|
||||||
|
[`wait_for_row_change`](../reference/db.md#fastapi_toolsets.db.wait_for_row_change) polls a row until a specific column changes value, useful for waiting on async side effects:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.db import wait_for_row_change
|
||||||
|
|
||||||
|
# Wait up to 30s for order.status to change
|
||||||
|
await wait_for_row_change(
|
||||||
|
session=session,
|
||||||
|
model=Order,
|
||||||
|
pk_value=order_id,
|
||||||
|
columns=[Order.status],
|
||||||
|
interval=1.0,
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/db.md)
|
||||||
50
docs/module/dependencies.md
Normal file
50
docs/module/dependencies.md
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
# Dependencies
|
||||||
|
|
||||||
|
FastAPI dependency factories for automatic model resolution from path and body parameters.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `dependencies` module provides two factory functions that create FastAPI dependencies to fetch a model instance from the database automatically — either from a path parameter or from a request body field — and inject it directly into your route handler.
|
||||||
|
|
||||||
|
## `PathDependency`
|
||||||
|
|
||||||
|
[`PathDependency`](../reference/dependencies.md#fastapi_toolsets.dependencies.PathDependency) resolves a model from a URL path parameter and injects it into the route handler. Raises [`NotFoundError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.NotFoundError) automatically if the record does not exist.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.dependencies import PathDependency
|
||||||
|
|
||||||
|
UserDep = PathDependency(model=User, field=User.id, session_dep=get_db)
|
||||||
|
|
||||||
|
@router.get("/users/{user_id}")
|
||||||
|
async def get_user(user: User = UserDep):
|
||||||
|
return user
|
||||||
|
```
|
||||||
|
|
||||||
|
By default the parameter name is inferred from the field (`user_id` for `User.id`). You can override it:
|
||||||
|
|
||||||
|
```python
|
||||||
|
UserDep = PathDependency(model=User, field=User.id, session_dep=get_db, param_name="id")
|
||||||
|
|
||||||
|
@router.get("/users/{id}")
|
||||||
|
async def get_user(user: User = UserDep):
|
||||||
|
return user
|
||||||
|
```
|
||||||
|
|
||||||
|
## `BodyDependency`
|
||||||
|
|
||||||
|
[`BodyDependency`](../reference/dependencies.md#fastapi_toolsets.dependencies.BodyDependency) resolves a model from a field in the request body. Useful when a body contains a foreign key and you want the full object injected:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.dependencies import BodyDependency
|
||||||
|
|
||||||
|
RoleDep = BodyDependency(model=Role, field=Role.id, session_dep=get_db, body_field="role_id")
|
||||||
|
|
||||||
|
@router.post("/users")
|
||||||
|
async def create_user(body: UserCreateSchema, role: Role = RoleDep):
|
||||||
|
user = User(username=body.username, role=role)
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/dependencies.md)
|
||||||
85
docs/module/exceptions.md
Normal file
85
docs/module/exceptions.md
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
# Exceptions
|
||||||
|
|
||||||
|
Structured API exceptions with consistent error responses and automatic OpenAPI documentation.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `exceptions` module provides a set of pre-built HTTP exceptions and a FastAPI exception handler that formats all errors — including validation errors — into a uniform [`ErrorResponse`](../reference/schemas.md#fastapi_toolsets.schemas.ErrorResponse).
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
Register the exception handlers on your FastAPI app at startup:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi_toolsets.exceptions import init_exceptions_handlers
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app=app)
|
||||||
|
```
|
||||||
|
|
||||||
|
This registers handlers for:
|
||||||
|
|
||||||
|
- [`ApiException`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.ApiException) — all custom exceptions below
|
||||||
|
- `RequestValidationError` — Pydantic request validation (422)
|
||||||
|
- `ResponseValidationError` — Pydantic response validation (422)
|
||||||
|
- `Exception` — unhandled errors (500)
|
||||||
|
|
||||||
|
## Built-in exceptions
|
||||||
|
|
||||||
|
| Exception | Status | Default message |
|
||||||
|
|-----------|--------|-----------------|
|
||||||
|
| [`UnauthorizedError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.UnauthorizedError) | 401 | Unauthorized |
|
||||||
|
| [`ForbiddenError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.ForbiddenError) | 403 | Forbidden |
|
||||||
|
| [`NotFoundError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.NotFoundError) | 404 | Not found |
|
||||||
|
| [`ConflictError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.ConflictError) | 409 | Conflict |
|
||||||
|
| [`NoSearchableFieldsError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.NoSearchableFieldsError) | 400 | No searchable fields |
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.exceptions import NotFoundError
|
||||||
|
|
||||||
|
@router.get("/users/{id}")
|
||||||
|
async def get_user(id: int, session: AsyncSession = Depends(get_db)):
|
||||||
|
user = await UserCrud.first(session=session, filters=[User.id == id])
|
||||||
|
if not user:
|
||||||
|
raise NotFoundError
|
||||||
|
return user
|
||||||
|
```
|
||||||
|
|
||||||
|
## Custom exceptions
|
||||||
|
|
||||||
|
Subclass [`ApiException`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.ApiException) and define an `api_error` class variable:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.exceptions import ApiException
|
||||||
|
from fastapi_toolsets.schemas import ApiError
|
||||||
|
|
||||||
|
class PaymentRequiredError(ApiException):
|
||||||
|
api_error = ApiError(
|
||||||
|
code=402,
|
||||||
|
msg="Payment required",
|
||||||
|
desc="Your subscription has expired.",
|
||||||
|
err_code="PAYMENT_REQUIRED",
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## OpenAPI response documentation
|
||||||
|
|
||||||
|
Use [`generate_error_responses`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.generate_error_responses) to add error schemas to your endpoint's OpenAPI spec:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.exceptions import generate_error_responses, NotFoundError, ForbiddenError
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/users/{id}",
|
||||||
|
responses=generate_error_responses(NotFoundError, ForbiddenError),
|
||||||
|
)
|
||||||
|
async def get_user(...): ...
|
||||||
|
```
|
||||||
|
|
||||||
|
!!! info
|
||||||
|
The pydantic validation error is automatically added by FastAPI.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/exceptions.md)
|
||||||
123
docs/module/fixtures.md
Normal file
123
docs/module/fixtures.md
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
# Fixtures
|
||||||
|
|
||||||
|
Dependency-aware database seeding with context-based loading strategies.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `fixtures` module lets you define named fixtures with dependencies between them, then load them into the database in the correct order. Fixtures can be scoped to contexts (e.g. base data, testing data) so that only the relevant ones are loaded for each environment.
|
||||||
|
|
||||||
|
## Defining fixtures
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.fixtures import FixtureRegistry, Context
|
||||||
|
|
||||||
|
fixtures = FixtureRegistry()
|
||||||
|
|
||||||
|
@fixtures.register
|
||||||
|
def roles():
|
||||||
|
return [
|
||||||
|
Role(id=1, name="admin"),
|
||||||
|
Role(id=2, name="user"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@fixtures.register(depends_on=["roles"], contexts=[Context.TESTING])
|
||||||
|
def test_users():
|
||||||
|
return [
|
||||||
|
User(id=1, username="alice", role_id=1),
|
||||||
|
User(id=2, username="bob", role_id=2),
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
Dependencies declared via `depends_on` are resolved topologically — `roles` will always be loaded before `test_users`.
|
||||||
|
|
||||||
|
## Loading fixtures
|
||||||
|
|
||||||
|
By context with [`load_fixtures_by_context`](../reference/fixtures.md#fastapi_toolsets.fixtures.utils.load_fixtures_by_context):
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.fixtures import load_fixtures_by_context
|
||||||
|
|
||||||
|
async with db_context() as session:
|
||||||
|
await load_fixtures_by_context(session=session, registry=fixtures, context=Context.TESTING)
|
||||||
|
```
|
||||||
|
|
||||||
|
Directly with [`load_fixtures`](../reference/fixtures.md#fastapi_toolsets.fixtures.utils.load_fixtures):
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.fixtures import load_fixtures
|
||||||
|
|
||||||
|
async with db_context() as session:
|
||||||
|
await load_fixtures(session=session, registry=fixtures)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Contexts
|
||||||
|
|
||||||
|
[`Context`](../reference/fixtures.md#fastapi_toolsets.fixtures.enum.Context) is an enum with predefined values:
|
||||||
|
|
||||||
|
| Context | Description |
|
||||||
|
|---------|-------------|
|
||||||
|
| `Context.BASE` | Core data required in all environments |
|
||||||
|
| `Context.TESTING` | Data only loaded during tests |
|
||||||
|
| `Context.PRODUCTION` | Data only loaded in production |
|
||||||
|
|
||||||
|
A fixture with no `contexts` defined takes `Context.BASE` by default.
|
||||||
|
|
||||||
|
## Load strategies
|
||||||
|
|
||||||
|
[`LoadStrategy`](../reference/fixtures.md#fastapi_toolsets.fixtures.enum.LoadStrategy) controls how the fixture loader handles rows that already exist:
|
||||||
|
|
||||||
|
| Strategy | Description |
|
||||||
|
|----------|-------------|
|
||||||
|
| `LoadStrategy.INSERT` | Insert only, fail on duplicates |
|
||||||
|
| `LoadStrategy.UPSERT` | Insert or update on conflict |
|
||||||
|
| `LoadStrategy.SKIP` | Skip rows that already exist |
|
||||||
|
|
||||||
|
## Merging registries
|
||||||
|
|
||||||
|
Split fixtures definitions across modules and merge them:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from myapp.fixtures.dev import dev_fixtures
|
||||||
|
from myapp.fixtures.prod import prod_fixtures
|
||||||
|
|
||||||
|
fixtures = fixturesRegistry()
|
||||||
|
fixtures.include_registry(registry=dev_fixtures)
|
||||||
|
fixtures.include_registry(registry=prod_fixtures)
|
||||||
|
|
||||||
|
## Pytest integration
|
||||||
|
|
||||||
|
Use [`register_fixtures`](../reference/pytest.md#fastapi_toolsets.pytest.plugin.register_fixtures) to expose each fixture in your registry as an injectable pytest fixture named `fixture_{name}` by default:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# conftest.py
|
||||||
|
import pytest
|
||||||
|
from fastapi_toolsets.pytest import create_db_session, register_fixtures
|
||||||
|
from app.fixtures import registry
|
||||||
|
from app.models import Base
|
||||||
|
|
||||||
|
DATABASE_URL = "postgresql+asyncpg://user:pass@localhost/test_db"
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_session():
|
||||||
|
async with create_db_session(database_url=DATABASE_URL, base=Base, cleanup=True) as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
register_fixtures(registry=registry, namespace=globals())
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
# test_users.py
|
||||||
|
async def test_user_can_login(fixture_users: list[User], fixture_roles: list[Role]):
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
The load order is resolved automatically from the `depends_on` declarations in your registry. Each generated fixture receives `db_session` as a dependency and returns the list of loaded model instances.
|
||||||
|
|
||||||
|
## CLI integration
|
||||||
|
|
||||||
|
Fixtures can be triggered from the CLI. See the [CLI module](cli.md) for setup instructions.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/fixtures.md)
|
||||||
39
docs/module/logger.md
Normal file
39
docs/module/logger.md
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
# Logger
|
||||||
|
|
||||||
|
Lightweight logging utilities with consistent formatting and uvicorn integration.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `logger` module provides two helpers: one to configure the root logger (and uvicorn loggers) at startup, and one to retrieve a named logger anywhere in your codebase.
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
Call [`configure_logging`](../reference/logger.md#fastapi_toolsets.logger.configure_logging) once at application startup:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.logger import configure_logging
|
||||||
|
|
||||||
|
configure_logging(level="INFO")
|
||||||
|
```
|
||||||
|
|
||||||
|
This sets up a stdout handler with a consistent format and also configures uvicorn's access and error loggers so all log output shares the same style.
|
||||||
|
|
||||||
|
## Getting a logger
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__)
|
||||||
|
logger.info("User created")
|
||||||
|
```
|
||||||
|
|
||||||
|
When called without arguments, [`get_logger`](../reference/logger.md#fastapi_toolsets.logger.get_logger) auto-detects the caller's module name via frame inspection:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Equivalent to get_logger(name=__name__)
|
||||||
|
logger = get_logger()
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/logger.md)
|
||||||
86
docs/module/metrics.md
Normal file
86
docs/module/metrics.md
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
# Metrics
|
||||||
|
|
||||||
|
Prometheus metrics integration with a decorator-based registry and multi-process support.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
=== "uv"
|
||||||
|
``` bash
|
||||||
|
uv add "fastapi-toolsets[metrics]"
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "pip"
|
||||||
|
``` bash
|
||||||
|
pip install "fastapi-toolsets[metrics]"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `metrics` module provides a [`MetricsRegistry`](../reference/metrics.md#fastapi_toolsets.metrics.registry.MetricsRegistry) to declare Prometheus metrics with decorators, and an [`init_metrics`](../reference/metrics.md#fastapi_toolsets.metrics.handler.init_metrics) function to mount a `/metrics` endpoint on your FastAPI app.
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi_toolsets.metrics import MetricsRegistry, init_metrics
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
metrics = MetricsRegistry()
|
||||||
|
|
||||||
|
init_metrics(app=app, registry=metrics)
|
||||||
|
```
|
||||||
|
|
||||||
|
This mounts the `/metrics` endpoint that Prometheus can scrape.
|
||||||
|
|
||||||
|
## Declaring metrics
|
||||||
|
|
||||||
|
### Providers
|
||||||
|
|
||||||
|
Providers are called once at startup and register metrics that are updated externally (e.g. counters, histograms):
|
||||||
|
|
||||||
|
```python
|
||||||
|
from prometheus_client import Counter, Histogram
|
||||||
|
|
||||||
|
@metrics.register
|
||||||
|
def http_requests():
|
||||||
|
return Counter("http_requests_total", "Total HTTP requests", ["method", "status"])
|
||||||
|
|
||||||
|
@metrics.register
|
||||||
|
def request_duration():
|
||||||
|
return Histogram("request_duration_seconds", "Request duration")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Collectors
|
||||||
|
|
||||||
|
Collectors are called on every scrape. Use them for metrics that reflect current state (e.g. gauges):
|
||||||
|
|
||||||
|
```python
|
||||||
|
@metrics.register(collect=True)
|
||||||
|
def queue_depth():
|
||||||
|
gauge = Gauge("queue_depth", "Current queue depth")
|
||||||
|
gauge.set(get_current_queue_depth())
|
||||||
|
```
|
||||||
|
|
||||||
|
## Merging registries
|
||||||
|
|
||||||
|
Split metrics definitions across modules and merge them:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from myapp.metrics.http import http_metrics
|
||||||
|
from myapp.metrics.db import db_metrics
|
||||||
|
|
||||||
|
metrics = MetricsRegistry()
|
||||||
|
metrics.include_registry(registry=http_metrics)
|
||||||
|
metrics.include_registry(registry=db_metrics)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Multi-process mode
|
||||||
|
|
||||||
|
Multi-process support is enabled automatically when the `PROMETHEUS_MULTIPROC_DIR` environment variable is set. No code changes are required.
|
||||||
|
|
||||||
|
!!! warning "Environment variable name"
|
||||||
|
The correct variable is `PROMETHEUS_MULTIPROC_DIR` (not `PROMETHEUS_MULTIPROCESS_DIR`).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/metrics.md)
|
||||||
86
docs/module/pytest.md
Normal file
86
docs/module/pytest.md
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
# Pytest
|
||||||
|
|
||||||
|
Testing helpers for FastAPI applications with async client, database sessions, and parallel worker support.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
=== "uv"
|
||||||
|
``` bash
|
||||||
|
uv add "fastapi-toolsets[pytest]"
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "pip"
|
||||||
|
``` bash
|
||||||
|
pip install "fastapi-toolsets[pytest]"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `pytest` module provides utilities for setting up async test clients, managing test database sessions, and supporting parallel test execution with `pytest-xdist`.
|
||||||
|
|
||||||
|
## Creating an async client
|
||||||
|
|
||||||
|
Use [`create_async_client`](../reference/pytest.md#fastapi_toolsets.pytest.utils.create_async_client) to get an `httpx.AsyncClient` configured for your FastAPI app:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.pytest import create_async_client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def http_client(db_session):
|
||||||
|
async def _override_get_db():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
async with create_async_client(
|
||||||
|
app=app,
|
||||||
|
base_url="http://127.0.0.1/api/v1",
|
||||||
|
dependency_overrides={get_db: _override_get_db},
|
||||||
|
) as c:
|
||||||
|
yield c
|
||||||
|
```
|
||||||
|
|
||||||
|
## Database sessions in tests
|
||||||
|
|
||||||
|
Use [`create_db_session`](../reference/pytest.md#fastapi_toolsets.pytest.utils.create_db_session) to create an isolated `AsyncSession` for a test:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.pytest import create_db_session, create_worker_database
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
async def worker_db_url():
|
||||||
|
async with create_worker_database(
|
||||||
|
database_url=str(settings.SQLALCHEMY_DATABASE_URI)
|
||||||
|
) as url:
|
||||||
|
yield url
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_session(worker_db_url):
|
||||||
|
async with create_db_session(
|
||||||
|
database_url=worker_db_url, base=Base, cleanup=True
|
||||||
|
) as session:
|
||||||
|
yield session
|
||||||
|
```
|
||||||
|
|
||||||
|
!!! info
|
||||||
|
In this example, the database is reset between each test using the argument `cleanup=True`.
|
||||||
|
|
||||||
|
## Parallel testing with pytest-xdist
|
||||||
|
|
||||||
|
The examples above are already compatible with parallel test execution with `pytest-xdist`.
|
||||||
|
|
||||||
|
## Cleaning up tables
|
||||||
|
|
||||||
|
If you want to manually clean up a database you can use [`cleanup_tables`](../reference/pytest.md#fastapi_toolsets.pytest.utils.cleanup_tables), this will truncates all tables between tests for fast isolation:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.pytest import cleanup_tables
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
async def clean(db_session):
|
||||||
|
yield
|
||||||
|
await cleanup_tables(session=db_session, base=Base)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/pytest.md)
|
||||||
49
docs/module/schemas.md
Normal file
49
docs/module/schemas.md
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
# Schemas
|
||||||
|
|
||||||
|
Standardized Pydantic response models for consistent API responses across your FastAPI application.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `schemas` module provides generic response wrappers that enforce a uniform response structure. All models use `from_attributes=True` for ORM compatibility and `validate_assignment=True` for runtime type safety.
|
||||||
|
|
||||||
|
## Response models
|
||||||
|
|
||||||
|
### [`Response[T]`](../reference/schemas.md#fastapi_toolsets.schemas.Response)
|
||||||
|
|
||||||
|
The most common wrapper for a single resource response.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.schemas import Response
|
||||||
|
|
||||||
|
@router.get("/users/{id}")
|
||||||
|
async def get_user(user: User = UserDep) -> Response[UserSchema]:
|
||||||
|
return Response(data=user, message="User retrieved")
|
||||||
|
```
|
||||||
|
|
||||||
|
### [`PaginatedResponse[T]`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse)
|
||||||
|
|
||||||
|
Wraps a list of items with pagination metadata.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.schemas import PaginatedResponse, Pagination
|
||||||
|
|
||||||
|
@router.get("/users")
|
||||||
|
async def list_users() -> PaginatedResponse[UserSchema]:
|
||||||
|
return PaginatedResponse(
|
||||||
|
data=users,
|
||||||
|
pagination=Pagination(
|
||||||
|
total_count=100,
|
||||||
|
items_per_page=10,
|
||||||
|
page=1,
|
||||||
|
has_more=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### [`ErrorResponse`](../reference/schemas.md#fastapi_toolsets.schemas.ErrorResponse)
|
||||||
|
|
||||||
|
Returned automatically by the exceptions handler.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[:material-api: API Reference](../reference/schemas.md)
|
||||||
7
docs/overrides/main.html
Normal file
7
docs/overrides/main.html
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
{% extends "base.html" %} {% block extrahead %}
|
||||||
|
<script
|
||||||
|
defer
|
||||||
|
src="https://analytics.d3vyce.fr/script.js"
|
||||||
|
data-website-id="338b8816-7b99-4c6a-82f3-15595be3fd47"
|
||||||
|
></script>
|
||||||
|
{{ super() }} {% endblock %}
|
||||||
27
docs/reference/cli.md
Normal file
27
docs/reference/cli.md
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
# `cli`
|
||||||
|
|
||||||
|
Here's the reference for the CLI configuration helpers used to load settings from `pyproject.toml`.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.cli.config`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.cli.config import (
|
||||||
|
import_from_string,
|
||||||
|
get_config_value,
|
||||||
|
get_fixtures_registry,
|
||||||
|
get_db_context,
|
||||||
|
get_custom_cli,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.cli.config.import_from_string
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.cli.config.get_config_value
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.cli.config.get_fixtures_registry
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.cli.config.get_db_context
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.cli.config.get_custom_cli
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.cli.utils.async_command
|
||||||
20
docs/reference/crud.md
Normal file
20
docs/reference/crud.md
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
# `crud`
|
||||||
|
|
||||||
|
Here's the reference for the CRUD classes, factory, and search utilities.
|
||||||
|
|
||||||
|
You can import the main symbols from `fastapi_toolsets.crud`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.crud import CrudFactory, AsyncCrud
|
||||||
|
from fastapi_toolsets.crud.search import SearchConfig, get_searchable_fields, build_search_filters
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.crud.factory.AsyncCrud
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.crud.factory.CrudFactory
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.crud.search.SearchConfig
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.crud.search.get_searchable_fields
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.crud.search.build_search_filters
|
||||||
28
docs/reference/db.md
Normal file
28
docs/reference/db.md
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
# `db`
|
||||||
|
|
||||||
|
Here's the reference for all database session utilities, transaction helpers, and locking functions.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.db`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.db import (
|
||||||
|
LockMode,
|
||||||
|
create_db_dependency,
|
||||||
|
create_db_context,
|
||||||
|
get_transaction,
|
||||||
|
lock_tables,
|
||||||
|
wait_for_row_change,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.db.LockMode
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.db.create_db_dependency
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.db.create_db_context
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.db.get_transaction
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.db.lock_tables
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.db.wait_for_row_change
|
||||||
13
docs/reference/dependencies.md
Normal file
13
docs/reference/dependencies.md
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# `dependencies`
|
||||||
|
|
||||||
|
Here's the reference for the FastAPI dependency factory functions.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.dependencies`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.dependencies import PathDependency, BodyDependency
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.dependencies.PathDependency
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.dependencies.BodyDependency
|
||||||
34
docs/reference/exceptions.md
Normal file
34
docs/reference/exceptions.md
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
# `exceptions`
|
||||||
|
|
||||||
|
Here's the reference for all exception classes and handler utilities.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.exceptions`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.exceptions import (
|
||||||
|
ApiException,
|
||||||
|
UnauthorizedError,
|
||||||
|
ForbiddenError,
|
||||||
|
NotFoundError,
|
||||||
|
ConflictError,
|
||||||
|
NoSearchableFieldsError,
|
||||||
|
generate_error_responses,
|
||||||
|
init_exceptions_handlers,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.ApiException
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.UnauthorizedError
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.ForbiddenError
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.NotFoundError
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.ConflictError
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.NoSearchableFieldsError
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.generate_error_responses
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.handler.init_exceptions_handlers
|
||||||
31
docs/reference/fixtures.md
Normal file
31
docs/reference/fixtures.md
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
# `fixtures`
|
||||||
|
|
||||||
|
Here's the reference for the fixture registry, enums, and loading utilities.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.fixtures`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.fixtures import (
|
||||||
|
Context,
|
||||||
|
LoadStrategy,
|
||||||
|
Fixture,
|
||||||
|
FixtureRegistry,
|
||||||
|
load_fixtures,
|
||||||
|
load_fixtures_by_context,
|
||||||
|
get_obj_by_attr,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.fixtures.enum.Context
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.fixtures.enum.LoadStrategy
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.fixtures.registry.Fixture
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.fixtures.registry.FixtureRegistry
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.fixtures.utils.load_fixtures
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.fixtures.utils.load_fixtures_by_context
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.fixtures.utils.get_obj_by_attr
|
||||||
13
docs/reference/logger.md
Normal file
13
docs/reference/logger.md
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# `logger`
|
||||||
|
|
||||||
|
Here's the reference for the logging utilities.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.logger`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.logger import configure_logging, get_logger
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.logger.configure_logging
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.logger.get_logger
|
||||||
15
docs/reference/metrics.md
Normal file
15
docs/reference/metrics.md
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
# `metrics`
|
||||||
|
|
||||||
|
Here's the reference for the Prometheus metrics registry and endpoint handler.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.metrics`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.metrics import Metric, MetricsRegistry, init_metrics
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.metrics.registry.Metric
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.metrics.registry.MetricsRegistry
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.metrics.handler.init_metrics
|
||||||
28
docs/reference/pytest.md
Normal file
28
docs/reference/pytest.md
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
# `pytest`
|
||||||
|
|
||||||
|
Here's the reference for all testing utilities and pytest fixtures.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.pytest`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.pytest import (
|
||||||
|
register_fixtures,
|
||||||
|
create_async_client,
|
||||||
|
create_db_session,
|
||||||
|
worker_database_url,
|
||||||
|
create_worker_database,
|
||||||
|
cleanup_tables,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.pytest.plugin.register_fixtures
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.pytest.utils.create_async_client
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.pytest.utils.create_db_session
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.pytest.utils.worker_database_url
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.pytest.utils.create_worker_database
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.pytest.utils.cleanup_tables
|
||||||
34
docs/reference/schemas.md
Normal file
34
docs/reference/schemas.md
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
# `schemas` module
|
||||||
|
|
||||||
|
Here's the reference for all response models and types provided by the `schemas` module.
|
||||||
|
|
||||||
|
You can import them directly from `fastapi_toolsets.schemas`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.schemas import (
|
||||||
|
PydanticBase,
|
||||||
|
ResponseStatus,
|
||||||
|
ApiError,
|
||||||
|
BaseResponse,
|
||||||
|
Response,
|
||||||
|
ErrorResponse,
|
||||||
|
Pagination,
|
||||||
|
PaginatedResponse,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.PydanticBase
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.ResponseStatus
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.ApiError
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.BaseResponse
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.Response
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.ErrorResponse
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.Pagination
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.schemas.PaginatedResponse
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "fastapi-toolsets"
|
name = "fastapi-toolsets"
|
||||||
version = "0.4.1"
|
version = "1.0.0"
|
||||||
description = "Reusable tools for FastAPI: async CRUD, fixtures, CLI, and standardized responses for SQLAlchemy + PostgreSQL"
|
description = "Reusable tools for FastAPI: async CRUD, fixtures, CLI, and standardized responses for SQLAlchemy + PostgreSQL"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
@@ -31,12 +31,10 @@ classifiers = [
|
|||||||
"Typing :: Typed",
|
"Typing :: Typed",
|
||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"fastapi>=0.100.0",
|
|
||||||
"sqlalchemy[asyncio]>=2.0",
|
|
||||||
"asyncpg>=0.29.0",
|
"asyncpg>=0.29.0",
|
||||||
|
"fastapi>=0.100.0",
|
||||||
"pydantic>=2.0",
|
"pydantic>=2.0",
|
||||||
"typer>=0.9.0",
|
"sqlalchemy[asyncio]>=2.0",
|
||||||
"httpx>=0.25.0",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
@@ -46,23 +44,47 @@ Repository = "https://github.com/d3vyce/fastapi-toolsets"
|
|||||||
Issues = "https://github.com/d3vyce/fastapi-toolsets/issues"
|
Issues = "https://github.com/d3vyce/fastapi-toolsets/issues"
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
test = [
|
cli = [
|
||||||
"pytest>=8.0.0",
|
"typer>=0.9.0",
|
||||||
"pytest-anyio>=0.0.0",
|
|
||||||
"coverage>=7.0.0",
|
|
||||||
"pytest-cov>=4.0.0",
|
|
||||||
]
|
]
|
||||||
dev = [
|
metrics = [
|
||||||
"fastapi-toolsets[test]",
|
"prometheus_client>=0.20.0",
|
||||||
"ruff>=0.1.0",
|
]
|
||||||
"ty>=0.0.1a0",
|
pytest = [
|
||||||
|
"httpx>=0.25.0",
|
||||||
|
"pytest-xdist>=3.0.0",
|
||||||
|
"pytest>=8.0.0",
|
||||||
|
]
|
||||||
|
all = [
|
||||||
|
"fastapi-toolsets[cli,metrics,pytest]",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
fastapi-toolsets = "fastapi_toolsets.cli:app"
|
manager = "fastapi_toolsets.cli.app:cli"
|
||||||
|
|
||||||
|
[dependency-groups]
|
||||||
|
dev = [
|
||||||
|
{include-group = "tests"},
|
||||||
|
{include-group = "docs"},
|
||||||
|
"fastapi-toolsets[all]",
|
||||||
|
"ruff>=0.1.0",
|
||||||
|
"ty>=0.0.1a0",
|
||||||
|
]
|
||||||
|
tests = [
|
||||||
|
"coverage>=7.0.0",
|
||||||
|
"httpx>=0.25.0",
|
||||||
|
"pytest-anyio>=0.0.0",
|
||||||
|
"pytest-cov>=4.0.0",
|
||||||
|
"pytest-xdist>=3.0.0",
|
||||||
|
"pytest>=8.0.0",
|
||||||
|
]
|
||||||
|
docs = [
|
||||||
|
"mkdocstrings-python>=2.0.2",
|
||||||
|
"zensical>=0.0.23",
|
||||||
|
]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["uv_build>=0.9.26,<0.10.0"]
|
requires = ["uv_build>=0.10,<0.11.0"]
|
||||||
build-backend = "uv_build"
|
build-backend = "uv_build"
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
|
|||||||
@@ -21,4 +21,4 @@ Example usage:
|
|||||||
return Response(data={"user": user.username}, message="Success")
|
return Response(data={"user": user.username}, message="Success")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__version__ = "0.4.1"
|
__version__ = "1.0.0"
|
||||||
|
|||||||
9
src/fastapi_toolsets/_imports.py
Normal file
9
src/fastapi_toolsets/_imports.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
"""Optional dependency helpers."""
|
||||||
|
|
||||||
|
|
||||||
|
def require_extra(package: str, extra: str) -> None:
|
||||||
|
"""Raise *ImportError* with an actionable install instruction."""
|
||||||
|
raise ImportError(
|
||||||
|
f"'{package}' is required to use this module. "
|
||||||
|
f"Install it with: pip install fastapi-toolsets[{extra}]"
|
||||||
|
)
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
"""CLI for FastAPI projects."""
|
"""CLI for FastAPI projects."""
|
||||||
|
|
||||||
from .app import app, register_command
|
from .utils import async_command
|
||||||
|
|
||||||
__all__ = ["app", "register_command"]
|
__all__ = ["async_command"]
|
||||||
|
|||||||
@@ -1,97 +1,37 @@
|
|||||||
"""Main CLI application."""
|
"""Main CLI application."""
|
||||||
|
|
||||||
import importlib.util
|
try:
|
||||||
import sys
|
import typer
|
||||||
from pathlib import Path
|
except ImportError:
|
||||||
from typing import Annotated
|
from .._imports import require_extra
|
||||||
|
|
||||||
import typer
|
require_extra(package="typer", extra="cli")
|
||||||
|
|
||||||
from .commands import fixtures
|
from ..logger import configure_logging
|
||||||
|
from .config import get_custom_cli
|
||||||
|
from .pyproject import load_pyproject
|
||||||
|
|
||||||
app = typer.Typer(
|
# Use custom CLI if configured, otherwise create default one
|
||||||
name="fastapi-utils",
|
_custom_cli = get_custom_cli()
|
||||||
help="CLI utilities for FastAPI projects.",
|
|
||||||
no_args_is_help=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Register built-in commands
|
if _custom_cli is not None:
|
||||||
app.add_typer(fixtures.app, name="fixtures")
|
cli = _custom_cli
|
||||||
|
else:
|
||||||
|
cli = typer.Typer(
|
||||||
|
name="manager",
|
||||||
|
help="CLI utilities for FastAPI projects.",
|
||||||
|
no_args_is_help=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
_config = load_pyproject()
|
||||||
|
if _config.get("fixtures") and _config.get("db_context"):
|
||||||
|
from .commands.fixtures import fixture_cli
|
||||||
|
|
||||||
|
cli.add_typer(fixture_cli, name="fixtures")
|
||||||
|
|
||||||
|
|
||||||
def register_command(command: typer.Typer, name: str) -> None:
|
@cli.callback()
|
||||||
"""Register a custom command group.
|
def main(ctx: typer.Context) -> None:
|
||||||
|
|
||||||
Args:
|
|
||||||
command: Typer app for the command group
|
|
||||||
name: Name for the command group
|
|
||||||
|
|
||||||
Example:
|
|
||||||
# In your project's cli.py:
|
|
||||||
import typer
|
|
||||||
from fastapi_toolsets.cli import app, register_command
|
|
||||||
|
|
||||||
my_commands = typer.Typer()
|
|
||||||
|
|
||||||
@my_commands.command()
|
|
||||||
def seed():
|
|
||||||
'''Seed the database.'''
|
|
||||||
...
|
|
||||||
|
|
||||||
register_command(my_commands, "db")
|
|
||||||
# Now available as: fastapi-utils db seed
|
|
||||||
"""
|
|
||||||
app.add_typer(command, name=name)
|
|
||||||
|
|
||||||
|
|
||||||
@app.callback()
|
|
||||||
def main(
|
|
||||||
ctx: typer.Context,
|
|
||||||
config: Annotated[
|
|
||||||
Path | None,
|
|
||||||
typer.Option(
|
|
||||||
"--config",
|
|
||||||
"-c",
|
|
||||||
help="Path to project config file (Python module with fixtures registry).",
|
|
||||||
envvar="FASTAPI_TOOLSETS_CONFIG",
|
|
||||||
),
|
|
||||||
] = None,
|
|
||||||
) -> None:
|
|
||||||
"""FastAPI utilities CLI."""
|
"""FastAPI utilities CLI."""
|
||||||
|
configure_logging()
|
||||||
ctx.ensure_object(dict)
|
ctx.ensure_object(dict)
|
||||||
|
|
||||||
if config:
|
|
||||||
ctx.obj["config_path"] = config
|
|
||||||
# Load the config module
|
|
||||||
config_module = _load_module_from_path(config)
|
|
||||||
ctx.obj["config_module"] = config_module
|
|
||||||
|
|
||||||
|
|
||||||
def _load_module_from_path(path: Path) -> object:
|
|
||||||
"""Load a Python module from a file path.
|
|
||||||
|
|
||||||
Handles both absolute and relative imports by adding the config's
|
|
||||||
parent directory to sys.path temporarily.
|
|
||||||
"""
|
|
||||||
path = path.resolve()
|
|
||||||
|
|
||||||
# Add the parent directory to sys.path to support relative imports
|
|
||||||
parent_dir = str(
|
|
||||||
path.parent.parent
|
|
||||||
) # Go up two levels (e.g., from app/cli_config.py to project root)
|
|
||||||
if parent_dir not in sys.path:
|
|
||||||
sys.path.insert(0, parent_dir)
|
|
||||||
|
|
||||||
# Also add immediate parent for direct module imports
|
|
||||||
immediate_parent = str(path.parent)
|
|
||||||
if immediate_parent not in sys.path:
|
|
||||||
sys.path.insert(0, immediate_parent)
|
|
||||||
|
|
||||||
spec = importlib.util.spec_from_file_location("config", path)
|
|
||||||
if spec is None or spec.loader is None:
|
|
||||||
raise typer.BadParameter(f"Cannot load module from {path}")
|
|
||||||
|
|
||||||
module = importlib.util.module_from_spec(spec)
|
|
||||||
sys.modules["config"] = module
|
|
||||||
spec.loader.exec_module(module)
|
|
||||||
return module
|
|
||||||
|
|||||||
@@ -1,138 +1,66 @@
|
|||||||
"""Fixture management commands."""
|
"""Fixture management commands."""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
from ...fixtures import Context, FixtureRegistry, LoadStrategy, load_fixtures_by_context
|
from ...fixtures import Context, LoadStrategy, load_fixtures_by_context
|
||||||
|
from ..config import get_db_context, get_fixtures_registry
|
||||||
|
from ..utils import async_command
|
||||||
|
|
||||||
app = typer.Typer(
|
fixture_cli = typer.Typer(
|
||||||
name="fixtures",
|
name="fixtures",
|
||||||
help="Manage database fixtures.",
|
help="Manage database fixtures.",
|
||||||
no_args_is_help=True,
|
no_args_is_help=True,
|
||||||
)
|
)
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
|
||||||
def _get_registry(ctx: typer.Context) -> FixtureRegistry:
|
@fixture_cli.command("list")
|
||||||
"""Get fixture registry from context."""
|
|
||||||
config = ctx.obj.get("config_module") if ctx.obj else None
|
|
||||||
if config is None:
|
|
||||||
raise typer.BadParameter(
|
|
||||||
"No config provided. Use --config to specify a config file with a 'fixtures' registry."
|
|
||||||
)
|
|
||||||
|
|
||||||
registry = getattr(config, "fixtures", None)
|
|
||||||
if registry is None:
|
|
||||||
raise typer.BadParameter(
|
|
||||||
"Config module must have a 'fixtures' attribute (FixtureRegistry instance)."
|
|
||||||
)
|
|
||||||
|
|
||||||
if not isinstance(registry, FixtureRegistry):
|
|
||||||
raise typer.BadParameter(
|
|
||||||
f"'fixtures' must be a FixtureRegistry instance, got {type(registry).__name__}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return registry
|
|
||||||
|
|
||||||
|
|
||||||
def _get_db_context(ctx: typer.Context):
|
|
||||||
"""Get database context manager from config."""
|
|
||||||
config = ctx.obj.get("config_module") if ctx.obj else None
|
|
||||||
if config is None:
|
|
||||||
raise typer.BadParameter("No config provided.")
|
|
||||||
|
|
||||||
get_db_context = getattr(config, "get_db_context", None)
|
|
||||||
if get_db_context is None:
|
|
||||||
raise typer.BadParameter("Config module must have a 'get_db_context' function.")
|
|
||||||
|
|
||||||
return get_db_context
|
|
||||||
|
|
||||||
|
|
||||||
@app.command("list")
|
|
||||||
def list_fixtures(
|
def list_fixtures(
|
||||||
ctx: typer.Context,
|
ctx: typer.Context,
|
||||||
context: Annotated[
|
context: Annotated[
|
||||||
str | None,
|
Context | None,
|
||||||
typer.Option(
|
typer.Option(
|
||||||
"--context",
|
"--context",
|
||||||
"-c",
|
"-c",
|
||||||
help="Filter by context (base, production, development, testing).",
|
help="Filter by context.",
|
||||||
),
|
),
|
||||||
] = None,
|
] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""List all registered fixtures."""
|
"""List all registered fixtures."""
|
||||||
registry = _get_registry(ctx)
|
registry = get_fixtures_registry()
|
||||||
|
fixtures = registry.get_by_context(context.value) if context else registry.get_all()
|
||||||
if context:
|
|
||||||
fixtures = registry.get_by_context(context)
|
|
||||||
else:
|
|
||||||
fixtures = registry.get_all()
|
|
||||||
|
|
||||||
if not fixtures:
|
if not fixtures:
|
||||||
typer.echo("No fixtures found.")
|
print("No fixtures found.")
|
||||||
return
|
return
|
||||||
|
|
||||||
typer.echo(f"\n{'Name':<30} {'Contexts':<30} {'Dependencies'}")
|
table = Table("Name", "Contexts", "Dependencies")
|
||||||
typer.echo("-" * 80)
|
|
||||||
|
|
||||||
for fixture in fixtures:
|
for fixture in fixtures:
|
||||||
contexts = ", ".join(fixture.contexts)
|
contexts = ", ".join(fixture.contexts)
|
||||||
deps = ", ".join(fixture.depends_on) if fixture.depends_on else "-"
|
deps = ", ".join(fixture.depends_on) if fixture.depends_on else "-"
|
||||||
typer.echo(f"{fixture.name:<30} {contexts:<30} {deps}")
|
table.add_row(fixture.name, contexts, deps)
|
||||||
|
|
||||||
typer.echo(f"\nTotal: {len(fixtures)} fixture(s)")
|
console.print(table)
|
||||||
|
print(f"\nTotal: {len(fixtures)} fixture(s)")
|
||||||
|
|
||||||
|
|
||||||
@app.command("graph")
|
@fixture_cli.command("load")
|
||||||
def show_graph(
|
@async_command
|
||||||
ctx: typer.Context,
|
async def load(
|
||||||
fixture_name: Annotated[
|
|
||||||
str | None,
|
|
||||||
typer.Argument(help="Show dependencies for a specific fixture."),
|
|
||||||
] = None,
|
|
||||||
) -> None:
|
|
||||||
"""Show fixture dependency graph."""
|
|
||||||
registry = _get_registry(ctx)
|
|
||||||
|
|
||||||
if fixture_name:
|
|
||||||
try:
|
|
||||||
order = registry.resolve_dependencies(fixture_name)
|
|
||||||
typer.echo(f"\nDependency chain for '{fixture_name}':\n")
|
|
||||||
for i, name in enumerate(order):
|
|
||||||
indent = " " * i
|
|
||||||
arrow = "└─> " if i > 0 else ""
|
|
||||||
typer.echo(f"{indent}{arrow}{name}")
|
|
||||||
except KeyError:
|
|
||||||
typer.echo(f"Fixture '{fixture_name}' not found.", err=True)
|
|
||||||
raise typer.Exit(1)
|
|
||||||
else:
|
|
||||||
# Show full graph
|
|
||||||
fixtures = registry.get_all()
|
|
||||||
|
|
||||||
typer.echo("\nFixture Dependency Graph:\n")
|
|
||||||
for fixture in fixtures:
|
|
||||||
deps = (
|
|
||||||
f" -> [{', '.join(fixture.depends_on)}]" if fixture.depends_on else ""
|
|
||||||
)
|
|
||||||
typer.echo(f" {fixture.name}{deps}")
|
|
||||||
|
|
||||||
|
|
||||||
@app.command("load")
|
|
||||||
def load(
|
|
||||||
ctx: typer.Context,
|
ctx: typer.Context,
|
||||||
contexts: Annotated[
|
contexts: Annotated[
|
||||||
list[str] | None,
|
list[Context] | None,
|
||||||
typer.Argument(
|
typer.Argument(help="Contexts to load."),
|
||||||
help="Contexts to load (base, production, development, testing)."
|
|
||||||
),
|
|
||||||
] = None,
|
] = None,
|
||||||
strategy: Annotated[
|
strategy: Annotated[
|
||||||
str,
|
LoadStrategy,
|
||||||
typer.Option(
|
typer.Option("--strategy", "-s", help="Load strategy."),
|
||||||
"--strategy", "-s", help="Load strategy: merge, insert, skip_existing."
|
] = LoadStrategy.MERGE,
|
||||||
),
|
|
||||||
] = "merge",
|
|
||||||
dry_run: Annotated[
|
dry_run: Annotated[
|
||||||
bool,
|
bool,
|
||||||
typer.Option(
|
typer.Option(
|
||||||
@@ -141,85 +69,32 @@ def load(
|
|||||||
] = False,
|
] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Load fixtures into the database."""
|
"""Load fixtures into the database."""
|
||||||
registry = _get_registry(ctx)
|
registry = get_fixtures_registry()
|
||||||
get_db_context = _get_db_context(ctx)
|
db_context = get_db_context()
|
||||||
|
|
||||||
# Parse contexts
|
context_list = [c.value for c in contexts] if contexts else [Context.BASE]
|
||||||
if contexts:
|
|
||||||
context_list = contexts
|
|
||||||
else:
|
|
||||||
context_list = [Context.BASE]
|
|
||||||
|
|
||||||
# Parse strategy
|
|
||||||
try:
|
|
||||||
load_strategy = LoadStrategy(strategy)
|
|
||||||
except ValueError:
|
|
||||||
typer.echo(
|
|
||||||
f"Invalid strategy: {strategy}. Use: merge, insert, skip_existing", err=True
|
|
||||||
)
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
# Resolve what will be loaded
|
|
||||||
ordered = registry.resolve_context_dependencies(*context_list)
|
ordered = registry.resolve_context_dependencies(*context_list)
|
||||||
|
|
||||||
if not ordered:
|
if not ordered:
|
||||||
typer.echo("No fixtures to load for the specified context(s).")
|
print("No fixtures to load for the specified context(s).")
|
||||||
return
|
return
|
||||||
|
|
||||||
typer.echo(f"\nFixtures to load ({load_strategy.value} strategy):")
|
print(f"\nFixtures to load ({strategy.value} strategy):")
|
||||||
for name in ordered:
|
for name in ordered:
|
||||||
fixture = registry.get(name)
|
fixture = registry.get(name)
|
||||||
instances = list(fixture.func())
|
instances = list(fixture.func())
|
||||||
model_name = type(instances[0]).__name__ if instances else "?"
|
model_name = type(instances[0]).__name__ if instances else "?"
|
||||||
typer.echo(f" - {name}: {len(instances)} {model_name}(s)")
|
print(f" - {name}: {len(instances)} {model_name}(s)")
|
||||||
|
|
||||||
if dry_run:
|
if dry_run:
|
||||||
typer.echo("\n[Dry run - no changes made]")
|
print("\n[Dry run - no changes made]")
|
||||||
return
|
return
|
||||||
|
|
||||||
typer.echo("\nLoading...")
|
async with db_context() as session:
|
||||||
|
result = await load_fixtures_by_context(
|
||||||
async def do_load():
|
session, registry, *context_list, strategy=strategy
|
||||||
async with get_db_context() as session:
|
)
|
||||||
result = await load_fixtures_by_context(
|
|
||||||
session, registry, *context_list, strategy=load_strategy
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
result = asyncio.run(do_load())
|
|
||||||
|
|
||||||
total = sum(len(items) for items in result.values())
|
total = sum(len(items) for items in result.values())
|
||||||
typer.echo(f"\nLoaded {total} record(s) successfully.")
|
print(f"\nLoaded {total} record(s) successfully.")
|
||||||
|
|
||||||
|
|
||||||
@app.command("show")
|
|
||||||
def show_fixture(
|
|
||||||
ctx: typer.Context,
|
|
||||||
name: Annotated[str, typer.Argument(help="Fixture name to show.")],
|
|
||||||
) -> None:
|
|
||||||
"""Show details of a specific fixture."""
|
|
||||||
registry = _get_registry(ctx)
|
|
||||||
|
|
||||||
try:
|
|
||||||
fixture = registry.get(name)
|
|
||||||
except KeyError:
|
|
||||||
typer.echo(f"Fixture '{name}' not found.", err=True)
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
typer.echo(f"\nFixture: {fixture.name}")
|
|
||||||
typer.echo(f"Contexts: {', '.join(fixture.contexts)}")
|
|
||||||
typer.echo(
|
|
||||||
f"Dependencies: {', '.join(fixture.depends_on) if fixture.depends_on else 'None'}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Show instances
|
|
||||||
instances = list(fixture.func())
|
|
||||||
if instances:
|
|
||||||
model_name = type(instances[0]).__name__
|
|
||||||
typer.echo(f"\nInstances ({len(instances)} {model_name}):")
|
|
||||||
for instance in instances[:10]: # Limit to 10
|
|
||||||
typer.echo(f" - {instance!r}")
|
|
||||||
if len(instances) > 10:
|
|
||||||
typer.echo(f" ... and {len(instances) - 10} more")
|
|
||||||
else:
|
|
||||||
typer.echo("\nNo instances (empty fixture)")
|
|
||||||
|
|||||||
123
src/fastapi_toolsets/cli/config.py
Normal file
123
src/fastapi_toolsets/cli/config.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
"""CLI configuration and dynamic imports."""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import sys
|
||||||
|
from typing import TYPE_CHECKING, Any, Literal, overload
|
||||||
|
|
||||||
|
import typer
|
||||||
|
|
||||||
|
from .pyproject import find_pyproject, load_pyproject
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..fixtures import FixtureRegistry
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_project_in_path():
|
||||||
|
"""Add project root to sys.path if not installed in editable mode."""
|
||||||
|
pyproject = find_pyproject()
|
||||||
|
if pyproject:
|
||||||
|
project_root = str(pyproject.parent)
|
||||||
|
if project_root not in sys.path:
|
||||||
|
sys.path.insert(0, project_root)
|
||||||
|
|
||||||
|
|
||||||
|
def import_from_string(import_path: str) -> Any:
|
||||||
|
"""Import an object from a dotted string path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
import_path: Import path in ``"module.submodule:attribute"`` format
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The imported attribute
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
typer.BadParameter: If the import path is invalid or import fails
|
||||||
|
"""
|
||||||
|
if ":" not in import_path:
|
||||||
|
raise typer.BadParameter(
|
||||||
|
f"Invalid import path '{import_path}'. Expected format: 'module:attribute'"
|
||||||
|
)
|
||||||
|
|
||||||
|
module_path, attr_name = import_path.rsplit(":", 1)
|
||||||
|
|
||||||
|
_ensure_project_in_path()
|
||||||
|
|
||||||
|
try:
|
||||||
|
module = importlib.import_module(module_path)
|
||||||
|
except ImportError as e:
|
||||||
|
raise typer.BadParameter(f"Cannot import module '{module_path}': {e}")
|
||||||
|
|
||||||
|
if not hasattr(module, attr_name):
|
||||||
|
raise typer.BadParameter(
|
||||||
|
f"Module '{module_path}' has no attribute '{attr_name}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
return getattr(module, attr_name)
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_config_value(key: str, required: Literal[True]) -> Any: ... # pragma: no cover
|
||||||
|
@overload
|
||||||
|
def get_config_value(
|
||||||
|
key: str, required: bool = False
|
||||||
|
) -> Any | None: ... # pragma: no cover
|
||||||
|
def get_config_value(key: str, required: bool = False) -> Any | None:
|
||||||
|
"""Get a configuration value from pyproject.toml.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The configuration key in [tool.fastapi-toolsets].
|
||||||
|
required: If True, raises an error when the key is missing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The configuration value, or None if not found and not required.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
typer.BadParameter: If required=True and the key is missing.
|
||||||
|
"""
|
||||||
|
config = load_pyproject()
|
||||||
|
value = config.get(key)
|
||||||
|
|
||||||
|
if required and value is None:
|
||||||
|
raise typer.BadParameter(
|
||||||
|
f"No '{key}' configured. "
|
||||||
|
f"Add '{key}' to [tool.fastapi-toolsets] in pyproject.toml."
|
||||||
|
)
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def get_fixtures_registry() -> FixtureRegistry:
|
||||||
|
"""Import and return the fixtures registry from config."""
|
||||||
|
from ..fixtures import FixtureRegistry
|
||||||
|
|
||||||
|
import_path = get_config_value("fixtures", required=True)
|
||||||
|
registry = import_from_string(import_path)
|
||||||
|
|
||||||
|
if not isinstance(registry, FixtureRegistry):
|
||||||
|
raise typer.BadParameter(
|
||||||
|
f"'fixtures' must be a FixtureRegistry instance, got {type(registry).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return registry
|
||||||
|
|
||||||
|
|
||||||
|
def get_db_context() -> Any:
|
||||||
|
"""Import and return the db_context function from config."""
|
||||||
|
import_path = get_config_value("db_context", required=True)
|
||||||
|
return import_from_string(import_path)
|
||||||
|
|
||||||
|
|
||||||
|
def get_custom_cli() -> typer.Typer | None:
|
||||||
|
"""Import and return the custom CLI Typer instance from config."""
|
||||||
|
import_path = get_config_value("custom_cli")
|
||||||
|
if not import_path:
|
||||||
|
return None
|
||||||
|
|
||||||
|
custom = import_from_string(import_path)
|
||||||
|
|
||||||
|
if not isinstance(custom, typer.Typer):
|
||||||
|
raise typer.BadParameter(
|
||||||
|
f"'custom_cli' must be a Typer instance, got {type(custom).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return custom
|
||||||
49
src/fastapi_toolsets/cli/pyproject.py
Normal file
49
src/fastapi_toolsets/cli/pyproject.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
"""Pyproject.toml discovery and loading."""
|
||||||
|
|
||||||
|
import tomllib
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
TOOL_NAME = "fastapi-toolsets"
|
||||||
|
|
||||||
|
|
||||||
|
def find_pyproject(start_path: Path | None = None) -> Path | None:
|
||||||
|
"""Find pyproject.toml by walking up the directory tree.
|
||||||
|
|
||||||
|
Similar to how pytest, black, and ruff discover their config files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_path: Directory to start searching from. Defaults to cwd.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to pyproject.toml, or None if not found.
|
||||||
|
"""
|
||||||
|
path = (start_path or Path.cwd()).resolve()
|
||||||
|
|
||||||
|
for directory in [path, *path.parents]:
|
||||||
|
pyproject = directory / "pyproject.toml"
|
||||||
|
if pyproject.is_file():
|
||||||
|
return pyproject
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def load_pyproject(path: Path | None = None) -> dict:
|
||||||
|
"""Load tool configuration from pyproject.toml.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Explicit path to pyproject.toml. If None, searches up from cwd.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The [tool.fastapi-toolsets] section as a dict, or empty dict if not found.
|
||||||
|
"""
|
||||||
|
pyproject_path = path or find_pyproject()
|
||||||
|
|
||||||
|
if not pyproject_path:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(pyproject_path, "rb") as f:
|
||||||
|
data = tomllib.load(f)
|
||||||
|
return data.get("tool", {}).get(TOOL_NAME, {})
|
||||||
|
except (OSError, tomllib.TOMLDecodeError):
|
||||||
|
return {}
|
||||||
29
src/fastapi_toolsets/cli/utils.py
Normal file
29
src/fastapi_toolsets/cli/utils.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
"""CLI utility functions."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import functools
|
||||||
|
from collections.abc import Callable, Coroutine
|
||||||
|
from typing import Any, ParamSpec, TypeVar
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def async_command(func: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, T]:
|
||||||
|
"""Decorator to run an async function as a sync CLI command.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
@fixture_cli.command("load")
|
||||||
|
@async_command
|
||||||
|
async def load(ctx: typer.Context) -> None:
|
||||||
|
async with get_db_context() as session:
|
||||||
|
await load_fixtures(session, registry)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||||
|
return asyncio.run(func(*args, **kwargs))
|
||||||
|
|
||||||
|
return wrapper
|
||||||
@@ -1,17 +1,17 @@
|
|||||||
"""Generic async CRUD operations for SQLAlchemy models."""
|
"""Generic async CRUD operations for SQLAlchemy models."""
|
||||||
|
|
||||||
from ..exceptions import NoSearchableFieldsError
|
from ..exceptions import NoSearchableFieldsError
|
||||||
from .factory import CrudFactory
|
from .factory import CrudFactory, JoinType, M2MFieldType
|
||||||
from .search import (
|
from .search import (
|
||||||
SearchConfig,
|
SearchConfig,
|
||||||
SearchFieldType,
|
|
||||||
get_searchable_fields,
|
get_searchable_fields,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"CrudFactory",
|
"CrudFactory",
|
||||||
"get_searchable_fields",
|
"get_searchable_fields",
|
||||||
|
"JoinType",
|
||||||
|
"M2MFieldType",
|
||||||
"NoSearchableFieldsError",
|
"NoSearchableFieldsError",
|
||||||
"SearchConfig",
|
"SearchConfig",
|
||||||
"SearchFieldType",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
"""Generic async CRUD operations for SQLAlchemy models."""
|
"""Generic async CRUD operations for SQLAlchemy models."""
|
||||||
|
|
||||||
from collections.abc import Sequence
|
from __future__ import annotations
|
||||||
from typing import Any, ClassVar, Generic, Self, TypeVar, cast
|
|
||||||
|
from collections.abc import Mapping, Sequence
|
||||||
|
from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import and_, func, select
|
from sqlalchemy import and_, func, select
|
||||||
@@ -9,14 +11,17 @@ from sqlalchemy import delete as sql_delete
|
|||||||
from sqlalchemy.dialects.postgresql import insert
|
from sqlalchemy.dialects.postgresql import insert
|
||||||
from sqlalchemy.exc import NoResultFound
|
from sqlalchemy.exc import NoResultFound
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase, QueryableAttribute, selectinload
|
||||||
from sqlalchemy.sql.roles import WhereHavingRole
|
from sqlalchemy.sql.roles import WhereHavingRole
|
||||||
|
|
||||||
from ..db import get_transaction
|
from ..db import get_transaction
|
||||||
from ..exceptions import NotFoundError
|
from ..exceptions import NotFoundError
|
||||||
|
from ..schemas import PaginatedResponse, Pagination, Response
|
||||||
from .search import SearchConfig, SearchFieldType, build_search_filters
|
from .search import SearchConfig, SearchFieldType, build_search_filters
|
||||||
|
|
||||||
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
||||||
|
JoinType = list[tuple[type[DeclarativeBase], Any]]
|
||||||
|
M2MFieldType = Mapping[str, QueryableAttribute[Any]]
|
||||||
|
|
||||||
|
|
||||||
class AsyncCrud(Generic[ModelType]):
|
class AsyncCrud(Generic[ModelType]):
|
||||||
@@ -27,27 +32,148 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
|
|
||||||
model: ClassVar[type[DeclarativeBase]]
|
model: ClassVar[type[DeclarativeBase]]
|
||||||
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
|
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
|
||||||
|
m2m_fields: ClassVar[M2MFieldType | None] = None
|
||||||
|
|
||||||
|
@overload
|
||||||
|
@classmethod
|
||||||
|
async def create( # pragma: no cover
|
||||||
|
cls: type[Self],
|
||||||
|
session: AsyncSession,
|
||||||
|
obj: BaseModel,
|
||||||
|
*,
|
||||||
|
as_response: Literal[True],
|
||||||
|
) -> Response[ModelType]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
@classmethod
|
||||||
|
async def create( # pragma: no cover
|
||||||
|
cls: type[Self],
|
||||||
|
session: AsyncSession,
|
||||||
|
obj: BaseModel,
|
||||||
|
*,
|
||||||
|
as_response: Literal[False] = ...,
|
||||||
|
) -> ModelType: ...
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _resolve_m2m(
|
||||||
|
cls: type[Self],
|
||||||
|
session: AsyncSession,
|
||||||
|
obj: BaseModel,
|
||||||
|
*,
|
||||||
|
only_set: bool = False,
|
||||||
|
) -> dict[str, list[Any]]:
|
||||||
|
"""Resolve M2M fields from a Pydantic schema into related model instances.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: DB async session
|
||||||
|
obj: Pydantic model containing M2M ID fields
|
||||||
|
only_set: If True, only process fields explicitly set on the schema
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping relationship attr names to lists of related instances
|
||||||
|
"""
|
||||||
|
result: dict[str, list[Any]] = {}
|
||||||
|
if not cls.m2m_fields:
|
||||||
|
return result
|
||||||
|
|
||||||
|
for schema_field, rel in cls.m2m_fields.items():
|
||||||
|
rel_attr = rel.property.key
|
||||||
|
related_model = rel.property.mapper.class_
|
||||||
|
if only_set and schema_field not in obj.model_fields_set:
|
||||||
|
continue
|
||||||
|
ids = getattr(obj, schema_field, None)
|
||||||
|
if ids is not None:
|
||||||
|
related = (
|
||||||
|
(
|
||||||
|
await session.execute(
|
||||||
|
select(related_model).where(related_model.id.in_(ids))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.scalars()
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
if len(related) != len(ids):
|
||||||
|
found_ids = {r.id for r in related}
|
||||||
|
missing = set(ids) - found_ids
|
||||||
|
raise NotFoundError(
|
||||||
|
f"Related {related_model.__name__} not found for IDs: {missing}"
|
||||||
|
)
|
||||||
|
result[rel_attr] = list(related)
|
||||||
|
else:
|
||||||
|
result[rel_attr] = []
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _m2m_schema_fields(cls: type[Self]) -> set[str]:
|
||||||
|
"""Return the set of schema field names that are M2M fields."""
|
||||||
|
if not cls.m2m_fields:
|
||||||
|
return set()
|
||||||
|
return set(cls.m2m_fields.keys())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create(
|
async def create(
|
||||||
cls: type[Self],
|
cls: type[Self],
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
obj: BaseModel,
|
obj: BaseModel,
|
||||||
) -> ModelType:
|
*,
|
||||||
|
as_response: bool = False,
|
||||||
|
) -> ModelType | Response[ModelType]:
|
||||||
"""Create a new record in the database.
|
"""Create a new record in the database.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session: DB async session
|
session: DB async session
|
||||||
obj: Pydantic model with data to create
|
obj: Pydantic model with data to create
|
||||||
|
as_response: If True, wrap result in Response object
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Created model instance
|
Created model instance or Response wrapping it
|
||||||
"""
|
"""
|
||||||
async with get_transaction(session):
|
async with get_transaction(session):
|
||||||
db_model = cls.model(**obj.model_dump())
|
m2m_exclude = cls._m2m_schema_fields()
|
||||||
|
data = (
|
||||||
|
obj.model_dump(exclude=m2m_exclude) if m2m_exclude else obj.model_dump()
|
||||||
|
)
|
||||||
|
db_model = cls.model(**data)
|
||||||
|
|
||||||
|
if m2m_exclude:
|
||||||
|
m2m_resolved = await cls._resolve_m2m(session, obj)
|
||||||
|
for rel_attr, related_instances in m2m_resolved.items():
|
||||||
|
setattr(db_model, rel_attr, related_instances)
|
||||||
|
|
||||||
session.add(db_model)
|
session.add(db_model)
|
||||||
await session.refresh(db_model)
|
await session.refresh(db_model)
|
||||||
return cast(ModelType, db_model)
|
result = cast(ModelType, db_model)
|
||||||
|
if as_response:
|
||||||
|
return Response(data=result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
@overload
|
||||||
|
@classmethod
|
||||||
|
async def get( # pragma: no cover
|
||||||
|
cls: type[Self],
|
||||||
|
session: AsyncSession,
|
||||||
|
filters: list[Any],
|
||||||
|
*,
|
||||||
|
joins: JoinType | None = None,
|
||||||
|
outer_join: bool = False,
|
||||||
|
with_for_update: bool = False,
|
||||||
|
load_options: list[Any] | None = None,
|
||||||
|
as_response: Literal[True],
|
||||||
|
) -> Response[ModelType]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
@classmethod
|
||||||
|
async def get( # pragma: no cover
|
||||||
|
cls: type[Self],
|
||||||
|
session: AsyncSession,
|
||||||
|
filters: list[Any],
|
||||||
|
*,
|
||||||
|
joins: JoinType | None = None,
|
||||||
|
outer_join: bool = False,
|
||||||
|
with_for_update: bool = False,
|
||||||
|
load_options: list[Any] | None = None,
|
||||||
|
as_response: Literal[False] = ...,
|
||||||
|
) -> ModelType: ...
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get(
|
async def get(
|
||||||
@@ -55,25 +181,39 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
filters: list[Any],
|
filters: list[Any],
|
||||||
*,
|
*,
|
||||||
|
joins: JoinType | None = None,
|
||||||
|
outer_join: bool = False,
|
||||||
with_for_update: bool = False,
|
with_for_update: bool = False,
|
||||||
load_options: list[Any] | None = None,
|
load_options: list[Any] | None = None,
|
||||||
) -> ModelType:
|
as_response: bool = False,
|
||||||
|
) -> ModelType | Response[ModelType]:
|
||||||
"""Get exactly one record. Raises NotFoundError if not found.
|
"""Get exactly one record. Raises NotFoundError if not found.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session: DB async session
|
session: DB async session
|
||||||
filters: List of SQLAlchemy filter conditions
|
filters: List of SQLAlchemy filter conditions
|
||||||
|
joins: List of (model, condition) tuples for joining related tables
|
||||||
|
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
|
||||||
with_for_update: Lock the row for update
|
with_for_update: Lock the row for update
|
||||||
load_options: SQLAlchemy loader options (e.g., selectinload)
|
load_options: SQLAlchemy loader options (e.g., selectinload)
|
||||||
|
as_response: If True, wrap result in Response object
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Model instance
|
Model instance or Response wrapping it
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
NotFoundError: If no record found
|
NotFoundError: If no record found
|
||||||
MultipleResultsFound: If more than one record found
|
MultipleResultsFound: If more than one record found
|
||||||
"""
|
"""
|
||||||
q = select(cls.model).where(and_(*filters))
|
q = select(cls.model)
|
||||||
|
if joins:
|
||||||
|
for model, condition in joins:
|
||||||
|
q = (
|
||||||
|
q.outerjoin(model, condition)
|
||||||
|
if outer_join
|
||||||
|
else q.join(model, condition)
|
||||||
|
)
|
||||||
|
q = q.where(and_(*filters))
|
||||||
if load_options:
|
if load_options:
|
||||||
q = q.options(*load_options)
|
q = q.options(*load_options)
|
||||||
if with_for_update:
|
if with_for_update:
|
||||||
@@ -82,7 +222,10 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
item = result.unique().scalar_one_or_none()
|
item = result.unique().scalar_one_or_none()
|
||||||
if not item:
|
if not item:
|
||||||
raise NotFoundError()
|
raise NotFoundError()
|
||||||
return cast(ModelType, item)
|
result = cast(ModelType, item)
|
||||||
|
if as_response:
|
||||||
|
return Response(data=result)
|
||||||
|
return result
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def first(
|
async def first(
|
||||||
@@ -90,6 +233,8 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
filters: list[Any] | None = None,
|
filters: list[Any] | None = None,
|
||||||
*,
|
*,
|
||||||
|
joins: JoinType | None = None,
|
||||||
|
outer_join: bool = False,
|
||||||
load_options: list[Any] | None = None,
|
load_options: list[Any] | None = None,
|
||||||
) -> ModelType | None:
|
) -> ModelType | None:
|
||||||
"""Get the first matching record, or None.
|
"""Get the first matching record, or None.
|
||||||
@@ -97,12 +242,21 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
Args:
|
Args:
|
||||||
session: DB async session
|
session: DB async session
|
||||||
filters: List of SQLAlchemy filter conditions
|
filters: List of SQLAlchemy filter conditions
|
||||||
|
joins: List of (model, condition) tuples for joining related tables
|
||||||
|
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
|
||||||
load_options: SQLAlchemy loader options
|
load_options: SQLAlchemy loader options
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Model instance or None
|
Model instance or None
|
||||||
"""
|
"""
|
||||||
q = select(cls.model)
|
q = select(cls.model)
|
||||||
|
if joins:
|
||||||
|
for model, condition in joins:
|
||||||
|
q = (
|
||||||
|
q.outerjoin(model, condition)
|
||||||
|
if outer_join
|
||||||
|
else q.join(model, condition)
|
||||||
|
)
|
||||||
if filters:
|
if filters:
|
||||||
q = q.where(and_(*filters))
|
q = q.where(and_(*filters))
|
||||||
if load_options:
|
if load_options:
|
||||||
@@ -116,6 +270,8 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
*,
|
*,
|
||||||
filters: list[Any] | None = None,
|
filters: list[Any] | None = None,
|
||||||
|
joins: JoinType | None = None,
|
||||||
|
outer_join: bool = False,
|
||||||
load_options: list[Any] | None = None,
|
load_options: list[Any] | None = None,
|
||||||
order_by: Any | None = None,
|
order_by: Any | None = None,
|
||||||
limit: int | None = None,
|
limit: int | None = None,
|
||||||
@@ -126,6 +282,8 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
Args:
|
Args:
|
||||||
session: DB async session
|
session: DB async session
|
||||||
filters: List of SQLAlchemy filter conditions
|
filters: List of SQLAlchemy filter conditions
|
||||||
|
joins: List of (model, condition) tuples for joining related tables
|
||||||
|
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
|
||||||
load_options: SQLAlchemy loader options
|
load_options: SQLAlchemy loader options
|
||||||
order_by: Column or list of columns to order by
|
order_by: Column or list of columns to order by
|
||||||
limit: Max number of rows to return
|
limit: Max number of rows to return
|
||||||
@@ -135,6 +293,13 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
List of model instances
|
List of model instances
|
||||||
"""
|
"""
|
||||||
q = select(cls.model)
|
q = select(cls.model)
|
||||||
|
if joins:
|
||||||
|
for model, condition in joins:
|
||||||
|
q = (
|
||||||
|
q.outerjoin(model, condition)
|
||||||
|
if outer_join
|
||||||
|
else q.join(model, condition)
|
||||||
|
)
|
||||||
if filters:
|
if filters:
|
||||||
q = q.where(and_(*filters))
|
q = q.where(and_(*filters))
|
||||||
if load_options:
|
if load_options:
|
||||||
@@ -148,6 +313,32 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
result = await session.execute(q)
|
result = await session.execute(q)
|
||||||
return cast(Sequence[ModelType], result.unique().scalars().all())
|
return cast(Sequence[ModelType], result.unique().scalars().all())
|
||||||
|
|
||||||
|
@overload
|
||||||
|
@classmethod
|
||||||
|
async def update( # pragma: no cover
|
||||||
|
cls: type[Self],
|
||||||
|
session: AsyncSession,
|
||||||
|
obj: BaseModel,
|
||||||
|
filters: list[Any],
|
||||||
|
*,
|
||||||
|
exclude_unset: bool = True,
|
||||||
|
exclude_none: bool = False,
|
||||||
|
as_response: Literal[True],
|
||||||
|
) -> Response[ModelType]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
@classmethod
|
||||||
|
async def update( # pragma: no cover
|
||||||
|
cls: type[Self],
|
||||||
|
session: AsyncSession,
|
||||||
|
obj: BaseModel,
|
||||||
|
filters: list[Any],
|
||||||
|
*,
|
||||||
|
exclude_unset: bool = True,
|
||||||
|
exclude_none: bool = False,
|
||||||
|
as_response: Literal[False] = ...,
|
||||||
|
) -> ModelType: ...
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def update(
|
async def update(
|
||||||
cls: type[Self],
|
cls: type[Self],
|
||||||
@@ -157,7 +348,8 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
*,
|
*,
|
||||||
exclude_unset: bool = True,
|
exclude_unset: bool = True,
|
||||||
exclude_none: bool = False,
|
exclude_none: bool = False,
|
||||||
) -> ModelType:
|
as_response: bool = False,
|
||||||
|
) -> ModelType | Response[ModelType]:
|
||||||
"""Update a record in the database.
|
"""Update a record in the database.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -166,21 +358,45 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
filters: List of SQLAlchemy filter conditions
|
filters: List of SQLAlchemy filter conditions
|
||||||
exclude_unset: Exclude fields not explicitly set in the schema
|
exclude_unset: Exclude fields not explicitly set in the schema
|
||||||
exclude_none: Exclude fields with None value
|
exclude_none: Exclude fields with None value
|
||||||
|
as_response: If True, wrap result in Response object
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Updated model instance
|
Updated model instance or Response wrapping it
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
NotFoundError: If no record found
|
NotFoundError: If no record found
|
||||||
"""
|
"""
|
||||||
async with get_transaction(session):
|
async with get_transaction(session):
|
||||||
db_model = await cls.get(session=session, filters=filters)
|
m2m_exclude = cls._m2m_schema_fields()
|
||||||
|
|
||||||
|
# Eagerly load M2M relationships that will be updated so that
|
||||||
|
# setattr does not trigger a lazy load (which fails in async).
|
||||||
|
m2m_load_options: list[Any] = []
|
||||||
|
if m2m_exclude and cls.m2m_fields:
|
||||||
|
for schema_field, rel in cls.m2m_fields.items():
|
||||||
|
if schema_field in obj.model_fields_set:
|
||||||
|
m2m_load_options.append(selectinload(rel))
|
||||||
|
|
||||||
|
db_model = await cls.get(
|
||||||
|
session=session,
|
||||||
|
filters=filters,
|
||||||
|
load_options=m2m_load_options or None,
|
||||||
|
)
|
||||||
values = obj.model_dump(
|
values = obj.model_dump(
|
||||||
exclude_unset=exclude_unset, exclude_none=exclude_none
|
exclude_unset=exclude_unset,
|
||||||
|
exclude_none=exclude_none,
|
||||||
|
exclude=m2m_exclude,
|
||||||
)
|
)
|
||||||
for key, value in values.items():
|
for key, value in values.items():
|
||||||
setattr(db_model, key, value)
|
setattr(db_model, key, value)
|
||||||
|
|
||||||
|
if m2m_exclude:
|
||||||
|
m2m_resolved = await cls._resolve_m2m(session, obj, only_set=True)
|
||||||
|
for rel_attr, related_instances in m2m_resolved.items():
|
||||||
|
setattr(db_model, rel_attr, related_instances)
|
||||||
await session.refresh(db_model)
|
await session.refresh(db_model)
|
||||||
|
if as_response:
|
||||||
|
return Response(data=db_model)
|
||||||
return db_model
|
return db_model
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -229,24 +445,49 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
)
|
)
|
||||||
return cast(ModelType | None, db_model)
|
return cast(ModelType | None, db_model)
|
||||||
|
|
||||||
|
@overload
|
||||||
|
@classmethod
|
||||||
|
async def delete( # pragma: no cover
|
||||||
|
cls: type[Self],
|
||||||
|
session: AsyncSession,
|
||||||
|
filters: list[Any],
|
||||||
|
*,
|
||||||
|
as_response: Literal[True],
|
||||||
|
) -> Response[None]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
@classmethod
|
||||||
|
async def delete( # pragma: no cover
|
||||||
|
cls: type[Self],
|
||||||
|
session: AsyncSession,
|
||||||
|
filters: list[Any],
|
||||||
|
*,
|
||||||
|
as_response: Literal[False] = ...,
|
||||||
|
) -> bool: ...
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def delete(
|
async def delete(
|
||||||
cls: type[Self],
|
cls: type[Self],
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
filters: list[Any],
|
filters: list[Any],
|
||||||
) -> bool:
|
*,
|
||||||
|
as_response: bool = False,
|
||||||
|
) -> bool | Response[None]:
|
||||||
"""Delete records from the database.
|
"""Delete records from the database.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session: DB async session
|
session: DB async session
|
||||||
filters: List of SQLAlchemy filter conditions
|
filters: List of SQLAlchemy filter conditions
|
||||||
|
as_response: If True, wrap result in Response object
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if deletion was executed
|
True if deletion was executed, or Response wrapping it
|
||||||
"""
|
"""
|
||||||
async with get_transaction(session):
|
async with get_transaction(session):
|
||||||
q = sql_delete(cls.model).where(and_(*filters))
|
q = sql_delete(cls.model).where(and_(*filters))
|
||||||
await session.execute(q)
|
await session.execute(q)
|
||||||
|
if as_response:
|
||||||
|
return Response(data=None)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -254,17 +495,29 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
cls: type[Self],
|
cls: type[Self],
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
filters: list[Any] | None = None,
|
filters: list[Any] | None = None,
|
||||||
|
*,
|
||||||
|
joins: JoinType | None = None,
|
||||||
|
outer_join: bool = False,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Count records matching the filters.
|
"""Count records matching the filters.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session: DB async session
|
session: DB async session
|
||||||
filters: List of SQLAlchemy filter conditions
|
filters: List of SQLAlchemy filter conditions
|
||||||
|
joins: List of (model, condition) tuples for joining related tables
|
||||||
|
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Number of matching records
|
Number of matching records
|
||||||
"""
|
"""
|
||||||
q = select(func.count()).select_from(cls.model)
|
q = select(func.count()).select_from(cls.model)
|
||||||
|
if joins:
|
||||||
|
for model, condition in joins:
|
||||||
|
q = (
|
||||||
|
q.outerjoin(model, condition)
|
||||||
|
if outer_join
|
||||||
|
else q.join(model, condition)
|
||||||
|
)
|
||||||
if filters:
|
if filters:
|
||||||
q = q.where(and_(*filters))
|
q = q.where(and_(*filters))
|
||||||
result = await session.execute(q)
|
result = await session.execute(q)
|
||||||
@@ -275,17 +528,30 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
cls: type[Self],
|
cls: type[Self],
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
filters: list[Any],
|
filters: list[Any],
|
||||||
|
*,
|
||||||
|
joins: JoinType | None = None,
|
||||||
|
outer_join: bool = False,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Check if a record exists.
|
"""Check if a record exists.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session: DB async session
|
session: DB async session
|
||||||
filters: List of SQLAlchemy filter conditions
|
filters: List of SQLAlchemy filter conditions
|
||||||
|
joins: List of (model, condition) tuples for joining related tables
|
||||||
|
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if at least one record matches
|
True if at least one record matches
|
||||||
"""
|
"""
|
||||||
q = select(cls.model).where(and_(*filters)).exists().select()
|
q = select(cls.model)
|
||||||
|
if joins:
|
||||||
|
for model, condition in joins:
|
||||||
|
q = (
|
||||||
|
q.outerjoin(model, condition)
|
||||||
|
if outer_join
|
||||||
|
else q.join(model, condition)
|
||||||
|
)
|
||||||
|
q = q.where(and_(*filters)).exists().select()
|
||||||
result = await session.execute(q)
|
result = await session.execute(q)
|
||||||
return bool(result.scalar())
|
return bool(result.scalar())
|
||||||
|
|
||||||
@@ -295,18 +561,22 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
*,
|
*,
|
||||||
filters: list[Any] | None = None,
|
filters: list[Any] | None = None,
|
||||||
|
joins: JoinType | None = None,
|
||||||
|
outer_join: bool = False,
|
||||||
load_options: list[Any] | None = None,
|
load_options: list[Any] | None = None,
|
||||||
order_by: Any | None = None,
|
order_by: Any | None = None,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
items_per_page: int = 20,
|
items_per_page: int = 20,
|
||||||
search: str | SearchConfig | None = None,
|
search: str | SearchConfig | None = None,
|
||||||
search_fields: Sequence[SearchFieldType] | None = None,
|
search_fields: Sequence[SearchFieldType] | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> PaginatedResponse[ModelType]:
|
||||||
"""Get paginated results with metadata.
|
"""Get paginated results with metadata.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session: DB async session
|
session: DB async session
|
||||||
filters: List of SQLAlchemy filter conditions
|
filters: List of SQLAlchemy filter conditions
|
||||||
|
joins: List of (model, condition) tuples for joining related tables
|
||||||
|
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
|
||||||
load_options: SQLAlchemy loader options
|
load_options: SQLAlchemy loader options
|
||||||
order_by: Column or list of columns to order by
|
order_by: Column or list of columns to order by
|
||||||
page: Page number (1-indexed)
|
page: Page number (1-indexed)
|
||||||
@@ -319,7 +589,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
"""
|
"""
|
||||||
filters = list(filters) if filters else []
|
filters = list(filters) if filters else []
|
||||||
offset = (page - 1) * items_per_page
|
offset = (page - 1) * items_per_page
|
||||||
joins: list[Any] = []
|
search_joins: list[Any] = []
|
||||||
|
|
||||||
# Build search filters
|
# Build search filters
|
||||||
if search:
|
if search:
|
||||||
@@ -330,11 +600,21 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
default_fields=cls.searchable_fields,
|
default_fields=cls.searchable_fields,
|
||||||
)
|
)
|
||||||
filters.extend(search_filters)
|
filters.extend(search_filters)
|
||||||
joins.extend(search_joins)
|
|
||||||
|
|
||||||
# Build query with joins
|
# Build query with joins
|
||||||
q = select(cls.model)
|
q = select(cls.model)
|
||||||
for join_rel in joins:
|
|
||||||
|
# Apply explicit joins
|
||||||
|
if joins:
|
||||||
|
for model, condition in joins:
|
||||||
|
q = (
|
||||||
|
q.outerjoin(model, condition)
|
||||||
|
if outer_join
|
||||||
|
else q.join(model, condition)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply search joins (always outer joins for search)
|
||||||
|
for join_rel in search_joins:
|
||||||
q = q.outerjoin(join_rel)
|
q = q.outerjoin(join_rel)
|
||||||
|
|
||||||
if filters:
|
if filters:
|
||||||
@@ -346,46 +626,63 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
|
|
||||||
q = q.offset(offset).limit(items_per_page)
|
q = q.offset(offset).limit(items_per_page)
|
||||||
result = await session.execute(q)
|
result = await session.execute(q)
|
||||||
items = result.unique().scalars().all()
|
items = cast(list[ModelType], result.unique().scalars().all())
|
||||||
|
|
||||||
# Count query (with same joins and filters)
|
# Count query (with same joins and filters)
|
||||||
pk_col = cls.model.__mapper__.primary_key[0]
|
pk_col = cls.model.__mapper__.primary_key[0]
|
||||||
count_q = select(func.count(func.distinct(getattr(cls.model, pk_col.name))))
|
count_q = select(func.count(func.distinct(getattr(cls.model, pk_col.name))))
|
||||||
count_q = count_q.select_from(cls.model)
|
count_q = count_q.select_from(cls.model)
|
||||||
for join_rel in joins:
|
|
||||||
|
# Apply explicit joins to count query
|
||||||
|
if joins:
|
||||||
|
for model, condition in joins:
|
||||||
|
count_q = (
|
||||||
|
count_q.outerjoin(model, condition)
|
||||||
|
if outer_join
|
||||||
|
else count_q.join(model, condition)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply search joins to count query
|
||||||
|
for join_rel in search_joins:
|
||||||
count_q = count_q.outerjoin(join_rel)
|
count_q = count_q.outerjoin(join_rel)
|
||||||
|
|
||||||
if filters:
|
if filters:
|
||||||
count_q = count_q.where(and_(*filters))
|
count_q = count_q.where(and_(*filters))
|
||||||
|
|
||||||
count_result = await session.execute(count_q)
|
count_result = await session.execute(count_q)
|
||||||
total_count = count_result.scalar_one()
|
total_count = count_result.scalar_one()
|
||||||
|
|
||||||
return {
|
return PaginatedResponse(
|
||||||
"data": items,
|
data=items,
|
||||||
"pagination": {
|
pagination=Pagination(
|
||||||
"total_count": total_count,
|
total_count=total_count,
|
||||||
"items_per_page": items_per_page,
|
items_per_page=items_per_page,
|
||||||
"page": page,
|
page=page,
|
||||||
"has_more": page * items_per_page < total_count,
|
has_more=page * items_per_page < total_count,
|
||||||
},
|
),
|
||||||
}
|
)
|
||||||
|
|
||||||
|
|
||||||
def CrudFactory(
|
def CrudFactory(
|
||||||
model: type[ModelType],
|
model: type[ModelType],
|
||||||
*,
|
*,
|
||||||
searchable_fields: Sequence[SearchFieldType] | None = None,
|
searchable_fields: Sequence[SearchFieldType] | None = None,
|
||||||
|
m2m_fields: M2MFieldType | None = None,
|
||||||
) -> type[AsyncCrud[ModelType]]:
|
) -> type[AsyncCrud[ModelType]]:
|
||||||
"""Create a CRUD class for a specific model.
|
"""Create a CRUD class for a specific model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: SQLAlchemy model class
|
model: SQLAlchemy model class
|
||||||
searchable_fields: Optional list of searchable fields
|
searchable_fields: Optional list of searchable fields
|
||||||
|
m2m_fields: Optional mapping for many-to-many relationships.
|
||||||
|
Maps schema field names (containing lists of IDs) to
|
||||||
|
SQLAlchemy relationship attributes.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
AsyncCrud subclass bound to the model
|
AsyncCrud subclass bound to the model
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
from fastapi_toolsets.crud import CrudFactory
|
from fastapi_toolsets.crud import CrudFactory
|
||||||
from myapp.models import User, Post
|
from myapp.models import User, Post
|
||||||
|
|
||||||
@@ -398,12 +695,37 @@ def CrudFactory(
|
|||||||
searchable_fields=[User.username, User.email, (User.role, Role.name)]
|
searchable_fields=[User.username, User.email, (User.role, Role.name)]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# With many-to-many fields:
|
||||||
|
# Schema has `tag_ids: list[UUID]`, model has `tags` relationship to Tag
|
||||||
|
PostCrud = CrudFactory(
|
||||||
|
Post,
|
||||||
|
m2m_fields={"tag_ids": Post.tags},
|
||||||
|
)
|
||||||
|
|
||||||
# Usage
|
# Usage
|
||||||
user = await UserCrud.get(session, [User.id == 1])
|
user = await UserCrud.get(session, [User.id == 1])
|
||||||
posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id])
|
posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id])
|
||||||
|
|
||||||
|
# Create with M2M - tag_ids are automatically resolved
|
||||||
|
post = await PostCrud.create(session, PostCreate(title="Hello", tag_ids=[id1, id2]))
|
||||||
|
|
||||||
# With search
|
# With search
|
||||||
result = await UserCrud.paginate(session, search="john")
|
result = await UserCrud.paginate(session, search="john")
|
||||||
|
|
||||||
|
# With joins (inner join by default):
|
||||||
|
users = await UserCrud.get_multi(
|
||||||
|
session,
|
||||||
|
joins=[(Post, Post.user_id == User.id)],
|
||||||
|
filters=[Post.published == True],
|
||||||
|
)
|
||||||
|
|
||||||
|
# With outer join:
|
||||||
|
users = await UserCrud.get_multi(
|
||||||
|
session,
|
||||||
|
joins=[(Post, Post.user_id == User.id)],
|
||||||
|
outer_join=True,
|
||||||
|
)
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
cls = type(
|
cls = type(
|
||||||
f"Async{model.__name__}Crud",
|
f"Async{model.__name__}Crud",
|
||||||
@@ -411,6 +733,7 @@ def CrudFactory(
|
|||||||
{
|
{
|
||||||
"model": model,
|
"model": model,
|
||||||
"searchable_fields": searchable_fields,
|
"searchable_fields": searchable_fields,
|
||||||
|
"m2m_fields": m2m_fields,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return cast(type[AsyncCrud[ModelType]], cls)
|
return cast(type[AsyncCrud[ModelType]], cls)
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
"""Database utilities: sessions, transactions, and locks."""
|
"""Database utilities: sessions, transactions, and locks."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from collections.abc import AsyncGenerator, Callable
|
from collections.abc import AsyncGenerator, Callable
|
||||||
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
@@ -14,6 +16,7 @@ __all__ = [
|
|||||||
"create_db_dependency",
|
"create_db_dependency",
|
||||||
"lock_tables",
|
"lock_tables",
|
||||||
"get_transaction",
|
"get_transaction",
|
||||||
|
"wait_for_row_change",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -32,6 +35,7 @@ def create_db_dependency(
|
|||||||
An async generator function usable with FastAPI's Depends()
|
An async generator function usable with FastAPI's Depends()
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||||
from fastapi_toolsets.db import create_db_dependency
|
from fastapi_toolsets.db import create_db_dependency
|
||||||
@@ -43,6 +47,7 @@ def create_db_dependency(
|
|||||||
@app.get("/users")
|
@app.get("/users")
|
||||||
async def list_users(session: AsyncSession = Depends(get_db)):
|
async def list_users(session: AsyncSession = Depends(get_db)):
|
||||||
...
|
...
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||||
@@ -69,6 +74,7 @@ def create_db_context(
|
|||||||
An async context manager function
|
An async context manager function
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||||
from fastapi_toolsets.db import create_db_context
|
from fastapi_toolsets.db import create_db_context
|
||||||
|
|
||||||
@@ -80,6 +86,7 @@ def create_db_context(
|
|||||||
async with get_db_context() as session:
|
async with get_db_context() as session:
|
||||||
user = await UserCrud.get(session, [User.id == 1])
|
user = await UserCrud.get(session, [User.id == 1])
|
||||||
...
|
...
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
get_db = create_db_dependency(session_maker)
|
get_db = create_db_dependency(session_maker)
|
||||||
return asynccontextmanager(get_db)
|
return asynccontextmanager(get_db)
|
||||||
@@ -101,9 +108,11 @@ async def get_transaction(
|
|||||||
The session within the transaction context
|
The session within the transaction context
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
async with get_transaction(session):
|
async with get_transaction(session):
|
||||||
session.add(model)
|
session.add(model)
|
||||||
# Auto-commits on exit, rolls back on exception
|
# Auto-commits on exit, rolls back on exception
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
if session.in_transaction():
|
if session.in_transaction():
|
||||||
async with session.begin_nested():
|
async with session.begin_nested():
|
||||||
@@ -155,6 +164,7 @@ async def lock_tables(
|
|||||||
SQLAlchemyError: If lock cannot be acquired within timeout
|
SQLAlchemyError: If lock cannot be acquired within timeout
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
from fastapi_toolsets.db import lock_tables, LockMode
|
from fastapi_toolsets.db import lock_tables, LockMode
|
||||||
|
|
||||||
async with lock_tables(session, [User, Account]):
|
async with lock_tables(session, [User, Account]):
|
||||||
@@ -166,6 +176,7 @@ async def lock_tables(
|
|||||||
async with lock_tables(session, [Order], mode=LockMode.EXCLUSIVE):
|
async with lock_tables(session, [Order], mode=LockMode.EXCLUSIVE):
|
||||||
# Exclusive lock - no other transactions can access
|
# Exclusive lock - no other transactions can access
|
||||||
await process_order(session, order_id)
|
await process_order(session, order_id)
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
table_names = ",".join(table.__tablename__ for table in tables)
|
table_names = ",".join(table.__tablename__ for table in tables)
|
||||||
|
|
||||||
@@ -173,3 +184,85 @@ async def lock_tables(
|
|||||||
await session.execute(text(f"SET LOCAL lock_timeout='{timeout}'"))
|
await session.execute(text(f"SET LOCAL lock_timeout='{timeout}'"))
|
||||||
await session.execute(text(f"LOCK {table_names} IN {mode.value} MODE"))
|
await session.execute(text(f"LOCK {table_names} IN {mode.value} MODE"))
|
||||||
yield session
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
_M = TypeVar("_M", bound=DeclarativeBase)
|
||||||
|
|
||||||
|
|
||||||
|
async def wait_for_row_change(
|
||||||
|
session: AsyncSession,
|
||||||
|
model: type[_M],
|
||||||
|
pk_value: Any,
|
||||||
|
*,
|
||||||
|
columns: list[str] | None = None,
|
||||||
|
interval: float = 0.5,
|
||||||
|
timeout: float | None = None,
|
||||||
|
) -> _M:
|
||||||
|
"""Poll a database row until a change is detected.
|
||||||
|
|
||||||
|
Queries the row every ``interval`` seconds and returns the model instance
|
||||||
|
once a change is detected in any column (or only the specified ``columns``).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: AsyncSession instance
|
||||||
|
model: SQLAlchemy model class
|
||||||
|
pk_value: Primary key value of the row to watch
|
||||||
|
columns: Optional list of column names to watch. If None, all columns
|
||||||
|
are watched.
|
||||||
|
interval: Polling interval in seconds (default: 0.5)
|
||||||
|
timeout: Maximum time to wait in seconds. None means wait forever.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The refreshed model instance with updated values
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
LookupError: If the row does not exist or is deleted during polling
|
||||||
|
TimeoutError: If timeout expires before a change is detected
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.db import wait_for_row_change
|
||||||
|
|
||||||
|
# Wait for any column to change
|
||||||
|
updated = await wait_for_row_change(session, User, user_id)
|
||||||
|
|
||||||
|
# Watch specific columns with a timeout
|
||||||
|
updated = await wait_for_row_change(
|
||||||
|
session, User, user_id,
|
||||||
|
columns=["status", "email"],
|
||||||
|
interval=1.0,
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
instance = await session.get(model, pk_value)
|
||||||
|
if instance is None:
|
||||||
|
raise LookupError(f"{model.__name__} with pk={pk_value!r} not found")
|
||||||
|
|
||||||
|
if columns is not None:
|
||||||
|
watch_cols = columns
|
||||||
|
else:
|
||||||
|
watch_cols = [attr.key for attr in model.__mapper__.column_attrs]
|
||||||
|
|
||||||
|
initial = {col: getattr(instance, col) for col in watch_cols}
|
||||||
|
|
||||||
|
elapsed = 0.0
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
elapsed += interval
|
||||||
|
|
||||||
|
if timeout is not None and elapsed >= timeout:
|
||||||
|
raise TimeoutError(
|
||||||
|
f"No change detected on {model.__name__} "
|
||||||
|
f"with pk={pk_value!r} within {timeout}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
session.expunge(instance)
|
||||||
|
instance = await session.get(model, pk_value)
|
||||||
|
|
||||||
|
if instance is None:
|
||||||
|
raise LookupError(f"{model.__name__} with pk={pk_value!r} was deleted")
|
||||||
|
|
||||||
|
current = {col: getattr(instance, col) for col in watch_cols}
|
||||||
|
if current != initial:
|
||||||
|
return instance
|
||||||
|
|||||||
145
src/fastapi_toolsets/dependencies.py
Normal file
145
src/fastapi_toolsets/dependencies.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
"""Dependency factories for FastAPI routes."""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from collections.abc import AsyncGenerator, Callable
|
||||||
|
from typing import Any, TypeVar, cast
|
||||||
|
|
||||||
|
from fastapi import Depends
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
from .crud import CrudFactory
|
||||||
|
|
||||||
|
__all__ = ["BodyDependency", "PathDependency"]
|
||||||
|
|
||||||
|
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
||||||
|
SessionDependency = Callable[[], AsyncGenerator[AsyncSession, None]]
|
||||||
|
|
||||||
|
|
||||||
|
def PathDependency(
|
||||||
|
model: type[ModelType],
|
||||||
|
field: Any,
|
||||||
|
*,
|
||||||
|
session_dep: SessionDependency,
|
||||||
|
param_name: str | None = None,
|
||||||
|
) -> ModelType:
|
||||||
|
"""Create a dependency that fetches a DB object from a path parameter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: SQLAlchemy model class
|
||||||
|
field: Model field to filter by (e.g., User.id)
|
||||||
|
session_dep: Session dependency function (e.g., get_db)
|
||||||
|
param_name: Path parameter name (defaults to model_field, e.g., user_id)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Depends() instance that resolves to the model instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If no matching record is found
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
UserDep = PathDependency(User, User.id, session_dep=get_db)
|
||||||
|
|
||||||
|
@router.get("/user/{id}")
|
||||||
|
async def get(
|
||||||
|
user: User = UserDep,
|
||||||
|
): ...
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
crud = CrudFactory(model)
|
||||||
|
name = (
|
||||||
|
param_name
|
||||||
|
if param_name is not None
|
||||||
|
else "{}_{}".format(model.__name__.lower(), field.key)
|
||||||
|
)
|
||||||
|
python_type = field.type.python_type
|
||||||
|
|
||||||
|
async def dependency(
|
||||||
|
session: AsyncSession = Depends(session_dep), **kwargs: Any
|
||||||
|
) -> ModelType:
|
||||||
|
value = kwargs[name]
|
||||||
|
return await crud.get(session, filters=[field == value])
|
||||||
|
|
||||||
|
setattr(
|
||||||
|
dependency,
|
||||||
|
"__signature__",
|
||||||
|
inspect.Signature(
|
||||||
|
parameters=[
|
||||||
|
inspect.Parameter(
|
||||||
|
name, inspect.Parameter.KEYWORD_ONLY, annotation=python_type
|
||||||
|
),
|
||||||
|
inspect.Parameter(
|
||||||
|
"session",
|
||||||
|
inspect.Parameter.KEYWORD_ONLY,
|
||||||
|
annotation=AsyncSession,
|
||||||
|
default=Depends(session_dep),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return cast(ModelType, Depends(cast(Callable[..., ModelType], dependency)))
|
||||||
|
|
||||||
|
|
||||||
|
def BodyDependency(
|
||||||
|
model: type[ModelType],
|
||||||
|
field: Any,
|
||||||
|
*,
|
||||||
|
session_dep: SessionDependency,
|
||||||
|
body_field: str,
|
||||||
|
) -> ModelType:
|
||||||
|
"""Create a dependency that fetches a DB object from a body field.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: SQLAlchemy model class
|
||||||
|
field: Model field to filter by (e.g., User.id)
|
||||||
|
session_dep: Session dependency function (e.g., get_db)
|
||||||
|
body_field: Name of the field in the request body
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Depends() instance that resolves to the model instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If no matching record is found
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
UserDep = BodyDependency(
|
||||||
|
User, User.ctfd_id, session_dep=get_db, body_field="user_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.post("/assign")
|
||||||
|
async def assign(
|
||||||
|
user: User = UserDep,
|
||||||
|
): ...
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
crud = CrudFactory(model)
|
||||||
|
python_type = field.type.python_type
|
||||||
|
|
||||||
|
async def dependency(
|
||||||
|
session: AsyncSession = Depends(session_dep), **kwargs: Any
|
||||||
|
) -> ModelType:
|
||||||
|
value = kwargs[body_field]
|
||||||
|
return await crud.get(session, filters=[field == value])
|
||||||
|
|
||||||
|
setattr(
|
||||||
|
dependency,
|
||||||
|
"__signature__",
|
||||||
|
inspect.Signature(
|
||||||
|
parameters=[
|
||||||
|
inspect.Parameter(
|
||||||
|
body_field, inspect.Parameter.KEYWORD_ONLY, annotation=python_type
|
||||||
|
),
|
||||||
|
inspect.Parameter(
|
||||||
|
"session",
|
||||||
|
inspect.Parameter.KEYWORD_ONLY,
|
||||||
|
annotation=AsyncSession,
|
||||||
|
default=Depends(session_dep),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return cast(ModelType, Depends(cast(Callable[..., ModelType], dependency)))
|
||||||
@@ -1,3 +1,5 @@
|
|||||||
|
"""Standardized API exceptions and error response handlers."""
|
||||||
|
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
ApiError,
|
ApiError,
|
||||||
ApiException,
|
ApiException,
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ class ApiException(Exception):
|
|||||||
The exception handler will use api_error to generate the response.
|
The exception handler will use api_error to generate the response.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
class CustomError(ApiException):
|
class CustomError(ApiException):
|
||||||
api_error = ApiError(
|
api_error = ApiError(
|
||||||
code=400,
|
code=400,
|
||||||
@@ -19,6 +20,7 @@ class ApiException(Exception):
|
|||||||
desc="The request was invalid.",
|
desc="The request was invalid.",
|
||||||
err_code="CUSTOM-400",
|
err_code="CUSTOM-400",
|
||||||
)
|
)
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
api_error: ClassVar[ApiError]
|
api_error: ClassVar[ApiError]
|
||||||
@@ -76,49 +78,6 @@ class ConflictError(ApiException):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class InsufficientRolesError(ForbiddenError):
|
|
||||||
"""User does not have the required roles."""
|
|
||||||
|
|
||||||
api_error = ApiError(
|
|
||||||
code=403,
|
|
||||||
msg="Insufficient Roles",
|
|
||||||
desc="You do not have the required roles to access this resource.",
|
|
||||||
err_code="RBAC-403",
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self, required_roles: list[str], user_roles: set[str] | None = None):
|
|
||||||
self.required_roles = required_roles
|
|
||||||
self.user_roles = user_roles
|
|
||||||
|
|
||||||
desc = f"Required roles: {', '.join(required_roles)}"
|
|
||||||
if user_roles is not None:
|
|
||||||
desc += f". User has: {', '.join(user_roles) if user_roles else 'no roles'}"
|
|
||||||
|
|
||||||
super().__init__(desc)
|
|
||||||
|
|
||||||
|
|
||||||
class UserNotFoundError(NotFoundError):
|
|
||||||
"""User was not found."""
|
|
||||||
|
|
||||||
api_error = ApiError(
|
|
||||||
code=404,
|
|
||||||
msg="User Not Found",
|
|
||||||
desc="The requested user was not found.",
|
|
||||||
err_code="USER-404",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RoleNotFoundError(NotFoundError):
|
|
||||||
"""Role was not found."""
|
|
||||||
|
|
||||||
api_error = ApiError(
|
|
||||||
code=404,
|
|
||||||
msg="Role Not Found",
|
|
||||||
desc="The requested role was not found.",
|
|
||||||
err_code="ROLE-404",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class NoSearchableFieldsError(ApiException):
|
class NoSearchableFieldsError(ApiException):
|
||||||
"""Raised when search is requested but no searchable fields are available."""
|
"""Raised when search is requested but no searchable fields are available."""
|
||||||
|
|
||||||
@@ -130,6 +89,11 @@ class NoSearchableFieldsError(ApiException):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, model: type) -> None:
|
def __init__(self, model: type) -> None:
|
||||||
|
"""Initialize the exception.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The SQLAlchemy model class that has no searchable fields
|
||||||
|
"""
|
||||||
self.model = model
|
self.model = model
|
||||||
detail = (
|
detail = (
|
||||||
f"No searchable fields found for model '{model.__name__}'. "
|
f"No searchable fields found for model '{model.__name__}'. "
|
||||||
@@ -152,6 +116,7 @@ def generate_error_responses(
|
|||||||
Dict suitable for FastAPI's responses parameter
|
Dict suitable for FastAPI's responses parameter
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
from fastapi_toolsets.exceptions import generate_error_responses, UnauthorizedError, ForbiddenError
|
from fastapi_toolsets.exceptions import generate_error_responses, UnauthorizedError, ForbiddenError
|
||||||
|
|
||||||
@app.get(
|
@app.get(
|
||||||
@@ -160,6 +125,7 @@ def generate_error_responses(
|
|||||||
)
|
)
|
||||||
async def admin_endpoint():
|
async def admin_endpoint():
|
||||||
...
|
...
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
responses: dict[int | str, dict[str, Any]] = {}
|
responses: dict[int | str, dict[str, Any]] = {}
|
||||||
|
|
||||||
@@ -172,7 +138,7 @@ def generate_error_responses(
|
|||||||
"content": {
|
"content": {
|
||||||
"application/json": {
|
"application/json": {
|
||||||
"example": {
|
"example": {
|
||||||
"data": None,
|
"data": api_error.data,
|
||||||
"status": ResponseStatus.FAIL.value,
|
"status": ResponseStatus.FAIL.value,
|
||||||
"message": api_error.msg,
|
"message": api_error.msg,
|
||||||
"description": api_error.desc,
|
"description": api_error.desc,
|
||||||
|
|||||||
@@ -7,11 +7,32 @@ from fastapi.exceptions import RequestValidationError, ResponseValidationError
|
|||||||
from fastapi.openapi.utils import get_openapi
|
from fastapi.openapi.utils import get_openapi
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
from ..schemas import ResponseStatus
|
from ..schemas import ErrorResponse, ResponseStatus
|
||||||
from .exceptions import ApiException
|
from .exceptions import ApiException
|
||||||
|
|
||||||
|
|
||||||
def init_exceptions_handlers(app: FastAPI) -> FastAPI:
|
def init_exceptions_handlers(app: FastAPI) -> FastAPI:
|
||||||
|
"""Register exception handlers and custom OpenAPI schema on a FastAPI app.
|
||||||
|
|
||||||
|
Installs handlers for :class:`ApiException`, validation errors, and
|
||||||
|
unhandled exceptions, and replaces the default 422 schema with a
|
||||||
|
consistent error format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: FastAPI application instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The same FastAPI instance (for chaining)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi_toolsets.exceptions import init_exceptions_handlers
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app)
|
||||||
|
```
|
||||||
|
"""
|
||||||
_register_exception_handlers(app)
|
_register_exception_handlers(app)
|
||||||
app.openapi = lambda: _custom_openapi(app) # type: ignore[method-assign]
|
app.openapi = lambda: _custom_openapi(app) # type: ignore[method-assign]
|
||||||
return app
|
return app
|
||||||
@@ -35,16 +56,16 @@ def _register_exception_handlers(app: FastAPI) -> None:
|
|||||||
async def api_exception_handler(request: Request, exc: ApiException) -> Response:
|
async def api_exception_handler(request: Request, exc: ApiException) -> Response:
|
||||||
"""Handle custom API exceptions with structured response."""
|
"""Handle custom API exceptions with structured response."""
|
||||||
api_error = exc.api_error
|
api_error = exc.api_error
|
||||||
|
error_response = ErrorResponse(
|
||||||
|
data=api_error.data,
|
||||||
|
message=api_error.msg,
|
||||||
|
description=api_error.desc,
|
||||||
|
error_code=api_error.err_code,
|
||||||
|
)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=api_error.code,
|
status_code=api_error.code,
|
||||||
content={
|
content=error_response.model_dump(),
|
||||||
"data": None,
|
|
||||||
"status": ResponseStatus.FAIL.value,
|
|
||||||
"message": api_error.msg,
|
|
||||||
"description": api_error.desc,
|
|
||||||
"error_code": api_error.err_code,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@app.exception_handler(RequestValidationError)
|
@app.exception_handler(RequestValidationError)
|
||||||
@@ -64,15 +85,15 @@ def _register_exception_handlers(app: FastAPI) -> None:
|
|||||||
@app.exception_handler(Exception)
|
@app.exception_handler(Exception)
|
||||||
async def generic_exception_handler(request: Request, exc: Exception) -> Response:
|
async def generic_exception_handler(request: Request, exc: Exception) -> Response:
|
||||||
"""Handle all unhandled exceptions with a generic 500 response."""
|
"""Handle all unhandled exceptions with a generic 500 response."""
|
||||||
|
error_response = ErrorResponse(
|
||||||
|
message="Internal Server Error",
|
||||||
|
description="An unexpected error occurred. Please try again later.",
|
||||||
|
error_code="SERVER-500",
|
||||||
|
)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
content={
|
content=error_response.model_dump(),
|
||||||
"data": None,
|
|
||||||
"status": ResponseStatus.FAIL.value,
|
|
||||||
"message": "Internal Server Error",
|
|
||||||
"description": "An unexpected error occurred. Please try again later.",
|
|
||||||
"error_code": "SERVER-500",
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -97,15 +118,16 @@ def _format_validation_error(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
error_response = ErrorResponse(
|
||||||
|
data={"errors": formatted_errors},
|
||||||
|
message="Validation Error",
|
||||||
|
description=f"{len(formatted_errors)} validation error(s) detected",
|
||||||
|
error_code="VAL-422",
|
||||||
|
)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
content={
|
content=error_response.model_dump(),
|
||||||
"data": {"errors": formatted_errors},
|
|
||||||
"status": ResponseStatus.FAIL.value,
|
|
||||||
"message": "Validation Error",
|
|
||||||
"description": f"{len(formatted_errors)} validation error(s) detected",
|
|
||||||
"error_code": "VAL-422",
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
"""Fixture system for seeding databases with dependency resolution."""
|
||||||
|
|
||||||
from .enum import LoadStrategy
|
from .enum import LoadStrategy
|
||||||
from .registry import Context, FixtureRegistry
|
from .registry import Context, FixtureRegistry
|
||||||
from .utils import get_obj_by_attr, load_fixtures, load_fixtures_by_context
|
from .utils import get_obj_by_attr, load_fixtures, load_fixtures_by_context
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
"""Enums for fixture loading strategies and contexts."""
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
"""Fixture system with dependency management and context support."""
|
"""Fixture system with dependency management and context support."""
|
||||||
|
|
||||||
import logging
|
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
from ..logger import get_logger
|
||||||
from .enum import Context
|
from .enum import Context
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -26,6 +26,7 @@ class FixtureRegistry:
|
|||||||
"""Registry for managing fixtures with dependencies.
|
"""Registry for managing fixtures with dependencies.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
from fastapi_toolsets.fixtures import FixtureRegistry, Context
|
from fastapi_toolsets.fixtures import FixtureRegistry, Context
|
||||||
|
|
||||||
fixtures = FixtureRegistry()
|
fixtures = FixtureRegistry()
|
||||||
@@ -48,10 +49,19 @@ class FixtureRegistry:
|
|||||||
return [
|
return [
|
||||||
Post(id=1, title="Test", user_id=1),
|
Post(id=1, title="Test", user_id=1),
|
||||||
]
|
]
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
contexts: list[str | Context] | None = None,
|
||||||
|
) -> None:
|
||||||
self._fixtures: dict[str, Fixture] = {}
|
self._fixtures: dict[str, Fixture] = {}
|
||||||
|
self._default_contexts: list[str] | None = (
|
||||||
|
[c.value if isinstance(c, Context) else c for c in contexts]
|
||||||
|
if contexts
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
def register(
|
def register(
|
||||||
self,
|
self,
|
||||||
@@ -72,6 +82,7 @@ class FixtureRegistry:
|
|||||||
contexts: List of contexts this fixture belongs to
|
contexts: List of contexts this fixture belongs to
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
@fixtures.register
|
@fixtures.register
|
||||||
def roles():
|
def roles():
|
||||||
return [Role(id=1, name="admin")]
|
return [Role(id=1, name="admin")]
|
||||||
@@ -79,16 +90,21 @@ class FixtureRegistry:
|
|||||||
@fixtures.register(depends_on=["roles"], contexts=[Context.TESTING])
|
@fixtures.register(depends_on=["roles"], contexts=[Context.TESTING])
|
||||||
def test_users():
|
def test_users():
|
||||||
return [User(id=1, username="test", role_id=1)]
|
return [User(id=1, username="test", role_id=1)]
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(
|
def decorator(
|
||||||
fn: Callable[[], Sequence[DeclarativeBase]],
|
fn: Callable[[], Sequence[DeclarativeBase]],
|
||||||
) -> Callable[[], Sequence[DeclarativeBase]]:
|
) -> Callable[[], Sequence[DeclarativeBase]]:
|
||||||
fixture_name = name or cast(Any, fn).__name__
|
fixture_name = name or cast(Any, fn).__name__
|
||||||
fixture_contexts = [
|
if contexts is not None:
|
||||||
c.value if isinstance(c, Context) else c
|
fixture_contexts = [
|
||||||
for c in (contexts or [Context.BASE])
|
c.value if isinstance(c, Context) else c for c in contexts
|
||||||
]
|
]
|
||||||
|
elif self._default_contexts is not None:
|
||||||
|
fixture_contexts = self._default_contexts
|
||||||
|
else:
|
||||||
|
fixture_contexts = [Context.BASE.value]
|
||||||
|
|
||||||
self._fixtures[fixture_name] = Fixture(
|
self._fixtures[fixture_name] = Fixture(
|
||||||
name=fixture_name,
|
name=fixture_name,
|
||||||
@@ -102,6 +118,34 @@ class FixtureRegistry:
|
|||||||
return decorator(func)
|
return decorator(func)
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
def include_registry(self, registry: "FixtureRegistry") -> None:
|
||||||
|
"""Include another `FixtureRegistry` in the same current `FixtureRegistry`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
registry: The `FixtureRegistry` to include
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If a fixture name already exists in the current registry
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
registry = FixtureRegistry()
|
||||||
|
dev_registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@dev_registry.register
|
||||||
|
def dev_data():
|
||||||
|
return [...]
|
||||||
|
|
||||||
|
registry.include_registry(registry=dev_registry)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
for name, fixture in registry._fixtures.items():
|
||||||
|
if name in self._fixtures:
|
||||||
|
raise ValueError(
|
||||||
|
f"Fixture '{name}' already exists in the current registry"
|
||||||
|
)
|
||||||
|
self._fixtures[name] = fixture
|
||||||
|
|
||||||
def get(self, name: str) -> Fixture:
|
def get(self, name: str) -> Fixture:
|
||||||
"""Get a fixture by name."""
|
"""Get a fixture by name."""
|
||||||
if name not in self._fixtures:
|
if name not in self._fixtures:
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import logging
|
"""Fixture loading utilities for database seeding."""
|
||||||
|
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from typing import Any, TypeVar
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
@@ -6,10 +7,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
from ..db import get_transaction
|
from ..db import get_transaction
|
||||||
|
from ..logger import get_logger
|
||||||
from .enum import LoadStrategy
|
from .enum import LoadStrategy
|
||||||
from .registry import Context, FixtureRegistry
|
from .registry import Context, FixtureRegistry
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = get_logger()
|
||||||
|
|
||||||
T = TypeVar("T", bound=DeclarativeBase)
|
T = TypeVar("T", bound=DeclarativeBase)
|
||||||
|
|
||||||
@@ -29,9 +31,14 @@ def get_obj_by_attr(
|
|||||||
The first model instance where the attribute matches the given value.
|
The first model instance where the attribute matches the given value.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
StopIteration: If no matching object is found.
|
StopIteration: If no matching object is found in the fixture group.
|
||||||
"""
|
"""
|
||||||
return next(obj for obj in fixtures() if getattr(obj, attr_name) == value)
|
try:
|
||||||
|
return next(obj for obj in fixtures() if getattr(obj, attr_name) == value)
|
||||||
|
except StopIteration:
|
||||||
|
raise StopIteration(
|
||||||
|
f"No object with {attr_name}={value} found in fixture '{getattr(fixtures, '__name__', repr(fixtures))}'"
|
||||||
|
) from None
|
||||||
|
|
||||||
|
|
||||||
async def load_fixtures(
|
async def load_fixtures(
|
||||||
@@ -52,9 +59,11 @@ async def load_fixtures(
|
|||||||
Dict mapping fixture names to loaded instances
|
Dict mapping fixture names to loaded instances
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
# Loads 'roles' first (dependency), then 'users'
|
# Loads 'roles' first (dependency), then 'users'
|
||||||
result = await load_fixtures(session, fixtures, "users")
|
result = await load_fixtures(session, fixtures, "users")
|
||||||
print(result["users"]) # [User(...), ...]
|
print(result["users"]) # [User(...), ...]
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
ordered = registry.resolve_dependencies(*names)
|
ordered = registry.resolve_dependencies(*names)
|
||||||
return await _load_ordered(session, registry, ordered, strategy)
|
return await _load_ordered(session, registry, ordered, strategy)
|
||||||
@@ -78,11 +87,13 @@ async def load_fixtures_by_context(
|
|||||||
Dict mapping fixture names to loaded instances
|
Dict mapping fixture names to loaded instances
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
# Load base + testing fixtures
|
# Load base + testing fixtures
|
||||||
await load_fixtures_by_context(
|
await load_fixtures_by_context(
|
||||||
session, fixtures,
|
session, fixtures,
|
||||||
Context.BASE, Context.TESTING
|
Context.BASE, Context.TESTING
|
||||||
)
|
)
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
ordered = registry.resolve_context_dependencies(*contexts)
|
ordered = registry.resolve_context_dependencies(*contexts)
|
||||||
return await _load_ordered(session, registry, ordered, strategy)
|
return await _load_ordered(session, registry, ordered, strategy)
|
||||||
|
|||||||
98
src/fastapi_toolsets/logger.py
Normal file
98
src/fastapi_toolsets/logger.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
"""Logging configuration for FastAPI applications and CLI tools."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
__all__ = ["LogLevel", "configure_logging", "get_logger"]
|
||||||
|
|
||||||
|
DEFAULT_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
|
UVICORN_LOGGERS = ("uvicorn", "uvicorn.access", "uvicorn.error")
|
||||||
|
|
||||||
|
LogLevel = Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
|
||||||
|
|
||||||
|
|
||||||
|
def configure_logging(
|
||||||
|
level: LogLevel | int = "INFO",
|
||||||
|
fmt: str = DEFAULT_FORMAT,
|
||||||
|
logger_name: str | None = None,
|
||||||
|
) -> logging.Logger:
|
||||||
|
"""Configure logging with a stdout handler and consistent format.
|
||||||
|
|
||||||
|
Sets up a :class:`~logging.StreamHandler` writing to stdout with the
|
||||||
|
given format and level. Also configures the uvicorn loggers so that
|
||||||
|
FastAPI access logs use the same format.
|
||||||
|
|
||||||
|
Calling this function multiple times is safe -- existing handlers are
|
||||||
|
replaced rather than duplicated.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
level: Log level (e.g. ``"DEBUG"``, ``"INFO"``, or ``logging.DEBUG``).
|
||||||
|
fmt: Log format string. Defaults to
|
||||||
|
``"%(asctime)s - %(name)s - %(levelname)s - %(message)s"``.
|
||||||
|
logger_name: Logger name to configure. ``None`` (the default)
|
||||||
|
configures the root logger so all loggers inherit the settings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The configured Logger instance.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.logger import configure_logging
|
||||||
|
|
||||||
|
logger = configure_logging("DEBUG")
|
||||||
|
logger.info("Application started")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
formatter = logging.Formatter(fmt)
|
||||||
|
|
||||||
|
handler = logging.StreamHandler(sys.stdout)
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
logger = logging.getLogger(logger_name)
|
||||||
|
logger.handlers.clear()
|
||||||
|
logger.addHandler(handler)
|
||||||
|
logger.setLevel(level)
|
||||||
|
|
||||||
|
for name in UVICORN_LOGGERS:
|
||||||
|
uv_logger = logging.getLogger(name)
|
||||||
|
uv_logger.handlers.clear()
|
||||||
|
uv_logger.addHandler(handler)
|
||||||
|
uv_logger.setLevel(level)
|
||||||
|
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
_SENTINEL = object()
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger(name: str | None = _SENTINEL) -> logging.Logger: # type: ignore[assignment]
|
||||||
|
"""Return a logger with the given *name*.
|
||||||
|
|
||||||
|
A thin convenience wrapper around :func:`logging.getLogger` that keeps
|
||||||
|
logging imports consistent across the codebase.
|
||||||
|
|
||||||
|
When called without arguments, the caller's ``__name__`` is used
|
||||||
|
automatically, so ``get_logger()`` in a module is equivalent to
|
||||||
|
``logging.getLogger(__name__)``. Pass ``None`` explicitly to get the
|
||||||
|
root logger.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Logger name. Defaults to the caller's ``__name__``.
|
||||||
|
Pass ``None`` to get the root logger.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Logger instance.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger() # uses caller's __name__
|
||||||
|
logger = get_logger("myapp") # explicit name
|
||||||
|
logger = get_logger(None) # root logger
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if name is _SENTINEL:
|
||||||
|
name = sys._getframe(1).f_globals.get("__name__")
|
||||||
|
return logging.getLogger(name)
|
||||||
21
src/fastapi_toolsets/metrics/__init__.py
Normal file
21
src/fastapi_toolsets/metrics/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
"""Prometheus metrics integration for FastAPI applications."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .registry import Metric, MetricsRegistry
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .handler import init_metrics
|
||||||
|
except ImportError:
|
||||||
|
|
||||||
|
def init_metrics(*_args: Any, **_kwargs: Any) -> None:
|
||||||
|
from .._imports import require_extra
|
||||||
|
|
||||||
|
require_extra(package="prometheus_client", extra="metrics")
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Metric",
|
||||||
|
"MetricsRegistry",
|
||||||
|
"init_metrics",
|
||||||
|
]
|
||||||
75
src/fastapi_toolsets/metrics/handler.py
Normal file
75
src/fastapi_toolsets/metrics/handler.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
"""Prometheus metrics endpoint for FastAPI applications."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.responses import Response
|
||||||
|
from prometheus_client import (
|
||||||
|
CONTENT_TYPE_LATEST,
|
||||||
|
CollectorRegistry,
|
||||||
|
generate_latest,
|
||||||
|
multiprocess,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..logger import get_logger
|
||||||
|
from .registry import MetricsRegistry
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def _is_multiprocess() -> bool:
|
||||||
|
"""Check if prometheus multi-process mode is enabled."""
|
||||||
|
return "PROMETHEUS_MULTIPROC_DIR" in os.environ
|
||||||
|
|
||||||
|
|
||||||
|
def init_metrics(
|
||||||
|
app: FastAPI,
|
||||||
|
registry: MetricsRegistry,
|
||||||
|
*,
|
||||||
|
path: str = "/metrics",
|
||||||
|
) -> FastAPI:
|
||||||
|
"""Register a Prometheus ``/metrics`` endpoint on a FastAPI app.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: FastAPI application instance.
|
||||||
|
registry: A :class:`MetricsRegistry` containing providers and collectors.
|
||||||
|
path: URL path for the metrics endpoint (default ``/metrics``).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The same FastAPI instance (for chaining).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi_toolsets.metrics import MetricsRegistry, init_metrics
|
||||||
|
|
||||||
|
metrics = MetricsRegistry()
|
||||||
|
app = FastAPI()
|
||||||
|
init_metrics(app, registry=metrics)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
for provider in registry.get_providers():
|
||||||
|
logger.debug("Initialising metric provider '%s'", provider.name)
|
||||||
|
provider.func()
|
||||||
|
|
||||||
|
collectors = registry.get_collectors()
|
||||||
|
|
||||||
|
@app.get(path, include_in_schema=False)
|
||||||
|
async def metrics_endpoint() -> Response:
|
||||||
|
for collector in collectors:
|
||||||
|
if asyncio.iscoroutinefunction(collector.func):
|
||||||
|
await collector.func()
|
||||||
|
else:
|
||||||
|
collector.func()
|
||||||
|
|
||||||
|
if _is_multiprocess():
|
||||||
|
prom_registry = CollectorRegistry()
|
||||||
|
multiprocess.MultiProcessCollector(prom_registry)
|
||||||
|
output = generate_latest(prom_registry)
|
||||||
|
else:
|
||||||
|
output = generate_latest()
|
||||||
|
|
||||||
|
return Response(content=output, media_type=CONTENT_TYPE_LATEST)
|
||||||
|
|
||||||
|
return app
|
||||||
128
src/fastapi_toolsets/metrics/registry.py
Normal file
128
src/fastapi_toolsets/metrics/registry.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
"""Metrics registry with decorator-based registration."""
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from ..logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Metric:
|
||||||
|
"""A metric definition with metadata."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
func: Callable[..., Any]
|
||||||
|
collect: bool = field(default=False)
|
||||||
|
|
||||||
|
|
||||||
|
class MetricsRegistry:
|
||||||
|
"""Registry for managing Prometheus metric providers and collectors.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from prometheus_client import Counter, Gauge
|
||||||
|
from fastapi_toolsets.metrics import MetricsRegistry
|
||||||
|
|
||||||
|
metrics = MetricsRegistry()
|
||||||
|
|
||||||
|
@metrics.register
|
||||||
|
def http_requests():
|
||||||
|
return Counter("http_requests_total", "Total HTTP requests", ["method", "status"])
|
||||||
|
|
||||||
|
@metrics.register(name="db_pool")
|
||||||
|
def database_pool_size():
|
||||||
|
return Gauge("db_pool_size", "Database connection pool size")
|
||||||
|
|
||||||
|
@metrics.register(collect=True)
|
||||||
|
def collect_queue_depth(gauge=Gauge("queue_depth", "Current queue depth")):
|
||||||
|
gauge.set(get_current_queue_depth())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._metrics: dict[str, Metric] = {}
|
||||||
|
|
||||||
|
def register(
|
||||||
|
self,
|
||||||
|
func: Callable[..., Any] | None = None,
|
||||||
|
*,
|
||||||
|
name: str | None = None,
|
||||||
|
collect: bool = False,
|
||||||
|
) -> Callable[..., Any]:
|
||||||
|
"""Register a metric provider or collector function.
|
||||||
|
|
||||||
|
Can be used as a decorator with or without arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: The metric function to register.
|
||||||
|
name: Metric name (defaults to function name).
|
||||||
|
collect: If ``True``, the function is called on every scrape.
|
||||||
|
If ``False`` (default), called once at init time.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
@metrics.register
|
||||||
|
def my_counter():
|
||||||
|
return Counter("my_counter", "A counter")
|
||||||
|
|
||||||
|
@metrics.register(collect=True, name="queue")
|
||||||
|
def collect_queue_depth():
|
||||||
|
gauge.set(compute_depth())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
|
metric_name = name or cast(Any, fn).__name__
|
||||||
|
self._metrics[metric_name] = Metric(
|
||||||
|
name=metric_name,
|
||||||
|
func=fn,
|
||||||
|
collect=collect,
|
||||||
|
)
|
||||||
|
return fn
|
||||||
|
|
||||||
|
if func is not None:
|
||||||
|
return decorator(func)
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def include_registry(self, registry: "MetricsRegistry") -> None:
|
||||||
|
"""Include another :class:`MetricsRegistry` into this one.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
registry: The registry to merge in.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If a metric name already exists in the current registry.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
main = MetricsRegistry()
|
||||||
|
sub = MetricsRegistry()
|
||||||
|
|
||||||
|
@sub.register
|
||||||
|
def sub_metric():
|
||||||
|
return Counter("sub_total", "Sub counter")
|
||||||
|
|
||||||
|
main.include_registry(sub)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
for metric_name, definition in registry._metrics.items():
|
||||||
|
if metric_name in self._metrics:
|
||||||
|
raise ValueError(
|
||||||
|
f"Metric '{metric_name}' already exists in the current registry"
|
||||||
|
)
|
||||||
|
self._metrics[metric_name] = definition
|
||||||
|
|
||||||
|
def get_all(self) -> list[Metric]:
|
||||||
|
"""Get all registered metric definitions."""
|
||||||
|
return list(self._metrics.values())
|
||||||
|
|
||||||
|
def get_providers(self) -> list[Metric]:
|
||||||
|
"""Get metric providers (called once at init)."""
|
||||||
|
return [m for m in self._metrics.values() if not m.collect]
|
||||||
|
|
||||||
|
def get_collectors(self) -> list[Metric]:
|
||||||
|
"""Get collectors (called on each scrape)."""
|
||||||
|
return [m for m in self._metrics.values() if m.collect]
|
||||||
@@ -1,8 +1,30 @@
|
|||||||
from .plugin import register_fixtures
|
"""Pytest helpers for FastAPI testing: sessions, clients, and fixtures."""
|
||||||
from .utils import create_async_client, create_db_session
|
|
||||||
|
try:
|
||||||
|
from .plugin import register_fixtures
|
||||||
|
except ImportError:
|
||||||
|
from .._imports import require_extra
|
||||||
|
|
||||||
|
require_extra(package="pytest", extra="pytest")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .utils import (
|
||||||
|
cleanup_tables,
|
||||||
|
create_async_client,
|
||||||
|
create_db_session,
|
||||||
|
create_worker_database,
|
||||||
|
worker_database_url,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
from .._imports import require_extra
|
||||||
|
|
||||||
|
require_extra(package="httpx", extra="pytest")
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"cleanup_tables",
|
||||||
"create_async_client",
|
"create_async_client",
|
||||||
"create_db_session",
|
"create_db_session",
|
||||||
|
"create_worker_database",
|
||||||
"register_fixtures",
|
"register_fixtures",
|
||||||
|
"worker_database_url",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,55 +1,4 @@
|
|||||||
"""Pytest plugin for using FixtureRegistry fixtures in tests.
|
"""Pytest plugin for using FixtureRegistry fixtures in tests."""
|
||||||
|
|
||||||
This module provides utilities to automatically generate pytest fixtures
|
|
||||||
from your FixtureRegistry, with proper dependency resolution.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
# conftest.py
|
|
||||||
import pytest
|
|
||||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
|
||||||
|
|
||||||
from app.fixtures import fixtures # Your FixtureRegistry
|
|
||||||
from app.models import Base
|
|
||||||
from fastapi_toolsets.pytest_plugin import register_fixtures
|
|
||||||
|
|
||||||
DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost:5432/test_db"
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def engine():
|
|
||||||
engine = create_async_engine(DATABASE_URL)
|
|
||||||
yield engine
|
|
||||||
await engine.dispose()
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def db_session(engine):
|
|
||||||
async with engine.begin() as conn:
|
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
|
||||||
|
|
||||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
|
||||||
session = session_factory()
|
|
||||||
|
|
||||||
try:
|
|
||||||
yield session
|
|
||||||
finally:
|
|
||||||
await session.close()
|
|
||||||
async with engine.begin() as conn:
|
|
||||||
await conn.run_sync(Base.metadata.drop_all)
|
|
||||||
|
|
||||||
# Automatically generate pytest fixtures from registry
|
|
||||||
# Creates: fixture_roles, fixture_users, fixture_posts, etc.
|
|
||||||
register_fixtures(fixtures, globals())
|
|
||||||
|
|
||||||
Usage in tests:
|
|
||||||
# test_users.py
|
|
||||||
async def test_user_count(db_session, fixture_users):
|
|
||||||
# fixture_users automatically loads fixture_roles first (if dependency)
|
|
||||||
# and returns the list of User models
|
|
||||||
assert len(fixture_users) > 0
|
|
||||||
|
|
||||||
async def test_user_role(db_session, fixture_users):
|
|
||||||
user = fixture_users[0]
|
|
||||||
assert user.role_id is not None
|
|
||||||
"""
|
|
||||||
|
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -86,6 +35,7 @@ def register_fixtures(
|
|||||||
List of created fixture names
|
List of created fixture names
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
# conftest.py
|
# conftest.py
|
||||||
from app.fixtures import fixtures
|
from app.fixtures import fixtures
|
||||||
from fastapi_toolsets.pytest_plugin import register_fixtures
|
from fastapi_toolsets.pytest_plugin import register_fixtures
|
||||||
@@ -96,6 +46,7 @@ def register_fixtures(
|
|||||||
# - fixture_roles
|
# - fixture_roles
|
||||||
# - fixture_users (depends on fixture_roles if users depends on roles)
|
# - fixture_users (depends on fixture_roles if users depends on roles)
|
||||||
# - fixture_posts (depends on fixture_users if posts depends on users)
|
# - fixture_posts (depends on fixture_users if posts depends on users)
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
created_fixtures: list[str] = []
|
created_fixtures: list[str] = []
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,18 @@
|
|||||||
"""Pytest helper utilities for FastAPI testing."""
|
"""Pytest helper utilities for FastAPI testing."""
|
||||||
|
|
||||||
from collections.abc import AsyncGenerator
|
import os
|
||||||
|
from collections.abc import AsyncGenerator, Callable
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from httpx import ASGITransport, AsyncClient
|
from httpx import ASGITransport, AsyncClient
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.engine import make_url
|
||||||
|
from sqlalchemy.ext.asyncio import (
|
||||||
|
AsyncSession,
|
||||||
|
async_sessionmaker,
|
||||||
|
create_async_engine,
|
||||||
|
)
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
from ..db import create_db_context
|
from ..db import create_db_context
|
||||||
@@ -15,12 +22,16 @@ from ..db import create_db_context
|
|||||||
async def create_async_client(
|
async def create_async_client(
|
||||||
app: Any,
|
app: Any,
|
||||||
base_url: str = "http://test",
|
base_url: str = "http://test",
|
||||||
|
dependency_overrides: dict[Callable[..., Any], Callable[..., Any]] | None = None,
|
||||||
) -> AsyncGenerator[AsyncClient, None]:
|
) -> AsyncGenerator[AsyncClient, None]:
|
||||||
"""Create an async httpx client for testing FastAPI applications.
|
"""Create an async httpx client for testing FastAPI applications.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
app: FastAPI application instance.
|
app: FastAPI application instance.
|
||||||
base_url: Base URL for requests. Defaults to "http://test".
|
base_url: Base URL for requests. Defaults to "http://test".
|
||||||
|
dependency_overrides: Optional mapping of original dependencies to
|
||||||
|
their test replacements. Applied via ``app.dependency_overrides``
|
||||||
|
before yielding and cleaned up after.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
An AsyncClient configured for the app.
|
An AsyncClient configured for the app.
|
||||||
@@ -41,10 +52,39 @@ async def create_async_client(
|
|||||||
response = await client.get("/health")
|
response = await client.get("/health")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Example with dependency overrides:
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.pytest import create_async_client, create_db_session
|
||||||
|
from app.db import get_db
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_session():
|
||||||
|
async with create_db_session(DATABASE_URL, Base, cleanup=True) as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def client(db_session):
|
||||||
|
async def override():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
async with create_async_client(
|
||||||
|
app, dependency_overrides={get_db: override}
|
||||||
|
) as c:
|
||||||
|
yield c
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
if dependency_overrides:
|
||||||
|
app.dependency_overrides.update(dependency_overrides)
|
||||||
|
|
||||||
transport = ASGITransport(app=app)
|
transport = ASGITransport(app=app)
|
||||||
async with AsyncClient(transport=transport, base_url=base_url) as client:
|
try:
|
||||||
yield client
|
async with AsyncClient(transport=transport, base_url=base_url) as client:
|
||||||
|
yield client
|
||||||
|
finally:
|
||||||
|
if dependency_overrides:
|
||||||
|
for key in dependency_overrides:
|
||||||
|
app.dependency_overrides.pop(key, None)
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
@@ -55,6 +95,7 @@ async def create_db_session(
|
|||||||
echo: bool = False,
|
echo: bool = False,
|
||||||
expire_on_commit: bool = False,
|
expire_on_commit: bool = False,
|
||||||
drop_tables: bool = True,
|
drop_tables: bool = True,
|
||||||
|
cleanup: bool = False,
|
||||||
) -> AsyncGenerator[AsyncSession, None]:
|
) -> AsyncGenerator[AsyncSession, None]:
|
||||||
"""Create a database session for testing.
|
"""Create a database session for testing.
|
||||||
|
|
||||||
@@ -67,6 +108,8 @@ async def create_db_session(
|
|||||||
echo: Enable SQLAlchemy query logging. Defaults to False.
|
echo: Enable SQLAlchemy query logging. Defaults to False.
|
||||||
expire_on_commit: Expire objects after commit. Defaults to False.
|
expire_on_commit: Expire objects after commit. Defaults to False.
|
||||||
drop_tables: Drop tables after test. Defaults to True.
|
drop_tables: Drop tables after test. Defaults to True.
|
||||||
|
cleanup: Truncate all tables after test using
|
||||||
|
:func:`cleanup_tables`. Defaults to False.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
An AsyncSession ready for database operations.
|
An AsyncSession ready for database operations.
|
||||||
@@ -80,7 +123,9 @@ async def create_db_session(
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def db_session():
|
async def db_session():
|
||||||
async with create_db_session(DATABASE_URL, Base) as session:
|
async with create_db_session(
|
||||||
|
DATABASE_URL, Base, cleanup=True
|
||||||
|
) as session:
|
||||||
yield session
|
yield session
|
||||||
|
|
||||||
async def test_create_user(db_session: AsyncSession):
|
async def test_create_user(db_session: AsyncSession):
|
||||||
@@ -103,8 +148,168 @@ async def create_db_session(
|
|||||||
async with get_session() as session:
|
async with get_session() as session:
|
||||||
yield session
|
yield session
|
||||||
|
|
||||||
|
if cleanup:
|
||||||
|
await cleanup_tables(session, base)
|
||||||
|
|
||||||
if drop_tables:
|
if drop_tables:
|
||||||
async with engine.begin() as conn:
|
async with engine.begin() as conn:
|
||||||
await conn.run_sync(base.metadata.drop_all)
|
await conn.run_sync(base.metadata.drop_all)
|
||||||
finally:
|
finally:
|
||||||
await engine.dispose()
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_xdist_worker(default_test_db: str) -> str:
|
||||||
|
"""Return the pytest-xdist worker name, or *default_test_db* when not running under xdist.
|
||||||
|
|
||||||
|
Reads the ``PYTEST_XDIST_WORKER`` environment variable that xdist sets
|
||||||
|
automatically in each worker process (e.g. ``"gw0"``, ``"gw1"``).
|
||||||
|
When xdist is not installed or not active, the variable is absent and
|
||||||
|
*default_test_db* is returned instead.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
default_test_db: Fallback value returned when ``PYTEST_XDIST_WORKER``
|
||||||
|
is not set.
|
||||||
|
"""
|
||||||
|
return os.environ.get("PYTEST_XDIST_WORKER", default_test_db)
|
||||||
|
|
||||||
|
|
||||||
|
def worker_database_url(database_url: str, default_test_db: str) -> str:
|
||||||
|
"""Derive a per-worker database URL for pytest-xdist parallel runs.
|
||||||
|
|
||||||
|
Appends ``_{worker_name}`` to the database name so each xdist worker
|
||||||
|
operates on its own database. When not running under xdist,
|
||||||
|
``_{default_test_db}`` is appended instead.
|
||||||
|
|
||||||
|
The worker name is read from the ``PYTEST_XDIST_WORKER`` environment
|
||||||
|
variable (set automatically by xdist in each worker process).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
database_url: Original database connection URL.
|
||||||
|
default_test_db: Suffix appended to the database name when
|
||||||
|
``PYTEST_XDIST_WORKER`` is not set.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A database URL with a worker- or default-specific database name.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
# With PYTEST_XDIST_WORKER="gw0":
|
||||||
|
url = worker_database_url(
|
||||||
|
"postgresql+asyncpg://user:pass@localhost/test_db",
|
||||||
|
default_test_db="test",
|
||||||
|
)
|
||||||
|
# "postgresql+asyncpg://user:pass@localhost/test_db_gw0"
|
||||||
|
|
||||||
|
# Without PYTEST_XDIST_WORKER:
|
||||||
|
url = worker_database_url(
|
||||||
|
"postgresql+asyncpg://user:pass@localhost/test_db",
|
||||||
|
default_test_db="test",
|
||||||
|
)
|
||||||
|
# "postgresql+asyncpg://user:pass@localhost/test_db_test"
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
worker = _get_xdist_worker(default_test_db=default_test_db)
|
||||||
|
|
||||||
|
url = make_url(database_url)
|
||||||
|
url = url.set(database=f"{url.database}_{worker}")
|
||||||
|
return url.render_as_string(hide_password=False)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def create_worker_database(
|
||||||
|
database_url: str,
|
||||||
|
default_test_db: str = "test_db",
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""Create and drop a per-worker database for pytest-xdist isolation.
|
||||||
|
|
||||||
|
Intended for use as a **session-scoped** fixture. Connects to the server
|
||||||
|
using the original *database_url* (with ``AUTOCOMMIT`` isolation for DDL),
|
||||||
|
creates a dedicated database for the worker, and yields the worker-specific
|
||||||
|
URL. On cleanup the worker database is dropped.
|
||||||
|
|
||||||
|
When running under xdist the database name is suffixed with the worker
|
||||||
|
name (e.g. ``_gw0``). Otherwise it is suffixed with *default_test_db*.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
database_url: Original database connection URL.
|
||||||
|
default_test_db: Suffix appended to the database name when
|
||||||
|
``PYTEST_XDIST_WORKER`` is not set. Defaults to ``"test_db"``.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
The worker-specific database URL.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.pytest import (
|
||||||
|
create_worker_database, create_db_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost/test_db"
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
async def worker_db_url():
|
||||||
|
async with create_worker_database(DATABASE_URL) as url:
|
||||||
|
yield url
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_session(worker_db_url):
|
||||||
|
async with create_db_session(
|
||||||
|
worker_db_url, Base, cleanup=True
|
||||||
|
) as session:
|
||||||
|
yield session
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
worker_url = worker_database_url(
|
||||||
|
database_url=database_url, default_test_db=default_test_db
|
||||||
|
)
|
||||||
|
worker_db_name = make_url(worker_url).database
|
||||||
|
|
||||||
|
engine = create_async_engine(
|
||||||
|
database_url,
|
||||||
|
isolation_level="AUTOCOMMIT",
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
await conn.execute(text(f"DROP DATABASE IF EXISTS {worker_db_name}"))
|
||||||
|
await conn.execute(text(f"CREATE DATABASE {worker_db_name}"))
|
||||||
|
|
||||||
|
yield worker_url
|
||||||
|
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
await conn.execute(text(f"DROP DATABASE IF EXISTS {worker_db_name}"))
|
||||||
|
finally:
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
async def cleanup_tables(
|
||||||
|
session: AsyncSession,
|
||||||
|
base: type[DeclarativeBase],
|
||||||
|
) -> None:
|
||||||
|
"""Truncate all tables for fast between-test cleanup.
|
||||||
|
|
||||||
|
Executes a single ``TRUNCATE … RESTART IDENTITY CASCADE`` statement
|
||||||
|
across every table in *base*'s metadata, which is significantly faster
|
||||||
|
than dropping and re-creating tables between tests.
|
||||||
|
|
||||||
|
This is a no-op when the metadata contains no tables.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: An active async database session.
|
||||||
|
base: SQLAlchemy DeclarativeBase class containing model metadata.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_session(worker_db_url):
|
||||||
|
async with create_db_session(worker_db_url, Base) as session:
|
||||||
|
yield session
|
||||||
|
await cleanup_tables(session, Base)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
tables = base.metadata.sorted_tables
|
||||||
|
if not tables:
|
||||||
|
return
|
||||||
|
|
||||||
|
table_names = ", ".join(f'"{t.name}"' for t in tables)
|
||||||
|
await session.execute(text(f"TRUNCATE {table_names} RESTART IDENTITY CASCADE"))
|
||||||
|
await session.commit()
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Base Pydantic schemas for API responses."""
|
"""Base Pydantic schemas for API responses."""
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import ClassVar, Generic, TypeVar
|
from typing import Any, ClassVar, Generic, TypeVar
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
@@ -10,6 +10,7 @@ __all__ = [
|
|||||||
"ErrorResponse",
|
"ErrorResponse",
|
||||||
"Pagination",
|
"Pagination",
|
||||||
"PaginatedResponse",
|
"PaginatedResponse",
|
||||||
|
"PydanticBase",
|
||||||
"Response",
|
"Response",
|
||||||
"ResponseStatus",
|
"ResponseStatus",
|
||||||
]
|
]
|
||||||
@@ -49,6 +50,7 @@ class ApiError(PydanticBase):
|
|||||||
msg: str
|
msg: str
|
||||||
desc: str
|
desc: str
|
||||||
err_code: str
|
err_code: str
|
||||||
|
data: Any | None = None
|
||||||
|
|
||||||
|
|
||||||
class BaseResponse(PydanticBase):
|
class BaseResponse(PydanticBase):
|
||||||
@@ -69,7 +71,9 @@ class Response(BaseResponse, Generic[DataT]):
|
|||||||
"""Generic API response with data payload.
|
"""Generic API response with data payload.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
Response[UserRead](data=user, message="User retrieved")
|
Response[UserRead](data=user, message="User retrieved")
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data: DataT | None = None
|
data: DataT | None = None
|
||||||
@@ -83,7 +87,7 @@ class ErrorResponse(BaseResponse):
|
|||||||
|
|
||||||
status: ResponseStatus = ResponseStatus.FAIL
|
status: ResponseStatus = ResponseStatus.FAIL
|
||||||
description: str | None = None
|
description: str | None = None
|
||||||
data: None = None
|
data: Any | None = None
|
||||||
|
|
||||||
|
|
||||||
class Pagination(PydanticBase):
|
class Pagination(PydanticBase):
|
||||||
@@ -106,10 +110,12 @@ class PaginatedResponse(BaseResponse, Generic[DataT]):
|
|||||||
"""Paginated API response for list endpoints.
|
"""Paginated API response for list endpoints.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
```python
|
||||||
PaginatedResponse[UserRead](
|
PaginatedResponse[UserRead](
|
||||||
data=users,
|
data=users,
|
||||||
pagination=Pagination(total_count=100, items_per_page=10, page=1, has_more=True)
|
pagination=Pagination(total_count=100, items_per_page=10, page=1, has_more=True)
|
||||||
)
|
)
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data: list[DataT]
|
data: list[DataT]
|
||||||
|
|||||||
@@ -5,24 +5,18 @@ import uuid
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import ForeignKey, String, Uuid
|
from sqlalchemy import Column, ForeignKey, String, Table, Uuid
|
||||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from fastapi_toolsets.crud import CrudFactory
|
from fastapi_toolsets.crud import CrudFactory
|
||||||
|
|
||||||
# PostgreSQL connection URL from environment or default for local development
|
DATABASE_URL = os.getenv(
|
||||||
DATABASE_URL = os.getenv("DATABASE_URL") or os.getenv(
|
key="DATABASE_URL",
|
||||||
"TEST_DATABASE_URL",
|
default="postgresql+asyncpg://postgres:postgres@localhost:5432/postgres",
|
||||||
"postgresql+asyncpg://postgres:postgres@localhost:5432/fastapi_toolsets_test",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Test Models
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
class Base(DeclarativeBase):
|
class Base(DeclarativeBase):
|
||||||
"""Base class for test models."""
|
"""Base class for test models."""
|
||||||
|
|
||||||
@@ -56,6 +50,25 @@ class User(Base):
|
|||||||
role: Mapped[Role | None] = relationship(back_populates="users")
|
role: Mapped[Role | None] = relationship(back_populates="users")
|
||||||
|
|
||||||
|
|
||||||
|
class Tag(Base):
|
||||||
|
"""Test tag model."""
|
||||||
|
|
||||||
|
__tablename__ = "tags"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||||
|
name: Mapped[str] = mapped_column(String(50), unique=True)
|
||||||
|
|
||||||
|
|
||||||
|
post_tags = Table(
|
||||||
|
"post_tags",
|
||||||
|
Base.metadata,
|
||||||
|
Column(
|
||||||
|
"post_id", Uuid, ForeignKey("posts.id", ondelete="CASCADE"), primary_key=True
|
||||||
|
),
|
||||||
|
Column("tag_id", Uuid, ForeignKey("tags.id", ondelete="CASCADE"), primary_key=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Post(Base):
|
class Post(Base):
|
||||||
"""Test post model."""
|
"""Test post model."""
|
||||||
|
|
||||||
@@ -67,10 +80,7 @@ class Post(Base):
|
|||||||
is_published: Mapped[bool] = mapped_column(default=False)
|
is_published: Mapped[bool] = mapped_column(default=False)
|
||||||
author_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"))
|
author_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"))
|
||||||
|
|
||||||
|
tags: Mapped[list[Tag]] = relationship(secondary=post_tags)
|
||||||
# =============================================================================
|
|
||||||
# Test Schemas
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
class RoleCreate(BaseModel):
|
class RoleCreate(BaseModel):
|
||||||
@@ -105,6 +115,13 @@ class UserUpdate(BaseModel):
|
|||||||
role_id: uuid.UUID | None = None
|
role_id: uuid.UUID | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TagCreate(BaseModel):
|
||||||
|
"""Schema for creating a tag."""
|
||||||
|
|
||||||
|
id: uuid.UUID | None = None
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
class PostCreate(BaseModel):
|
class PostCreate(BaseModel):
|
||||||
"""Schema for creating a post."""
|
"""Schema for creating a post."""
|
||||||
|
|
||||||
@@ -123,18 +140,31 @@ class PostUpdate(BaseModel):
|
|||||||
is_published: bool | None = None
|
is_published: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
class PostM2MCreate(BaseModel):
|
||||||
# CRUD Classes
|
"""Schema for creating a post with M2M tag IDs."""
|
||||||
# =============================================================================
|
|
||||||
|
id: uuid.UUID | None = None
|
||||||
|
title: str
|
||||||
|
content: str = ""
|
||||||
|
is_published: bool = False
|
||||||
|
author_id: uuid.UUID
|
||||||
|
tag_ids: list[uuid.UUID] = []
|
||||||
|
|
||||||
|
|
||||||
|
class PostM2MUpdate(BaseModel):
|
||||||
|
"""Schema for updating a post with M2M tag IDs."""
|
||||||
|
|
||||||
|
title: str | None = None
|
||||||
|
content: str | None = None
|
||||||
|
is_published: bool | None = None
|
||||||
|
tag_ids: list[uuid.UUID] | None = None
|
||||||
|
|
||||||
|
|
||||||
RoleCrud = CrudFactory(Role)
|
RoleCrud = CrudFactory(Role)
|
||||||
UserCrud = CrudFactory(User)
|
UserCrud = CrudFactory(User)
|
||||||
PostCrud = CrudFactory(Post)
|
PostCrud = CrudFactory(Post)
|
||||||
|
TagCrud = CrudFactory(Tag)
|
||||||
|
PostM2MCrud = CrudFactory(Post, m2m_fields={"tag_ids": Post.tags})
|
||||||
# =============================================================================
|
|
||||||
# Fixtures
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
543
tests/test_cli.py
Normal file
543
tests/test_cli.py
Normal file
@@ -0,0 +1,543 @@
|
|||||||
|
"""Tests for fastapi_toolsets.cli module."""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from typer.testing import CliRunner
|
||||||
|
|
||||||
|
from fastapi_toolsets.cli.config import (
|
||||||
|
get_config_value,
|
||||||
|
get_custom_cli,
|
||||||
|
get_db_context,
|
||||||
|
get_fixtures_registry,
|
||||||
|
import_from_string,
|
||||||
|
)
|
||||||
|
from fastapi_toolsets.cli.pyproject import find_pyproject, load_pyproject
|
||||||
|
from fastapi_toolsets.cli.utils import async_command
|
||||||
|
from fastapi_toolsets.fixtures import FixtureRegistry
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
|
||||||
|
|
||||||
|
class TestPyproject:
|
||||||
|
"""Tests for pyproject.toml discovery and loading."""
|
||||||
|
|
||||||
|
def test_find_pyproject_in_current_dir(self, tmp_path, monkeypatch):
|
||||||
|
"""Finds pyproject.toml in current directory."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[project]\nname = 'test'\n")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
result = find_pyproject()
|
||||||
|
assert result == pyproject
|
||||||
|
|
||||||
|
def test_find_pyproject_in_parent_dir(self, tmp_path, monkeypatch):
|
||||||
|
"""Finds pyproject.toml in parent directory."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[project]\nname = 'test'\n")
|
||||||
|
subdir = tmp_path / "src" / "app"
|
||||||
|
subdir.mkdir(parents=True)
|
||||||
|
monkeypatch.chdir(subdir)
|
||||||
|
|
||||||
|
result = find_pyproject()
|
||||||
|
assert result == pyproject
|
||||||
|
|
||||||
|
def test_find_pyproject_not_found(self, tmp_path, monkeypatch):
|
||||||
|
"""Returns None when no pyproject.toml exists."""
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
result = find_pyproject()
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_load_pyproject_returns_tool_config(self, tmp_path, monkeypatch):
|
||||||
|
"""load_pyproject returns the [tool.fastapi-toolsets] section."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text(
|
||||||
|
'[tool.fastapi-toolsets]\nfixtures = "app.fixtures:registry"\n'
|
||||||
|
)
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
result = load_pyproject()
|
||||||
|
assert result == {"fixtures": "app.fixtures:registry"}
|
||||||
|
|
||||||
|
def test_load_pyproject_empty_when_no_file(self, tmp_path, monkeypatch):
|
||||||
|
"""Returns empty dict when no pyproject.toml exists."""
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
result = load_pyproject()
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
def test_load_pyproject_empty_when_no_tool_section(self, tmp_path, monkeypatch):
|
||||||
|
"""Returns empty dict when no [tool.fastapi-toolsets] section."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[project]\nname = 'test'\n")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
result = load_pyproject()
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
def test_load_pyproject_invalid_toml(self, tmp_path, monkeypatch):
|
||||||
|
"""Returns empty dict when pyproject.toml is invalid."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("invalid toml {{{")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
result = load_pyproject()
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
class TestImportFromString:
|
||||||
|
"""Tests for import_from_string function."""
|
||||||
|
|
||||||
|
def test_import_valid_path(self):
|
||||||
|
"""Import valid module:attribute path."""
|
||||||
|
result = import_from_string("fastapi_toolsets.fixtures:FixtureRegistry")
|
||||||
|
assert result is FixtureRegistry
|
||||||
|
|
||||||
|
def test_import_without_colon_raises_error(self):
|
||||||
|
"""Import path without colon raises error."""
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
import_from_string("fastapi_toolsets.fixtures.FixtureRegistry")
|
||||||
|
assert "Expected format: 'module:attribute'" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_import_nonexistent_module_raises_error(self):
|
||||||
|
"""Import nonexistent module raises error."""
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
import_from_string("nonexistent.module:something")
|
||||||
|
assert "Cannot import module" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_import_nonexistent_attribute_raises_error(self):
|
||||||
|
"""Import nonexistent attribute raises error."""
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
import_from_string("fastapi_toolsets.fixtures:NonexistentClass")
|
||||||
|
assert "has no attribute" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetConfigValue:
|
||||||
|
"""Tests for get_config_value function."""
|
||||||
|
|
||||||
|
def test_get_existing_value(self, tmp_path, monkeypatch):
|
||||||
|
"""Returns value when key exists."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text('[tool.fastapi-toolsets]\nfixtures = "app:registry"\n')
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
result = get_config_value("fixtures")
|
||||||
|
assert result == "app:registry"
|
||||||
|
|
||||||
|
def test_get_missing_value_returns_none(self, tmp_path, monkeypatch):
|
||||||
|
"""Returns None when key is missing and not required."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[tool.fastapi-toolsets]\n")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
result = get_config_value("fixtures")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_get_missing_value_required_raises_error(self, tmp_path, monkeypatch):
|
||||||
|
"""Raises error when key is missing and required."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[tool.fastapi-toolsets]\n")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
get_config_value("fixtures", required=True)
|
||||||
|
assert "No 'fixtures' configured" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetFixturesRegistry:
|
||||||
|
"""Tests for get_fixtures_registry function."""
|
||||||
|
|
||||||
|
def test_raises_when_not_configured(self, tmp_path, monkeypatch):
|
||||||
|
"""Raises error when fixtures not configured."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[tool.fastapi-toolsets]\n")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
get_fixtures_registry()
|
||||||
|
assert "No 'fixtures' configured" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_raises_when_not_registry_instance(self, tmp_path, monkeypatch):
|
||||||
|
"""Raises error when imported object is not a FixtureRegistry."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text(
|
||||||
|
'[tool.fastapi-toolsets]\nfixtures = "my_fixtures:registry"\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
fixtures_file = tmp_path / "my_fixtures.py"
|
||||||
|
fixtures_file.write_text("registry = 'not a registry'\n")
|
||||||
|
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
if str(tmp_path) not in sys.path:
|
||||||
|
sys.path.insert(0, str(tmp_path))
|
||||||
|
|
||||||
|
try:
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
get_fixtures_registry()
|
||||||
|
assert "must be a FixtureRegistry instance" in str(exc_info.value)
|
||||||
|
finally:
|
||||||
|
if str(tmp_path) in sys.path:
|
||||||
|
sys.path.remove(str(tmp_path))
|
||||||
|
if "my_fixtures" in sys.modules:
|
||||||
|
del sys.modules["my_fixtures"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetDbContext:
|
||||||
|
"""Tests for get_db_context function."""
|
||||||
|
|
||||||
|
def test_raises_when_not_configured(self, tmp_path, monkeypatch):
|
||||||
|
"""Raises error when db_context not configured."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[tool.fastapi-toolsets]\n")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
get_db_context()
|
||||||
|
assert "No 'db_context' configured" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetCustomCli:
|
||||||
|
"""Tests for get_custom_cli function."""
|
||||||
|
|
||||||
|
def test_returns_none_when_not_configured(self, tmp_path, monkeypatch):
|
||||||
|
"""Returns None when custom_cli not configured."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[tool.fastapi-toolsets]\n")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
result = get_custom_cli()
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_raises_when_not_typer_instance(self, tmp_path, monkeypatch):
|
||||||
|
"""Raises error when imported object is not a Typer instance."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text('[tool.fastapi-toolsets]\ncustom_cli = "my_cli:cli"\n')
|
||||||
|
|
||||||
|
cli_file = tmp_path / "my_cli.py"
|
||||||
|
cli_file.write_text("cli = 'not a typer'\n")
|
||||||
|
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
if str(tmp_path) not in sys.path:
|
||||||
|
sys.path.insert(0, str(tmp_path))
|
||||||
|
|
||||||
|
try:
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
get_custom_cli()
|
||||||
|
assert "must be a Typer instance" in str(exc_info.value)
|
||||||
|
finally:
|
||||||
|
if str(tmp_path) in sys.path:
|
||||||
|
sys.path.remove(str(tmp_path))
|
||||||
|
if "my_cli" in sys.modules:
|
||||||
|
del sys.modules["my_cli"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestCliApp:
|
||||||
|
"""Tests for CLI application."""
|
||||||
|
|
||||||
|
def test_cli_help(self, tmp_path, monkeypatch):
|
||||||
|
"""CLI shows help without fixtures."""
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
# Need to reload the module to pick up new cwd
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from fastapi_toolsets.cli import app
|
||||||
|
|
||||||
|
importlib.reload(app)
|
||||||
|
|
||||||
|
result = runner.invoke(app.cli, ["--help"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "CLI utilities for FastAPI projects" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
class TestFixturesCli:
|
||||||
|
"""Tests for fixtures CLI commands."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def cli_env(self, tmp_path, monkeypatch):
|
||||||
|
"""Set up CLI environment with fixtures config."""
|
||||||
|
# Create pyproject.toml
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text(
|
||||||
|
"[tool.fastapi-toolsets]\n"
|
||||||
|
'fixtures = "fixtures:registry"\n'
|
||||||
|
'db_context = "db:get_session"\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create fixtures module
|
||||||
|
fixtures_file = tmp_path / "fixtures.py"
|
||||||
|
fixtures_file.write_text(
|
||||||
|
"from fastapi_toolsets.fixtures import FixtureRegistry, Context\n"
|
||||||
|
"\n"
|
||||||
|
"registry = FixtureRegistry()\n"
|
||||||
|
"\n"
|
||||||
|
"@registry.register(contexts=[Context.BASE])\n"
|
||||||
|
"def roles():\n"
|
||||||
|
' return [{"id": 1, "name": "admin"}, {"id": 2, "name": "user"}]\n'
|
||||||
|
"\n"
|
||||||
|
'@registry.register(depends_on=["roles"], contexts=[Context.TESTING])\n'
|
||||||
|
"def users():\n"
|
||||||
|
' return [{"id": 1, "name": "alice", "role_id": 1}]\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create db module
|
||||||
|
db_file = tmp_path / "db.py"
|
||||||
|
db_file.write_text(
|
||||||
|
"from contextlib import asynccontextmanager\n"
|
||||||
|
"\n"
|
||||||
|
"@asynccontextmanager\n"
|
||||||
|
"async def get_session():\n"
|
||||||
|
" yield None\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
# Add tmp_path to sys.path for imports
|
||||||
|
if str(tmp_path) not in sys.path:
|
||||||
|
sys.path.insert(0, str(tmp_path))
|
||||||
|
|
||||||
|
# Reload the CLI module to pick up new config
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from fastapi_toolsets.cli import app
|
||||||
|
|
||||||
|
importlib.reload(app)
|
||||||
|
|
||||||
|
yield tmp_path, app.cli
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
if str(tmp_path) in sys.path:
|
||||||
|
sys.path.remove(str(tmp_path))
|
||||||
|
|
||||||
|
def test_fixtures_list(self, cli_env):
|
||||||
|
"""fixtures list shows registered fixtures."""
|
||||||
|
tmp_path, cli = cli_env
|
||||||
|
result = runner.invoke(cli, ["fixtures", "list"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "roles" in result.output
|
||||||
|
assert "users" in result.output
|
||||||
|
assert "Total: 2 fixture(s)" in result.output
|
||||||
|
|
||||||
|
def test_fixtures_list_with_context(self, cli_env):
|
||||||
|
"""fixtures list --context filters by context."""
|
||||||
|
tmp_path, cli = cli_env
|
||||||
|
result = runner.invoke(cli, ["fixtures", "list", "--context", "base"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "roles" in result.output
|
||||||
|
assert "users" not in result.output
|
||||||
|
assert "Total: 1 fixture(s)" in result.output
|
||||||
|
|
||||||
|
def test_fixtures_load_dry_run(self, cli_env):
|
||||||
|
"""fixtures load --dry-run shows what would be loaded."""
|
||||||
|
tmp_path, cli = cli_env
|
||||||
|
result = runner.invoke(cli, ["fixtures", "load", "base", "--dry-run"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Fixtures to load" in result.output
|
||||||
|
assert "roles" in result.output
|
||||||
|
assert "[Dry run - no changes made]" in result.output
|
||||||
|
|
||||||
|
def test_fixtures_load_invalid_strategy(self, cli_env):
|
||||||
|
"""fixtures load with invalid strategy shows error."""
|
||||||
|
tmp_path, cli = cli_env
|
||||||
|
result = runner.invoke(
|
||||||
|
cli, ["fixtures", "load", "base", "--strategy", "invalid"]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code != 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestCliWithoutFixturesConfig:
|
||||||
|
"""Tests for CLI when fixtures is not configured."""
|
||||||
|
|
||||||
|
def test_no_fixtures_command(self, tmp_path, monkeypatch):
|
||||||
|
"""fixtures command is not available when not configured."""
|
||||||
|
# Create pyproject.toml without fixtures
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text('[project]\nname = "test"\n')
|
||||||
|
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
# Reload the CLI module
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from fastapi_toolsets.cli import app
|
||||||
|
|
||||||
|
importlib.reload(app)
|
||||||
|
|
||||||
|
result = runner.invoke(app.cli, ["--help"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "fixtures" not in result.output
|
||||||
|
|
||||||
|
|
||||||
|
class TestCustomCliConfig:
|
||||||
|
"""Tests for custom CLI configuration."""
|
||||||
|
|
||||||
|
def test_cli_with_custom_cli(self, tmp_path, monkeypatch):
|
||||||
|
"""CLI uses custom Typer instance when configured."""
|
||||||
|
import typer
|
||||||
|
|
||||||
|
# Create pyproject.toml with custom_cli config
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text('[tool.fastapi-toolsets]\ncustom_cli = "my_cli:cli"\n')
|
||||||
|
|
||||||
|
# Create custom CLI module with its own Typer and commands
|
||||||
|
cli_file = tmp_path / "my_cli.py"
|
||||||
|
cli_file.write_text(
|
||||||
|
"import typer\n"
|
||||||
|
"\n"
|
||||||
|
"cli = typer.Typer(name='my-app', help='My custom CLI')\n"
|
||||||
|
"\n"
|
||||||
|
"@cli.command()\n"
|
||||||
|
"def hello():\n"
|
||||||
|
' print("Hello from custom CLI!")\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
# Add tmp_path to sys.path for imports
|
||||||
|
if str(tmp_path) not in sys.path:
|
||||||
|
sys.path.insert(0, str(tmp_path))
|
||||||
|
|
||||||
|
# Remove my_cli from sys.modules if it was previously loaded
|
||||||
|
if "my_cli" in sys.modules:
|
||||||
|
del sys.modules["my_cli"]
|
||||||
|
|
||||||
|
# Reload the CLI module to pick up new config
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from fastapi_toolsets.cli import app
|
||||||
|
|
||||||
|
importlib.reload(app)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Verify custom CLI is used
|
||||||
|
assert isinstance(app.cli, typer.Typer)
|
||||||
|
|
||||||
|
result = runner.invoke(app.cli, ["--help"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "My custom CLI" in result.output
|
||||||
|
assert "hello" in result.output
|
||||||
|
|
||||||
|
result = runner.invoke(app.cli, ["hello"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Hello from custom CLI!" in result.output
|
||||||
|
finally:
|
||||||
|
if str(tmp_path) in sys.path:
|
||||||
|
sys.path.remove(str(tmp_path))
|
||||||
|
if "my_cli" in sys.modules:
|
||||||
|
del sys.modules["my_cli"]
|
||||||
|
|
||||||
|
def test_custom_cli_with_fixtures(self, tmp_path, monkeypatch):
|
||||||
|
"""Custom CLI gets fixtures command added when configured."""
|
||||||
|
# Create pyproject.toml with both custom_cli and fixtures
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text(
|
||||||
|
"[tool.fastapi-toolsets]\n"
|
||||||
|
'custom_cli = "my_cli:cli"\n'
|
||||||
|
'fixtures = "fixtures:registry"\n'
|
||||||
|
'db_context = "db:get_session"\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create custom CLI module
|
||||||
|
cli_file = tmp_path / "my_cli.py"
|
||||||
|
cli_file.write_text(
|
||||||
|
"import typer\n"
|
||||||
|
"\n"
|
||||||
|
"cli = typer.Typer(name='my-app', help='My custom CLI')\n"
|
||||||
|
"\n"
|
||||||
|
"@cli.command()\n"
|
||||||
|
"def hello():\n"
|
||||||
|
' print("Hello!")\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create fixtures module
|
||||||
|
fixtures_file = tmp_path / "fixtures.py"
|
||||||
|
fixtures_file.write_text(
|
||||||
|
"from fastapi_toolsets.fixtures import FixtureRegistry\n"
|
||||||
|
"\n"
|
||||||
|
"registry = FixtureRegistry()\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create db module
|
||||||
|
db_file = tmp_path / "db.py"
|
||||||
|
db_file.write_text(
|
||||||
|
"from contextlib import asynccontextmanager\n"
|
||||||
|
"\n"
|
||||||
|
"@asynccontextmanager\n"
|
||||||
|
"async def get_session():\n"
|
||||||
|
" yield None\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
if str(tmp_path) not in sys.path:
|
||||||
|
sys.path.insert(0, str(tmp_path))
|
||||||
|
|
||||||
|
for mod in ["my_cli", "fixtures", "db"]:
|
||||||
|
if mod in sys.modules:
|
||||||
|
del sys.modules[mod]
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from fastapi_toolsets.cli import app
|
||||||
|
|
||||||
|
importlib.reload(app)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = runner.invoke(app.cli, ["--help"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
# Should have both custom command and fixtures
|
||||||
|
assert "hello" in result.output
|
||||||
|
assert "fixtures" in result.output
|
||||||
|
finally:
|
||||||
|
if str(tmp_path) in sys.path:
|
||||||
|
sys.path.remove(str(tmp_path))
|
||||||
|
for mod in ["my_cli", "fixtures", "db"]:
|
||||||
|
if mod in sys.modules:
|
||||||
|
del sys.modules[mod]
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncCommand:
|
||||||
|
"""Tests for async_command decorator."""
|
||||||
|
|
||||||
|
def test_async_command_runs_coroutine(self):
|
||||||
|
"""async_command runs async function synchronously."""
|
||||||
|
|
||||||
|
@async_command
|
||||||
|
async def async_func(value: int) -> int:
|
||||||
|
return value * 2
|
||||||
|
|
||||||
|
result = async_func(21)
|
||||||
|
assert result == 42
|
||||||
|
|
||||||
|
def test_async_command_preserves_signature(self):
|
||||||
|
"""async_command preserves function signature."""
|
||||||
|
|
||||||
|
@async_command
|
||||||
|
async def async_func(name: str, count: int = 1) -> str:
|
||||||
|
return f"{name} x {count}"
|
||||||
|
|
||||||
|
result = async_func("test", count=3)
|
||||||
|
assert result == "test x 3"
|
||||||
|
|
||||||
|
def test_async_command_preserves_docstring(self):
|
||||||
|
"""async_command preserves function docstring."""
|
||||||
|
|
||||||
|
@async_command
|
||||||
|
async def async_func() -> None:
|
||||||
|
"""This is a docstring."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert async_func.__doc__ == """This is a docstring."""
|
||||||
|
|
||||||
|
def test_async_command_preserves_name(self):
|
||||||
|
"""async_command preserves function name."""
|
||||||
|
|
||||||
|
@async_command
|
||||||
|
async def my_async_function() -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert my_async_function.__name__ == "my_async_function"
|
||||||
@@ -4,16 +4,25 @@ import uuid
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
from fastapi_toolsets.crud import CrudFactory
|
from fastapi_toolsets.crud import CrudFactory
|
||||||
from fastapi_toolsets.crud.factory import AsyncCrud
|
from fastapi_toolsets.crud.factory import AsyncCrud
|
||||||
from fastapi_toolsets.exceptions import NotFoundError
|
from fastapi_toolsets.exceptions import NotFoundError
|
||||||
|
|
||||||
from .conftest import (
|
from .conftest import (
|
||||||
|
Post,
|
||||||
|
PostCreate,
|
||||||
|
PostCrud,
|
||||||
|
PostM2MCreate,
|
||||||
|
PostM2MCrud,
|
||||||
|
PostM2MUpdate,
|
||||||
Role,
|
Role,
|
||||||
RoleCreate,
|
RoleCreate,
|
||||||
RoleCrud,
|
RoleCrud,
|
||||||
RoleUpdate,
|
RoleUpdate,
|
||||||
|
TagCreate,
|
||||||
|
TagCrud,
|
||||||
User,
|
User,
|
||||||
UserCreate,
|
UserCreate,
|
||||||
UserCrud,
|
UserCrud,
|
||||||
@@ -426,11 +435,11 @@ class TestCrudPaginate:
|
|||||||
|
|
||||||
result = await RoleCrud.paginate(db_session, page=1, items_per_page=10)
|
result = await RoleCrud.paginate(db_session, page=1, items_per_page=10)
|
||||||
|
|
||||||
assert len(result["data"]) == 10
|
assert len(result.data) == 10
|
||||||
assert result["pagination"]["total_count"] == 25
|
assert result.pagination.total_count == 25
|
||||||
assert result["pagination"]["page"] == 1
|
assert result.pagination.page == 1
|
||||||
assert result["pagination"]["items_per_page"] == 10
|
assert result.pagination.items_per_page == 10
|
||||||
assert result["pagination"]["has_more"] is True
|
assert result.pagination.has_more is True
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_paginate_last_page(self, db_session: AsyncSession):
|
async def test_paginate_last_page(self, db_session: AsyncSession):
|
||||||
@@ -440,8 +449,8 @@ class TestCrudPaginate:
|
|||||||
|
|
||||||
result = await RoleCrud.paginate(db_session, page=3, items_per_page=10)
|
result = await RoleCrud.paginate(db_session, page=3, items_per_page=10)
|
||||||
|
|
||||||
assert len(result["data"]) == 5
|
assert len(result.data) == 5
|
||||||
assert result["pagination"]["has_more"] is False
|
assert result.pagination.has_more is False
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_paginate_with_filters(self, db_session: AsyncSession):
|
async def test_paginate_with_filters(self, db_session: AsyncSession):
|
||||||
@@ -463,7 +472,7 @@ class TestCrudPaginate:
|
|||||||
items_per_page=10,
|
items_per_page=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result["pagination"]["total_count"] == 5
|
assert result.pagination.total_count == 5
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_paginate_with_ordering(self, db_session: AsyncSession):
|
async def test_paginate_with_ordering(self, db_session: AsyncSession):
|
||||||
@@ -479,5 +488,713 @@ class TestCrudPaginate:
|
|||||||
items_per_page=10,
|
items_per_page=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
names = [r.name for r in result["data"]]
|
names = [r.name for r in result.data]
|
||||||
assert names == ["alpha", "bravo", "charlie"]
|
assert names == ["alpha", "bravo", "charlie"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestCrudJoins:
|
||||||
|
"""Tests for CRUD operations with joins."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_get_with_join(self, db_session: AsyncSession):
|
||||||
|
"""Get with inner join filters correctly."""
|
||||||
|
# Create user with posts
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="author", email="author@test.com"),
|
||||||
|
)
|
||||||
|
await PostCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostCreate(title="Post 1", author_id=user.id, is_published=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get user with join on published posts
|
||||||
|
fetched = await UserCrud.get(
|
||||||
|
db_session,
|
||||||
|
filters=[User.id == user.id, Post.is_published == True], # noqa: E712
|
||||||
|
joins=[(Post, Post.author_id == User.id)],
|
||||||
|
)
|
||||||
|
assert fetched.id == user.id
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_first_with_join(self, db_session: AsyncSession):
|
||||||
|
"""First with join returns matching record."""
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="writer", email="writer@test.com"),
|
||||||
|
)
|
||||||
|
await PostCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostCreate(title="Draft", author_id=user.id, is_published=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find user with unpublished posts
|
||||||
|
result = await UserCrud.first(
|
||||||
|
db_session,
|
||||||
|
filters=[Post.is_published == False], # noqa: E712
|
||||||
|
joins=[(Post, Post.author_id == User.id)],
|
||||||
|
)
|
||||||
|
assert result is not None
|
||||||
|
assert result.id == user.id
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_first_with_outer_join(self, db_session: AsyncSession):
|
||||||
|
"""First with outer join includes records without related data."""
|
||||||
|
# User without posts
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="no_posts", email="no_posts@test.com"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# With outer join, user should be found even without posts
|
||||||
|
result = await UserCrud.first(
|
||||||
|
db_session,
|
||||||
|
filters=[User.id == user.id],
|
||||||
|
joins=[(Post, Post.author_id == User.id)],
|
||||||
|
outer_join=True,
|
||||||
|
)
|
||||||
|
assert result is not None
|
||||||
|
assert result.id == user.id
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_get_multi_with_inner_join(self, db_session: AsyncSession):
|
||||||
|
"""Get multiple with inner join only returns matching records."""
|
||||||
|
# User with published post
|
||||||
|
user1 = await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="publisher", email="pub@test.com"),
|
||||||
|
)
|
||||||
|
await PostCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostCreate(title="Published", author_id=user1.id, is_published=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# User without posts
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="lurker", email="lurk@test.com"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Inner join should only return user with published post
|
||||||
|
users = await UserCrud.get_multi(
|
||||||
|
db_session,
|
||||||
|
joins=[(Post, Post.author_id == User.id)],
|
||||||
|
filters=[Post.is_published == True], # noqa: E712
|
||||||
|
)
|
||||||
|
assert len(users) == 1
|
||||||
|
assert users[0].username == "publisher"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_get_multi_with_outer_join(self, db_session: AsyncSession):
|
||||||
|
"""Get multiple with outer join includes all records."""
|
||||||
|
# User with post
|
||||||
|
user1 = await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="has_post", email="has@test.com"),
|
||||||
|
)
|
||||||
|
await PostCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostCreate(title="My Post", author_id=user1.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
# User without posts
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="no_post", email="no@test.com"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Outer join should return both users
|
||||||
|
users = await UserCrud.get_multi(
|
||||||
|
db_session,
|
||||||
|
joins=[(Post, Post.author_id == User.id)],
|
||||||
|
outer_join=True,
|
||||||
|
)
|
||||||
|
assert len(users) == 2
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_count_with_join(self, db_session: AsyncSession):
|
||||||
|
"""Count with join counts correctly."""
|
||||||
|
# Create users with different post statuses
|
||||||
|
user1 = await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="active_author", email="active@test.com"),
|
||||||
|
)
|
||||||
|
await PostCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostCreate(title="Published 1", author_id=user1.id, is_published=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
user2 = await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="draft_author", email="draft@test.com"),
|
||||||
|
)
|
||||||
|
await PostCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostCreate(title="Draft 1", author_id=user2.id, is_published=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Count users with published posts
|
||||||
|
count = await UserCrud.count(
|
||||||
|
db_session,
|
||||||
|
filters=[Post.is_published == True], # noqa: E712
|
||||||
|
joins=[(Post, Post.author_id == User.id)],
|
||||||
|
)
|
||||||
|
assert count == 1
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_exists_with_join(self, db_session: AsyncSession):
|
||||||
|
"""Exists with join checks correctly."""
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="poster", email="poster@test.com"),
|
||||||
|
)
|
||||||
|
await PostCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostCreate(title="Exists Post", author_id=user.id, is_published=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if user with published post exists
|
||||||
|
exists = await UserCrud.exists(
|
||||||
|
db_session,
|
||||||
|
filters=[Post.is_published == True], # noqa: E712
|
||||||
|
joins=[(Post, Post.author_id == User.id)],
|
||||||
|
)
|
||||||
|
assert exists is True
|
||||||
|
|
||||||
|
# Check if user with specific title exists
|
||||||
|
exists = await UserCrud.exists(
|
||||||
|
db_session,
|
||||||
|
filters=[Post.title == "Nonexistent"],
|
||||||
|
joins=[(Post, Post.author_id == User.id)],
|
||||||
|
)
|
||||||
|
assert exists is False
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_paginate_with_join(self, db_session: AsyncSession):
|
||||||
|
"""Paginate with join works correctly."""
|
||||||
|
# Create users with posts
|
||||||
|
for i in range(5):
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username=f"author{i}", email=f"author{i}@test.com"),
|
||||||
|
)
|
||||||
|
await PostCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostCreate(
|
||||||
|
title=f"Post {i}",
|
||||||
|
author_id=user.id,
|
||||||
|
is_published=i % 2 == 0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Paginate users with published posts
|
||||||
|
result = await UserCrud.paginate(
|
||||||
|
db_session,
|
||||||
|
joins=[(Post, Post.author_id == User.id)],
|
||||||
|
filters=[Post.is_published == True], # noqa: E712
|
||||||
|
page=1,
|
||||||
|
items_per_page=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.pagination.total_count == 3
|
||||||
|
assert len(result.data) == 3
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_paginate_with_outer_join(self, db_session: AsyncSession):
|
||||||
|
"""Paginate with outer join includes all records."""
|
||||||
|
# User with post
|
||||||
|
user1 = await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="with_post", email="with@test.com"),
|
||||||
|
)
|
||||||
|
await PostCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostCreate(title="A Post", author_id=user1.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
# User without post
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="without_post", email="without@test.com"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Paginate with outer join
|
||||||
|
result = await UserCrud.paginate(
|
||||||
|
db_session,
|
||||||
|
joins=[(Post, Post.author_id == User.id)],
|
||||||
|
outer_join=True,
|
||||||
|
page=1,
|
||||||
|
items_per_page=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.pagination.total_count == 2
|
||||||
|
assert len(result.data) == 2
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_multiple_joins(self, db_session: AsyncSession):
|
||||||
|
"""Multiple joins can be applied."""
|
||||||
|
role = await RoleCrud.create(db_session, RoleCreate(name="author_role"))
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(
|
||||||
|
username="multi_join",
|
||||||
|
email="multi@test.com",
|
||||||
|
role_id=role.id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await PostCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostCreate(title="Multi Join Post", author_id=user.id, is_published=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Join both Role and Post
|
||||||
|
users = await UserCrud.get_multi(
|
||||||
|
db_session,
|
||||||
|
joins=[
|
||||||
|
(Role, Role.id == User.role_id),
|
||||||
|
(Post, Post.author_id == User.id),
|
||||||
|
],
|
||||||
|
filters=[Role.name == "author_role", Post.is_published == True], # noqa: E712
|
||||||
|
)
|
||||||
|
assert len(users) == 1
|
||||||
|
assert users[0].username == "multi_join"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsResponse:
|
||||||
|
"""Tests for as_response parameter."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_create_as_response(self, db_session: AsyncSession):
|
||||||
|
"""Create with as_response=True returns Response."""
|
||||||
|
from fastapi_toolsets.schemas import Response
|
||||||
|
|
||||||
|
data = RoleCreate(name="response_role")
|
||||||
|
result = await RoleCrud.create(db_session, data, as_response=True)
|
||||||
|
|
||||||
|
assert isinstance(result, Response)
|
||||||
|
assert result.data is not None
|
||||||
|
assert result.data.name == "response_role"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_get_as_response(self, db_session: AsyncSession):
|
||||||
|
"""Get with as_response=True returns Response."""
|
||||||
|
from fastapi_toolsets.schemas import Response
|
||||||
|
|
||||||
|
created = await RoleCrud.create(db_session, RoleCreate(name="get_response"))
|
||||||
|
result = await RoleCrud.get(
|
||||||
|
db_session, [Role.id == created.id], as_response=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, Response)
|
||||||
|
assert result.data is not None
|
||||||
|
assert result.data.id == created.id
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_update_as_response(self, db_session: AsyncSession):
|
||||||
|
"""Update with as_response=True returns Response."""
|
||||||
|
from fastapi_toolsets.schemas import Response
|
||||||
|
|
||||||
|
created = await RoleCrud.create(db_session, RoleCreate(name="old_name"))
|
||||||
|
result = await RoleCrud.update(
|
||||||
|
db_session,
|
||||||
|
RoleUpdate(name="new_name"),
|
||||||
|
[Role.id == created.id],
|
||||||
|
as_response=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, Response)
|
||||||
|
assert result.data is not None
|
||||||
|
assert result.data.name == "new_name"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_delete_as_response(self, db_session: AsyncSession):
|
||||||
|
"""Delete with as_response=True returns Response."""
|
||||||
|
from fastapi_toolsets.schemas import Response
|
||||||
|
|
||||||
|
created = await RoleCrud.create(db_session, RoleCreate(name="to_delete"))
|
||||||
|
result = await RoleCrud.delete(
|
||||||
|
db_session, [Role.id == created.id], as_response=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, Response)
|
||||||
|
assert result.data is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestCrudFactoryM2M:
|
||||||
|
"""Tests for CrudFactory with m2m_fields parameter."""
|
||||||
|
|
||||||
|
def test_creates_crud_with_m2m_fields(self):
|
||||||
|
"""CrudFactory configures m2m_fields on the class."""
|
||||||
|
crud = CrudFactory(Post, m2m_fields={"tag_ids": Post.tags})
|
||||||
|
assert crud.m2m_fields is not None
|
||||||
|
assert "tag_ids" in crud.m2m_fields
|
||||||
|
|
||||||
|
def test_creates_crud_without_m2m_fields(self):
|
||||||
|
"""CrudFactory without m2m_fields has None."""
|
||||||
|
crud = CrudFactory(Post)
|
||||||
|
assert crud.m2m_fields is None
|
||||||
|
|
||||||
|
def test_m2m_schema_fields(self):
|
||||||
|
"""_m2m_schema_fields returns correct field names."""
|
||||||
|
crud = CrudFactory(Post, m2m_fields={"tag_ids": Post.tags})
|
||||||
|
assert crud._m2m_schema_fields() == {"tag_ids"}
|
||||||
|
|
||||||
|
def test_m2m_schema_fields_empty_when_none(self):
|
||||||
|
"""_m2m_schema_fields returns empty set when no m2m_fields."""
|
||||||
|
crud = CrudFactory(Post)
|
||||||
|
assert crud._m2m_schema_fields() == set()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_resolve_m2m_returns_empty_without_m2m_fields(
|
||||||
|
self, db_session: AsyncSession
|
||||||
|
):
|
||||||
|
"""_resolve_m2m returns empty dict when m2m_fields is not configured."""
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
class DummySchema(BaseModel):
|
||||||
|
name: str
|
||||||
|
|
||||||
|
result = await PostCrud._resolve_m2m(db_session, DummySchema(name="test"))
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
class TestM2MResolveNone:
|
||||||
|
"""Tests for _resolve_m2m when IDs field is None."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_resolve_m2m_with_none_ids(self, db_session: AsyncSession):
|
||||||
|
"""_resolve_m2m sets empty list when ids value is None."""
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
class SchemaWithNullableTags(BaseModel):
|
||||||
|
tag_ids: list[uuid.UUID] | None = None
|
||||||
|
|
||||||
|
result = await PostM2MCrud._resolve_m2m(
|
||||||
|
db_session, SchemaWithNullableTags(tag_ids=None)
|
||||||
|
)
|
||||||
|
assert result == {"tags": []}
|
||||||
|
|
||||||
|
|
||||||
|
class TestM2MCreate:
|
||||||
|
"""Tests for create with M2M relationships."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_create_with_m2m_tags(self, db_session: AsyncSession):
|
||||||
|
"""Create a post with M2M tags resolves tag IDs."""
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="author", email="author@test.com")
|
||||||
|
)
|
||||||
|
tag1 = await TagCrud.create(db_session, TagCreate(name="python"))
|
||||||
|
tag2 = await TagCrud.create(db_session, TagCreate(name="fastapi"))
|
||||||
|
|
||||||
|
post = await PostM2MCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostM2MCreate(
|
||||||
|
title="M2M Post",
|
||||||
|
author_id=user.id,
|
||||||
|
tag_ids=[tag1.id, tag2.id],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert post.id is not None
|
||||||
|
assert post.title == "M2M Post"
|
||||||
|
|
||||||
|
# Reload with tags eagerly loaded
|
||||||
|
loaded = await PostM2MCrud.get(
|
||||||
|
db_session,
|
||||||
|
[Post.id == post.id],
|
||||||
|
load_options=[selectinload(Post.tags)],
|
||||||
|
)
|
||||||
|
tag_names = sorted(t.name for t in loaded.tags)
|
||||||
|
assert tag_names == ["fastapi", "python"]
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_create_with_empty_m2m(self, db_session: AsyncSession):
|
||||||
|
"""Create a post with empty tag_ids list works."""
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="author", email="author@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
post = await PostM2MCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostM2MCreate(
|
||||||
|
title="No Tags Post",
|
||||||
|
author_id=user.id,
|
||||||
|
tag_ids=[],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert post.id is not None
|
||||||
|
|
||||||
|
loaded = await PostM2MCrud.get(
|
||||||
|
db_session,
|
||||||
|
[Post.id == post.id],
|
||||||
|
load_options=[selectinload(Post.tags)],
|
||||||
|
)
|
||||||
|
assert loaded.tags == []
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_create_with_default_m2m(self, db_session: AsyncSession):
|
||||||
|
"""Create a post using default tag_ids (empty list) works."""
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="author", email="author@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
post = await PostM2MCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostM2MCreate(title="Default Tags", author_id=user.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
loaded = await PostM2MCrud.get(
|
||||||
|
db_session,
|
||||||
|
[Post.id == post.id],
|
||||||
|
load_options=[selectinload(Post.tags)],
|
||||||
|
)
|
||||||
|
assert loaded.tags == []
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_create_with_nonexistent_tag_id_raises(
|
||||||
|
self, db_session: AsyncSession
|
||||||
|
):
|
||||||
|
"""Create with a nonexistent tag ID raises NotFoundError."""
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="author", email="author@test.com")
|
||||||
|
)
|
||||||
|
tag = await TagCrud.create(db_session, TagCreate(name="valid"))
|
||||||
|
fake_id = uuid.uuid4()
|
||||||
|
|
||||||
|
with pytest.raises(NotFoundError):
|
||||||
|
await PostM2MCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostM2MCreate(
|
||||||
|
title="Bad Tags",
|
||||||
|
author_id=user.id,
|
||||||
|
tag_ids=[tag.id, fake_id],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_create_with_single_tag(self, db_session: AsyncSession):
|
||||||
|
"""Create with a single tag works correctly."""
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="author", email="author@test.com")
|
||||||
|
)
|
||||||
|
tag = await TagCrud.create(db_session, TagCreate(name="solo"))
|
||||||
|
|
||||||
|
post = await PostM2MCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostM2MCreate(
|
||||||
|
title="Single Tag",
|
||||||
|
author_id=user.id,
|
||||||
|
tag_ids=[tag.id],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
loaded = await PostM2MCrud.get(
|
||||||
|
db_session,
|
||||||
|
[Post.id == post.id],
|
||||||
|
load_options=[selectinload(Post.tags)],
|
||||||
|
)
|
||||||
|
assert len(loaded.tags) == 1
|
||||||
|
assert loaded.tags[0].name == "solo"
|
||||||
|
|
||||||
|
|
||||||
|
class TestM2MUpdate:
|
||||||
|
"""Tests for update with M2M relationships."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_update_m2m_tags(self, db_session: AsyncSession):
|
||||||
|
"""Update replaces M2M tags when tag_ids is set."""
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="author", email="author@test.com")
|
||||||
|
)
|
||||||
|
tag1 = await TagCrud.create(db_session, TagCreate(name="old_tag"))
|
||||||
|
tag2 = await TagCrud.create(db_session, TagCreate(name="new_tag"))
|
||||||
|
|
||||||
|
# Create with tag1
|
||||||
|
post = await PostM2MCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostM2MCreate(
|
||||||
|
title="Update Test",
|
||||||
|
author_id=user.id,
|
||||||
|
tag_ids=[tag1.id],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update to tag2
|
||||||
|
updated = await PostM2MCrud.update(
|
||||||
|
db_session,
|
||||||
|
PostM2MUpdate(tag_ids=[tag2.id]),
|
||||||
|
[Post.id == post.id],
|
||||||
|
)
|
||||||
|
|
||||||
|
loaded = await PostM2MCrud.get(
|
||||||
|
db_session,
|
||||||
|
[Post.id == updated.id],
|
||||||
|
load_options=[selectinload(Post.tags)],
|
||||||
|
)
|
||||||
|
assert len(loaded.tags) == 1
|
||||||
|
assert loaded.tags[0].name == "new_tag"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_update_without_m2m_preserves_tags(self, db_session: AsyncSession):
|
||||||
|
"""Update without setting tag_ids preserves existing tags."""
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="author", email="author@test.com")
|
||||||
|
)
|
||||||
|
tag = await TagCrud.create(db_session, TagCreate(name="keep_me"))
|
||||||
|
|
||||||
|
post = await PostM2MCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostM2MCreate(
|
||||||
|
title="Keep Tags",
|
||||||
|
author_id=user.id,
|
||||||
|
tag_ids=[tag.id],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update only title, tag_ids not set
|
||||||
|
await PostM2MCrud.update(
|
||||||
|
db_session,
|
||||||
|
PostM2MUpdate(title="Updated Title"),
|
||||||
|
[Post.id == post.id],
|
||||||
|
)
|
||||||
|
|
||||||
|
loaded = await PostM2MCrud.get(
|
||||||
|
db_session,
|
||||||
|
[Post.id == post.id],
|
||||||
|
load_options=[selectinload(Post.tags)],
|
||||||
|
)
|
||||||
|
assert loaded.title == "Updated Title"
|
||||||
|
assert len(loaded.tags) == 1
|
||||||
|
assert loaded.tags[0].name == "keep_me"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_update_clear_m2m_tags(self, db_session: AsyncSession):
|
||||||
|
"""Update with empty tag_ids clears all tags."""
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="author", email="author@test.com")
|
||||||
|
)
|
||||||
|
tag = await TagCrud.create(db_session, TagCreate(name="remove_me"))
|
||||||
|
|
||||||
|
post = await PostM2MCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostM2MCreate(
|
||||||
|
title="Clear Tags",
|
||||||
|
author_id=user.id,
|
||||||
|
tag_ids=[tag.id],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Explicitly set tag_ids to empty list
|
||||||
|
await PostM2MCrud.update(
|
||||||
|
db_session,
|
||||||
|
PostM2MUpdate(tag_ids=[]),
|
||||||
|
[Post.id == post.id],
|
||||||
|
)
|
||||||
|
|
||||||
|
loaded = await PostM2MCrud.get(
|
||||||
|
db_session,
|
||||||
|
[Post.id == post.id],
|
||||||
|
load_options=[selectinload(Post.tags)],
|
||||||
|
)
|
||||||
|
assert loaded.tags == []
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_update_m2m_with_nonexistent_id_raises(
|
||||||
|
self, db_session: AsyncSession
|
||||||
|
):
|
||||||
|
"""Update with nonexistent tag ID raises NotFoundError."""
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="author", email="author@test.com")
|
||||||
|
)
|
||||||
|
tag = await TagCrud.create(db_session, TagCreate(name="existing"))
|
||||||
|
|
||||||
|
post = await PostM2MCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostM2MCreate(
|
||||||
|
title="Bad Update",
|
||||||
|
author_id=user.id,
|
||||||
|
tag_ids=[tag.id],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
fake_id = uuid.uuid4()
|
||||||
|
with pytest.raises(NotFoundError):
|
||||||
|
await PostM2MCrud.update(
|
||||||
|
db_session,
|
||||||
|
PostM2MUpdate(tag_ids=[fake_id]),
|
||||||
|
[Post.id == post.id],
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_update_m2m_and_scalar_fields(self, db_session: AsyncSession):
|
||||||
|
"""Update both scalar fields and M2M tags together."""
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="author", email="author@test.com")
|
||||||
|
)
|
||||||
|
tag1 = await TagCrud.create(db_session, TagCreate(name="tag1"))
|
||||||
|
tag2 = await TagCrud.create(db_session, TagCreate(name="tag2"))
|
||||||
|
|
||||||
|
post = await PostM2MCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostM2MCreate(
|
||||||
|
title="Original",
|
||||||
|
author_id=user.id,
|
||||||
|
tag_ids=[tag1.id],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update title and tags simultaneously
|
||||||
|
await PostM2MCrud.update(
|
||||||
|
db_session,
|
||||||
|
PostM2MUpdate(title="Updated", tag_ids=[tag1.id, tag2.id]),
|
||||||
|
[Post.id == post.id],
|
||||||
|
)
|
||||||
|
|
||||||
|
loaded = await PostM2MCrud.get(
|
||||||
|
db_session,
|
||||||
|
[Post.id == post.id],
|
||||||
|
load_options=[selectinload(Post.tags)],
|
||||||
|
)
|
||||||
|
assert loaded.title == "Updated"
|
||||||
|
tag_names = sorted(t.name for t in loaded.tags)
|
||||||
|
assert tag_names == ["tag1", "tag2"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestM2MWithNonM2MCrud:
|
||||||
|
"""Tests that non-M2M CRUD classes are unaffected."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_create_without_m2m_unchanged(self, db_session: AsyncSession):
|
||||||
|
"""Regular PostCrud.create still works without M2M logic."""
|
||||||
|
from .conftest import PostCreate
|
||||||
|
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="author", email="author@test.com")
|
||||||
|
)
|
||||||
|
post = await PostCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostCreate(title="Plain Post", author_id=user.id),
|
||||||
|
)
|
||||||
|
assert post.id is not None
|
||||||
|
assert post.title == "Plain Post"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_update_without_m2m_unchanged(self, db_session: AsyncSession):
|
||||||
|
"""Regular PostCrud.update still works without M2M logic."""
|
||||||
|
from .conftest import PostCreate, PostUpdate
|
||||||
|
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="author", email="author@test.com")
|
||||||
|
)
|
||||||
|
post = await PostCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostCreate(title="Plain Post", author_id=user.id),
|
||||||
|
)
|
||||||
|
updated = await PostCrud.update(
|
||||||
|
db_session,
|
||||||
|
PostUpdate(title="Updated Plain"),
|
||||||
|
[Post.id == post.id],
|
||||||
|
)
|
||||||
|
assert updated.title == "Updated Plain"
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ class TestPaginateSearch:
|
|||||||
search_fields=[User.username],
|
search_fields=[User.username],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result["pagination"]["total_count"] == 2
|
assert result.pagination.total_count == 2
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_search_multiple_columns(self, db_session: AsyncSession):
|
async def test_search_multiple_columns(self, db_session: AsyncSession):
|
||||||
@@ -57,7 +57,7 @@ class TestPaginateSearch:
|
|||||||
search_fields=[User.username, User.email],
|
search_fields=[User.username, User.email],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result["pagination"]["total_count"] == 2
|
assert result.pagination.total_count == 2
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_search_relationship_depth1(self, db_session: AsyncSession):
|
async def test_search_relationship_depth1(self, db_session: AsyncSession):
|
||||||
@@ -84,7 +84,7 @@ class TestPaginateSearch:
|
|||||||
search_fields=[(User.role, Role.name)],
|
search_fields=[(User.role, Role.name)],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result["pagination"]["total_count"] == 2
|
assert result.pagination.total_count == 2
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_search_mixed_direct_and_relation(self, db_session: AsyncSession):
|
async def test_search_mixed_direct_and_relation(self, db_session: AsyncSession):
|
||||||
@@ -102,7 +102,7 @@ class TestPaginateSearch:
|
|||||||
search_fields=[User.username, (User.role, Role.name)],
|
search_fields=[User.username, (User.role, Role.name)],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result["pagination"]["total_count"] == 1
|
assert result.pagination.total_count == 1
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_search_case_insensitive(self, db_session: AsyncSession):
|
async def test_search_case_insensitive(self, db_session: AsyncSession):
|
||||||
@@ -117,7 +117,7 @@ class TestPaginateSearch:
|
|||||||
search_fields=[User.username],
|
search_fields=[User.username],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result["pagination"]["total_count"] == 1
|
assert result.pagination.total_count == 1
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_search_case_sensitive(self, db_session: AsyncSession):
|
async def test_search_case_sensitive(self, db_session: AsyncSession):
|
||||||
@@ -132,7 +132,7 @@ class TestPaginateSearch:
|
|||||||
search=SearchConfig(query="johndoe", case_sensitive=True),
|
search=SearchConfig(query="johndoe", case_sensitive=True),
|
||||||
search_fields=[User.username],
|
search_fields=[User.username],
|
||||||
)
|
)
|
||||||
assert result["pagination"]["total_count"] == 0
|
assert result.pagination.total_count == 0
|
||||||
|
|
||||||
# Should find (case match)
|
# Should find (case match)
|
||||||
result = await UserCrud.paginate(
|
result = await UserCrud.paginate(
|
||||||
@@ -140,7 +140,7 @@ class TestPaginateSearch:
|
|||||||
search=SearchConfig(query="JohnDoe", case_sensitive=True),
|
search=SearchConfig(query="JohnDoe", case_sensitive=True),
|
||||||
search_fields=[User.username],
|
search_fields=[User.username],
|
||||||
)
|
)
|
||||||
assert result["pagination"]["total_count"] == 1
|
assert result.pagination.total_count == 1
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_search_empty_query(self, db_session: AsyncSession):
|
async def test_search_empty_query(self, db_session: AsyncSession):
|
||||||
@@ -153,10 +153,10 @@ class TestPaginateSearch:
|
|||||||
)
|
)
|
||||||
|
|
||||||
result = await UserCrud.paginate(db_session, search="")
|
result = await UserCrud.paginate(db_session, search="")
|
||||||
assert result["pagination"]["total_count"] == 2
|
assert result.pagination.total_count == 2
|
||||||
|
|
||||||
result = await UserCrud.paginate(db_session, search=None)
|
result = await UserCrud.paginate(db_session, search=None)
|
||||||
assert result["pagination"]["total_count"] == 2
|
assert result.pagination.total_count == 2
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_search_with_existing_filters(self, db_session: AsyncSession):
|
async def test_search_with_existing_filters(self, db_session: AsyncSession):
|
||||||
@@ -177,8 +177,8 @@ class TestPaginateSearch:
|
|||||||
search_fields=[User.username],
|
search_fields=[User.username],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result["pagination"]["total_count"] == 1
|
assert result.pagination.total_count == 1
|
||||||
assert result["data"][0].username == "active_john"
|
assert result.data[0].username == "active_john"
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_search_auto_detect_fields(self, db_session: AsyncSession):
|
async def test_search_auto_detect_fields(self, db_session: AsyncSession):
|
||||||
@@ -189,7 +189,7 @@ class TestPaginateSearch:
|
|||||||
|
|
||||||
result = await UserCrud.paginate(db_session, search="findme")
|
result = await UserCrud.paginate(db_session, search="findme")
|
||||||
|
|
||||||
assert result["pagination"]["total_count"] == 1
|
assert result.pagination.total_count == 1
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_search_no_results(self, db_session: AsyncSession):
|
async def test_search_no_results(self, db_session: AsyncSession):
|
||||||
@@ -204,8 +204,8 @@ class TestPaginateSearch:
|
|||||||
search_fields=[User.username],
|
search_fields=[User.username],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result["pagination"]["total_count"] == 0
|
assert result.pagination.total_count == 0
|
||||||
assert result["data"] == []
|
assert result.data == []
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_search_with_pagination(self, db_session: AsyncSession):
|
async def test_search_with_pagination(self, db_session: AsyncSession):
|
||||||
@@ -224,9 +224,9 @@ class TestPaginateSearch:
|
|||||||
items_per_page=5,
|
items_per_page=5,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result["pagination"]["total_count"] == 15
|
assert result.pagination.total_count == 15
|
||||||
assert len(result["data"]) == 5
|
assert len(result.data) == 5
|
||||||
assert result["pagination"]["has_more"] is True
|
assert result.pagination.has_more is True
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_search_null_relationship(self, db_session: AsyncSession):
|
async def test_search_null_relationship(self, db_session: AsyncSession):
|
||||||
@@ -248,7 +248,7 @@ class TestPaginateSearch:
|
|||||||
search_fields=[User.username],
|
search_fields=[User.username],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result["pagination"]["total_count"] == 2
|
assert result.pagination.total_count == 2
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_search_with_order_by(self, db_session: AsyncSession):
|
async def test_search_with_order_by(self, db_session: AsyncSession):
|
||||||
@@ -270,8 +270,8 @@ class TestPaginateSearch:
|
|||||||
order_by=User.username,
|
order_by=User.username,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result["pagination"]["total_count"] == 3
|
assert result.pagination.total_count == 3
|
||||||
usernames = [u.username for u in result["data"]]
|
usernames = [u.username for u in result.data]
|
||||||
assert usernames == ["alice", "bob", "charlie"]
|
assert usernames == ["alice", "bob", "charlie"]
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@@ -292,8 +292,8 @@ class TestPaginateSearch:
|
|||||||
search_fields=[User.id, User.username],
|
search_fields=[User.id, User.username],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result["pagination"]["total_count"] == 1
|
assert result.pagination.total_count == 1
|
||||||
assert result["data"][0].id == user_id
|
assert result.data[0].id == user_id
|
||||||
|
|
||||||
|
|
||||||
class TestSearchConfig:
|
class TestSearchConfig:
|
||||||
@@ -318,8 +318,8 @@ class TestSearchConfig:
|
|||||||
search_fields=[User.username, User.email],
|
search_fields=[User.username, User.email],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result["pagination"]["total_count"] == 1
|
assert result.pagination.total_count == 1
|
||||||
assert result["data"][0].username == "john_test"
|
assert result.data[0].username == "john_test"
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_search_config_with_fields(self, db_session: AsyncSession):
|
async def test_search_config_with_fields(self, db_session: AsyncSession):
|
||||||
@@ -333,7 +333,7 @@ class TestSearchConfig:
|
|||||||
search=SearchConfig(query="findme", fields=[User.email]),
|
search=SearchConfig(query="findme", fields=[User.email]),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result["pagination"]["total_count"] == 1
|
assert result.pagination.total_count == 1
|
||||||
|
|
||||||
|
|
||||||
class TestNoSearchableFieldsError:
|
class TestNoSearchableFieldsError:
|
||||||
|
|||||||
102
tests/test_db.py
102
tests/test_db.py
@@ -1,5 +1,8 @@
|
|||||||
"""Tests for fastapi_toolsets.db module."""
|
"""Tests for fastapi_toolsets.db module."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
@@ -9,6 +12,7 @@ from fastapi_toolsets.db import (
|
|||||||
create_db_dependency,
|
create_db_dependency,
|
||||||
get_transaction,
|
get_transaction,
|
||||||
lock_tables,
|
lock_tables,
|
||||||
|
wait_for_row_change,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .conftest import DATABASE_URL, Base, Role, RoleCrud, User
|
from .conftest import DATABASE_URL, Base, Role, RoleCrud, User
|
||||||
@@ -241,3 +245,101 @@ class TestLockTables:
|
|||||||
|
|
||||||
result = await RoleCrud.first(db_session, [Role.name == "lock_rollback_role"])
|
result = await RoleCrud.first(db_session, [Role.name == "lock_rollback_role"])
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestWaitForRowChange:
|
||||||
|
"""Tests for wait_for_row_change polling function."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_detects_update(self, db_session: AsyncSession, engine):
|
||||||
|
"""Returns updated instance when a column value changes."""
|
||||||
|
role = Role(name="watch_role")
|
||||||
|
db_session.add(role)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
async def update_later():
|
||||||
|
await asyncio.sleep(0.15)
|
||||||
|
factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
async with factory() as other:
|
||||||
|
r = await other.get(Role, role.id)
|
||||||
|
assert r is not None
|
||||||
|
r.name = "updated_role"
|
||||||
|
await other.commit()
|
||||||
|
|
||||||
|
update_task = asyncio.create_task(update_later())
|
||||||
|
result = await wait_for_row_change(db_session, Role, role.id, interval=0.05)
|
||||||
|
await update_task
|
||||||
|
|
||||||
|
assert result.name == "updated_role"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_watches_specific_columns(self, db_session: AsyncSession, engine):
|
||||||
|
"""Only triggers on changes to specified columns."""
|
||||||
|
user = User(username="testuser", email="test@example.com")
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
async def update_later():
|
||||||
|
factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
# First: change email (not watched) — should not trigger
|
||||||
|
await asyncio.sleep(0.15)
|
||||||
|
async with factory() as other:
|
||||||
|
u = await other.get(User, user.id)
|
||||||
|
assert u is not None
|
||||||
|
u.email = "new@example.com"
|
||||||
|
await other.commit()
|
||||||
|
# Second: change username (watched) — should trigger
|
||||||
|
await asyncio.sleep(0.15)
|
||||||
|
async with factory() as other:
|
||||||
|
u = await other.get(User, user.id)
|
||||||
|
assert u is not None
|
||||||
|
u.username = "newuser"
|
||||||
|
await other.commit()
|
||||||
|
|
||||||
|
update_task = asyncio.create_task(update_later())
|
||||||
|
result = await wait_for_row_change(
|
||||||
|
db_session, User, user.id, columns=["username"], interval=0.05
|
||||||
|
)
|
||||||
|
await update_task
|
||||||
|
|
||||||
|
assert result.username == "newuser"
|
||||||
|
assert result.email == "new@example.com"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_nonexistent_row_raises(self, db_session: AsyncSession):
|
||||||
|
"""Raises LookupError when the row does not exist."""
|
||||||
|
fake_id = uuid.uuid4()
|
||||||
|
with pytest.raises(LookupError, match="not found"):
|
||||||
|
await wait_for_row_change(db_session, Role, fake_id, interval=0.05)
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_timeout_raises(self, db_session: AsyncSession):
|
||||||
|
"""Raises TimeoutError when no change is detected within timeout."""
|
||||||
|
role = Role(name="timeout_role")
|
||||||
|
db_session.add(role)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
with pytest.raises(TimeoutError):
|
||||||
|
await wait_for_row_change(
|
||||||
|
db_session, Role, role.id, interval=0.05, timeout=0.2
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_deleted_row_raises(self, db_session: AsyncSession, engine):
|
||||||
|
"""Raises LookupError when the row is deleted during polling."""
|
||||||
|
role = Role(name="delete_role")
|
||||||
|
db_session.add(role)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
async def delete_later():
|
||||||
|
await asyncio.sleep(0.15)
|
||||||
|
factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
async with factory() as other:
|
||||||
|
r = await other.get(Role, role.id)
|
||||||
|
await other.delete(r)
|
||||||
|
await other.commit()
|
||||||
|
|
||||||
|
delete_task = asyncio.create_task(delete_later())
|
||||||
|
with pytest.raises(LookupError):
|
||||||
|
await wait_for_row_change(db_session, Role, role.id, interval=0.05)
|
||||||
|
await delete_task
|
||||||
|
|||||||
186
tests/test_dependencies.py
Normal file
186
tests/test_dependencies.py
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
"""Tests for fastapi_toolsets.dependencies module."""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import uuid
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.params import Depends
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from fastapi_toolsets.dependencies import BodyDependency, PathDependency
|
||||||
|
|
||||||
|
from .conftest import Role, RoleCreate, RoleCrud, User
|
||||||
|
|
||||||
|
|
||||||
|
async def mock_get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
"""Mock session dependency for testing."""
|
||||||
|
yield None
|
||||||
|
|
||||||
|
|
||||||
|
class TestPathDependency:
|
||||||
|
"""Tests for PathDependency factory."""
|
||||||
|
|
||||||
|
def test_returns_depends_instance(self):
|
||||||
|
"""PathDependency returns a Depends instance."""
|
||||||
|
dep = PathDependency(Role, Role.id, session_dep=mock_get_db)
|
||||||
|
assert isinstance(dep, Depends)
|
||||||
|
|
||||||
|
def test_signature_has_default_param_name(self):
|
||||||
|
"""PathDependency uses model_field as default param name."""
|
||||||
|
dep = cast(Any, PathDependency(Role, Role.id, session_dep=mock_get_db))
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
params = list(sig.parameters.keys())
|
||||||
|
|
||||||
|
assert "role_id" in params
|
||||||
|
assert "session" in params
|
||||||
|
|
||||||
|
def test_signature_has_correct_type_annotation(self):
|
||||||
|
"""PathDependency uses field's python type for annotation."""
|
||||||
|
dep = cast(Any, PathDependency(Role, Role.id, session_dep=mock_get_db))
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
|
||||||
|
assert sig.parameters["role_id"].annotation == uuid.UUID
|
||||||
|
assert sig.parameters["session"].annotation == AsyncSession
|
||||||
|
|
||||||
|
def test_signature_session_has_depends_default(self):
|
||||||
|
"""PathDependency session param has Depends as default."""
|
||||||
|
dep = cast(Any, PathDependency(Role, Role.id, session_dep=mock_get_db))
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
|
||||||
|
assert isinstance(sig.parameters["session"].default, Depends)
|
||||||
|
|
||||||
|
def test_custom_param_name_in_signature(self):
|
||||||
|
"""PathDependency uses custom param_name in signature."""
|
||||||
|
dep = cast(
|
||||||
|
Any,
|
||||||
|
PathDependency(
|
||||||
|
Role, Role.id, session_dep=mock_get_db, param_name="role_uuid"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
params = list(sig.parameters.keys())
|
||||||
|
|
||||||
|
assert "role_uuid" in params
|
||||||
|
assert "id" not in params
|
||||||
|
|
||||||
|
def test_string_field_type(self):
|
||||||
|
"""PathDependency handles string field types."""
|
||||||
|
dep = cast(Any, PathDependency(User, User.username, session_dep=mock_get_db))
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
|
||||||
|
assert sig.parameters["user_username"].annotation is str
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_dependency_fetches_object(self, db_session):
|
||||||
|
"""PathDependency inner function fetches object from database."""
|
||||||
|
role = await RoleCrud.create(db_session, RoleCreate(name="test_role"))
|
||||||
|
|
||||||
|
dep = cast(Any, PathDependency(Role, Role.id, session_dep=mock_get_db))
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
result = await func(session=db_session, role_id=role.id)
|
||||||
|
|
||||||
|
assert result.id == role.id
|
||||||
|
assert result.name == "test_role"
|
||||||
|
|
||||||
|
|
||||||
|
class TestBodyDependency:
|
||||||
|
"""Tests for BodyDependency factory."""
|
||||||
|
|
||||||
|
def test_returns_depends_instance(self):
|
||||||
|
"""BodyDependency returns a Depends instance."""
|
||||||
|
dep = BodyDependency(
|
||||||
|
Role, Role.id, session_dep=mock_get_db, body_field="role_id"
|
||||||
|
)
|
||||||
|
assert isinstance(dep, Depends)
|
||||||
|
|
||||||
|
def test_signature_has_body_field_as_param(self):
|
||||||
|
"""BodyDependency uses body_field as param name."""
|
||||||
|
dep = cast(
|
||||||
|
Any,
|
||||||
|
BodyDependency(
|
||||||
|
Role, Role.id, session_dep=mock_get_db, body_field="role_id"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
params = list(sig.parameters.keys())
|
||||||
|
|
||||||
|
assert "role_id" in params
|
||||||
|
assert "session" in params
|
||||||
|
|
||||||
|
def test_signature_has_correct_type_annotation(self):
|
||||||
|
"""BodyDependency uses field's python type for annotation."""
|
||||||
|
dep = cast(
|
||||||
|
Any,
|
||||||
|
BodyDependency(
|
||||||
|
Role, Role.id, session_dep=mock_get_db, body_field="role_id"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
|
||||||
|
assert sig.parameters["role_id"].annotation == uuid.UUID
|
||||||
|
assert sig.parameters["session"].annotation == AsyncSession
|
||||||
|
|
||||||
|
def test_signature_session_has_depends_default(self):
|
||||||
|
"""BodyDependency session param has Depends as default."""
|
||||||
|
dep = cast(
|
||||||
|
Any,
|
||||||
|
BodyDependency(
|
||||||
|
Role, Role.id, session_dep=mock_get_db, body_field="role_id"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
|
||||||
|
assert isinstance(sig.parameters["session"].default, Depends)
|
||||||
|
|
||||||
|
def test_different_body_field_name(self):
|
||||||
|
"""BodyDependency can use any body_field name."""
|
||||||
|
dep = cast(
|
||||||
|
Any,
|
||||||
|
BodyDependency(
|
||||||
|
User, User.id, session_dep=mock_get_db, body_field="user_uuid"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
params = list(sig.parameters.keys())
|
||||||
|
|
||||||
|
assert "user_uuid" in params
|
||||||
|
assert "id" not in params
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_dependency_fetches_object(self, db_session):
|
||||||
|
"""BodyDependency inner function fetches object from database."""
|
||||||
|
role = await RoleCrud.create(db_session, RoleCreate(name="body_test_role"))
|
||||||
|
|
||||||
|
dep = cast(
|
||||||
|
Any,
|
||||||
|
BodyDependency(
|
||||||
|
Role, Role.id, session_dep=mock_get_db, body_field="role_id"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
func = dep.dependency
|
||||||
|
|
||||||
|
result = await func(session=db_session, role_id=role.id)
|
||||||
|
|
||||||
|
assert result.id == role.id
|
||||||
|
assert result.name == "body_test_role"
|
||||||
@@ -108,6 +108,24 @@ class TestGenerateErrorResponses:
|
|||||||
assert example["status"] == "FAIL"
|
assert example["status"] == "FAIL"
|
||||||
assert example["error_code"] == "RES-404"
|
assert example["error_code"] == "RES-404"
|
||||||
assert example["message"] == "Not Found"
|
assert example["message"] == "Not Found"
|
||||||
|
assert example["data"] is None
|
||||||
|
|
||||||
|
def test_response_example_with_data(self):
|
||||||
|
"""Generated response includes data when set on ApiError."""
|
||||||
|
|
||||||
|
class ErrorWithData(ApiException):
|
||||||
|
api_error = ApiError(
|
||||||
|
code=400,
|
||||||
|
msg="Bad Request",
|
||||||
|
desc="Invalid input.",
|
||||||
|
err_code="BAD-400",
|
||||||
|
data={"details": "some context"},
|
||||||
|
)
|
||||||
|
|
||||||
|
responses = generate_error_responses(ErrorWithData)
|
||||||
|
example = responses[400]["content"]["application/json"]["example"]
|
||||||
|
|
||||||
|
assert example["data"] == {"details": "some context"}
|
||||||
|
|
||||||
|
|
||||||
class TestInitExceptionsHandlers:
|
class TestInitExceptionsHandlers:
|
||||||
@@ -137,6 +155,59 @@ class TestInitExceptionsHandlers:
|
|||||||
assert data["error_code"] == "RES-404"
|
assert data["error_code"] == "RES-404"
|
||||||
assert data["message"] == "Not Found"
|
assert data["message"] == "Not Found"
|
||||||
|
|
||||||
|
def test_handles_api_exception_without_data(self):
|
||||||
|
"""ApiException without data returns null data field."""
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app)
|
||||||
|
|
||||||
|
@app.get("/error")
|
||||||
|
async def raise_error():
|
||||||
|
raise NotFoundError()
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/error")
|
||||||
|
|
||||||
|
assert response.status_code == 404
|
||||||
|
assert response.json()["data"] is None
|
||||||
|
|
||||||
|
def test_handles_api_exception_with_data(self):
|
||||||
|
"""ApiException with data returns the data payload."""
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app)
|
||||||
|
|
||||||
|
class CustomValidationError(ApiException):
|
||||||
|
api_error = ApiError(
|
||||||
|
code=422,
|
||||||
|
msg="Validation Error",
|
||||||
|
desc="1 validation error(s) detected",
|
||||||
|
err_code="CUSTOM-422",
|
||||||
|
data={
|
||||||
|
"errors": [
|
||||||
|
{
|
||||||
|
"field": "email",
|
||||||
|
"message": "invalid format",
|
||||||
|
"type": "value_error",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.get("/error")
|
||||||
|
async def raise_error():
|
||||||
|
raise CustomValidationError()
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/error")
|
||||||
|
|
||||||
|
assert response.status_code == 422
|
||||||
|
data = response.json()
|
||||||
|
assert data["data"] == {
|
||||||
|
"errors": [
|
||||||
|
{"field": "email", "message": "invalid format", "type": "value_error"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
assert data["error_code"] == "CUSTOM-422"
|
||||||
|
|
||||||
def test_handles_validation_error(self):
|
def test_handles_validation_error(self):
|
||||||
"""Handles validation errors with structured response."""
|
"""Handles validation errors with structured response."""
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|||||||
@@ -159,6 +159,178 @@ class TestFixtureRegistry:
|
|||||||
assert names == {"test_data"}
|
assert names == {"test_data"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestIncludeRegistry:
|
||||||
|
"""Tests for FixtureRegistry.include_registry method."""
|
||||||
|
|
||||||
|
def test_include_empty_registry(self):
|
||||||
|
"""Include an empty registry does nothing."""
|
||||||
|
main_registry = FixtureRegistry()
|
||||||
|
other_registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@main_registry.register
|
||||||
|
def roles():
|
||||||
|
return []
|
||||||
|
|
||||||
|
main_registry.include_registry(other_registry)
|
||||||
|
|
||||||
|
assert len(main_registry.get_all()) == 1
|
||||||
|
|
||||||
|
def test_include_registry_adds_fixtures(self):
|
||||||
|
"""Include registry adds all fixtures from the other registry."""
|
||||||
|
main_registry = FixtureRegistry()
|
||||||
|
other_registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@main_registry.register
|
||||||
|
def roles():
|
||||||
|
return []
|
||||||
|
|
||||||
|
@other_registry.register
|
||||||
|
def users():
|
||||||
|
return []
|
||||||
|
|
||||||
|
@other_registry.register
|
||||||
|
def posts():
|
||||||
|
return []
|
||||||
|
|
||||||
|
main_registry.include_registry(other_registry)
|
||||||
|
|
||||||
|
names = {f.name for f in main_registry.get_all()}
|
||||||
|
assert names == {"roles", "users", "posts"}
|
||||||
|
|
||||||
|
def test_include_registry_preserves_dependencies(self):
|
||||||
|
"""Include registry preserves fixture dependencies."""
|
||||||
|
main_registry = FixtureRegistry()
|
||||||
|
other_registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@main_registry.register
|
||||||
|
def roles():
|
||||||
|
return []
|
||||||
|
|
||||||
|
@other_registry.register(depends_on=["roles"])
|
||||||
|
def users():
|
||||||
|
return []
|
||||||
|
|
||||||
|
main_registry.include_registry(other_registry)
|
||||||
|
|
||||||
|
fixture = main_registry.get("users")
|
||||||
|
assert fixture.depends_on == ["roles"]
|
||||||
|
|
||||||
|
def test_include_registry_preserves_contexts(self):
|
||||||
|
"""Include registry preserves fixture contexts."""
|
||||||
|
main_registry = FixtureRegistry()
|
||||||
|
other_registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@other_registry.register(contexts=[Context.TESTING, Context.DEVELOPMENT])
|
||||||
|
def test_data():
|
||||||
|
return []
|
||||||
|
|
||||||
|
main_registry.include_registry(other_registry)
|
||||||
|
|
||||||
|
fixture = main_registry.get("test_data")
|
||||||
|
assert Context.TESTING.value in fixture.contexts
|
||||||
|
assert Context.DEVELOPMENT.value in fixture.contexts
|
||||||
|
|
||||||
|
def test_include_registry_raises_on_duplicate(self):
|
||||||
|
"""Include registry raises ValueError on duplicate fixture names."""
|
||||||
|
main_registry = FixtureRegistry()
|
||||||
|
other_registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@main_registry.register(name="roles")
|
||||||
|
def roles_main():
|
||||||
|
return []
|
||||||
|
|
||||||
|
@other_registry.register(name="roles")
|
||||||
|
def roles_other():
|
||||||
|
return []
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="already exists"):
|
||||||
|
main_registry.include_registry(other_registry)
|
||||||
|
|
||||||
|
def test_include_multiple_registries(self):
|
||||||
|
"""Include multiple registries sequentially."""
|
||||||
|
main_registry = FixtureRegistry()
|
||||||
|
dev_registry = FixtureRegistry()
|
||||||
|
test_registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@main_registry.register
|
||||||
|
def base():
|
||||||
|
return []
|
||||||
|
|
||||||
|
@dev_registry.register
|
||||||
|
def dev_data():
|
||||||
|
return []
|
||||||
|
|
||||||
|
@test_registry.register
|
||||||
|
def test_data():
|
||||||
|
return []
|
||||||
|
|
||||||
|
main_registry.include_registry(dev_registry)
|
||||||
|
main_registry.include_registry(test_registry)
|
||||||
|
|
||||||
|
names = {f.name for f in main_registry.get_all()}
|
||||||
|
assert names == {"base", "dev_data", "test_data"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestDefaultContexts:
|
||||||
|
"""Tests for FixtureRegistry default contexts."""
|
||||||
|
|
||||||
|
def test_default_contexts_applied_to_fixtures(self):
|
||||||
|
"""Default contexts are applied when no contexts specified."""
|
||||||
|
registry = FixtureRegistry(contexts=[Context.TESTING])
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def test_data():
|
||||||
|
return []
|
||||||
|
|
||||||
|
fixture = registry.get("test_data")
|
||||||
|
assert fixture.contexts == [Context.TESTING.value]
|
||||||
|
|
||||||
|
def test_explicit_contexts_override_default(self):
|
||||||
|
"""Explicit contexts override default contexts."""
|
||||||
|
registry = FixtureRegistry(contexts=[Context.TESTING])
|
||||||
|
|
||||||
|
@registry.register(contexts=[Context.PRODUCTION])
|
||||||
|
def prod_data():
|
||||||
|
return []
|
||||||
|
|
||||||
|
fixture = registry.get("prod_data")
|
||||||
|
assert fixture.contexts == [Context.PRODUCTION.value]
|
||||||
|
|
||||||
|
def test_no_default_contexts_uses_base(self):
|
||||||
|
"""Without default contexts, BASE is used."""
|
||||||
|
registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def data():
|
||||||
|
return []
|
||||||
|
|
||||||
|
fixture = registry.get("data")
|
||||||
|
assert fixture.contexts == [Context.BASE.value]
|
||||||
|
|
||||||
|
def test_multiple_default_contexts(self):
|
||||||
|
"""Multiple default contexts are applied."""
|
||||||
|
registry = FixtureRegistry(contexts=[Context.DEVELOPMENT, Context.TESTING])
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def dev_test_data():
|
||||||
|
return []
|
||||||
|
|
||||||
|
fixture = registry.get("dev_test_data")
|
||||||
|
assert Context.DEVELOPMENT.value in fixture.contexts
|
||||||
|
assert Context.TESTING.value in fixture.contexts
|
||||||
|
|
||||||
|
def test_default_contexts_with_string_values(self):
|
||||||
|
"""Default contexts work with string values."""
|
||||||
|
registry = FixtureRegistry(contexts=["custom_context"])
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def custom_data():
|
||||||
|
return []
|
||||||
|
|
||||||
|
fixture = registry.get("custom_data")
|
||||||
|
assert fixture.contexts == ["custom_context"]
|
||||||
|
|
||||||
|
|
||||||
class TestDependencyResolution:
|
class TestDependencyResolution:
|
||||||
"""Tests for fixture dependency resolution."""
|
"""Tests for fixture dependency resolution."""
|
||||||
|
|
||||||
@@ -572,8 +744,11 @@ class TestGetObjByAttr:
|
|||||||
assert user.username == "alice"
|
assert user.username == "alice"
|
||||||
|
|
||||||
def test_no_match_raises_stop_iteration(self):
|
def test_no_match_raises_stop_iteration(self):
|
||||||
"""Raises StopIteration when no object matches."""
|
"""Raises StopIteration with contextual message when no object matches."""
|
||||||
with pytest.raises(StopIteration):
|
with pytest.raises(
|
||||||
|
StopIteration,
|
||||||
|
match="No object with name=nonexistent found in fixture 'roles'",
|
||||||
|
):
|
||||||
get_obj_by_attr(self.roles, "name", "nonexistent")
|
get_obj_by_attr(self.roles, "name", "nonexistent")
|
||||||
|
|
||||||
def test_no_match_on_wrong_value_type(self):
|
def test_no_match_on_wrong_value_type(self):
|
||||||
|
|||||||
229
tests/test_imports.py
Normal file
229
tests/test_imports.py
Normal file
@@ -0,0 +1,229 @@
|
|||||||
|
"""Tests for optional dependency import guards."""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import sys
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from fastapi_toolsets._imports import require_extra
|
||||||
|
|
||||||
|
|
||||||
|
class TestRequireExtra:
|
||||||
|
"""Tests for the require_extra helper."""
|
||||||
|
|
||||||
|
def test_raises_import_error(self):
|
||||||
|
"""require_extra raises ImportError."""
|
||||||
|
with pytest.raises(ImportError):
|
||||||
|
require_extra(package="some_pkg", extra="some_extra")
|
||||||
|
|
||||||
|
def test_error_message_contains_package_name(self):
|
||||||
|
"""Error message mentions the missing package."""
|
||||||
|
with pytest.raises(ImportError, match="'prometheus_client'"):
|
||||||
|
require_extra(package="prometheus_client", extra="metrics")
|
||||||
|
|
||||||
|
def test_error_message_contains_install_instruction(self):
|
||||||
|
"""Error message contains the pip install command."""
|
||||||
|
with pytest.raises(
|
||||||
|
ImportError, match=r"pip install fastapi-toolsets\[metrics\]"
|
||||||
|
):
|
||||||
|
require_extra(package="prometheus_client", extra="metrics")
|
||||||
|
|
||||||
|
|
||||||
|
def _reload_without_package(module_path: str, blocked_packages: list[str]):
|
||||||
|
"""Reload a module while blocking specific package imports.
|
||||||
|
|
||||||
|
Removes the target module and its parents from sys.modules so they
|
||||||
|
get re-imported, and patches builtins.__import__ to raise ImportError
|
||||||
|
for *blocked_packages*.
|
||||||
|
"""
|
||||||
|
# Remove cached modules so they get re-imported
|
||||||
|
to_remove = [
|
||||||
|
key
|
||||||
|
for key in sys.modules
|
||||||
|
if key == module_path or key.startswith(module_path + ".")
|
||||||
|
]
|
||||||
|
saved = {}
|
||||||
|
for key in to_remove:
|
||||||
|
saved[key] = sys.modules.pop(key)
|
||||||
|
|
||||||
|
# Also remove parent package to force re-execution of __init__.py
|
||||||
|
parts = module_path.rsplit(".", 1)
|
||||||
|
if len(parts) == 2:
|
||||||
|
parent = parts[0]
|
||||||
|
parent_keys = [
|
||||||
|
key for key in sys.modules if key == parent or key.startswith(parent + ".")
|
||||||
|
]
|
||||||
|
for key in parent_keys:
|
||||||
|
if key not in saved:
|
||||||
|
saved[key] = sys.modules.pop(key)
|
||||||
|
|
||||||
|
original_import = (
|
||||||
|
__builtins__.__import__ if hasattr(__builtins__, "__import__") else __import__
|
||||||
|
)
|
||||||
|
|
||||||
|
def blocking_import(name, *args, **kwargs):
|
||||||
|
for blocked in blocked_packages:
|
||||||
|
if name == blocked or name.startswith(blocked + "."):
|
||||||
|
raise ImportError(f"Mocked: No module named '{name}'")
|
||||||
|
return original_import(name, *args, **kwargs)
|
||||||
|
|
||||||
|
return saved, blocking_import
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetricsImportGuard:
|
||||||
|
"""Tests for metrics module import guard when prometheus_client is missing."""
|
||||||
|
|
||||||
|
def test_registry_imports_without_prometheus(self):
|
||||||
|
"""Metric and MetricsRegistry are importable without prometheus_client."""
|
||||||
|
saved, blocking_import = _reload_without_package(
|
||||||
|
"fastapi_toolsets.metrics", ["prometheus_client"]
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with patch("builtins.__import__", side_effect=blocking_import):
|
||||||
|
mod = importlib.import_module("fastapi_toolsets.metrics")
|
||||||
|
# Registry types should be available (they're stdlib-only)
|
||||||
|
assert hasattr(mod, "Metric")
|
||||||
|
assert hasattr(mod, "MetricsRegistry")
|
||||||
|
finally:
|
||||||
|
# Restore original modules
|
||||||
|
for key in list(sys.modules):
|
||||||
|
if key.startswith("fastapi_toolsets.metrics"):
|
||||||
|
sys.modules.pop(key, None)
|
||||||
|
sys.modules.update(saved)
|
||||||
|
|
||||||
|
def test_init_metrics_stub_raises_without_prometheus(self):
|
||||||
|
"""init_metrics raises ImportError when prometheus_client is missing."""
|
||||||
|
saved, blocking_import = _reload_without_package(
|
||||||
|
"fastapi_toolsets.metrics", ["prometheus_client"]
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with patch("builtins.__import__", side_effect=blocking_import):
|
||||||
|
mod = importlib.import_module("fastapi_toolsets.metrics")
|
||||||
|
with pytest.raises(ImportError, match="prometheus_client"):
|
||||||
|
mod.init_metrics(None, None) # type: ignore[arg-type]
|
||||||
|
finally:
|
||||||
|
for key in list(sys.modules):
|
||||||
|
if key.startswith("fastapi_toolsets.metrics"):
|
||||||
|
sys.modules.pop(key, None)
|
||||||
|
sys.modules.update(saved)
|
||||||
|
|
||||||
|
def test_init_metrics_works_with_prometheus(self):
|
||||||
|
"""init_metrics is the real function when prometheus_client is available."""
|
||||||
|
from fastapi_toolsets.metrics import init_metrics
|
||||||
|
|
||||||
|
# Should be the real function, not a stub
|
||||||
|
assert init_metrics.__module__ == "fastapi_toolsets.metrics.handler"
|
||||||
|
|
||||||
|
|
||||||
|
class TestPytestImportGuard:
|
||||||
|
"""Tests for pytest module import guard when dependencies are missing."""
|
||||||
|
|
||||||
|
def test_import_raises_without_pytest_package(self):
|
||||||
|
"""Importing fastapi_toolsets.pytest raises when pytest is missing."""
|
||||||
|
saved, blocking_import = _reload_without_package(
|
||||||
|
"fastapi_toolsets.pytest", ["pytest"]
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with patch("builtins.__import__", side_effect=blocking_import):
|
||||||
|
with pytest.raises(ImportError, match="pytest"):
|
||||||
|
importlib.import_module("fastapi_toolsets.pytest")
|
||||||
|
finally:
|
||||||
|
for key in list(sys.modules):
|
||||||
|
if key.startswith("fastapi_toolsets.pytest"):
|
||||||
|
sys.modules.pop(key, None)
|
||||||
|
sys.modules.update(saved)
|
||||||
|
|
||||||
|
def test_import_raises_without_httpx(self):
|
||||||
|
"""Importing fastapi_toolsets.pytest raises when httpx is missing."""
|
||||||
|
saved, blocking_import = _reload_without_package(
|
||||||
|
"fastapi_toolsets.pytest", ["httpx"]
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with patch("builtins.__import__", side_effect=blocking_import):
|
||||||
|
with pytest.raises(ImportError, match="httpx"):
|
||||||
|
importlib.import_module("fastapi_toolsets.pytest")
|
||||||
|
finally:
|
||||||
|
for key in list(sys.modules):
|
||||||
|
if key.startswith("fastapi_toolsets.pytest"):
|
||||||
|
sys.modules.pop(key, None)
|
||||||
|
sys.modules.update(saved)
|
||||||
|
|
||||||
|
def test_all_exports_available_with_deps(self):
|
||||||
|
"""All expected exports are available when deps are installed."""
|
||||||
|
from fastapi_toolsets.pytest import (
|
||||||
|
cleanup_tables,
|
||||||
|
create_async_client,
|
||||||
|
create_db_session,
|
||||||
|
create_worker_database,
|
||||||
|
register_fixtures,
|
||||||
|
worker_database_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert callable(register_fixtures)
|
||||||
|
assert callable(create_async_client)
|
||||||
|
assert callable(create_db_session)
|
||||||
|
assert callable(create_worker_database)
|
||||||
|
assert callable(worker_database_url)
|
||||||
|
assert callable(cleanup_tables)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCliImportGuard:
|
||||||
|
"""Tests for CLI module import guard when typer is missing."""
|
||||||
|
|
||||||
|
def test_import_raises_without_typer(self):
|
||||||
|
"""Importing cli.app raises when typer is missing."""
|
||||||
|
saved, blocking_import = _reload_without_package(
|
||||||
|
"fastapi_toolsets.cli.app", ["typer"]
|
||||||
|
)
|
||||||
|
# Also remove cli.config since it imports typer too
|
||||||
|
config_keys = [
|
||||||
|
k for k in sys.modules if k.startswith("fastapi_toolsets.cli.config")
|
||||||
|
]
|
||||||
|
for key in config_keys:
|
||||||
|
if key not in saved:
|
||||||
|
saved[key] = sys.modules.pop(key)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with patch("builtins.__import__", side_effect=blocking_import):
|
||||||
|
with pytest.raises(ImportError, match="typer"):
|
||||||
|
importlib.import_module("fastapi_toolsets.cli.app")
|
||||||
|
finally:
|
||||||
|
for key in list(sys.modules):
|
||||||
|
if key.startswith("fastapi_toolsets.cli.app") or key.startswith(
|
||||||
|
"fastapi_toolsets.cli.config"
|
||||||
|
):
|
||||||
|
sys.modules.pop(key, None)
|
||||||
|
sys.modules.update(saved)
|
||||||
|
|
||||||
|
def test_error_message_suggests_cli_extra(self):
|
||||||
|
"""Error message suggests installing the cli extra."""
|
||||||
|
saved, blocking_import = _reload_without_package(
|
||||||
|
"fastapi_toolsets.cli.app", ["typer"]
|
||||||
|
)
|
||||||
|
config_keys = [
|
||||||
|
k for k in sys.modules if k.startswith("fastapi_toolsets.cli.config")
|
||||||
|
]
|
||||||
|
for key in config_keys:
|
||||||
|
if key not in saved:
|
||||||
|
saved[key] = sys.modules.pop(key)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with patch("builtins.__import__", side_effect=blocking_import):
|
||||||
|
with pytest.raises(
|
||||||
|
ImportError, match=r"pip install fastapi-toolsets\[cli\]"
|
||||||
|
):
|
||||||
|
importlib.import_module("fastapi_toolsets.cli.app")
|
||||||
|
finally:
|
||||||
|
for key in list(sys.modules):
|
||||||
|
if key.startswith("fastapi_toolsets.cli.app") or key.startswith(
|
||||||
|
"fastapi_toolsets.cli.config"
|
||||||
|
):
|
||||||
|
sys.modules.pop(key, None)
|
||||||
|
sys.modules.update(saved)
|
||||||
|
|
||||||
|
def test_async_command_imports_without_typer(self):
|
||||||
|
"""async_command is importable without typer (stdlib only)."""
|
||||||
|
from fastapi_toolsets.cli import async_command
|
||||||
|
|
||||||
|
assert callable(async_command)
|
||||||
118
tests/test_logger.py
Normal file
118
tests/test_logger.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from fastapi_toolsets.logger import (
|
||||||
|
DEFAULT_FORMAT,
|
||||||
|
UVICORN_LOGGERS,
|
||||||
|
configure_logging,
|
||||||
|
get_logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _reset_loggers():
|
||||||
|
"""Reset the root and uvicorn loggers after each test."""
|
||||||
|
yield
|
||||||
|
root = logging.getLogger()
|
||||||
|
root.handlers.clear()
|
||||||
|
root.setLevel(logging.WARNING)
|
||||||
|
for name in UVICORN_LOGGERS:
|
||||||
|
uv = logging.getLogger(name)
|
||||||
|
uv.handlers.clear()
|
||||||
|
uv.setLevel(logging.NOTSET)
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfigureLogging:
|
||||||
|
def test_sets_up_handler_and_format(self):
|
||||||
|
logger = configure_logging()
|
||||||
|
|
||||||
|
assert len(logger.handlers) == 1
|
||||||
|
handler = logger.handlers[0]
|
||||||
|
assert isinstance(handler, logging.StreamHandler)
|
||||||
|
assert handler.stream is sys.stdout
|
||||||
|
assert handler.formatter is not None
|
||||||
|
assert handler.formatter._fmt == DEFAULT_FORMAT
|
||||||
|
|
||||||
|
def test_default_level_is_info(self):
|
||||||
|
logger = configure_logging()
|
||||||
|
|
||||||
|
assert logger.level == logging.INFO
|
||||||
|
|
||||||
|
def test_custom_level_string(self):
|
||||||
|
logger = configure_logging(level="DEBUG")
|
||||||
|
|
||||||
|
assert logger.level == logging.DEBUG
|
||||||
|
|
||||||
|
def test_custom_level_int(self):
|
||||||
|
logger = configure_logging(level=logging.WARNING)
|
||||||
|
|
||||||
|
assert logger.level == logging.WARNING
|
||||||
|
|
||||||
|
def test_custom_format(self):
|
||||||
|
custom_fmt = "%(levelname)s: %(message)s"
|
||||||
|
logger = configure_logging(fmt=custom_fmt)
|
||||||
|
|
||||||
|
handler = logger.handlers[0]
|
||||||
|
assert handler.formatter is not None
|
||||||
|
assert handler.formatter._fmt == custom_fmt
|
||||||
|
|
||||||
|
def test_named_logger(self):
|
||||||
|
logger = configure_logging(logger_name="myapp")
|
||||||
|
|
||||||
|
assert logger.name == "myapp"
|
||||||
|
assert len(logger.handlers) == 1
|
||||||
|
|
||||||
|
def test_default_configures_root_logger(self):
|
||||||
|
logger = configure_logging()
|
||||||
|
|
||||||
|
assert logger is logging.getLogger()
|
||||||
|
|
||||||
|
def test_idempotent_no_duplicate_handlers(self):
|
||||||
|
configure_logging()
|
||||||
|
configure_logging()
|
||||||
|
logger = configure_logging()
|
||||||
|
|
||||||
|
assert len(logger.handlers) == 1
|
||||||
|
|
||||||
|
def test_configures_uvicorn_loggers(self):
|
||||||
|
configure_logging(level="DEBUG")
|
||||||
|
|
||||||
|
for name in UVICORN_LOGGERS:
|
||||||
|
uv_logger = logging.getLogger(name)
|
||||||
|
assert len(uv_logger.handlers) == 1
|
||||||
|
assert uv_logger.level == logging.DEBUG
|
||||||
|
handler = uv_logger.handlers[0]
|
||||||
|
assert handler.formatter is not None
|
||||||
|
assert handler.formatter._fmt == DEFAULT_FORMAT
|
||||||
|
|
||||||
|
def test_returns_configured_logger(self):
|
||||||
|
logger = configure_logging(logger_name="test.return")
|
||||||
|
|
||||||
|
assert isinstance(logger, logging.Logger)
|
||||||
|
assert logger.name == "test.return"
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetLogger:
|
||||||
|
def test_returns_named_logger(self):
|
||||||
|
logger = get_logger("myapp.services")
|
||||||
|
|
||||||
|
assert isinstance(logger, logging.Logger)
|
||||||
|
assert logger.name == "myapp.services"
|
||||||
|
|
||||||
|
def test_returns_root_logger_when_none(self):
|
||||||
|
logger = get_logger(None)
|
||||||
|
|
||||||
|
assert logger is logging.getLogger()
|
||||||
|
|
||||||
|
def test_defaults_to_caller_module_name(self):
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
assert logger.name == __name__
|
||||||
|
|
||||||
|
def test_same_name_returns_same_logger(self):
|
||||||
|
a = get_logger("myapp")
|
||||||
|
b = get_logger("myapp")
|
||||||
|
|
||||||
|
assert a is b
|
||||||
519
tests/test_metrics.py
Normal file
519
tests/test_metrics.py
Normal file
@@ -0,0 +1,519 @@
|
|||||||
|
"""Tests for fastapi_toolsets.metrics module."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from prometheus_client import REGISTRY, CollectorRegistry, Counter, Gauge
|
||||||
|
|
||||||
|
from fastapi_toolsets.metrics import Metric, MetricsRegistry, init_metrics
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _clean_prometheus_registry():
|
||||||
|
"""Unregister test collectors from the global registry after each test."""
|
||||||
|
yield
|
||||||
|
collectors = list(REGISTRY._names_to_collectors.values())
|
||||||
|
for collector in collectors:
|
||||||
|
try:
|
||||||
|
REGISTRY.unregister(collector)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetric:
|
||||||
|
"""Tests for Metric dataclass."""
|
||||||
|
|
||||||
|
def test_default_collect_is_false(self):
|
||||||
|
"""Default collect is False (provider mode)."""
|
||||||
|
definition = Metric(name="test", func=lambda: None)
|
||||||
|
assert definition.collect is False
|
||||||
|
|
||||||
|
def test_collect_true(self):
|
||||||
|
"""Collect can be set to True (collector mode)."""
|
||||||
|
definition = Metric(name="test", func=lambda: None, collect=True)
|
||||||
|
assert definition.collect is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetricsRegistry:
|
||||||
|
"""Tests for MetricsRegistry class."""
|
||||||
|
|
||||||
|
def test_register_with_decorator(self):
|
||||||
|
"""Register metric with bare decorator."""
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def my_counter():
|
||||||
|
return Counter("test_counter", "A test counter")
|
||||||
|
|
||||||
|
names = [m.name for m in registry.get_all()]
|
||||||
|
assert "my_counter" in names
|
||||||
|
|
||||||
|
def test_register_with_custom_name(self):
|
||||||
|
"""Register metric with custom name."""
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register(name="custom_name")
|
||||||
|
def my_counter():
|
||||||
|
return Counter("test_counter_2", "A test counter")
|
||||||
|
|
||||||
|
definition = registry.get_all()[0]
|
||||||
|
assert definition.name == "custom_name"
|
||||||
|
|
||||||
|
def test_register_as_collector(self):
|
||||||
|
"""Register metric with collect=True."""
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register(collect=True)
|
||||||
|
def collect_something():
|
||||||
|
pass
|
||||||
|
|
||||||
|
definition = registry.get_all()[0]
|
||||||
|
assert definition.collect is True
|
||||||
|
|
||||||
|
def test_register_preserves_function(self):
|
||||||
|
"""Decorator returns the original function unchanged."""
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
def my_func():
|
||||||
|
return "original"
|
||||||
|
|
||||||
|
result = registry.register(my_func)
|
||||||
|
assert result is my_func
|
||||||
|
assert result() == "original"
|
||||||
|
|
||||||
|
def test_register_parameterized_preserves_function(self):
|
||||||
|
"""Parameterized decorator returns the original function unchanged."""
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
def my_func():
|
||||||
|
return "original"
|
||||||
|
|
||||||
|
result = registry.register(name="custom")(my_func)
|
||||||
|
assert result is my_func
|
||||||
|
assert result() == "original"
|
||||||
|
|
||||||
|
def test_get_all(self):
|
||||||
|
"""Get all registered metrics."""
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def metric_a():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def metric_b():
|
||||||
|
pass
|
||||||
|
|
||||||
|
names = {m.name for m in registry.get_all()}
|
||||||
|
assert names == {"metric_a", "metric_b"}
|
||||||
|
|
||||||
|
def test_get_providers(self):
|
||||||
|
"""Get only provider metrics (collect=False)."""
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def provider():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@registry.register(collect=True)
|
||||||
|
def collector():
|
||||||
|
pass
|
||||||
|
|
||||||
|
providers = registry.get_providers()
|
||||||
|
assert len(providers) == 1
|
||||||
|
assert providers[0].name == "provider"
|
||||||
|
|
||||||
|
def test_get_collectors(self):
|
||||||
|
"""Get only collector metrics (collect=True)."""
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def provider():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@registry.register(collect=True)
|
||||||
|
def collector():
|
||||||
|
pass
|
||||||
|
|
||||||
|
collectors = registry.get_collectors()
|
||||||
|
assert len(collectors) == 1
|
||||||
|
assert collectors[0].name == "collector"
|
||||||
|
|
||||||
|
def test_register_overwrites_same_name(self):
|
||||||
|
"""Registering with the same name overwrites the previous entry."""
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register(name="metric")
|
||||||
|
def first():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@registry.register(name="metric")
|
||||||
|
def second():
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert len(registry.get_all()) == 1
|
||||||
|
assert registry.get_all()[0].func is second
|
||||||
|
|
||||||
|
|
||||||
|
class TestIncludeRegistry:
|
||||||
|
"""Tests for MetricsRegistry.include_registry method."""
|
||||||
|
|
||||||
|
def test_include_empty_registry(self):
|
||||||
|
"""Include an empty registry does nothing."""
|
||||||
|
main = MetricsRegistry()
|
||||||
|
other = MetricsRegistry()
|
||||||
|
|
||||||
|
@main.register
|
||||||
|
def metric_a():
|
||||||
|
pass
|
||||||
|
|
||||||
|
main.include_registry(other)
|
||||||
|
assert len(main.get_all()) == 1
|
||||||
|
|
||||||
|
def test_include_registry_adds_metrics(self):
|
||||||
|
"""Include registry adds all metrics from the other registry."""
|
||||||
|
main = MetricsRegistry()
|
||||||
|
other = MetricsRegistry()
|
||||||
|
|
||||||
|
@main.register
|
||||||
|
def metric_a():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@other.register
|
||||||
|
def metric_b():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@other.register
|
||||||
|
def metric_c():
|
||||||
|
pass
|
||||||
|
|
||||||
|
main.include_registry(other)
|
||||||
|
names = {m.name for m in main.get_all()}
|
||||||
|
assert names == {"metric_a", "metric_b", "metric_c"}
|
||||||
|
|
||||||
|
def test_include_registry_preserves_collect_flag(self):
|
||||||
|
"""Include registry preserves the collect flag."""
|
||||||
|
main = MetricsRegistry()
|
||||||
|
other = MetricsRegistry()
|
||||||
|
|
||||||
|
@other.register(collect=True)
|
||||||
|
def collector():
|
||||||
|
pass
|
||||||
|
|
||||||
|
main.include_registry(other)
|
||||||
|
assert main.get_all()[0].collect is True
|
||||||
|
|
||||||
|
def test_include_registry_raises_on_duplicate(self):
|
||||||
|
"""Include registry raises ValueError on duplicate metric names."""
|
||||||
|
main = MetricsRegistry()
|
||||||
|
other = MetricsRegistry()
|
||||||
|
|
||||||
|
@main.register(name="metric")
|
||||||
|
def metric_main():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@other.register(name="metric")
|
||||||
|
def metric_other():
|
||||||
|
pass
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="already exists"):
|
||||||
|
main.include_registry(other)
|
||||||
|
|
||||||
|
def test_include_multiple_registries(self):
|
||||||
|
"""Include multiple registries sequentially."""
|
||||||
|
main = MetricsRegistry()
|
||||||
|
sub1 = MetricsRegistry()
|
||||||
|
sub2 = MetricsRegistry()
|
||||||
|
|
||||||
|
@main.register
|
||||||
|
def base():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@sub1.register
|
||||||
|
def sub1_metric():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@sub2.register
|
||||||
|
def sub2_metric():
|
||||||
|
pass
|
||||||
|
|
||||||
|
main.include_registry(sub1)
|
||||||
|
main.include_registry(sub2)
|
||||||
|
|
||||||
|
names = {m.name for m in main.get_all()}
|
||||||
|
assert names == {"base", "sub1_metric", "sub2_metric"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestInitMetrics:
|
||||||
|
"""Tests for init_metrics function."""
|
||||||
|
|
||||||
|
def test_returns_app(self):
|
||||||
|
"""Returns the FastAPI app."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
result = init_metrics(app, registry)
|
||||||
|
assert result is app
|
||||||
|
|
||||||
|
def test_metrics_endpoint_responds(self):
|
||||||
|
"""The /metrics endpoint returns 200."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/metrics")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_metrics_endpoint_content_type(self):
|
||||||
|
"""The /metrics endpoint returns prometheus content type."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/metrics")
|
||||||
|
|
||||||
|
assert "text/plain" in response.headers["content-type"]
|
||||||
|
|
||||||
|
def test_custom_path(self):
|
||||||
|
"""Custom path is used for the metrics endpoint."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
init_metrics(app, registry, path="/custom-metrics")
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
assert client.get("/custom-metrics").status_code == 200
|
||||||
|
assert client.get("/metrics").status_code == 404
|
||||||
|
|
||||||
|
def test_providers_called_at_init(self):
|
||||||
|
"""Provider functions are called once at init time."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
mock = MagicMock()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def my_provider():
|
||||||
|
mock()
|
||||||
|
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
mock.assert_called_once()
|
||||||
|
|
||||||
|
def test_collectors_called_on_scrape(self):
|
||||||
|
"""Collector functions are called on each scrape."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
mock = MagicMock()
|
||||||
|
|
||||||
|
@registry.register(collect=True)
|
||||||
|
def my_collector():
|
||||||
|
mock()
|
||||||
|
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
client.get("/metrics")
|
||||||
|
client.get("/metrics")
|
||||||
|
|
||||||
|
assert mock.call_count == 2
|
||||||
|
|
||||||
|
def test_collectors_not_called_at_init(self):
|
||||||
|
"""Collector functions are not called at init time."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
mock = MagicMock()
|
||||||
|
|
||||||
|
@registry.register(collect=True)
|
||||||
|
def my_collector():
|
||||||
|
mock()
|
||||||
|
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
mock.assert_not_called()
|
||||||
|
|
||||||
|
def test_async_collectors_called_on_scrape(self):
|
||||||
|
"""Async collector functions are awaited on each scrape."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
mock = AsyncMock()
|
||||||
|
|
||||||
|
@registry.register(collect=True)
|
||||||
|
async def my_async_collector():
|
||||||
|
await mock()
|
||||||
|
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
client.get("/metrics")
|
||||||
|
client.get("/metrics")
|
||||||
|
|
||||||
|
assert mock.call_count == 2
|
||||||
|
|
||||||
|
def test_mixed_sync_and_async_collectors(self):
|
||||||
|
"""Both sync and async collectors are called on scrape."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
sync_mock = MagicMock()
|
||||||
|
async_mock = AsyncMock()
|
||||||
|
|
||||||
|
@registry.register(collect=True)
|
||||||
|
def sync_collector():
|
||||||
|
sync_mock()
|
||||||
|
|
||||||
|
@registry.register(collect=True)
|
||||||
|
async def async_collector():
|
||||||
|
await async_mock()
|
||||||
|
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
client.get("/metrics")
|
||||||
|
|
||||||
|
sync_mock.assert_called_once()
|
||||||
|
async_mock.assert_called_once()
|
||||||
|
|
||||||
|
def test_registered_metrics_appear_in_output(self):
|
||||||
|
"""Metrics created by providers appear in /metrics output."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def my_gauge():
|
||||||
|
g = Gauge("test_gauge_value", "A test gauge")
|
||||||
|
g.set(42)
|
||||||
|
return g
|
||||||
|
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/metrics")
|
||||||
|
|
||||||
|
assert b"test_gauge_value" in response.content
|
||||||
|
assert b"42.0" in response.content
|
||||||
|
|
||||||
|
def test_endpoint_not_in_openapi_schema(self):
|
||||||
|
"""The /metrics endpoint is not included in the OpenAPI schema."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
schema = app.openapi()
|
||||||
|
assert "/metrics" not in schema.get("paths", {})
|
||||||
|
|
||||||
|
|
||||||
|
class TestMultiProcessMode:
|
||||||
|
"""Tests for multi-process Prometheus mode."""
|
||||||
|
|
||||||
|
def test_multiprocess_with_env_var(self):
|
||||||
|
"""Multi-process mode works when PROMETHEUS_MULTIPROC_DIR is set."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
os.environ["PROMETHEUS_MULTIPROC_DIR"] = tmpdir
|
||||||
|
try:
|
||||||
|
# Use a separate registry to avoid conflicts with default
|
||||||
|
prom_registry = CollectorRegistry()
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def mp_counter():
|
||||||
|
return Counter(
|
||||||
|
"mp_test_counter",
|
||||||
|
"A multiprocess counter",
|
||||||
|
registry=prom_registry,
|
||||||
|
)
|
||||||
|
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/metrics")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
finally:
|
||||||
|
del os.environ["PROMETHEUS_MULTIPROC_DIR"]
|
||||||
|
|
||||||
|
def test_single_process_without_env_var(self):
|
||||||
|
"""Single-process mode when PROMETHEUS_MULTIPROC_DIR is not set."""
|
||||||
|
os.environ.pop("PROMETHEUS_MULTIPROC_DIR", None)
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def sp_gauge():
|
||||||
|
g = Gauge("sp_test_gauge", "A single-process gauge")
|
||||||
|
g.set(99)
|
||||||
|
return g
|
||||||
|
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/metrics")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert b"sp_test_gauge" in response.content
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetricsIntegration:
|
||||||
|
"""Integration tests for the metrics module."""
|
||||||
|
|
||||||
|
def test_full_workflow(self):
|
||||||
|
"""Full workflow: registry, providers, collectors, endpoint."""
|
||||||
|
app = FastAPI()
|
||||||
|
registry = MetricsRegistry()
|
||||||
|
call_count = {"value": 0}
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def request_counter():
|
||||||
|
return Counter(
|
||||||
|
"integration_requests_total",
|
||||||
|
"Total requests",
|
||||||
|
["method"],
|
||||||
|
)
|
||||||
|
|
||||||
|
@registry.register(collect=True)
|
||||||
|
def collect_uptime():
|
||||||
|
call_count["value"] += 1
|
||||||
|
|
||||||
|
init_metrics(app, registry)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
response = client.get("/metrics")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert b"integration_requests_total" in response.content
|
||||||
|
assert call_count["value"] == 1
|
||||||
|
|
||||||
|
response = client.get("/metrics")
|
||||||
|
assert call_count["value"] == 2
|
||||||
|
|
||||||
|
def test_multiple_registries_merged(self):
|
||||||
|
"""Multiple registries can be merged and used together."""
|
||||||
|
app = FastAPI()
|
||||||
|
main = MetricsRegistry()
|
||||||
|
sub = MetricsRegistry()
|
||||||
|
|
||||||
|
@main.register
|
||||||
|
def main_gauge():
|
||||||
|
g = Gauge("main_gauge_val", "Main gauge")
|
||||||
|
g.set(1)
|
||||||
|
return g
|
||||||
|
|
||||||
|
@sub.register
|
||||||
|
def sub_gauge():
|
||||||
|
g = Gauge("sub_gauge_val", "Sub gauge")
|
||||||
|
g.set(2)
|
||||||
|
return g
|
||||||
|
|
||||||
|
main.include_registry(sub)
|
||||||
|
init_metrics(app, main)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/metrics")
|
||||||
|
|
||||||
|
assert b"main_gauge_val" in response.content
|
||||||
|
assert b"sub_gauge_val" in response.content
|
||||||
@@ -3,18 +3,23 @@
|
|||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi import FastAPI
|
from fastapi import Depends, FastAPI
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select, text
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.engine import make_url
|
||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, selectinload
|
||||||
|
|
||||||
from fastapi_toolsets.fixtures import Context, FixtureRegistry
|
from fastapi_toolsets.fixtures import Context, FixtureRegistry
|
||||||
from fastapi_toolsets.pytest import (
|
from fastapi_toolsets.pytest import (
|
||||||
|
cleanup_tables,
|
||||||
create_async_client,
|
create_async_client,
|
||||||
create_db_session,
|
create_db_session,
|
||||||
|
create_worker_database,
|
||||||
register_fixtures,
|
register_fixtures,
|
||||||
|
worker_database_url,
|
||||||
)
|
)
|
||||||
|
from fastapi_toolsets.pytest.utils import _get_xdist_worker
|
||||||
|
|
||||||
from .conftest import DATABASE_URL, Base, Role, RoleCrud, User, UserCrud
|
from .conftest import DATABASE_URL, Base, Role, RoleCrud, User, UserCrud
|
||||||
|
|
||||||
@@ -231,6 +236,30 @@ class TestCreateAsyncClient:
|
|||||||
|
|
||||||
assert client_ref.is_closed
|
assert client_ref.is_closed
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_dependency_overrides_applied_and_cleaned(self):
|
||||||
|
"""Dependency overrides are applied during the context and removed after."""
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
async def original_dep() -> str:
|
||||||
|
return "original"
|
||||||
|
|
||||||
|
async def override_dep() -> str:
|
||||||
|
return "overridden"
|
||||||
|
|
||||||
|
@app.get("/dep")
|
||||||
|
async def dep_endpoint(value: str = Depends(original_dep)):
|
||||||
|
return {"value": value}
|
||||||
|
|
||||||
|
async with create_async_client(
|
||||||
|
app, dependency_overrides={original_dep: override_dep}
|
||||||
|
) as client:
|
||||||
|
response = await client.get("/dep")
|
||||||
|
assert response.json() == {"value": "overridden"}
|
||||||
|
|
||||||
|
# Overrides should be cleaned up
|
||||||
|
assert original_dep not in app.dependency_overrides
|
||||||
|
|
||||||
|
|
||||||
class TestCreateDbSession:
|
class TestCreateDbSession:
|
||||||
"""Tests for create_db_session helper."""
|
"""Tests for create_db_session helper."""
|
||||||
@@ -291,3 +320,216 @@ class TestCreateDbSession:
|
|||||||
# Cleanup: drop tables manually
|
# Cleanup: drop tables manually
|
||||||
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as _:
|
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as _:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_cleanup_truncates_tables(self):
|
||||||
|
"""Tables are truncated after session closes when cleanup=True."""
|
||||||
|
role_id = uuid.uuid4()
|
||||||
|
async with create_db_session(
|
||||||
|
DATABASE_URL, Base, cleanup=True, drop_tables=False
|
||||||
|
) as session:
|
||||||
|
role = Role(id=role_id, name="will_be_cleaned")
|
||||||
|
session.add(role)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# Data should have been truncated, but tables still exist
|
||||||
|
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
|
||||||
|
result = await session.execute(select(Role))
|
||||||
|
assert result.all() == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetXdistWorker:
|
||||||
|
"""Tests for _get_xdist_worker helper."""
|
||||||
|
|
||||||
|
def test_returns_default_test_db_without_env_var(
|
||||||
|
self, monkeypatch: pytest.MonkeyPatch
|
||||||
|
):
|
||||||
|
"""Returns default_test_db when PYTEST_XDIST_WORKER is not set."""
|
||||||
|
monkeypatch.delenv("PYTEST_XDIST_WORKER", raising=False)
|
||||||
|
assert _get_xdist_worker("my_default") == "my_default"
|
||||||
|
|
||||||
|
def test_returns_worker_name(self, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
"""Returns the worker name from the environment variable."""
|
||||||
|
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw0")
|
||||||
|
assert _get_xdist_worker("ignored") == "gw0"
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorkerDatabaseUrl:
|
||||||
|
"""Tests for worker_database_url helper."""
|
||||||
|
|
||||||
|
def test_appends_default_test_db_without_xdist(
|
||||||
|
self, monkeypatch: pytest.MonkeyPatch
|
||||||
|
):
|
||||||
|
"""default_test_db is appended when not running under xdist."""
|
||||||
|
monkeypatch.delenv("PYTEST_XDIST_WORKER", raising=False)
|
||||||
|
url = "postgresql+asyncpg://user:pass@localhost:5432/mydb"
|
||||||
|
result = worker_database_url(url, default_test_db="fallback")
|
||||||
|
assert make_url(result).database == "mydb_fallback"
|
||||||
|
|
||||||
|
def test_appends_worker_id_to_database_name(self, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
"""Worker name is appended to the database name."""
|
||||||
|
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw0")
|
||||||
|
url = "postgresql+asyncpg://user:pass@localhost:5432/db"
|
||||||
|
result = worker_database_url(url, default_test_db="unused")
|
||||||
|
assert make_url(result).database == "db_gw0"
|
||||||
|
|
||||||
|
def test_preserves_url_components(self, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
"""Host, port, username, password, and driver are preserved."""
|
||||||
|
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw2")
|
||||||
|
url = "postgresql+asyncpg://myuser:secret@dbhost:6543/testdb"
|
||||||
|
result = make_url(worker_database_url(url, default_test_db="unused"))
|
||||||
|
|
||||||
|
assert result.drivername == "postgresql+asyncpg"
|
||||||
|
assert result.username == "myuser"
|
||||||
|
assert result.password == "secret"
|
||||||
|
assert result.host == "dbhost"
|
||||||
|
assert result.port == 6543
|
||||||
|
assert result.database == "testdb_gw2"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateWorkerDatabase:
|
||||||
|
"""Tests for create_worker_database context manager."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_creates_default_db_without_xdist(
|
||||||
|
self, monkeypatch: pytest.MonkeyPatch
|
||||||
|
):
|
||||||
|
"""Without xdist, creates a database suffixed with default_test_db."""
|
||||||
|
monkeypatch.delenv("PYTEST_XDIST_WORKER", raising=False)
|
||||||
|
default_test_db = "no_xdist_default"
|
||||||
|
expected_db = make_url(
|
||||||
|
worker_database_url(DATABASE_URL, default_test_db=default_test_db)
|
||||||
|
).database
|
||||||
|
|
||||||
|
async with create_worker_database(
|
||||||
|
DATABASE_URL, default_test_db=default_test_db
|
||||||
|
) as url:
|
||||||
|
assert make_url(url).database == expected_db
|
||||||
|
|
||||||
|
# Verify the database exists while inside the context
|
||||||
|
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
result = await conn.execute(
|
||||||
|
text("SELECT 1 FROM pg_database WHERE datname = :name"),
|
||||||
|
{"name": expected_db},
|
||||||
|
)
|
||||||
|
assert result.scalar() == 1
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
# After context exit the database should be dropped
|
||||||
|
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
result = await conn.execute(
|
||||||
|
text("SELECT 1 FROM pg_database WHERE datname = :name"),
|
||||||
|
{"name": expected_db},
|
||||||
|
)
|
||||||
|
assert result.scalar() is None
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_creates_and_drops_worker_database(
|
||||||
|
self, monkeypatch: pytest.MonkeyPatch
|
||||||
|
):
|
||||||
|
"""Worker database exists inside the context and is dropped after."""
|
||||||
|
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw_test_create")
|
||||||
|
expected_db = make_url(
|
||||||
|
worker_database_url(DATABASE_URL, default_test_db="unused")
|
||||||
|
).database
|
||||||
|
|
||||||
|
async with create_worker_database(DATABASE_URL) as url:
|
||||||
|
assert make_url(url).database == expected_db
|
||||||
|
|
||||||
|
# Verify the database exists while inside the context
|
||||||
|
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
result = await conn.execute(
|
||||||
|
text("SELECT 1 FROM pg_database WHERE datname = :name"),
|
||||||
|
{"name": expected_db},
|
||||||
|
)
|
||||||
|
assert result.scalar() == 1
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
# After context exit the database should be dropped
|
||||||
|
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
result = await conn.execute(
|
||||||
|
text("SELECT 1 FROM pg_database WHERE datname = :name"),
|
||||||
|
{"name": expected_db},
|
||||||
|
)
|
||||||
|
assert result.scalar() is None
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_cleans_up_stale_database(self, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
"""A pre-existing worker database is dropped and recreated."""
|
||||||
|
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw_test_stale")
|
||||||
|
expected_db = make_url(
|
||||||
|
worker_database_url(DATABASE_URL, default_test_db="unused")
|
||||||
|
).database
|
||||||
|
|
||||||
|
# Pre-create the database to simulate a stale leftover
|
||||||
|
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
await conn.execute(text(f"DROP DATABASE IF EXISTS {expected_db}"))
|
||||||
|
await conn.execute(text(f"CREATE DATABASE {expected_db}"))
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
# Should succeed despite the database already existing
|
||||||
|
async with create_worker_database(DATABASE_URL) as url:
|
||||||
|
assert make_url(url).database == expected_db
|
||||||
|
|
||||||
|
# Verify cleanup after context exit
|
||||||
|
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
result = await conn.execute(
|
||||||
|
text("SELECT 1 FROM pg_database WHERE datname = :name"),
|
||||||
|
{"name": expected_db},
|
||||||
|
)
|
||||||
|
assert result.scalar() is None
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
class TestCleanupTables:
|
||||||
|
"""Tests for cleanup_tables helper."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_truncates_all_tables(self):
|
||||||
|
"""All table rows are removed after cleanup_tables."""
|
||||||
|
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
|
||||||
|
role = Role(id=uuid.uuid4(), name="cleanup_role")
|
||||||
|
session.add(role)
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
username="cleanup_user",
|
||||||
|
email="cleanup@test.com",
|
||||||
|
role_id=role.id,
|
||||||
|
)
|
||||||
|
session.add(user)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# Verify rows exist
|
||||||
|
roles_count = await RoleCrud.count(session)
|
||||||
|
users_count = await UserCrud.count(session)
|
||||||
|
assert roles_count == 1
|
||||||
|
assert users_count == 1
|
||||||
|
|
||||||
|
await cleanup_tables(session, Base)
|
||||||
|
|
||||||
|
# Verify tables are empty
|
||||||
|
roles_count = await RoleCrud.count(session)
|
||||||
|
users_count = await UserCrud.count(session)
|
||||||
|
assert roles_count == 0
|
||||||
|
assert users_count == 0
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_noop_for_empty_metadata(self):
|
||||||
|
"""cleanup_tables does not raise when metadata has no tables."""
|
||||||
|
|
||||||
|
class EmptyBase(DeclarativeBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
|
||||||
|
# Should not raise
|
||||||
|
await cleanup_tables(session, EmptyBase)
|
||||||
|
|||||||
@@ -46,6 +46,31 @@ class TestApiError:
|
|||||||
assert error.desc == "The resource was not found."
|
assert error.desc == "The resource was not found."
|
||||||
assert error.err_code == "RES-404"
|
assert error.err_code == "RES-404"
|
||||||
|
|
||||||
|
def test_data_defaults_to_none(self):
|
||||||
|
"""ApiError data field defaults to None."""
|
||||||
|
error = ApiError(
|
||||||
|
code=404,
|
||||||
|
msg="Not Found",
|
||||||
|
desc="The resource was not found.",
|
||||||
|
err_code="RES-404",
|
||||||
|
)
|
||||||
|
assert error.data is None
|
||||||
|
|
||||||
|
def test_create_with_data(self):
|
||||||
|
"""ApiError can be created with a data payload."""
|
||||||
|
error = ApiError(
|
||||||
|
code=422,
|
||||||
|
msg="Validation Error",
|
||||||
|
desc="2 validation error(s) detected",
|
||||||
|
err_code="VAL-422",
|
||||||
|
data={
|
||||||
|
"errors": [{"field": "name", "message": "required", "type": "missing"}]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert error.data == {
|
||||||
|
"errors": [{"field": "name", "message": "required", "type": "missing"}]
|
||||||
|
}
|
||||||
|
|
||||||
def test_requires_all_fields(self):
|
def test_requires_all_fields(self):
|
||||||
"""ApiError requires all fields."""
|
"""ApiError requires all fields."""
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
|
|||||||
318
zensical.toml
Normal file
318
zensical.toml
Normal file
@@ -0,0 +1,318 @@
|
|||||||
|
# ============================================================================
|
||||||
|
#
|
||||||
|
# The configuration produced by default is meant to highlight the features
|
||||||
|
# that Zensical provides and to serve as a starting point for your own
|
||||||
|
# projects.
|
||||||
|
#
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
[project]
|
||||||
|
|
||||||
|
# The site_name is shown in the page header and the browser window title
|
||||||
|
#
|
||||||
|
# Read more: https://zensical.org/docs/setup/basics/#site_name
|
||||||
|
site_name = "FastAPI Toolsets"
|
||||||
|
|
||||||
|
# The site_description is included in the HTML head and should contain a
|
||||||
|
# meaningful description of the site content for use by search engines.
|
||||||
|
#
|
||||||
|
# Read more: https://zensical.org/docs/setup/basics/#site_description
|
||||||
|
site_description = "Production-ready utilities for FastAPI applications."
|
||||||
|
|
||||||
|
# The site_author attribute. This is used in the HTML head element.
|
||||||
|
#
|
||||||
|
# Read more: https://zensical.org/docs/setup/basics/#site_author
|
||||||
|
site_author = "d3vyce"
|
||||||
|
|
||||||
|
# The site_url is the canonical URL for your site. When building online
|
||||||
|
# documentation you should set this.
|
||||||
|
# Read more: https://zensical.org/docs/setup/basics/#site_url
|
||||||
|
site_url = "https://fastapi-toolsets.d3vyce.fr"
|
||||||
|
|
||||||
|
# The copyright notice appears in the page footer and can contain an HTML
|
||||||
|
# fragment.
|
||||||
|
#
|
||||||
|
# Read more: https://zensical.org/docs/setup/basics/#copyright
|
||||||
|
copyright = """
|
||||||
|
Copyright © 2026 d3vyce
|
||||||
|
"""
|
||||||
|
|
||||||
|
repo_url = "https://github.com/d3vyce/fastapi-toolsets"
|
||||||
|
|
||||||
|
# Zensical supports both implicit navigation and explicitly defined navigation.
|
||||||
|
# If you decide not to define a navigation here then Zensical will simply
|
||||||
|
# derive the navigation structure from the directory structure of your
|
||||||
|
# "docs_dir". The definition below demonstrates how a navigation structure
|
||||||
|
# can be defined using TOML syntax.
|
||||||
|
#
|
||||||
|
# Read more: https://zensical.org/docs/setup/navigation/
|
||||||
|
# nav = [
|
||||||
|
# { "Get started" = "index.md" },
|
||||||
|
# { "Markdown in 5min" = "markdown.md" },
|
||||||
|
# ]
|
||||||
|
|
||||||
|
# With the "extra_css" option you can add your own CSS styling to customize
|
||||||
|
# your Zensical project according to your needs. You can add any number of
|
||||||
|
# CSS files.
|
||||||
|
#
|
||||||
|
# The path provided should be relative to the "docs_dir".
|
||||||
|
#
|
||||||
|
# Read more: https://zensical.org/docs/customization/#additional-css
|
||||||
|
#
|
||||||
|
#extra_css = ["stylesheets/extra.css"]
|
||||||
|
|
||||||
|
# With the `extra_javascript` option you can add your own JavaScript to your
|
||||||
|
# project to customize the behavior according to your needs.
|
||||||
|
#
|
||||||
|
# The path provided should be relative to the "docs_dir".
|
||||||
|
#
|
||||||
|
# Read more: https://zensical.org/docs/customization/#additional-javascript
|
||||||
|
#extra_javascript = ["javascripts/extra.js"]
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
# Section for configuring theme options
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
[project.theme]
|
||||||
|
|
||||||
|
# change this to "classic" to use the traditional Material for MkDocs look.
|
||||||
|
#variant = "classic"
|
||||||
|
|
||||||
|
# Zensical allows you to override specific blocks, partials, or whole
|
||||||
|
# templates as well as to define your own templates. To do this, uncomment
|
||||||
|
# the custom_dir setting below and set it to a directory in which you
|
||||||
|
# keep your template overrides.
|
||||||
|
#
|
||||||
|
# Read more:
|
||||||
|
# - https://zensical.org/docs/customization/#extending-the-theme
|
||||||
|
#
|
||||||
|
custom_dir = "docs/overrides"
|
||||||
|
|
||||||
|
# With the "favicon" option you can set your own image to use as the icon
|
||||||
|
# browsers will use in the browser title bar or tab bar. The path provided
|
||||||
|
# must be relative to the "docs_dir".
|
||||||
|
#
|
||||||
|
# Read more:
|
||||||
|
# - https://zensical.org/docs/setup/logo-and-icons/#favicon
|
||||||
|
# - https://developer.mozilla.org/en-US/docs/Glossary/Favicon
|
||||||
|
#
|
||||||
|
#favicon = "images/favicon.png"
|
||||||
|
|
||||||
|
# Zensical supports more than 60 different languages. This means that the
|
||||||
|
# labels and tooltips that Zensical's templates produce are translated.
|
||||||
|
# The "language" option allows you to set the language used. This language
|
||||||
|
# is also indicated in the HTML head element to help with accessibility
|
||||||
|
# and guide search engines and translation tools.
|
||||||
|
#
|
||||||
|
# The default language is "en" (English). It is possible to create
|
||||||
|
# sites with multiple languages and configure a language selector. See
|
||||||
|
# the documentation for details.
|
||||||
|
#
|
||||||
|
# Read more:
|
||||||
|
# - https://zensical.org/docs/setup/language/
|
||||||
|
#
|
||||||
|
language = "en"
|
||||||
|
|
||||||
|
# Zensical provides a number of feature toggles that change the behavior
|
||||||
|
# of the documentation site.
|
||||||
|
features = [
|
||||||
|
# Zensical includes an announcement bar. This feature allows users to
|
||||||
|
# dismiss it when they have read the announcement.
|
||||||
|
# https://zensical.org/docs/setup/header/#announcement-bar
|
||||||
|
"announce.dismiss",
|
||||||
|
|
||||||
|
# If you have a repository configured and turn on this feature, Zensical
|
||||||
|
# will generate an edit button for the page. This works for common
|
||||||
|
# repository hosting services.
|
||||||
|
# https://zensical.org/docs/setup/repository/#content-actions
|
||||||
|
#"content.action.edit",
|
||||||
|
|
||||||
|
# If you have a repository configured and turn on this feature, Zensical
|
||||||
|
# will generate a button that allows the user to view the Markdown
|
||||||
|
# code for the current page.
|
||||||
|
# https://zensical.org/docs/setup/repository/#content-actions
|
||||||
|
"content.action.view",
|
||||||
|
|
||||||
|
# Code annotations allow you to add an icon with a tooltip to your
|
||||||
|
# code blocks to provide explanations at crucial points.
|
||||||
|
# https://zensical.org/docs/authoring/code-blocks/#code-annotations
|
||||||
|
"content.code.annotate",
|
||||||
|
|
||||||
|
# This feature turns on a button in code blocks that allow users to
|
||||||
|
# copy the content to their clipboard without first selecting it.
|
||||||
|
# https://zensical.org/docs/authoring/code-blocks/#code-copy-button
|
||||||
|
"content.code.copy",
|
||||||
|
|
||||||
|
# Code blocks can include a button to allow for the selection of line
|
||||||
|
# ranges by the user.
|
||||||
|
# https://zensical.org/docs/authoring/code-blocks/#code-selection-button
|
||||||
|
"content.code.select",
|
||||||
|
|
||||||
|
# Zensical can render footnotes as inline tooltips, so the user can read
|
||||||
|
# the footnote without leaving the context of the document.
|
||||||
|
# https://zensical.org/docs/authoring/footnotes/#footnote-tooltips
|
||||||
|
"content.footnote.tooltips",
|
||||||
|
|
||||||
|
# If you have many content tabs that have the same titles (e.g., "Python",
|
||||||
|
# "JavaScript", "Cobol"), this feature causes all of them to switch to
|
||||||
|
# at the same time when the user chooses their language in one.
|
||||||
|
# https://zensical.org/docs/authoring/content-tabs/#linked-content-tabs
|
||||||
|
"content.tabs.link",
|
||||||
|
|
||||||
|
# With this feature enabled users can add tooltips to links that will be
|
||||||
|
# displayed when the mouse pointer hovers the link.
|
||||||
|
# https://zensical.org/docs/authoring/tooltips/#improved-tooltips
|
||||||
|
"content.tooltips",
|
||||||
|
|
||||||
|
# With this feature enabled, Zensical will automatically hide parts
|
||||||
|
# of the header when the user scrolls past a certain point.
|
||||||
|
# https://zensical.org/docs/setup/header/#automatic-hiding
|
||||||
|
# "header.autohide",
|
||||||
|
|
||||||
|
# Turn on this feature to expand all collapsible sections in the
|
||||||
|
# navigation sidebar by default.
|
||||||
|
# https://zensical.org/docs/setup/navigation/#navigation-expansion
|
||||||
|
# "navigation.expand",
|
||||||
|
|
||||||
|
# This feature turns on navigation elements in the footer that allow the
|
||||||
|
# user to navigate to a next or previous page.
|
||||||
|
# https://zensical.org/docs/setup/footer/#navigation
|
||||||
|
"navigation.footer",
|
||||||
|
|
||||||
|
# When section index pages are enabled, documents can be directly attached
|
||||||
|
# to sections, which is particularly useful for providing overview pages.
|
||||||
|
# https://zensical.org/docs/setup/navigation/#section-index-pages
|
||||||
|
"navigation.indexes",
|
||||||
|
|
||||||
|
# When instant navigation is enabled, clicks on all internal links will be
|
||||||
|
# intercepted and dispatched via XHR without fully reloading the page.
|
||||||
|
# https://zensical.org/docs/setup/navigation/#instant-navigation
|
||||||
|
"navigation.instant",
|
||||||
|
|
||||||
|
# With instant prefetching, your site will start to fetch a page once the
|
||||||
|
# user hovers over a link. This will reduce the perceived loading time
|
||||||
|
# for the user.
|
||||||
|
# https://zensical.org/docs/setup/navigation/#instant-prefetching
|
||||||
|
"navigation.instant.prefetch",
|
||||||
|
|
||||||
|
# In order to provide a better user experience on slow connections when
|
||||||
|
# using instant navigation, a progress indicator can be enabled.
|
||||||
|
# https://zensical.org/docs/setup/navigation/#progress-indicator
|
||||||
|
#"navigation.instant.progress",
|
||||||
|
|
||||||
|
# When navigation paths are activated, a breadcrumb navigation is rendered
|
||||||
|
# above the title of each page
|
||||||
|
# https://zensical.org/docs/setup/navigation/#navigation-path
|
||||||
|
"navigation.path",
|
||||||
|
|
||||||
|
# When pruning is enabled, only the visible navigation items are included
|
||||||
|
# in the rendered HTML, reducing the size of the built site by 33% or more.
|
||||||
|
# https://zensical.org/docs/setup/navigation/#navigation-pruning
|
||||||
|
#"navigation.prune",
|
||||||
|
|
||||||
|
# When sections are enabled, top-level sections are rendered as groups in
|
||||||
|
# the sidebar for viewports above 1220px, but remain as-is on mobile.
|
||||||
|
# https://zensical.org/docs/setup/navigation/#navigation-sections
|
||||||
|
"navigation.sections",
|
||||||
|
|
||||||
|
# When tabs are enabled, top-level sections are rendered in a menu layer
|
||||||
|
# below the header for viewports above 1220px, but remain as-is on mobile.
|
||||||
|
# https://zensical.org/docs/setup/navigation/#navigation-tabs
|
||||||
|
"navigation.tabs",
|
||||||
|
|
||||||
|
# When sticky tabs are enabled, navigation tabs will lock below the header
|
||||||
|
# and always remain visible when scrolling down.
|
||||||
|
# https://zensical.org/docs/setup/navigation/#sticky-navigation-tabs
|
||||||
|
#"navigation.tabs.sticky",
|
||||||
|
|
||||||
|
# A back-to-top button can be shown when the user, after scrolling down,
|
||||||
|
# starts to scroll up again.
|
||||||
|
# https://zensical.org/docs/setup/navigation/#back-to-top-button
|
||||||
|
"navigation.top",
|
||||||
|
|
||||||
|
# When anchor tracking is enabled, the URL in the address bar is
|
||||||
|
# automatically updated with the active anchor as highlighted in the table
|
||||||
|
# of contents.
|
||||||
|
# https://zensical.org/docs/setup/navigation/#anchor-tracking
|
||||||
|
"navigation.tracking",
|
||||||
|
|
||||||
|
# When search highlighting is enabled and a user clicks on a search result,
|
||||||
|
# Zensical will highlight all occurrences after following the link.
|
||||||
|
# https://zensical.org/docs/setup/search/#search-highlighting
|
||||||
|
"search.highlight",
|
||||||
|
|
||||||
|
# When anchor following for the table of contents is enabled, the sidebar
|
||||||
|
# is automatically scrolled so that the active anchor is always visible.
|
||||||
|
# https://zensical.org/docs/setup/navigation/#anchor-following
|
||||||
|
# "toc.follow",
|
||||||
|
|
||||||
|
# When navigation integration for the table of contents is enabled, it is
|
||||||
|
# always rendered as part of the navigation sidebar on the left.
|
||||||
|
# https://zensical.org/docs/setup/navigation/#navigation-integration
|
||||||
|
#"toc.integrate",
|
||||||
|
]
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
# In the "palette" subsection you can configure options for the color scheme.
|
||||||
|
# You can configure different color # schemes, e.g., to turn on dark mode,
|
||||||
|
# that the user can switch between. Each color scheme can be further
|
||||||
|
# customized.
|
||||||
|
#
|
||||||
|
# Read more:
|
||||||
|
# - https://zensical.org/docs/setup/colors/
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
[[project.theme.palette]]
|
||||||
|
scheme = "default"
|
||||||
|
toggle.icon = "lucide/sun"
|
||||||
|
toggle.name = "Switch to dark mode"
|
||||||
|
|
||||||
|
[[project.theme.palette]]
|
||||||
|
scheme = "slate"
|
||||||
|
toggle.icon = "lucide/moon"
|
||||||
|
toggle.name = "Switch to light mode"
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
# In the "font" subsection you can configure the fonts used. By default, fonts
|
||||||
|
# are loaded from Google Fonts, giving you a wide range of choices from a set
|
||||||
|
# of suitably licensed fonts. There are options for a normal text font and for
|
||||||
|
# a monospaced font used in code blocks.
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
[project.theme.font]
|
||||||
|
text = "Inter"
|
||||||
|
code = "Jetbrains Mono"
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
# You can configure your own logo to be shown in the header using the "logo"
|
||||||
|
# option in the "icons" subsection. The logo can be a path to a file in your
|
||||||
|
# "docs_dir" or it can be a path to an icon.
|
||||||
|
#
|
||||||
|
# Likewise, you can customize the logo used for the repository section of the
|
||||||
|
# header. Zensical derives the default logo for this from the repository URL.
|
||||||
|
# See below...
|
||||||
|
#
|
||||||
|
# There are other icons you can customize. See the documentation for details.
|
||||||
|
#
|
||||||
|
# Read more:
|
||||||
|
# - https://zensical.org/docs/setup/logo-and-icons
|
||||||
|
# - https://zensical.org/docs/authoring/icons-emojis/#search
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
[project.theme.icon]
|
||||||
|
#logo = "lucide/smile"
|
||||||
|
repo = "fontawesome/brands/github"
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
# The "extra" section contains miscellaneous settings.
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
#[[project.extra.social]]
|
||||||
|
#icon = "fontawesome/brands/github"
|
||||||
|
#link = "https://github.com/user/repo"
|
||||||
|
|
||||||
|
|
||||||
|
[project.plugins.mkdocstrings.handlers.python]
|
||||||
|
inventories = ["https://docs.python.org/3/objects.inv"]
|
||||||
|
paths = ["src"]
|
||||||
|
|
||||||
|
[project.plugins.mkdocstrings.handlers.python.options]
|
||||||
|
docstring_style = "google"
|
||||||
|
inherited_members = true
|
||||||
|
show_source = false
|
||||||
|
show_root_heading = true
|
||||||
Reference in New Issue
Block a user