feat: add cursor based pagination in CrudFactory (#86)

This commit is contained in:
d3vyce
2026-02-23 13:51:34 +01:00
committed by GitHub
parent 7482bc5dad
commit 6cf7df55ef
7 changed files with 1003 additions and 41 deletions

View File

@@ -1,6 +1,6 @@
# CRUD # CRUD
Generic async CRUD operations for SQLAlchemy models with search, pagination, and many-to-many support. This module has features that are only compatible with Postgres. Generic async CRUD operations for SQLAlchemy models with search, pagination, and many-to-many support.
!!! info !!! info
This module has been coded and tested to be compatible with PostgreSQL only. This module has been coded and tested to be compatible with PostgreSQL only.
@@ -48,6 +48,21 @@ exists = await UserCrud.exists(session=session, filters=[User.email == email])
## Pagination ## Pagination
!!! info "Added in `v1.1` (only offset_pagination via `paginate` if `<v1.1`)"
Two pagination strategies are available. Both return a [`PaginatedResponse`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse) but differ in how they navigate through results.
| | `offset_paginate` | `cursor_paginate` |
|---|---|---|
| Total count | Yes | No |
| Jump to arbitrary page | Yes | No |
| Performance on deep pages | Degrades | Constant |
| Stable under concurrent inserts | No | Yes |
| Search compatible | Yes | Yes |
| Use case | Admin panels, numbered pagination | Feeds, APIs, infinite scroll |
### Offset pagination
```python ```python
@router.get( @router.get(
"", "",
@@ -58,14 +73,88 @@ async def get_users(
items_per_page: int = 50, items_per_page: int = 50,
page: int = 1, page: int = 1,
): ):
return await crud.UserCrud.paginate( return await crud.UserCrud.offset_paginate(
session=session, session=session,
items_per_page=items_per_page, items_per_page=items_per_page,
page=page, page=page,
) )
``` ```
The [`paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.paginate) function will return a [`PaginatedResponse`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse). The [`offset_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.offset_paginate) method returns a [`PaginatedResponse`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse) whose `pagination` field is an [`OffsetPagination`](../reference/schemas.md#fastapi_toolsets.schemas.OffsetPagination) object:
```json
{
"status": "SUCCESS",
"data": ["..."],
"pagination": {
"total_count": 100,
"page": 1,
"items_per_page": 20,
"has_more": true
}
}
```
!!! warning "Deprecated: `paginate`"
The `paginate` function is a backward-compatible alias for `offset_paginate`. This function is **deprecated** and will be removed in **v2.0**.
### Cursor pagination
```python
@router.get(
"",
response_model=PaginatedResponse[UserRead],
)
async def list_users(
session: SessionDep,
cursor: str | None = None,
items_per_page: int = 20,
):
return await UserCrud.cursor_paginate(
session=session,
cursor=cursor,
items_per_page=items_per_page,
)
```
The [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate) method returns a [`PaginatedResponse`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse) whose `pagination` field is a [`CursorPagination`](../reference/schemas.md#fastapi_toolsets.schemas.CursorPagination) object:
```json
{
"status": "SUCCESS",
"data": ["..."],
"pagination": {
"next_cursor": "eyJ2YWx1ZSI6ICIzZjQ3YWM2OS0uLi4ifQ==",
"prev_cursor": null,
"items_per_page": 20,
"has_more": true
}
}
```
Pass `next_cursor` as the `cursor` query parameter on the next request to advance to the next page. `prev_cursor` is set on pages 2+ and points back to the first item of the current page. Both are `null` when there is no adjacent page.
#### Choosing a cursor column
The cursor column is set once on [`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory) via the `cursor_column` parameter. It must be monotonically ordered for stable results:
- Auto-increment integer PKs
- UUID v7 PKs
- Timestamps
!!! warning
Random UUID v4 PKs are **not** suitable as cursor columns because their ordering is non-deterministic.
!!! note
`cursor_column` is required. Calling [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate) on a CRUD class that has no `cursor_column` configured raises a `ValueError`.
```python
# Paginate by the primary key
PostCrud = CrudFactory(model=Post, cursor_column=Post.id)
# Paginate by a timestamp column instead
PostCrud = CrudFactory(model=Post, cursor_column=Post.created_at)
```
## Search ## Search
@@ -82,7 +171,7 @@ PostCrud = CrudFactory(
) )
``` ```
This allow to do a search with the [`paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.paginate) function: This allows searching with both [`offset_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.offset_paginate) and [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate):
```python ```python
@router.get( @router.get(
@@ -95,7 +184,7 @@ async def get_users(
page: int = 1, page: int = 1,
search: str | None = None, search: str | None = None,
): ):
return await crud.UserCrud.paginate( return await crud.UserCrud.offset_paginate(
session=session, session=session,
items_per_page=items_per_page, items_per_page=items_per_page,
page=page, page=page,
@@ -103,9 +192,28 @@ async def get_users(
) )
``` ```
```python
@router.get(
"",
response_model=PaginatedResponse[User],
)
async def get_users(
session: SessionDep,
cursor: str | None = None,
items_per_page: int = 50,
search: str | None = None,
):
return await crud.UserCrud.cursor_paginate(
session=session,
items_per_page=items_per_page,
cursor=cursor,
search=search,
)
```
## Relationship loading ## Relationship loading
!!! info "Added in v1.1" !!! info "Added in `v1.1`"
By default, SQLAlchemy relationships are not loaded unless explicitly requested. Instead of using `lazy="selectin"` on model definitions (which is implicit and applies globally), define a `default_load_options` on the CRUD class to control loading strategy explicitly. By default, SQLAlchemy relationships are not loaded unless explicitly requested. Instead of using `lazy="selectin"` on model definitions (which is implicit and applies globally), define a `default_load_options` on the CRUD class to control loading strategy explicitly.
@@ -124,7 +232,7 @@ ArticleCrud = CrudFactory(
) )
``` ```
`default_load_options` applies automatically to all read operations (`get`, `first`, `get_multi`, `paginate`). When `load_options` is passed at call-site, it **fully replaces** `default_load_options` for that query — giving you precise per-call control: `default_load_options` applies automatically to all read operations (`get`, `first`, `get_multi`, `offset_paginate`, `cursor_paginate`). When `load_options` is passed at call-site, it **fully replaces** `default_load_options` for that query — giving you precise per-call control:
```python ```python
# Only loads category, tags are not loaded # Only loads category, tags are not loaded
@@ -168,7 +276,7 @@ await UserCrud.upsert(
!!! info "Added in `v1.1`" !!! info "Added in `v1.1`"
Pass a Pydantic schema class to `create`, `get`, `update`, or `paginate` to serialize the result directly into that schema and wrap it in a [`Response[schema]`](../reference/schemas.md#fastapi_toolsets.schemas.Response) or [`PaginatedResponse[schema]`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse): Pass a Pydantic schema class to `create`, `get`, `update`, or `offset_paginate` to serialize the result directly into that schema and wrap it in a [`Response[schema]`](../reference/schemas.md#fastapi_toolsets.schemas.Response) or [`PaginatedResponse[schema]`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse):
```python ```python
class UserRead(PydanticBase): class UserRead(PydanticBase):
@@ -188,7 +296,7 @@ async def get_user(session: SessionDep, uuid: UUID) -> Response[UserRead]:
@router.get("") @router.get("")
async def list_users(session: SessionDep, page: int = 1) -> PaginatedResponse[UserRead]: async def list_users(session: SessionDep, page: int = 1) -> PaginatedResponse[UserRead]:
return await crud.UserCrud.paginate( return await crud.UserCrud.offset_paginate(
session=session, session=session,
page=page, page=page,
schema=UserRead, schema=UserRead,

View File

@@ -2,12 +2,15 @@
from __future__ import annotations from __future__ import annotations
import base64
import json
import uuid as uuid_module
import warnings import warnings
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import and_, func, select from sqlalchemy import Integer, Uuid, and_, func, select
from sqlalchemy import delete as sql_delete from sqlalchemy import delete as sql_delete
from sqlalchemy.dialects.postgresql import insert from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.exc import NoResultFound from sqlalchemy.exc import NoResultFound
@@ -18,7 +21,7 @@ from sqlalchemy.sql.roles import WhereHavingRole
from ..db import get_transaction from ..db import get_transaction
from ..exceptions import NotFoundError from ..exceptions import NotFoundError
from ..schemas import PaginatedResponse, Pagination, Response from ..schemas import CursorPagination, OffsetPagination, PaginatedResponse, Response
from .search import SearchConfig, SearchFieldType, build_search_filters from .search import SearchConfig, SearchFieldType, build_search_filters
ModelType = TypeVar("ModelType", bound=DeclarativeBase) ModelType = TypeVar("ModelType", bound=DeclarativeBase)
@@ -27,6 +30,16 @@ JoinType = list[tuple[type[DeclarativeBase], Any]]
M2MFieldType = Mapping[str, QueryableAttribute[Any]] M2MFieldType = Mapping[str, QueryableAttribute[Any]]
def _encode_cursor(value: Any) -> str:
"""Encode cursor column value as an base64 string."""
return base64.b64encode(json.dumps(str(value)).encode()).decode()
def _decode_cursor(cursor: str) -> str:
"""Decode cursor base64 string."""
return json.loads(base64.b64decode(cursor.encode()).decode())
class AsyncCrud(Generic[ModelType]): class AsyncCrud(Generic[ModelType]):
"""Generic async CRUD operations for SQLAlchemy models. """Generic async CRUD operations for SQLAlchemy models.
@@ -37,6 +50,7 @@ class AsyncCrud(Generic[ModelType]):
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
m2m_fields: ClassVar[M2MFieldType | None] = None m2m_fields: ClassVar[M2MFieldType | None] = None
default_load_options: ClassVar[list[ExecutableOption] | None] = None default_load_options: ClassVar[list[ExecutableOption] | None] = None
cursor_column: ClassVar[Any | None] = None
@classmethod @classmethod
def _resolve_load_options( def _resolve_load_options(
@@ -664,7 +678,7 @@ class AsyncCrud(Generic[ModelType]):
@overload @overload
@classmethod @classmethod
async def paginate( # pragma: no cover async def offset_paginate( # pragma: no cover
cls: type[Self], cls: type[Self],
session: AsyncSession, session: AsyncSession,
*, *,
@@ -683,7 +697,7 @@ class AsyncCrud(Generic[ModelType]):
# Backward-compatible - will be removed in v2.0 # Backward-compatible - will be removed in v2.0
@overload @overload
@classmethod @classmethod
async def paginate( # pragma: no cover async def offset_paginate( # pragma: no cover
cls: type[Self], cls: type[Self],
session: AsyncSession, session: AsyncSession,
*, *,
@@ -700,7 +714,7 @@ class AsyncCrud(Generic[ModelType]):
) -> PaginatedResponse[ModelType]: ... ) -> PaginatedResponse[ModelType]: ...
@classmethod @classmethod
async def paginate( async def offset_paginate(
cls: type[Self], cls: type[Self],
session: AsyncSession, session: AsyncSession,
*, *,
@@ -715,7 +729,7 @@ class AsyncCrud(Generic[ModelType]):
search_fields: Sequence[SearchFieldType] | None = None, search_fields: Sequence[SearchFieldType] | None = None,
schema: type[BaseModel] | None = None, schema: type[BaseModel] | None = None,
) -> PaginatedResponse[ModelType] | PaginatedResponse[Any]: ) -> PaginatedResponse[ModelType] | PaginatedResponse[Any]:
"""Get paginated results with metadata. """Get paginated results using offset-based pagination.
Args: Args:
session: DB async session session: DB async session
@@ -731,7 +745,7 @@ class AsyncCrud(Generic[ModelType]):
schema: Optional Pydantic schema to serialize each item into. schema: Optional Pydantic schema to serialize each item into.
Returns: Returns:
Dict with 'data' and 'pagination' keys PaginatedResponse with OffsetPagination metadata
""" """
filters = list(filters) if filters else [] filters = list(filters) if filters else []
offset = (page - 1) * items_per_page offset = (page - 1) * items_per_page
@@ -803,7 +817,7 @@ class AsyncCrud(Generic[ModelType]):
return PaginatedResponse( return PaginatedResponse(
data=items, data=items,
pagination=Pagination( pagination=OffsetPagination(
total_count=total_count, total_count=total_count,
items_per_page=items_per_page, items_per_page=items_per_page,
page=page, page=page,
@@ -811,6 +825,175 @@ class AsyncCrud(Generic[ModelType]):
), ),
) )
# Backward-compatible - will be removed in v2.0
paginate = offset_paginate
@overload
@classmethod
async def cursor_paginate( # pragma: no cover
cls: type[Self],
session: AsyncSession,
*,
cursor: str | None = None,
filters: list[Any] | None = None,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: list[ExecutableOption] | None = None,
order_by: Any | None = None,
items_per_page: int = 20,
search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None,
schema: type[SchemaType],
) -> PaginatedResponse[SchemaType]: ...
# Backward-compatible - will be removed in v2.0
@overload
@classmethod
async def cursor_paginate( # pragma: no cover
cls: type[Self],
session: AsyncSession,
*,
cursor: str | None = None,
filters: list[Any] | None = None,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: list[ExecutableOption] | None = None,
order_by: Any | None = None,
items_per_page: int = 20,
search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None,
schema: None = ...,
) -> PaginatedResponse[ModelType]: ...
@classmethod
async def cursor_paginate(
cls: type[Self],
session: AsyncSession,
*,
cursor: str | None = None,
filters: list[Any] | None = None,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: list[ExecutableOption] | None = None,
order_by: Any | None = None,
items_per_page: int = 20,
search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None,
schema: type[BaseModel] | None = None,
) -> PaginatedResponse[ModelType] | PaginatedResponse[Any]:
"""Get paginated results using cursor-based pagination.
Args:
session: DB async session.
cursor: Cursor string from a previous ``CursorPagination``.
Omit (or pass ``None``) to start from the beginning.
filters: List of SQLAlchemy filter conditions.
joins: List of ``(model, condition)`` tuples for joining related
tables.
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN.
load_options: SQLAlchemy loader options. Falls back to
``default_load_options`` when not provided.
order_by: Additional ordering applied after the cursor column.
items_per_page: Number of items per page (default 20).
search: Search query string or SearchConfig object.
search_fields: Fields to search in (overrides class default).
schema: Optional Pydantic schema to serialize each item into.
Returns:
PaginatedResponse with CursorPagination metadata
"""
filters = list(filters) if filters else []
search_joins: list[Any] = []
if cls.cursor_column is None:
raise ValueError(
f"{cls.__name__}.cursor_column is not set. "
"Pass cursor_column=<column> to CrudFactory() to use cursor_paginate."
)
cursor_column: Any = cls.cursor_column
cursor_col_name: str = cursor_column.key
if cursor is not None:
raw_val = _decode_cursor(cursor)
col_type = cursor_column.property.columns[0].type
if isinstance(col_type, Integer):
cursor_val: Any = int(raw_val)
elif isinstance(col_type, Uuid):
cursor_val = uuid_module.UUID(raw_val)
else:
cursor_val = raw_val
filters.append(cursor_column > cursor_val)
# Build search filters
if search:
search_filters, search_joins = build_search_filters(
cls.model,
search,
search_fields=search_fields,
default_fields=cls.searchable_fields,
)
filters.extend(search_filters)
# Build query
q = select(cls.model)
# Apply explicit joins
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
# Apply search joins (always outer joins)
for join_rel in search_joins:
q = q.outerjoin(join_rel)
if filters:
q = q.where(and_(*filters))
if resolved := cls._resolve_load_options(load_options):
q = q.options(*resolved)
# Cursor column is always the primary sort
q = q.order_by(cursor_column)
if order_by is not None:
q = q.order_by(order_by)
# Fetch one extra to detect whether a next page exists
q = q.limit(items_per_page + 1)
result = await session.execute(q)
raw_items = cast(list[ModelType], result.unique().scalars().all())
has_more = len(raw_items) > items_per_page
items_page = raw_items[:items_per_page]
# next_cursor points past the last item on this page
next_cursor: str | None = None
if has_more and items_page:
next_cursor = _encode_cursor(getattr(items_page[-1], cursor_col_name))
# prev_cursor points to the first item on this page or None when on the first page
prev_cursor: str | None = None
if cursor is not None and items_page:
prev_cursor = _encode_cursor(getattr(items_page[0], cursor_col_name))
items: list[Any] = (
[schema.model_validate(item) for item in items_page]
if schema
else items_page
)
return PaginatedResponse(
data=items,
pagination=CursorPagination(
next_cursor=next_cursor,
prev_cursor=prev_cursor,
items_per_page=items_per_page,
has_more=has_more,
),
)
def CrudFactory( def CrudFactory(
model: type[ModelType], model: type[ModelType],
@@ -818,6 +1001,7 @@ def CrudFactory(
searchable_fields: Sequence[SearchFieldType] | None = None, searchable_fields: Sequence[SearchFieldType] | None = None,
m2m_fields: M2MFieldType | None = None, m2m_fields: M2MFieldType | None = None,
default_load_options: list[ExecutableOption] | None = None, default_load_options: list[ExecutableOption] | None = None,
cursor_column: Any | None = None,
) -> type[AsyncCrud[ModelType]]: ) -> type[AsyncCrud[ModelType]]:
"""Create a CRUD class for a specific model. """Create a CRUD class for a specific model.
@@ -832,6 +1016,8 @@ def CrudFactory(
instead of ``lazy="selectin"`` on the model so that loading instead of ``lazy="selectin"`` on the model so that loading
strategy is explicit and per-CRUD. Overridden entirely (not strategy is explicit and per-CRUD. Overridden entirely (not
merged) when ``load_options`` is provided at call-site. merged) when ``load_options`` is provided at call-site.
cursor_column: Required to call ``cursor_paginate``
Must be monotonically ordered (e.g. integer PK, UUID v7, timestamp).
Returns: Returns:
AsyncCrud subclass bound to the model AsyncCrud subclass bound to the model
@@ -857,6 +1043,12 @@ def CrudFactory(
m2m_fields={"tag_ids": Post.tags}, m2m_fields={"tag_ids": Post.tags},
) )
# With a fixed cursor column for cursor_paginate:
PostCrud = CrudFactory(
Post,
cursor_column=Post.created_at,
)
# With default load strategy (replaces lazy="selectin" on the model): # With default load strategy (replaces lazy="selectin" on the model):
ArticleCrud = CrudFactory( ArticleCrud = CrudFactory(
Article, Article,
@@ -878,7 +1070,7 @@ def CrudFactory(
post = await PostCrud.create(session, PostCreate(title="Hello", tag_ids=[id1, id2])) post = await PostCrud.create(session, PostCreate(title="Hello", tag_ids=[id1, id2]))
# With search # With search
result = await UserCrud.paginate(session, search="john") result = await UserCrud.offset_paginate(session, search="john")
# With joins (inner join by default): # With joins (inner join by default):
users = await UserCrud.get_multi( users = await UserCrud.get_multi(
@@ -903,6 +1095,7 @@ def CrudFactory(
"searchable_fields": searchable_fields, "searchable_fields": searchable_fields,
"m2m_fields": m2m_fields, "m2m_fields": m2m_fields,
"default_load_options": default_load_options, "default_load_options": default_load_options,
"cursor_column": cursor_column,
}, },
) )
return cast(type[AsyncCrud[ModelType]], cls) return cast(type[AsyncCrud[ModelType]], cls)

View File

@@ -7,7 +7,9 @@ from pydantic import BaseModel, ConfigDict
__all__ = [ __all__ = [
"ApiError", "ApiError",
"CursorPagination",
"ErrorResponse", "ErrorResponse",
"OffsetPagination",
"Pagination", "Pagination",
"PaginatedResponse", "PaginatedResponse",
"PydanticBase", "PydanticBase",
@@ -90,8 +92,8 @@ class ErrorResponse(BaseResponse):
data: Any | None = None data: Any | None = None
class Pagination(PydanticBase): class OffsetPagination(PydanticBase):
"""Pagination metadata for list responses. """Pagination metadata for offset-based list responses.
Attributes: Attributes:
total_count: Total number of items across all pages total_count: Total number of items across all pages
@@ -106,17 +108,28 @@ class Pagination(PydanticBase):
has_more: bool has_more: bool
class PaginatedResponse(BaseResponse, Generic[DataT]): # Backward-compatible - will be removed in v2.0
"""Paginated API response for list endpoints. Pagination = OffsetPagination
Example:
```python class CursorPagination(PydanticBase):
PaginatedResponse[UserRead]( """Pagination metadata for cursor-based list responses.
data=users,
pagination=Pagination(total_count=100, items_per_page=10, page=1, has_more=True) Attributes:
) next_cursor: Encoded cursor for the next page, or None on the last page.
``` prev_cursor: Encoded cursor for the previous page, or None on the first page.
items_per_page: Number of items requested per page.
has_more: Whether there is at least one more page after this one.
""" """
next_cursor: str | None
prev_cursor: str | None = None
items_per_page: int
has_more: bool
class PaginatedResponse(BaseResponse, Generic[DataT]):
"""Paginated API response for list endpoints."""
data: list[DataT] data: list[DataT]
pagination: Pagination pagination: OffsetPagination | CursorPagination

View File

@@ -5,13 +5,12 @@ import uuid
import pytest import pytest
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import Column, ForeignKey, String, Table, Uuid from sqlalchemy import Column, ForeignKey, Integer, String, Table, Uuid
from fastapi_toolsets.schemas import PydanticBase
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from fastapi_toolsets.crud import CrudFactory from fastapi_toolsets.crud import CrudFactory
from fastapi_toolsets.schemas import PydanticBase
DATABASE_URL = os.getenv( DATABASE_URL = os.getenv(
key="DATABASE_URL", key="DATABASE_URL",
@@ -71,6 +70,15 @@ post_tags = Table(
) )
class IntRole(Base):
"""Test role model with auto-increment integer PK."""
__tablename__ = "int_roles"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
name: Mapped[str] = mapped_column(String(50), unique=True)
class Post(Base): class Post(Base):
"""Test post model.""" """Test post model."""
@@ -116,7 +124,7 @@ class UserCreate(BaseModel):
class UserRead(PydanticBase): class UserRead(PydanticBase):
"""Schema for reading a user (subset of fields).""" """Schema for reading a user (subset of fields — no email)."""
id: uuid.UUID id: uuid.UUID
username: str username: str
@@ -176,8 +184,17 @@ class PostM2MUpdate(BaseModel):
tag_ids: list[uuid.UUID] | None = None tag_ids: list[uuid.UUID] | None = None
class IntRoleCreate(BaseModel):
"""Schema for creating an IntRole."""
name: str
RoleCrud = CrudFactory(Role) RoleCrud = CrudFactory(Role)
RoleCursorCrud = CrudFactory(Role, cursor_column=Role.id)
IntRoleCursorCrud = CrudFactory(IntRole, cursor_column=IntRole.id)
UserCrud = CrudFactory(User) UserCrud = CrudFactory(User)
UserCursorCrud = CrudFactory(User, cursor_column=User.id)
PostCrud = CrudFactory(Post) PostCrud = CrudFactory(Post)
TagCrud = CrudFactory(Tag) TagCrud = CrudFactory(Tag)
PostM2MCrud = CrudFactory(Post, m2m_fields={"tag_ids": Post.tags}) PostM2MCrud = CrudFactory(Post, m2m_fields={"tag_ids": Post.tags})

View File

@@ -11,6 +11,8 @@ from fastapi_toolsets.crud.factory import AsyncCrud
from fastapi_toolsets.exceptions import NotFoundError from fastapi_toolsets.exceptions import NotFoundError
from .conftest import ( from .conftest import (
IntRoleCreate,
IntRoleCursorCrud,
Post, Post,
PostCreate, PostCreate,
PostCrud, PostCrud,
@@ -20,6 +22,7 @@ from .conftest import (
Role, Role,
RoleCreate, RoleCreate,
RoleCrud, RoleCrud,
RoleCursorCrud,
RoleRead, RoleRead,
RoleUpdate, RoleUpdate,
TagCreate, TagCreate,
@@ -27,6 +30,7 @@ from .conftest import (
User, User,
UserCreate, UserCreate,
UserCrud, UserCrud,
UserCursorCrud,
UserRead, UserRead,
UserUpdate, UserUpdate,
) )
@@ -581,8 +585,11 @@ class TestCrudPaginate:
for i in range(25): for i in range(25):
await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}")) await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
from fastapi_toolsets.schemas import OffsetPagination
result = await RoleCrud.paginate(db_session, page=1, items_per_page=10) result = await RoleCrud.paginate(db_session, page=1, items_per_page=10)
assert isinstance(result.pagination, OffsetPagination)
assert len(result.data) == 10 assert len(result.data) == 10
assert result.pagination.total_count == 25 assert result.pagination.total_count == 25
assert result.pagination.page == 1 assert result.pagination.page == 1
@@ -613,6 +620,8 @@ class TestCrudPaginate:
), ),
) )
from fastapi_toolsets.schemas import OffsetPagination
result = await UserCrud.paginate( result = await UserCrud.paginate(
db_session, db_session,
filters=[User.is_active == True], # noqa: E712 filters=[User.is_active == True], # noqa: E712
@@ -620,6 +629,7 @@ class TestCrudPaginate:
items_per_page=10, items_per_page=10,
) )
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 5 assert result.pagination.total_count == 5
@pytest.mark.anyio @pytest.mark.anyio
@@ -835,6 +845,8 @@ class TestCrudJoins:
), ),
) )
from fastapi_toolsets.schemas import OffsetPagination
# Paginate users with published posts # Paginate users with published posts
result = await UserCrud.paginate( result = await UserCrud.paginate(
db_session, db_session,
@@ -844,6 +856,7 @@ class TestCrudJoins:
items_per_page=10, items_per_page=10,
) )
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 3 assert result.pagination.total_count == 3
assert len(result.data) == 3 assert len(result.data) == 3
@@ -866,6 +879,8 @@ class TestCrudJoins:
UserCreate(username="without_post", email="without@test.com"), UserCreate(username="without_post", email="without@test.com"),
) )
from fastapi_toolsets.schemas import OffsetPagination
# Paginate with outer join # Paginate with outer join
result = await UserCrud.paginate( result = await UserCrud.paginate(
db_session, db_session,
@@ -875,6 +890,7 @@ class TestCrudJoins:
items_per_page=10, items_per_page=10,
) )
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 2 assert result.pagination.total_count == 2
assert len(result.data) == 2 assert len(result.data) == 2
@@ -1512,3 +1528,494 @@ class TestSchemaResponse:
assert isinstance(result, Response) assert isinstance(result, Response)
assert isinstance(result.data, RoleRead) assert isinstance(result.data, RoleRead)
class TestPaginateAlias:
"""Tests that paginate is a backward-compatible alias for offset_paginate."""
def test_paginate_is_alias_of_offset_paginate(self):
"""paginate and offset_paginate are the same underlying function."""
assert RoleCrud.paginate.__func__ is RoleCrud.offset_paginate.__func__
@pytest.mark.anyio
async def test_paginate_alias_returns_offset_pagination(
self, db_session: AsyncSession
):
"""paginate() still works and returns PaginatedResponse with OffsetPagination."""
from fastapi_toolsets.schemas import OffsetPagination, PaginatedResponse
for i in range(3):
await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
result = await RoleCrud.paginate(db_session, page=1, items_per_page=10)
assert isinstance(result, PaginatedResponse)
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 3
assert result.pagination.page == 1
class TestCursorPaginate:
"""Tests for cursor-based pagination via cursor_paginate()."""
@pytest.mark.anyio
async def test_first_page_no_cursor(self, db_session: AsyncSession):
"""cursor_paginate without cursor returns the first page."""
from fastapi_toolsets.schemas import CursorPagination, PaginatedResponse
for i in range(25):
await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
result = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=10)
assert isinstance(result, PaginatedResponse)
assert isinstance(result.pagination, CursorPagination)
assert len(result.data) == 10
assert result.pagination.has_more is True
assert result.pagination.next_cursor is not None
assert result.pagination.prev_cursor is None
assert result.pagination.items_per_page == 10
@pytest.mark.anyio
async def test_last_page(self, db_session: AsyncSession):
"""cursor_paginate returns has_more=False and next_cursor=None on last page."""
from fastapi_toolsets.schemas import CursorPagination
for i in range(5):
await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
result = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=10)
assert isinstance(result.pagination, CursorPagination)
assert len(result.data) == 5
assert result.pagination.has_more is False
assert result.pagination.next_cursor is None
@pytest.mark.anyio
async def test_advances_correctly(self, db_session: AsyncSession):
"""Providing next_cursor from the first page returns the next page."""
for i in range(15):
await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
from fastapi_toolsets.schemas import CursorPagination
page1 = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=10)
assert isinstance(page1.pagination, CursorPagination)
assert len(page1.data) == 10
assert page1.pagination.has_more is True
cursor = page1.pagination.next_cursor
page2 = await RoleCursorCrud.cursor_paginate(
db_session, cursor=cursor, items_per_page=10
)
assert isinstance(page2.pagination, CursorPagination)
assert len(page2.data) == 5
assert page2.pagination.has_more is False
assert page2.pagination.next_cursor is None
@pytest.mark.anyio
async def test_no_duplicates_across_pages(self, db_session: AsyncSession):
"""Items from consecutive cursor pages are non-overlapping and cover all rows."""
for i in range(7):
await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
from fastapi_toolsets.schemas import CursorPagination
page1 = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=4)
assert isinstance(page1.pagination, CursorPagination)
page2 = await RoleCursorCrud.cursor_paginate(
db_session,
cursor=page1.pagination.next_cursor,
items_per_page=4,
)
ids_page1 = {r.id for r in page1.data}
ids_page2 = {r.id for r in page2.data}
assert ids_page1.isdisjoint(ids_page2)
assert len(ids_page1 | ids_page2) == 7
@pytest.mark.anyio
async def test_empty_table(self, db_session: AsyncSession):
"""cursor_paginate on an empty table returns empty data with no cursor."""
from fastapi_toolsets.schemas import CursorPagination
result = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=10)
assert isinstance(result.pagination, CursorPagination)
assert result.data == []
assert result.pagination.has_more is False
assert result.pagination.next_cursor is None
assert result.pagination.prev_cursor is None
@pytest.mark.anyio
async def test_with_filters(self, db_session: AsyncSession):
"""cursor_paginate respects filters."""
for i in range(10):
await UserCrud.create(
db_session,
UserCreate(
username=f"user{i}",
email=f"user{i}@test.com",
is_active=i % 2 == 0,
),
)
result = await UserCursorCrud.cursor_paginate(
db_session,
filters=[User.is_active == True], # noqa: E712
items_per_page=20,
)
assert len(result.data) == 5
assert all(u.is_active for u in result.data)
@pytest.mark.anyio
async def test_with_schema(self, db_session: AsyncSession):
"""cursor_paginate with schema serializes items into the schema."""
from fastapi_toolsets.schemas import PaginatedResponse
for i in range(3):
await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
result = await RoleCursorCrud.cursor_paginate(db_session, schema=RoleRead)
assert isinstance(result, PaginatedResponse)
assert all(isinstance(item, RoleRead) for item in result.data)
assert all(
hasattr(item, "id") and hasattr(item, "name") for item in result.data
)
@pytest.mark.anyio
async def test_with_cursor_column(self, db_session: AsyncSession):
"""cursor_paginate uses cursor_column set on CrudFactory."""
from fastapi_toolsets.crud import CrudFactory
from fastapi_toolsets.schemas import CursorPagination
RoleNameCrud = CrudFactory(Role, cursor_column=Role.name)
for i in range(5):
await RoleNameCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
result = await RoleNameCrud.cursor_paginate(db_session, items_per_page=3)
assert isinstance(result.pagination, CursorPagination)
assert len(result.data) == 3
assert result.pagination.has_more is True
assert result.pagination.next_cursor is not None
@pytest.mark.anyio
async def test_raises_without_cursor_column(self, db_session: AsyncSession):
"""cursor_paginate raises ValueError when cursor_column is not configured."""
with pytest.raises(ValueError, match="cursor_column is not set"):
await RoleCrud.cursor_paginate(db_session)
class TestCursorPaginatePrevCursor:
"""Tests for prev_cursor behavior in cursor_paginate()."""
@pytest.mark.anyio
async def test_prev_cursor_none_on_first_page(self, db_session: AsyncSession):
"""prev_cursor is None when no cursor was provided (first page)."""
for i in range(5):
await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
from fastapi_toolsets.schemas import CursorPagination
result = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=3)
assert isinstance(result.pagination, CursorPagination)
assert result.pagination.prev_cursor is None
@pytest.mark.anyio
async def test_prev_cursor_set_on_subsequent_pages(self, db_session: AsyncSession):
"""prev_cursor is set when a cursor was provided (subsequent pages)."""
for i in range(10):
await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
from fastapi_toolsets.schemas import CursorPagination
page1 = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=5)
assert isinstance(page1.pagination, CursorPagination)
page2 = await RoleCursorCrud.cursor_paginate(
db_session,
cursor=page1.pagination.next_cursor,
items_per_page=5,
)
assert isinstance(page2.pagination, CursorPagination)
assert page2.pagination.prev_cursor is not None
@pytest.mark.anyio
async def test_prev_cursor_points_to_first_item(self, db_session: AsyncSession):
"""prev_cursor encodes the value of the first item on the current page."""
import base64
import json
for i in range(10):
await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
from fastapi_toolsets.schemas import CursorPagination
page1 = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=5)
assert isinstance(page1.pagination, CursorPagination)
page2 = await RoleCursorCrud.cursor_paginate(
db_session,
cursor=page1.pagination.next_cursor,
items_per_page=5,
)
assert isinstance(page2.pagination, CursorPagination)
assert page2.pagination.prev_cursor is not None
# Decode prev_cursor and compare to first item's id
decoded = json.loads(
base64.b64decode(page2.pagination.prev_cursor.encode()).decode()
)
first_item_id = str(page2.data[0].id)
assert decoded == first_item_id
class TestCursorPaginateWithSearch:
"""Tests for cursor_paginate() combined with search."""
@pytest.mark.anyio
async def test_cursor_paginate_with_search(self, db_session: AsyncSession):
"""cursor_paginate respects search filters alongside cursor predicate."""
from fastapi_toolsets.crud import CrudFactory
# Create a CRUD with searchable fields and cursor column
SearchableRoleCrud = CrudFactory(
Role, searchable_fields=[Role.name], cursor_column=Role.id
)
for i in range(5):
await RoleCrud.create(db_session, RoleCreate(name=f"admin{i:02d}"))
for i in range(5):
await RoleCrud.create(db_session, RoleCreate(name=f"user{i:02d}"))
result = await SearchableRoleCrud.cursor_paginate(
db_session,
search="admin",
items_per_page=20,
)
assert len(result.data) == 5
assert all("admin" in r.name for r in result.data)
class TestCursorPaginateExtraOptions:
"""Tests for cursor_paginate() covering joins, load_options, and order_by."""
@pytest.mark.anyio
async def test_with_joins(self, db_session: AsyncSession):
"""cursor_paginate applies explicit inner joins."""
from fastapi_toolsets.schemas import CursorPagination
role = await RoleCrud.create(db_session, RoleCreate(name="member"))
for i in range(5):
await UserCrud.create(
db_session,
UserCreate(
username=f"u{i}",
email=f"u{i}@test.com",
role_id=role.id,
),
)
# One user without a role to confirm inner join excludes them
await UserCrud.create(
db_session,
UserCreate(username="norole", email="norole@test.com"),
)
result = await UserCursorCrud.cursor_paginate(
db_session,
joins=[(Role, User.role_id == Role.id)],
items_per_page=20,
)
assert isinstance(result.pagination, CursorPagination)
# Only users with a role are returned (inner join)
assert len(result.data) == 5
@pytest.mark.anyio
async def test_with_outer_join(self, db_session: AsyncSession):
"""cursor_paginate applies LEFT OUTER JOIN when outer_join=True."""
from fastapi_toolsets.schemas import CursorPagination
role = await RoleCrud.create(db_session, RoleCreate(name="member"))
for i in range(3):
await UserCrud.create(
db_session,
UserCreate(
username=f"u{i}",
email=f"u{i}@test.com",
role_id=role.id,
),
)
await UserCrud.create(
db_session,
UserCreate(username="norole", email="norole@test.com"),
)
result = await UserCursorCrud.cursor_paginate(
db_session,
joins=[(Role, User.role_id == Role.id)],
outer_join=True,
items_per_page=20,
)
assert isinstance(result.pagination, CursorPagination)
# All users are included (outer join)
assert len(result.data) == 4
@pytest.mark.anyio
async def test_with_load_options(self, db_session: AsyncSession):
"""cursor_paginate passes load_options to the query."""
from fastapi_toolsets.schemas import CursorPagination
role = await RoleCrud.create(db_session, RoleCreate(name="manager"))
for i in range(3):
await UserCrud.create(
db_session,
UserCreate(
username=f"u{i}",
email=f"u{i}@test.com",
role_id=role.id,
),
)
result = await UserCursorCrud.cursor_paginate(
db_session,
load_options=[selectinload(User.role)],
items_per_page=20,
)
assert isinstance(result.pagination, CursorPagination)
assert len(result.data) == 3
# Relationship was eagerly loaded
assert all(u.role is not None for u in result.data)
@pytest.mark.anyio
async def test_with_order_by(self, db_session: AsyncSession):
"""cursor_paginate applies additional order_by after the cursor column."""
from fastapi_toolsets.schemas import CursorPagination
for i in range(5):
await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
result = await RoleCursorCrud.cursor_paginate(
db_session,
order_by=Role.name.desc(),
items_per_page=3,
)
assert isinstance(result.pagination, CursorPagination)
assert len(result.data) == 3
@pytest.mark.anyio
async def test_integer_cursor_column(self, db_session: AsyncSession):
"""cursor_paginate decodes Integer cursor values correctly."""
from fastapi_toolsets.schemas import CursorPagination
for i in range(5):
await IntRoleCursorCrud.create(db_session, IntRoleCreate(name=f"role{i}"))
page1 = await IntRoleCursorCrud.cursor_paginate(db_session, items_per_page=3)
assert isinstance(page1.pagination, CursorPagination)
assert len(page1.data) == 3
assert page1.pagination.has_more is True
page2 = await IntRoleCursorCrud.cursor_paginate(
db_session,
cursor=page1.pagination.next_cursor,
items_per_page=3,
)
assert isinstance(page2.pagination, CursorPagination)
assert len(page2.data) == 2
assert page2.pagination.has_more is False
@pytest.mark.anyio
async def test_string_cursor_column(self, db_session: AsyncSession):
"""cursor_paginate decodes non-UUID/non-Integer cursor values (string branch)."""
from fastapi_toolsets.crud import CrudFactory
from fastapi_toolsets.schemas import CursorPagination
RoleNameCursorCrud = CrudFactory(Role, cursor_column=Role.name)
for i in range(5):
await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
page1 = await RoleNameCursorCrud.cursor_paginate(db_session, items_per_page=3)
assert isinstance(page1.pagination, CursorPagination)
assert len(page1.data) == 3
assert page1.pagination.has_more is True
page2 = await RoleNameCursorCrud.cursor_paginate(
db_session,
cursor=page1.pagination.next_cursor,
items_per_page=3,
)
assert isinstance(page2.pagination, CursorPagination)
assert len(page2.data) == 2
assert page2.pagination.has_more is False
class TestCursorPaginateSearchJoins:
"""Tests for cursor_paginate() search that traverses relationships (search_joins)."""
@pytest.mark.anyio
async def test_search_via_relationship(self, db_session: AsyncSession):
"""cursor_paginate outerjoin search-join when searching through a relationship."""
from fastapi_toolsets.schemas import CursorPagination
role_admin = await RoleCrud.create(db_session, RoleCreate(name="administrator"))
role_user = await RoleCrud.create(db_session, RoleCreate(name="regularuser"))
for i in range(3):
await UserCrud.create(
db_session,
UserCreate(
username=f"admin_u{i}",
email=f"admin_u{i}@test.com",
role_id=role_admin.id,
),
)
for i in range(2):
await UserCrud.create(
db_session,
UserCreate(
username=f"reg_u{i}",
email=f"reg_u{i}@test.com",
role_id=role_user.id,
),
)
result = await UserCursorCrud.cursor_paginate(
db_session,
search="administrator",
search_fields=[(User.role, Role.name)],
items_per_page=20,
)
assert isinstance(result.pagination, CursorPagination)
assert len(result.data) == 3
class TestGetWithForUpdate:
"""Tests for get() with with_for_update=True."""
@pytest.mark.anyio
async def test_get_with_for_update(self, db_session: AsyncSession):
"""get() with with_for_update=True locks the row."""
role = await RoleCrud.create(db_session, RoleCreate(name="locked"))
result = await RoleCrud.get(
db_session,
filters=[Role.id == role.id],
with_for_update=True,
)
assert result.id == role.id
assert result.name == "locked"

View File

@@ -6,6 +6,7 @@ import pytest
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from fastapi_toolsets.crud import SearchConfig, get_searchable_fields from fastapi_toolsets.crud import SearchConfig, get_searchable_fields
from fastapi_toolsets.schemas import OffsetPagination
from .conftest import ( from .conftest import (
Role, Role,
@@ -39,6 +40,7 @@ class TestPaginateSearch:
search_fields=[User.username], search_fields=[User.username],
) )
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 2 assert result.pagination.total_count == 2
@pytest.mark.anyio @pytest.mark.anyio
@@ -57,6 +59,7 @@ class TestPaginateSearch:
search_fields=[User.username, User.email], search_fields=[User.username, User.email],
) )
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 2 assert result.pagination.total_count == 2
@pytest.mark.anyio @pytest.mark.anyio
@@ -84,6 +87,7 @@ class TestPaginateSearch:
search_fields=[(User.role, Role.name)], search_fields=[(User.role, Role.name)],
) )
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 2 assert result.pagination.total_count == 2
@pytest.mark.anyio @pytest.mark.anyio
@@ -102,6 +106,7 @@ class TestPaginateSearch:
search_fields=[User.username, (User.role, Role.name)], search_fields=[User.username, (User.role, Role.name)],
) )
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 1 assert result.pagination.total_count == 1
@pytest.mark.anyio @pytest.mark.anyio
@@ -117,6 +122,7 @@ class TestPaginateSearch:
search_fields=[User.username], search_fields=[User.username],
) )
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 1 assert result.pagination.total_count == 1
@pytest.mark.anyio @pytest.mark.anyio
@@ -132,6 +138,7 @@ class TestPaginateSearch:
search=SearchConfig(query="johndoe", case_sensitive=True), search=SearchConfig(query="johndoe", case_sensitive=True),
search_fields=[User.username], search_fields=[User.username],
) )
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 0 assert result.pagination.total_count == 0
# Should find (case match) # Should find (case match)
@@ -140,6 +147,7 @@ class TestPaginateSearch:
search=SearchConfig(query="JohnDoe", case_sensitive=True), search=SearchConfig(query="JohnDoe", case_sensitive=True),
search_fields=[User.username], search_fields=[User.username],
) )
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 1 assert result.pagination.total_count == 1
@pytest.mark.anyio @pytest.mark.anyio
@@ -153,9 +161,11 @@ class TestPaginateSearch:
) )
result = await UserCrud.paginate(db_session, search="") result = await UserCrud.paginate(db_session, search="")
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 2 assert result.pagination.total_count == 2
result = await UserCrud.paginate(db_session, search=None) result = await UserCrud.paginate(db_session, search=None)
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 2 assert result.pagination.total_count == 2
@pytest.mark.anyio @pytest.mark.anyio
@@ -177,6 +187,7 @@ class TestPaginateSearch:
search_fields=[User.username], search_fields=[User.username],
) )
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 1 assert result.pagination.total_count == 1
assert result.data[0].username == "active_john" assert result.data[0].username == "active_john"
@@ -189,6 +200,7 @@ class TestPaginateSearch:
result = await UserCrud.paginate(db_session, search="findme") result = await UserCrud.paginate(db_session, search="findme")
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 1 assert result.pagination.total_count == 1
@pytest.mark.anyio @pytest.mark.anyio
@@ -204,6 +216,7 @@ class TestPaginateSearch:
search_fields=[User.username], search_fields=[User.username],
) )
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 0 assert result.pagination.total_count == 0
assert result.data == [] assert result.data == []
@@ -224,6 +237,7 @@ class TestPaginateSearch:
items_per_page=5, items_per_page=5,
) )
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 15 assert result.pagination.total_count == 15
assert len(result.data) == 5 assert len(result.data) == 5
assert result.pagination.has_more is True assert result.pagination.has_more is True
@@ -248,6 +262,7 @@ class TestPaginateSearch:
search_fields=[User.username], search_fields=[User.username],
) )
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 2 assert result.pagination.total_count == 2
@pytest.mark.anyio @pytest.mark.anyio
@@ -270,6 +285,7 @@ class TestPaginateSearch:
order_by=User.username, order_by=User.username,
) )
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 3 assert result.pagination.total_count == 3
usernames = [u.username for u in result.data] usernames = [u.username for u in result.data]
assert usernames == ["alice", "bob", "charlie"] assert usernames == ["alice", "bob", "charlie"]
@@ -292,6 +308,7 @@ class TestPaginateSearch:
search_fields=[User.id, User.username], search_fields=[User.id, User.username],
) )
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 1 assert result.pagination.total_count == 1
assert result.data[0].id == user_id assert result.data[0].id == user_id
@@ -318,6 +335,7 @@ class TestSearchConfig:
search_fields=[User.username, User.email], search_fields=[User.username, User.email],
) )
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 1 assert result.pagination.total_count == 1
assert result.data[0].username == "john_test" assert result.data[0].username == "john_test"
@@ -333,6 +351,7 @@ class TestSearchConfig:
search=SearchConfig(query="findme", fields=[User.email]), search=SearchConfig(query="findme", fields=[User.email]),
) )
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 1 assert result.pagination.total_count == 1

View File

@@ -5,7 +5,9 @@ from pydantic import ValidationError
from fastapi_toolsets.schemas import ( from fastapi_toolsets.schemas import (
ApiError, ApiError,
CursorPagination,
ErrorResponse, ErrorResponse,
OffsetPagination,
PaginatedResponse, PaginatedResponse,
Pagination, Pagination,
Response, Response,
@@ -154,12 +156,12 @@ class TestErrorResponse:
assert data["description"] == "Details" assert data["description"] == "Details"
class TestPagination: class TestOffsetPagination:
"""Tests for Pagination schema.""" """Tests for OffsetPagination schema (canonical name for offset-based pagination)."""
def test_create_pagination(self): def test_create_pagination(self):
"""Create Pagination with all fields.""" """Create OffsetPagination with all fields."""
pagination = Pagination( pagination = OffsetPagination(
total_count=100, total_count=100,
items_per_page=10, items_per_page=10,
page=1, page=1,
@@ -173,7 +175,7 @@ class TestPagination:
def test_last_page_has_more_false(self): def test_last_page_has_more_false(self):
"""Last page has has_more=False.""" """Last page has has_more=False."""
pagination = Pagination( pagination = OffsetPagination(
total_count=25, total_count=25,
items_per_page=10, items_per_page=10,
page=3, page=3,
@@ -183,8 +185,8 @@ class TestPagination:
assert pagination.has_more is False assert pagination.has_more is False
def test_serialization(self): def test_serialization(self):
"""Pagination serializes correctly.""" """OffsetPagination serializes correctly."""
pagination = Pagination( pagination = OffsetPagination(
total_count=50, total_count=50,
items_per_page=20, items_per_page=20,
page=2, page=2,
@@ -197,6 +199,77 @@ class TestPagination:
assert data["page"] == 2 assert data["page"] == 2
assert data["has_more"] is True assert data["has_more"] is True
def test_pagination_alias_is_offset_pagination(self):
"""Pagination is a backward-compatible alias for OffsetPagination."""
assert Pagination is OffsetPagination
def test_pagination_alias_constructs_offset_pagination(self):
"""Code using Pagination(...) still works unchanged."""
pagination = Pagination(
total_count=10,
items_per_page=5,
page=2,
has_more=False,
)
assert isinstance(pagination, OffsetPagination)
class TestCursorPagination:
"""Tests for CursorPagination schema."""
def test_create_with_next_cursor(self):
"""CursorPagination with a next cursor indicates more pages."""
pagination = CursorPagination(
next_cursor="eyJ2YWx1ZSI6ICIxMjMifQ==",
items_per_page=20,
has_more=True,
)
assert pagination.next_cursor == "eyJ2YWx1ZSI6ICIxMjMifQ=="
assert pagination.prev_cursor is None
assert pagination.items_per_page == 20
assert pagination.has_more is True
def test_create_last_page(self):
"""CursorPagination for the last page has next_cursor=None and has_more=False."""
pagination = CursorPagination(
next_cursor=None,
items_per_page=20,
has_more=False,
)
assert pagination.next_cursor is None
assert pagination.has_more is False
def test_prev_cursor_defaults_to_none(self):
"""prev_cursor defaults to None."""
pagination = CursorPagination(
next_cursor=None, items_per_page=10, has_more=False
)
assert pagination.prev_cursor is None
def test_prev_cursor_can_be_set(self):
"""prev_cursor can be explicitly set."""
pagination = CursorPagination(
next_cursor="next123",
prev_cursor="prev456",
items_per_page=10,
has_more=True,
)
assert pagination.prev_cursor == "prev456"
def test_serialization(self):
"""CursorPagination serializes correctly."""
pagination = CursorPagination(
next_cursor="abc123",
prev_cursor="xyz789",
items_per_page=20,
has_more=True,
)
data = pagination.model_dump()
assert data["next_cursor"] == "abc123"
assert data["prev_cursor"] == "xyz789"
assert data["items_per_page"] == 20
assert data["has_more"] is True
class TestPaginatedResponse: class TestPaginatedResponse:
"""Tests for PaginatedResponse schema.""" """Tests for PaginatedResponse schema."""
@@ -214,6 +287,7 @@ class TestPaginatedResponse:
pagination=pagination, pagination=pagination,
) )
assert isinstance(response.pagination, OffsetPagination)
assert len(response.data) == 2 assert len(response.data) == 2
assert response.pagination.total_count == 30 assert response.pagination.total_count == 30
assert response.status == ResponseStatus.SUCCESS assert response.status == ResponseStatus.SUCCESS
@@ -247,6 +321,7 @@ class TestPaginatedResponse:
pagination=pagination, pagination=pagination,
) )
assert isinstance(response.pagination, OffsetPagination)
assert response.data == [] assert response.data == []
assert response.pagination.total_count == 0 assert response.pagination.total_count == 0
@@ -290,6 +365,36 @@ class TestPaginatedResponse:
assert data["data"] == ["item1", "item2"] assert data["data"] == ["item1", "item2"]
assert data["pagination"]["page"] == 5 assert data["pagination"]["page"] == 5
def test_pagination_field_accepts_offset_pagination(self):
"""PaginatedResponse.pagination accepts OffsetPagination."""
response = PaginatedResponse(
data=[1, 2],
pagination=OffsetPagination(
total_count=2, items_per_page=10, page=1, has_more=False
),
)
assert isinstance(response.pagination, OffsetPagination)
def test_pagination_field_accepts_cursor_pagination(self):
"""PaginatedResponse.pagination accepts CursorPagination."""
response = PaginatedResponse(
data=[1, 2],
pagination=CursorPagination(
next_cursor=None, items_per_page=10, has_more=False
),
)
assert isinstance(response.pagination, CursorPagination)
def test_pagination_alias_accepted(self):
"""Constructing PaginatedResponse with Pagination (alias) still works."""
response = PaginatedResponse(
data=[],
pagination=Pagination(
total_count=0, items_per_page=10, page=1, has_more=False
),
)
assert isinstance(response.pagination, OffsetPagination)
class TestFromAttributes: class TestFromAttributes:
"""Tests for from_attributes config (ORM mode).""" """Tests for from_attributes config (ORM mode)."""