Compare commits

...

7 Commits

13 changed files with 840 additions and 122 deletions

View File

@@ -72,6 +72,7 @@ GET /articles/offset?page=2&items_per_page=10&search=fastapi&status=published&or
], ],
"pagination": { "pagination": {
"total_count": 42, "total_count": 42,
"pages": 5,
"page": 2, "page": 2,
"items_per_page": 10, "items_per_page": 10,
"has_more": true "has_more": true
@@ -85,6 +86,8 @@ GET /articles/offset?page=2&items_per_page=10&search=fastapi&status=published&or
`filter_attributes` always reflects the values visible **after** applying the active filters. Use it to populate filter dropdowns on the client. `filter_attributes` always reflects the values visible **after** applying the active filters. Use it to populate filter dropdowns on the client.
To skip the `COUNT(*)` query for better performance on large tables, pass `include_total=False`. `pagination.total_count` will be `null` in the response, while `has_more` remains accurate.
### Cursor pagination ### Cursor pagination
Best for feeds, infinite scroll, or any high-throughput API where offset performance degrades. Best for feeds, infinite scroll, or any high-throughput API where offset performance degrades.
@@ -144,7 +147,7 @@ GET /articles/?pagination_type=offset&page=1&items_per_page=10
"status": "SUCCESS", "status": "SUCCESS",
"pagination_type": "offset", "pagination_type": "offset",
"data": ["..."], "data": ["..."],
"pagination": { "total_count": 42, "page": 1, "items_per_page": 10, "has_more": true } "pagination": { "total_count": 42, "pages": 5, "page": 1, "items_per_page": 10, "has_more": true }
} }
``` ```

View File

@@ -182,6 +182,7 @@ The [`offset_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.Async
"data": ["..."], "data": ["..."],
"pagination": { "pagination": {
"total_count": 100, "total_count": 100,
"pages": 5,
"page": 1, "page": 1,
"items_per_page": 20, "items_per_page": 20,
"has_more": true "has_more": true
@@ -189,6 +190,40 @@ The [`offset_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.Async
} }
``` ```
#### Skipping the COUNT query
!!! info "Added in `v2.4.1`"
By default `offset_paginate` runs two queries: one for the page items and one `COUNT(*)` for `total_count`. On large tables the `COUNT` can be expensive. Pass `include_total=False` to skip it:
```python
result = await UserCrud.offset_paginate(
session=session,
page=page,
items_per_page=items_per_page,
include_total=False,
schema=UserRead,
)
```
#### Pagination params dependency
!!! info "Added in `v2.4.1`"
Use [`offset_params()`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.offset_params) to generate a FastAPI dependency that injects `page` and `items_per_page` from query parameters with configurable defaults and a `max_page_size` cap:
```python
from typing import Annotated
from fastapi import Depends
@router.get("")
async def list_users(
session: SessionDep,
params: Annotated[dict, Depends(UserCrud.offset_params(default_page_size=20, max_page_size=100))],
) -> OffsetPaginatedResponse[UserRead]:
return await UserCrud.offset_paginate(session=session, **params, schema=UserRead)
```
### Cursor pagination ### Cursor pagination
```python ```python
@@ -238,7 +273,7 @@ The cursor column is set once on [`CrudFactory`](../reference/crud.md#fastapi_to
!!! note !!! note
`cursor_column` is required. Calling [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate) on a CRUD class that has no `cursor_column` configured raises a `ValueError`. `cursor_column` is required. Calling [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate) on a CRUD class that has no `cursor_column` configured raises a `ValueError`.
The cursor value is base64-encoded when returned to the client and decoded back to the correct Python type on the next request. The following SQLAlchemy column types are supported: The cursor value is URL-safe base64-encoded (no padding) when returned to the client and decoded back to the correct Python type on the next request. The following SQLAlchemy column types are supported:
| SQLAlchemy type | Python type | | SQLAlchemy type | Python type |
|---|---| |---|---|
@@ -256,6 +291,24 @@ PostCrud = CrudFactory(model=Post, cursor_column=Post.id)
PostCrud = CrudFactory(model=Post, cursor_column=Post.created_at) PostCrud = CrudFactory(model=Post, cursor_column=Post.created_at)
``` ```
#### Pagination params dependency
!!! info "Added in `v2.4.1`"
Use [`cursor_params()`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_params) to inject `cursor` and `items_per_page` from query parameters with a `max_page_size` cap:
```python
from typing import Annotated
from fastapi import Depends
@router.get("")
async def list_users(
session: SessionDep,
params: Annotated[dict, Depends(UserCrud.cursor_params(default_page_size=20, max_page_size=100))],
) -> CursorPaginatedResponse[UserRead]:
return await UserCrud.cursor_paginate(session=session, **params, schema=UserRead)
```
### Unified endpoint (both strategies) ### Unified endpoint (both strategies)
!!! info "Added in `v2.3.0`" !!! info "Added in `v2.3.0`"
@@ -289,7 +342,24 @@ GET /users?pagination_type=offset&page=2&items_per_page=10
GET /users?pagination_type=cursor&cursor=eyJ2YWx1ZSI6...&items_per_page=10 GET /users?pagination_type=cursor&cursor=eyJ2YWx1ZSI6...&items_per_page=10
``` ```
Both `page` and `cursor` are always accepted by the endpoint — unused parameters are silently ignored by `paginate()`. #### Pagination params dependency
!!! info "Added in `v2.4.1`"
Use [`paginate_params()`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.paginate_params) to inject all parameters at once with configurable defaults and a `max_page_size` cap:
```python
from typing import Annotated
from fastapi import Depends
from fastapi_toolsets.schemas import PaginatedResponse
@router.get("")
async def list_users(
session: SessionDep,
params: Annotated[dict, Depends(UserCrud.paginate_params(default_page_size=20, max_page_size=100))],
) -> PaginatedResponse[UserRead]:
return await UserCrud.paginate(session, **params, schema=UserRead)
```
## Search ## Search

View File

@@ -1,8 +1,8 @@
from typing import Annotated from typing import Annotated
from fastapi import APIRouter, Depends, Query from fastapi import APIRouter, Depends
from fastapi_toolsets.crud import OrderByClause, PaginationType from fastapi_toolsets.crud import OrderByClause
from fastapi_toolsets.schemas import ( from fastapi_toolsets.schemas import (
CursorPaginatedResponse, CursorPaginatedResponse,
OffsetPaginatedResponse, OffsetPaginatedResponse,
@@ -20,19 +20,20 @@ router = APIRouter(prefix="/articles")
@router.get("/offset") @router.get("/offset")
async def list_articles_offset( async def list_articles_offset(
session: SessionDep, session: SessionDep,
params: Annotated[
dict,
Depends(ArticleCrud.offset_params(default_page_size=20, max_page_size=100)),
],
filter_by: Annotated[dict[str, list[str]], Depends(ArticleCrud.filter_params())], filter_by: Annotated[dict[str, list[str]], Depends(ArticleCrud.filter_params())],
order_by: Annotated[ order_by: Annotated[
OrderByClause | None, OrderByClause | None,
Depends(ArticleCrud.order_params(default_field=Article.created_at)), Depends(ArticleCrud.order_params(default_field=Article.created_at)),
], ],
page: int = Query(1, ge=1),
items_per_page: int = Query(20, ge=1, le=100),
search: str | None = None, search: str | None = None,
) -> OffsetPaginatedResponse[ArticleRead]: ) -> OffsetPaginatedResponse[ArticleRead]:
return await ArticleCrud.offset_paginate( return await ArticleCrud.offset_paginate(
session=session, session=session,
page=page, **params,
items_per_page=items_per_page,
search=search, search=search,
filter_by=filter_by or None, filter_by=filter_by or None,
order_by=order_by, order_by=order_by,
@@ -43,19 +44,20 @@ async def list_articles_offset(
@router.get("/cursor") @router.get("/cursor")
async def list_articles_cursor( async def list_articles_cursor(
session: SessionDep, session: SessionDep,
params: Annotated[
dict,
Depends(ArticleCrud.cursor_params(default_page_size=20, max_page_size=100)),
],
filter_by: Annotated[dict[str, list[str]], Depends(ArticleCrud.filter_params())], filter_by: Annotated[dict[str, list[str]], Depends(ArticleCrud.filter_params())],
order_by: Annotated[ order_by: Annotated[
OrderByClause | None, OrderByClause | None,
Depends(ArticleCrud.order_params(default_field=Article.created_at)), Depends(ArticleCrud.order_params(default_field=Article.created_at)),
], ],
cursor: str | None = None,
items_per_page: int = Query(20, ge=1, le=100),
search: str | None = None, search: str | None = None,
) -> CursorPaginatedResponse[ArticleRead]: ) -> CursorPaginatedResponse[ArticleRead]:
return await ArticleCrud.cursor_paginate( return await ArticleCrud.cursor_paginate(
session=session, session=session,
cursor=cursor, **params,
items_per_page=items_per_page,
search=search, search=search,
filter_by=filter_by or None, filter_by=filter_by or None,
order_by=order_by, order_by=order_by,
@@ -66,23 +68,20 @@ async def list_articles_cursor(
@router.get("/") @router.get("/")
async def list_articles( async def list_articles(
session: SessionDep, session: SessionDep,
params: Annotated[
dict,
Depends(ArticleCrud.paginate_params(default_page_size=20, max_page_size=100)),
],
filter_by: Annotated[dict[str, list[str]], Depends(ArticleCrud.filter_params())], filter_by: Annotated[dict[str, list[str]], Depends(ArticleCrud.filter_params())],
order_by: Annotated[ order_by: Annotated[
OrderByClause | None, OrderByClause | None,
Depends(ArticleCrud.order_params(default_field=Article.created_at)), Depends(ArticleCrud.order_params(default_field=Article.created_at)),
], ],
pagination_type: PaginationType = PaginationType.OFFSET,
page: int = Query(1, ge=1),
cursor: str | None = None,
items_per_page: int = Query(20, ge=1, le=100),
search: str | None = None, search: str | None = None,
) -> PaginatedResponse[ArticleRead]: ) -> PaginatedResponse[ArticleRead]:
return await ArticleCrud.paginate( return await ArticleCrud.paginate(
session, session,
pagination_type=pagination_type, **params,
page=page,
cursor=cursor,
items_per_page=items_per_page,
search=search, search=search,
filter_by=filter_by or None, filter_by=filter_by or None,
order_by=order_by, order_by=order_by,

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "fastapi-toolsets" name = "fastapi-toolsets"
version = "2.3.0" version = "2.4.1"
description = "Production-ready utilities for FastAPI applications" description = "Production-ready utilities for FastAPI applications"
readme = "README.md" readme = "README.md"
license = "MIT" license = "MIT"

View File

@@ -21,4 +21,4 @@ Example usage:
return Response(data={"user": user.username}, message="Success") return Response(data={"user": user.username}, message="Success")
""" """
__version__ = "2.3.0" __version__ = "2.4.1"

View File

@@ -58,18 +58,33 @@ class _CursorDirection(str, Enum):
def _encode_cursor( def _encode_cursor(
value: Any, *, direction: _CursorDirection = _CursorDirection.NEXT value: Any, *, direction: _CursorDirection = _CursorDirection.NEXT
) -> str: ) -> str:
"""Encode a cursor column value and navigation direction as a base64 string.""" """Encode a cursor column value and navigation direction as a URL-safe base64 string."""
return base64.b64encode( return (
json.dumps({"val": str(value), "dir": direction}).encode() base64.urlsafe_b64encode(
).decode() json.dumps({"val": str(value), "dir": direction}).encode()
)
.decode()
.rstrip("=")
)
def _decode_cursor(cursor: str) -> tuple[str, _CursorDirection]: def _decode_cursor(cursor: str) -> tuple[str, _CursorDirection]:
"""Decode a cursor base64 string into ``(raw_value, direction)``.""" """Decode a URL-safe base64 cursor string into ``(raw_value, direction)``."""
payload = json.loads(base64.b64decode(cursor.encode()).decode()) padded = cursor + "=" * (-len(cursor) % 4)
payload = json.loads(base64.urlsafe_b64decode(padded).decode())
return payload["val"], _CursorDirection(payload["dir"]) return payload["val"], _CursorDirection(payload["dir"])
def _page_size_query(default: int, max_size: int) -> int:
"""Return a FastAPI ``Query`` for the ``items_per_page`` parameter."""
return Query(
default,
ge=1,
le=max_size,
description=f"Number of items per page (max {max_size})",
)
def _parse_cursor_value(raw_val: str, col_type: Any) -> Any: def _parse_cursor_value(raw_val: str, col_type: Any) -> Any:
"""Parse a raw cursor string value back into the appropriate Python type.""" """Parse a raw cursor string value back into the appropriate Python type."""
if isinstance(col_type, Integer): if isinstance(col_type, Integer):
@@ -254,6 +269,7 @@ class AsyncCrud(Generic[ModelType]):
facet_fields: Sequence[FacetFieldType] | None = None, facet_fields: Sequence[FacetFieldType] | None = None,
) -> Callable[..., Awaitable[dict[str, list[str]]]]: ) -> Callable[..., Awaitable[dict[str, list[str]]]]:
"""Return a FastAPI dependency that collects facet filter values from query parameters. """Return a FastAPI dependency that collects facet filter values from query parameters.
Args: Args:
facet_fields: Override the facet fields for this dependency. Falls back to the facet_fields: Override the facet fields for this dependency. Falls back to the
class-level ``facet_fields`` if not provided. class-level ``facet_fields`` if not provided.
@@ -293,6 +309,121 @@ class AsyncCrud(Generic[ModelType]):
return dependency return dependency
@classmethod
def offset_params(
cls: type[Self],
*,
default_page_size: int = 20,
max_page_size: int = 100,
include_total: bool = True,
) -> Callable[..., Awaitable[dict[str, Any]]]:
"""Return a FastAPI dependency that collects offset pagination params from query params.
Args:
default_page_size: Default value for the ``items_per_page`` query parameter.
max_page_size: Maximum allowed value for ``items_per_page`` (enforced via
``le`` on the ``Query``).
include_total: Server-side flag forwarded as-is to ``include_total`` in
:meth:`offset_paginate`. Not exposed as a query parameter.
Returns:
An async dependency that resolves to a dict with ``page``,
``items_per_page``, and ``include_total`` keys, ready to be
unpacked into :meth:`offset_paginate`.
"""
async def dependency(
page: int = Query(1, ge=1, description="Page number (1-indexed)"),
items_per_page: int = _page_size_query(default_page_size, max_page_size),
) -> dict[str, Any]:
return {
"page": page,
"items_per_page": items_per_page,
"include_total": include_total,
}
dependency.__name__ = f"{cls.model.__name__}OffsetParams"
return dependency
@classmethod
def cursor_params(
cls: type[Self],
*,
default_page_size: int = 20,
max_page_size: int = 100,
) -> Callable[..., Awaitable[dict[str, Any]]]:
"""Return a FastAPI dependency that collects cursor pagination params from query params.
Args:
default_page_size: Default value for the ``items_per_page`` query parameter.
max_page_size: Maximum allowed value for ``items_per_page`` (enforced via
``le`` on the ``Query``).
Returns:
An async dependency that resolves to a dict with ``cursor`` and
``items_per_page`` keys, ready to be unpacked into
:meth:`cursor_paginate`.
"""
async def dependency(
cursor: str | None = Query(
None, description="Cursor token from a previous response"
),
items_per_page: int = _page_size_query(default_page_size, max_page_size),
) -> dict[str, Any]:
return {"cursor": cursor, "items_per_page": items_per_page}
dependency.__name__ = f"{cls.model.__name__}CursorParams"
return dependency
@classmethod
def paginate_params(
cls: type[Self],
*,
default_page_size: int = 20,
max_page_size: int = 100,
default_pagination_type: PaginationType = PaginationType.OFFSET,
include_total: bool = True,
) -> Callable[..., Awaitable[dict[str, Any]]]:
"""Return a FastAPI dependency that collects all pagination params from query params.
Args:
default_page_size: Default value for the ``items_per_page`` query parameter.
max_page_size: Maximum allowed value for ``items_per_page`` (enforced via
``le`` on the ``Query``).
default_pagination_type: Default pagination strategy.
include_total: Server-side flag forwarded as-is to ``include_total`` in
:meth:`paginate`. Not exposed as a query parameter.
Returns:
An async dependency that resolves to a dict with ``pagination_type``,
``page``, ``cursor``, ``items_per_page``, and ``include_total`` keys,
ready to be unpacked into :meth:`paginate`.
"""
async def dependency(
pagination_type: PaginationType = Query(
default_pagination_type, description="Pagination strategy"
),
page: int = Query(
1, ge=1, description="Page number (1-indexed, offset only)"
),
cursor: str | None = Query(
None, description="Cursor token from a previous response (cursor only)"
),
items_per_page: int = _page_size_query(default_page_size, max_page_size),
) -> dict[str, Any]:
return {
"pagination_type": pagination_type,
"page": page,
"cursor": cursor,
"items_per_page": items_per_page,
"include_total": include_total,
}
dependency.__name__ = f"{cls.model.__name__}PaginateParams"
return dependency
@classmethod @classmethod
def order_params( def order_params(
cls: type[Self], cls: type[Self],
@@ -922,6 +1053,7 @@ class AsyncCrud(Generic[ModelType]):
order_by: OrderByClause | None = None, order_by: OrderByClause | None = None,
page: int = 1, page: int = 1,
items_per_page: int = 20, items_per_page: int = 20,
include_total: bool = True,
search: str | SearchConfig | None = None, search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None, search_fields: Sequence[SearchFieldType] | None = None,
facet_fields: Sequence[FacetFieldType] | None = None, facet_fields: Sequence[FacetFieldType] | None = None,
@@ -939,6 +1071,8 @@ class AsyncCrud(Generic[ModelType]):
order_by: Column or list of columns to order by order_by: Column or list of columns to order by
page: Page number (1-indexed) page: Page number (1-indexed)
items_per_page: Number of items per page items_per_page: Number of items per page
include_total: When ``False``, skip the ``COUNT`` query;
``pagination.total_count`` will be ``None``.
search: Search query string or SearchConfig object search: Search query string or SearchConfig object
search_fields: Fields to search in (overrides class default) search_fields: Fields to search in (overrides class default)
facet_fields: Columns to compute distinct values for (overrides class default) facet_fields: Columns to compute distinct values for (overrides class default)
@@ -983,28 +1117,39 @@ class AsyncCrud(Generic[ModelType]):
if order_by is not None: if order_by is not None:
q = q.order_by(order_by) q = q.order_by(order_by)
q = q.offset(offset).limit(items_per_page) if include_total:
result = await session.execute(q) q = q.offset(offset).limit(items_per_page)
raw_items = cast(list[ModelType], result.unique().scalars().all()) result = await session.execute(q)
raw_items = cast(list[ModelType], result.unique().scalars().all())
# Count query (with same joins and filters)
pk_col = cls.model.__mapper__.primary_key[0]
count_q = select(func.count(func.distinct(getattr(cls.model, pk_col.name))))
count_q = count_q.select_from(cls.model)
# Apply explicit joins to count query
count_q = _apply_joins(count_q, joins, outer_join)
# Apply search joins to count query
count_q = _apply_search_joins(count_q, search_joins)
if filters:
count_q = count_q.where(and_(*filters))
count_result = await session.execute(count_q)
total_count: int | None = count_result.scalar_one()
has_more = page * items_per_page < total_count
else:
# Fetch one extra row to detect if a next page exists without COUNT
q = q.offset(offset).limit(items_per_page + 1)
result = await session.execute(q)
raw_items = cast(list[ModelType], result.unique().scalars().all())
has_more = len(raw_items) > items_per_page
raw_items = raw_items[:items_per_page]
total_count = None
items: list[Any] = [schema.model_validate(item) for item in raw_items] items: list[Any] = [schema.model_validate(item) for item in raw_items]
# Count query (with same joins and filters)
pk_col = cls.model.__mapper__.primary_key[0]
count_q = select(func.count(func.distinct(getattr(cls.model, pk_col.name))))
count_q = count_q.select_from(cls.model)
# Apply explicit joins to count query
count_q = _apply_joins(count_q, joins, outer_join)
# Apply search joins to count query
count_q = _apply_search_joins(count_q, search_joins)
if filters:
count_q = count_q.where(and_(*filters))
count_result = await session.execute(count_q)
total_count = count_result.scalar_one()
filter_attributes = await cls._build_filter_attributes( filter_attributes = await cls._build_filter_attributes(
session, facet_fields, filters, search_joins session, facet_fields, filters, search_joins
) )
@@ -1015,7 +1160,7 @@ class AsyncCrud(Generic[ModelType]):
total_count=total_count, total_count=total_count,
items_per_page=items_per_page, items_per_page=items_per_page,
page=page, page=page,
has_more=page * items_per_page < total_count, has_more=has_more,
), ),
filter_attributes=filter_attributes, filter_attributes=filter_attributes,
) )
@@ -1190,6 +1335,7 @@ class AsyncCrud(Generic[ModelType]):
page: int = ..., page: int = ...,
cursor: str | None = ..., cursor: str | None = ...,
items_per_page: int = ..., items_per_page: int = ...,
include_total: bool = ...,
search: str | SearchConfig | None = ..., search: str | SearchConfig | None = ...,
search_fields: Sequence[SearchFieldType] | None = ..., search_fields: Sequence[SearchFieldType] | None = ...,
facet_fields: Sequence[FacetFieldType] | None = ..., facet_fields: Sequence[FacetFieldType] | None = ...,
@@ -1212,6 +1358,7 @@ class AsyncCrud(Generic[ModelType]):
page: int = ..., page: int = ...,
cursor: str | None = ..., cursor: str | None = ...,
items_per_page: int = ..., items_per_page: int = ...,
include_total: bool = ...,
search: str | SearchConfig | None = ..., search: str | SearchConfig | None = ...,
search_fields: Sequence[SearchFieldType] | None = ..., search_fields: Sequence[SearchFieldType] | None = ...,
facet_fields: Sequence[FacetFieldType] | None = ..., facet_fields: Sequence[FacetFieldType] | None = ...,
@@ -1233,6 +1380,7 @@ class AsyncCrud(Generic[ModelType]):
page: int = 1, page: int = 1,
cursor: str | None = None, cursor: str | None = None,
items_per_page: int = 20, items_per_page: int = 20,
include_total: bool = True,
search: str | SearchConfig | None = None, search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None, search_fields: Sequence[SearchFieldType] | None = None,
facet_fields: Sequence[FacetFieldType] | None = None, facet_fields: Sequence[FacetFieldType] | None = None,
@@ -1258,6 +1406,8 @@ class AsyncCrud(Generic[ModelType]):
:class:`.CursorPaginatedResponse`. Only used when :class:`.CursorPaginatedResponse`. Only used when
``pagination_type`` is ``CURSOR``. ``pagination_type`` is ``CURSOR``.
items_per_page: Number of items per page (default 20). items_per_page: Number of items per page (default 20).
include_total: When ``False``, skip the ``COUNT`` query;
only applies when ``pagination_type`` is ``OFFSET``.
search: Search query string or :class:`.SearchConfig` object. search: Search query string or :class:`.SearchConfig` object.
search_fields: Fields to search in (overrides class default). search_fields: Fields to search in (overrides class default).
facet_fields: Columns to compute distinct values for (overrides facet_fields: Columns to compute distinct values for (overrides
@@ -1304,6 +1454,7 @@ class AsyncCrud(Generic[ModelType]):
order_by=order_by, order_by=order_by,
page=page, page=page,
items_per_page=items_per_page, items_per_page=items_per_page,
include_total=include_total,
search=search, search=search,
search_fields=search_fields, search_fields=search_fields,
facet_fields=facet_fields, facet_fields=facet_fields,

View File

@@ -9,6 +9,7 @@ from typing import Any, TypeVar
from sqlalchemy import event from sqlalchemy import event
from sqlalchemy import inspect as sa_inspect from sqlalchemy import inspect as sa_inspect
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import set_committed_value as _sa_set_committed_value
from ..logger import get_logger from ..logger import get_logger
@@ -53,6 +54,17 @@ def watch(*fields: str) -> Any:
return decorator return decorator
def _snapshot_column_attrs(obj: Any) -> dict[str, Any]:
"""Read currently-loaded column values into a plain dict."""
state = sa_inspect(obj) # InstanceState
state_dict = state.dict
return {
prop.key: state_dict[prop.key]
for prop in state.mapper.column_attrs
if prop.key in state_dict
}
def _upsert_changes( def _upsert_changes(
pending: dict[int, tuple[Any, dict[str, dict[str, Any]]]], pending: dict[int, tuple[Any, dict[str, dict[str, Any]]]],
obj: Any, obj: Any,
@@ -139,16 +151,31 @@ def _task_error_handler(task: asyncio.Task[Any]) -> None:
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc) _logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
def _call_callback(loop: asyncio.AbstractEventLoop, fn: Any, *args: Any) -> None: def _schedule_with_snapshot(
"""Dispatch *fn* with *args*, handling both sync and async callables.""" loop: asyncio.AbstractEventLoop, obj: Any, fn: Any, *args: Any
try: ) -> None:
result = fn(*args) """Snapshot *obj*'s column attrs now (before expire_on_commit wipes them),
except Exception as exc: then schedule a coroutine that restores the snapshot and calls *fn*.
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc) """
return snapshot = _snapshot_column_attrs(obj)
if asyncio.iscoroutine(result):
task = loop.create_task(result) async def _run(
task.add_done_callback(_task_error_handler) obj: Any = obj,
fn: Any = fn,
snapshot: dict[str, Any] = snapshot,
args: tuple = args,
) -> None:
for key, value in snapshot.items():
_sa_set_committed_value(obj, key, value)
try:
result = fn(*args)
if asyncio.iscoroutine(result):
await result
except Exception as exc:
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
task = loop.create_task(_run())
task.add_done_callback(_task_error_handler)
@event.listens_for(AsyncSession.sync_session_class, "after_commit") @event.listens_for(AsyncSession.sync_session_class, "after_commit")
@@ -168,13 +195,13 @@ def _after_commit(session: Any) -> None:
return return
for obj in creates: for obj in creates:
_call_callback(loop, obj.on_create) _schedule_with_snapshot(loop, obj, obj.on_create)
for obj in deletes: for obj in deletes:
_call_callback(loop, obj.on_delete) _schedule_with_snapshot(loop, obj, obj.on_delete)
for obj, changes in field_changes.values(): for obj, changes in field_changes.values():
_call_callback(loop, obj.on_update, changes) _schedule_with_snapshot(loop, obj, obj.on_update, changes)
class WatchedFieldsMixin: class WatchedFieldsMixin:

View File

@@ -1,9 +1,10 @@
"""Base Pydantic schemas for API responses.""" """Base Pydantic schemas for API responses."""
import math
from enum import Enum from enum import Enum
from typing import Annotated, Any, ClassVar, Generic, Literal, TypeVar, Union from typing import Annotated, Any, ClassVar, Generic, Literal, TypeVar, Union
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field, computed_field
from .types import DataT from .types import DataT
@@ -98,17 +99,29 @@ class OffsetPagination(PydanticBase):
"""Pagination metadata for offset-based list responses. """Pagination metadata for offset-based list responses.
Attributes: Attributes:
total_count: Total number of items across all pages total_count: Total number of items across all pages.
``None`` when ``include_total=False``.
items_per_page: Number of items per page items_per_page: Number of items per page
page: Current page number (1-indexed) page: Current page number (1-indexed)
has_more: Whether there are more pages has_more: Whether there are more pages
pages: Total number of pages
""" """
total_count: int total_count: int | None
items_per_page: int items_per_page: int
page: int page: int
has_more: bool has_more: bool
@computed_field
@property
def pages(self) -> int | None:
"""Total number of pages, or ``None`` when ``total_count`` is unknown."""
if self.total_count is None:
return None
if self.items_per_page == 0:
return 0
return math.ceil(self.total_count / self.items_per_page)
class CursorPagination(PydanticBase): class CursorPagination(PydanticBase):
"""Pagination metadata for cursor-based list responses. """Pagination metadata for cursor-based list responses.

View File

@@ -1759,6 +1759,52 @@ class TestSchemaResponse:
assert result.data[0].username == "pg_user" assert result.data[0].username == "pg_user"
assert not hasattr(result.data[0], "email") assert not hasattr(result.data[0], "email")
@pytest.mark.anyio
async def test_include_total_false_skips_count(self, db_session: AsyncSession):
"""offset_paginate with include_total=False returns total_count=None."""
from fastapi_toolsets.schemas import OffsetPagination
for i in range(5):
await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
result = await RoleCrud.offset_paginate(
db_session, items_per_page=10, include_total=False, schema=RoleRead
)
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count is None
assert len(result.data) == 5
assert result.pagination.has_more is False
@pytest.mark.anyio
async def test_include_total_false_has_more_true(self, db_session: AsyncSession):
"""offset_paginate with include_total=False sets has_more via extra-row probe."""
for i in range(15):
await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
result = await RoleCrud.offset_paginate(
db_session, items_per_page=10, include_total=False, schema=RoleRead
)
assert result.pagination.total_count is None
assert result.pagination.has_more is True
assert len(result.data) == 10
@pytest.mark.anyio
async def test_include_total_false_exact_page_boundary(
self, db_session: AsyncSession
):
"""offset_paginate with include_total=False: has_more=False when items == page size."""
for i in range(10):
await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
result = await RoleCrud.offset_paginate(
db_session, items_per_page=10, include_total=False, schema=RoleRead
)
assert result.pagination.has_more is False
assert len(result.data) == 10
class TestCursorPaginate: class TestCursorPaginate:
"""Tests for cursor-based pagination via cursor_paginate().""" """Tests for cursor-based pagination via cursor_paginate()."""
@@ -2521,3 +2567,20 @@ class TestPaginate:
pagination_type="unknown", pagination_type="unknown",
schema=RoleRead, schema=RoleRead,
) # type: ignore[no-matching-overload] ) # type: ignore[no-matching-overload]
@pytest.mark.anyio
async def test_offset_include_total_false(self, db_session: AsyncSession):
"""paginate() passes include_total=False through to offset_paginate."""
from fastapi_toolsets.schemas import OffsetPagination
await RoleCrud.create(db_session, RoleCreate(name="admin"))
result = await RoleCrud.paginate(
db_session,
pagination_type=PaginationType.OFFSET,
include_total=False,
schema=RoleRead,
)
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count is None

View File

@@ -14,12 +14,14 @@ from fastapi_toolsets.crud import (
get_searchable_fields, get_searchable_fields,
) )
from fastapi_toolsets.exceptions import InvalidOrderFieldError from fastapi_toolsets.exceptions import InvalidOrderFieldError
from fastapi_toolsets.schemas import OffsetPagination from fastapi_toolsets.schemas import OffsetPagination, PaginationType
from .conftest import ( from .conftest import (
Role, Role,
RoleCreate, RoleCreate,
RoleCrud, RoleCrud,
RoleCursorCrud,
RoleRead,
User, User,
UserCreate, UserCreate,
UserCrud, UserCrud,
@@ -1193,3 +1195,245 @@ class TestOrderParamsSchema:
assert results[0].username == "alice" assert results[0].username == "alice"
assert results[1].username == "charlie" assert results[1].username == "charlie"
class TestOffsetParamsSchema:
"""Tests for AsyncCrud.offset_params()."""
def test_returns_page_and_items_per_page_params(self):
"""Returned dependency has page and items_per_page params only."""
dep = RoleCrud.offset_params()
param_names = set(inspect.signature(dep).parameters)
assert param_names == {"page", "items_per_page"}
def test_dependency_name_includes_model_name(self):
"""Dependency function is named after the model."""
dep = RoleCrud.offset_params()
assert getattr(dep, "__name__") == "RoleOffsetParams"
def test_default_page_size_reflected_in_items_per_page_default(self):
"""default_page_size is used as the default for items_per_page."""
dep = RoleCrud.offset_params(default_page_size=42)
sig = inspect.signature(dep)
assert sig.parameters["items_per_page"].default.default == 42
def test_max_page_size_reflected_in_items_per_page_le(self):
"""max_page_size is used as le constraint on items_per_page."""
dep = RoleCrud.offset_params(max_page_size=50)
sig = inspect.signature(dep)
le = next(
m.le
for m in sig.parameters["items_per_page"].default.metadata
if hasattr(m, "le")
)
assert le == 50
def test_include_total_not_a_query_param(self):
"""include_total is not exposed as a query parameter."""
dep = RoleCrud.offset_params()
param_names = set(inspect.signature(dep).parameters)
assert "include_total" not in param_names
@pytest.mark.anyio
async def test_include_total_true_forwarded_in_result(self):
"""include_total=True factory arg appears in the resolved dict."""
result = await RoleCrud.offset_params(include_total=True)(
page=1, items_per_page=10
)
assert result["include_total"] is True
@pytest.mark.anyio
async def test_include_total_false_forwarded_in_result(self):
"""include_total=False factory arg appears in the resolved dict."""
result = await RoleCrud.offset_params(include_total=False)(
page=1, items_per_page=10
)
assert result["include_total"] is False
@pytest.mark.anyio
async def test_awaiting_dep_returns_dict(self):
"""Awaiting the dependency returns a dict with page, items_per_page, include_total."""
dep = RoleCrud.offset_params(include_total=False)
result = await dep(page=2, items_per_page=10)
assert result == {"page": 2, "items_per_page": 10, "include_total": False}
@pytest.mark.anyio
async def test_integrates_with_offset_paginate(self, db_session: AsyncSession):
"""offset_params output can be unpacked directly into offset_paginate."""
await RoleCrud.create(db_session, RoleCreate(name="admin"))
dep = RoleCrud.offset_params()
params = await dep(page=1, items_per_page=10)
result = await RoleCrud.offset_paginate(db_session, **params, schema=RoleRead)
assert result.pagination.page == 1
assert result.pagination.items_per_page == 10
class TestCursorParamsSchema:
"""Tests for AsyncCrud.cursor_params()."""
def test_returns_cursor_and_items_per_page_params(self):
"""Returned dependency has cursor and items_per_page params."""
dep = RoleCursorCrud.cursor_params()
param_names = set(inspect.signature(dep).parameters)
assert param_names == {"cursor", "items_per_page"}
def test_dependency_name_includes_model_name(self):
"""Dependency function is named after the model."""
dep = RoleCursorCrud.cursor_params()
assert getattr(dep, "__name__") == "RoleCursorParams"
def test_default_page_size_reflected_in_items_per_page_default(self):
"""default_page_size is used as the default for items_per_page."""
dep = RoleCursorCrud.cursor_params(default_page_size=15)
sig = inspect.signature(dep)
assert sig.parameters["items_per_page"].default.default == 15
def test_max_page_size_reflected_in_items_per_page_le(self):
"""max_page_size is used as le constraint on items_per_page."""
dep = RoleCursorCrud.cursor_params(max_page_size=75)
sig = inspect.signature(dep)
le = next(
m.le
for m in sig.parameters["items_per_page"].default.metadata
if hasattr(m, "le")
)
assert le == 75
def test_cursor_defaults_to_none(self):
"""cursor defaults to None."""
dep = RoleCursorCrud.cursor_params()
sig = inspect.signature(dep)
assert sig.parameters["cursor"].default.default is None
@pytest.mark.anyio
async def test_awaiting_dep_returns_dict(self):
"""Awaiting the dependency returns a dict with cursor and items_per_page."""
dep = RoleCursorCrud.cursor_params()
result = await dep(cursor=None, items_per_page=5)
assert result == {"cursor": None, "items_per_page": 5}
@pytest.mark.anyio
async def test_integrates_with_cursor_paginate(self, db_session: AsyncSession):
"""cursor_params output can be unpacked directly into cursor_paginate."""
await RoleCrud.create(db_session, RoleCreate(name="admin"))
dep = RoleCursorCrud.cursor_params()
params = await dep(cursor=None, items_per_page=10)
result = await RoleCursorCrud.cursor_paginate(
db_session, **params, schema=RoleRead
)
assert result.pagination.items_per_page == 10
class TestPaginateParamsSchema:
"""Tests for AsyncCrud.paginate_params()."""
def test_returns_all_params(self):
"""Returned dependency has pagination_type, page, cursor, items_per_page (no include_total)."""
dep = RoleCursorCrud.paginate_params()
param_names = set(inspect.signature(dep).parameters)
assert param_names == {"pagination_type", "page", "cursor", "items_per_page"}
def test_dependency_name_includes_model_name(self):
"""Dependency function is named after the model."""
dep = RoleCursorCrud.paginate_params()
assert getattr(dep, "__name__") == "RolePaginateParams"
def test_default_pagination_type(self):
"""default_pagination_type is reflected in pagination_type default."""
from fastapi_toolsets.schemas import PaginationType
dep = RoleCursorCrud.paginate_params(
default_pagination_type=PaginationType.CURSOR
)
sig = inspect.signature(dep)
assert (
sig.parameters["pagination_type"].default.default == PaginationType.CURSOR
)
def test_default_page_size(self):
"""default_page_size is reflected in items_per_page default."""
dep = RoleCursorCrud.paginate_params(default_page_size=15)
sig = inspect.signature(dep)
assert sig.parameters["items_per_page"].default.default == 15
def test_max_page_size_le_constraint(self):
"""max_page_size is used as le constraint on items_per_page."""
dep = RoleCursorCrud.paginate_params(max_page_size=60)
sig = inspect.signature(dep)
le = next(
m.le
for m in sig.parameters["items_per_page"].default.metadata
if hasattr(m, "le")
)
assert le == 60
def test_include_total_not_a_query_param(self):
"""include_total is not exposed as a query parameter."""
dep = RoleCursorCrud.paginate_params()
assert "include_total" not in set(inspect.signature(dep).parameters)
@pytest.mark.anyio
async def test_include_total_forwarded_in_result(self):
"""include_total factory arg appears in the resolved dict."""
result_true = await RoleCursorCrud.paginate_params(include_total=True)(
pagination_type=PaginationType.OFFSET,
page=1,
cursor=None,
items_per_page=10,
)
result_false = await RoleCursorCrud.paginate_params(include_total=False)(
pagination_type=PaginationType.OFFSET,
page=1,
cursor=None,
items_per_page=10,
)
assert result_true["include_total"] is True
assert result_false["include_total"] is False
@pytest.mark.anyio
async def test_awaiting_dep_returns_dict(self):
"""Awaiting the dependency returns a dict with all pagination keys."""
dep = RoleCursorCrud.paginate_params()
result = await dep(
pagination_type=PaginationType.OFFSET,
page=2,
cursor=None,
items_per_page=10,
)
assert result == {
"pagination_type": PaginationType.OFFSET,
"page": 2,
"cursor": None,
"items_per_page": 10,
"include_total": True,
}
@pytest.mark.anyio
async def test_integrates_with_paginate_offset(self, db_session: AsyncSession):
"""paginate_params output unpacks into paginate() for offset strategy."""
from fastapi_toolsets.schemas import OffsetPagination
await RoleCrud.create(db_session, RoleCreate(name="admin"))
params = await RoleCursorCrud.paginate_params()(
pagination_type=PaginationType.OFFSET,
page=1,
cursor=None,
items_per_page=10,
)
result = await RoleCursorCrud.paginate(db_session, **params, schema=RoleRead)
assert isinstance(result.pagination, OffsetPagination)
@pytest.mark.anyio
async def test_integrates_with_paginate_cursor(self, db_session: AsyncSession):
"""paginate_params output unpacks into paginate() for cursor strategy."""
from fastapi_toolsets.schemas import CursorPagination
await RoleCrud.create(db_session, RoleCreate(name="admin"))
params = await RoleCursorCrud.paginate_params()(
pagination_type=PaginationType.CURSOR,
page=1,
cursor=None,
items_per_page=10,
)
result = await RoleCursorCrud.paginate(db_session, **params, schema=RoleRead)
assert isinstance(result.pagination, CursorPagination)

View File

@@ -31,7 +31,6 @@ from fastapi_toolsets.models.watched import (
_after_flush, _after_flush,
_after_flush_postexec, _after_flush_postexec,
_after_rollback, _after_rollback,
_call_callback,
_task_error_handler, _task_error_handler,
_upsert_changes, _upsert_changes,
) )
@@ -128,6 +127,17 @@ class WatchAllModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
_test_events.append({"event": "update", "obj_id": self.id, "changes": changes}) _test_events.append({"event": "update", "obj_id": self.id, "changes": changes})
class FailingCallbackModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
"""Model whose on_create always raises to test exception logging."""
__tablename__ = "mixin_failing_callback_models"
name: Mapped[str] = mapped_column(String(50))
async def on_create(self) -> None:
raise RuntimeError("callback intentionally failed")
class NonWatchedModel(MixinBase): class NonWatchedModel(MixinBase):
__tablename__ = "mixin_non_watched_models" __tablename__ = "mixin_non_watched_models"
@@ -135,6 +145,32 @@ class NonWatchedModel(MixinBase):
value: Mapped[str] = mapped_column(String(50)) value: Mapped[str] = mapped_column(String(50))
_attr_access_events: list[dict] = []
class AttrAccessModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
"""Model used to verify that self attributes are accessible in every callback."""
__tablename__ = "mixin_attr_access_models"
name: Mapped[str] = mapped_column(String(50))
async def on_create(self) -> None:
_attr_access_events.append(
{"event": "create", "id": self.id, "name": self.name}
)
async def on_delete(self) -> None:
_attr_access_events.append(
{"event": "delete", "id": self.id, "name": self.name}
)
async def on_update(self, changes: dict) -> None:
_attr_access_events.append(
{"event": "update", "id": self.id, "name": self.name}
)
_sync_events: list[dict] = [] _sync_events: list[dict] = []
@@ -174,6 +210,25 @@ async def mixin_session():
await engine.dispose() await engine.dispose()
@pytest.fixture(scope="function")
async def mixin_session_expire():
"""Session with expire_on_commit=True (the default) to exercise attribute access after commit."""
engine = create_async_engine(DATABASE_URL, echo=False)
async with engine.begin() as conn:
await conn.run_sync(MixinBase.metadata.create_all)
session_factory = async_sessionmaker(engine, expire_on_commit=True)
session = session_factory()
try:
yield session
finally:
await session.close()
async with engine.begin() as conn:
await conn.run_sync(MixinBase.metadata.drop_all)
await engine.dispose()
class TestUUIDMixin: class TestUUIDMixin:
@pytest.mark.anyio @pytest.mark.anyio
async def test_uuid_generated_by_db(self, mixin_session): async def test_uuid_generated_by_db(self, mixin_session):
@@ -742,6 +797,16 @@ class TestWatchedFieldsMixin:
assert _test_events == [] assert _test_events == []
@pytest.mark.anyio
async def test_callback_exception_is_logged(self, mixin_session):
"""Exceptions raised inside on_create are logged, not propagated."""
obj = FailingCallbackModel(name="boom")
mixin_session.add(obj)
with patch.object(_watched_module._logger, "error") as mock_error:
await mixin_session.commit()
await asyncio.sleep(0)
mock_error.assert_called_once()
@pytest.mark.anyio @pytest.mark.anyio
async def test_non_watched_model_no_callback(self, mixin_session): async def test_non_watched_model_no_callback(self, mixin_session):
"""Dirty objects whose type is not a WatchedFieldsMixin are skipped.""" """Dirty objects whose type is not a WatchedFieldsMixin are skipped."""
@@ -903,65 +968,66 @@ class TestSyncCallbacks:
assert updates[0]["changes"]["status"] == {"old": "initial", "new": "updated"} assert updates[0]["changes"]["status"] == {"old": "initial", "new": "updated"}
class TestCallCallback: class TestAttributeAccessInCallbacks:
"""Verify that self attributes are accessible inside every callback type.
Uses expire_on_commit=True (the SQLAlchemy default) so the tests would fail
without the snapshot-restore logic in _schedule_with_snapshot.
"""
@pytest.fixture(autouse=True)
def clear_events(self):
_attr_access_events.clear()
yield
_attr_access_events.clear()
@pytest.mark.anyio @pytest.mark.anyio
async def test_async_callback_scheduled_as_task(self): async def test_on_create_pk_and_field_accessible(self, mixin_session_expire):
"""_call_callback schedules async functions as tasks.""" """id (server default) and regular fields are readable inside on_create."""
called = [] obj = AttrAccessModel(name="hello")
mixin_session_expire.add(obj)
async def async_fn() -> None: await mixin_session_expire.commit()
called.append("async")
loop = asyncio.get_running_loop()
_call_callback(loop, async_fn)
await asyncio.sleep(0) await asyncio.sleep(0)
assert called == ["async"]
events = [e for e in _attr_access_events if e["event"] == "create"]
assert len(events) == 1
assert isinstance(events[0]["id"], uuid.UUID)
assert events[0]["name"] == "hello"
@pytest.mark.anyio @pytest.mark.anyio
async def test_sync_callback_called_directly(self): async def test_on_delete_pk_and_field_accessible(self, mixin_session_expire):
"""_call_callback invokes sync functions immediately.""" """id and regular fields are readable inside on_delete."""
called = [] obj = AttrAccessModel(name="to-delete")
mixin_session_expire.add(obj)
def sync_fn() -> None: await mixin_session_expire.commit()
called.append("sync")
loop = asyncio.get_running_loop()
_call_callback(loop, sync_fn)
assert called == ["sync"]
@pytest.mark.anyio
async def test_sync_callback_exception_logged(self):
"""_call_callback logs exceptions from sync callbacks."""
def failing_fn() -> None:
raise RuntimeError("sync error")
loop = asyncio.get_running_loop()
with patch.object(_watched_module._logger, "error") as mock_error:
_call_callback(loop, failing_fn)
mock_error.assert_called_once()
@pytest.mark.anyio
async def test_async_callback_with_args(self):
"""_call_callback passes arguments to async callbacks."""
received = []
async def async_fn(changes: dict) -> None:
received.append(changes)
loop = asyncio.get_running_loop()
_call_callback(loop, async_fn, {"status": {"old": "a", "new": "b"}})
await asyncio.sleep(0) await asyncio.sleep(0)
assert received == [{"status": {"old": "a", "new": "b"}}] _attr_access_events.clear()
await mixin_session_expire.delete(obj)
await mixin_session_expire.commit()
await asyncio.sleep(0)
events = [e for e in _attr_access_events if e["event"] == "delete"]
assert len(events) == 1
assert isinstance(events[0]["id"], uuid.UUID)
assert events[0]["name"] == "to-delete"
@pytest.mark.anyio @pytest.mark.anyio
async def test_sync_callback_with_args(self): async def test_on_update_pk_and_updated_field_accessible(
"""_call_callback passes arguments to sync callbacks.""" self, mixin_session_expire
received = [] ):
"""id and the new field value are readable inside on_update."""
obj = AttrAccessModel(name="original")
mixin_session_expire.add(obj)
await mixin_session_expire.commit()
await asyncio.sleep(0)
_attr_access_events.clear()
def sync_fn(changes: dict) -> None: obj.name = "updated"
received.append(changes) await mixin_session_expire.commit()
await asyncio.sleep(0)
loop = asyncio.get_running_loop() events = [e for e in _attr_access_events if e["event"] == "update"]
_call_callback(loop, sync_fn, {"x": 1}) assert len(events) == 1
assert received == [{"x": 1}] assert isinstance(events[0]["id"], uuid.UUID)
assert events[0]["name"] == "updated"

View File

@@ -201,6 +201,88 @@ class TestOffsetPagination:
assert data["page"] == 2 assert data["page"] == 2
assert data["has_more"] is True assert data["has_more"] is True
def test_total_count_can_be_none(self):
"""total_count accepts None (include_total=False mode)."""
pagination = OffsetPagination(
total_count=None,
items_per_page=20,
page=1,
has_more=True,
)
assert pagination.total_count is None
def test_serialization_with_none_total_count(self):
"""OffsetPagination serializes total_count=None correctly."""
pagination = OffsetPagination(
total_count=None,
items_per_page=20,
page=1,
has_more=False,
)
data = pagination.model_dump()
assert data["total_count"] is None
def test_pages_computed(self):
"""pages is ceil(total_count / items_per_page)."""
pagination = OffsetPagination(
total_count=42,
items_per_page=10,
page=1,
has_more=True,
)
assert pagination.pages == 5
def test_pages_exact_division(self):
"""pages is exact when total_count is evenly divisible."""
pagination = OffsetPagination(
total_count=40,
items_per_page=10,
page=1,
has_more=False,
)
assert pagination.pages == 4
def test_pages_zero_total(self):
"""pages is 0 when total_count is 0."""
pagination = OffsetPagination(
total_count=0,
items_per_page=10,
page=1,
has_more=False,
)
assert pagination.pages == 0
def test_pages_zero_items_per_page(self):
"""pages is 0 when items_per_page is 0."""
pagination = OffsetPagination(
total_count=100,
items_per_page=0,
page=1,
has_more=False,
)
assert pagination.pages == 0
def test_pages_none_when_total_count_none(self):
"""pages is None when total_count is None (include_total=False)."""
pagination = OffsetPagination(
total_count=None,
items_per_page=20,
page=1,
has_more=True,
)
assert pagination.pages is None
def test_pages_in_serialization(self):
"""pages appears in model_dump output."""
pagination = OffsetPagination(
total_count=25,
items_per_page=10,
page=1,
has_more=True,
)
data = pagination.model_dump()
assert data["pages"] == 3
class TestCursorPagination: class TestCursorPagination:
"""Tests for CursorPagination schema.""" """Tests for CursorPagination schema."""

2
uv.lock generated
View File

@@ -251,7 +251,7 @@ wheels = [
[[package]] [[package]]
name = "fastapi-toolsets" name = "fastapi-toolsets"
version = "2.3.0" version = "2.4.1"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "asyncpg" }, { name = "asyncpg" },