9 Commits

Author SHA1 Message Date
0fc86d3c34 refactor: simplify and deduplicate across crud, metrics, cli, and
exceptions
2026-03-01 09:47:36 -05:00
82ef96082e test: add missing tests for fixtures/utils.py 2026-03-01 08:05:20 -05:00
e0828c7e71 refactor: centralize type aliases in types.py and simplify crud layer 2026-03-01 07:46:49 -05:00
59d028d00e refactor: remove deprecated parameter and function 2026-03-01 07:21:07 -05:00
56d365d14b Version 1.3.0 2026-03-01 05:22:16 -05:00
d3vyce
a257d85d45 Add sort_params helper in CrudFactory (#103)
* feat: add sort_params helper in CrudFactory

* docs: add sorting

* fix: change sort_by to order_by
2026-03-01 11:20:43 +01:00
117675d02f Version 1.2.1 2026-02-27 13:57:03 -05:00
d3vyce
d7ad7308c5 Add examples in documentations (#99)
* docs: fix crud

* docs: update README features

* docs: add pagination/search example

* docs: update zensical.toml

* docs: cleanup

* docs: update status to Stable + update description

* docs: add example run commands
2026-02-27 19:56:09 +01:00
8d57bf9525 Version 1.2.0 2026-02-26 09:34:35 -05:00
37 changed files with 1599 additions and 969 deletions

View File

@@ -44,7 +44,7 @@ uv add "fastapi-toolsets[all]"
### Core ### Core
- **CRUD**: Generic async CRUD operations with `CrudFactory`, built-in search with relationship traversal - **CRUD**: Generic async CRUD operations with `CrudFactory`, built-in full-text/faceted search and Offset/Cursor pagination.
- **Database**: Session management, transaction helpers, table locking, and polling-based row change detection - **Database**: Session management, transaction helpers, table locking, and polling-based row change detection
- **Dependencies**: FastAPI dependency factories (`PathDependency`, `BodyDependency`) for automatic DB lookups from path or body parameters - **Dependencies**: FastAPI dependency factories (`PathDependency`, `BodyDependency`) for automatic DB lookups from path or body parameters
- **Fixtures**: Fixture system with dependency management, context support, and pytest integration - **Fixtures**: Fixture system with dependency management, context support, and pytest integration

View File

@@ -0,0 +1,134 @@
# Pagination & search
This example builds an articles listing endpoint that supports **offset pagination**, **cursor pagination**, **full-text search**, **faceted filtering**, and **sorting** — all from a single `CrudFactory` definition.
## Models
```python title="models.py"
--8<-- "docs_src/examples/pagination_search/models.py"
```
## Schemas
```python title="schemas.py"
--8<-- "docs_src/examples/pagination_search/schemas.py"
```
## Crud
Declare `searchable_fields`, `facet_fields`, and `order_fields` once on [`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory). All endpoints built from this class share the same defaults and can override them per call.
```python title="crud.py"
--8<-- "docs_src/examples/pagination_search/crud.py"
```
## Session dependency
```python title="db.py"
--8<-- "docs_src/examples/pagination_search/db.py"
```
!!! info "Deploy a Postgres DB with docker"
```bash
docker run -d --name postgres -e POSTGRES_USER=postgres -e POSTGRES_PASSWORD=postgres -e POSTGRES_DB=postgres -p 5432:5432 postgres:18-alpine
```
## App
```python title="app.py"
--8<-- "docs_src/examples/pagination_search/app.py"
```
## Routes
### Offset pagination
Best for admin panels or any UI that needs a total item count and numbered pages.
```python title="routes.py:1:36"
--8<-- "docs_src/examples/pagination_search/routes.py:1:36"
```
**Example request**
```
GET /articles/offset?page=2&items_per_page=10&search=fastapi&status=published&order_by=title&order=asc
```
**Example response**
```json
{
"status": "SUCCESS",
"data": [
{ "id": "3f47ac69-...", "title": "FastAPI tips", "status": "published", ... }
],
"pagination": {
"total_count": 42,
"page": 2,
"items_per_page": 10,
"has_more": true
},
"filter_attributes": {
"status": ["archived", "draft", "published"],
"name": ["backend", "frontend", "python"]
}
}
```
`filter_attributes` always reflects the values visible **after** applying the active filters. Use it to populate filter dropdowns on the client.
### Cursor pagination
Best for feeds, infinite scroll, or any high-throughput API where offset performance degrades.
```python title="routes.py:39:59"
--8<-- "docs_src/examples/pagination_search/routes.py:39:59"
```
**Example request**
```
GET /articles/cursor?items_per_page=10&status=published&order_by=created_at&order=desc
```
**Example response**
```json
{
"status": "SUCCESS",
"data": [
{ "id": "3f47ac69-...", "title": "FastAPI tips", "status": "published", ... }
],
"pagination": {
"next_cursor": "eyJ2YWx1ZSI6ICIzZjQ3YWM2OS0uLi4ifQ==",
"prev_cursor": null,
"items_per_page": 10,
"has_more": true
},
"filter_attributes": {
"status": ["published"],
"name": ["backend", "python"]
}
}
```
Pass `next_cursor` as the `cursor` query parameter on the next request to advance to the next page.
## Search behaviour
Both endpoints inherit the same `searchable_fields` declared on `ArticleCrud`:
Search is **case-insensitive** and uses a `LIKE %query%` pattern. Pass a [`SearchConfig`](../reference/crud.md#fastapi_toolsets.crud.search.SearchConfig) instead of a plain string to control case sensitivity or switch to `match_mode="all"` (AND across all fields instead of OR).
```python
from fastapi_toolsets.crud import SearchConfig
# Both title AND body must contain "fastapi"
result = await ArticleCrud.offset_paginate(
session,
search=SearchConfig(query="fastapi", case_sensitive=True, match_mode="all"),
search_fields=[Article.title, Article.body],
)
```

View File

@@ -44,7 +44,7 @@ uv add "fastapi-toolsets[all]"
### Core ### Core
- **CRUD**: Generic async CRUD operations with `CrudFactory`, built-in search with relationship traversal - **CRUD**: Generic async CRUD operations with `CrudFactory`, built-in full-text/faceted search and offset/cursor pagination.
- **Database**: Session management, transaction helpers, table locking, and polling-based row change detection - **Database**: Session management, transaction helpers, table locking, and polling-based row change detection
- **Dependencies**: FastAPI dependency factories (`PathDependency`, `BodyDependency`) for automatic DB lookups from path or body parameters - **Dependencies**: FastAPI dependency factories (`PathDependency`, `BodyDependency`) for automatic DB lookups from path or body parameters
- **Fixtures**: Fixture system with dependency management, context support, and pytest integration - **Fixtures**: Fixture system with dependency management, context support, and pytest integration

View File

@@ -95,9 +95,6 @@ The [`offset_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.Async
} }
``` ```
!!! warning "Deprecated: `paginate`"
The `paginate` function is a backward-compatible alias for `offset_paginate`. This function is **deprecated** and will be removed in **v2.0**.
### Cursor pagination ### Cursor pagination
```python ```python
@@ -170,7 +167,7 @@ PostCrud = CrudFactory(model=Post, cursor_column=Post.created_at)
Two search strategies are available, both compatible with [`offset_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.offset_paginate) and [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate). Two search strategies are available, both compatible with [`offset_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.offset_paginate) and [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate).
| | Full-text search | Filter attributes | | | Full-text search | Faceted search |
|---|---|---| |---|---|---|
| Input | Free-text string | Exact column values | | Input | Free-text string | Exact column values |
| Relationship support | Yes | Yes | | Relationship support | Yes | Yes |
@@ -242,7 +239,7 @@ async def get_users(
) )
``` ```
### Filter attributes ### Faceted search
!!! info "Added in `v1.2`" !!! info "Added in `v1.2`"
@@ -295,6 +292,8 @@ Use `filter_by` to pass the client's chosen filter values directly — no need t
Use [`filter_params()`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.filter_params) to generate a dict with the facet filter values from the query parameters: Use [`filter_params()`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.filter_params) to generate a dict with the facet filter values from the query parameters:
```python ```python
from typing import Annotated
from fastapi import Depends from fastapi import Depends
UserCrud = CrudFactory( UserCrud = CrudFactory(
@@ -306,7 +305,7 @@ UserCrud = CrudFactory(
async def list_users( async def list_users(
session: SessionDep, session: SessionDep,
page: int = 1, page: int = 1,
filter_by: dict[str, list[str]] = Depends(UserCrud.filter_params()), filter_by: Annotated[dict[str, list[str]], Depends(UserCrud.filter_params())],
) -> PaginatedResponse[UserRead]: ) -> PaginatedResponse[UserRead]:
return await UserCrud.offset_paginate( return await UserCrud.offset_paginate(
session=session, session=session,
@@ -323,6 +322,58 @@ GET /users?status=active&country=FR → filter_by={"status": ["active"], "coun
GET /users?role=admin&role=editor → filter_by={"role": ["admin", "editor"]} (IN clause) GET /users?role=admin&role=editor → filter_by={"role": ["admin", "editor"]} (IN clause)
``` ```
## Sorting
!!! info "Added in `v1.3`"
Declare `order_fields` on the CRUD class to expose client-driven column ordering via `order_by` and `order` query parameters.
```python
UserCrud = CrudFactory(
model=User,
order_fields=[
User.name,
User.created_at,
],
)
```
Call [`order_params()`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.order_params) to generate a FastAPI dependency that maps the query parameters to an [`OrderByClause`](../reference/crud.md#fastapi_toolsets.crud.factory.OrderByClause) expression:
```python
from typing import Annotated
from fastapi import Depends
from fastapi_toolsets.crud import OrderByClause
@router.get("")
async def list_users(
session: SessionDep,
order_by: Annotated[OrderByClause | None, Depends(UserCrud.order_params())],
) -> PaginatedResponse[UserRead]:
return await UserCrud.offset_paginate(session=session, order_by=order_by)
```
The dependency adds two query parameters to the endpoint:
| Parameter | Type |
| ---------- | --------------- |
| `order_by` | `str | null` |
| `order` | `asc` or `desc` |
```
GET /users?order_by=name&order=asc → ORDER BY users.name ASC
GET /users?order_by=name&order=desc → ORDER BY users.name DESC
```
An unknown `order_by` value raises [`InvalidOrderFieldError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.InvalidOrderFieldError) (HTTP 422).
You can also pass `order_fields` directly to `order_params()` to override the class-level defaults without modifying them:
```python
UserOrderParams = UserCrud.order_params(order_fields=[User.name])
```
## Relationship loading ## Relationship loading
!!! info "Added in `v1.1`" !!! info "Added in `v1.1`"
@@ -384,7 +435,7 @@ await UserCrud.upsert(
) )
``` ```
## `schema` — typed response serialization ## Response serialization
!!! info "Added in `v1.1`" !!! info "Added in `v1.1`"
@@ -417,9 +468,6 @@ async def list_users(session: SessionDep, page: int = 1) -> PaginatedResponse[Us
The schema must have `from_attributes=True` (or inherit from [`PydanticBase`](../reference/schemas.md#fastapi_toolsets.schemas.PydanticBase)) so it can be built from SQLAlchemy model instances. The schema must have `from_attributes=True` (or inherit from [`PydanticBase`](../reference/schemas.md#fastapi_toolsets.schemas.PydanticBase)) so it can be built from SQLAlchemy model instances.
!!! warning "Deprecated: `as_response`"
The `as_response=True` parameter is **deprecated** and will be removed in **v2.0**. Replace it with `schema=YourSchema`.
--- ---
[:material-api: API Reference](../reference/crud.md) [:material-api: API Reference](../reference/crud.md)

View File

@@ -13,6 +13,7 @@ from fastapi_toolsets.exceptions import (
ConflictError, ConflictError,
NoSearchableFieldsError, NoSearchableFieldsError,
InvalidFacetFilterError, InvalidFacetFilterError,
InvalidOrderFieldError,
generate_error_responses, generate_error_responses,
init_exceptions_handlers, init_exceptions_handlers,
) )
@@ -32,6 +33,8 @@ from fastapi_toolsets.exceptions import (
## ::: fastapi_toolsets.exceptions.exceptions.InvalidFacetFilterError ## ::: fastapi_toolsets.exceptions.exceptions.InvalidFacetFilterError
## ::: fastapi_toolsets.exceptions.exceptions.InvalidOrderFieldError
## ::: fastapi_toolsets.exceptions.exceptions.generate_error_responses ## ::: fastapi_toolsets.exceptions.exceptions.generate_error_responses
## ::: fastapi_toolsets.exceptions.handler.init_exceptions_handlers ## ::: fastapi_toolsets.exceptions.handler.init_exceptions_handlers

0
docs_src/__init__.py Normal file
View File

View File

View File

@@ -0,0 +1,9 @@
from fastapi import FastAPI
from fastapi_toolsets.exceptions import init_exceptions_handlers
from .routes import router
app = FastAPI()
init_exceptions_handlers(app=app)
app.include_router(router=router)

View File

@@ -0,0 +1,21 @@
from fastapi_toolsets.crud import CrudFactory
from .models import Article, Category
ArticleCrud = CrudFactory(
model=Article,
cursor_column=Article.created_at,
searchable_fields=[ # default fields for full-text search
Article.title,
Article.body,
(Article.category, Category.name),
],
facet_fields=[ # fields exposed as filter dropdowns
Article.status,
(Article.category, Category.name),
],
order_fields=[ # fields exposed for client-driven ordering
Article.title,
Article.created_at,
],
)

View File

@@ -0,0 +1,17 @@
from typing import Annotated
from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from fastapi_toolsets.db import create_db_context, create_db_dependency
DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost:5432/postgres"
engine = create_async_engine(url=DATABASE_URL, future=True)
async_session_maker = async_sessionmaker(bind=engine, expire_on_commit=False)
get_db = create_db_dependency(session_maker=async_session_maker)
get_db_context = create_db_context(session_maker=async_session_maker)
SessionDep = Annotated[AsyncSession, Depends(get_db)]

View File

@@ -0,0 +1,36 @@
import datetime
import uuid
from sqlalchemy import Boolean, DateTime, ForeignKey, String, Text, func
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
class Base(DeclarativeBase):
pass
class Category(Base):
__tablename__ = "categories"
id: Mapped[uuid.UUID] = mapped_column(primary_key=True, default=uuid.uuid4)
name: Mapped[str] = mapped_column(String(64), unique=True)
articles: Mapped[list["Article"]] = relationship(back_populates="category")
class Article(Base):
__tablename__ = "articles"
id: Mapped[uuid.UUID] = mapped_column(primary_key=True, default=uuid.uuid4)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
title: Mapped[str] = mapped_column(String(256))
body: Mapped[str] = mapped_column(Text)
status: Mapped[str] = mapped_column(String(32))
published: Mapped[bool] = mapped_column(Boolean, default=False)
category_id: Mapped[uuid.UUID | None] = mapped_column(
ForeignKey("categories.id"), nullable=True
)
category: Mapped["Category | None"] = relationship(back_populates="articles")

View File

@@ -0,0 +1,59 @@
from typing import Annotated
from fastapi import APIRouter, Depends, Query
from fastapi_toolsets.crud import OrderByClause
from fastapi_toolsets.schemas import PaginatedResponse
from .crud import ArticleCrud
from .db import SessionDep
from .models import Article
from .schemas import ArticleRead
router = APIRouter(prefix="/articles")
@router.get("/offset")
async def list_articles_offset(
session: SessionDep,
filter_by: Annotated[dict[str, list[str]], Depends(ArticleCrud.filter_params())],
order_by: Annotated[
OrderByClause | None,
Depends(ArticleCrud.order_params(default_field=Article.created_at)),
],
page: int = Query(1, ge=1),
items_per_page: int = Query(20, ge=1, le=100),
search: str | None = None,
) -> PaginatedResponse[ArticleRead]:
return await ArticleCrud.offset_paginate(
session=session,
page=page,
items_per_page=items_per_page,
search=search,
filter_by=filter_by or None,
order_by=order_by,
schema=ArticleRead,
)
@router.get("/cursor")
async def list_articles_cursor(
session: SessionDep,
filter_by: Annotated[dict[str, list[str]], Depends(ArticleCrud.filter_params())],
order_by: Annotated[
OrderByClause | None,
Depends(ArticleCrud.order_params(default_field=Article.created_at)),
],
cursor: str | None = None,
items_per_page: int = Query(20, ge=1, le=100),
search: str | None = None,
) -> PaginatedResponse[ArticleRead]:
return await ArticleCrud.cursor_paginate(
session=session,
cursor=cursor,
items_per_page=items_per_page,
search=search,
filter_by=filter_by or None,
order_by=order_by,
schema=ArticleRead,
)

View File

@@ -0,0 +1,13 @@
import datetime
import uuid
from fastapi_toolsets.schemas import PydanticBase
class ArticleRead(PydanticBase):
id: uuid.UUID
created_at: datetime.datetime
title: str
status: str
published: bool
category_id: uuid.UUID | None

View File

@@ -1,7 +1,7 @@
[project] [project]
name = "fastapi-toolsets" name = "fastapi-toolsets"
version = "1.1.2" version = "1.3.0"
description = "Reusable tools for FastAPI: async CRUD, fixtures, CLI, and standardized responses for SQLAlchemy + PostgreSQL" description = "Production-ready utilities for FastAPI applications"
readme = "README.md" readme = "README.md"
license = "MIT" license = "MIT"
license-files = ["LICENSE"] license-files = ["LICENSE"]
@@ -11,7 +11,7 @@ authors = [
] ]
keywords = ["fastapi", "sqlalchemy", "postgresql"] keywords = ["fastapi", "sqlalchemy", "postgresql"]
classifiers = [ classifiers = [
"Development Status :: 4 - Beta", "Development Status :: 5 - Production/Stable",
"Framework :: AsyncIO", "Framework :: AsyncIO",
"Framework :: FastAPI", "Framework :: FastAPI",
"Framework :: Pydantic", "Framework :: Pydantic",

View File

@@ -21,4 +21,4 @@ Example usage:
return Response(data={"user": user.username}, message="Success") return Response(data={"user": user.username}, message="Success")
""" """
__version__ = "1.1.2" __version__ = "1.3.0"

View File

@@ -72,7 +72,7 @@ async def load(
registry = get_fixtures_registry() registry = get_fixtures_registry()
db_context = get_db_context() db_context = get_db_context()
context_list = [c.value for c in contexts] if contexts else [Context.BASE] context_list = list(contexts) if contexts else [Context.BASE]
ordered = registry.resolve_context_dependencies(*context_list) ordered = registry.resolve_context_dependencies(*context_list)

View File

@@ -1,12 +1,9 @@
"""Generic async CRUD operations for SQLAlchemy models.""" """Generic async CRUD operations for SQLAlchemy models."""
from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError
from .factory import CrudFactory, JoinType, M2MFieldType from ..types import FacetFieldType, JoinType, M2MFieldType, OrderByClause
from .search import ( from .factory import CrudFactory
FacetFieldType, from .search import SearchConfig, get_searchable_fields
SearchConfig,
get_searchable_fields,
)
__all__ = [ __all__ = [
"CrudFactory", "CrudFactory",
@@ -16,5 +13,6 @@ __all__ = [
"JoinType", "JoinType",
"M2MFieldType", "M2MFieldType",
"NoSearchableFieldsError", "NoSearchableFieldsError",
"OrderByClause",
"SearchConfig", "SearchConfig",
] ]

View File

@@ -6,11 +6,10 @@ import base64
import inspect import inspect
import json import json
import uuid as uuid_module import uuid as uuid_module
import warnings from collections.abc import Awaitable, Callable, Sequence
from collections.abc import Awaitable, Callable, Mapping, Sequence
from datetime import date, datetime from datetime import date, datetime
from decimal import Decimal from decimal import Decimal
from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload from typing import Any, ClassVar, Generic, Literal, Self, cast, overload
from fastapi import Query from fastapi import Query
from pydantic import BaseModel from pydantic import BaseModel
@@ -24,23 +23,25 @@ from sqlalchemy.sql.base import ExecutableOption
from sqlalchemy.sql.roles import WhereHavingRole from sqlalchemy.sql.roles import WhereHavingRole
from ..db import get_transaction from ..db import get_transaction
from ..exceptions import NotFoundError from ..exceptions import InvalidOrderFieldError, NotFoundError
from ..schemas import CursorPagination, OffsetPagination, PaginatedResponse, Response from ..schemas import CursorPagination, OffsetPagination, PaginatedResponse, Response
from .search import ( from ..types import (
FacetFieldType, FacetFieldType,
SearchConfig, JoinType,
M2MFieldType,
ModelType,
OrderByClause,
SchemaType,
SearchFieldType, SearchFieldType,
)
from .search import (
SearchConfig,
build_facets, build_facets,
build_filter_by, build_filter_by,
build_search_filters, build_search_filters,
facet_keys, facet_keys,
) )
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
SchemaType = TypeVar("SchemaType", bound=BaseModel)
JoinType = list[tuple[type[DeclarativeBase], Any]]
M2MFieldType = Mapping[str, QueryableAttribute[Any]]
def _encode_cursor(value: Any) -> str: def _encode_cursor(value: Any) -> str:
"""Encode cursor column value as an base64 string.""" """Encode cursor column value as an base64 string."""
@@ -52,6 +53,22 @@ def _decode_cursor(cursor: str) -> str:
return json.loads(base64.b64decode(cursor.encode()).decode()) return json.loads(base64.b64decode(cursor.encode()).decode())
def _apply_joins(q: Any, joins: JoinType | None, outer_join: bool) -> Any:
"""Apply a list of (model, condition) joins to a SQLAlchemy select query."""
if not joins:
return q
for model, condition in joins:
q = q.outerjoin(model, condition) if outer_join else q.join(model, condition)
return q
def _apply_search_joins(q: Any, search_joins: list[Any]) -> Any:
"""Apply relationship-based outer joins (from search/filter_by) to a query."""
for join_rel in search_joins:
q = q.outerjoin(join_rel)
return q
class AsyncCrud(Generic[ModelType]): class AsyncCrud(Generic[ModelType]):
"""Generic async CRUD operations for SQLAlchemy models. """Generic async CRUD operations for SQLAlchemy models.
@@ -61,6 +78,7 @@ class AsyncCrud(Generic[ModelType]):
model: ClassVar[type[DeclarativeBase]] model: ClassVar[type[DeclarativeBase]]
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
facet_fields: ClassVar[Sequence[FacetFieldType] | None] = None facet_fields: ClassVar[Sequence[FacetFieldType] | None] = None
order_fields: ClassVar[Sequence[QueryableAttribute[Any]] | None] = None
m2m_fields: ClassVar[M2MFieldType | None] = None m2m_fields: ClassVar[M2MFieldType | None] = None
default_load_options: ClassVar[list[ExecutableOption] | None] = None default_load_options: ClassVar[list[ExecutableOption] | None] = None
cursor_column: ClassVar[Any | None] = None cursor_column: ClassVar[Any | None] = None
@@ -130,6 +148,48 @@ class AsyncCrud(Generic[ModelType]):
return set() return set()
return set(cls.m2m_fields.keys()) return set(cls.m2m_fields.keys())
@classmethod
def _resolve_facet_fields(
cls: type[Self],
facet_fields: Sequence[FacetFieldType] | None,
) -> Sequence[FacetFieldType] | None:
"""Return facet_fields if given, otherwise fall back to the class-level default."""
return facet_fields if facet_fields is not None else cls.facet_fields
@classmethod
def _prepare_filter_by(
cls: type[Self],
filter_by: dict[str, Any] | BaseModel | None,
facet_fields: Sequence[FacetFieldType] | None,
) -> tuple[list[Any], list[Any]]:
"""Normalize filter_by and return (filters, joins) to apply to the query."""
if isinstance(filter_by, BaseModel):
filter_by = filter_by.model_dump(exclude_none=True)
if not filter_by:
return [], []
resolved = cls._resolve_facet_fields(facet_fields)
return build_filter_by(filter_by, resolved or [])
@classmethod
async def _build_filter_attributes(
cls: type[Self],
session: AsyncSession,
facet_fields: Sequence[FacetFieldType] | None,
filters: list[Any],
search_joins: list[Any],
) -> dict[str, list[Any]] | None:
"""Build facet filter_attributes, or return None if no facet fields configured."""
resolved = cls._resolve_facet_fields(facet_fields)
if not resolved:
return None
return await build_facets(
session,
cls.model,
resolved,
base_filters=filters,
base_joins=search_joins,
)
@classmethod @classmethod
def filter_params( def filter_params(
cls: type[Self], cls: type[Self],
@@ -150,7 +210,7 @@ class AsyncCrud(Generic[ModelType]):
ValueError: If no facet fields are configured on this CRUD class and none are ValueError: If no facet fields are configured on this CRUD class and none are
provided via ``facet_fields``. provided via ``facet_fields``.
""" """
fields = facet_fields if facet_fields is not None else cls.facet_fields fields = cls._resolve_facet_fields(facet_fields)
if not fields: if not fields:
raise ValueError( raise ValueError(
f"{cls.__name__} has no facet_fields configured. " f"{cls.__name__} has no facet_fields configured. "
@@ -176,6 +236,63 @@ class AsyncCrud(Generic[ModelType]):
return dependency return dependency
@classmethod
def order_params(
cls: type[Self],
*,
order_fields: Sequence[QueryableAttribute[Any]] | None = None,
default_field: QueryableAttribute[Any] | None = None,
default_order: Literal["asc", "desc"] = "asc",
) -> Callable[..., Awaitable[OrderByClause | None]]:
"""Return a FastAPI dependency that resolves order query params into an order_by clause.
Args:
order_fields: Override the allowed order fields. Falls back to the class-level
``order_fields`` if not provided.
default_field: Field to order by when ``order_by`` query param is absent.
If ``None`` and no ``order_by`` is provided, no ordering is applied.
default_order: Default order direction when ``order`` is absent
(``"asc"`` or ``"desc"``).
Returns:
An async dependency function named ``{Model}OrderParams`` that resolves to an
``OrderByClause`` (or ``None``). Pass it to ``Depends()`` in your route.
Raises:
ValueError: If no order fields are configured on this CRUD class and none are
provided via ``order_fields``.
InvalidOrderFieldError: When the request provides an unknown ``order_by`` value.
"""
fields = order_fields if order_fields is not None else cls.order_fields
if not fields:
raise ValueError(
f"{cls.__name__} has no order_fields configured. "
"Pass order_fields= or set them on CrudFactory."
)
field_map: dict[str, QueryableAttribute[Any]] = {f.key: f for f in fields}
valid_keys = sorted(field_map.keys())
async def dependency(
order_by: str | None = Query(
None, description=f"Field to order by. Valid values: {valid_keys}"
),
order: Literal["asc", "desc"] = Query(
default_order, description="Sort direction"
),
) -> OrderByClause | None:
if order_by is None:
if default_field is None:
return None
field = default_field
elif order_by not in field_map:
raise InvalidOrderFieldError(order_by, valid_keys)
else:
field = field_map[order_by]
return field.asc() if order == "asc" else field.desc()
dependency.__name__ = f"{cls.model.__name__}OrderParams"
return dependency
@overload @overload
@classmethod @classmethod
async def create( # pragma: no cover async def create( # pragma: no cover
@@ -184,10 +301,8 @@ class AsyncCrud(Generic[ModelType]):
obj: BaseModel, obj: BaseModel,
*, *,
schema: type[SchemaType], schema: type[SchemaType],
as_response: bool = ...,
) -> Response[SchemaType]: ... ) -> Response[SchemaType]: ...
# Backward-compatible - will be removed in v2.0
@overload @overload
@classmethod @classmethod
async def create( # pragma: no cover async def create( # pragma: no cover
@@ -195,18 +310,6 @@ class AsyncCrud(Generic[ModelType]):
session: AsyncSession, session: AsyncSession,
obj: BaseModel, obj: BaseModel,
*, *,
as_response: Literal[True],
schema: None = ...,
) -> Response[ModelType]: ...
@overload
@classmethod
async def create( # pragma: no cover
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
*,
as_response: Literal[False] = ...,
schema: None = ..., schema: None = ...,
) -> ModelType: ... ) -> ModelType: ...
@@ -216,29 +319,19 @@ class AsyncCrud(Generic[ModelType]):
session: AsyncSession, session: AsyncSession,
obj: BaseModel, obj: BaseModel,
*, *,
as_response: bool = False,
schema: type[BaseModel] | None = None, schema: type[BaseModel] | None = None,
) -> ModelType | Response[ModelType] | Response[Any]: ) -> ModelType | Response[Any]:
"""Create a new record in the database. """Create a new record in the database.
Args: Args:
session: DB async session session: DB async session
obj: Pydantic model with data to create obj: Pydantic model with data to create
as_response: Deprecated. Use ``schema`` instead. Will be removed in v2.0.
schema: Pydantic schema to serialize the result into. When provided, schema: Pydantic schema to serialize the result into. When provided,
the result is automatically wrapped in a ``Response[schema]``. the result is automatically wrapped in a ``Response[schema]``.
Returns: Returns:
Created model instance, or ``Response[schema]`` when ``schema`` is given, Created model instance, or ``Response[schema]`` when ``schema`` is given.
or ``Response[ModelType]`` when ``as_response=True`` (deprecated).
""" """
if as_response and schema is None:
warnings.warn(
"as_response is deprecated and will be removed in v2.0. "
"Use schema=YourSchema instead.",
DeprecationWarning,
stacklevel=2,
)
async with get_transaction(session): async with get_transaction(session):
m2m_exclude = cls._m2m_schema_fields() m2m_exclude = cls._m2m_schema_fields()
data = ( data = (
@@ -254,9 +347,8 @@ class AsyncCrud(Generic[ModelType]):
session.add(db_model) session.add(db_model)
await session.refresh(db_model) await session.refresh(db_model)
result = cast(ModelType, db_model) result = cast(ModelType, db_model)
if as_response or schema: if schema:
data_out = schema.model_validate(result) if schema else result return Response(data=schema.model_validate(result))
return Response(data=data_out)
return result return result
@overload @overload
@@ -271,10 +363,8 @@ class AsyncCrud(Generic[ModelType]):
with_for_update: bool = False, with_for_update: bool = False,
load_options: list[ExecutableOption] | None = None, load_options: list[ExecutableOption] | None = None,
schema: type[SchemaType], schema: type[SchemaType],
as_response: bool = ...,
) -> Response[SchemaType]: ... ) -> Response[SchemaType]: ...
# Backward-compatible - will be removed in v2.0
@overload @overload
@classmethod @classmethod
async def get( # pragma: no cover async def get( # pragma: no cover
@@ -286,22 +376,6 @@ class AsyncCrud(Generic[ModelType]):
outer_join: bool = False, outer_join: bool = False,
with_for_update: bool = False, with_for_update: bool = False,
load_options: list[ExecutableOption] | None = None, load_options: list[ExecutableOption] | None = None,
as_response: Literal[True],
schema: None = ...,
) -> Response[ModelType]: ...
@overload
@classmethod
async def get( # pragma: no cover
cls: type[Self],
session: AsyncSession,
filters: list[Any],
*,
joins: JoinType | None = None,
outer_join: bool = False,
with_for_update: bool = False,
load_options: list[ExecutableOption] | None = None,
as_response: Literal[False] = ...,
schema: None = ..., schema: None = ...,
) -> ModelType: ... ) -> ModelType: ...
@@ -315,9 +389,8 @@ class AsyncCrud(Generic[ModelType]):
outer_join: bool = False, outer_join: bool = False,
with_for_update: bool = False, with_for_update: bool = False,
load_options: list[ExecutableOption] | None = None, load_options: list[ExecutableOption] | None = None,
as_response: bool = False,
schema: type[BaseModel] | None = None, schema: type[BaseModel] | None = None,
) -> ModelType | Response[ModelType] | Response[Any]: ) -> ModelType | Response[Any]:
"""Get exactly one record. Raises NotFoundError if not found. """Get exactly one record. Raises NotFoundError if not found.
Args: Args:
@@ -327,33 +400,18 @@ class AsyncCrud(Generic[ModelType]):
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
with_for_update: Lock the row for update with_for_update: Lock the row for update
load_options: SQLAlchemy loader options (e.g., selectinload) load_options: SQLAlchemy loader options (e.g., selectinload)
as_response: Deprecated. Use ``schema`` instead. Will be removed in v2.0.
schema: Pydantic schema to serialize the result into. When provided, schema: Pydantic schema to serialize the result into. When provided,
the result is automatically wrapped in a ``Response[schema]``. the result is automatically wrapped in a ``Response[schema]``.
Returns: Returns:
Model instance, or ``Response[schema]`` when ``schema`` is given, Model instance, or ``Response[schema]`` when ``schema`` is given.
or ``Response[ModelType]`` when ``as_response=True`` (deprecated).
Raises: Raises:
NotFoundError: If no record found NotFoundError: If no record found
MultipleResultsFound: If more than one record found MultipleResultsFound: If more than one record found
""" """
if as_response and schema is None:
warnings.warn(
"as_response is deprecated and will be removed in v2.0. "
"Use schema=YourSchema instead.",
DeprecationWarning,
stacklevel=2,
)
q = select(cls.model) q = select(cls.model)
if joins: q = _apply_joins(q, joins, outer_join)
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
q = q.where(and_(*filters)) q = q.where(and_(*filters))
if resolved := cls._resolve_load_options(load_options): if resolved := cls._resolve_load_options(load_options):
q = q.options(*resolved) q = q.options(*resolved)
@@ -364,9 +422,8 @@ class AsyncCrud(Generic[ModelType]):
if not item: if not item:
raise NotFoundError() raise NotFoundError()
result = cast(ModelType, item) result = cast(ModelType, item)
if as_response or schema: if schema:
data_out = schema.model_validate(result) if schema else result return Response(data=schema.model_validate(result))
return Response(data=data_out)
return result return result
@classmethod @classmethod
@@ -392,13 +449,7 @@ class AsyncCrud(Generic[ModelType]):
Model instance or None Model instance or None
""" """
q = select(cls.model) q = select(cls.model)
if joins: q = _apply_joins(q, joins, outer_join)
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
if filters: if filters:
q = q.where(and_(*filters)) q = q.where(and_(*filters))
if resolved := cls._resolve_load_options(load_options): if resolved := cls._resolve_load_options(load_options):
@@ -415,7 +466,7 @@ class AsyncCrud(Generic[ModelType]):
joins: JoinType | None = None, joins: JoinType | None = None,
outer_join: bool = False, outer_join: bool = False,
load_options: list[ExecutableOption] | None = None, load_options: list[ExecutableOption] | None = None,
order_by: Any | None = None, order_by: OrderByClause | None = None,
limit: int | None = None, limit: int | None = None,
offset: int | None = None, offset: int | None = None,
) -> Sequence[ModelType]: ) -> Sequence[ModelType]:
@@ -435,13 +486,7 @@ class AsyncCrud(Generic[ModelType]):
List of model instances List of model instances
""" """
q = select(cls.model) q = select(cls.model)
if joins: q = _apply_joins(q, joins, outer_join)
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
if filters: if filters:
q = q.where(and_(*filters)) q = q.where(and_(*filters))
if resolved := cls._resolve_load_options(load_options): if resolved := cls._resolve_load_options(load_options):
@@ -466,10 +511,8 @@ class AsyncCrud(Generic[ModelType]):
exclude_unset: bool = True, exclude_unset: bool = True,
exclude_none: bool = False, exclude_none: bool = False,
schema: type[SchemaType], schema: type[SchemaType],
as_response: bool = ...,
) -> Response[SchemaType]: ... ) -> Response[SchemaType]: ...
# Backward-compatible - will be removed in v2.0
@overload @overload
@classmethod @classmethod
async def update( # pragma: no cover async def update( # pragma: no cover
@@ -480,21 +523,6 @@ class AsyncCrud(Generic[ModelType]):
*, *,
exclude_unset: bool = True, exclude_unset: bool = True,
exclude_none: bool = False, exclude_none: bool = False,
as_response: Literal[True],
schema: None = ...,
) -> Response[ModelType]: ...
@overload
@classmethod
async def update( # pragma: no cover
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
filters: list[Any],
*,
exclude_unset: bool = True,
exclude_none: bool = False,
as_response: Literal[False] = ...,
schema: None = ..., schema: None = ...,
) -> ModelType: ... ) -> ModelType: ...
@@ -507,9 +535,8 @@ class AsyncCrud(Generic[ModelType]):
*, *,
exclude_unset: bool = True, exclude_unset: bool = True,
exclude_none: bool = False, exclude_none: bool = False,
as_response: bool = False,
schema: type[BaseModel] | None = None, schema: type[BaseModel] | None = None,
) -> ModelType | Response[ModelType] | Response[Any]: ) -> ModelType | Response[Any]:
"""Update a record in the database. """Update a record in the database.
Args: Args:
@@ -518,24 +545,15 @@ class AsyncCrud(Generic[ModelType]):
filters: List of SQLAlchemy filter conditions filters: List of SQLAlchemy filter conditions
exclude_unset: Exclude fields not explicitly set in the schema exclude_unset: Exclude fields not explicitly set in the schema
exclude_none: Exclude fields with None value exclude_none: Exclude fields with None value
as_response: Deprecated. Use ``schema`` instead. Will be removed in v2.0.
schema: Pydantic schema to serialize the result into. When provided, schema: Pydantic schema to serialize the result into. When provided,
the result is automatically wrapped in a ``Response[schema]``. the result is automatically wrapped in a ``Response[schema]``.
Returns: Returns:
Updated model instance, or ``Response[schema]`` when ``schema`` is given, Updated model instance, or ``Response[schema]`` when ``schema`` is given.
or ``Response[ModelType]`` when ``as_response=True`` (deprecated).
Raises: Raises:
NotFoundError: If no record found NotFoundError: If no record found
""" """
if as_response and schema is None:
warnings.warn(
"as_response is deprecated and will be removed in v2.0. "
"Use schema=YourSchema instead.",
DeprecationWarning,
stacklevel=2,
)
async with get_transaction(session): async with get_transaction(session):
m2m_exclude = cls._m2m_schema_fields() m2m_exclude = cls._m2m_schema_fields()
@@ -565,9 +583,8 @@ class AsyncCrud(Generic[ModelType]):
for rel_attr, related_instances in m2m_resolved.items(): for rel_attr, related_instances in m2m_resolved.items():
setattr(db_model, rel_attr, related_instances) setattr(db_model, rel_attr, related_instances)
await session.refresh(db_model) await session.refresh(db_model)
if as_response or schema: if schema:
data_out = schema.model_validate(db_model) if schema else db_model return Response(data=schema.model_validate(db_model))
return Response(data=data_out)
return db_model return db_model
@classmethod @classmethod
@@ -623,7 +640,7 @@ class AsyncCrud(Generic[ModelType]):
session: AsyncSession, session: AsyncSession,
filters: list[Any], filters: list[Any],
*, *,
as_response: Literal[True], return_response: Literal[True],
) -> Response[None]: ... ) -> Response[None]: ...
@overload @overload
@@ -633,8 +650,8 @@ class AsyncCrud(Generic[ModelType]):
session: AsyncSession, session: AsyncSession,
filters: list[Any], filters: list[Any],
*, *,
as_response: Literal[False] = ..., return_response: Literal[False] = ...,
) -> bool: ... ) -> None: ...
@classmethod @classmethod
async def delete( async def delete(
@@ -642,33 +659,26 @@ class AsyncCrud(Generic[ModelType]):
session: AsyncSession, session: AsyncSession,
filters: list[Any], filters: list[Any],
*, *,
as_response: bool = False, return_response: bool = False,
) -> bool | Response[None]: ) -> None | Response[None]:
"""Delete records from the database. """Delete records from the database.
Args: Args:
session: DB async session session: DB async session
filters: List of SQLAlchemy filter conditions filters: List of SQLAlchemy filter conditions
as_response: Deprecated. Will be removed in v2.0. When ``True``, return_response: When ``True``, returns ``Response[None]`` instead
returns ``Response[None]`` instead of ``bool``. of ``None``. Useful for API endpoints that expect a consistent
response envelope.
Returns: Returns:
``True`` if deletion was executed, or ``Response[None]`` when ``None``, or ``Response[None]`` when ``return_response=True``.
``as_response=True`` (deprecated).
""" """
if as_response:
warnings.warn(
"as_response is deprecated and will be removed in v2.0. "
"Use schema=YourSchema instead.",
DeprecationWarning,
stacklevel=2,
)
async with get_transaction(session): async with get_transaction(session):
q = sql_delete(cls.model).where(and_(*filters)) q = sql_delete(cls.model).where(and_(*filters))
await session.execute(q) await session.execute(q)
if as_response: if return_response:
return Response(data=None) return Response(data=None)
return True return None
@classmethod @classmethod
async def count( async def count(
@@ -691,13 +701,7 @@ class AsyncCrud(Generic[ModelType]):
Number of matching records Number of matching records
""" """
q = select(func.count()).select_from(cls.model) q = select(func.count()).select_from(cls.model)
if joins: q = _apply_joins(q, joins, outer_join)
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
if filters: if filters:
q = q.where(and_(*filters)) q = q.where(and_(*filters))
result = await session.execute(q) result = await session.execute(q)
@@ -724,58 +728,11 @@ class AsyncCrud(Generic[ModelType]):
True if at least one record matches True if at least one record matches
""" """
q = select(cls.model) q = select(cls.model)
if joins: q = _apply_joins(q, joins, outer_join)
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
q = q.where(and_(*filters)).exists().select() q = q.where(and_(*filters)).exists().select()
result = await session.execute(q) result = await session.execute(q)
return bool(result.scalar()) return bool(result.scalar())
@overload
@classmethod
async def offset_paginate( # pragma: no cover
cls: type[Self],
session: AsyncSession,
*,
filters: list[Any] | None = None,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: list[ExecutableOption] | None = None,
order_by: Any | None = None,
page: int = 1,
items_per_page: int = 20,
search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None,
facet_fields: Sequence[FacetFieldType] | None = None,
filter_by: dict[str, Any] | BaseModel | None = None,
schema: type[SchemaType],
) -> PaginatedResponse[SchemaType]: ...
# Backward-compatible - will be removed in v2.0
@overload
@classmethod
async def offset_paginate( # pragma: no cover
cls: type[Self],
session: AsyncSession,
*,
filters: list[Any] | None = None,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: list[ExecutableOption] | None = None,
order_by: Any | None = None,
page: int = 1,
items_per_page: int = 20,
search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None,
facet_fields: Sequence[FacetFieldType] | None = None,
filter_by: dict[str, Any] | BaseModel | None = None,
schema: None = ...,
) -> PaginatedResponse[ModelType]: ...
@classmethod @classmethod
async def offset_paginate( async def offset_paginate(
cls: type[Self], cls: type[Self],
@@ -785,15 +742,15 @@ class AsyncCrud(Generic[ModelType]):
joins: JoinType | None = None, joins: JoinType | None = None,
outer_join: bool = False, outer_join: bool = False,
load_options: list[ExecutableOption] | None = None, load_options: list[ExecutableOption] | None = None,
order_by: Any | None = None, order_by: OrderByClause | None = None,
page: int = 1, page: int = 1,
items_per_page: int = 20, items_per_page: int = 20,
search: str | SearchConfig | None = None, search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None, search_fields: Sequence[SearchFieldType] | None = None,
facet_fields: Sequence[FacetFieldType] | None = None, facet_fields: Sequence[FacetFieldType] | None = None,
filter_by: dict[str, Any] | BaseModel | None = None, filter_by: dict[str, Any] | BaseModel | None = None,
schema: type[BaseModel] | None = None, schema: type[BaseModel],
) -> PaginatedResponse[ModelType] | PaginatedResponse[Any]: ) -> PaginatedResponse[Any]:
"""Get paginated results using offset-based pagination. """Get paginated results using offset-based pagination.
Args: Args:
@@ -811,54 +768,36 @@ class AsyncCrud(Generic[ModelType]):
filter_by: Dict of {column_key: value} to filter by declared facet fields. filter_by: Dict of {column_key: value} to filter by declared facet fields.
Keys must match the column.key of a facet field. Scalar → equality, Keys must match the column.key of a facet field. Scalar → equality,
list → IN clause. Raises InvalidFacetFilterError for unknown keys. list → IN clause. Raises InvalidFacetFilterError for unknown keys.
schema: Optional Pydantic schema to serialize each item into. schema: Pydantic schema to serialize each item into.
Returns: Returns:
PaginatedResponse with OffsetPagination metadata PaginatedResponse with OffsetPagination metadata
""" """
filters = list(filters) if filters else [] filters = list(filters) if filters else []
offset = (page - 1) * items_per_page offset = (page - 1) * items_per_page
search_joins: list[Any] = []
if isinstance(filter_by, BaseModel): fb_filters, search_joins = cls._prepare_filter_by(filter_by, facet_fields)
filter_by = filter_by.model_dump(exclude_none=True) or None filters.extend(fb_filters)
# Build filter_by conditions from declared facet fields
if filter_by:
resolved_facets_for_filter = (
facet_fields if facet_fields is not None else cls.facet_fields
)
fb_filters, fb_joins = build_filter_by(
filter_by, resolved_facets_for_filter or []
)
filters.extend(fb_filters)
search_joins.extend(fb_joins)
# Build search filters # Build search filters
if search: if search:
search_filters, search_joins = build_search_filters( search_filters, new_search_joins = build_search_filters(
cls.model, cls.model,
search, search,
search_fields=search_fields, search_fields=search_fields,
default_fields=cls.searchable_fields, default_fields=cls.searchable_fields,
) )
filters.extend(search_filters) filters.extend(search_filters)
search_joins.extend(new_search_joins)
# Build query with joins # Build query with joins
q = select(cls.model) q = select(cls.model)
# Apply explicit joins # Apply explicit joins
if joins: q = _apply_joins(q, joins, outer_join)
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
# Apply search joins (always outer joins for search) # Apply search joins (always outer joins for search)
for join_rel in search_joins: q = _apply_search_joins(q, search_joins)
q = q.outerjoin(join_rel)
if filters: if filters:
q = q.where(and_(*filters)) q = q.where(and_(*filters))
@@ -870,9 +809,7 @@ class AsyncCrud(Generic[ModelType]):
q = q.offset(offset).limit(items_per_page) q = q.offset(offset).limit(items_per_page)
result = await session.execute(q) result = await session.execute(q)
raw_items = cast(list[ModelType], result.unique().scalars().all()) raw_items = cast(list[ModelType], result.unique().scalars().all())
items: list[Any] = ( items: list[Any] = [schema.model_validate(item) for item in raw_items]
[schema.model_validate(item) for item in raw_items] if schema else raw_items
)
# Count query (with same joins and filters) # Count query (with same joins and filters)
pk_col = cls.model.__mapper__.primary_key[0] pk_col = cls.model.__mapper__.primary_key[0]
@@ -880,17 +817,10 @@ class AsyncCrud(Generic[ModelType]):
count_q = count_q.select_from(cls.model) count_q = count_q.select_from(cls.model)
# Apply explicit joins to count query # Apply explicit joins to count query
if joins: count_q = _apply_joins(count_q, joins, outer_join)
for model, condition in joins:
count_q = (
count_q.outerjoin(model, condition)
if outer_join
else count_q.join(model, condition)
)
# Apply search joins to count query # Apply search joins to count query
for join_rel in search_joins: count_q = _apply_search_joins(count_q, search_joins)
count_q = count_q.outerjoin(join_rel)
if filters: if filters:
count_q = count_q.where(and_(*filters)) count_q = count_q.where(and_(*filters))
@@ -898,19 +828,9 @@ class AsyncCrud(Generic[ModelType]):
count_result = await session.execute(count_q) count_result = await session.execute(count_q)
total_count = count_result.scalar_one() total_count = count_result.scalar_one()
# Build facets filter_attributes = await cls._build_filter_attributes(
resolved_facet_fields = ( session, facet_fields, filters, search_joins
facet_fields if facet_fields is not None else cls.facet_fields
) )
filter_attributes: dict[str, list[Any]] | None = None
if resolved_facet_fields:
filter_attributes = await build_facets(
session,
cls.model,
resolved_facet_fields,
base_filters=filters or None,
base_joins=search_joins or None,
)
return PaginatedResponse( return PaginatedResponse(
data=items, data=items,
@@ -923,50 +843,6 @@ class AsyncCrud(Generic[ModelType]):
filter_attributes=filter_attributes, filter_attributes=filter_attributes,
) )
# Backward-compatible - will be removed in v2.0
paginate = offset_paginate
@overload
@classmethod
async def cursor_paginate( # pragma: no cover
cls: type[Self],
session: AsyncSession,
*,
cursor: str | None = None,
filters: list[Any] | None = None,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: list[ExecutableOption] | None = None,
order_by: Any | None = None,
items_per_page: int = 20,
search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None,
facet_fields: Sequence[FacetFieldType] | None = None,
filter_by: dict[str, Any] | BaseModel | None = None,
schema: type[SchemaType],
) -> PaginatedResponse[SchemaType]: ...
# Backward-compatible - will be removed in v2.0
@overload
@classmethod
async def cursor_paginate( # pragma: no cover
cls: type[Self],
session: AsyncSession,
*,
cursor: str | None = None,
filters: list[Any] | None = None,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: list[ExecutableOption] | None = None,
order_by: Any | None = None,
items_per_page: int = 20,
search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None,
facet_fields: Sequence[FacetFieldType] | None = None,
filter_by: dict[str, Any] | BaseModel | None = None,
schema: None = ...,
) -> PaginatedResponse[ModelType]: ...
@classmethod @classmethod
async def cursor_paginate( async def cursor_paginate(
cls: type[Self], cls: type[Self],
@@ -977,14 +853,14 @@ class AsyncCrud(Generic[ModelType]):
joins: JoinType | None = None, joins: JoinType | None = None,
outer_join: bool = False, outer_join: bool = False,
load_options: list[ExecutableOption] | None = None, load_options: list[ExecutableOption] | None = None,
order_by: Any | None = None, order_by: OrderByClause | None = None,
items_per_page: int = 20, items_per_page: int = 20,
search: str | SearchConfig | None = None, search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None, search_fields: Sequence[SearchFieldType] | None = None,
facet_fields: Sequence[FacetFieldType] | None = None, facet_fields: Sequence[FacetFieldType] | None = None,
filter_by: dict[str, Any] | BaseModel | None = None, filter_by: dict[str, Any] | BaseModel | None = None,
schema: type[BaseModel] | None = None, schema: type[BaseModel],
) -> PaginatedResponse[ModelType] | PaginatedResponse[Any]: ) -> PaginatedResponse[Any]:
"""Get paginated results using cursor-based pagination. """Get paginated results using cursor-based pagination.
Args: Args:
@@ -1011,21 +887,9 @@ class AsyncCrud(Generic[ModelType]):
PaginatedResponse with CursorPagination metadata PaginatedResponse with CursorPagination metadata
""" """
filters = list(filters) if filters else [] filters = list(filters) if filters else []
search_joins: list[Any] = []
if isinstance(filter_by, BaseModel): fb_filters, search_joins = cls._prepare_filter_by(filter_by, facet_fields)
filter_by = filter_by.model_dump(exclude_none=True) or None filters.extend(fb_filters)
# Build filter_by conditions from declared facet fields
if filter_by:
resolved_facets_for_filter = (
facet_fields if facet_fields is not None else cls.facet_fields
)
fb_filters, fb_joins = build_filter_by(
filter_by, resolved_facets_for_filter or []
)
filters.extend(fb_filters)
search_joins.extend(fb_joins)
if cls.cursor_column is None: if cls.cursor_column is None:
raise ValueError( raise ValueError(
@@ -1058,29 +922,23 @@ class AsyncCrud(Generic[ModelType]):
# Build search filters # Build search filters
if search: if search:
search_filters, search_joins = build_search_filters( search_filters, new_search_joins = build_search_filters(
cls.model, cls.model,
search, search,
search_fields=search_fields, search_fields=search_fields,
default_fields=cls.searchable_fields, default_fields=cls.searchable_fields,
) )
filters.extend(search_filters) filters.extend(search_filters)
search_joins.extend(new_search_joins)
# Build query # Build query
q = select(cls.model) q = select(cls.model)
# Apply explicit joins # Apply explicit joins
if joins: q = _apply_joins(q, joins, outer_join)
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
# Apply search joins (always outer joins) # Apply search joins (always outer joins)
for join_rel in search_joins: q = _apply_search_joins(q, search_joins)
q = q.outerjoin(join_rel)
if filters: if filters:
q = q.where(and_(*filters)) q = q.where(and_(*filters))
@@ -1110,25 +968,11 @@ class AsyncCrud(Generic[ModelType]):
if cursor is not None and items_page: if cursor is not None and items_page:
prev_cursor = _encode_cursor(getattr(items_page[0], cursor_col_name)) prev_cursor = _encode_cursor(getattr(items_page[0], cursor_col_name))
items: list[Any] = ( items: list[Any] = [schema.model_validate(item) for item in items_page]
[schema.model_validate(item) for item in items_page]
if schema
else items_page
)
# Build facets filter_attributes = await cls._build_filter_attributes(
resolved_facet_fields = ( session, facet_fields, filters, search_joins
facet_fields if facet_fields is not None else cls.facet_fields
) )
filter_attributes: dict[str, list[Any]] | None = None
if resolved_facet_fields:
filter_attributes = await build_facets(
session,
cls.model,
resolved_facet_fields,
base_filters=filters or None,
base_joins=search_joins or None,
)
return PaginatedResponse( return PaginatedResponse(
data=items, data=items,
@@ -1147,6 +991,7 @@ def CrudFactory(
*, *,
searchable_fields: Sequence[SearchFieldType] | None = None, searchable_fields: Sequence[SearchFieldType] | None = None,
facet_fields: Sequence[FacetFieldType] | None = None, facet_fields: Sequence[FacetFieldType] | None = None,
order_fields: Sequence[QueryableAttribute[Any]] | None = None,
m2m_fields: M2MFieldType | None = None, m2m_fields: M2MFieldType | None = None,
default_load_options: list[ExecutableOption] | None = None, default_load_options: list[ExecutableOption] | None = None,
cursor_column: Any | None = None, cursor_column: Any | None = None,
@@ -1159,6 +1004,8 @@ def CrudFactory(
facet_fields: Optional list of columns to compute distinct values for in paginated facet_fields: Optional list of columns to compute distinct values for in paginated
responses. Supports direct columns (``User.status``) and relationship tuples responses. Supports direct columns (``User.status``) and relationship tuples
(``(User.role, Role.name)``). Can be overridden per call. (``(User.role, Role.name)``). Can be overridden per call.
order_fields: Optional list of model attributes that callers are allowed to order by
via ``order_params()``. Can be overridden per call.
m2m_fields: Optional mapping for many-to-many relationships. m2m_fields: Optional mapping for many-to-many relationships.
Maps schema field names (containing lists of IDs) to Maps schema field names (containing lists of IDs) to
SQLAlchemy relationship attributes. SQLAlchemy relationship attributes.
@@ -1252,6 +1099,7 @@ def CrudFactory(
"model": model, "model": model,
"searchable_fields": searchable_fields, "searchable_fields": searchable_fields,
"facet_fields": facet_fields, "facet_fields": facet_fields,
"order_fields": order_fields,
"m2m_fields": m2m_fields, "m2m_fields": m2m_fields,
"default_load_options": default_load_options, "default_load_options": default_load_options,
"cursor_column": cursor_column, "cursor_column": cursor_column,

View File

@@ -1,24 +1,23 @@
"""Search utilities for AsyncCrud.""" """Search utilities for AsyncCrud."""
import asyncio import asyncio
import functools
from collections import Counter from collections import Counter
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass, replace
from typing import TYPE_CHECKING, Any, Literal from typing import TYPE_CHECKING, Any, Literal
from sqlalchemy import String, or_, select from sqlalchemy import String, and_, or_, select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.orm.attributes import InstrumentedAttribute
from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError
from ..types import FacetFieldType, SearchFieldType
if TYPE_CHECKING: if TYPE_CHECKING:
from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.elements import ColumnElement
SearchFieldType = InstrumentedAttribute[Any] | tuple[InstrumentedAttribute[Any], ...]
FacetFieldType = SearchFieldType
@dataclass @dataclass
class SearchConfig: class SearchConfig:
@@ -37,6 +36,7 @@ class SearchConfig:
match_mode: Literal["any", "all"] = "any" match_mode: Literal["any", "all"] = "any"
@functools.lru_cache(maxsize=128)
def get_searchable_fields( def get_searchable_fields(
model: type[DeclarativeBase], model: type[DeclarativeBase],
*, *,
@@ -101,14 +101,11 @@ def build_search_filters(
if isinstance(search, str): if isinstance(search, str):
config = SearchConfig(query=search, fields=search_fields) config = SearchConfig(query=search, fields=search_fields)
else: else:
config = search config = (
if search_fields is not None: replace(search, fields=search_fields)
config = SearchConfig( if search_fields is not None
query=config.query, else search
fields=search_fields, )
case_sensitive=config.case_sensitive,
match_mode=config.match_mode,
)
if not config.query or not config.query.strip(): if not config.query or not config.query.strip():
return [], [] return [], []
@@ -227,8 +224,6 @@ async def build_facets(
q = q.outerjoin(rel) q = q.outerjoin(rel)
if base_filters: if base_filters:
from sqlalchemy import and_
q = q.where(and_(*base_filters)) q = q.where(and_(*base_filters))
q = q.order_by(column) q = q.order_by(column)

View File

@@ -1,20 +1,17 @@
"""Dependency factories for FastAPI routes.""" """Dependency factories for FastAPI routes."""
import inspect import inspect
from collections.abc import AsyncGenerator, Callable from collections.abc import Callable
from typing import Any, TypeVar, cast from typing import Any, cast
from fastapi import Depends from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase
from .crud import CrudFactory from .crud import CrudFactory
from .types import ModelType, SessionDependency
__all__ = ["BodyDependency", "PathDependency"] __all__ = ["BodyDependency", "PathDependency"]
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
SessionDependency = Callable[[], AsyncGenerator[AsyncSession, None]]
def PathDependency( def PathDependency(
model: type[ModelType], model: type[ModelType],

View File

@@ -6,6 +6,7 @@ from .exceptions import (
ConflictError, ConflictError,
ForbiddenError, ForbiddenError,
InvalidFacetFilterError, InvalidFacetFilterError,
InvalidOrderFieldError,
NoSearchableFieldsError, NoSearchableFieldsError,
NotFoundError, NotFoundError,
UnauthorizedError, UnauthorizedError,
@@ -21,6 +22,7 @@ __all__ = [
"generate_error_responses", "generate_error_responses",
"init_exceptions_handlers", "init_exceptions_handlers",
"InvalidFacetFilterError", "InvalidFacetFilterError",
"InvalidOrderFieldError",
"NoSearchableFieldsError", "NoSearchableFieldsError",
"NotFoundError", "NotFoundError",
"UnauthorizedError", "UnauthorizedError",

View File

@@ -128,6 +128,31 @@ class InvalidFacetFilterError(ApiException):
super().__init__(detail) super().__init__(detail)
class InvalidOrderFieldError(ApiException):
"""Raised when order_by contains a field not in the allowed order fields."""
api_error = ApiError(
code=422,
msg="Invalid Order Field",
desc="The requested order field is not allowed for this resource.",
err_code="SORT-422",
)
def __init__(self, field: str, valid_fields: list[str]) -> None:
"""Initialize the exception.
Args:
field: The unknown order field provided by the caller
valid_fields: List of valid field names
"""
self.field = field
self.valid_fields = valid_fields
detail = (
f"'{field}' is not an allowed order field. Valid fields: {valid_fields}."
)
super().__init__(detail)
def generate_error_responses( def generate_error_responses(
*errors: type[ApiException], *errors: type[ApiException],
) -> dict[int | str, dict[str, Any]]: ) -> dict[int | str, dict[str, Any]]:

View File

@@ -10,6 +10,10 @@ from fastapi.responses import JSONResponse
from ..schemas import ErrorResponse, ResponseStatus from ..schemas import ErrorResponse, ResponseStatus
from .exceptions import ApiException from .exceptions import ApiException
_VALIDATION_LOCATION_PARAMS: frozenset[str] = frozenset(
{"body", "query", "path", "header", "cookie"}
)
def init_exceptions_handlers(app: FastAPI) -> FastAPI: def init_exceptions_handlers(app: FastAPI) -> FastAPI:
"""Register exception handlers and custom OpenAPI schema on a FastAPI app. """Register exception handlers and custom OpenAPI schema on a FastAPI app.
@@ -106,9 +110,7 @@ def _format_validation_error(
for error in errors: for error in errors:
field_path = ".".join( field_path = ".".join(
str(loc) str(loc) for loc in error["loc"] if loc not in _VALIDATION_LOCATION_PARAMS
for loc in error["loc"]
if loc not in ("body", "query", "path", "header", "cookie")
) )
formatted_errors.append( formatted_errors.append(
{ {

View File

@@ -1,24 +1,84 @@
"""Fixture loading utilities for database seeding.""" """Fixture loading utilities for database seeding."""
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from typing import Any, TypeVar from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import DeclarativeBase
from ..db import get_transaction from ..db import get_transaction
from ..logger import get_logger from ..logger import get_logger
from ..types import ModelType
from .enum import LoadStrategy from .enum import LoadStrategy
from .registry import Context, FixtureRegistry from .registry import Context, FixtureRegistry
logger = get_logger() logger = get_logger()
T = TypeVar("T", bound=DeclarativeBase)
async def _load_ordered(
session: AsyncSession,
registry: FixtureRegistry,
ordered_names: list[str],
strategy: LoadStrategy,
) -> dict[str, list[DeclarativeBase]]:
"""Load fixtures in order."""
results: dict[str, list[DeclarativeBase]] = {}
for name in ordered_names:
fixture = registry.get(name)
instances = list(fixture.func())
if not instances:
results[name] = []
continue
model_name = type(instances[0]).__name__
loaded: list[DeclarativeBase] = []
async with get_transaction(session):
for instance in instances:
if strategy == LoadStrategy.INSERT:
session.add(instance)
loaded.append(instance)
elif strategy == LoadStrategy.MERGE:
merged = await session.merge(instance)
loaded.append(merged)
else: # LoadStrategy.SKIP_EXISTING
pk = _get_primary_key(instance)
if pk is not None:
existing = await session.get(type(instance), pk)
if existing is None:
session.add(instance)
loaded.append(instance)
else:
session.add(instance)
loaded.append(instance)
results[name] = loaded
logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")
return results
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
"""Get the primary key value of a model instance."""
mapper = instance.__class__.__mapper__
pk_cols = mapper.primary_key
if len(pk_cols) == 1:
return getattr(instance, pk_cols[0].name, None)
pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols)
if all(v is not None for v in pk_values):
return pk_values
return None
def get_obj_by_attr( def get_obj_by_attr(
fixtures: Callable[[], Sequence[T]], attr_name: str, value: Any fixtures: Callable[[], Sequence[ModelType]], attr_name: str, value: Any
) -> T: ) -> ModelType:
"""Get a SQLAlchemy model instance by matching an attribute value. """Get a SQLAlchemy model instance by matching an attribute value.
Args: Args:
@@ -57,13 +117,6 @@ async def load_fixtures(
Returns: Returns:
Dict mapping fixture names to loaded instances Dict mapping fixture names to loaded instances
Example:
```python
# Loads 'roles' first (dependency), then 'users'
result = await load_fixtures(session, fixtures, "users")
print(result["users"]) # [User(...), ...]
```
""" """
ordered = registry.resolve_dependencies(*names) ordered = registry.resolve_dependencies(*names)
return await _load_ordered(session, registry, ordered, strategy) return await _load_ordered(session, registry, ordered, strategy)
@@ -85,76 +138,6 @@ async def load_fixtures_by_context(
Returns: Returns:
Dict mapping fixture names to loaded instances Dict mapping fixture names to loaded instances
Example:
```python
# Load base + testing fixtures
await load_fixtures_by_context(
session, fixtures,
Context.BASE, Context.TESTING
)
```
""" """
ordered = registry.resolve_context_dependencies(*contexts) ordered = registry.resolve_context_dependencies(*contexts)
return await _load_ordered(session, registry, ordered, strategy) return await _load_ordered(session, registry, ordered, strategy)
async def _load_ordered(
session: AsyncSession,
registry: FixtureRegistry,
ordered_names: list[str],
strategy: LoadStrategy,
) -> dict[str, list[DeclarativeBase]]:
"""Load fixtures in order."""
results: dict[str, list[DeclarativeBase]] = {}
for name in ordered_names:
fixture = registry.get(name)
instances = list(fixture.func())
if not instances:
results[name] = []
continue
model_name = type(instances[0]).__name__
loaded: list[DeclarativeBase] = []
async with get_transaction(session):
for instance in instances:
if strategy == LoadStrategy.INSERT:
session.add(instance)
loaded.append(instance)
elif strategy == LoadStrategy.MERGE:
merged = await session.merge(instance)
loaded.append(merged)
elif strategy == LoadStrategy.SKIP_EXISTING:
pk = _get_primary_key(instance)
if pk is not None:
existing = await session.get(type(instance), pk)
if existing is None:
session.add(instance)
loaded.append(instance)
else:
session.add(instance)
loaded.append(instance)
results[name] = loaded
logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")
return results
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
"""Get the primary key value of a model instance."""
mapper = instance.__class__.__mapper__
pk_cols = mapper.primary_key
if len(pk_cols) == 1:
return getattr(instance, pk_cols[0].name, None)
pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols)
if all(v is not None for v in pk_values):
return pk_values
return None

View File

@@ -53,17 +53,23 @@ def init_metrics(
logger.debug("Initialising metric provider '%s'", provider.name) logger.debug("Initialising metric provider '%s'", provider.name)
provider.func() provider.func()
collectors = registry.get_collectors() # Partition collectors and cache env check at startup — both are stable for the app lifetime.
async_collectors = [
c for c in registry.get_collectors() if asyncio.iscoroutinefunction(c.func)
]
sync_collectors = [
c for c in registry.get_collectors() if not asyncio.iscoroutinefunction(c.func)
]
multiprocess_mode = _is_multiprocess()
@app.get(path, include_in_schema=False) @app.get(path, include_in_schema=False)
async def metrics_endpoint() -> Response: async def metrics_endpoint() -> Response:
for collector in collectors: for collector in sync_collectors:
if asyncio.iscoroutinefunction(collector.func): collector.func()
await collector.func() for collector in async_collectors:
else: await collector.func()
collector.func()
if _is_multiprocess(): if multiprocess_mode:
prom_registry = CollectorRegistry() prom_registry = CollectorRegistry()
multiprocess.MultiProcessCollector(prom_registry) multiprocess.MultiProcessCollector(prom_registry)
output = generate_latest(prom_registry) output = generate_latest(prom_registry)

View File

@@ -1,24 +1,23 @@
"""Base Pydantic schemas for API responses.""" """Base Pydantic schemas for API responses."""
from enum import Enum from enum import Enum
from typing import Any, ClassVar, Generic, TypeVar from typing import Any, ClassVar, Generic
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from .types import DataT
__all__ = [ __all__ = [
"ApiError", "ApiError",
"CursorPagination", "CursorPagination",
"ErrorResponse", "ErrorResponse",
"OffsetPagination", "OffsetPagination",
"Pagination",
"PaginatedResponse", "PaginatedResponse",
"PydanticBase", "PydanticBase",
"Response", "Response",
"ResponseStatus", "ResponseStatus",
] ]
DataT = TypeVar("DataT")
class PydanticBase(BaseModel): class PydanticBase(BaseModel):
"""Base class for all Pydantic models with common configuration.""" """Base class for all Pydantic models with common configuration."""
@@ -108,10 +107,6 @@ class OffsetPagination(PydanticBase):
has_more: bool has_more: bool
# Backward-compatible - will be removed in v2.0
Pagination = OffsetPagination
class CursorPagination(PydanticBase): class CursorPagination(PydanticBase):
"""Pagination metadata for cursor-based list responses. """Pagination metadata for cursor-based list responses.

View File

@@ -0,0 +1,27 @@
"""Shared type aliases for the fastapi-toolsets package."""
from collections.abc import AsyncGenerator, Callable, Mapping
from typing import Any, TypeVar
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase, QueryableAttribute
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.sql.elements import ColumnElement
# Generic TypeVars
DataT = TypeVar("DataT")
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
SchemaType = TypeVar("SchemaType", bound=BaseModel)
# CRUD type aliases
JoinType = list[tuple[type[DeclarativeBase], Any]]
M2MFieldType = Mapping[str, QueryableAttribute[Any]]
OrderByClause = ColumnElement[Any] | QueryableAttribute[Any]
# Search / facet type aliases
SearchFieldType = InstrumentedAttribute[Any] | tuple[InstrumentedAttribute[Any], ...]
FacetFieldType = SearchFieldType
# Dependency type aliases
SessionDependency = Callable[[], AsyncGenerator[AsyncSession, None]]

View File

@@ -92,6 +92,15 @@ class IntRole(Base):
name: Mapped[str] = mapped_column(String(50), unique=True) name: Mapped[str] = mapped_column(String(50), unique=True)
class Permission(Base):
"""Test model with composite primary key."""
__tablename__ = "permissions"
subject: Mapped[str] = mapped_column(String(50), primary_key=True)
action: Mapped[str] = mapped_column(String(50), primary_key=True)
class Event(Base): class Event(Base):
"""Test model with DateTime and Date cursor columns.""" """Test model with DateTime and Date cursor columns."""
@@ -162,6 +171,7 @@ class UserRead(PydanticBase):
id: uuid.UUID id: uuid.UUID
username: str username: str
is_active: bool = True
class UserUpdate(BaseModel): class UserUpdate(BaseModel):
@@ -218,12 +228,26 @@ class PostM2MUpdate(BaseModel):
tag_ids: list[uuid.UUID] | None = None tag_ids: list[uuid.UUID] | None = None
class IntRoleRead(PydanticBase):
"""Schema for reading an IntRole."""
id: int
name: str
class IntRoleCreate(BaseModel): class IntRoleCreate(BaseModel):
"""Schema for creating an IntRole.""" """Schema for creating an IntRole."""
name: str name: str
class EventRead(PydanticBase):
"""Schema for reading an Event."""
id: uuid.UUID
name: str
class EventCreate(BaseModel): class EventCreate(BaseModel):
"""Schema for creating an Event.""" """Schema for creating an Event."""
@@ -232,6 +256,13 @@ class EventCreate(BaseModel):
scheduled_date: datetime.date scheduled_date: datetime.date
class ProductRead(PydanticBase):
"""Schema for reading a Product."""
id: uuid.UUID
name: str
class ProductCreate(BaseModel): class ProductCreate(BaseModel):
"""Schema for creating a Product.""" """Schema for creating a Product."""

View File

@@ -15,8 +15,10 @@ from .conftest import (
EventCrud, EventCrud,
EventDateCursorCrud, EventDateCursorCrud,
EventDateTimeCursorCrud, EventDateTimeCursorCrud,
EventRead,
IntRoleCreate, IntRoleCreate,
IntRoleCursorCrud, IntRoleCursorCrud,
IntRoleRead,
Post, Post,
PostCreate, PostCreate,
PostCrud, PostCrud,
@@ -26,6 +28,7 @@ from .conftest import (
ProductCreate, ProductCreate,
ProductCrud, ProductCrud,
ProductNumericCursorCrud, ProductNumericCursorCrud,
ProductRead,
Role, Role,
RoleCreate, RoleCreate,
RoleCrud, RoleCrud,
@@ -169,7 +172,14 @@ class TestDefaultLoadOptionsIntegration:
async def test_default_load_options_applied_to_paginate( async def test_default_load_options_applied_to_paginate(
self, db_session: AsyncSession self, db_session: AsyncSession
): ):
"""default_load_options loads relationships automatically on paginate().""" """default_load_options loads relationships automatically on offset_paginate()."""
from fastapi_toolsets.schemas import PydanticBase
class UserWithRoleRead(PydanticBase):
id: uuid.UUID
username: str
role: RoleRead | None = None
UserWithDefaultLoad = CrudFactory( UserWithDefaultLoad = CrudFactory(
User, default_load_options=[selectinload(User.role)] User, default_load_options=[selectinload(User.role)]
) )
@@ -178,7 +188,9 @@ class TestDefaultLoadOptionsIntegration:
db_session, db_session,
UserCreate(username="alice", email="alice@test.com", role_id=role.id), UserCreate(username="alice", email="alice@test.com", role_id=role.id),
) )
result = await UserWithDefaultLoad.paginate(db_session) result = await UserWithDefaultLoad.offset_paginate(
db_session, schema=UserWithRoleRead
)
assert result.data[0].role is not None assert result.data[0].role is not None
assert result.data[0].role.name == "admin" assert result.data[0].role.name == "admin"
@@ -430,7 +442,7 @@ class TestCrudDelete:
role = await RoleCrud.create(db_session, RoleCreate(name="to_delete")) role = await RoleCrud.create(db_session, RoleCreate(name="to_delete"))
result = await RoleCrud.delete(db_session, [Role.id == role.id]) result = await RoleCrud.delete(db_session, [Role.id == role.id])
assert result is True assert result is None
assert await RoleCrud.first(db_session, [Role.id == role.id]) is None assert await RoleCrud.first(db_session, [Role.id == role.id]) is None
@pytest.mark.anyio @pytest.mark.anyio
@@ -454,6 +466,20 @@ class TestCrudDelete:
assert len(remaining) == 1 assert len(remaining) == 1
assert remaining[0].username == "u3" assert remaining[0].username == "u3"
@pytest.mark.anyio
async def test_delete_return_response(self, db_session: AsyncSession):
"""Delete with return_response=True returns Response[None]."""
from fastapi_toolsets.schemas import Response
role = await RoleCrud.create(db_session, RoleCreate(name="to_delete_resp"))
result = await RoleCrud.delete(
db_session, [Role.id == role.id], return_response=True
)
assert isinstance(result, Response)
assert result.data is None
assert await RoleCrud.first(db_session, [Role.id == role.id]) is None
class TestCrudExists: class TestCrudExists:
"""Tests for CRUD exists operations.""" """Tests for CRUD exists operations."""
@@ -594,7 +620,9 @@ class TestCrudPaginate:
from fastapi_toolsets.schemas import OffsetPagination from fastapi_toolsets.schemas import OffsetPagination
result = await RoleCrud.paginate(db_session, page=1, items_per_page=10) result = await RoleCrud.offset_paginate(
db_session, page=1, items_per_page=10, schema=RoleRead
)
assert isinstance(result.pagination, OffsetPagination) assert isinstance(result.pagination, OffsetPagination)
assert len(result.data) == 10 assert len(result.data) == 10
@@ -609,7 +637,9 @@ class TestCrudPaginate:
for i in range(25): for i in range(25):
await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}")) await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
result = await RoleCrud.paginate(db_session, page=3, items_per_page=10) result = await RoleCrud.offset_paginate(
db_session, page=3, items_per_page=10, schema=RoleRead
)
assert len(result.data) == 5 assert len(result.data) == 5
assert result.pagination.has_more is False assert result.pagination.has_more is False
@@ -629,11 +659,12 @@ class TestCrudPaginate:
from fastapi_toolsets.schemas import OffsetPagination from fastapi_toolsets.schemas import OffsetPagination
result = await UserCrud.paginate( result = await UserCrud.offset_paginate(
db_session, db_session,
filters=[User.is_active == True], # noqa: E712 filters=[User.is_active == True], # noqa: E712
page=1, page=1,
items_per_page=10, items_per_page=10,
schema=UserRead,
) )
assert isinstance(result.pagination, OffsetPagination) assert isinstance(result.pagination, OffsetPagination)
@@ -646,11 +677,12 @@ class TestCrudPaginate:
await RoleCrud.create(db_session, RoleCreate(name="alpha")) await RoleCrud.create(db_session, RoleCreate(name="alpha"))
await RoleCrud.create(db_session, RoleCreate(name="bravo")) await RoleCrud.create(db_session, RoleCreate(name="bravo"))
result = await RoleCrud.paginate( result = await RoleCrud.offset_paginate(
db_session, db_session,
order_by=Role.name, order_by=Role.name,
page=1, page=1,
items_per_page=10, items_per_page=10,
schema=RoleRead,
) )
names = [r.name for r in result.data] names = [r.name for r in result.data]
@@ -855,12 +887,13 @@ class TestCrudJoins:
from fastapi_toolsets.schemas import OffsetPagination from fastapi_toolsets.schemas import OffsetPagination
# Paginate users with published posts # Paginate users with published posts
result = await UserCrud.paginate( result = await UserCrud.offset_paginate(
db_session, db_session,
joins=[(Post, Post.author_id == User.id)], joins=[(Post, Post.author_id == User.id)],
filters=[Post.is_published == True], # noqa: E712 filters=[Post.is_published == True], # noqa: E712
page=1, page=1,
items_per_page=10, items_per_page=10,
schema=UserRead,
) )
assert isinstance(result.pagination, OffsetPagination) assert isinstance(result.pagination, OffsetPagination)
@@ -889,12 +922,13 @@ class TestCrudJoins:
from fastapi_toolsets.schemas import OffsetPagination from fastapi_toolsets.schemas import OffsetPagination
# Paginate with outer join # Paginate with outer join
result = await UserCrud.paginate( result = await UserCrud.offset_paginate(
db_session, db_session,
joins=[(Post, Post.author_id == User.id)], joins=[(Post, Post.author_id == User.id)],
outer_join=True, outer_join=True,
page=1, page=1,
items_per_page=10, items_per_page=10,
schema=UserRead,
) )
assert isinstance(result.pagination, OffsetPagination) assert isinstance(result.pagination, OffsetPagination)
@@ -931,70 +965,6 @@ class TestCrudJoins:
assert users[0].username == "multi_join" assert users[0].username == "multi_join"
class TestAsResponse:
"""Tests for as_response parameter (deprecated, kept for backward compat)."""
@pytest.mark.anyio
async def test_create_as_response(self, db_session: AsyncSession):
"""Create with as_response=True returns Response and emits DeprecationWarning."""
from fastapi_toolsets.schemas import Response
data = RoleCreate(name="response_role")
with pytest.warns(DeprecationWarning, match="as_response is deprecated"):
result = await RoleCrud.create(db_session, data, as_response=True)
assert isinstance(result, Response)
assert result.data is not None
assert result.data.name == "response_role"
@pytest.mark.anyio
async def test_get_as_response(self, db_session: AsyncSession):
"""Get with as_response=True returns Response and emits DeprecationWarning."""
from fastapi_toolsets.schemas import Response
created = await RoleCrud.create(db_session, RoleCreate(name="get_response"))
with pytest.warns(DeprecationWarning, match="as_response is deprecated"):
result = await RoleCrud.get(
db_session, [Role.id == created.id], as_response=True
)
assert isinstance(result, Response)
assert result.data is not None
assert result.data.id == created.id
@pytest.mark.anyio
async def test_update_as_response(self, db_session: AsyncSession):
"""Update with as_response=True returns Response and emits DeprecationWarning."""
from fastapi_toolsets.schemas import Response
created = await RoleCrud.create(db_session, RoleCreate(name="old_name"))
with pytest.warns(DeprecationWarning, match="as_response is deprecated"):
result = await RoleCrud.update(
db_session,
RoleUpdate(name="new_name"),
[Role.id == created.id],
as_response=True,
)
assert isinstance(result, Response)
assert result.data is not None
assert result.data.name == "new_name"
@pytest.mark.anyio
async def test_delete_as_response(self, db_session: AsyncSession):
"""Delete with as_response=True returns Response and emits DeprecationWarning."""
from fastapi_toolsets.schemas import Response
created = await RoleCrud.create(db_session, RoleCreate(name="to_delete"))
with pytest.warns(DeprecationWarning, match="as_response is deprecated"):
result = await RoleCrud.delete(
db_session, [Role.id == created.id], as_response=True
)
assert isinstance(result, Response)
assert result.data is None
class TestCrudFactoryM2M: class TestCrudFactoryM2M:
"""Tests for CrudFactory with m2m_fields parameter.""" """Tests for CrudFactory with m2m_fields parameter."""
@@ -1475,92 +1445,35 @@ class TestSchemaResponse:
assert isinstance(result, Response) assert isinstance(result, Response)
@pytest.mark.anyio @pytest.mark.anyio
async def test_paginate_with_schema(self, db_session: AsyncSession): async def test_offset_paginate_with_schema(self, db_session: AsyncSession):
"""paginate with schema returns PaginatedResponse[SchemaType].""" """offset_paginate with schema returns PaginatedResponse[SchemaType]."""
from fastapi_toolsets.schemas import PaginatedResponse from fastapi_toolsets.schemas import PaginatedResponse
await RoleCrud.create(db_session, RoleCreate(name="p_role1")) await RoleCrud.create(db_session, RoleCreate(name="p_role1"))
await RoleCrud.create(db_session, RoleCreate(name="p_role2")) await RoleCrud.create(db_session, RoleCreate(name="p_role2"))
result = await RoleCrud.paginate(db_session, schema=RoleRead) result = await RoleCrud.offset_paginate(db_session, schema=RoleRead)
assert isinstance(result, PaginatedResponse) assert isinstance(result, PaginatedResponse)
assert len(result.data) == 2 assert len(result.data) == 2
assert all(isinstance(item, RoleRead) for item in result.data) assert all(isinstance(item, RoleRead) for item in result.data)
@pytest.mark.anyio @pytest.mark.anyio
async def test_paginate_schema_filters_fields(self, db_session: AsyncSession): async def test_offset_paginate_schema_filters_fields(
"""paginate with schema only exposes schema fields per item.""" self, db_session: AsyncSession
):
"""offset_paginate with schema only exposes schema fields per item."""
await UserCrud.create( await UserCrud.create(
db_session, db_session,
UserCreate(username="pg_user", email="pg@test.com"), UserCreate(username="pg_user", email="pg@test.com"),
) )
result = await UserCrud.paginate(db_session, schema=UserRead) result = await UserCrud.offset_paginate(db_session, schema=UserRead)
assert isinstance(result.data[0], UserRead) assert isinstance(result.data[0], UserRead)
assert result.data[0].username == "pg_user" assert result.data[0].username == "pg_user"
assert not hasattr(result.data[0], "email") assert not hasattr(result.data[0], "email")
@pytest.mark.anyio
async def test_as_response_true_without_schema_unchanged(
self, db_session: AsyncSession
):
"""as_response=True without schema still returns Response[ModelType] with a warning."""
from fastapi_toolsets.schemas import Response
created = await RoleCrud.create(db_session, RoleCreate(name="compat"))
with pytest.warns(DeprecationWarning, match="as_response is deprecated"):
result = await RoleCrud.get(
db_session, [Role.id == created.id], as_response=True
)
assert isinstance(result, Response)
assert isinstance(result.data, Role)
@pytest.mark.anyio
async def test_schema_with_explicit_as_response_true(
self, db_session: AsyncSession
):
"""schema combined with explicit as_response=True works correctly."""
from fastapi_toolsets.schemas import Response
created = await RoleCrud.create(db_session, RoleCreate(name="combined"))
result = await RoleCrud.get(
db_session,
[Role.id == created.id],
as_response=True,
schema=RoleRead,
)
assert isinstance(result, Response)
assert isinstance(result.data, RoleRead)
class TestPaginateAlias:
"""Tests that paginate is a backward-compatible alias for offset_paginate."""
def test_paginate_is_alias_of_offset_paginate(self):
"""paginate and offset_paginate are the same underlying function."""
assert RoleCrud.paginate.__func__ is RoleCrud.offset_paginate.__func__
@pytest.mark.anyio
async def test_paginate_alias_returns_offset_pagination(
self, db_session: AsyncSession
):
"""paginate() still works and returns PaginatedResponse with OffsetPagination."""
from fastapi_toolsets.schemas import OffsetPagination, PaginatedResponse
for i in range(3):
await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
result = await RoleCrud.paginate(db_session, page=1, items_per_page=10)
assert isinstance(result, PaginatedResponse)
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 3
assert result.pagination.page == 1
class TestCursorPaginate: class TestCursorPaginate:
"""Tests for cursor-based pagination via cursor_paginate().""" """Tests for cursor-based pagination via cursor_paginate()."""
@@ -1573,7 +1486,9 @@ class TestCursorPaginate:
for i in range(25): for i in range(25):
await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}")) await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
result = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=10) result = await RoleCursorCrud.cursor_paginate(
db_session, items_per_page=10, schema=RoleRead
)
assert isinstance(result, PaginatedResponse) assert isinstance(result, PaginatedResponse)
assert isinstance(result.pagination, CursorPagination) assert isinstance(result.pagination, CursorPagination)
@@ -1591,7 +1506,9 @@ class TestCursorPaginate:
for i in range(5): for i in range(5):
await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}")) await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
result = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=10) result = await RoleCursorCrud.cursor_paginate(
db_session, items_per_page=10, schema=RoleRead
)
assert isinstance(result.pagination, CursorPagination) assert isinstance(result.pagination, CursorPagination)
assert len(result.data) == 5 assert len(result.data) == 5
@@ -1606,14 +1523,16 @@ class TestCursorPaginate:
from fastapi_toolsets.schemas import CursorPagination from fastapi_toolsets.schemas import CursorPagination
page1 = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=10) page1 = await RoleCursorCrud.cursor_paginate(
db_session, items_per_page=10, schema=RoleRead
)
assert isinstance(page1.pagination, CursorPagination) assert isinstance(page1.pagination, CursorPagination)
assert len(page1.data) == 10 assert len(page1.data) == 10
assert page1.pagination.has_more is True assert page1.pagination.has_more is True
cursor = page1.pagination.next_cursor cursor = page1.pagination.next_cursor
page2 = await RoleCursorCrud.cursor_paginate( page2 = await RoleCursorCrud.cursor_paginate(
db_session, cursor=cursor, items_per_page=10 db_session, cursor=cursor, items_per_page=10, schema=RoleRead
) )
assert isinstance(page2.pagination, CursorPagination) assert isinstance(page2.pagination, CursorPagination)
assert len(page2.data) == 5 assert len(page2.data) == 5
@@ -1628,12 +1547,15 @@ class TestCursorPaginate:
from fastapi_toolsets.schemas import CursorPagination from fastapi_toolsets.schemas import CursorPagination
page1 = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=4) page1 = await RoleCursorCrud.cursor_paginate(
db_session, items_per_page=4, schema=RoleRead
)
assert isinstance(page1.pagination, CursorPagination) assert isinstance(page1.pagination, CursorPagination)
page2 = await RoleCursorCrud.cursor_paginate( page2 = await RoleCursorCrud.cursor_paginate(
db_session, db_session,
cursor=page1.pagination.next_cursor, cursor=page1.pagination.next_cursor,
items_per_page=4, items_per_page=4,
schema=RoleRead,
) )
ids_page1 = {r.id for r in page1.data} ids_page1 = {r.id for r in page1.data}
@@ -1646,7 +1568,9 @@ class TestCursorPaginate:
"""cursor_paginate on an empty table returns empty data with no cursor.""" """cursor_paginate on an empty table returns empty data with no cursor."""
from fastapi_toolsets.schemas import CursorPagination from fastapi_toolsets.schemas import CursorPagination
result = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=10) result = await RoleCursorCrud.cursor_paginate(
db_session, items_per_page=10, schema=RoleRead
)
assert isinstance(result.pagination, CursorPagination) assert isinstance(result.pagination, CursorPagination)
assert result.data == [] assert result.data == []
@@ -1671,6 +1595,7 @@ class TestCursorPaginate:
db_session, db_session,
filters=[User.is_active == True], # noqa: E712 filters=[User.is_active == True], # noqa: E712
items_per_page=20, items_per_page=20,
schema=UserRead,
) )
assert len(result.data) == 5 assert len(result.data) == 5
@@ -1703,7 +1628,9 @@ class TestCursorPaginate:
for i in range(5): for i in range(5):
await RoleNameCrud.create(db_session, RoleCreate(name=f"role{i:02d}")) await RoleNameCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
result = await RoleNameCrud.cursor_paginate(db_session, items_per_page=3) result = await RoleNameCrud.cursor_paginate(
db_session, items_per_page=3, schema=RoleRead
)
assert isinstance(result.pagination, CursorPagination) assert isinstance(result.pagination, CursorPagination)
assert len(result.data) == 3 assert len(result.data) == 3
@@ -1714,7 +1641,7 @@ class TestCursorPaginate:
async def test_raises_without_cursor_column(self, db_session: AsyncSession): async def test_raises_without_cursor_column(self, db_session: AsyncSession):
"""cursor_paginate raises ValueError when cursor_column is not configured.""" """cursor_paginate raises ValueError when cursor_column is not configured."""
with pytest.raises(ValueError, match="cursor_column is not set"): with pytest.raises(ValueError, match="cursor_column is not set"):
await RoleCrud.cursor_paginate(db_session) await RoleCrud.cursor_paginate(db_session, schema=RoleRead)
class TestCursorPaginatePrevCursor: class TestCursorPaginatePrevCursor:
@@ -1728,7 +1655,9 @@ class TestCursorPaginatePrevCursor:
from fastapi_toolsets.schemas import CursorPagination from fastapi_toolsets.schemas import CursorPagination
result = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=3) result = await RoleCursorCrud.cursor_paginate(
db_session, items_per_page=3, schema=RoleRead
)
assert isinstance(result.pagination, CursorPagination) assert isinstance(result.pagination, CursorPagination)
assert result.pagination.prev_cursor is None assert result.pagination.prev_cursor is None
@@ -1741,12 +1670,15 @@ class TestCursorPaginatePrevCursor:
from fastapi_toolsets.schemas import CursorPagination from fastapi_toolsets.schemas import CursorPagination
page1 = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=5) page1 = await RoleCursorCrud.cursor_paginate(
db_session, items_per_page=5, schema=RoleRead
)
assert isinstance(page1.pagination, CursorPagination) assert isinstance(page1.pagination, CursorPagination)
page2 = await RoleCursorCrud.cursor_paginate( page2 = await RoleCursorCrud.cursor_paginate(
db_session, db_session,
cursor=page1.pagination.next_cursor, cursor=page1.pagination.next_cursor,
items_per_page=5, items_per_page=5,
schema=RoleRead,
) )
assert isinstance(page2.pagination, CursorPagination) assert isinstance(page2.pagination, CursorPagination)
assert page2.pagination.prev_cursor is not None assert page2.pagination.prev_cursor is not None
@@ -1762,12 +1694,15 @@ class TestCursorPaginatePrevCursor:
from fastapi_toolsets.schemas import CursorPagination from fastapi_toolsets.schemas import CursorPagination
page1 = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=5) page1 = await RoleCursorCrud.cursor_paginate(
db_session, items_per_page=5, schema=RoleRead
)
assert isinstance(page1.pagination, CursorPagination) assert isinstance(page1.pagination, CursorPagination)
page2 = await RoleCursorCrud.cursor_paginate( page2 = await RoleCursorCrud.cursor_paginate(
db_session, db_session,
cursor=page1.pagination.next_cursor, cursor=page1.pagination.next_cursor,
items_per_page=5, items_per_page=5,
schema=RoleRead,
) )
assert isinstance(page2.pagination, CursorPagination) assert isinstance(page2.pagination, CursorPagination)
assert page2.pagination.prev_cursor is not None assert page2.pagination.prev_cursor is not None
@@ -1802,6 +1737,7 @@ class TestCursorPaginateWithSearch:
db_session, db_session,
search="admin", search="admin",
items_per_page=20, items_per_page=20,
schema=RoleRead,
) )
assert len(result.data) == 5 assert len(result.data) == 5
@@ -1836,6 +1772,7 @@ class TestCursorPaginateExtraOptions:
db_session, db_session,
joins=[(Role, User.role_id == Role.id)], joins=[(Role, User.role_id == Role.id)],
items_per_page=20, items_per_page=20,
schema=UserRead,
) )
assert isinstance(result.pagination, CursorPagination) assert isinstance(result.pagination, CursorPagination)
@@ -1867,6 +1804,7 @@ class TestCursorPaginateExtraOptions:
joins=[(Role, User.role_id == Role.id)], joins=[(Role, User.role_id == Role.id)],
outer_join=True, outer_join=True,
items_per_page=20, items_per_page=20,
schema=UserRead,
) )
assert isinstance(result.pagination, CursorPagination) assert isinstance(result.pagination, CursorPagination)
@@ -1876,7 +1814,12 @@ class TestCursorPaginateExtraOptions:
@pytest.mark.anyio @pytest.mark.anyio
async def test_with_load_options(self, db_session: AsyncSession): async def test_with_load_options(self, db_session: AsyncSession):
"""cursor_paginate passes load_options to the query.""" """cursor_paginate passes load_options to the query."""
from fastapi_toolsets.schemas import CursorPagination from fastapi_toolsets.schemas import CursorPagination, PydanticBase
class UserWithRoleRead(PydanticBase):
id: uuid.UUID
username: str
role: RoleRead | None = None
role = await RoleCrud.create(db_session, RoleCreate(name="manager")) role = await RoleCrud.create(db_session, RoleCreate(name="manager"))
for i in range(3): for i in range(3):
@@ -1893,6 +1836,7 @@ class TestCursorPaginateExtraOptions:
db_session, db_session,
load_options=[selectinload(User.role)], load_options=[selectinload(User.role)],
items_per_page=20, items_per_page=20,
schema=UserWithRoleRead,
) )
assert isinstance(result.pagination, CursorPagination) assert isinstance(result.pagination, CursorPagination)
@@ -1912,6 +1856,7 @@ class TestCursorPaginateExtraOptions:
db_session, db_session,
order_by=Role.name.desc(), order_by=Role.name.desc(),
items_per_page=3, items_per_page=3,
schema=RoleRead,
) )
assert isinstance(result.pagination, CursorPagination) assert isinstance(result.pagination, CursorPagination)
@@ -1925,7 +1870,9 @@ class TestCursorPaginateExtraOptions:
for i in range(5): for i in range(5):
await IntRoleCursorCrud.create(db_session, IntRoleCreate(name=f"role{i}")) await IntRoleCursorCrud.create(db_session, IntRoleCreate(name=f"role{i}"))
page1 = await IntRoleCursorCrud.cursor_paginate(db_session, items_per_page=3) page1 = await IntRoleCursorCrud.cursor_paginate(
db_session, items_per_page=3, schema=IntRoleRead
)
assert isinstance(page1.pagination, CursorPagination) assert isinstance(page1.pagination, CursorPagination)
assert len(page1.data) == 3 assert len(page1.data) == 3
@@ -1935,6 +1882,7 @@ class TestCursorPaginateExtraOptions:
db_session, db_session,
cursor=page1.pagination.next_cursor, cursor=page1.pagination.next_cursor,
items_per_page=3, items_per_page=3,
schema=IntRoleRead,
) )
assert isinstance(page2.pagination, CursorPagination) assert isinstance(page2.pagination, CursorPagination)
@@ -1955,7 +1903,9 @@ class TestCursorPaginateExtraOptions:
await RoleCrud.create(db_session, RoleCreate(name="role01")) await RoleCrud.create(db_session, RoleCreate(name="role01"))
# First page succeeds (no cursor to decode) # First page succeeds (no cursor to decode)
page1 = await RoleNameCursorCrud.cursor_paginate(db_session, items_per_page=1) page1 = await RoleNameCursorCrud.cursor_paginate(
db_session, items_per_page=1, schema=RoleRead
)
assert page1.pagination.has_more is True assert page1.pagination.has_more is True
assert isinstance(page1.pagination, CursorPagination) assert isinstance(page1.pagination, CursorPagination)
@@ -1965,6 +1915,7 @@ class TestCursorPaginateExtraOptions:
db_session, db_session,
cursor=page1.pagination.next_cursor, cursor=page1.pagination.next_cursor,
items_per_page=1, items_per_page=1,
schema=RoleRead,
) )
@@ -2003,6 +1954,7 @@ class TestCursorPaginateSearchJoins:
search="administrator", search="administrator",
search_fields=[(User.role, Role.name)], search_fields=[(User.role, Role.name)],
items_per_page=20, items_per_page=20,
schema=UserRead,
) )
assert isinstance(result.pagination, CursorPagination) assert isinstance(result.pagination, CursorPagination)
@@ -2049,7 +2001,7 @@ class TestCursorPaginateColumnTypes:
) )
page1 = await EventDateTimeCursorCrud.cursor_paginate( page1 = await EventDateTimeCursorCrud.cursor_paginate(
db_session, items_per_page=3 db_session, items_per_page=3, schema=EventRead
) )
assert isinstance(page1.pagination, CursorPagination) assert isinstance(page1.pagination, CursorPagination)
@@ -2060,6 +2012,7 @@ class TestCursorPaginateColumnTypes:
db_session, db_session,
cursor=page1.pagination.next_cursor, cursor=page1.pagination.next_cursor,
items_per_page=3, items_per_page=3,
schema=EventRead,
) )
assert isinstance(page2.pagination, CursorPagination) assert isinstance(page2.pagination, CursorPagination)
@@ -2087,7 +2040,9 @@ class TestCursorPaginateColumnTypes:
), ),
) )
page1 = await EventDateCursorCrud.cursor_paginate(db_session, items_per_page=3) page1 = await EventDateCursorCrud.cursor_paginate(
db_session, items_per_page=3, schema=EventRead
)
assert isinstance(page1.pagination, CursorPagination) assert isinstance(page1.pagination, CursorPagination)
assert len(page1.data) == 3 assert len(page1.data) == 3
@@ -2097,6 +2052,7 @@ class TestCursorPaginateColumnTypes:
db_session, db_session,
cursor=page1.pagination.next_cursor, cursor=page1.pagination.next_cursor,
items_per_page=3, items_per_page=3,
schema=EventRead,
) )
assert isinstance(page2.pagination, CursorPagination) assert isinstance(page2.pagination, CursorPagination)
@@ -2123,7 +2079,7 @@ class TestCursorPaginateColumnTypes:
) )
page1 = await ProductNumericCursorCrud.cursor_paginate( page1 = await ProductNumericCursorCrud.cursor_paginate(
db_session, items_per_page=3 db_session, items_per_page=3, schema=ProductRead
) )
assert isinstance(page1.pagination, CursorPagination) assert isinstance(page1.pagination, CursorPagination)
@@ -2134,6 +2090,7 @@ class TestCursorPaginateColumnTypes:
db_session, db_session,
cursor=page1.pagination.next_cursor, cursor=page1.pagination.next_cursor,
items_per_page=3, items_per_page=3,
schema=ProductRead,
) )
assert isinstance(page2.pagination, CursorPagination) assert isinstance(page2.pagination, CursorPagination)

View File

@@ -1,9 +1,11 @@
"""Tests for CRUD search functionality.""" """Tests for CRUD search functionality."""
import inspect
import uuid import uuid
import pytest import pytest
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql.elements import ColumnElement, UnaryExpression
from fastapi_toolsets.crud import ( from fastapi_toolsets.crud import (
CrudFactory, CrudFactory,
@@ -11,6 +13,7 @@ from fastapi_toolsets.crud import (
SearchConfig, SearchConfig,
get_searchable_fields, get_searchable_fields,
) )
from fastapi_toolsets.exceptions import InvalidOrderFieldError
from fastapi_toolsets.schemas import OffsetPagination from fastapi_toolsets.schemas import OffsetPagination
from .conftest import ( from .conftest import (
@@ -20,6 +23,7 @@ from .conftest import (
User, User,
UserCreate, UserCreate,
UserCrud, UserCrud,
UserRead,
) )
@@ -39,10 +43,11 @@ class TestPaginateSearch:
db_session, UserCreate(username="bob_smith", email="bob@test.com") db_session, UserCreate(username="bob_smith", email="bob@test.com")
) )
result = await UserCrud.paginate( result = await UserCrud.offset_paginate(
db_session, db_session,
search="doe", search="doe",
search_fields=[User.username], search_fields=[User.username],
schema=UserRead,
) )
assert isinstance(result.pagination, OffsetPagination) assert isinstance(result.pagination, OffsetPagination)
@@ -58,10 +63,11 @@ class TestPaginateSearch:
db_session, UserCreate(username="company_bob", email="bob@other.com") db_session, UserCreate(username="company_bob", email="bob@other.com")
) )
result = await UserCrud.paginate( result = await UserCrud.offset_paginate(
db_session, db_session,
search="company", search="company",
search_fields=[User.username, User.email], search_fields=[User.username, User.email],
schema=UserRead,
) )
assert isinstance(result.pagination, OffsetPagination) assert isinstance(result.pagination, OffsetPagination)
@@ -86,10 +92,11 @@ class TestPaginateSearch:
UserCreate(username="user1", email="u1@test.com", role_id=user_role.id), UserCreate(username="user1", email="u1@test.com", role_id=user_role.id),
) )
result = await UserCrud.paginate( result = await UserCrud.offset_paginate(
db_session, db_session,
search="admin", search="admin",
search_fields=[(User.role, Role.name)], search_fields=[(User.role, Role.name)],
schema=UserRead,
) )
assert isinstance(result.pagination, OffsetPagination) assert isinstance(result.pagination, OffsetPagination)
@@ -105,10 +112,11 @@ class TestPaginateSearch:
) )
# Search "admin" in username OR role.name # Search "admin" in username OR role.name
result = await UserCrud.paginate( result = await UserCrud.offset_paginate(
db_session, db_session,
search="admin", search="admin",
search_fields=[User.username, (User.role, Role.name)], search_fields=[User.username, (User.role, Role.name)],
schema=UserRead,
) )
assert isinstance(result.pagination, OffsetPagination) assert isinstance(result.pagination, OffsetPagination)
@@ -121,10 +129,11 @@ class TestPaginateSearch:
db_session, UserCreate(username="JohnDoe", email="j@test.com") db_session, UserCreate(username="JohnDoe", email="j@test.com")
) )
result = await UserCrud.paginate( result = await UserCrud.offset_paginate(
db_session, db_session,
search="johndoe", search="johndoe",
search_fields=[User.username], search_fields=[User.username],
schema=UserRead,
) )
assert isinstance(result.pagination, OffsetPagination) assert isinstance(result.pagination, OffsetPagination)
@@ -138,19 +147,21 @@ class TestPaginateSearch:
) )
# Should not find (case mismatch) # Should not find (case mismatch)
result = await UserCrud.paginate( result = await UserCrud.offset_paginate(
db_session, db_session,
search=SearchConfig(query="johndoe", case_sensitive=True), search=SearchConfig(query="johndoe", case_sensitive=True),
search_fields=[User.username], search_fields=[User.username],
schema=UserRead,
) )
assert isinstance(result.pagination, OffsetPagination) assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 0 assert result.pagination.total_count == 0
# Should find (case match) # Should find (case match)
result = await UserCrud.paginate( result = await UserCrud.offset_paginate(
db_session, db_session,
search=SearchConfig(query="JohnDoe", case_sensitive=True), search=SearchConfig(query="JohnDoe", case_sensitive=True),
search_fields=[User.username], search_fields=[User.username],
schema=UserRead,
) )
assert isinstance(result.pagination, OffsetPagination) assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 1 assert result.pagination.total_count == 1
@@ -165,11 +176,13 @@ class TestPaginateSearch:
db_session, UserCreate(username="user2", email="u2@test.com") db_session, UserCreate(username="user2", email="u2@test.com")
) )
result = await UserCrud.paginate(db_session, search="") result = await UserCrud.offset_paginate(db_session, search="", schema=UserRead)
assert isinstance(result.pagination, OffsetPagination) assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 2 assert result.pagination.total_count == 2
result = await UserCrud.paginate(db_session, search=None) result = await UserCrud.offset_paginate(
db_session, search=None, schema=UserRead
)
assert isinstance(result.pagination, OffsetPagination) assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 2 assert result.pagination.total_count == 2
@@ -185,11 +198,12 @@ class TestPaginateSearch:
UserCreate(username="inactive_john", email="ij@test.com", is_active=False), UserCreate(username="inactive_john", email="ij@test.com", is_active=False),
) )
result = await UserCrud.paginate( result = await UserCrud.offset_paginate(
db_session, db_session,
filters=[User.is_active == True], # noqa: E712 filters=[User.is_active == True], # noqa: E712
search="john", search="john",
search_fields=[User.username], search_fields=[User.username],
schema=UserRead,
) )
assert isinstance(result.pagination, OffsetPagination) assert isinstance(result.pagination, OffsetPagination)
@@ -203,7 +217,9 @@ class TestPaginateSearch:
db_session, UserCreate(username="findme", email="other@test.com") db_session, UserCreate(username="findme", email="other@test.com")
) )
result = await UserCrud.paginate(db_session, search="findme") result = await UserCrud.offset_paginate(
db_session, search="findme", schema=UserRead
)
assert isinstance(result.pagination, OffsetPagination) assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 1 assert result.pagination.total_count == 1
@@ -215,10 +231,11 @@ class TestPaginateSearch:
db_session, UserCreate(username="john", email="j@test.com") db_session, UserCreate(username="john", email="j@test.com")
) )
result = await UserCrud.paginate( result = await UserCrud.offset_paginate(
db_session, db_session,
search="nonexistent", search="nonexistent",
search_fields=[User.username], search_fields=[User.username],
schema=UserRead,
) )
assert isinstance(result.pagination, OffsetPagination) assert isinstance(result.pagination, OffsetPagination)
@@ -234,12 +251,13 @@ class TestPaginateSearch:
UserCreate(username=f"user_{i}", email=f"user{i}@test.com"), UserCreate(username=f"user_{i}", email=f"user{i}@test.com"),
) )
result = await UserCrud.paginate( result = await UserCrud.offset_paginate(
db_session, db_session,
search="user_", search="user_",
search_fields=[User.username], search_fields=[User.username],
page=1, page=1,
items_per_page=5, items_per_page=5,
schema=UserRead,
) )
assert isinstance(result.pagination, OffsetPagination) assert isinstance(result.pagination, OffsetPagination)
@@ -261,10 +279,11 @@ class TestPaginateSearch:
) )
# Search in username, not in role # Search in username, not in role
result = await UserCrud.paginate( result = await UserCrud.offset_paginate(
db_session, db_session,
search="role", search="role",
search_fields=[User.username], search_fields=[User.username],
schema=UserRead,
) )
assert isinstance(result.pagination, OffsetPagination) assert isinstance(result.pagination, OffsetPagination)
@@ -283,11 +302,12 @@ class TestPaginateSearch:
db_session, UserCreate(username="bob", email="b@test.com") db_session, UserCreate(username="bob", email="b@test.com")
) )
result = await UserCrud.paginate( result = await UserCrud.offset_paginate(
db_session, db_session,
search="@test.com", search="@test.com",
search_fields=[User.email], search_fields=[User.email],
order_by=User.username, order_by=User.username,
schema=UserRead,
) )
assert isinstance(result.pagination, OffsetPagination) assert isinstance(result.pagination, OffsetPagination)
@@ -307,10 +327,11 @@ class TestPaginateSearch:
) )
# Search by UUID (partial match) # Search by UUID (partial match)
result = await UserCrud.paginate( result = await UserCrud.offset_paginate(
db_session, db_session,
search="12345678", search="12345678",
search_fields=[User.id, User.username], search_fields=[User.id, User.username],
schema=UserRead,
) )
assert isinstance(result.pagination, OffsetPagination) assert isinstance(result.pagination, OffsetPagination)
@@ -360,10 +381,11 @@ class TestSearchConfig:
) )
# 'john' must be in username AND email # 'john' must be in username AND email
result = await UserCrud.paginate( result = await UserCrud.offset_paginate(
db_session, db_session,
search=SearchConfig(query="john", match_mode="all"), search=SearchConfig(query="john", match_mode="all"),
search_fields=[User.username, User.email], search_fields=[User.username, User.email],
schema=UserRead,
) )
assert isinstance(result.pagination, OffsetPagination) assert isinstance(result.pagination, OffsetPagination)
@@ -377,9 +399,10 @@ class TestSearchConfig:
db_session, UserCreate(username="test", email="findme@test.com") db_session, UserCreate(username="test", email="findme@test.com")
) )
result = await UserCrud.paginate( result = await UserCrud.offset_paginate(
db_session, db_session,
search=SearchConfig(query="findme", fields=[User.email]), search=SearchConfig(query="findme", fields=[User.email]),
schema=UserRead,
) )
assert isinstance(result.pagination, OffsetPagination) assert isinstance(result.pagination, OffsetPagination)
@@ -475,7 +498,7 @@ class TestFacetsNotSet:
db_session, UserCreate(username="alice", email="a@test.com") db_session, UserCreate(username="alice", email="a@test.com")
) )
result = await UserCrud.offset_paginate(db_session) result = await UserCrud.offset_paginate(db_session, schema=UserRead)
assert result.filter_attributes is None assert result.filter_attributes is None
@@ -487,7 +510,7 @@ class TestFacetsNotSet:
db_session, UserCreate(username="alice", email="a@test.com") db_session, UserCreate(username="alice", email="a@test.com")
) )
result = await UserCursorCrud.cursor_paginate(db_session) result = await UserCursorCrud.cursor_paginate(db_session, schema=UserRead)
assert result.filter_attributes is None assert result.filter_attributes is None
@@ -506,7 +529,7 @@ class TestFacetsDirectColumn:
db_session, UserCreate(username="bob", email="b@test.com") db_session, UserCreate(username="bob", email="b@test.com")
) )
result = await UserFacetCrud.offset_paginate(db_session) result = await UserFacetCrud.offset_paginate(db_session, schema=UserRead)
assert result.filter_attributes is not None assert result.filter_attributes is not None
# Distinct usernames, sorted # Distinct usernames, sorted
@@ -525,7 +548,7 @@ class TestFacetsDirectColumn:
db_session, UserCreate(username="bob", email="b@test.com") db_session, UserCreate(username="bob", email="b@test.com")
) )
result = await UserFacetCursorCrud.cursor_paginate(db_session) result = await UserFacetCursorCrud.cursor_paginate(db_session, schema=UserRead)
assert result.filter_attributes is not None assert result.filter_attributes is not None
assert set(result.filter_attributes["email"]) == {"a@test.com", "b@test.com"} assert set(result.filter_attributes["email"]) == {"a@test.com", "b@test.com"}
@@ -541,7 +564,7 @@ class TestFacetsDirectColumn:
db_session, UserCreate(username="bob", email="b@test.com") db_session, UserCreate(username="bob", email="b@test.com")
) )
result = await UserFacetCrud.offset_paginate(db_session) result = await UserFacetCrud.offset_paginate(db_session, schema=UserRead)
assert result.filter_attributes is not None assert result.filter_attributes is not None
assert "username" in result.filter_attributes assert "username" in result.filter_attributes
@@ -558,7 +581,7 @@ class TestFacetsDirectColumn:
# Override: ask for email instead of username # Override: ask for email instead of username
result = await UserFacetCrud.offset_paginate( result = await UserFacetCrud.offset_paginate(
db_session, facet_fields=[User.email] db_session, facet_fields=[User.email], schema=UserRead
) )
assert result.filter_attributes is not None assert result.filter_attributes is not None
@@ -584,6 +607,7 @@ class TestFacetsRespectFilters:
result = await UserFacetCrud.offset_paginate( result = await UserFacetCrud.offset_paginate(
db_session, db_session,
filters=[User.is_active == True], # noqa: E712 filters=[User.is_active == True], # noqa: E712
schema=UserRead,
) )
assert result.filter_attributes is not None assert result.filter_attributes is not None
@@ -614,7 +638,7 @@ class TestFacetsRelationship:
db_session, UserCreate(username="charlie", email="c@test.com") db_session, UserCreate(username="charlie", email="c@test.com")
) )
result = await UserRelFacetCrud.offset_paginate(db_session) result = await UserRelFacetCrud.offset_paginate(db_session, schema=UserRead)
assert result.filter_attributes is not None assert result.filter_attributes is not None
assert set(result.filter_attributes["name"]) == {"admin", "editor"} assert set(result.filter_attributes["name"]) == {"admin", "editor"}
@@ -629,7 +653,7 @@ class TestFacetsRelationship:
db_session, UserCreate(username="norole", email="n@test.com") db_session, UserCreate(username="norole", email="n@test.com")
) )
result = await UserRelFacetCrud.offset_paginate(db_session) result = await UserRelFacetCrud.offset_paginate(db_session, schema=UserRead)
assert result.filter_attributes is not None assert result.filter_attributes is not None
assert result.filter_attributes["name"] == [] assert result.filter_attributes["name"] == []
@@ -653,7 +677,10 @@ class TestFacetsRelationship:
) )
result = await UserSearchFacetCrud.offset_paginate( result = await UserSearchFacetCrud.offset_paginate(
db_session, search="admin", search_fields=[(User.role, Role.name)] db_session,
search="admin",
search_fields=[(User.role, Role.name)],
schema=UserRead,
) )
assert result.filter_attributes is not None assert result.filter_attributes is not None
@@ -675,7 +702,7 @@ class TestFilterBy:
) )
result = await UserFacetCrud.offset_paginate( result = await UserFacetCrud.offset_paginate(
db_session, filter_by={"username": "alice"} db_session, filter_by={"username": "alice"}, schema=UserRead
) )
assert len(result.data) == 1 assert len(result.data) == 1
@@ -698,7 +725,7 @@ class TestFilterBy:
) )
result = await UserFacetCrud.offset_paginate( result = await UserFacetCrud.offset_paginate(
db_session, filter_by={"username": ["alice", "bob"]} db_session, filter_by={"username": ["alice", "bob"]}, schema=UserRead
) )
assert isinstance(result.pagination, OffsetPagination) assert isinstance(result.pagination, OffsetPagination)
@@ -723,7 +750,7 @@ class TestFilterBy:
) )
result = await UserRelFacetCrud.offset_paginate( result = await UserRelFacetCrud.offset_paginate(
db_session, filter_by={"name": "admin"} db_session, filter_by={"name": "admin"}, schema=UserRead
) )
assert isinstance(result.pagination, OffsetPagination) assert isinstance(result.pagination, OffsetPagination)
@@ -746,6 +773,7 @@ class TestFilterBy:
db_session, db_session,
filters=[User.is_active == True], # noqa: E712 filters=[User.is_active == True], # noqa: E712
filter_by={"username": ["alice", "alice2"]}, filter_by={"username": ["alice", "alice2"]},
schema=UserRead,
) )
# Only alice passes both: is_active=True AND username IN [alice, alice2] # Only alice passes both: is_active=True AND username IN [alice, alice2]
@@ -760,7 +788,7 @@ class TestFilterBy:
with pytest.raises(InvalidFacetFilterError) as exc_info: with pytest.raises(InvalidFacetFilterError) as exc_info:
await UserFacetCrud.offset_paginate( await UserFacetCrud.offset_paginate(
db_session, filter_by={"nonexistent": "value"} db_session, filter_by={"nonexistent": "value"}, schema=UserRead
) )
assert exc_info.value.key == "nonexistent" assert exc_info.value.key == "nonexistent"
@@ -792,6 +820,7 @@ class TestFilterBy:
result = await UserRoleFacetCrud.offset_paginate( result = await UserRoleFacetCrud.offset_paginate(
db_session, db_session,
filter_by={"name": "admin", "id": str(admin.id)}, filter_by={"name": "admin", "id": str(admin.id)},
schema=UserRead,
) )
assert isinstance(result.pagination, OffsetPagination) assert isinstance(result.pagination, OffsetPagination)
@@ -812,7 +841,7 @@ class TestFilterBy:
) )
result = await UserFacetCursorCrud.cursor_paginate( result = await UserFacetCursorCrud.cursor_paginate(
db_session, filter_by={"username": "alice"} db_session, filter_by={"username": "alice"}, schema=UserRead
) )
assert len(result.data) == 1 assert len(result.data) == 1
@@ -836,7 +865,7 @@ class TestFilterBy:
) )
result = await UserFacetCrud.offset_paginate( result = await UserFacetCrud.offset_paginate(
db_session, filter_by=UserFilter(username="alice") db_session, filter_by=UserFilter(username="alice"), schema=UserRead
) )
assert isinstance(result.pagination, OffsetPagination) assert isinstance(result.pagination, OffsetPagination)
@@ -862,7 +891,7 @@ class TestFilterBy:
) )
result = await UserFacetCursorCrud.cursor_paginate( result = await UserFacetCursorCrud.cursor_paginate(
db_session, filter_by=UserFilter(username="alice") db_session, filter_by=UserFilter(username="alice"), schema=UserRead
) )
assert len(result.data) == 1 assert len(result.data) == 1
@@ -971,7 +1000,9 @@ class TestFilterParamsSchema:
dep = UserFacetCrud.filter_params() dep = UserFacetCrud.filter_params()
f = await dep(username=["alice"]) f = await dep(username=["alice"])
result = await UserFacetCrud.offset_paginate(db_session, filter_by=f) result = await UserFacetCrud.offset_paginate(
db_session, filter_by=f, schema=UserRead
)
assert isinstance(result.pagination, OffsetPagination) assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 1 assert result.pagination.total_count == 1
@@ -992,7 +1023,9 @@ class TestFilterParamsSchema:
dep = UserFacetCursorCrud.filter_params() dep = UserFacetCursorCrud.filter_params()
f = await dep(username=["alice"]) f = await dep(username=["alice"])
result = await UserFacetCursorCrud.cursor_paginate(db_session, filter_by=f) result = await UserFacetCursorCrud.cursor_paginate(
db_session, filter_by=f, schema=UserRead
)
assert len(result.data) == 1 assert len(result.data) == 1
assert result.data[0].username == "alice" assert result.data[0].username == "alice"
@@ -1010,7 +1043,150 @@ class TestFilterParamsSchema:
dep = UserFacetCrud.filter_params() dep = UserFacetCrud.filter_params()
f = await dep() # all fields None f = await dep() # all fields None
result = await UserFacetCrud.offset_paginate(db_session, filter_by=f) result = await UserFacetCrud.offset_paginate(
db_session, filter_by=f, schema=UserRead
)
assert isinstance(result.pagination, OffsetPagination) assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 2 assert result.pagination.total_count == 2
class TestOrderParamsSchema:
"""Tests for AsyncCrud.order_params()."""
def test_generates_order_by_and_order_params(self):
"""Returned dependency has order_by and order query params."""
UserOrderCrud = CrudFactory(User, order_fields=[User.username, User.email])
dep = UserOrderCrud.order_params()
param_names = set(inspect.signature(dep).parameters)
assert param_names == {"order_by", "order"}
def test_dependency_name_includes_model_name(self):
"""Dependency function is named after the model."""
UserOrderCrud = CrudFactory(User, order_fields=[User.username])
dep = UserOrderCrud.order_params()
assert getattr(dep, "__name__") == "UserOrderParams"
def test_raises_when_no_order_fields(self):
"""ValueError raised when no order_fields are configured or provided."""
with pytest.raises(ValueError, match="no order_fields"):
UserCrud.order_params()
def test_order_fields_override(self):
"""order_fields= parameter overrides the class-level default."""
UserOrderCrud = CrudFactory(User, order_fields=[User.username, User.email])
dep = UserOrderCrud.order_params(order_fields=[User.email])
param_names = set(inspect.signature(dep).parameters)
assert "order_by" in param_names
# description should only mention email, not username
sig = inspect.signature(dep)
description = sig.parameters["order_by"].default.description
assert "email" in description
assert "username" not in description
def test_order_by_description_lists_valid_fields(self):
"""order_by query param description mentions each allowed field."""
UserOrderCrud = CrudFactory(User, order_fields=[User.username, User.email])
dep = UserOrderCrud.order_params()
sig = inspect.signature(dep)
description = sig.parameters["order_by"].default.description
assert "username" in description
assert "email" in description
def test_default_order_reflected_in_order_default(self):
"""default_order is used as the default value for order."""
UserOrderCrud = CrudFactory(User, order_fields=[User.username])
dep_asc = UserOrderCrud.order_params(default_order="asc")
dep_desc = UserOrderCrud.order_params(default_order="desc")
sig_asc = inspect.signature(dep_asc)
sig_desc = inspect.signature(dep_desc)
assert sig_asc.parameters["order"].default.default == "asc"
assert sig_desc.parameters["order"].default.default == "desc"
@pytest.mark.anyio
async def test_no_order_by_no_default_returns_none(self):
"""Returns None when order_by is absent and no default_field is set."""
UserOrderCrud = CrudFactory(User, order_fields=[User.username])
dep = UserOrderCrud.order_params()
result = await dep(order_by=None, order="asc")
assert result is None
@pytest.mark.anyio
async def test_no_order_by_with_default_field_returns_asc_expression(self):
"""Returns default_field.asc() when order_by absent and order=asc."""
UserOrderCrud = CrudFactory(User, order_fields=[User.username])
dep = UserOrderCrud.order_params(default_field=User.username)
result = await dep(order_by=None, order="asc")
assert isinstance(result, UnaryExpression)
assert "ASC" in str(result)
@pytest.mark.anyio
async def test_no_order_by_with_default_field_returns_desc_expression(self):
"""Returns default_field.desc() when order_by absent and order=desc."""
UserOrderCrud = CrudFactory(User, order_fields=[User.username])
dep = UserOrderCrud.order_params(default_field=User.username)
result = await dep(order_by=None, order="desc")
assert isinstance(result, UnaryExpression)
assert "DESC" in str(result)
@pytest.mark.anyio
async def test_valid_order_by_asc(self):
"""Returns field.asc() for a valid order_by with order=asc."""
UserOrderCrud = CrudFactory(User, order_fields=[User.username])
dep = UserOrderCrud.order_params()
result = await dep(order_by="username", order="asc")
assert isinstance(result, UnaryExpression)
assert "ASC" in str(result)
@pytest.mark.anyio
async def test_valid_order_by_desc(self):
"""Returns field.desc() for a valid order_by with order=desc."""
UserOrderCrud = CrudFactory(User, order_fields=[User.username])
dep = UserOrderCrud.order_params()
result = await dep(order_by="username", order="desc")
assert isinstance(result, UnaryExpression)
assert "DESC" in str(result)
@pytest.mark.anyio
async def test_invalid_order_by_raises_invalid_order_field_error(self):
"""Raises InvalidOrderFieldError for an unknown order_by value."""
UserOrderCrud = CrudFactory(User, order_fields=[User.username])
dep = UserOrderCrud.order_params()
with pytest.raises(InvalidOrderFieldError) as exc_info:
await dep(order_by="nonexistent", order="asc")
assert exc_info.value.field == "nonexistent"
assert "username" in exc_info.value.valid_fields
@pytest.mark.anyio
async def test_multiple_fields_all_resolve(self):
"""All configured fields resolve correctly via order_by."""
UserOrderCrud = CrudFactory(User, order_fields=[User.username, User.email])
dep = UserOrderCrud.order_params()
result_username = await dep(order_by="username", order="asc")
result_email = await dep(order_by="email", order="desc")
assert isinstance(result_username, ColumnElement)
assert isinstance(result_email, ColumnElement)
@pytest.mark.anyio
async def test_order_params_integrates_with_get_multi(
self, db_session: AsyncSession
):
"""order_params output is accepted by get_multi(order_by=...)."""
UserOrderCrud = CrudFactory(User, order_fields=[User.username])
await UserCrud.create(
db_session, UserCreate(username="charlie", email="c@test.com")
)
await UserCrud.create(
db_session, UserCreate(username="alice", email="a@test.com")
)
dep = UserOrderCrud.order_params()
order_by = await dep(order_by="username", order="asc")
results = await UserOrderCrud.get_multi(db_session, order_by=order_by)
assert results[0].username == "alice"
assert results[1].username == "charlie"

View File

@@ -0,0 +1,395 @@
"""Live test for the docs/examples/pagination-search.md example.
Spins up the exact FastAPI app described in the example (sourced from
docs_src/examples/pagination_search/) and exercises it through a real HTTP
client against a real PostgreSQL database.
"""
import datetime
import pytest
from fastapi import FastAPI
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from docs_src.examples.pagination_search.db import get_db
from docs_src.examples.pagination_search.models import Article, Base, Category
from docs_src.examples.pagination_search.routes import router
from fastapi_toolsets.exceptions import init_exceptions_handlers
from .conftest import DATABASE_URL
def build_app(session: AsyncSession) -> FastAPI:
app = FastAPI()
init_exceptions_handlers(app)
async def override_get_db():
yield session
app.dependency_overrides[get_db] = override_get_db
app.include_router(router)
return app
@pytest.fixture(scope="function")
async def ex_db_session():
"""Isolated session for the example models (separate tables from conftest)."""
engine = create_async_engine(DATABASE_URL, echo=False)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
session_factory = async_sessionmaker(engine, expire_on_commit=False)
session = session_factory()
try:
yield session
finally:
await session.close()
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await engine.dispose()
@pytest.fixture
async def client(ex_db_session: AsyncSession):
app = build_app(ex_db_session)
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as ac:
yield ac
async def seed(session: AsyncSession):
"""Insert representative fixture data."""
python = Category(name="python")
backend = Category(name="backend")
session.add_all([python, backend])
await session.flush()
now = datetime.datetime(2024, 1, 1, tzinfo=datetime.timezone.utc)
session.add_all(
[
Article(
title="FastAPI tips",
body="Ten useful tips for FastAPI.",
status="published",
published=True,
category_id=python.id,
created_at=now,
),
Article(
title="SQLAlchemy async",
body="How to use async SQLAlchemy.",
status="published",
published=True,
category_id=backend.id,
created_at=now + datetime.timedelta(seconds=1),
),
Article(
title="Draft notes",
body="Work in progress.",
status="draft",
published=False,
category_id=None,
created_at=now + datetime.timedelta(seconds=2),
),
]
)
await session.commit()
class TestAppSessionDep:
@pytest.mark.anyio
async def test_get_db_yields_async_session(self):
"""get_db yields a real AsyncSession when called directly."""
from docs_src.examples.pagination_search.db import get_db
gen = get_db()
session = await gen.__anext__()
assert isinstance(session, AsyncSession)
await session.close()
class TestOffsetPagination:
@pytest.mark.anyio
async def test_returns_all_articles(self, client: AsyncClient, ex_db_session):
await seed(ex_db_session)
resp = await client.get("/articles/offset")
assert resp.status_code == 200
body = resp.json()
assert body["pagination"]["total_count"] == 3
assert len(body["data"]) == 3
@pytest.mark.anyio
async def test_pagination_page_size(self, client: AsyncClient, ex_db_session):
await seed(ex_db_session)
resp = await client.get("/articles/offset?items_per_page=2&page=1")
assert resp.status_code == 200
body = resp.json()
assert len(body["data"]) == 2
assert body["pagination"]["total_count"] == 3
assert body["pagination"]["has_more"] is True
@pytest.mark.anyio
async def test_full_text_search(self, client: AsyncClient, ex_db_session):
await seed(ex_db_session)
resp = await client.get("/articles/offset?search=fastapi")
assert resp.status_code == 200
body = resp.json()
assert body["pagination"]["total_count"] == 1
assert body["data"][0]["title"] == "FastAPI tips"
@pytest.mark.anyio
async def test_search_traverses_relationship(
self, client: AsyncClient, ex_db_session
):
await seed(ex_db_session)
# "python" matches Category.name, not Article.title or body
resp = await client.get("/articles/offset?search=python")
assert resp.status_code == 200
body = resp.json()
assert body["pagination"]["total_count"] == 1
assert body["data"][0]["title"] == "FastAPI tips"
@pytest.mark.anyio
async def test_facet_filter_scalar(self, client: AsyncClient, ex_db_session):
await seed(ex_db_session)
resp = await client.get("/articles/offset?status=published")
assert resp.status_code == 200
body = resp.json()
assert body["pagination"]["total_count"] == 2
assert all(a["status"] == "published" for a in body["data"])
@pytest.mark.anyio
async def test_facet_filter_multi_value(self, client: AsyncClient, ex_db_session):
await seed(ex_db_session)
resp = await client.get("/articles/offset?status=published&status=draft")
assert resp.status_code == 200
body = resp.json()
assert body["pagination"]["total_count"] == 3
@pytest.mark.anyio
async def test_filter_attributes_in_response(
self, client: AsyncClient, ex_db_session
):
await seed(ex_db_session)
resp = await client.get("/articles/offset")
assert resp.status_code == 200
body = resp.json()
fa = body["filter_attributes"]
assert set(fa["status"]) == {"draft", "published"}
# "name" is unique across all facet fields — no prefix needed
assert set(fa["name"]) == {"backend", "python"}
@pytest.mark.anyio
async def test_filter_attributes_scoped_to_filter(
self, client: AsyncClient, ex_db_session
):
await seed(ex_db_session)
resp = await client.get("/articles/offset?status=published")
body = resp.json()
# draft is filtered out → should not appear in filter_attributes
assert "draft" not in body["filter_attributes"]["status"]
@pytest.mark.anyio
async def test_search_and_filter_combined(self, client: AsyncClient, ex_db_session):
await seed(ex_db_session)
resp = await client.get("/articles/offset?search=async&status=published")
assert resp.status_code == 200
body = resp.json()
assert body["pagination"]["total_count"] == 1
assert body["data"][0]["title"] == "SQLAlchemy async"
class TestCursorPagination:
@pytest.mark.anyio
async def test_first_page(self, client: AsyncClient, ex_db_session):
await seed(ex_db_session)
resp = await client.get("/articles/cursor?items_per_page=2")
assert resp.status_code == 200
body = resp.json()
assert len(body["data"]) == 2
assert body["pagination"]["has_more"] is True
assert body["pagination"]["next_cursor"] is not None
assert body["pagination"]["prev_cursor"] is None
@pytest.mark.anyio
async def test_second_page(self, client: AsyncClient, ex_db_session):
await seed(ex_db_session)
first = await client.get("/articles/cursor?items_per_page=2")
next_cursor = first.json()["pagination"]["next_cursor"]
resp = await client.get(
f"/articles/cursor?items_per_page=2&cursor={next_cursor}"
)
assert resp.status_code == 200
body = resp.json()
assert len(body["data"]) == 1
assert body["pagination"]["has_more"] is False
@pytest.mark.anyio
async def test_facet_filter(self, client: AsyncClient, ex_db_session):
await seed(ex_db_session)
resp = await client.get("/articles/cursor?status=draft")
assert resp.status_code == 200
body = resp.json()
assert len(body["data"]) == 1
assert body["data"][0]["status"] == "draft"
@pytest.mark.anyio
async def test_full_text_search(self, client: AsyncClient, ex_db_session):
await seed(ex_db_session)
resp = await client.get("/articles/cursor?search=sqlalchemy")
assert resp.status_code == 200
body = resp.json()
assert len(body["data"]) == 1
assert body["data"][0]["title"] == "SQLAlchemy async"
class TestOffsetSorting:
"""Tests for order_by / order query parameters on the offset endpoint."""
@pytest.mark.anyio
async def test_default_order_uses_created_at_asc(
self, client: AsyncClient, ex_db_session
):
"""No order_by → default field (created_at) ASC."""
await seed(ex_db_session)
resp = await client.get("/articles/offset")
assert resp.status_code == 200
titles = [a["title"] for a in resp.json()["data"]]
assert titles == ["FastAPI tips", "SQLAlchemy async", "Draft notes"]
@pytest.mark.anyio
async def test_order_by_title_asc(self, client: AsyncClient, ex_db_session):
"""order_by=title&order=asc returns alphabetical order."""
await seed(ex_db_session)
resp = await client.get("/articles/offset?order_by=title&order=asc")
assert resp.status_code == 200
titles = [a["title"] for a in resp.json()["data"]]
assert titles == ["Draft notes", "FastAPI tips", "SQLAlchemy async"]
@pytest.mark.anyio
async def test_order_by_title_desc(self, client: AsyncClient, ex_db_session):
"""order_by=title&order=desc returns reverse alphabetical order."""
await seed(ex_db_session)
resp = await client.get("/articles/offset?order_by=title&order=desc")
assert resp.status_code == 200
titles = [a["title"] for a in resp.json()["data"]]
assert titles == ["SQLAlchemy async", "FastAPI tips", "Draft notes"]
@pytest.mark.anyio
async def test_order_by_created_at_desc(self, client: AsyncClient, ex_db_session):
"""order_by=created_at&order=desc returns newest-first."""
await seed(ex_db_session)
resp = await client.get("/articles/offset?order_by=created_at&order=desc")
assert resp.status_code == 200
titles = [a["title"] for a in resp.json()["data"]]
assert titles == ["Draft notes", "SQLAlchemy async", "FastAPI tips"]
@pytest.mark.anyio
async def test_invalid_order_by_returns_422(
self, client: AsyncClient, ex_db_session
):
"""Unknown order_by field returns 422 with SORT-422 error code."""
resp = await client.get("/articles/offset?order_by=nonexistent_field")
assert resp.status_code == 422
body = resp.json()
assert body["error_code"] == "SORT-422"
assert body["status"] == "FAIL"
class TestCursorSorting:
"""Tests for order_by / order query parameters on the cursor endpoint.
In cursor_paginate the cursor_column is always the primary sort; order_by
acts as a secondary tiebreaker. With the seeded articles (all having unique
created_at values) the overall ordering is always created_at ASC regardless
of the order_by value — only the valid/invalid field check and the response
shape are meaningful here.
"""
@pytest.mark.anyio
async def test_default_order_uses_created_at_asc(
self, client: AsyncClient, ex_db_session
):
"""No order_by → default field (created_at) ASC."""
await seed(ex_db_session)
resp = await client.get("/articles/cursor")
assert resp.status_code == 200
titles = [a["title"] for a in resp.json()["data"]]
assert titles == ["FastAPI tips", "SQLAlchemy async", "Draft notes"]
@pytest.mark.anyio
async def test_order_by_title_asc_accepted(
self, client: AsyncClient, ex_db_session
):
"""order_by=title is a valid field — request succeeds and returns all articles."""
await seed(ex_db_session)
resp = await client.get("/articles/cursor?order_by=title&order=asc")
assert resp.status_code == 200
assert len(resp.json()["data"]) == 3
@pytest.mark.anyio
async def test_order_by_title_desc_accepted(
self, client: AsyncClient, ex_db_session
):
"""order_by=title&order=desc is valid — request succeeds and returns all articles."""
await seed(ex_db_session)
resp = await client.get("/articles/cursor?order_by=title&order=desc")
assert resp.status_code == 200
assert len(resp.json()["data"]) == 3
@pytest.mark.anyio
async def test_invalid_order_by_returns_422(
self, client: AsyncClient, ex_db_session
):
"""Unknown order_by field returns 422 with SORT-422 error code."""
resp = await client.get("/articles/cursor?order_by=nonexistent_field")
assert resp.status_code == 422
body = resp.json()
assert body["error_code"] == "SORT-422"
assert body["status"] == "FAIL"

View File

@@ -8,6 +8,7 @@ from fastapi_toolsets.exceptions import (
ApiException, ApiException,
ConflictError, ConflictError,
ForbiddenError, ForbiddenError,
InvalidOrderFieldError,
NotFoundError, NotFoundError,
UnauthorizedError, UnauthorizedError,
generate_error_responses, generate_error_responses,
@@ -334,3 +335,43 @@ class TestExceptionIntegration:
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == {"id": 1} assert response.json() == {"id": 1}
class TestInvalidOrderFieldError:
"""Tests for InvalidOrderFieldError exception."""
def test_api_error_attributes(self):
"""InvalidOrderFieldError has correct api_error metadata."""
assert InvalidOrderFieldError.api_error.code == 422
assert InvalidOrderFieldError.api_error.err_code == "SORT-422"
assert InvalidOrderFieldError.api_error.msg == "Invalid Order Field"
def test_stores_field_and_valid_fields(self):
"""InvalidOrderFieldError stores field and valid_fields on the instance."""
error = InvalidOrderFieldError("unknown", ["name", "created_at"])
assert error.field == "unknown"
assert error.valid_fields == ["name", "created_at"]
def test_message_contains_field_and_valid_fields(self):
"""Exception message mentions the bad field and valid options."""
error = InvalidOrderFieldError("bad_field", ["name", "email"])
assert "bad_field" in str(error)
assert "name" in str(error)
assert "email" in str(error)
def test_handled_as_422_by_exception_handler(self):
"""init_exceptions_handlers turns InvalidOrderFieldError into a 422 response."""
app = FastAPI()
init_exceptions_handlers(app)
@app.get("/items")
async def list_items():
raise InvalidOrderFieldError("bad", ["name"])
client = TestClient(app)
response = client.get("/items")
assert response.status_code == 422
data = response.json()
assert data["error_code"] == "SORT-422"
assert data["status"] == "FAIL"

View File

@@ -14,7 +14,9 @@ from fastapi_toolsets.fixtures import (
load_fixtures_by_context, load_fixtures_by_context,
) )
from .conftest import Role, User from fastapi_toolsets.fixtures.utils import _get_primary_key
from .conftest import IntRole, Permission, Role, User
class TestContext: class TestContext:
@@ -597,6 +599,46 @@ class TestLoadFixtures:
count = await RoleCrud.count(db_session) count = await RoleCrud.count(db_session)
assert count == 2 assert count == 2
@pytest.mark.anyio
async def test_skip_existing_skips_if_record_exists(self, db_session: AsyncSession):
"""SKIP_EXISTING returns empty loaded list when the record already exists."""
registry = FixtureRegistry()
role_id = uuid.uuid4()
@registry.register
def roles():
return [Role(id=role_id, name="admin")]
# First load — inserts the record.
result1 = await load_fixtures(
db_session, registry, "roles", strategy=LoadStrategy.SKIP_EXISTING
)
assert len(result1["roles"]) == 1
# Remove from identity map so session.get() queries the DB in the second load.
db_session.expunge_all()
# Second load — record exists in DB, nothing should be added.
result2 = await load_fixtures(
db_session, registry, "roles", strategy=LoadStrategy.SKIP_EXISTING
)
assert result2["roles"] == []
@pytest.mark.anyio
async def test_skip_existing_null_pk_inserts(self, db_session: AsyncSession):
"""SKIP_EXISTING inserts when the instance has no PK set (auto-increment)."""
registry = FixtureRegistry()
@registry.register
def int_roles():
# No id provided — PK is None before INSERT (autoincrement).
return [IntRole(name="member")]
result = await load_fixtures(
db_session, registry, "int_roles", strategy=LoadStrategy.SKIP_EXISTING
)
assert len(result["int_roles"]) == 1
class TestLoadFixturesByContext: class TestLoadFixturesByContext:
"""Tests for load_fixtures_by_context function.""" """Tests for load_fixtures_by_context function."""
@@ -755,3 +797,19 @@ class TestGetObjByAttr:
"""Raises StopIteration when value type doesn't match.""" """Raises StopIteration when value type doesn't match."""
with pytest.raises(StopIteration): with pytest.raises(StopIteration):
get_obj_by_attr(self.roles, "id", "not-a-uuid") get_obj_by_attr(self.roles, "id", "not-a-uuid")
class TestGetPrimaryKey:
"""Unit tests for the _get_primary_key helper (composite PK paths)."""
def test_composite_pk_all_set(self):
"""Returns a tuple when all composite PK values are set."""
instance = Permission(subject="post", action="read")
pk = _get_primary_key(instance)
assert pk == ("post", "read")
def test_composite_pk_partial_none(self):
"""Returns None when any composite PK value is None."""
instance = Permission(subject="post") # action is None
pk = _get_primary_key(instance)
assert pk is None

View File

@@ -9,7 +9,6 @@ from fastapi_toolsets.schemas import (
ErrorResponse, ErrorResponse,
OffsetPagination, OffsetPagination,
PaginatedResponse, PaginatedResponse,
Pagination,
Response, Response,
ResponseStatus, ResponseStatus,
) )
@@ -199,20 +198,6 @@ class TestOffsetPagination:
assert data["page"] == 2 assert data["page"] == 2
assert data["has_more"] is True assert data["has_more"] is True
def test_pagination_alias_is_offset_pagination(self):
"""Pagination is a backward-compatible alias for OffsetPagination."""
assert Pagination is OffsetPagination
def test_pagination_alias_constructs_offset_pagination(self):
"""Code using Pagination(...) still works unchanged."""
pagination = Pagination(
total_count=10,
items_per_page=5,
page=2,
has_more=False,
)
assert isinstance(pagination, OffsetPagination)
class TestCursorPagination: class TestCursorPagination:
"""Tests for CursorPagination schema.""" """Tests for CursorPagination schema."""
@@ -276,7 +261,7 @@ class TestPaginatedResponse:
def test_create_paginated_response(self): def test_create_paginated_response(self):
"""Create PaginatedResponse with data and pagination.""" """Create PaginatedResponse with data and pagination."""
pagination = Pagination( pagination = OffsetPagination(
total_count=30, total_count=30,
items_per_page=10, items_per_page=10,
page=1, page=1,
@@ -294,7 +279,7 @@ class TestPaginatedResponse:
def test_with_custom_message(self): def test_with_custom_message(self):
"""PaginatedResponse with custom message.""" """PaginatedResponse with custom message."""
pagination = Pagination( pagination = OffsetPagination(
total_count=5, total_count=5,
items_per_page=10, items_per_page=10,
page=1, page=1,
@@ -310,7 +295,7 @@ class TestPaginatedResponse:
def test_empty_data(self): def test_empty_data(self):
"""PaginatedResponse with empty data.""" """PaginatedResponse with empty data."""
pagination = Pagination( pagination = OffsetPagination(
total_count=0, total_count=0,
items_per_page=10, items_per_page=10,
page=1, page=1,
@@ -332,7 +317,7 @@ class TestPaginatedResponse:
id: int id: int
name: str name: str
pagination = Pagination( pagination = OffsetPagination(
total_count=1, total_count=1,
items_per_page=10, items_per_page=10,
page=1, page=1,
@@ -347,7 +332,7 @@ class TestPaginatedResponse:
def test_serialization(self): def test_serialization(self):
"""PaginatedResponse serializes correctly.""" """PaginatedResponse serializes correctly."""
pagination = Pagination( pagination = OffsetPagination(
total_count=100, total_count=100,
items_per_page=10, items_per_page=10,
page=5, page=5,
@@ -385,16 +370,6 @@ class TestPaginatedResponse:
) )
assert isinstance(response.pagination, CursorPagination) assert isinstance(response.pagination, CursorPagination)
def test_pagination_alias_accepted(self):
"""Constructing PaginatedResponse with Pagination (alias) still works."""
response = PaginatedResponse(
data=[],
pagination=Pagination(
total_count=0, items_per_page=10, page=1, has_more=False
),
)
assert isinstance(response.pagination, OffsetPagination)
class TestFromAttributes: class TestFromAttributes:
"""Tests for from_attributes config (ORM mode).""" """Tests for from_attributes config (ORM mode)."""

2
uv.lock generated
View File

@@ -251,7 +251,7 @@ wheels = [
[[package]] [[package]]
name = "fastapi-toolsets" name = "fastapi-toolsets"
version = "1.1.2" version = "1.3.0"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "asyncpg" }, { name = "asyncpg" },

View File

@@ -1,265 +1,35 @@
# ============================================================================
#
# The configuration produced by default is meant to highlight the features
# that Zensical provides and to serve as a starting point for your own
# projects.
#
# ============================================================================
[project] [project]
# The site_name is shown in the page header and the browser window title
#
# Read more: https://zensical.org/docs/setup/basics/#site_name
site_name = "FastAPI Toolsets" site_name = "FastAPI Toolsets"
# The site_description is included in the HTML head and should contain a
# meaningful description of the site content for use by search engines.
#
# Read more: https://zensical.org/docs/setup/basics/#site_description
site_description = "Production-ready utilities for FastAPI applications." site_description = "Production-ready utilities for FastAPI applications."
# The site_author attribute. This is used in the HTML head element.
#
# Read more: https://zensical.org/docs/setup/basics/#site_author
site_author = "d3vyce" site_author = "d3vyce"
# The site_url is the canonical URL for your site. When building online
# documentation you should set this.
# Read more: https://zensical.org/docs/setup/basics/#site_url
site_url = "https://fastapi-toolsets.d3vyce.fr" site_url = "https://fastapi-toolsets.d3vyce.fr"
copyright = "Copyright &copy; 2026 d3vyce"
# The copyright notice appears in the page footer and can contain an HTML
# fragment.
#
# Read more: https://zensical.org/docs/setup/basics/#copyright
copyright = """
Copyright &copy; 2026 d3vyce
"""
repo_url = "https://github.com/d3vyce/fastapi-toolsets" repo_url = "https://github.com/d3vyce/fastapi-toolsets"
# Zensical supports both implicit navigation and explicitly defined navigation.
# If you decide not to define a navigation here then Zensical will simply
# derive the navigation structure from the directory structure of your
# "docs_dir". The definition below demonstrates how a navigation structure
# can be defined using TOML syntax.
#
# Read more: https://zensical.org/docs/setup/navigation/
# nav = [
# { "Get started" = "index.md" },
# { "Markdown in 5min" = "markdown.md" },
# ]
# With the "extra_css" option you can add your own CSS styling to customize
# your Zensical project according to your needs. You can add any number of
# CSS files.
#
# The path provided should be relative to the "docs_dir".
#
# Read more: https://zensical.org/docs/customization/#additional-css
#
#extra_css = ["stylesheets/extra.css"]
# With the `extra_javascript` option you can add your own JavaScript to your
# project to customize the behavior according to your needs.
#
# The path provided should be relative to the "docs_dir".
#
# Read more: https://zensical.org/docs/customization/#additional-javascript
#extra_javascript = ["javascripts/extra.js"]
# ----------------------------------------------------------------------------
# Section for configuring theme options
# ----------------------------------------------------------------------------
[project.theme] [project.theme]
# change this to "classic" to use the traditional Material for MkDocs look.
#variant = "classic"
# Zensical allows you to override specific blocks, partials, or whole
# templates as well as to define your own templates. To do this, uncomment
# the custom_dir setting below and set it to a directory in which you
# keep your template overrides.
#
# Read more:
# - https://zensical.org/docs/customization/#extending-the-theme
#
custom_dir = "docs/overrides" custom_dir = "docs/overrides"
# With the "favicon" option you can set your own image to use as the icon
# browsers will use in the browser title bar or tab bar. The path provided
# must be relative to the "docs_dir".
#
# Read more:
# - https://zensical.org/docs/setup/logo-and-icons/#favicon
# - https://developer.mozilla.org/en-US/docs/Glossary/Favicon
#
#favicon = "images/favicon.png"
# Zensical supports more than 60 different languages. This means that the
# labels and tooltips that Zensical's templates produce are translated.
# The "language" option allows you to set the language used. This language
# is also indicated in the HTML head element to help with accessibility
# and guide search engines and translation tools.
#
# The default language is "en" (English). It is possible to create
# sites with multiple languages and configure a language selector. See
# the documentation for details.
#
# Read more:
# - https://zensical.org/docs/setup/language/
#
language = "en" language = "en"
# Zensical provides a number of feature toggles that change the behavior
# of the documentation site.
features = [ features = [
# Zensical includes an announcement bar. This feature allows users to
# dismiss it when they have read the announcement.
# https://zensical.org/docs/setup/header/#announcement-bar
"announce.dismiss", "announce.dismiss",
# If you have a repository configured and turn on this feature, Zensical
# will generate an edit button for the page. This works for common
# repository hosting services.
# https://zensical.org/docs/setup/repository/#content-actions
#"content.action.edit",
# If you have a repository configured and turn on this feature, Zensical
# will generate a button that allows the user to view the Markdown
# code for the current page.
# https://zensical.org/docs/setup/repository/#content-actions
"content.action.view", "content.action.view",
# Code annotations allow you to add an icon with a tooltip to your
# code blocks to provide explanations at crucial points.
# https://zensical.org/docs/authoring/code-blocks/#code-annotations
"content.code.annotate", "content.code.annotate",
# This feature turns on a button in code blocks that allow users to
# copy the content to their clipboard without first selecting it.
# https://zensical.org/docs/authoring/code-blocks/#code-copy-button
"content.code.copy", "content.code.copy",
# Code blocks can include a button to allow for the selection of line
# ranges by the user.
# https://zensical.org/docs/authoring/code-blocks/#code-selection-button
"content.code.select", "content.code.select",
# Zensical can render footnotes as inline tooltips, so the user can read
# the footnote without leaving the context of the document.
# https://zensical.org/docs/authoring/footnotes/#footnote-tooltips
"content.footnote.tooltips", "content.footnote.tooltips",
# If you have many content tabs that have the same titles (e.g., "Python",
# "JavaScript", "Cobol"), this feature causes all of them to switch to
# at the same time when the user chooses their language in one.
# https://zensical.org/docs/authoring/content-tabs/#linked-content-tabs
"content.tabs.link", "content.tabs.link",
# With this feature enabled users can add tooltips to links that will be
# displayed when the mouse pointer hovers the link.
# https://zensical.org/docs/authoring/tooltips/#improved-tooltips
"content.tooltips", "content.tooltips",
# With this feature enabled, Zensical will automatically hide parts
# of the header when the user scrolls past a certain point.
# https://zensical.org/docs/setup/header/#automatic-hiding
# "header.autohide",
# Turn on this feature to expand all collapsible sections in the
# navigation sidebar by default.
# https://zensical.org/docs/setup/navigation/#navigation-expansion
# "navigation.expand",
# This feature turns on navigation elements in the footer that allow the
# user to navigate to a next or previous page.
# https://zensical.org/docs/setup/footer/#navigation
"navigation.footer", "navigation.footer",
# When section index pages are enabled, documents can be directly attached
# to sections, which is particularly useful for providing overview pages.
# https://zensical.org/docs/setup/navigation/#section-index-pages
"navigation.indexes", "navigation.indexes",
# When instant navigation is enabled, clicks on all internal links will be
# intercepted and dispatched via XHR without fully reloading the page.
# https://zensical.org/docs/setup/navigation/#instant-navigation
"navigation.instant", "navigation.instant",
# With instant prefetching, your site will start to fetch a page once the
# user hovers over a link. This will reduce the perceived loading time
# for the user.
# https://zensical.org/docs/setup/navigation/#instant-prefetching
"navigation.instant.prefetch", "navigation.instant.prefetch",
# In order to provide a better user experience on slow connections when
# using instant navigation, a progress indicator can be enabled.
# https://zensical.org/docs/setup/navigation/#progress-indicator
#"navigation.instant.progress",
# When navigation paths are activated, a breadcrumb navigation is rendered
# above the title of each page
# https://zensical.org/docs/setup/navigation/#navigation-path
"navigation.path", "navigation.path",
# When pruning is enabled, only the visible navigation items are included
# in the rendered HTML, reducing the size of the built site by 33% or more.
# https://zensical.org/docs/setup/navigation/#navigation-pruning
#"navigation.prune",
# When sections are enabled, top-level sections are rendered as groups in
# the sidebar for viewports above 1220px, but remain as-is on mobile.
# https://zensical.org/docs/setup/navigation/#navigation-sections
"navigation.sections", "navigation.sections",
# When tabs are enabled, top-level sections are rendered in a menu layer
# below the header for viewports above 1220px, but remain as-is on mobile.
# https://zensical.org/docs/setup/navigation/#navigation-tabs
"navigation.tabs", "navigation.tabs",
# When sticky tabs are enabled, navigation tabs will lock below the header
# and always remain visible when scrolling down.
# https://zensical.org/docs/setup/navigation/#sticky-navigation-tabs
#"navigation.tabs.sticky",
# A back-to-top button can be shown when the user, after scrolling down,
# starts to scroll up again.
# https://zensical.org/docs/setup/navigation/#back-to-top-button
"navigation.top", "navigation.top",
# When anchor tracking is enabled, the URL in the address bar is
# automatically updated with the active anchor as highlighted in the table
# of contents.
# https://zensical.org/docs/setup/navigation/#anchor-tracking
"navigation.tracking", "navigation.tracking",
# When search highlighting is enabled and a user clicks on a search result,
# Zensical will highlight all occurrences after following the link.
# https://zensical.org/docs/setup/search/#search-highlighting
"search.highlight", "search.highlight",
# When anchor following for the table of contents is enabled, the sidebar
# is automatically scrolled so that the active anchor is always visible.
# https://zensical.org/docs/setup/navigation/#anchor-following
# "toc.follow",
# When navigation integration for the table of contents is enabled, it is
# always rendered as part of the navigation sidebar on the left.
# https://zensical.org/docs/setup/navigation/#navigation-integration
#"toc.integrate",
] ]
# ----------------------------------------------------------------------------
# In the "palette" subsection you can configure options for the color scheme.
# You can configure different color # schemes, e.g., to turn on dark mode,
# that the user can switch between. Each color scheme can be further
# customized.
#
# Read more:
# - https://zensical.org/docs/setup/colors/
# ----------------------------------------------------------------------------
[[project.theme.palette]] [[project.theme.palette]]
scheme = "default" scheme = "default"
toggle.icon = "lucide/sun" toggle.icon = "lucide/sun"
@@ -270,43 +40,13 @@ scheme = "slate"
toggle.icon = "lucide/moon" toggle.icon = "lucide/moon"
toggle.name = "Switch to light mode" toggle.name = "Switch to light mode"
# ----------------------------------------------------------------------------
# In the "font" subsection you can configure the fonts used. By default, fonts
# are loaded from Google Fonts, giving you a wide range of choices from a set
# of suitably licensed fonts. There are options for a normal text font and for
# a monospaced font used in code blocks.
# ----------------------------------------------------------------------------
[project.theme.font] [project.theme.font]
text = "Inter" text = "Inter"
code = "Jetbrains Mono" code = "Jetbrains Mono"
# ----------------------------------------------------------------------------
# You can configure your own logo to be shown in the header using the "logo"
# option in the "icons" subsection. The logo can be a path to a file in your
# "docs_dir" or it can be a path to an icon.
#
# Likewise, you can customize the logo used for the repository section of the
# header. Zensical derives the default logo for this from the repository URL.
# See below...
#
# There are other icons you can customize. See the documentation for details.
#
# Read more:
# - https://zensical.org/docs/setup/logo-and-icons
# - https://zensical.org/docs/authoring/icons-emojis/#search
# ----------------------------------------------------------------------------
[project.theme.icon] [project.theme.icon]
#logo = "lucide/smile"
repo = "fontawesome/brands/github" repo = "fontawesome/brands/github"
# ----------------------------------------------------------------------------
# The "extra" section contains miscellaneous settings.
# ----------------------------------------------------------------------------
#[[project.extra.social]]
#icon = "fontawesome/brands/github"
#link = "https://github.com/user/repo"
[project.plugins.mkdocstrings.handlers.python] [project.plugins.mkdocstrings.handlers.python]
inventories = ["https://docs.python.org/3/objects.inv"] inventories = ["https://docs.python.org/3/objects.inv"]
paths = ["src"] paths = ["src"]
@@ -316,3 +56,42 @@ docstring_style = "google"
inherited_members = true inherited_members = true
show_source = false show_source = false
show_root_heading = true show_root_heading = true
[project.markdown_extensions]
abbr = {}
admonition = {}
attr_list = {}
def_list = {}
footnotes = {}
md_in_html = {}
"pymdownx.arithmatex" = {generic = true}
"pymdownx.betterem" = {}
"pymdownx.caret" = {}
"pymdownx.details" = {}
"pymdownx.emoji" = {}
"pymdownx.inlinehilite" = {}
"pymdownx.keys" = {}
"pymdownx.magiclink" = {}
"pymdownx.mark" = {}
"pymdownx.smartsymbols" = {}
"pymdownx.tasklist" = {custom_checkbox = true}
"pymdownx.tilde" = {}
[project.markdown_extensions."pymdownx.highlight"]
anchor_linenums = true
line_spans = "__span"
pygments_lang_class = true
[project.markdown_extensions."pymdownx.superfences"]
custom_fences = [{name = "mermaid", class = "mermaid"}]
[project.markdown_extensions."pymdownx.tabbed"]
alternate_style = true
combine_header_slug = true
[project.markdown_extensions."toc"]
permalink = true
[project.markdown_extensions."pymdownx.snippets"]
base_path = ["."]
check_paths = true