mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-03-01 17:00:48 +01:00
feat: add cursor based pagination in CrudFactory (#86)
This commit is contained in:
@@ -2,12 +2,15 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import uuid as uuid_module
|
||||
import warnings
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import and_, func, select
|
||||
from sqlalchemy import Integer, Uuid, and_, func, select
|
||||
from sqlalchemy import delete as sql_delete
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
@@ -18,7 +21,7 @@ from sqlalchemy.sql.roles import WhereHavingRole
|
||||
|
||||
from ..db import get_transaction
|
||||
from ..exceptions import NotFoundError
|
||||
from ..schemas import PaginatedResponse, Pagination, Response
|
||||
from ..schemas import CursorPagination, OffsetPagination, PaginatedResponse, Response
|
||||
from .search import SearchConfig, SearchFieldType, build_search_filters
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
||||
@@ -27,6 +30,16 @@ JoinType = list[tuple[type[DeclarativeBase], Any]]
|
||||
M2MFieldType = Mapping[str, QueryableAttribute[Any]]
|
||||
|
||||
|
||||
def _encode_cursor(value: Any) -> str:
|
||||
"""Encode cursor column value as an base64 string."""
|
||||
return base64.b64encode(json.dumps(str(value)).encode()).decode()
|
||||
|
||||
|
||||
def _decode_cursor(cursor: str) -> str:
|
||||
"""Decode cursor base64 string."""
|
||||
return json.loads(base64.b64decode(cursor.encode()).decode())
|
||||
|
||||
|
||||
class AsyncCrud(Generic[ModelType]):
|
||||
"""Generic async CRUD operations for SQLAlchemy models.
|
||||
|
||||
@@ -37,6 +50,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
|
||||
m2m_fields: ClassVar[M2MFieldType | None] = None
|
||||
default_load_options: ClassVar[list[ExecutableOption] | None] = None
|
||||
cursor_column: ClassVar[Any | None] = None
|
||||
|
||||
@classmethod
|
||||
def _resolve_load_options(
|
||||
@@ -664,7 +678,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def paginate( # pragma: no cover
|
||||
async def offset_paginate( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
*,
|
||||
@@ -683,7 +697,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
# Backward-compatible - will be removed in v2.0
|
||||
@overload
|
||||
@classmethod
|
||||
async def paginate( # pragma: no cover
|
||||
async def offset_paginate( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
*,
|
||||
@@ -700,7 +714,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
) -> PaginatedResponse[ModelType]: ...
|
||||
|
||||
@classmethod
|
||||
async def paginate(
|
||||
async def offset_paginate(
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
*,
|
||||
@@ -715,7 +729,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
search_fields: Sequence[SearchFieldType] | None = None,
|
||||
schema: type[BaseModel] | None = None,
|
||||
) -> PaginatedResponse[ModelType] | PaginatedResponse[Any]:
|
||||
"""Get paginated results with metadata.
|
||||
"""Get paginated results using offset-based pagination.
|
||||
|
||||
Args:
|
||||
session: DB async session
|
||||
@@ -731,7 +745,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
schema: Optional Pydantic schema to serialize each item into.
|
||||
|
||||
Returns:
|
||||
Dict with 'data' and 'pagination' keys
|
||||
PaginatedResponse with OffsetPagination metadata
|
||||
"""
|
||||
filters = list(filters) if filters else []
|
||||
offset = (page - 1) * items_per_page
|
||||
@@ -803,7 +817,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
|
||||
return PaginatedResponse(
|
||||
data=items,
|
||||
pagination=Pagination(
|
||||
pagination=OffsetPagination(
|
||||
total_count=total_count,
|
||||
items_per_page=items_per_page,
|
||||
page=page,
|
||||
@@ -811,6 +825,175 @@ class AsyncCrud(Generic[ModelType]):
|
||||
),
|
||||
)
|
||||
|
||||
# Backward-compatible - will be removed in v2.0
|
||||
paginate = offset_paginate
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def cursor_paginate( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
*,
|
||||
cursor: str | None = None,
|
||||
filters: list[Any] | None = None,
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
load_options: list[ExecutableOption] | None = None,
|
||||
order_by: Any | None = None,
|
||||
items_per_page: int = 20,
|
||||
search: str | SearchConfig | None = None,
|
||||
search_fields: Sequence[SearchFieldType] | None = None,
|
||||
schema: type[SchemaType],
|
||||
) -> PaginatedResponse[SchemaType]: ...
|
||||
|
||||
# Backward-compatible - will be removed in v2.0
|
||||
@overload
|
||||
@classmethod
|
||||
async def cursor_paginate( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
*,
|
||||
cursor: str | None = None,
|
||||
filters: list[Any] | None = None,
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
load_options: list[ExecutableOption] | None = None,
|
||||
order_by: Any | None = None,
|
||||
items_per_page: int = 20,
|
||||
search: str | SearchConfig | None = None,
|
||||
search_fields: Sequence[SearchFieldType] | None = None,
|
||||
schema: None = ...,
|
||||
) -> PaginatedResponse[ModelType]: ...
|
||||
|
||||
@classmethod
|
||||
async def cursor_paginate(
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
*,
|
||||
cursor: str | None = None,
|
||||
filters: list[Any] | None = None,
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
load_options: list[ExecutableOption] | None = None,
|
||||
order_by: Any | None = None,
|
||||
items_per_page: int = 20,
|
||||
search: str | SearchConfig | None = None,
|
||||
search_fields: Sequence[SearchFieldType] | None = None,
|
||||
schema: type[BaseModel] | None = None,
|
||||
) -> PaginatedResponse[ModelType] | PaginatedResponse[Any]:
|
||||
"""Get paginated results using cursor-based pagination.
|
||||
|
||||
Args:
|
||||
session: DB async session.
|
||||
cursor: Cursor string from a previous ``CursorPagination``.
|
||||
Omit (or pass ``None``) to start from the beginning.
|
||||
filters: List of SQLAlchemy filter conditions.
|
||||
joins: List of ``(model, condition)`` tuples for joining related
|
||||
tables.
|
||||
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN.
|
||||
load_options: SQLAlchemy loader options. Falls back to
|
||||
``default_load_options`` when not provided.
|
||||
order_by: Additional ordering applied after the cursor column.
|
||||
items_per_page: Number of items per page (default 20).
|
||||
search: Search query string or SearchConfig object.
|
||||
search_fields: Fields to search in (overrides class default).
|
||||
schema: Optional Pydantic schema to serialize each item into.
|
||||
|
||||
Returns:
|
||||
PaginatedResponse with CursorPagination metadata
|
||||
"""
|
||||
filters = list(filters) if filters else []
|
||||
search_joins: list[Any] = []
|
||||
|
||||
if cls.cursor_column is None:
|
||||
raise ValueError(
|
||||
f"{cls.__name__}.cursor_column is not set. "
|
||||
"Pass cursor_column=<column> to CrudFactory() to use cursor_paginate."
|
||||
)
|
||||
cursor_column: Any = cls.cursor_column
|
||||
cursor_col_name: str = cursor_column.key
|
||||
|
||||
if cursor is not None:
|
||||
raw_val = _decode_cursor(cursor)
|
||||
col_type = cursor_column.property.columns[0].type
|
||||
if isinstance(col_type, Integer):
|
||||
cursor_val: Any = int(raw_val)
|
||||
elif isinstance(col_type, Uuid):
|
||||
cursor_val = uuid_module.UUID(raw_val)
|
||||
else:
|
||||
cursor_val = raw_val
|
||||
filters.append(cursor_column > cursor_val)
|
||||
|
||||
# Build search filters
|
||||
if search:
|
||||
search_filters, search_joins = build_search_filters(
|
||||
cls.model,
|
||||
search,
|
||||
search_fields=search_fields,
|
||||
default_fields=cls.searchable_fields,
|
||||
)
|
||||
filters.extend(search_filters)
|
||||
|
||||
# Build query
|
||||
q = select(cls.model)
|
||||
|
||||
# Apply explicit joins
|
||||
if joins:
|
||||
for model, condition in joins:
|
||||
q = (
|
||||
q.outerjoin(model, condition)
|
||||
if outer_join
|
||||
else q.join(model, condition)
|
||||
)
|
||||
|
||||
# Apply search joins (always outer joins)
|
||||
for join_rel in search_joins:
|
||||
q = q.outerjoin(join_rel)
|
||||
|
||||
if filters:
|
||||
q = q.where(and_(*filters))
|
||||
if resolved := cls._resolve_load_options(load_options):
|
||||
q = q.options(*resolved)
|
||||
|
||||
# Cursor column is always the primary sort
|
||||
q = q.order_by(cursor_column)
|
||||
if order_by is not None:
|
||||
q = q.order_by(order_by)
|
||||
|
||||
# Fetch one extra to detect whether a next page exists
|
||||
q = q.limit(items_per_page + 1)
|
||||
result = await session.execute(q)
|
||||
raw_items = cast(list[ModelType], result.unique().scalars().all())
|
||||
|
||||
has_more = len(raw_items) > items_per_page
|
||||
items_page = raw_items[:items_per_page]
|
||||
|
||||
# next_cursor points past the last item on this page
|
||||
next_cursor: str | None = None
|
||||
if has_more and items_page:
|
||||
next_cursor = _encode_cursor(getattr(items_page[-1], cursor_col_name))
|
||||
|
||||
# prev_cursor points to the first item on this page or None when on the first page
|
||||
prev_cursor: str | None = None
|
||||
if cursor is not None and items_page:
|
||||
prev_cursor = _encode_cursor(getattr(items_page[0], cursor_col_name))
|
||||
|
||||
items: list[Any] = (
|
||||
[schema.model_validate(item) for item in items_page]
|
||||
if schema
|
||||
else items_page
|
||||
)
|
||||
|
||||
return PaginatedResponse(
|
||||
data=items,
|
||||
pagination=CursorPagination(
|
||||
next_cursor=next_cursor,
|
||||
prev_cursor=prev_cursor,
|
||||
items_per_page=items_per_page,
|
||||
has_more=has_more,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def CrudFactory(
|
||||
model: type[ModelType],
|
||||
@@ -818,6 +1001,7 @@ def CrudFactory(
|
||||
searchable_fields: Sequence[SearchFieldType] | None = None,
|
||||
m2m_fields: M2MFieldType | None = None,
|
||||
default_load_options: list[ExecutableOption] | None = None,
|
||||
cursor_column: Any | None = None,
|
||||
) -> type[AsyncCrud[ModelType]]:
|
||||
"""Create a CRUD class for a specific model.
|
||||
|
||||
@@ -832,6 +1016,8 @@ def CrudFactory(
|
||||
instead of ``lazy="selectin"`` on the model so that loading
|
||||
strategy is explicit and per-CRUD. Overridden entirely (not
|
||||
merged) when ``load_options`` is provided at call-site.
|
||||
cursor_column: Required to call ``cursor_paginate``
|
||||
Must be monotonically ordered (e.g. integer PK, UUID v7, timestamp).
|
||||
|
||||
Returns:
|
||||
AsyncCrud subclass bound to the model
|
||||
@@ -857,6 +1043,12 @@ def CrudFactory(
|
||||
m2m_fields={"tag_ids": Post.tags},
|
||||
)
|
||||
|
||||
# With a fixed cursor column for cursor_paginate:
|
||||
PostCrud = CrudFactory(
|
||||
Post,
|
||||
cursor_column=Post.created_at,
|
||||
)
|
||||
|
||||
# With default load strategy (replaces lazy="selectin" on the model):
|
||||
ArticleCrud = CrudFactory(
|
||||
Article,
|
||||
@@ -878,7 +1070,7 @@ def CrudFactory(
|
||||
post = await PostCrud.create(session, PostCreate(title="Hello", tag_ids=[id1, id2]))
|
||||
|
||||
# With search
|
||||
result = await UserCrud.paginate(session, search="john")
|
||||
result = await UserCrud.offset_paginate(session, search="john")
|
||||
|
||||
# With joins (inner join by default):
|
||||
users = await UserCrud.get_multi(
|
||||
@@ -903,6 +1095,7 @@ def CrudFactory(
|
||||
"searchable_fields": searchable_fields,
|
||||
"m2m_fields": m2m_fields,
|
||||
"default_load_options": default_load_options,
|
||||
"cursor_column": cursor_column,
|
||||
},
|
||||
)
|
||||
return cast(type[AsyncCrud[ModelType]], cls)
|
||||
|
||||
Reference in New Issue
Block a user