diff --git a/docs/module/crud.md b/docs/module/crud.md index f7cae44..9f34ecf 100644 --- a/docs/module/crud.md +++ b/docs/module/crud.md @@ -1,6 +1,6 @@ # CRUD -Generic async CRUD operations for SQLAlchemy models with search, pagination, and many-to-many support. This module has features that are only compatible with Postgres. +Generic async CRUD operations for SQLAlchemy models with search, pagination, and many-to-many support. !!! info This module has been coded and tested to be compatible with PostgreSQL only. @@ -48,6 +48,21 @@ exists = await UserCrud.exists(session=session, filters=[User.email == email]) ## Pagination +!!! info "Added in `v1.1` (only offset_pagination via `paginate` if ` Response[UserRead]: @router.get("") async def list_users(session: SessionDep, page: int = 1) -> PaginatedResponse[UserRead]: - return await crud.UserCrud.paginate( + return await crud.UserCrud.offset_paginate( session=session, page=page, schema=UserRead, diff --git a/src/fastapi_toolsets/crud/factory.py b/src/fastapi_toolsets/crud/factory.py index 4927e36..6323af6 100644 --- a/src/fastapi_toolsets/crud/factory.py +++ b/src/fastapi_toolsets/crud/factory.py @@ -2,12 +2,15 @@ from __future__ import annotations +import base64 +import json +import uuid as uuid_module import warnings from collections.abc import Mapping, Sequence from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload from pydantic import BaseModel -from sqlalchemy import and_, func, select +from sqlalchemy import Integer, Uuid, and_, func, select from sqlalchemy import delete as sql_delete from sqlalchemy.dialects.postgresql import insert from sqlalchemy.exc import NoResultFound @@ -18,7 +21,7 @@ from sqlalchemy.sql.roles import WhereHavingRole from ..db import get_transaction from ..exceptions import NotFoundError -from ..schemas import PaginatedResponse, Pagination, Response +from ..schemas import CursorPagination, OffsetPagination, PaginatedResponse, Response from .search import SearchConfig, SearchFieldType, build_search_filters ModelType = TypeVar("ModelType", bound=DeclarativeBase) @@ -27,6 +30,16 @@ JoinType = list[tuple[type[DeclarativeBase], Any]] M2MFieldType = Mapping[str, QueryableAttribute[Any]] +def _encode_cursor(value: Any) -> str: + """Encode cursor column value as an base64 string.""" + return base64.b64encode(json.dumps(str(value)).encode()).decode() + + +def _decode_cursor(cursor: str) -> str: + """Decode cursor base64 string.""" + return json.loads(base64.b64decode(cursor.encode()).decode()) + + class AsyncCrud(Generic[ModelType]): """Generic async CRUD operations for SQLAlchemy models. @@ -37,6 +50,7 @@ class AsyncCrud(Generic[ModelType]): searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None m2m_fields: ClassVar[M2MFieldType | None] = None default_load_options: ClassVar[list[ExecutableOption] | None] = None + cursor_column: ClassVar[Any | None] = None @classmethod def _resolve_load_options( @@ -664,7 +678,7 @@ class AsyncCrud(Generic[ModelType]): @overload @classmethod - async def paginate( # pragma: no cover + async def offset_paginate( # pragma: no cover cls: type[Self], session: AsyncSession, *, @@ -683,7 +697,7 @@ class AsyncCrud(Generic[ModelType]): # Backward-compatible - will be removed in v2.0 @overload @classmethod - async def paginate( # pragma: no cover + async def offset_paginate( # pragma: no cover cls: type[Self], session: AsyncSession, *, @@ -700,7 +714,7 @@ class AsyncCrud(Generic[ModelType]): ) -> PaginatedResponse[ModelType]: ... @classmethod - async def paginate( + async def offset_paginate( cls: type[Self], session: AsyncSession, *, @@ -715,7 +729,7 @@ class AsyncCrud(Generic[ModelType]): search_fields: Sequence[SearchFieldType] | None = None, schema: type[BaseModel] | None = None, ) -> PaginatedResponse[ModelType] | PaginatedResponse[Any]: - """Get paginated results with metadata. + """Get paginated results using offset-based pagination. Args: session: DB async session @@ -731,7 +745,7 @@ class AsyncCrud(Generic[ModelType]): schema: Optional Pydantic schema to serialize each item into. Returns: - Dict with 'data' and 'pagination' keys + PaginatedResponse with OffsetPagination metadata """ filters = list(filters) if filters else [] offset = (page - 1) * items_per_page @@ -803,7 +817,7 @@ class AsyncCrud(Generic[ModelType]): return PaginatedResponse( data=items, - pagination=Pagination( + pagination=OffsetPagination( total_count=total_count, items_per_page=items_per_page, page=page, @@ -811,6 +825,175 @@ class AsyncCrud(Generic[ModelType]): ), ) + # Backward-compatible - will be removed in v2.0 + paginate = offset_paginate + + @overload + @classmethod + async def cursor_paginate( # pragma: no cover + cls: type[Self], + session: AsyncSession, + *, + cursor: str | None = None, + filters: list[Any] | None = None, + joins: JoinType | None = None, + outer_join: bool = False, + load_options: list[ExecutableOption] | None = None, + order_by: Any | None = None, + items_per_page: int = 20, + search: str | SearchConfig | None = None, + search_fields: Sequence[SearchFieldType] | None = None, + schema: type[SchemaType], + ) -> PaginatedResponse[SchemaType]: ... + + # Backward-compatible - will be removed in v2.0 + @overload + @classmethod + async def cursor_paginate( # pragma: no cover + cls: type[Self], + session: AsyncSession, + *, + cursor: str | None = None, + filters: list[Any] | None = None, + joins: JoinType | None = None, + outer_join: bool = False, + load_options: list[ExecutableOption] | None = None, + order_by: Any | None = None, + items_per_page: int = 20, + search: str | SearchConfig | None = None, + search_fields: Sequence[SearchFieldType] | None = None, + schema: None = ..., + ) -> PaginatedResponse[ModelType]: ... + + @classmethod + async def cursor_paginate( + cls: type[Self], + session: AsyncSession, + *, + cursor: str | None = None, + filters: list[Any] | None = None, + joins: JoinType | None = None, + outer_join: bool = False, + load_options: list[ExecutableOption] | None = None, + order_by: Any | None = None, + items_per_page: int = 20, + search: str | SearchConfig | None = None, + search_fields: Sequence[SearchFieldType] | None = None, + schema: type[BaseModel] | None = None, + ) -> PaginatedResponse[ModelType] | PaginatedResponse[Any]: + """Get paginated results using cursor-based pagination. + + Args: + session: DB async session. + cursor: Cursor string from a previous ``CursorPagination``. + Omit (or pass ``None``) to start from the beginning. + filters: List of SQLAlchemy filter conditions. + joins: List of ``(model, condition)`` tuples for joining related + tables. + outer_join: Use LEFT OUTER JOIN instead of INNER JOIN. + load_options: SQLAlchemy loader options. Falls back to + ``default_load_options`` when not provided. + order_by: Additional ordering applied after the cursor column. + items_per_page: Number of items per page (default 20). + search: Search query string or SearchConfig object. + search_fields: Fields to search in (overrides class default). + schema: Optional Pydantic schema to serialize each item into. + + Returns: + PaginatedResponse with CursorPagination metadata + """ + filters = list(filters) if filters else [] + search_joins: list[Any] = [] + + if cls.cursor_column is None: + raise ValueError( + f"{cls.__name__}.cursor_column is not set. " + "Pass cursor_column= to CrudFactory() to use cursor_paginate." + ) + cursor_column: Any = cls.cursor_column + cursor_col_name: str = cursor_column.key + + if cursor is not None: + raw_val = _decode_cursor(cursor) + col_type = cursor_column.property.columns[0].type + if isinstance(col_type, Integer): + cursor_val: Any = int(raw_val) + elif isinstance(col_type, Uuid): + cursor_val = uuid_module.UUID(raw_val) + else: + cursor_val = raw_val + filters.append(cursor_column > cursor_val) + + # Build search filters + if search: + search_filters, search_joins = build_search_filters( + cls.model, + search, + search_fields=search_fields, + default_fields=cls.searchable_fields, + ) + filters.extend(search_filters) + + # Build query + q = select(cls.model) + + # Apply explicit joins + if joins: + for model, condition in joins: + q = ( + q.outerjoin(model, condition) + if outer_join + else q.join(model, condition) + ) + + # Apply search joins (always outer joins) + for join_rel in search_joins: + q = q.outerjoin(join_rel) + + if filters: + q = q.where(and_(*filters)) + if resolved := cls._resolve_load_options(load_options): + q = q.options(*resolved) + + # Cursor column is always the primary sort + q = q.order_by(cursor_column) + if order_by is not None: + q = q.order_by(order_by) + + # Fetch one extra to detect whether a next page exists + q = q.limit(items_per_page + 1) + result = await session.execute(q) + raw_items = cast(list[ModelType], result.unique().scalars().all()) + + has_more = len(raw_items) > items_per_page + items_page = raw_items[:items_per_page] + + # next_cursor points past the last item on this page + next_cursor: str | None = None + if has_more and items_page: + next_cursor = _encode_cursor(getattr(items_page[-1], cursor_col_name)) + + # prev_cursor points to the first item on this page or None when on the first page + prev_cursor: str | None = None + if cursor is not None and items_page: + prev_cursor = _encode_cursor(getattr(items_page[0], cursor_col_name)) + + items: list[Any] = ( + [schema.model_validate(item) for item in items_page] + if schema + else items_page + ) + + return PaginatedResponse( + data=items, + pagination=CursorPagination( + next_cursor=next_cursor, + prev_cursor=prev_cursor, + items_per_page=items_per_page, + has_more=has_more, + ), + ) + def CrudFactory( model: type[ModelType], @@ -818,6 +1001,7 @@ def CrudFactory( searchable_fields: Sequence[SearchFieldType] | None = None, m2m_fields: M2MFieldType | None = None, default_load_options: list[ExecutableOption] | None = None, + cursor_column: Any | None = None, ) -> type[AsyncCrud[ModelType]]: """Create a CRUD class for a specific model. @@ -832,6 +1016,8 @@ def CrudFactory( instead of ``lazy="selectin"`` on the model so that loading strategy is explicit and per-CRUD. Overridden entirely (not merged) when ``load_options`` is provided at call-site. + cursor_column: Required to call ``cursor_paginate`` + Must be monotonically ordered (e.g. integer PK, UUID v7, timestamp). Returns: AsyncCrud subclass bound to the model @@ -857,6 +1043,12 @@ def CrudFactory( m2m_fields={"tag_ids": Post.tags}, ) + # With a fixed cursor column for cursor_paginate: + PostCrud = CrudFactory( + Post, + cursor_column=Post.created_at, + ) + # With default load strategy (replaces lazy="selectin" on the model): ArticleCrud = CrudFactory( Article, @@ -878,7 +1070,7 @@ def CrudFactory( post = await PostCrud.create(session, PostCreate(title="Hello", tag_ids=[id1, id2])) # With search - result = await UserCrud.paginate(session, search="john") + result = await UserCrud.offset_paginate(session, search="john") # With joins (inner join by default): users = await UserCrud.get_multi( @@ -903,6 +1095,7 @@ def CrudFactory( "searchable_fields": searchable_fields, "m2m_fields": m2m_fields, "default_load_options": default_load_options, + "cursor_column": cursor_column, }, ) return cast(type[AsyncCrud[ModelType]], cls) diff --git a/src/fastapi_toolsets/schemas.py b/src/fastapi_toolsets/schemas.py index 787f691..0a0070f 100644 --- a/src/fastapi_toolsets/schemas.py +++ b/src/fastapi_toolsets/schemas.py @@ -7,7 +7,9 @@ from pydantic import BaseModel, ConfigDict __all__ = [ "ApiError", + "CursorPagination", "ErrorResponse", + "OffsetPagination", "Pagination", "PaginatedResponse", "PydanticBase", @@ -90,8 +92,8 @@ class ErrorResponse(BaseResponse): data: Any | None = None -class Pagination(PydanticBase): - """Pagination metadata for list responses. +class OffsetPagination(PydanticBase): + """Pagination metadata for offset-based list responses. Attributes: total_count: Total number of items across all pages @@ -106,17 +108,28 @@ class Pagination(PydanticBase): has_more: bool -class PaginatedResponse(BaseResponse, Generic[DataT]): - """Paginated API response for list endpoints. +# Backward-compatible - will be removed in v2.0 +Pagination = OffsetPagination - Example: - ```python - PaginatedResponse[UserRead]( - data=users, - pagination=Pagination(total_count=100, items_per_page=10, page=1, has_more=True) - ) - ``` + +class CursorPagination(PydanticBase): + """Pagination metadata for cursor-based list responses. + + Attributes: + next_cursor: Encoded cursor for the next page, or None on the last page. + prev_cursor: Encoded cursor for the previous page, or None on the first page. + items_per_page: Number of items requested per page. + has_more: Whether there is at least one more page after this one. """ + next_cursor: str | None + prev_cursor: str | None = None + items_per_page: int + has_more: bool + + +class PaginatedResponse(BaseResponse, Generic[DataT]): + """Paginated API response for list endpoints.""" + data: list[DataT] - pagination: Pagination + pagination: OffsetPagination | CursorPagination diff --git a/tests/conftest.py b/tests/conftest.py index aafe043..0c96388 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,13 +5,12 @@ import uuid import pytest from pydantic import BaseModel -from sqlalchemy import Column, ForeignKey, String, Table, Uuid - -from fastapi_toolsets.schemas import PydanticBase +from sqlalchemy import Column, ForeignKey, Integer, String, Table, Uuid from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship from fastapi_toolsets.crud import CrudFactory +from fastapi_toolsets.schemas import PydanticBase DATABASE_URL = os.getenv( key="DATABASE_URL", @@ -71,6 +70,15 @@ post_tags = Table( ) +class IntRole(Base): + """Test role model with auto-increment integer PK.""" + + __tablename__ = "int_roles" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + name: Mapped[str] = mapped_column(String(50), unique=True) + + class Post(Base): """Test post model.""" @@ -116,7 +124,7 @@ class UserCreate(BaseModel): class UserRead(PydanticBase): - """Schema for reading a user (subset of fields).""" + """Schema for reading a user (subset of fields — no email).""" id: uuid.UUID username: str @@ -176,8 +184,17 @@ class PostM2MUpdate(BaseModel): tag_ids: list[uuid.UUID] | None = None +class IntRoleCreate(BaseModel): + """Schema for creating an IntRole.""" + + name: str + + RoleCrud = CrudFactory(Role) +RoleCursorCrud = CrudFactory(Role, cursor_column=Role.id) +IntRoleCursorCrud = CrudFactory(IntRole, cursor_column=IntRole.id) UserCrud = CrudFactory(User) +UserCursorCrud = CrudFactory(User, cursor_column=User.id) PostCrud = CrudFactory(Post) TagCrud = CrudFactory(Tag) PostM2MCrud = CrudFactory(Post, m2m_fields={"tag_ids": Post.tags}) diff --git a/tests/test_crud.py b/tests/test_crud.py index 0aa2df8..b15d011 100644 --- a/tests/test_crud.py +++ b/tests/test_crud.py @@ -11,6 +11,8 @@ from fastapi_toolsets.crud.factory import AsyncCrud from fastapi_toolsets.exceptions import NotFoundError from .conftest import ( + IntRoleCreate, + IntRoleCursorCrud, Post, PostCreate, PostCrud, @@ -20,6 +22,7 @@ from .conftest import ( Role, RoleCreate, RoleCrud, + RoleCursorCrud, RoleRead, RoleUpdate, TagCreate, @@ -27,6 +30,7 @@ from .conftest import ( User, UserCreate, UserCrud, + UserCursorCrud, UserRead, UserUpdate, ) @@ -581,8 +585,11 @@ class TestCrudPaginate: for i in range(25): await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}")) + from fastapi_toolsets.schemas import OffsetPagination + result = await RoleCrud.paginate(db_session, page=1, items_per_page=10) + assert isinstance(result.pagination, OffsetPagination) assert len(result.data) == 10 assert result.pagination.total_count == 25 assert result.pagination.page == 1 @@ -613,6 +620,8 @@ class TestCrudPaginate: ), ) + from fastapi_toolsets.schemas import OffsetPagination + result = await UserCrud.paginate( db_session, filters=[User.is_active == True], # noqa: E712 @@ -620,6 +629,7 @@ class TestCrudPaginate: items_per_page=10, ) + assert isinstance(result.pagination, OffsetPagination) assert result.pagination.total_count == 5 @pytest.mark.anyio @@ -835,6 +845,8 @@ class TestCrudJoins: ), ) + from fastapi_toolsets.schemas import OffsetPagination + # Paginate users with published posts result = await UserCrud.paginate( db_session, @@ -844,6 +856,7 @@ class TestCrudJoins: items_per_page=10, ) + assert isinstance(result.pagination, OffsetPagination) assert result.pagination.total_count == 3 assert len(result.data) == 3 @@ -866,6 +879,8 @@ class TestCrudJoins: UserCreate(username="without_post", email="without@test.com"), ) + from fastapi_toolsets.schemas import OffsetPagination + # Paginate with outer join result = await UserCrud.paginate( db_session, @@ -875,6 +890,7 @@ class TestCrudJoins: items_per_page=10, ) + assert isinstance(result.pagination, OffsetPagination) assert result.pagination.total_count == 2 assert len(result.data) == 2 @@ -1512,3 +1528,494 @@ class TestSchemaResponse: assert isinstance(result, Response) assert isinstance(result.data, RoleRead) + + +class TestPaginateAlias: + """Tests that paginate is a backward-compatible alias for offset_paginate.""" + + def test_paginate_is_alias_of_offset_paginate(self): + """paginate and offset_paginate are the same underlying function.""" + assert RoleCrud.paginate.__func__ is RoleCrud.offset_paginate.__func__ + + @pytest.mark.anyio + async def test_paginate_alias_returns_offset_pagination( + self, db_session: AsyncSession + ): + """paginate() still works and returns PaginatedResponse with OffsetPagination.""" + from fastapi_toolsets.schemas import OffsetPagination, PaginatedResponse + + for i in range(3): + await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}")) + + result = await RoleCrud.paginate(db_session, page=1, items_per_page=10) + + assert isinstance(result, PaginatedResponse) + assert isinstance(result.pagination, OffsetPagination) + assert result.pagination.total_count == 3 + assert result.pagination.page == 1 + + +class TestCursorPaginate: + """Tests for cursor-based pagination via cursor_paginate().""" + + @pytest.mark.anyio + async def test_first_page_no_cursor(self, db_session: AsyncSession): + """cursor_paginate without cursor returns the first page.""" + from fastapi_toolsets.schemas import CursorPagination, PaginatedResponse + + for i in range(25): + await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}")) + + result = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=10) + + assert isinstance(result, PaginatedResponse) + assert isinstance(result.pagination, CursorPagination) + assert len(result.data) == 10 + assert result.pagination.has_more is True + assert result.pagination.next_cursor is not None + assert result.pagination.prev_cursor is None + assert result.pagination.items_per_page == 10 + + @pytest.mark.anyio + async def test_last_page(self, db_session: AsyncSession): + """cursor_paginate returns has_more=False and next_cursor=None on last page.""" + from fastapi_toolsets.schemas import CursorPagination + + for i in range(5): + await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}")) + + result = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=10) + + assert isinstance(result.pagination, CursorPagination) + assert len(result.data) == 5 + assert result.pagination.has_more is False + assert result.pagination.next_cursor is None + + @pytest.mark.anyio + async def test_advances_correctly(self, db_session: AsyncSession): + """Providing next_cursor from the first page returns the next page.""" + for i in range(15): + await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}")) + + from fastapi_toolsets.schemas import CursorPagination + + page1 = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=10) + assert isinstance(page1.pagination, CursorPagination) + assert len(page1.data) == 10 + assert page1.pagination.has_more is True + + cursor = page1.pagination.next_cursor + page2 = await RoleCursorCrud.cursor_paginate( + db_session, cursor=cursor, items_per_page=10 + ) + assert isinstance(page2.pagination, CursorPagination) + assert len(page2.data) == 5 + assert page2.pagination.has_more is False + assert page2.pagination.next_cursor is None + + @pytest.mark.anyio + async def test_no_duplicates_across_pages(self, db_session: AsyncSession): + """Items from consecutive cursor pages are non-overlapping and cover all rows.""" + for i in range(7): + await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}")) + + from fastapi_toolsets.schemas import CursorPagination + + page1 = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=4) + assert isinstance(page1.pagination, CursorPagination) + page2 = await RoleCursorCrud.cursor_paginate( + db_session, + cursor=page1.pagination.next_cursor, + items_per_page=4, + ) + + ids_page1 = {r.id for r in page1.data} + ids_page2 = {r.id for r in page2.data} + assert ids_page1.isdisjoint(ids_page2) + assert len(ids_page1 | ids_page2) == 7 + + @pytest.mark.anyio + async def test_empty_table(self, db_session: AsyncSession): + """cursor_paginate on an empty table returns empty data with no cursor.""" + from fastapi_toolsets.schemas import CursorPagination + + result = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=10) + + assert isinstance(result.pagination, CursorPagination) + assert result.data == [] + assert result.pagination.has_more is False + assert result.pagination.next_cursor is None + assert result.pagination.prev_cursor is None + + @pytest.mark.anyio + async def test_with_filters(self, db_session: AsyncSession): + """cursor_paginate respects filters.""" + for i in range(10): + await UserCrud.create( + db_session, + UserCreate( + username=f"user{i}", + email=f"user{i}@test.com", + is_active=i % 2 == 0, + ), + ) + + result = await UserCursorCrud.cursor_paginate( + db_session, + filters=[User.is_active == True], # noqa: E712 + items_per_page=20, + ) + + assert len(result.data) == 5 + assert all(u.is_active for u in result.data) + + @pytest.mark.anyio + async def test_with_schema(self, db_session: AsyncSession): + """cursor_paginate with schema serializes items into the schema.""" + from fastapi_toolsets.schemas import PaginatedResponse + + for i in range(3): + await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}")) + + result = await RoleCursorCrud.cursor_paginate(db_session, schema=RoleRead) + + assert isinstance(result, PaginatedResponse) + assert all(isinstance(item, RoleRead) for item in result.data) + assert all( + hasattr(item, "id") and hasattr(item, "name") for item in result.data + ) + + @pytest.mark.anyio + async def test_with_cursor_column(self, db_session: AsyncSession): + """cursor_paginate uses cursor_column set on CrudFactory.""" + from fastapi_toolsets.crud import CrudFactory + from fastapi_toolsets.schemas import CursorPagination + + RoleNameCrud = CrudFactory(Role, cursor_column=Role.name) + + for i in range(5): + await RoleNameCrud.create(db_session, RoleCreate(name=f"role{i:02d}")) + + result = await RoleNameCrud.cursor_paginate(db_session, items_per_page=3) + + assert isinstance(result.pagination, CursorPagination) + assert len(result.data) == 3 + assert result.pagination.has_more is True + assert result.pagination.next_cursor is not None + + @pytest.mark.anyio + async def test_raises_without_cursor_column(self, db_session: AsyncSession): + """cursor_paginate raises ValueError when cursor_column is not configured.""" + with pytest.raises(ValueError, match="cursor_column is not set"): + await RoleCrud.cursor_paginate(db_session) + + +class TestCursorPaginatePrevCursor: + """Tests for prev_cursor behavior in cursor_paginate().""" + + @pytest.mark.anyio + async def test_prev_cursor_none_on_first_page(self, db_session: AsyncSession): + """prev_cursor is None when no cursor was provided (first page).""" + for i in range(5): + await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}")) + + from fastapi_toolsets.schemas import CursorPagination + + result = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=3) + + assert isinstance(result.pagination, CursorPagination) + assert result.pagination.prev_cursor is None + + @pytest.mark.anyio + async def test_prev_cursor_set_on_subsequent_pages(self, db_session: AsyncSession): + """prev_cursor is set when a cursor was provided (subsequent pages).""" + for i in range(10): + await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}")) + + from fastapi_toolsets.schemas import CursorPagination + + page1 = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=5) + assert isinstance(page1.pagination, CursorPagination) + page2 = await RoleCursorCrud.cursor_paginate( + db_session, + cursor=page1.pagination.next_cursor, + items_per_page=5, + ) + assert isinstance(page2.pagination, CursorPagination) + assert page2.pagination.prev_cursor is not None + + @pytest.mark.anyio + async def test_prev_cursor_points_to_first_item(self, db_session: AsyncSession): + """prev_cursor encodes the value of the first item on the current page.""" + import base64 + import json + + for i in range(10): + await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}")) + + from fastapi_toolsets.schemas import CursorPagination + + page1 = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=5) + assert isinstance(page1.pagination, CursorPagination) + page2 = await RoleCursorCrud.cursor_paginate( + db_session, + cursor=page1.pagination.next_cursor, + items_per_page=5, + ) + assert isinstance(page2.pagination, CursorPagination) + assert page2.pagination.prev_cursor is not None + + # Decode prev_cursor and compare to first item's id + decoded = json.loads( + base64.b64decode(page2.pagination.prev_cursor.encode()).decode() + ) + first_item_id = str(page2.data[0].id) + assert decoded == first_item_id + + +class TestCursorPaginateWithSearch: + """Tests for cursor_paginate() combined with search.""" + + @pytest.mark.anyio + async def test_cursor_paginate_with_search(self, db_session: AsyncSession): + """cursor_paginate respects search filters alongside cursor predicate.""" + from fastapi_toolsets.crud import CrudFactory + + # Create a CRUD with searchable fields and cursor column + SearchableRoleCrud = CrudFactory( + Role, searchable_fields=[Role.name], cursor_column=Role.id + ) + + for i in range(5): + await RoleCrud.create(db_session, RoleCreate(name=f"admin{i:02d}")) + for i in range(5): + await RoleCrud.create(db_session, RoleCreate(name=f"user{i:02d}")) + + result = await SearchableRoleCrud.cursor_paginate( + db_session, + search="admin", + items_per_page=20, + ) + + assert len(result.data) == 5 + assert all("admin" in r.name for r in result.data) + + +class TestCursorPaginateExtraOptions: + """Tests for cursor_paginate() covering joins, load_options, and order_by.""" + + @pytest.mark.anyio + async def test_with_joins(self, db_session: AsyncSession): + """cursor_paginate applies explicit inner joins.""" + from fastapi_toolsets.schemas import CursorPagination + + role = await RoleCrud.create(db_session, RoleCreate(name="member")) + for i in range(5): + await UserCrud.create( + db_session, + UserCreate( + username=f"u{i}", + email=f"u{i}@test.com", + role_id=role.id, + ), + ) + # One user without a role to confirm inner join excludes them + await UserCrud.create( + db_session, + UserCreate(username="norole", email="norole@test.com"), + ) + + result = await UserCursorCrud.cursor_paginate( + db_session, + joins=[(Role, User.role_id == Role.id)], + items_per_page=20, + ) + + assert isinstance(result.pagination, CursorPagination) + # Only users with a role are returned (inner join) + assert len(result.data) == 5 + + @pytest.mark.anyio + async def test_with_outer_join(self, db_session: AsyncSession): + """cursor_paginate applies LEFT OUTER JOIN when outer_join=True.""" + from fastapi_toolsets.schemas import CursorPagination + + role = await RoleCrud.create(db_session, RoleCreate(name="member")) + for i in range(3): + await UserCrud.create( + db_session, + UserCreate( + username=f"u{i}", + email=f"u{i}@test.com", + role_id=role.id, + ), + ) + await UserCrud.create( + db_session, + UserCreate(username="norole", email="norole@test.com"), + ) + + result = await UserCursorCrud.cursor_paginate( + db_session, + joins=[(Role, User.role_id == Role.id)], + outer_join=True, + items_per_page=20, + ) + + assert isinstance(result.pagination, CursorPagination) + # All users are included (outer join) + assert len(result.data) == 4 + + @pytest.mark.anyio + async def test_with_load_options(self, db_session: AsyncSession): + """cursor_paginate passes load_options to the query.""" + from fastapi_toolsets.schemas import CursorPagination + + role = await RoleCrud.create(db_session, RoleCreate(name="manager")) + for i in range(3): + await UserCrud.create( + db_session, + UserCreate( + username=f"u{i}", + email=f"u{i}@test.com", + role_id=role.id, + ), + ) + + result = await UserCursorCrud.cursor_paginate( + db_session, + load_options=[selectinload(User.role)], + items_per_page=20, + ) + + assert isinstance(result.pagination, CursorPagination) + assert len(result.data) == 3 + # Relationship was eagerly loaded + assert all(u.role is not None for u in result.data) + + @pytest.mark.anyio + async def test_with_order_by(self, db_session: AsyncSession): + """cursor_paginate applies additional order_by after the cursor column.""" + from fastapi_toolsets.schemas import CursorPagination + + for i in range(5): + await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}")) + + result = await RoleCursorCrud.cursor_paginate( + db_session, + order_by=Role.name.desc(), + items_per_page=3, + ) + + assert isinstance(result.pagination, CursorPagination) + assert len(result.data) == 3 + + @pytest.mark.anyio + async def test_integer_cursor_column(self, db_session: AsyncSession): + """cursor_paginate decodes Integer cursor values correctly.""" + from fastapi_toolsets.schemas import CursorPagination + + for i in range(5): + await IntRoleCursorCrud.create(db_session, IntRoleCreate(name=f"role{i}")) + + page1 = await IntRoleCursorCrud.cursor_paginate(db_session, items_per_page=3) + + assert isinstance(page1.pagination, CursorPagination) + assert len(page1.data) == 3 + assert page1.pagination.has_more is True + + page2 = await IntRoleCursorCrud.cursor_paginate( + db_session, + cursor=page1.pagination.next_cursor, + items_per_page=3, + ) + + assert isinstance(page2.pagination, CursorPagination) + assert len(page2.data) == 2 + assert page2.pagination.has_more is False + + @pytest.mark.anyio + async def test_string_cursor_column(self, db_session: AsyncSession): + """cursor_paginate decodes non-UUID/non-Integer cursor values (string branch).""" + from fastapi_toolsets.crud import CrudFactory + from fastapi_toolsets.schemas import CursorPagination + + RoleNameCursorCrud = CrudFactory(Role, cursor_column=Role.name) + + for i in range(5): + await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}")) + + page1 = await RoleNameCursorCrud.cursor_paginate(db_session, items_per_page=3) + + assert isinstance(page1.pagination, CursorPagination) + assert len(page1.data) == 3 + assert page1.pagination.has_more is True + + page2 = await RoleNameCursorCrud.cursor_paginate( + db_session, + cursor=page1.pagination.next_cursor, + items_per_page=3, + ) + + assert isinstance(page2.pagination, CursorPagination) + assert len(page2.data) == 2 + assert page2.pagination.has_more is False + + +class TestCursorPaginateSearchJoins: + """Tests for cursor_paginate() search that traverses relationships (search_joins).""" + + @pytest.mark.anyio + async def test_search_via_relationship(self, db_session: AsyncSession): + """cursor_paginate outerjoin search-join when searching through a relationship.""" + from fastapi_toolsets.schemas import CursorPagination + + role_admin = await RoleCrud.create(db_session, RoleCreate(name="administrator")) + role_user = await RoleCrud.create(db_session, RoleCreate(name="regularuser")) + + for i in range(3): + await UserCrud.create( + db_session, + UserCreate( + username=f"admin_u{i}", + email=f"admin_u{i}@test.com", + role_id=role_admin.id, + ), + ) + for i in range(2): + await UserCrud.create( + db_session, + UserCreate( + username=f"reg_u{i}", + email=f"reg_u{i}@test.com", + role_id=role_user.id, + ), + ) + + result = await UserCursorCrud.cursor_paginate( + db_session, + search="administrator", + search_fields=[(User.role, Role.name)], + items_per_page=20, + ) + + assert isinstance(result.pagination, CursorPagination) + assert len(result.data) == 3 + + +class TestGetWithForUpdate: + """Tests for get() with with_for_update=True.""" + + @pytest.mark.anyio + async def test_get_with_for_update(self, db_session: AsyncSession): + """get() with with_for_update=True locks the row.""" + role = await RoleCrud.create(db_session, RoleCreate(name="locked")) + + result = await RoleCrud.get( + db_session, + filters=[Role.id == role.id], + with_for_update=True, + ) + + assert result.id == role.id + assert result.name == "locked" diff --git a/tests/test_crud_search.py b/tests/test_crud_search.py index 8f45826..79e6eaf 100644 --- a/tests/test_crud_search.py +++ b/tests/test_crud_search.py @@ -6,6 +6,7 @@ import pytest from sqlalchemy.ext.asyncio import AsyncSession from fastapi_toolsets.crud import SearchConfig, get_searchable_fields +from fastapi_toolsets.schemas import OffsetPagination from .conftest import ( Role, @@ -39,6 +40,7 @@ class TestPaginateSearch: search_fields=[User.username], ) + assert isinstance(result.pagination, OffsetPagination) assert result.pagination.total_count == 2 @pytest.mark.anyio @@ -57,6 +59,7 @@ class TestPaginateSearch: search_fields=[User.username, User.email], ) + assert isinstance(result.pagination, OffsetPagination) assert result.pagination.total_count == 2 @pytest.mark.anyio @@ -84,6 +87,7 @@ class TestPaginateSearch: search_fields=[(User.role, Role.name)], ) + assert isinstance(result.pagination, OffsetPagination) assert result.pagination.total_count == 2 @pytest.mark.anyio @@ -102,6 +106,7 @@ class TestPaginateSearch: search_fields=[User.username, (User.role, Role.name)], ) + assert isinstance(result.pagination, OffsetPagination) assert result.pagination.total_count == 1 @pytest.mark.anyio @@ -117,6 +122,7 @@ class TestPaginateSearch: search_fields=[User.username], ) + assert isinstance(result.pagination, OffsetPagination) assert result.pagination.total_count == 1 @pytest.mark.anyio @@ -132,6 +138,7 @@ class TestPaginateSearch: search=SearchConfig(query="johndoe", case_sensitive=True), search_fields=[User.username], ) + assert isinstance(result.pagination, OffsetPagination) assert result.pagination.total_count == 0 # Should find (case match) @@ -140,6 +147,7 @@ class TestPaginateSearch: search=SearchConfig(query="JohnDoe", case_sensitive=True), search_fields=[User.username], ) + assert isinstance(result.pagination, OffsetPagination) assert result.pagination.total_count == 1 @pytest.mark.anyio @@ -153,9 +161,11 @@ class TestPaginateSearch: ) result = await UserCrud.paginate(db_session, search="") + assert isinstance(result.pagination, OffsetPagination) assert result.pagination.total_count == 2 result = await UserCrud.paginate(db_session, search=None) + assert isinstance(result.pagination, OffsetPagination) assert result.pagination.total_count == 2 @pytest.mark.anyio @@ -177,6 +187,7 @@ class TestPaginateSearch: search_fields=[User.username], ) + assert isinstance(result.pagination, OffsetPagination) assert result.pagination.total_count == 1 assert result.data[0].username == "active_john" @@ -189,6 +200,7 @@ class TestPaginateSearch: result = await UserCrud.paginate(db_session, search="findme") + assert isinstance(result.pagination, OffsetPagination) assert result.pagination.total_count == 1 @pytest.mark.anyio @@ -204,6 +216,7 @@ class TestPaginateSearch: search_fields=[User.username], ) + assert isinstance(result.pagination, OffsetPagination) assert result.pagination.total_count == 0 assert result.data == [] @@ -224,6 +237,7 @@ class TestPaginateSearch: items_per_page=5, ) + assert isinstance(result.pagination, OffsetPagination) assert result.pagination.total_count == 15 assert len(result.data) == 5 assert result.pagination.has_more is True @@ -248,6 +262,7 @@ class TestPaginateSearch: search_fields=[User.username], ) + assert isinstance(result.pagination, OffsetPagination) assert result.pagination.total_count == 2 @pytest.mark.anyio @@ -270,6 +285,7 @@ class TestPaginateSearch: order_by=User.username, ) + assert isinstance(result.pagination, OffsetPagination) assert result.pagination.total_count == 3 usernames = [u.username for u in result.data] assert usernames == ["alice", "bob", "charlie"] @@ -292,6 +308,7 @@ class TestPaginateSearch: search_fields=[User.id, User.username], ) + assert isinstance(result.pagination, OffsetPagination) assert result.pagination.total_count == 1 assert result.data[0].id == user_id @@ -318,6 +335,7 @@ class TestSearchConfig: search_fields=[User.username, User.email], ) + assert isinstance(result.pagination, OffsetPagination) assert result.pagination.total_count == 1 assert result.data[0].username == "john_test" @@ -333,6 +351,7 @@ class TestSearchConfig: search=SearchConfig(query="findme", fields=[User.email]), ) + assert isinstance(result.pagination, OffsetPagination) assert result.pagination.total_count == 1 diff --git a/tests/test_schemas.py b/tests/test_schemas.py index abea9ac..c23b660 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -5,7 +5,9 @@ from pydantic import ValidationError from fastapi_toolsets.schemas import ( ApiError, + CursorPagination, ErrorResponse, + OffsetPagination, PaginatedResponse, Pagination, Response, @@ -154,12 +156,12 @@ class TestErrorResponse: assert data["description"] == "Details" -class TestPagination: - """Tests for Pagination schema.""" +class TestOffsetPagination: + """Tests for OffsetPagination schema (canonical name for offset-based pagination).""" def test_create_pagination(self): - """Create Pagination with all fields.""" - pagination = Pagination( + """Create OffsetPagination with all fields.""" + pagination = OffsetPagination( total_count=100, items_per_page=10, page=1, @@ -173,7 +175,7 @@ class TestPagination: def test_last_page_has_more_false(self): """Last page has has_more=False.""" - pagination = Pagination( + pagination = OffsetPagination( total_count=25, items_per_page=10, page=3, @@ -183,8 +185,8 @@ class TestPagination: assert pagination.has_more is False def test_serialization(self): - """Pagination serializes correctly.""" - pagination = Pagination( + """OffsetPagination serializes correctly.""" + pagination = OffsetPagination( total_count=50, items_per_page=20, page=2, @@ -197,6 +199,77 @@ class TestPagination: assert data["page"] == 2 assert data["has_more"] is True + def test_pagination_alias_is_offset_pagination(self): + """Pagination is a backward-compatible alias for OffsetPagination.""" + assert Pagination is OffsetPagination + + def test_pagination_alias_constructs_offset_pagination(self): + """Code using Pagination(...) still works unchanged.""" + pagination = Pagination( + total_count=10, + items_per_page=5, + page=2, + has_more=False, + ) + assert isinstance(pagination, OffsetPagination) + + +class TestCursorPagination: + """Tests for CursorPagination schema.""" + + def test_create_with_next_cursor(self): + """CursorPagination with a next cursor indicates more pages.""" + pagination = CursorPagination( + next_cursor="eyJ2YWx1ZSI6ICIxMjMifQ==", + items_per_page=20, + has_more=True, + ) + assert pagination.next_cursor == "eyJ2YWx1ZSI6ICIxMjMifQ==" + assert pagination.prev_cursor is None + assert pagination.items_per_page == 20 + assert pagination.has_more is True + + def test_create_last_page(self): + """CursorPagination for the last page has next_cursor=None and has_more=False.""" + pagination = CursorPagination( + next_cursor=None, + items_per_page=20, + has_more=False, + ) + assert pagination.next_cursor is None + assert pagination.has_more is False + + def test_prev_cursor_defaults_to_none(self): + """prev_cursor defaults to None.""" + pagination = CursorPagination( + next_cursor=None, items_per_page=10, has_more=False + ) + assert pagination.prev_cursor is None + + def test_prev_cursor_can_be_set(self): + """prev_cursor can be explicitly set.""" + pagination = CursorPagination( + next_cursor="next123", + prev_cursor="prev456", + items_per_page=10, + has_more=True, + ) + assert pagination.prev_cursor == "prev456" + + def test_serialization(self): + """CursorPagination serializes correctly.""" + pagination = CursorPagination( + next_cursor="abc123", + prev_cursor="xyz789", + items_per_page=20, + has_more=True, + ) + data = pagination.model_dump() + assert data["next_cursor"] == "abc123" + assert data["prev_cursor"] == "xyz789" + assert data["items_per_page"] == 20 + assert data["has_more"] is True + class TestPaginatedResponse: """Tests for PaginatedResponse schema.""" @@ -214,6 +287,7 @@ class TestPaginatedResponse: pagination=pagination, ) + assert isinstance(response.pagination, OffsetPagination) assert len(response.data) == 2 assert response.pagination.total_count == 30 assert response.status == ResponseStatus.SUCCESS @@ -247,6 +321,7 @@ class TestPaginatedResponse: pagination=pagination, ) + assert isinstance(response.pagination, OffsetPagination) assert response.data == [] assert response.pagination.total_count == 0 @@ -290,6 +365,36 @@ class TestPaginatedResponse: assert data["data"] == ["item1", "item2"] assert data["pagination"]["page"] == 5 + def test_pagination_field_accepts_offset_pagination(self): + """PaginatedResponse.pagination accepts OffsetPagination.""" + response = PaginatedResponse( + data=[1, 2], + pagination=OffsetPagination( + total_count=2, items_per_page=10, page=1, has_more=False + ), + ) + assert isinstance(response.pagination, OffsetPagination) + + def test_pagination_field_accepts_cursor_pagination(self): + """PaginatedResponse.pagination accepts CursorPagination.""" + response = PaginatedResponse( + data=[1, 2], + pagination=CursorPagination( + next_cursor=None, items_per_page=10, has_more=False + ), + ) + assert isinstance(response.pagination, CursorPagination) + + def test_pagination_alias_accepted(self): + """Constructing PaginatedResponse with Pagination (alias) still works.""" + response = PaginatedResponse( + data=[], + pagination=Pagination( + total_count=0, items_per_page=10, page=1, has_more=False + ), + ) + assert isinstance(response.pagination, OffsetPagination) + class TestFromAttributes: """Tests for from_attributes config (ORM mode)."""