9 Commits

Author SHA1 Message Date
7ec407834a Version 1.1.2 2026-02-23 14:59:33 -05:00
d3vyce
7da34f33a2 fix: handle Date, Float, Numeric cursor column types in cursor_paginate (#90) 2026-02-23 20:58:43 +01:00
8c8911fb27 Version 1.1.1 2026-02-23 11:04:10 -05:00
d3vyce
c0c3b38054 fix: NameError in cli/config.py (#88) 2026-02-23 14:40:36 +01:00
e17d385910 Version 1.1.0 2026-02-23 08:09:59 -05:00
d3vyce
6cf7df55ef feat: add cursor based pagination in CrudFactory (#86) 2026-02-23 13:51:34 +01:00
d3vyce
7482bc5dad feat: add schema parameter to CRUD methods for typed response serialization (#84) 2026-02-23 10:02:52 +01:00
d3vyce
9d07dfea85 feat: add opt-in default_load_options parameter in CrudFactory (#82)
* feat: add opt-in default_load_options parameter in CrudFactory

* docs: add Relationship loading in CRUD
2026-02-21 12:35:15 +01:00
d3vyce
31678935aa Version 1.0.0 (#80)
* docs: fix typos

* chore: build docs only when release

* Version 1.0.0
2026-02-20 14:09:01 +01:00
21 changed files with 2007 additions and 221 deletions

View File

@@ -1,12 +1,14 @@
name: Documentation name: Documentation
on: on:
push: release:
branches: types: [published]
- main
permissions: permissions:
contents: read contents: read
pages: write pages: write
id-token: write id-token: write
jobs: jobs:
deploy: deploy:
environment: environment:

View File

@@ -1,6 +1,6 @@
# CLI # CLI
Typer-based command-line interface for managing your FastAPI application, with built-in fixture loading. Typer-based command-line interface for managing your FastAPI application, with built-in fixture commands integration.
## Installation ## Installation
@@ -16,7 +16,7 @@ Typer-based command-line interface for managing your FastAPI application, with b
## Overview ## Overview
The `cli` module provides a `manager` entry point built with [Typer](https://typer.tiangolo.com/). It auto-discovers fixture commands when a [`FixtureRegistry`](../reference/fixtures.md#fastapi_toolsets.fixtures.registry.FixtureRegistry) and a database context are configured. The `cli` module provides a `manager` entry point built with [Typer](https://typer.tiangolo.com/). It allow custom commands to be added in addition of the fixture commands when a [`FixtureRegistry`](../reference/fixtures.md#fastapi_toolsets.fixtures.registry.FixtureRegistry) and a database context are configured.
## Configuration ## Configuration
@@ -24,24 +24,48 @@ Configure the CLI in your `pyproject.toml`:
```toml ```toml
[tool.fastapi-toolsets] [tool.fastapi-toolsets]
cli = "myapp.cli:cli" # optional: your custom Typer app cli = "myapp.cli:cli" # Custom Typer app
fixtures = "myapp.fixtures:registry" # FixtureRegistry instance fixtures = "myapp.fixtures:registry" # FixtureRegistry instance
db_context = "myapp.db:db_context" # async context manager for sessions db_context = "myapp.db:db_context" # Async context manager for sessions
``` ```
All fields are optional. Without configuration, the `manager` command still works but only includes the built-in commands. All fields are optional. Without configuration, the `manager` command still works but no command are available.
## Usage ## Usage
```bash ```bash
# List available commands # Manager commands
manager --help manager --help
# Load fixtures for a specific context Usage: manager [OPTIONS] COMMAND [ARGS]...
manager fixtures load --context testing
# Load all fixtures (no context filter) FastAPI utilities CLI.
manager fixtures load
╭─ Options ────────────────────────────────────────────────────────────────────────╮
│ --install-completion Install completion for the current shell. │
│ --show-completion Show completion for the current shell, to copy it │
│ or customize the installation. │
│ --help Show this message and exit. │
╰──────────────────────────────────────────────────────────────────────────────────╯
╭─ Commands ───────────────────────────────────────────────────────────────────────╮
│ check-db │
│ fixtures Manage database fixtures. │
╰──────────────────────────────────────────────────────────────────────────────────╯
# Fixtures commands
manager fixtures --help
Usage: manager fixtures [OPTIONS] COMMAND [ARGS]...
Manage database fixtures.
╭─ Options ────────────────────────────────────────────────────────────────────────╮
│ --help Show this message and exit. │
╰──────────────────────────────────────────────────────────────────────────────────╯
╭─ Commands ───────────────────────────────────────────────────────────────────────╮
│ list List all registered fixtures. │
│ load Load fixtures into the database. │
╰──────────────────────────────────────────────────────────────────────────────────╯
``` ```
## Custom CLI ## Custom CLI
@@ -64,14 +88,6 @@ def hello():
cli = "myapp.cli:cli" cli = "myapp.cli:cli"
``` ```
## Entry point
The `manager` script is registered automatically when the package is installed:
```bash
manager --help
```
--- ---
[:material-api: API Reference](../reference/cli.md) [:material-api: API Reference](../reference/cli.md)

View File

@@ -2,6 +2,9 @@
Generic async CRUD operations for SQLAlchemy models with search, pagination, and many-to-many support. Generic async CRUD operations for SQLAlchemy models with search, pagination, and many-to-many support.
!!! info
This module has been coded and tested to be compatible with PostgreSQL only.
## Overview ## Overview
The `crud` module provides [`AsyncCrud`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud), an abstract base class with a full suite of async database operations, and [`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory), a convenience function to instantiate it for a given model. The `crud` module provides [`AsyncCrud`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud), an abstract base class with a full suite of async database operations, and [`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory), a convenience function to instantiate it for a given model.
@@ -12,10 +15,7 @@ The `crud` module provides [`AsyncCrud`](../reference/crud.md#fastapi_toolsets.c
from fastapi_toolsets.crud import CrudFactory from fastapi_toolsets.crud import CrudFactory
from myapp.models import User from myapp.models import User
UserCrud = CrudFactory( UserCrud = CrudFactory(model=User)
User,
searchable_fields=[User.username, User.email],
)
``` ```
[`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory) dynamically creates a class named `AsyncUserCrud` with `User` as its model. [`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory) dynamically creates a class named `AsyncUserCrud` with `User` as its model.
@@ -24,42 +24,146 @@ UserCrud = CrudFactory(
```python ```python
# Create # Create
user = await UserCrud.create(session, obj=UserCreateSchema(username="alice")) user = await UserCrud.create(session=session, obj=UserCreateSchema(username="alice"))
# Get one (raises NotFoundError if not found) # Get one (raises NotFoundError if not found)
user = await UserCrud.get(session, filters=[User.id == user_id]) user = await UserCrud.get(session=session, filters=[User.id == user_id])
# Get first or None # Get first or None
user = await UserCrud.first(session, filters=[User.email == email]) user = await UserCrud.first(session=session, filters=[User.email == email])
# Get multiple # Get multiple
users = await UserCrud.get_multi(session, filters=[User.is_active == True]) users = await UserCrud.get_multi(session=session, filters=[User.is_active == True])
# Update # Update
user = await UserCrud.update(session, obj=UserUpdateSchema(username="bob"), filters=[User.id == user_id]) user = await UserCrud.update(session=session, obj=UserUpdateSchema(username="bob"), filters=[User.id == user_id])
# Delete # Delete
await UserCrud.delete(session, filters=[User.id == user_id]) await UserCrud.delete(session=session, filters=[User.id == user_id])
# Count / exists # Count / exists
count = await UserCrud.count(session, filters=[User.is_active == True]) count = await UserCrud.count(session=session, filters=[User.is_active == True])
exists = await UserCrud.exists(session, filters=[User.email == email]) exists = await UserCrud.exists(session=session, filters=[User.email == email])
``` ```
## Pagination ## Pagination
!!! info "Added in `v1.1` (only offset_pagination via `paginate` if `<v1.1`)"
Two pagination strategies are available. Both return a [`PaginatedResponse`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse) but differ in how they navigate through results.
| | `offset_paginate` | `cursor_paginate` |
|---|---|---|
| Total count | Yes | No |
| Jump to arbitrary page | Yes | No |
| Performance on deep pages | Degrades | Constant |
| Stable under concurrent inserts | No | Yes |
| Search compatible | Yes | Yes |
| Use case | Admin panels, numbered pagination | Feeds, APIs, infinite scroll |
### Offset pagination
```python ```python
result = await UserCrud.paginate( @router.get(
session, "",
filters=[User.is_active == True], response_model=PaginatedResponse[User],
order_by=[User.created_at.desc()],
page=1,
items_per_page=20,
search="alice",
search_fields=[User.username, User.email],
) )
# result.data: list of users async def get_users(
# result.pagination: Pagination(total_count, items_per_page, page, has_more) session: SessionDep,
items_per_page: int = 50,
page: int = 1,
):
return await crud.UserCrud.offset_paginate(
session=session,
items_per_page=items_per_page,
page=page,
)
```
The [`offset_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.offset_paginate) method returns a [`PaginatedResponse`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse) whose `pagination` field is an [`OffsetPagination`](../reference/schemas.md#fastapi_toolsets.schemas.OffsetPagination) object:
```json
{
"status": "SUCCESS",
"data": ["..."],
"pagination": {
"total_count": 100,
"page": 1,
"items_per_page": 20,
"has_more": true
}
}
```
!!! warning "Deprecated: `paginate`"
The `paginate` function is a backward-compatible alias for `offset_paginate`. This function is **deprecated** and will be removed in **v2.0**.
### Cursor pagination
```python
@router.get(
"",
response_model=PaginatedResponse[UserRead],
)
async def list_users(
session: SessionDep,
cursor: str | None = None,
items_per_page: int = 20,
):
return await UserCrud.cursor_paginate(
session=session,
cursor=cursor,
items_per_page=items_per_page,
)
```
The [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate) method returns a [`PaginatedResponse`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse) whose `pagination` field is a [`CursorPagination`](../reference/schemas.md#fastapi_toolsets.schemas.CursorPagination) object:
```json
{
"status": "SUCCESS",
"data": ["..."],
"pagination": {
"next_cursor": "eyJ2YWx1ZSI6ICIzZjQ3YWM2OS0uLi4ifQ==",
"prev_cursor": null,
"items_per_page": 20,
"has_more": true
}
}
```
Pass `next_cursor` as the `cursor` query parameter on the next request to advance to the next page. `prev_cursor` is set on pages 2+ and points back to the first item of the current page. Both are `null` when there is no adjacent page.
#### Choosing a cursor column
The cursor column is set once on [`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory) via the `cursor_column` parameter. It must be monotonically ordered for stable results:
- Auto-increment integer PKs
- UUID v7 PKs
- Timestamps
!!! warning
Random UUID v4 PKs are **not** suitable as cursor columns because their ordering is non-deterministic.
!!! note
`cursor_column` is required. Calling [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate) on a CRUD class that has no `cursor_column` configured raises a `ValueError`.
The cursor value is base64-encoded when returned to the client and decoded back to the correct Python type on the next request. The following SQLAlchemy column types are supported:
| SQLAlchemy type | Python type |
|---|---|
| `Integer`, `BigInteger`, `SmallInteger` | `int` |
| `Uuid` | `uuid.UUID` |
| `DateTime` | `datetime.datetime` |
| `Date` | `datetime.date` |
| `Float`, `Numeric` | `decimal.Decimal` |
```python
# Paginate by the primary key
PostCrud = CrudFactory(model=Post, cursor_column=Post.id)
# Paginate by a timestamp column instead
PostCrud = CrudFactory(model=Post, cursor_column=Post.created_at)
``` ```
## Search ## Search
@@ -68,7 +172,7 @@ Declare searchable fields on the CRUD class. Relationship traversal is supported
```python ```python
PostCrud = CrudFactory( PostCrud = CrudFactory(
Post, model=Post,
searchable_fields=[ searchable_fields=[
Post.title, Post.title,
Post.content, Post.content,
@@ -77,18 +181,92 @@ PostCrud = CrudFactory(
) )
``` ```
This allows searching with both [`offset_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.offset_paginate) and [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate):
```python
@router.get(
"",
response_model=PaginatedResponse[User],
)
async def get_users(
session: SessionDep,
items_per_page: int = 50,
page: int = 1,
search: str | None = None,
):
return await crud.UserCrud.offset_paginate(
session=session,
items_per_page=items_per_page,
page=page,
search=search,
)
```
```python
@router.get(
"",
response_model=PaginatedResponse[User],
)
async def get_users(
session: SessionDep,
cursor: str | None = None,
items_per_page: int = 50,
search: str | None = None,
):
return await crud.UserCrud.cursor_paginate(
session=session,
items_per_page=items_per_page,
cursor=cursor,
search=search,
)
```
## Relationship loading
!!! info "Added in `v1.1`"
By default, SQLAlchemy relationships are not loaded unless explicitly requested. Instead of using `lazy="selectin"` on model definitions (which is implicit and applies globally), define a `default_load_options` on the CRUD class to control loading strategy explicitly.
!!! warning
Avoid using `lazy="selectin"` on model relationships. It fires silently on every query, cannot be disabled per-call, and can cause unexpected cascading loads through deep relationship chains. Use `default_load_options` instead.
```python
from sqlalchemy.orm import selectinload
ArticleCrud = CrudFactory(
model=Article,
default_load_options=[
selectinload(Article.category),
selectinload(Article.tags),
],
)
```
`default_load_options` applies automatically to all read operations (`get`, `first`, `get_multi`, `offset_paginate`, `cursor_paginate`). When `load_options` is passed at call-site, it **fully replaces** `default_load_options` for that query — giving you precise per-call control:
```python
# Only loads category, tags are not loaded
article = await ArticleCrud.get(
session=session,
filters=[Article.id == article_id],
load_options=[selectinload(Article.category)],
)
# Loads nothing — useful for write-then-refresh flows or lightweight checks
articles = await ArticleCrud.get_multi(session=session, load_options=[])
```
## Many-to-many relationships ## Many-to-many relationships
Use `m2m_fields` to map schema fields containing lists of IDs to SQLAlchemy relationships. The CRUD class resolves and validates all IDs before persisting: Use `m2m_fields` to map schema fields containing lists of IDs to SQLAlchemy relationships. The CRUD class resolves and validates all IDs before persisting:
```python ```python
PostCrud = CrudFactory( PostCrud = CrudFactory(
Post, model=Post,
m2m_fields={"tag_ids": Post.tags}, m2m_fields={"tag_ids": Post.tags},
) )
# schema: PostCreateSchema(title="Hello", tag_ids=[1, 2, 3]) post = await PostCrud.create(session=session, obj=PostCreateSchema(title="Hello", tag_ids=[1, 2, 3]))
post = await PostCrud.create(session, obj=PostCreateSchema(...))
``` ```
## Upsert ## Upsert
@@ -97,20 +275,49 @@ Atomic `INSERT ... ON CONFLICT DO UPDATE` using PostgreSQL:
```python ```python
await UserCrud.upsert( await UserCrud.upsert(
session, session=session,
obj=UserCreateSchema(email="alice@example.com", username="alice"), obj=UserCreateSchema(email="alice@example.com", username="alice"),
index_elements=[User.email], index_elements=[User.email],
set_={"username"}, set_={"username"},
) )
``` ```
## `as_response` ## `schema` — typed response serialization
Pass `as_response=True` to any write operation to get a [`Response[ModelType]`](../reference/schemas.md#fastapi_toolsets.schemas.Response) back directly: !!! info "Added in `v1.1`"
Pass a Pydantic schema class to `create`, `get`, `update`, or `offset_paginate` to serialize the result directly into that schema and wrap it in a [`Response[schema]`](../reference/schemas.md#fastapi_toolsets.schemas.Response) or [`PaginatedResponse[schema]`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse):
```python ```python
response = await UserCrud.create(session, obj=schema, as_response=True) class UserRead(PydanticBase):
# response: Response[User] id: UUID
username: str
@router.get(
"/{uuid}",
responses=generate_error_responses(NotFoundError),
)
async def get_user(session: SessionDep, uuid: UUID) -> Response[UserRead]:
return await crud.UserCrud.get(
session=session,
filters=[User.id == uuid],
schema=UserRead,
)
@router.get("")
async def list_users(session: SessionDep, page: int = 1) -> PaginatedResponse[UserRead]:
return await crud.UserCrud.offset_paginate(
session=session,
page=page,
schema=UserRead,
)
``` ```
The schema must have `from_attributes=True` (or inherit from [`PydanticBase`](../reference/schemas.md#fastapi_toolsets.schemas.PydanticBase)) so it can be built from SQLAlchemy model instances.
!!! warning "Deprecated: `as_response`"
The `as_response=True` parameter is **deprecated** and will be removed in **v2.0**. Replace it with `schema=YourSchema`.
---
[:material-api: API Reference](../reference/crud.md) [:material-api: API Reference](../reference/crud.md)

View File

@@ -2,9 +2,12 @@
SQLAlchemy async session management with transactions, table locking, and row-change polling. SQLAlchemy async session management with transactions, table locking, and row-change polling.
!!! info
This module has been coded and tested to be compatible with PostgreSQL only.
## Overview ## Overview
The `db` module provides helpers to create FastAPI dependencies and context managers for `AsyncSession`, along with utilities for nested transactions, PostgreSQL advisory locks, and polling for row changes. The `db` module provides helpers to create FastAPI dependencies and context managers for `AsyncSession`, along with utilities for nested transactions, table lock and polling for row changes.
## Session dependency ## Session dependency
@@ -14,10 +17,10 @@ Use [`create_db_dependency`](../reference/db.md#fastapi_toolsets.db.create_db_de
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from fastapi_toolsets.db import create_db_dependency from fastapi_toolsets.db import create_db_dependency
engine = create_async_engine("postgresql+asyncpg://...") engine = create_async_engine(url="postgresql+asyncpg://...", future=True)
session_maker = async_sessionmaker(engine) session_maker = async_sessionmaker(bind=engine, expire_on_commit=False)
get_db = create_db_dependency(session_maker) get_db = create_db_dependency(session_maker=session_maker)
@router.get("/users") @router.get("/users")
async def list_users(session: AsyncSession = Depends(get_db)): async def list_users(session: AsyncSession = Depends(get_db)):
@@ -31,11 +34,11 @@ Use [`create_db_context`](../reference/db.md#fastapi_toolsets.db.create_db_conte
```python ```python
from fastapi_toolsets.db import create_db_context from fastapi_toolsets.db import create_db_context
db_context = create_db_context(session_maker) db_context = create_db_context(session_maker=session_maker)
async def seed(): async def seed():
async with db_context() as session: async with db_context() as session:
session.add(User(name="admin")) ...
``` ```
## Nested transactions ## Nested transactions
@@ -45,11 +48,11 @@ async def seed():
```python ```python
from fastapi_toolsets.db import get_transaction from fastapi_toolsets.db import get_transaction
async def create_user_with_role(session): async def create_user_with_role(session=session):
async with get_transaction(session): async with get_transaction(session=session):
session.add(role) ...
async with get_transaction(session): # uses savepoint async with get_transaction(session=session): # uses savepoint
session.add(user) ...
``` ```
## Table locking ## Table locking
@@ -59,7 +62,7 @@ async def create_user_with_role(session):
```python ```python
from fastapi_toolsets.db import lock_tables from fastapi_toolsets.db import lock_tables
async with lock_tables(session, tables=[User], mode="EXCLUSIVE"): async with lock_tables(session=session, tables=[User], mode="EXCLUSIVE"):
# No other transaction can modify User until this block exits # No other transaction can modify User until this block exits
... ...
``` ```
@@ -75,7 +78,7 @@ from fastapi_toolsets.db import wait_for_row_change
# Wait up to 30s for order.status to change # Wait up to 30s for order.status to change
await wait_for_row_change( await wait_for_row_change(
session, session=session,
model=Order, model=Order,
pk_value=order_id, pk_value=order_id,
columns=[Order.status], columns=[Order.status],

View File

@@ -13,17 +13,17 @@ The `dependencies` module provides two factory functions that create FastAPI dep
```python ```python
from fastapi_toolsets.dependencies import PathDependency from fastapi_toolsets.dependencies import PathDependency
UserDep = PathDependency(User, User.id, session_dep=get_db) UserDep = PathDependency(model=User, field=User.id, session_dep=get_db)
@router.get("/users/{user_id}") @router.get("/users/{user_id}")
async def get_user(user: User = UserDep): async def get_user(user: User = UserDep):
return user return user
``` ```
The parameter name is inferred from the field (`user_id` for `User.id`). You can override it: By default the parameter name is inferred from the field (`user_id` for `User.id`). You can override it:
```python ```python
UserDep = PathDependency(User, User.id, session_dep=get_db, param_name="id") UserDep = PathDependency(model=User, field=User.id, session_dep=get_db, param_name="id")
@router.get("/users/{id}") @router.get("/users/{id}")
async def get_user(user: User = UserDep): async def get_user(user: User = UserDep):
@@ -37,7 +37,7 @@ async def get_user(user: User = UserDep):
```python ```python
from fastapi_toolsets.dependencies import BodyDependency from fastapi_toolsets.dependencies import BodyDependency
RoleDep = BodyDependency(Role, Role.id, session_dep=get_db, body_field="role_id") RoleDep = BodyDependency(model=Role, field=Role.id, session_dep=get_db, body_field="role_id")
@router.post("/users") @router.post("/users")
async def create_user(body: UserCreateSchema, role: Role = RoleDep): async def create_user(body: UserCreateSchema, role: Role = RoleDep):
@@ -45,4 +45,6 @@ async def create_user(body: UserCreateSchema, role: Role = RoleDep):
... ...
``` ```
---
[:material-api: API Reference](../reference/dependencies.md) [:material-api: API Reference](../reference/dependencies.md)

View File

@@ -4,7 +4,7 @@ Structured API exceptions with consistent error responses and automatic OpenAPI
## Overview ## Overview
The `exceptions` module provides a set of pre-built HTTP exceptions and a FastAPI exception handler that formats all errors — including validation errors — into a uniform [`ErrorResponse`](../reference/schemas.md#fastapi_toolsets.schemas.ErrorResponse) shape. The `exceptions` module provides a set of pre-built HTTP exceptions and a FastAPI exception handler that formats all errors — including validation errors — into a uniform [`ErrorResponse`](../reference/schemas.md#fastapi_toolsets.schemas.ErrorResponse).
## Setup ## Setup
@@ -15,7 +15,7 @@ from fastapi import FastAPI
from fastapi_toolsets.exceptions import init_exceptions_handlers from fastapi_toolsets.exceptions import init_exceptions_handlers
app = FastAPI() app = FastAPI()
init_exceptions_handlers(app) init_exceptions_handlers(app=app)
``` ```
This registers handlers for: This registers handlers for:
@@ -36,11 +36,11 @@ This registers handlers for:
| [`NoSearchableFieldsError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.NoSearchableFieldsError) | 400 | No searchable fields | | [`NoSearchableFieldsError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.NoSearchableFieldsError) | 400 | No searchable fields |
```python ```python
from fastapi_toolsets.exceptions import NotFoundError, ForbiddenError from fastapi_toolsets.exceptions import NotFoundError
@router.get("/users/{id}") @router.get("/users/{id}")
async def get_user(id: int, session: AsyncSession = Depends(get_db)): async def get_user(id: int, session: AsyncSession = Depends(get_db)):
user = await UserCrud.first(session, filters=[User.id == id]) user = await UserCrud.first(session=session, filters=[User.id == id])
if not user: if not user:
raise NotFoundError raise NotFoundError
return user return user
@@ -77,6 +77,9 @@ from fastapi_toolsets.exceptions import generate_error_responses, NotFoundError,
async def get_user(...): ... async def get_user(...): ...
``` ```
!!! info
The pydantic validation error is automatically added by FastAPI.
--- ---
[:material-api: API Reference](../reference/exceptions.md) [:material-api: API Reference](../reference/exceptions.md)

View File

@@ -32,22 +32,22 @@ Dependencies declared via `depends_on` are resolved topologically — `roles` wi
## Loading fixtures ## Loading fixtures
### By context By context with [`load_fixtures_by_context`](../reference/fixtures.md#fastapi_toolsets.fixtures.utils.load_fixtures_by_context):
```python ```python
from fastapi_toolsets.fixtures import load_fixtures_by_context from fastapi_toolsets.fixtures import load_fixtures_by_context
async with db_context() as session: async with db_context() as session:
await load_fixtures_by_context(session, registry=fixtures, context=Context.TESTING) await load_fixtures_by_context(session=session, registry=fixtures, context=Context.TESTING)
``` ```
### Directly Directly with [`load_fixtures`](../reference/fixtures.md#fastapi_toolsets.fixtures.utils.load_fixtures):
```python ```python
from fastapi_toolsets.fixtures import load_fixtures from fastapi_toolsets.fixtures import load_fixtures
async with db_context() as session: async with db_context() as session:
await load_fixtures(session, registry=fixtures) await load_fixtures(session=session, registry=fixtures)
``` ```
## Contexts ## Contexts
@@ -60,7 +60,7 @@ async with db_context() as session:
| `Context.TESTING` | Data only loaded during tests | | `Context.TESTING` | Data only loaded during tests |
| `Context.PRODUCTION` | Data only loaded in production | | `Context.PRODUCTION` | Data only loaded in production |
A fixture with no `contexts` argument is loaded in all contexts. A fixture with no `contexts` defined takes `Context.BASE` by default.
## Load strategies ## Load strategies
@@ -72,9 +72,21 @@ A fixture with no `contexts` argument is loaded in all contexts.
| `LoadStrategy.UPSERT` | Insert or update on conflict | | `LoadStrategy.UPSERT` | Insert or update on conflict |
| `LoadStrategy.SKIP` | Skip rows that already exist | | `LoadStrategy.SKIP` | Skip rows that already exist |
## Merging registries
Split fixtures definitions across modules and merge them:
```python
from myapp.fixtures.dev import dev_fixtures
from myapp.fixtures.prod import prod_fixtures
fixtures = fixturesRegistry()
fixtures.include_registry(registry=dev_fixtures)
fixtures.include_registry(registry=prod_fixtures)
## Pytest integration ## Pytest integration
Use [`register_fixtures`](../reference/pytest.md#fastapi_toolsets.pytest.plugin.register_fixtures) to expose each fixture in your registry as an injectable pytest fixture named `fixture_{name}`: Use [`register_fixtures`](../reference/pytest.md#fastapi_toolsets.pytest.plugin.register_fixtures) to expose each fixture in your registry as an injectable pytest fixture named `fixture_{name}` by default:
```python ```python
# conftest.py # conftest.py
@@ -95,10 +107,8 @@ register_fixtures(registry=registry, namespace=globals())
```python ```python
# test_users.py # test_users.py
async def test_user_can_login(fixture_users, fixture_roles, client): async def test_user_can_login(fixture_users: list[User], fixture_roles: list[Role]):
# fixture_roles is loaded first (dependency), then fixture_users ...
response = await client.post("/auth/login", json={"username": "alice"})
assert response.status_code == 200
``` ```

View File

@@ -34,4 +34,6 @@ When called without arguments, [`get_logger`](../reference/logger.md#fastapi_too
logger = get_logger() logger = get_logger()
``` ```
---
[:material-api: API Reference](../reference/logger.md) [:material-api: API Reference](../reference/logger.md)

View File

@@ -27,7 +27,7 @@ from fastapi_toolsets.metrics import MetricsRegistry, init_metrics
app = FastAPI() app = FastAPI()
metrics = MetricsRegistry() metrics = MetricsRegistry()
init_metrics(app, registry=metrics) init_metrics(app=app, registry=metrics)
``` ```
This mounts the `/metrics` endpoint that Prometheus can scrape. This mounts the `/metrics` endpoint that Prometheus can scrape.
@@ -70,8 +70,8 @@ from myapp.metrics.http import http_metrics
from myapp.metrics.db import db_metrics from myapp.metrics.db import db_metrics
metrics = MetricsRegistry() metrics = MetricsRegistry()
metrics.include_registry(http_metrics) metrics.include_registry(registry=http_metrics)
metrics.include_registry(db_metrics) metrics.include_registry(registry=db_metrics)
``` ```
## Multi-process mode ## Multi-process mode
@@ -81,10 +81,6 @@ Multi-process support is enabled automatically when the `PROMETHEUS_MULTIPROC_DI
!!! warning "Environment variable name" !!! warning "Environment variable name"
The correct variable is `PROMETHEUS_MULTIPROC_DIR` (not `PROMETHEUS_MULTIPROCESS_DIR`). The correct variable is `PROMETHEUS_MULTIPROC_DIR` (not `PROMETHEUS_MULTIPROCESS_DIR`).
```bash
export PROMETHEUS_MULTIPROC_DIR=/tmp/prometheus
```
--- ---
[:material-api: API Reference](../reference/metrics.md) [:material-api: API Reference](../reference/metrics.md)

View File

@@ -26,8 +26,15 @@ Use [`create_async_client`](../reference/pytest.md#fastapi_toolsets.pytest.utils
from fastapi_toolsets.pytest import create_async_client from fastapi_toolsets.pytest import create_async_client
@pytest.fixture @pytest.fixture
async def client(app): async def http_client(db_session):
async with create_async_client(app=app) as c: async def _override_get_db():
yield db_session
async with create_async_client(
app=app,
base_url="http://127.0.0.1/api/v1",
dependency_overrides={get_db: _override_get_db},
) as c:
yield c yield c
``` ```
@@ -36,36 +43,34 @@ async def client(app):
Use [`create_db_session`](../reference/pytest.md#fastapi_toolsets.pytest.utils.create_db_session) to create an isolated `AsyncSession` for a test: Use [`create_db_session`](../reference/pytest.md#fastapi_toolsets.pytest.utils.create_db_session) to create an isolated `AsyncSession` for a test:
```python ```python
from fastapi_toolsets.pytest import create_db_session from fastapi_toolsets.pytest import create_db_session, create_worker_database
@pytest.fixture
async def db_session():
async with create_db_session(database_url=DATABASE_URL, base=Base, cleanup=True) as session:
yield session
```
## Parallel testing with pytest-xdist
When running tests in parallel, each worker needs its own database. Use these helpers to create and identify worker databases:
```python
from fastapi_toolsets.pytest import create_worker_database, create_db_session
# In conftest.py session-scoped fixture
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
async def worker_db_url(): async def worker_db_url():
async with create_worker_database(database_url=DATABASE_URL) as url: async with create_worker_database(
database_url=str(settings.SQLALCHEMY_DATABASE_URI)
) as url:
yield url yield url
@pytest.fixture @pytest.fixture
async def db_session(worker_db_url): async def db_session(worker_db_url):
async with create_db_session(database_url=worker_db_url, base=Base, cleanup=True) as session: async with create_db_session(
database_url=worker_db_url, base=Base, cleanup=True
) as session:
yield session yield session
``` ```
!!! info
In this example, the database is reset between each test using the argument `cleanup=True`.
## Parallel testing with pytest-xdist
The examples above are already compatible with parallel test execution with `pytest-xdist`.
## Cleaning up tables ## Cleaning up tables
[`cleanup_tables`](../reference/pytest.md#fastapi_toolsets.pytest.utils.cleanup_tables) truncates all tables between tests for fast isolation: If you want to manually clean up a database you can use [`cleanup_tables`](../reference/pytest.md#fastapi_toolsets.pytest.utils.cleanup_tables), this will truncates all tables between tests for fast isolation:
```python ```python
from fastapi_toolsets.pytest import cleanup_tables from fastapi_toolsets.pytest import cleanup_tables

View File

@@ -8,7 +8,7 @@ The `schemas` module provides generic response wrappers that enforce a uniform r
## Response models ## Response models
### `Response[T]` ### [`Response[T]`](../reference/schemas.md#fastapi_toolsets.schemas.Response)
The most common wrapper for a single resource response. The most common wrapper for a single resource response.
@@ -20,7 +20,7 @@ async def get_user(user: User = UserDep) -> Response[UserSchema]:
return Response(data=user, message="User retrieved") return Response(data=user, message="User retrieved")
``` ```
### `PaginatedResponse[T]` ### [`PaginatedResponse[T]`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse)
Wraps a list of items with pagination metadata. Wraps a list of items with pagination metadata.
@@ -40,15 +40,10 @@ async def list_users() -> PaginatedResponse[UserSchema]:
) )
``` ```
### `ErrorResponse` ### [`ErrorResponse`](../reference/schemas.md#fastapi_toolsets.schemas.ErrorResponse)
Returned automatically by the exceptions handler. Can also be used as a response model for OpenAPI docs. Returned automatically by the exceptions handler.
```python ---
from fastapi_toolsets.schemas import ErrorResponse
@router.delete("/users/{id}", responses={404: {"model": ErrorResponse}})
async def delete_user(...): ...
```
[:material-api: API Reference](../reference/schemas.md) [:material-api: API Reference](../reference/schemas.md)

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "fastapi-toolsets" name = "fastapi-toolsets"
version = "0.10.0" version = "1.1.2"
description = "Reusable tools for FastAPI: async CRUD, fixtures, CLI, and standardized responses for SQLAlchemy + PostgreSQL" description = "Reusable tools for FastAPI: async CRUD, fixtures, CLI, and standardized responses for SQLAlchemy + PostgreSQL"
readme = "README.md" readme = "README.md"
license = "MIT" license = "MIT"

View File

@@ -21,4 +21,4 @@ Example usage:
return Response(data={"user": user.username}, message="Success") return Response(data={"user": user.username}, message="Success")
""" """
__version__ = "0.10.0" __version__ = "1.1.2"

View File

@@ -1,5 +1,7 @@
"""CLI configuration and dynamic imports.""" """CLI configuration and dynamic imports."""
from __future__ import annotations
import importlib import importlib
import sys import sys
from typing import TYPE_CHECKING, Any, Literal, overload from typing import TYPE_CHECKING, Any, Literal, overload

View File

@@ -2,28 +2,46 @@
from __future__ import annotations from __future__ import annotations
import base64
import json
import uuid as uuid_module
import warnings
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from datetime import date, datetime
from decimal import Decimal
from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import and_, func, select from sqlalchemy import Date, DateTime, Float, Integer, Numeric, Uuid, and_, func, select
from sqlalchemy import delete as sql_delete from sqlalchemy import delete as sql_delete
from sqlalchemy.dialects.postgresql import insert from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.exc import NoResultFound from sqlalchemy.exc import NoResultFound
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase, QueryableAttribute, selectinload from sqlalchemy.orm import DeclarativeBase, QueryableAttribute, selectinload
from sqlalchemy.sql.base import ExecutableOption
from sqlalchemy.sql.roles import WhereHavingRole from sqlalchemy.sql.roles import WhereHavingRole
from ..db import get_transaction from ..db import get_transaction
from ..exceptions import NotFoundError from ..exceptions import NotFoundError
from ..schemas import PaginatedResponse, Pagination, Response from ..schemas import CursorPagination, OffsetPagination, PaginatedResponse, Response
from .search import SearchConfig, SearchFieldType, build_search_filters from .search import SearchConfig, SearchFieldType, build_search_filters
ModelType = TypeVar("ModelType", bound=DeclarativeBase) ModelType = TypeVar("ModelType", bound=DeclarativeBase)
SchemaType = TypeVar("SchemaType", bound=BaseModel)
JoinType = list[tuple[type[DeclarativeBase], Any]] JoinType = list[tuple[type[DeclarativeBase], Any]]
M2MFieldType = Mapping[str, QueryableAttribute[Any]] M2MFieldType = Mapping[str, QueryableAttribute[Any]]
def _encode_cursor(value: Any) -> str:
"""Encode cursor column value as an base64 string."""
return base64.b64encode(json.dumps(str(value)).encode()).decode()
def _decode_cursor(cursor: str) -> str:
"""Decode cursor base64 string."""
return json.loads(base64.b64decode(cursor.encode()).decode())
class AsyncCrud(Generic[ModelType]): class AsyncCrud(Generic[ModelType]):
"""Generic async CRUD operations for SQLAlchemy models. """Generic async CRUD operations for SQLAlchemy models.
@@ -33,26 +51,17 @@ class AsyncCrud(Generic[ModelType]):
model: ClassVar[type[DeclarativeBase]] model: ClassVar[type[DeclarativeBase]]
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
m2m_fields: ClassVar[M2MFieldType | None] = None m2m_fields: ClassVar[M2MFieldType | None] = None
default_load_options: ClassVar[list[ExecutableOption] | None] = None
cursor_column: ClassVar[Any | None] = None
@overload
@classmethod @classmethod
async def create( # pragma: no cover def _resolve_load_options(
cls: type[Self], cls, load_options: list[ExecutableOption] | None
session: AsyncSession, ) -> list[ExecutableOption] | None:
obj: BaseModel, """Return load_options if provided, else fall back to default_load_options."""
*, if load_options is not None:
as_response: Literal[True], return load_options
) -> Response[ModelType]: ... return cls.default_load_options
@overload
@classmethod
async def create( # pragma: no cover
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
*,
as_response: Literal[False] = ...,
) -> ModelType: ...
@classmethod @classmethod
async def _resolve_m2m( async def _resolve_m2m(
@@ -110,6 +119,40 @@ class AsyncCrud(Generic[ModelType]):
return set() return set()
return set(cls.m2m_fields.keys()) return set(cls.m2m_fields.keys())
@overload
@classmethod
async def create( # pragma: no cover
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
*,
schema: type[SchemaType],
as_response: bool = ...,
) -> Response[SchemaType]: ...
# Backward-compatible - will be removed in v2.0
@overload
@classmethod
async def create( # pragma: no cover
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
*,
as_response: Literal[True],
schema: None = ...,
) -> Response[ModelType]: ...
@overload
@classmethod
async def create( # pragma: no cover
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
*,
as_response: Literal[False] = ...,
schema: None = ...,
) -> ModelType: ...
@classmethod @classmethod
async def create( async def create(
cls: type[Self], cls: type[Self],
@@ -117,17 +160,28 @@ class AsyncCrud(Generic[ModelType]):
obj: BaseModel, obj: BaseModel,
*, *,
as_response: bool = False, as_response: bool = False,
) -> ModelType | Response[ModelType]: schema: type[BaseModel] | None = None,
) -> ModelType | Response[ModelType] | Response[Any]:
"""Create a new record in the database. """Create a new record in the database.
Args: Args:
session: DB async session session: DB async session
obj: Pydantic model with data to create obj: Pydantic model with data to create
as_response: If True, wrap result in Response object as_response: Deprecated. Use ``schema`` instead. Will be removed in v2.0.
schema: Pydantic schema to serialize the result into. When provided,
the result is automatically wrapped in a ``Response[schema]``.
Returns: Returns:
Created model instance or Response wrapping it Created model instance, or ``Response[schema]`` when ``schema`` is given,
or ``Response[ModelType]`` when ``as_response=True`` (deprecated).
""" """
if as_response and schema is None:
warnings.warn(
"as_response is deprecated and will be removed in v2.0. "
"Use schema=YourSchema instead.",
DeprecationWarning,
stacklevel=2,
)
async with get_transaction(session): async with get_transaction(session):
m2m_exclude = cls._m2m_schema_fields() m2m_exclude = cls._m2m_schema_fields()
data = ( data = (
@@ -143,8 +197,9 @@ class AsyncCrud(Generic[ModelType]):
session.add(db_model) session.add(db_model)
await session.refresh(db_model) await session.refresh(db_model)
result = cast(ModelType, db_model) result = cast(ModelType, db_model)
if as_response: if as_response or schema:
return Response(data=result) data_out = schema.model_validate(result) if schema else result
return Response(data=data_out)
return result return result
@overload @overload
@@ -157,8 +212,25 @@ class AsyncCrud(Generic[ModelType]):
joins: JoinType | None = None, joins: JoinType | None = None,
outer_join: bool = False, outer_join: bool = False,
with_for_update: bool = False, with_for_update: bool = False,
load_options: list[Any] | None = None, load_options: list[ExecutableOption] | None = None,
schema: type[SchemaType],
as_response: bool = ...,
) -> Response[SchemaType]: ...
# Backward-compatible - will be removed in v2.0
@overload
@classmethod
async def get( # pragma: no cover
cls: type[Self],
session: AsyncSession,
filters: list[Any],
*,
joins: JoinType | None = None,
outer_join: bool = False,
with_for_update: bool = False,
load_options: list[ExecutableOption] | None = None,
as_response: Literal[True], as_response: Literal[True],
schema: None = ...,
) -> Response[ModelType]: ... ) -> Response[ModelType]: ...
@overload @overload
@@ -171,8 +243,9 @@ class AsyncCrud(Generic[ModelType]):
joins: JoinType | None = None, joins: JoinType | None = None,
outer_join: bool = False, outer_join: bool = False,
with_for_update: bool = False, with_for_update: bool = False,
load_options: list[Any] | None = None, load_options: list[ExecutableOption] | None = None,
as_response: Literal[False] = ..., as_response: Literal[False] = ...,
schema: None = ...,
) -> ModelType: ... ) -> ModelType: ...
@classmethod @classmethod
@@ -184,9 +257,10 @@ class AsyncCrud(Generic[ModelType]):
joins: JoinType | None = None, joins: JoinType | None = None,
outer_join: bool = False, outer_join: bool = False,
with_for_update: bool = False, with_for_update: bool = False,
load_options: list[Any] | None = None, load_options: list[ExecutableOption] | None = None,
as_response: bool = False, as_response: bool = False,
) -> ModelType | Response[ModelType]: schema: type[BaseModel] | None = None,
) -> ModelType | Response[ModelType] | Response[Any]:
"""Get exactly one record. Raises NotFoundError if not found. """Get exactly one record. Raises NotFoundError if not found.
Args: Args:
@@ -196,15 +270,25 @@ class AsyncCrud(Generic[ModelType]):
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
with_for_update: Lock the row for update with_for_update: Lock the row for update
load_options: SQLAlchemy loader options (e.g., selectinload) load_options: SQLAlchemy loader options (e.g., selectinload)
as_response: If True, wrap result in Response object as_response: Deprecated. Use ``schema`` instead. Will be removed in v2.0.
schema: Pydantic schema to serialize the result into. When provided,
the result is automatically wrapped in a ``Response[schema]``.
Returns: Returns:
Model instance or Response wrapping it Model instance, or ``Response[schema]`` when ``schema`` is given,
or ``Response[ModelType]`` when ``as_response=True`` (deprecated).
Raises: Raises:
NotFoundError: If no record found NotFoundError: If no record found
MultipleResultsFound: If more than one record found MultipleResultsFound: If more than one record found
""" """
if as_response and schema is None:
warnings.warn(
"as_response is deprecated and will be removed in v2.0. "
"Use schema=YourSchema instead.",
DeprecationWarning,
stacklevel=2,
)
q = select(cls.model) q = select(cls.model)
if joins: if joins:
for model, condition in joins: for model, condition in joins:
@@ -214,8 +298,8 @@ class AsyncCrud(Generic[ModelType]):
else q.join(model, condition) else q.join(model, condition)
) )
q = q.where(and_(*filters)) q = q.where(and_(*filters))
if load_options: if resolved := cls._resolve_load_options(load_options):
q = q.options(*load_options) q = q.options(*resolved)
if with_for_update: if with_for_update:
q = q.with_for_update() q = q.with_for_update()
result = await session.execute(q) result = await session.execute(q)
@@ -223,8 +307,9 @@ class AsyncCrud(Generic[ModelType]):
if not item: if not item:
raise NotFoundError() raise NotFoundError()
result = cast(ModelType, item) result = cast(ModelType, item)
if as_response: if as_response or schema:
return Response(data=result) data_out = schema.model_validate(result) if schema else result
return Response(data=data_out)
return result return result
@classmethod @classmethod
@@ -235,7 +320,7 @@ class AsyncCrud(Generic[ModelType]):
*, *,
joins: JoinType | None = None, joins: JoinType | None = None,
outer_join: bool = False, outer_join: bool = False,
load_options: list[Any] | None = None, load_options: list[ExecutableOption] | None = None,
) -> ModelType | None: ) -> ModelType | None:
"""Get the first matching record, or None. """Get the first matching record, or None.
@@ -259,8 +344,8 @@ class AsyncCrud(Generic[ModelType]):
) )
if filters: if filters:
q = q.where(and_(*filters)) q = q.where(and_(*filters))
if load_options: if resolved := cls._resolve_load_options(load_options):
q = q.options(*load_options) q = q.options(*resolved)
result = await session.execute(q) result = await session.execute(q)
return cast(ModelType | None, result.unique().scalars().first()) return cast(ModelType | None, result.unique().scalars().first())
@@ -272,7 +357,7 @@ class AsyncCrud(Generic[ModelType]):
filters: list[Any] | None = None, filters: list[Any] | None = None,
joins: JoinType | None = None, joins: JoinType | None = None,
outer_join: bool = False, outer_join: bool = False,
load_options: list[Any] | None = None, load_options: list[ExecutableOption] | None = None,
order_by: Any | None = None, order_by: Any | None = None,
limit: int | None = None, limit: int | None = None,
offset: int | None = None, offset: int | None = None,
@@ -302,8 +387,8 @@ class AsyncCrud(Generic[ModelType]):
) )
if filters: if filters:
q = q.where(and_(*filters)) q = q.where(and_(*filters))
if load_options: if resolved := cls._resolve_load_options(load_options):
q = q.options(*load_options) q = q.options(*resolved)
if order_by is not None: if order_by is not None:
q = q.order_by(order_by) q = q.order_by(order_by)
if offset is not None: if offset is not None:
@@ -313,6 +398,21 @@ class AsyncCrud(Generic[ModelType]):
result = await session.execute(q) result = await session.execute(q)
return cast(Sequence[ModelType], result.unique().scalars().all()) return cast(Sequence[ModelType], result.unique().scalars().all())
@overload
@classmethod
async def update( # pragma: no cover
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
filters: list[Any],
*,
exclude_unset: bool = True,
exclude_none: bool = False,
schema: type[SchemaType],
as_response: bool = ...,
) -> Response[SchemaType]: ...
# Backward-compatible - will be removed in v2.0
@overload @overload
@classmethod @classmethod
async def update( # pragma: no cover async def update( # pragma: no cover
@@ -324,6 +424,7 @@ class AsyncCrud(Generic[ModelType]):
exclude_unset: bool = True, exclude_unset: bool = True,
exclude_none: bool = False, exclude_none: bool = False,
as_response: Literal[True], as_response: Literal[True],
schema: None = ...,
) -> Response[ModelType]: ... ) -> Response[ModelType]: ...
@overload @overload
@@ -337,6 +438,7 @@ class AsyncCrud(Generic[ModelType]):
exclude_unset: bool = True, exclude_unset: bool = True,
exclude_none: bool = False, exclude_none: bool = False,
as_response: Literal[False] = ..., as_response: Literal[False] = ...,
schema: None = ...,
) -> ModelType: ... ) -> ModelType: ...
@classmethod @classmethod
@@ -349,7 +451,8 @@ class AsyncCrud(Generic[ModelType]):
exclude_unset: bool = True, exclude_unset: bool = True,
exclude_none: bool = False, exclude_none: bool = False,
as_response: bool = False, as_response: bool = False,
) -> ModelType | Response[ModelType]: schema: type[BaseModel] | None = None,
) -> ModelType | Response[ModelType] | Response[Any]:
"""Update a record in the database. """Update a record in the database.
Args: Args:
@@ -358,20 +461,30 @@ class AsyncCrud(Generic[ModelType]):
filters: List of SQLAlchemy filter conditions filters: List of SQLAlchemy filter conditions
exclude_unset: Exclude fields not explicitly set in the schema exclude_unset: Exclude fields not explicitly set in the schema
exclude_none: Exclude fields with None value exclude_none: Exclude fields with None value
as_response: If True, wrap result in Response object as_response: Deprecated. Use ``schema`` instead. Will be removed in v2.0.
schema: Pydantic schema to serialize the result into. When provided,
the result is automatically wrapped in a ``Response[schema]``.
Returns: Returns:
Updated model instance or Response wrapping it Updated model instance, or ``Response[schema]`` when ``schema`` is given,
or ``Response[ModelType]`` when ``as_response=True`` (deprecated).
Raises: Raises:
NotFoundError: If no record found NotFoundError: If no record found
""" """
if as_response and schema is None:
warnings.warn(
"as_response is deprecated and will be removed in v2.0. "
"Use schema=YourSchema instead.",
DeprecationWarning,
stacklevel=2,
)
async with get_transaction(session): async with get_transaction(session):
m2m_exclude = cls._m2m_schema_fields() m2m_exclude = cls._m2m_schema_fields()
# Eagerly load M2M relationships that will be updated so that # Eagerly load M2M relationships that will be updated so that
# setattr does not trigger a lazy load (which fails in async). # setattr does not trigger a lazy load (which fails in async).
m2m_load_options: list[Any] = [] m2m_load_options: list[ExecutableOption] = []
if m2m_exclude and cls.m2m_fields: if m2m_exclude and cls.m2m_fields:
for schema_field, rel in cls.m2m_fields.items(): for schema_field, rel in cls.m2m_fields.items():
if schema_field in obj.model_fields_set: if schema_field in obj.model_fields_set:
@@ -395,8 +508,9 @@ class AsyncCrud(Generic[ModelType]):
for rel_attr, related_instances in m2m_resolved.items(): for rel_attr, related_instances in m2m_resolved.items():
setattr(db_model, rel_attr, related_instances) setattr(db_model, rel_attr, related_instances)
await session.refresh(db_model) await session.refresh(db_model)
if as_response: if as_response or schema:
return Response(data=db_model) data_out = schema.model_validate(db_model) if schema else db_model
return Response(data=data_out)
return db_model return db_model
@classmethod @classmethod
@@ -478,11 +592,20 @@ class AsyncCrud(Generic[ModelType]):
Args: Args:
session: DB async session session: DB async session
filters: List of SQLAlchemy filter conditions filters: List of SQLAlchemy filter conditions
as_response: If True, wrap result in Response object as_response: Deprecated. Will be removed in v2.0. When ``True``,
returns ``Response[None]`` instead of ``bool``.
Returns: Returns:
True if deletion was executed, or Response wrapping it ``True`` if deletion was executed, or ``Response[None]`` when
``as_response=True`` (deprecated).
""" """
if as_response:
warnings.warn(
"as_response is deprecated and will be removed in v2.0. "
"Use schema=YourSchema instead.",
DeprecationWarning,
stacklevel=2,
)
async with get_transaction(session): async with get_transaction(session):
q = sql_delete(cls.model).where(and_(*filters)) q = sql_delete(cls.model).where(and_(*filters))
await session.execute(q) await session.execute(q)
@@ -555,22 +678,60 @@ class AsyncCrud(Generic[ModelType]):
result = await session.execute(q) result = await session.execute(q)
return bool(result.scalar()) return bool(result.scalar())
@overload
@classmethod @classmethod
async def paginate( async def offset_paginate( # pragma: no cover
cls: type[Self], cls: type[Self],
session: AsyncSession, session: AsyncSession,
*, *,
filters: list[Any] | None = None, filters: list[Any] | None = None,
joins: JoinType | None = None, joins: JoinType | None = None,
outer_join: bool = False, outer_join: bool = False,
load_options: list[Any] | None = None, load_options: list[ExecutableOption] | None = None,
order_by: Any | None = None, order_by: Any | None = None,
page: int = 1, page: int = 1,
items_per_page: int = 20, items_per_page: int = 20,
search: str | SearchConfig | None = None, search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None, search_fields: Sequence[SearchFieldType] | None = None,
) -> PaginatedResponse[ModelType]: schema: type[SchemaType],
"""Get paginated results with metadata. ) -> PaginatedResponse[SchemaType]: ...
# Backward-compatible - will be removed in v2.0
@overload
@classmethod
async def offset_paginate( # pragma: no cover
cls: type[Self],
session: AsyncSession,
*,
filters: list[Any] | None = None,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: list[ExecutableOption] | None = None,
order_by: Any | None = None,
page: int = 1,
items_per_page: int = 20,
search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None,
schema: None = ...,
) -> PaginatedResponse[ModelType]: ...
@classmethod
async def offset_paginate(
cls: type[Self],
session: AsyncSession,
*,
filters: list[Any] | None = None,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: list[ExecutableOption] | None = None,
order_by: Any | None = None,
page: int = 1,
items_per_page: int = 20,
search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None,
schema: type[BaseModel] | None = None,
) -> PaginatedResponse[ModelType] | PaginatedResponse[Any]:
"""Get paginated results using offset-based pagination.
Args: Args:
session: DB async session session: DB async session
@@ -583,9 +744,10 @@ class AsyncCrud(Generic[ModelType]):
items_per_page: Number of items per page items_per_page: Number of items per page
search: Search query string or SearchConfig object search: Search query string or SearchConfig object
search_fields: Fields to search in (overrides class default) search_fields: Fields to search in (overrides class default)
schema: Optional Pydantic schema to serialize each item into.
Returns: Returns:
Dict with 'data' and 'pagination' keys PaginatedResponse with OffsetPagination metadata
""" """
filters = list(filters) if filters else [] filters = list(filters) if filters else []
offset = (page - 1) * items_per_page offset = (page - 1) * items_per_page
@@ -619,14 +781,17 @@ class AsyncCrud(Generic[ModelType]):
if filters: if filters:
q = q.where(and_(*filters)) q = q.where(and_(*filters))
if load_options: if resolved := cls._resolve_load_options(load_options):
q = q.options(*load_options) q = q.options(*resolved)
if order_by is not None: if order_by is not None:
q = q.order_by(order_by) q = q.order_by(order_by)
q = q.offset(offset).limit(items_per_page) q = q.offset(offset).limit(items_per_page)
result = await session.execute(q) result = await session.execute(q)
items = cast(list[ModelType], result.unique().scalars().all()) raw_items = cast(list[ModelType], result.unique().scalars().all())
items: list[Any] = (
[schema.model_validate(item) for item in raw_items] if schema else raw_items
)
# Count query (with same joins and filters) # Count query (with same joins and filters)
pk_col = cls.model.__mapper__.primary_key[0] pk_col = cls.model.__mapper__.primary_key[0]
@@ -654,7 +819,7 @@ class AsyncCrud(Generic[ModelType]):
return PaginatedResponse( return PaginatedResponse(
data=items, data=items,
pagination=Pagination( pagination=OffsetPagination(
total_count=total_count, total_count=total_count,
items_per_page=items_per_page, items_per_page=items_per_page,
page=page, page=page,
@@ -662,12 +827,193 @@ class AsyncCrud(Generic[ModelType]):
), ),
) )
# Backward-compatible - will be removed in v2.0
paginate = offset_paginate
@overload
@classmethod
async def cursor_paginate( # pragma: no cover
cls: type[Self],
session: AsyncSession,
*,
cursor: str | None = None,
filters: list[Any] | None = None,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: list[ExecutableOption] | None = None,
order_by: Any | None = None,
items_per_page: int = 20,
search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None,
schema: type[SchemaType],
) -> PaginatedResponse[SchemaType]: ...
# Backward-compatible - will be removed in v2.0
@overload
@classmethod
async def cursor_paginate( # pragma: no cover
cls: type[Self],
session: AsyncSession,
*,
cursor: str | None = None,
filters: list[Any] | None = None,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: list[ExecutableOption] | None = None,
order_by: Any | None = None,
items_per_page: int = 20,
search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None,
schema: None = ...,
) -> PaginatedResponse[ModelType]: ...
@classmethod
async def cursor_paginate(
cls: type[Self],
session: AsyncSession,
*,
cursor: str | None = None,
filters: list[Any] | None = None,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: list[ExecutableOption] | None = None,
order_by: Any | None = None,
items_per_page: int = 20,
search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None,
schema: type[BaseModel] | None = None,
) -> PaginatedResponse[ModelType] | PaginatedResponse[Any]:
"""Get paginated results using cursor-based pagination.
Args:
session: DB async session.
cursor: Cursor string from a previous ``CursorPagination``.
Omit (or pass ``None``) to start from the beginning.
filters: List of SQLAlchemy filter conditions.
joins: List of ``(model, condition)`` tuples for joining related
tables.
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN.
load_options: SQLAlchemy loader options. Falls back to
``default_load_options`` when not provided.
order_by: Additional ordering applied after the cursor column.
items_per_page: Number of items per page (default 20).
search: Search query string or SearchConfig object.
search_fields: Fields to search in (overrides class default).
schema: Optional Pydantic schema to serialize each item into.
Returns:
PaginatedResponse with CursorPagination metadata
"""
filters = list(filters) if filters else []
search_joins: list[Any] = []
if cls.cursor_column is None:
raise ValueError(
f"{cls.__name__}.cursor_column is not set. "
"Pass cursor_column=<column> to CrudFactory() to use cursor_paginate."
)
cursor_column: Any = cls.cursor_column
cursor_col_name: str = cursor_column.key
if cursor is not None:
raw_val = _decode_cursor(cursor)
col_type = cursor_column.property.columns[0].type
if isinstance(col_type, Integer):
cursor_val: Any = int(raw_val)
elif isinstance(col_type, Uuid):
cursor_val = uuid_module.UUID(raw_val)
elif isinstance(col_type, DateTime):
cursor_val = datetime.fromisoformat(raw_val)
elif isinstance(col_type, Date):
cursor_val = date.fromisoformat(raw_val)
elif isinstance(col_type, (Float, Numeric)):
cursor_val = Decimal(raw_val)
else:
raise ValueError(
f"Unsupported cursor column type: {type(col_type).__name__!r}. "
"Supported types: Integer, BigInteger, SmallInteger, Uuid, "
"DateTime, Date, Float, Numeric."
)
filters.append(cursor_column > cursor_val)
# Build search filters
if search:
search_filters, search_joins = build_search_filters(
cls.model,
search,
search_fields=search_fields,
default_fields=cls.searchable_fields,
)
filters.extend(search_filters)
# Build query
q = select(cls.model)
# Apply explicit joins
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
# Apply search joins (always outer joins)
for join_rel in search_joins:
q = q.outerjoin(join_rel)
if filters:
q = q.where(and_(*filters))
if resolved := cls._resolve_load_options(load_options):
q = q.options(*resolved)
# Cursor column is always the primary sort
q = q.order_by(cursor_column)
if order_by is not None:
q = q.order_by(order_by)
# Fetch one extra to detect whether a next page exists
q = q.limit(items_per_page + 1)
result = await session.execute(q)
raw_items = cast(list[ModelType], result.unique().scalars().all())
has_more = len(raw_items) > items_per_page
items_page = raw_items[:items_per_page]
# next_cursor points past the last item on this page
next_cursor: str | None = None
if has_more and items_page:
next_cursor = _encode_cursor(getattr(items_page[-1], cursor_col_name))
# prev_cursor points to the first item on this page or None when on the first page
prev_cursor: str | None = None
if cursor is not None and items_page:
prev_cursor = _encode_cursor(getattr(items_page[0], cursor_col_name))
items: list[Any] = (
[schema.model_validate(item) for item in items_page]
if schema
else items_page
)
return PaginatedResponse(
data=items,
pagination=CursorPagination(
next_cursor=next_cursor,
prev_cursor=prev_cursor,
items_per_page=items_per_page,
has_more=has_more,
),
)
def CrudFactory( def CrudFactory(
model: type[ModelType], model: type[ModelType],
*, *,
searchable_fields: Sequence[SearchFieldType] | None = None, searchable_fields: Sequence[SearchFieldType] | None = None,
m2m_fields: M2MFieldType | None = None, m2m_fields: M2MFieldType | None = None,
default_load_options: list[ExecutableOption] | None = None,
cursor_column: Any | None = None,
) -> type[AsyncCrud[ModelType]]: ) -> type[AsyncCrud[ModelType]]:
"""Create a CRUD class for a specific model. """Create a CRUD class for a specific model.
@@ -677,6 +1023,14 @@ def CrudFactory(
m2m_fields: Optional mapping for many-to-many relationships. m2m_fields: Optional mapping for many-to-many relationships.
Maps schema field names (containing lists of IDs) to Maps schema field names (containing lists of IDs) to
SQLAlchemy relationship attributes. SQLAlchemy relationship attributes.
default_load_options: Default SQLAlchemy loader options applied to all read
queries when no explicit ``load_options`` are passed. Use this
instead of ``lazy="selectin"`` on the model so that loading
strategy is explicit and per-CRUD. Overridden entirely (not
merged) when ``load_options`` is provided at call-site.
cursor_column: Required to call ``cursor_paginate``.
Must be monotonically ordered (e.g. integer PK, UUID v7, timestamp).
See the cursor pagination docs for supported column types.
Returns: Returns:
AsyncCrud subclass bound to the model AsyncCrud subclass bound to the model
@@ -702,6 +1056,25 @@ def CrudFactory(
m2m_fields={"tag_ids": Post.tags}, m2m_fields={"tag_ids": Post.tags},
) )
# With a fixed cursor column for cursor_paginate:
PostCrud = CrudFactory(
Post,
cursor_column=Post.created_at,
)
# With default load strategy (replaces lazy="selectin" on the model):
ArticleCrud = CrudFactory(
Article,
default_load_options=[selectinload(Article.category), selectinload(Article.tags)],
)
# Override default_load_options for a specific call:
article = await ArticleCrud.get(
session,
[Article.id == 1],
load_options=[selectinload(Article.category)], # tags won't load
)
# Usage # Usage
user = await UserCrud.get(session, [User.id == 1]) user = await UserCrud.get(session, [User.id == 1])
posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id]) posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id])
@@ -710,7 +1083,7 @@ def CrudFactory(
post = await PostCrud.create(session, PostCreate(title="Hello", tag_ids=[id1, id2])) post = await PostCrud.create(session, PostCreate(title="Hello", tag_ids=[id1, id2]))
# With search # With search
result = await UserCrud.paginate(session, search="john") result = await UserCrud.offset_paginate(session, search="john")
# With joins (inner join by default): # With joins (inner join by default):
users = await UserCrud.get_multi( users = await UserCrud.get_multi(
@@ -734,6 +1107,8 @@ def CrudFactory(
"model": model, "model": model,
"searchable_fields": searchable_fields, "searchable_fields": searchable_fields,
"m2m_fields": m2m_fields, "m2m_fields": m2m_fields,
"default_load_options": default_load_options,
"cursor_column": cursor_column,
}, },
) )
return cast(type[AsyncCrud[ModelType]], cls) return cast(type[AsyncCrud[ModelType]], cls)

View File

@@ -7,7 +7,9 @@ from pydantic import BaseModel, ConfigDict
__all__ = [ __all__ = [
"ApiError", "ApiError",
"CursorPagination",
"ErrorResponse", "ErrorResponse",
"OffsetPagination",
"Pagination", "Pagination",
"PaginatedResponse", "PaginatedResponse",
"PydanticBase", "PydanticBase",
@@ -90,8 +92,8 @@ class ErrorResponse(BaseResponse):
data: Any | None = None data: Any | None = None
class Pagination(PydanticBase): class OffsetPagination(PydanticBase):
"""Pagination metadata for list responses. """Pagination metadata for offset-based list responses.
Attributes: Attributes:
total_count: Total number of items across all pages total_count: Total number of items across all pages
@@ -106,17 +108,28 @@ class Pagination(PydanticBase):
has_more: bool has_more: bool
class PaginatedResponse(BaseResponse, Generic[DataT]): # Backward-compatible - will be removed in v2.0
"""Paginated API response for list endpoints. Pagination = OffsetPagination
Example:
```python class CursorPagination(PydanticBase):
PaginatedResponse[UserRead]( """Pagination metadata for cursor-based list responses.
data=users,
pagination=Pagination(total_count=100, items_per_page=10, page=1, has_more=True) Attributes:
) next_cursor: Encoded cursor for the next page, or None on the last page.
``` prev_cursor: Encoded cursor for the previous page, or None on the first page.
items_per_page: Number of items requested per page.
has_more: Whether there is at least one more page after this one.
""" """
next_cursor: str | None
prev_cursor: str | None = None
items_per_page: int
has_more: bool
class PaginatedResponse(BaseResponse, Generic[DataT]):
"""Paginated API response for list endpoints."""
data: list[DataT] data: list[DataT]
pagination: Pagination pagination: OffsetPagination | CursorPagination

View File

@@ -5,11 +5,25 @@ import uuid
import pytest import pytest
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import Column, ForeignKey, String, Table, Uuid import datetime
import decimal
from sqlalchemy import (
Column,
Date,
DateTime,
ForeignKey,
Integer,
Numeric,
String,
Table,
Uuid,
)
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from fastapi_toolsets.crud import CrudFactory from fastapi_toolsets.crud import CrudFactory
from fastapi_toolsets.schemas import PydanticBase
DATABASE_URL = os.getenv( DATABASE_URL = os.getenv(
key="DATABASE_URL", key="DATABASE_URL",
@@ -69,6 +83,36 @@ post_tags = Table(
) )
class IntRole(Base):
"""Test role model with auto-increment integer PK."""
__tablename__ = "int_roles"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
name: Mapped[str] = mapped_column(String(50), unique=True)
class Event(Base):
"""Test model with DateTime and Date cursor columns."""
__tablename__ = "events"
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
name: Mapped[str] = mapped_column(String(100))
occurred_at: Mapped[datetime.datetime] = mapped_column(DateTime)
scheduled_date: Mapped[datetime.date] = mapped_column(Date)
class Product(Base):
"""Test model with Numeric cursor column."""
__tablename__ = "products"
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
name: Mapped[str] = mapped_column(String(100))
price: Mapped[decimal.Decimal] = mapped_column(Numeric(10, 2))
class Post(Base): class Post(Base):
"""Test post model.""" """Test post model."""
@@ -90,6 +134,13 @@ class RoleCreate(BaseModel):
name: str name: str
class RoleRead(PydanticBase):
"""Schema for reading a role."""
id: uuid.UUID
name: str
class RoleUpdate(BaseModel): class RoleUpdate(BaseModel):
"""Schema for updating a role.""" """Schema for updating a role."""
@@ -106,6 +157,13 @@ class UserCreate(BaseModel):
role_id: uuid.UUID | None = None role_id: uuid.UUID | None = None
class UserRead(PydanticBase):
"""Schema for reading a user (subset of fields — no email)."""
id: uuid.UUID
username: str
class UserUpdate(BaseModel): class UserUpdate(BaseModel):
"""Schema for updating a user.""" """Schema for updating a user."""
@@ -160,11 +218,40 @@ class PostM2MUpdate(BaseModel):
tag_ids: list[uuid.UUID] | None = None tag_ids: list[uuid.UUID] | None = None
class IntRoleCreate(BaseModel):
"""Schema for creating an IntRole."""
name: str
class EventCreate(BaseModel):
"""Schema for creating an Event."""
name: str
occurred_at: datetime.datetime
scheduled_date: datetime.date
class ProductCreate(BaseModel):
"""Schema for creating a Product."""
name: str
price: decimal.Decimal
RoleCrud = CrudFactory(Role) RoleCrud = CrudFactory(Role)
RoleCursorCrud = CrudFactory(Role, cursor_column=Role.id)
IntRoleCursorCrud = CrudFactory(IntRole, cursor_column=IntRole.id)
UserCrud = CrudFactory(User) UserCrud = CrudFactory(User)
UserCursorCrud = CrudFactory(User, cursor_column=User.id)
PostCrud = CrudFactory(Post) PostCrud = CrudFactory(Post)
TagCrud = CrudFactory(Tag) TagCrud = CrudFactory(Tag)
PostM2MCrud = CrudFactory(Post, m2m_fields={"tag_ids": Post.tags}) PostM2MCrud = CrudFactory(Post, m2m_fields={"tag_ids": Post.tags})
EventCrud = CrudFactory(Event)
EventDateTimeCursorCrud = CrudFactory(Event, cursor_column=Event.occurred_at)
EventDateCursorCrud = CrudFactory(Event, cursor_column=Event.scheduled_date)
ProductCrud = CrudFactory(Product)
ProductNumericCursorCrud = CrudFactory(Product, cursor_column=Product.price)
@pytest.fixture @pytest.fixture

File diff suppressed because it is too large Load Diff

View File

@@ -6,6 +6,7 @@ import pytest
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from fastapi_toolsets.crud import SearchConfig, get_searchable_fields from fastapi_toolsets.crud import SearchConfig, get_searchable_fields
from fastapi_toolsets.schemas import OffsetPagination
from .conftest import ( from .conftest import (
Role, Role,
@@ -39,6 +40,7 @@ class TestPaginateSearch:
search_fields=[User.username], search_fields=[User.username],
) )
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 2 assert result.pagination.total_count == 2
@pytest.mark.anyio @pytest.mark.anyio
@@ -57,6 +59,7 @@ class TestPaginateSearch:
search_fields=[User.username, User.email], search_fields=[User.username, User.email],
) )
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 2 assert result.pagination.total_count == 2
@pytest.mark.anyio @pytest.mark.anyio
@@ -84,6 +87,7 @@ class TestPaginateSearch:
search_fields=[(User.role, Role.name)], search_fields=[(User.role, Role.name)],
) )
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 2 assert result.pagination.total_count == 2
@pytest.mark.anyio @pytest.mark.anyio
@@ -102,6 +106,7 @@ class TestPaginateSearch:
search_fields=[User.username, (User.role, Role.name)], search_fields=[User.username, (User.role, Role.name)],
) )
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 1 assert result.pagination.total_count == 1
@pytest.mark.anyio @pytest.mark.anyio
@@ -117,6 +122,7 @@ class TestPaginateSearch:
search_fields=[User.username], search_fields=[User.username],
) )
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 1 assert result.pagination.total_count == 1
@pytest.mark.anyio @pytest.mark.anyio
@@ -132,6 +138,7 @@ class TestPaginateSearch:
search=SearchConfig(query="johndoe", case_sensitive=True), search=SearchConfig(query="johndoe", case_sensitive=True),
search_fields=[User.username], search_fields=[User.username],
) )
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 0 assert result.pagination.total_count == 0
# Should find (case match) # Should find (case match)
@@ -140,6 +147,7 @@ class TestPaginateSearch:
search=SearchConfig(query="JohnDoe", case_sensitive=True), search=SearchConfig(query="JohnDoe", case_sensitive=True),
search_fields=[User.username], search_fields=[User.username],
) )
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 1 assert result.pagination.total_count == 1
@pytest.mark.anyio @pytest.mark.anyio
@@ -153,9 +161,11 @@ class TestPaginateSearch:
) )
result = await UserCrud.paginate(db_session, search="") result = await UserCrud.paginate(db_session, search="")
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 2 assert result.pagination.total_count == 2
result = await UserCrud.paginate(db_session, search=None) result = await UserCrud.paginate(db_session, search=None)
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 2 assert result.pagination.total_count == 2
@pytest.mark.anyio @pytest.mark.anyio
@@ -177,6 +187,7 @@ class TestPaginateSearch:
search_fields=[User.username], search_fields=[User.username],
) )
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 1 assert result.pagination.total_count == 1
assert result.data[0].username == "active_john" assert result.data[0].username == "active_john"
@@ -189,6 +200,7 @@ class TestPaginateSearch:
result = await UserCrud.paginate(db_session, search="findme") result = await UserCrud.paginate(db_session, search="findme")
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 1 assert result.pagination.total_count == 1
@pytest.mark.anyio @pytest.mark.anyio
@@ -204,6 +216,7 @@ class TestPaginateSearch:
search_fields=[User.username], search_fields=[User.username],
) )
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 0 assert result.pagination.total_count == 0
assert result.data == [] assert result.data == []
@@ -224,6 +237,7 @@ class TestPaginateSearch:
items_per_page=5, items_per_page=5,
) )
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 15 assert result.pagination.total_count == 15
assert len(result.data) == 5 assert len(result.data) == 5
assert result.pagination.has_more is True assert result.pagination.has_more is True
@@ -248,6 +262,7 @@ class TestPaginateSearch:
search_fields=[User.username], search_fields=[User.username],
) )
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 2 assert result.pagination.total_count == 2
@pytest.mark.anyio @pytest.mark.anyio
@@ -270,6 +285,7 @@ class TestPaginateSearch:
order_by=User.username, order_by=User.username,
) )
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 3 assert result.pagination.total_count == 3
usernames = [u.username for u in result.data] usernames = [u.username for u in result.data]
assert usernames == ["alice", "bob", "charlie"] assert usernames == ["alice", "bob", "charlie"]
@@ -292,6 +308,7 @@ class TestPaginateSearch:
search_fields=[User.id, User.username], search_fields=[User.id, User.username],
) )
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 1 assert result.pagination.total_count == 1
assert result.data[0].id == user_id assert result.data[0].id == user_id
@@ -318,6 +335,7 @@ class TestSearchConfig:
search_fields=[User.username, User.email], search_fields=[User.username, User.email],
) )
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 1 assert result.pagination.total_count == 1
assert result.data[0].username == "john_test" assert result.data[0].username == "john_test"
@@ -333,6 +351,7 @@ class TestSearchConfig:
search=SearchConfig(query="findme", fields=[User.email]), search=SearchConfig(query="findme", fields=[User.email]),
) )
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 1 assert result.pagination.total_count == 1

View File

@@ -5,7 +5,9 @@ from pydantic import ValidationError
from fastapi_toolsets.schemas import ( from fastapi_toolsets.schemas import (
ApiError, ApiError,
CursorPagination,
ErrorResponse, ErrorResponse,
OffsetPagination,
PaginatedResponse, PaginatedResponse,
Pagination, Pagination,
Response, Response,
@@ -154,12 +156,12 @@ class TestErrorResponse:
assert data["description"] == "Details" assert data["description"] == "Details"
class TestPagination: class TestOffsetPagination:
"""Tests for Pagination schema.""" """Tests for OffsetPagination schema (canonical name for offset-based pagination)."""
def test_create_pagination(self): def test_create_pagination(self):
"""Create Pagination with all fields.""" """Create OffsetPagination with all fields."""
pagination = Pagination( pagination = OffsetPagination(
total_count=100, total_count=100,
items_per_page=10, items_per_page=10,
page=1, page=1,
@@ -173,7 +175,7 @@ class TestPagination:
def test_last_page_has_more_false(self): def test_last_page_has_more_false(self):
"""Last page has has_more=False.""" """Last page has has_more=False."""
pagination = Pagination( pagination = OffsetPagination(
total_count=25, total_count=25,
items_per_page=10, items_per_page=10,
page=3, page=3,
@@ -183,8 +185,8 @@ class TestPagination:
assert pagination.has_more is False assert pagination.has_more is False
def test_serialization(self): def test_serialization(self):
"""Pagination serializes correctly.""" """OffsetPagination serializes correctly."""
pagination = Pagination( pagination = OffsetPagination(
total_count=50, total_count=50,
items_per_page=20, items_per_page=20,
page=2, page=2,
@@ -197,6 +199,77 @@ class TestPagination:
assert data["page"] == 2 assert data["page"] == 2
assert data["has_more"] is True assert data["has_more"] is True
def test_pagination_alias_is_offset_pagination(self):
"""Pagination is a backward-compatible alias for OffsetPagination."""
assert Pagination is OffsetPagination
def test_pagination_alias_constructs_offset_pagination(self):
"""Code using Pagination(...) still works unchanged."""
pagination = Pagination(
total_count=10,
items_per_page=5,
page=2,
has_more=False,
)
assert isinstance(pagination, OffsetPagination)
class TestCursorPagination:
"""Tests for CursorPagination schema."""
def test_create_with_next_cursor(self):
"""CursorPagination with a next cursor indicates more pages."""
pagination = CursorPagination(
next_cursor="eyJ2YWx1ZSI6ICIxMjMifQ==",
items_per_page=20,
has_more=True,
)
assert pagination.next_cursor == "eyJ2YWx1ZSI6ICIxMjMifQ=="
assert pagination.prev_cursor is None
assert pagination.items_per_page == 20
assert pagination.has_more is True
def test_create_last_page(self):
"""CursorPagination for the last page has next_cursor=None and has_more=False."""
pagination = CursorPagination(
next_cursor=None,
items_per_page=20,
has_more=False,
)
assert pagination.next_cursor is None
assert pagination.has_more is False
def test_prev_cursor_defaults_to_none(self):
"""prev_cursor defaults to None."""
pagination = CursorPagination(
next_cursor=None, items_per_page=10, has_more=False
)
assert pagination.prev_cursor is None
def test_prev_cursor_can_be_set(self):
"""prev_cursor can be explicitly set."""
pagination = CursorPagination(
next_cursor="next123",
prev_cursor="prev456",
items_per_page=10,
has_more=True,
)
assert pagination.prev_cursor == "prev456"
def test_serialization(self):
"""CursorPagination serializes correctly."""
pagination = CursorPagination(
next_cursor="abc123",
prev_cursor="xyz789",
items_per_page=20,
has_more=True,
)
data = pagination.model_dump()
assert data["next_cursor"] == "abc123"
assert data["prev_cursor"] == "xyz789"
assert data["items_per_page"] == 20
assert data["has_more"] is True
class TestPaginatedResponse: class TestPaginatedResponse:
"""Tests for PaginatedResponse schema.""" """Tests for PaginatedResponse schema."""
@@ -214,6 +287,7 @@ class TestPaginatedResponse:
pagination=pagination, pagination=pagination,
) )
assert isinstance(response.pagination, OffsetPagination)
assert len(response.data) == 2 assert len(response.data) == 2
assert response.pagination.total_count == 30 assert response.pagination.total_count == 30
assert response.status == ResponseStatus.SUCCESS assert response.status == ResponseStatus.SUCCESS
@@ -247,6 +321,7 @@ class TestPaginatedResponse:
pagination=pagination, pagination=pagination,
) )
assert isinstance(response.pagination, OffsetPagination)
assert response.data == [] assert response.data == []
assert response.pagination.total_count == 0 assert response.pagination.total_count == 0
@@ -290,6 +365,36 @@ class TestPaginatedResponse:
assert data["data"] == ["item1", "item2"] assert data["data"] == ["item1", "item2"]
assert data["pagination"]["page"] == 5 assert data["pagination"]["page"] == 5
def test_pagination_field_accepts_offset_pagination(self):
"""PaginatedResponse.pagination accepts OffsetPagination."""
response = PaginatedResponse(
data=[1, 2],
pagination=OffsetPagination(
total_count=2, items_per_page=10, page=1, has_more=False
),
)
assert isinstance(response.pagination, OffsetPagination)
def test_pagination_field_accepts_cursor_pagination(self):
"""PaginatedResponse.pagination accepts CursorPagination."""
response = PaginatedResponse(
data=[1, 2],
pagination=CursorPagination(
next_cursor=None, items_per_page=10, has_more=False
),
)
assert isinstance(response.pagination, CursorPagination)
def test_pagination_alias_accepted(self):
"""Constructing PaginatedResponse with Pagination (alias) still works."""
response = PaginatedResponse(
data=[],
pagination=Pagination(
total_count=0, items_per_page=10, page=1, has_more=False
),
)
assert isinstance(response.pagination, OffsetPagination)
class TestFromAttributes: class TestFromAttributes:
"""Tests for from_attributes config (ORM mode).""" """Tests for from_attributes config (ORM mode)."""

2
uv.lock generated
View File

@@ -251,7 +251,7 @@ wheels = [
[[package]] [[package]]
name = "fastapi-toolsets" name = "fastapi-toolsets"
version = "0.10.0" version = "1.1.2"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "asyncpg" }, { name = "asyncpg" },