feat: add cursor based pagination in CrudFactory (#86)

This commit is contained in:
d3vyce
2026-02-23 13:51:34 +01:00
committed by GitHub
parent 7482bc5dad
commit 6cf7df55ef
7 changed files with 1003 additions and 41 deletions

View File

@@ -2,12 +2,15 @@
from __future__ import annotations
import base64
import json
import uuid as uuid_module
import warnings
from collections.abc import Mapping, Sequence
from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload
from pydantic import BaseModel
from sqlalchemy import and_, func, select
from sqlalchemy import Integer, Uuid, and_, func, select
from sqlalchemy import delete as sql_delete
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.exc import NoResultFound
@@ -18,7 +21,7 @@ from sqlalchemy.sql.roles import WhereHavingRole
from ..db import get_transaction
from ..exceptions import NotFoundError
from ..schemas import PaginatedResponse, Pagination, Response
from ..schemas import CursorPagination, OffsetPagination, PaginatedResponse, Response
from .search import SearchConfig, SearchFieldType, build_search_filters
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
@@ -27,6 +30,16 @@ JoinType = list[tuple[type[DeclarativeBase], Any]]
M2MFieldType = Mapping[str, QueryableAttribute[Any]]
def _encode_cursor(value: Any) -> str:
"""Encode cursor column value as an base64 string."""
return base64.b64encode(json.dumps(str(value)).encode()).decode()
def _decode_cursor(cursor: str) -> str:
"""Decode cursor base64 string."""
return json.loads(base64.b64decode(cursor.encode()).decode())
class AsyncCrud(Generic[ModelType]):
"""Generic async CRUD operations for SQLAlchemy models.
@@ -37,6 +50,7 @@ class AsyncCrud(Generic[ModelType]):
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
m2m_fields: ClassVar[M2MFieldType | None] = None
default_load_options: ClassVar[list[ExecutableOption] | None] = None
cursor_column: ClassVar[Any | None] = None
@classmethod
def _resolve_load_options(
@@ -664,7 +678,7 @@ class AsyncCrud(Generic[ModelType]):
@overload
@classmethod
async def paginate( # pragma: no cover
async def offset_paginate( # pragma: no cover
cls: type[Self],
session: AsyncSession,
*,
@@ -683,7 +697,7 @@ class AsyncCrud(Generic[ModelType]):
# Backward-compatible - will be removed in v2.0
@overload
@classmethod
async def paginate( # pragma: no cover
async def offset_paginate( # pragma: no cover
cls: type[Self],
session: AsyncSession,
*,
@@ -700,7 +714,7 @@ class AsyncCrud(Generic[ModelType]):
) -> PaginatedResponse[ModelType]: ...
@classmethod
async def paginate(
async def offset_paginate(
cls: type[Self],
session: AsyncSession,
*,
@@ -715,7 +729,7 @@ class AsyncCrud(Generic[ModelType]):
search_fields: Sequence[SearchFieldType] | None = None,
schema: type[BaseModel] | None = None,
) -> PaginatedResponse[ModelType] | PaginatedResponse[Any]:
"""Get paginated results with metadata.
"""Get paginated results using offset-based pagination.
Args:
session: DB async session
@@ -731,7 +745,7 @@ class AsyncCrud(Generic[ModelType]):
schema: Optional Pydantic schema to serialize each item into.
Returns:
Dict with 'data' and 'pagination' keys
PaginatedResponse with OffsetPagination metadata
"""
filters = list(filters) if filters else []
offset = (page - 1) * items_per_page
@@ -803,7 +817,7 @@ class AsyncCrud(Generic[ModelType]):
return PaginatedResponse(
data=items,
pagination=Pagination(
pagination=OffsetPagination(
total_count=total_count,
items_per_page=items_per_page,
page=page,
@@ -811,6 +825,175 @@ class AsyncCrud(Generic[ModelType]):
),
)
# Backward-compatible - will be removed in v2.0
paginate = offset_paginate
@overload
@classmethod
async def cursor_paginate( # pragma: no cover
cls: type[Self],
session: AsyncSession,
*,
cursor: str | None = None,
filters: list[Any] | None = None,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: list[ExecutableOption] | None = None,
order_by: Any | None = None,
items_per_page: int = 20,
search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None,
schema: type[SchemaType],
) -> PaginatedResponse[SchemaType]: ...
# Backward-compatible - will be removed in v2.0
@overload
@classmethod
async def cursor_paginate( # pragma: no cover
cls: type[Self],
session: AsyncSession,
*,
cursor: str | None = None,
filters: list[Any] | None = None,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: list[ExecutableOption] | None = None,
order_by: Any | None = None,
items_per_page: int = 20,
search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None,
schema: None = ...,
) -> PaginatedResponse[ModelType]: ...
@classmethod
async def cursor_paginate(
cls: type[Self],
session: AsyncSession,
*,
cursor: str | None = None,
filters: list[Any] | None = None,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: list[ExecutableOption] | None = None,
order_by: Any | None = None,
items_per_page: int = 20,
search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None,
schema: type[BaseModel] | None = None,
) -> PaginatedResponse[ModelType] | PaginatedResponse[Any]:
"""Get paginated results using cursor-based pagination.
Args:
session: DB async session.
cursor: Cursor string from a previous ``CursorPagination``.
Omit (or pass ``None``) to start from the beginning.
filters: List of SQLAlchemy filter conditions.
joins: List of ``(model, condition)`` tuples for joining related
tables.
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN.
load_options: SQLAlchemy loader options. Falls back to
``default_load_options`` when not provided.
order_by: Additional ordering applied after the cursor column.
items_per_page: Number of items per page (default 20).
search: Search query string or SearchConfig object.
search_fields: Fields to search in (overrides class default).
schema: Optional Pydantic schema to serialize each item into.
Returns:
PaginatedResponse with CursorPagination metadata
"""
filters = list(filters) if filters else []
search_joins: list[Any] = []
if cls.cursor_column is None:
raise ValueError(
f"{cls.__name__}.cursor_column is not set. "
"Pass cursor_column=<column> to CrudFactory() to use cursor_paginate."
)
cursor_column: Any = cls.cursor_column
cursor_col_name: str = cursor_column.key
if cursor is not None:
raw_val = _decode_cursor(cursor)
col_type = cursor_column.property.columns[0].type
if isinstance(col_type, Integer):
cursor_val: Any = int(raw_val)
elif isinstance(col_type, Uuid):
cursor_val = uuid_module.UUID(raw_val)
else:
cursor_val = raw_val
filters.append(cursor_column > cursor_val)
# Build search filters
if search:
search_filters, search_joins = build_search_filters(
cls.model,
search,
search_fields=search_fields,
default_fields=cls.searchable_fields,
)
filters.extend(search_filters)
# Build query
q = select(cls.model)
# Apply explicit joins
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
# Apply search joins (always outer joins)
for join_rel in search_joins:
q = q.outerjoin(join_rel)
if filters:
q = q.where(and_(*filters))
if resolved := cls._resolve_load_options(load_options):
q = q.options(*resolved)
# Cursor column is always the primary sort
q = q.order_by(cursor_column)
if order_by is not None:
q = q.order_by(order_by)
# Fetch one extra to detect whether a next page exists
q = q.limit(items_per_page + 1)
result = await session.execute(q)
raw_items = cast(list[ModelType], result.unique().scalars().all())
has_more = len(raw_items) > items_per_page
items_page = raw_items[:items_per_page]
# next_cursor points past the last item on this page
next_cursor: str | None = None
if has_more and items_page:
next_cursor = _encode_cursor(getattr(items_page[-1], cursor_col_name))
# prev_cursor points to the first item on this page or None when on the first page
prev_cursor: str | None = None
if cursor is not None and items_page:
prev_cursor = _encode_cursor(getattr(items_page[0], cursor_col_name))
items: list[Any] = (
[schema.model_validate(item) for item in items_page]
if schema
else items_page
)
return PaginatedResponse(
data=items,
pagination=CursorPagination(
next_cursor=next_cursor,
prev_cursor=prev_cursor,
items_per_page=items_per_page,
has_more=has_more,
),
)
def CrudFactory(
model: type[ModelType],
@@ -818,6 +1001,7 @@ def CrudFactory(
searchable_fields: Sequence[SearchFieldType] | None = None,
m2m_fields: M2MFieldType | None = None,
default_load_options: list[ExecutableOption] | None = None,
cursor_column: Any | None = None,
) -> type[AsyncCrud[ModelType]]:
"""Create a CRUD class for a specific model.
@@ -832,6 +1016,8 @@ def CrudFactory(
instead of ``lazy="selectin"`` on the model so that loading
strategy is explicit and per-CRUD. Overridden entirely (not
merged) when ``load_options`` is provided at call-site.
cursor_column: Required to call ``cursor_paginate``
Must be monotonically ordered (e.g. integer PK, UUID v7, timestamp).
Returns:
AsyncCrud subclass bound to the model
@@ -857,6 +1043,12 @@ def CrudFactory(
m2m_fields={"tag_ids": Post.tags},
)
# With a fixed cursor column for cursor_paginate:
PostCrud = CrudFactory(
Post,
cursor_column=Post.created_at,
)
# With default load strategy (replaces lazy="selectin" on the model):
ArticleCrud = CrudFactory(
Article,
@@ -878,7 +1070,7 @@ def CrudFactory(
post = await PostCrud.create(session, PostCreate(title="Hello", tag_ids=[id1, id2]))
# With search
result = await UserCrud.paginate(session, search="john")
result = await UserCrud.offset_paginate(session, search="john")
# With joins (inner join by default):
users = await UserCrud.get_multi(
@@ -903,6 +1095,7 @@ def CrudFactory(
"searchable_fields": searchable_fields,
"m2m_fields": m2m_fields,
"default_load_options": default_load_options,
"cursor_column": cursor_column,
},
)
return cast(type[AsyncCrud[ModelType]], cls)