diff --git a/docs_src/examples/pagination_search/routes.py b/docs_src/examples/pagination_search/routes.py index e88778d..821e82b 100644 --- a/docs_src/examples/pagination_search/routes.py +++ b/docs_src/examples/pagination_search/routes.py @@ -2,8 +2,12 @@ from typing import Annotated from fastapi import APIRouter, Depends, Query -from fastapi_toolsets.crud import OrderByClause -from fastapi_toolsets.schemas import PaginatedResponse +from fastapi_toolsets.crud import OrderByClause, PaginationType +from fastapi_toolsets.schemas import ( + CursorPaginatedResponse, + OffsetPaginatedResponse, + PaginatedResponse, +) from .crud import ArticleCrud from .db import SessionDep @@ -24,7 +28,7 @@ async def list_articles_offset( page: int = Query(1, ge=1), items_per_page: int = Query(20, ge=1, le=100), search: str | None = None, -) -> PaginatedResponse[ArticleRead]: +) -> OffsetPaginatedResponse[ArticleRead]: return await ArticleCrud.offset_paginate( session=session, page=page, @@ -47,7 +51,7 @@ async def list_articles_cursor( cursor: str | None = None, items_per_page: int = Query(20, ge=1, le=100), search: str | None = None, -) -> PaginatedResponse[ArticleRead]: +) -> CursorPaginatedResponse[ArticleRead]: return await ArticleCrud.cursor_paginate( session=session, cursor=cursor, @@ -57,3 +61,42 @@ async def list_articles_cursor( order_by=order_by, schema=ArticleRead, ) + + +@router.get("/") +async def list_articles( + session: SessionDep, + filter_by: Annotated[dict[str, list[str]], Depends(ArticleCrud.filter_params())], + order_by: Annotated[ + OrderByClause | None, + Depends(ArticleCrud.order_params(default_field=Article.created_at)), + ], + pagination_type: PaginationType = PaginationType.OFFSET, + page: int = Query(1, ge=1, description="Current page (offset pagination only)"), + cursor: str | None = Query( + None, description="Cursor token (cursor pagination only)" + ), + items_per_page: int = Query(20, ge=1, le=100), + search: str | None = None, +) -> PaginatedResponse[ArticleRead]: + """List articles using either offset or cursor pagination. + + Pass `pagination_type=offset` (default) for page-based pagination with a + total count, or `pagination_type=cursor` for efficient cursor-based + pagination suited to large datasets and infinite scroll. + + - **offset**: use `page` to navigate; response includes `total_count`. + - **cursor**: use the `next_cursor` / `prev_cursor` from the previous + response as the `cursor` query parameter; no total count is returned. + """ + return await ArticleCrud.paginate( + session, + pagination_type=pagination_type, + page=page, + cursor=cursor, + items_per_page=items_per_page, + search=search, + filter_by=filter_by or None, + order_by=order_by, + schema=ArticleRead, + ) diff --git a/src/fastapi_toolsets/crud/__init__.py b/src/fastapi_toolsets/crud/__init__.py index cb22110..bc4fb60 100644 --- a/src/fastapi_toolsets/crud/__init__.py +++ b/src/fastapi_toolsets/crud/__init__.py @@ -1,6 +1,7 @@ """Generic async CRUD operations for SQLAlchemy models.""" from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError +from ..schemas import PaginationType from ..types import ( FacetFieldType, JoinType, @@ -8,10 +9,11 @@ from ..types import ( OrderByClause, SearchFieldType, ) -from .factory import CrudFactory +from .factory import AsyncCrud, CrudFactory from .search import SearchConfig, get_searchable_fields __all__ = [ + "AsyncCrud", "CrudFactory", "FacetFieldType", "get_searchable_fields", @@ -20,6 +22,7 @@ __all__ = [ "M2MFieldType", "NoSearchableFieldsError", "OrderByClause", + "PaginationType", "SearchConfig", "SearchFieldType", ] diff --git a/src/fastapi_toolsets/crud/factory.py b/src/fastapi_toolsets/crud/factory.py index ce754ff..5361faf 100644 --- a/src/fastapi_toolsets/crud/factory.py +++ b/src/fastapi_toolsets/crud/factory.py @@ -23,7 +23,14 @@ from sqlalchemy.sql.roles import WhereHavingRole from ..db import get_transaction from ..exceptions import InvalidOrderFieldError, NotFoundError -from ..schemas import CursorPagination, OffsetPagination, PaginatedResponse, Response +from ..schemas import ( + CursorPaginatedResponse, + CursorPagination, + OffsetPaginatedResponse, + OffsetPagination, + PaginationType, + Response, +) from ..types import ( FacetFieldType, JoinType, @@ -889,7 +896,7 @@ class AsyncCrud(Generic[ModelType]): facet_fields: Sequence[FacetFieldType] | None = None, filter_by: dict[str, Any] | BaseModel | None = None, schema: type[BaseModel], - ) -> PaginatedResponse[Any]: + ) -> OffsetPaginatedResponse[Any]: """Get paginated results using offset-based pagination. Args: @@ -971,7 +978,7 @@ class AsyncCrud(Generic[ModelType]): session, facet_fields, filters, search_joins ) - return PaginatedResponse( + return OffsetPaginatedResponse( data=items, pagination=OffsetPagination( total_count=total_count, @@ -999,7 +1006,7 @@ class AsyncCrud(Generic[ModelType]): facet_fields: Sequence[FacetFieldType] | None = None, filter_by: dict[str, Any] | BaseModel | None = None, schema: type[BaseModel], - ) -> PaginatedResponse[Any]: + ) -> CursorPaginatedResponse[Any]: """Get paginated results using cursor-based pagination. Args: @@ -1113,7 +1120,7 @@ class AsyncCrud(Generic[ModelType]): session, facet_fields, filters, search_joins ) - return PaginatedResponse( + return CursorPaginatedResponse( data=items, pagination=CursorPagination( next_cursor=next_cursor, @@ -1124,6 +1131,136 @@ class AsyncCrud(Generic[ModelType]): filter_attributes=filter_attributes, ) + @overload + @classmethod + async def paginate( # pragma: no cover + cls: type[Self], + session: AsyncSession, + *, + pagination_type: Literal[PaginationType.OFFSET], + filters: list[Any] | None = ..., + joins: JoinType | None = ..., + outer_join: bool = ..., + load_options: Sequence[ExecutableOption] | None = ..., + order_by: OrderByClause | None = ..., + page: int = ..., + cursor: str | None = ..., + items_per_page: int = ..., + search: str | SearchConfig | None = ..., + search_fields: Sequence[SearchFieldType] | None = ..., + facet_fields: Sequence[FacetFieldType] | None = ..., + filter_by: dict[str, Any] | BaseModel | None = ..., + schema: type[BaseModel], + ) -> OffsetPaginatedResponse[Any]: ... + + @overload + @classmethod + async def paginate( # pragma: no cover + cls: type[Self], + session: AsyncSession, + *, + pagination_type: Literal[PaginationType.CURSOR], + filters: list[Any] | None = ..., + joins: JoinType | None = ..., + outer_join: bool = ..., + load_options: Sequence[ExecutableOption] | None = ..., + order_by: OrderByClause | None = ..., + page: int = ..., + cursor: str | None = ..., + items_per_page: int = ..., + search: str | SearchConfig | None = ..., + search_fields: Sequence[SearchFieldType] | None = ..., + facet_fields: Sequence[FacetFieldType] | None = ..., + filter_by: dict[str, Any] | BaseModel | None = ..., + schema: type[BaseModel], + ) -> CursorPaginatedResponse[Any]: ... + + @classmethod + async def paginate( + cls: type[Self], + session: AsyncSession, + *, + pagination_type: PaginationType = PaginationType.OFFSET, + filters: list[Any] | None = None, + joins: JoinType | None = None, + outer_join: bool = False, + load_options: Sequence[ExecutableOption] | None = None, + order_by: OrderByClause | None = None, + page: int = 1, + cursor: str | None = None, + items_per_page: int = 20, + search: str | SearchConfig | None = None, + search_fields: Sequence[SearchFieldType] | None = None, + facet_fields: Sequence[FacetFieldType] | None = None, + filter_by: dict[str, Any] | BaseModel | None = None, + schema: type[BaseModel], + ) -> OffsetPaginatedResponse[Any] | CursorPaginatedResponse[Any]: + """Get paginated results using either offset or cursor pagination. + + Args: + session: DB async session. + pagination_type: Pagination strategy. Defaults to + ``PaginationType.OFFSET``. + filters: List of SQLAlchemy filter conditions. + joins: List of ``(model, condition)`` tuples for joining related + tables. + outer_join: Use LEFT OUTER JOIN instead of INNER JOIN. + load_options: SQLAlchemy loader options. Falls back to + ``default_load_options`` when not provided. + order_by: Column or expression to order results by. + page: Page number (1-indexed). Only used when + ``pagination_type`` is ``OFFSET``. + cursor: Cursor token from a previous + :class:`.CursorPaginatedResponse`. Only used when + ``pagination_type`` is ``CURSOR``. + items_per_page: Number of items per page (default 20). + search: Search query string or :class:`.SearchConfig` object. + search_fields: Fields to search in (overrides class default). + facet_fields: Columns to compute distinct values for (overrides + class default). + filter_by: Dict of ``{column_key: value}`` to filter by declared + facet fields. Keys must match the ``column.key`` of a facet + field. Scalar → equality, list → IN clause. Raises + :exc:`.InvalidFacetFilterError` for unknown keys. + schema: Pydantic schema to serialize each item into. + + Returns: + :class:`.OffsetPaginatedResponse` when ``pagination_type`` is + ``OFFSET``, :class:`.CursorPaginatedResponse` when it is + ``CURSOR``. + """ + if pagination_type is PaginationType.CURSOR: + return await cls.cursor_paginate( + session, + cursor=cursor, + filters=filters, + joins=joins, + outer_join=outer_join, + load_options=load_options, + order_by=order_by, + items_per_page=items_per_page, + search=search, + search_fields=search_fields, + facet_fields=facet_fields, + filter_by=filter_by, + schema=schema, + ) + return await cls.offset_paginate( + session, + filters=filters, + joins=joins, + outer_join=outer_join, + load_options=load_options, + order_by=order_by, + page=page, + items_per_page=items_per_page, + search=search, + search_fields=search_fields, + facet_fields=facet_fields, + filter_by=filter_by, + schema=schema, + ) + def CrudFactory( model: type[ModelType], diff --git a/src/fastapi_toolsets/schemas.py b/src/fastapi_toolsets/schemas.py index 80016cd..bf0e88b 100644 --- a/src/fastapi_toolsets/schemas.py +++ b/src/fastapi_toolsets/schemas.py @@ -1,7 +1,7 @@ """Base Pydantic schemas for API responses.""" from enum import Enum -from typing import Any, ClassVar, Generic +from typing import Any, ClassVar, Generic, Literal from pydantic import BaseModel, ConfigDict @@ -10,9 +10,12 @@ from .types import DataT __all__ = [ "ApiError", "CursorPagination", + "CursorPaginatedResponse", "ErrorResponse", "OffsetPagination", + "OffsetPaginatedResponse", "PaginatedResponse", + "PaginationType", "PydanticBase", "Response", "ResponseStatus", @@ -123,9 +126,48 @@ class CursorPagination(PydanticBase): has_more: bool +class PaginationType(str, Enum): + """Pagination strategy selector for :meth:`.AsyncCrud.paginate`.""" + + OFFSET = "offset" + CURSOR = "cursor" + + class PaginatedResponse(BaseResponse, Generic[DataT]): - """Paginated API response for list endpoints.""" + """Paginated API response for list endpoints. + + Base class and return type for endpoints that support both pagination + strategies. Use :class:`OffsetPaginatedResponse` or + :class:`CursorPaginatedResponse` when the strategy is fixed; use + ``PaginatedResponse`` as the return annotation for unified endpoints that + dispatch via :meth:`~fastapi_toolsets.crud.factory.AsyncCrud.paginate`. + """ data: list[DataT] pagination: OffsetPagination | CursorPagination + pagination_type: PaginationType | None = None filter_attributes: dict[str, list[Any]] | None = None + + +class OffsetPaginatedResponse(PaginatedResponse[DataT]): + """Paginated response with typed offset-based pagination metadata. + + The ``pagination_type`` field is always ``"offset"`` and acts as a + discriminator, allowing frontend clients to narrow the union type returned + by a unified ``paginate()`` endpoint. + """ + + pagination: OffsetPagination + pagination_type: Literal[PaginationType.OFFSET] = PaginationType.OFFSET + + +class CursorPaginatedResponse(PaginatedResponse[DataT]): + """Paginated response with typed cursor-based pagination metadata. + + The ``pagination_type`` field is always ``"cursor"`` and acts as a + discriminator, allowing frontend clients to narrow the union type returned + by a unified ``paginate()`` endpoint. + """ + + pagination: CursorPagination + pagination_type: Literal[PaginationType.CURSOR] = PaginationType.CURSOR diff --git a/tests/test_example_pagination_search.py b/tests/test_example_pagination_search.py index 9ca98fe..3281cd2 100644 --- a/tests/test_example_pagination_search.py +++ b/tests/test_example_pagination_search.py @@ -393,3 +393,105 @@ class TestCursorSorting: body = resp.json() assert body["error_code"] == "SORT-422" assert body["status"] == "FAIL" + + +class TestPaginateUnified: + """Tests for the unified GET /articles/ endpoint using paginate().""" + + @pytest.mark.anyio + async def test_defaults_to_offset_pagination( + self, client: AsyncClient, ex_db_session + ): + """Without pagination_type, defaults to offset pagination.""" + await seed(ex_db_session) + + resp = await client.get("/articles/") + + assert resp.status_code == 200 + body = resp.json() + assert body["pagination_type"] == "offset" + assert "total_count" in body["pagination"] + assert body["pagination"]["total_count"] == 3 + + @pytest.mark.anyio + async def test_explicit_offset_pagination(self, client: AsyncClient, ex_db_session): + """pagination_type=offset returns OffsetPagination metadata.""" + await seed(ex_db_session) + + resp = await client.get( + "/articles/?pagination_type=offset&page=1&items_per_page=2" + ) + + assert resp.status_code == 200 + body = resp.json() + assert body["pagination_type"] == "offset" + assert body["pagination"]["total_count"] == 3 + assert body["pagination"]["page"] == 1 + assert body["pagination"]["has_more"] is True + assert len(body["data"]) == 2 + + @pytest.mark.anyio + async def test_cursor_pagination_type(self, client: AsyncClient, ex_db_session): + """pagination_type=cursor returns CursorPagination metadata.""" + await seed(ex_db_session) + + resp = await client.get("/articles/?pagination_type=cursor&items_per_page=2") + + assert resp.status_code == 200 + body = resp.json() + assert body["pagination_type"] == "cursor" + assert "next_cursor" in body["pagination"] + assert "total_count" not in body["pagination"] + assert body["pagination"]["has_more"] is True + assert len(body["data"]) == 2 + + @pytest.mark.anyio + async def test_cursor_pagination_navigate_pages( + self, client: AsyncClient, ex_db_session + ): + """Cursor from first page can be used to fetch the next page.""" + await seed(ex_db_session) + + first = await client.get("/articles/?pagination_type=cursor&items_per_page=2") + assert first.status_code == 200 + first_body = first.json() + next_cursor = first_body["pagination"]["next_cursor"] + assert next_cursor is not None + + second = await client.get( + f"/articles/?pagination_type=cursor&items_per_page=2&cursor={next_cursor}" + ) + assert second.status_code == 200 + second_body = second.json() + assert second_body["pagination_type"] == "cursor" + assert second_body["pagination"]["has_more"] is False + assert len(second_body["data"]) == 1 + + @pytest.mark.anyio + async def test_cursor_pagination_with_search( + self, client: AsyncClient, ex_db_session + ): + """paginate() with cursor type respects search parameter.""" + await seed(ex_db_session) + + resp = await client.get("/articles/?pagination_type=cursor&search=fastapi") + + assert resp.status_code == 200 + body = resp.json() + assert body["pagination_type"] == "cursor" + assert len(body["data"]) == 1 + assert body["data"][0]["title"] == "FastAPI tips" + + @pytest.mark.anyio + async def test_offset_pagination_with_filter( + self, client: AsyncClient, ex_db_session + ): + """paginate() with offset type respects filter_by parameter.""" + await seed(ex_db_session) + + resp = await client.get("/articles/?pagination_type=offset&status=published") + + assert resp.status_code == 200 + body = resp.json() + assert body["pagination_type"] == "offset" + assert body["pagination"]["total_count"] == 2 diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 17694b3..a5f17a0 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -6,9 +6,12 @@ from pydantic import ValidationError from fastapi_toolsets.schemas import ( ApiError, CursorPagination, + CursorPaginatedResponse, ErrorResponse, OffsetPagination, + OffsetPaginatedResponse, PaginatedResponse, + PaginationType, Response, ResponseStatus, ) @@ -312,11 +315,6 @@ class TestPaginatedResponse: def test_generic_type_hint(self): """PaginatedResponse supports generic type hints.""" - - class UserOut: - id: int - name: str - pagination = OffsetPagination( total_count=1, items_per_page=10, @@ -371,6 +369,191 @@ class TestPaginatedResponse: assert isinstance(response.pagination, CursorPagination) +class TestPaginationType: + """Tests for PaginationType enum.""" + + def test_offset_value(self): + """OFFSET has string value 'offset'.""" + assert PaginationType.OFFSET == "offset" + assert PaginationType.OFFSET.value == "offset" + + def test_cursor_value(self): + """CURSOR has string value 'cursor'.""" + assert PaginationType.CURSOR == "cursor" + assert PaginationType.CURSOR.value == "cursor" + + def test_is_string_enum(self): + """PaginationType is a string enum.""" + assert isinstance(PaginationType.OFFSET, str) + assert isinstance(PaginationType.CURSOR, str) + + def test_members(self): + """PaginationType has exactly two members.""" + assert set(PaginationType) == {PaginationType.OFFSET, PaginationType.CURSOR} + + +class TestOffsetPaginatedResponse: + """Tests for OffsetPaginatedResponse schema.""" + + def test_pagination_type_is_offset(self): + """pagination_type is always PaginationType.OFFSET.""" + response = OffsetPaginatedResponse( + data=[], + pagination=OffsetPagination( + total_count=0, items_per_page=10, page=1, has_more=False + ), + ) + assert response.pagination_type is PaginationType.OFFSET + + def test_pagination_type_serializes_to_string(self): + """pagination_type serializes to 'offset' in JSON mode.""" + response = OffsetPaginatedResponse( + data=[], + pagination=OffsetPagination( + total_count=0, items_per_page=10, page=1, has_more=False + ), + ) + assert response.model_dump(mode="json")["pagination_type"] == "offset" + + def test_pagination_field_is_typed(self): + """pagination field is OffsetPagination, not the union.""" + response = OffsetPaginatedResponse( + data=[{"id": 1}], + pagination=OffsetPagination( + total_count=10, items_per_page=5, page=2, has_more=True + ), + ) + assert isinstance(response.pagination, OffsetPagination) + assert response.pagination.total_count == 10 + assert response.pagination.page == 2 + + def test_is_subclass_of_paginated_response(self): + """OffsetPaginatedResponse IS a PaginatedResponse.""" + response = OffsetPaginatedResponse( + data=[], + pagination=OffsetPagination( + total_count=0, items_per_page=10, page=1, has_more=False + ), + ) + assert isinstance(response, PaginatedResponse) + + def test_pagination_type_default_cannot_be_overridden_to_cursor(self): + """pagination_type rejects values other than OFFSET.""" + with pytest.raises(ValidationError): + OffsetPaginatedResponse( + data=[], + pagination=OffsetPagination( + total_count=0, items_per_page=10, page=1, has_more=False + ), + pagination_type=PaginationType.CURSOR, # type: ignore[arg-type] + ) + + def test_filter_attributes_defaults_to_none(self): + """filter_attributes defaults to None.""" + response = OffsetPaginatedResponse( + data=[], + pagination=OffsetPagination( + total_count=0, items_per_page=10, page=1, has_more=False + ), + ) + assert response.filter_attributes is None + + def test_full_serialization(self): + """Full JSON serialization includes all expected fields.""" + response = OffsetPaginatedResponse( + data=[{"id": 1}], + pagination=OffsetPagination( + total_count=1, items_per_page=10, page=1, has_more=False + ), + filter_attributes={"status": ["active"]}, + ) + data = response.model_dump(mode="json") + + assert data["pagination_type"] == "offset" + assert data["status"] == "SUCCESS" + assert data["data"] == [{"id": 1}] + assert data["pagination"]["total_count"] == 1 + assert data["filter_attributes"] == {"status": ["active"]} + + +class TestCursorPaginatedResponse: + """Tests for CursorPaginatedResponse schema.""" + + def test_pagination_type_is_cursor(self): + """pagination_type is always PaginationType.CURSOR.""" + response = CursorPaginatedResponse( + data=[], + pagination=CursorPagination( + next_cursor=None, items_per_page=10, has_more=False + ), + ) + assert response.pagination_type is PaginationType.CURSOR + + def test_pagination_type_serializes_to_string(self): + """pagination_type serializes to 'cursor' in JSON mode.""" + response = CursorPaginatedResponse( + data=[], + pagination=CursorPagination( + next_cursor=None, items_per_page=10, has_more=False + ), + ) + assert response.model_dump(mode="json")["pagination_type"] == "cursor" + + def test_pagination_field_is_typed(self): + """pagination field is CursorPagination, not the union.""" + response = CursorPaginatedResponse( + data=[{"id": 1}], + pagination=CursorPagination( + next_cursor="abc123", + prev_cursor=None, + items_per_page=20, + has_more=True, + ), + ) + assert isinstance(response.pagination, CursorPagination) + assert response.pagination.next_cursor == "abc123" + assert response.pagination.has_more is True + + def test_is_subclass_of_paginated_response(self): + """CursorPaginatedResponse IS a PaginatedResponse.""" + response = CursorPaginatedResponse( + data=[], + pagination=CursorPagination( + next_cursor=None, items_per_page=10, has_more=False + ), + ) + assert isinstance(response, PaginatedResponse) + + def test_pagination_type_default_cannot_be_overridden_to_offset(self): + """pagination_type rejects values other than CURSOR.""" + with pytest.raises(ValidationError): + CursorPaginatedResponse( + data=[], + pagination=CursorPagination( + next_cursor=None, items_per_page=10, has_more=False + ), + pagination_type=PaginationType.OFFSET, # type: ignore[arg-type] + ) + + def test_full_serialization(self): + """Full JSON serialization includes all expected fields.""" + response = CursorPaginatedResponse( + data=[{"id": 1}], + pagination=CursorPagination( + next_cursor="tok_next", + prev_cursor="tok_prev", + items_per_page=10, + has_more=True, + ), + ) + data = response.model_dump(mode="json") + + assert data["pagination_type"] == "cursor" + assert data["status"] == "SUCCESS" + assert data["pagination"]["next_cursor"] == "tok_next" + assert data["pagination"]["prev_cursor"] == "tok_prev" + + class TestFromAttributes: """Tests for from_attributes config (ORM mode)."""