diff --git a/src/fastapi_toolsets/crud/factory.py b/src/fastapi_toolsets/crud/factory.py index 5361faf..7d8f752 100644 --- a/src/fastapi_toolsets/crud/factory.py +++ b/src/fastapi_toolsets/crud/factory.py @@ -1229,37 +1229,45 @@ class AsyncCrud(Generic[ModelType]): ``OFFSET``, :class:`.CursorPaginatedResponse` when it is ``CURSOR``. """ - if pagination_type is PaginationType.CURSOR: - return await cls.cursor_paginate( - session, - cursor=cursor, - filters=filters, - joins=joins, - outer_join=outer_join, - load_options=load_options, - order_by=order_by, - items_per_page=items_per_page, - search=search, - search_fields=search_fields, - facet_fields=facet_fields, - filter_by=filter_by, - schema=schema, - ) - return await cls.offset_paginate( - session, - filters=filters, - joins=joins, - outer_join=outer_join, - load_options=load_options, - order_by=order_by, - page=page, - items_per_page=items_per_page, - search=search, - search_fields=search_fields, - facet_fields=facet_fields, - filter_by=filter_by, - schema=schema, - ) + if items_per_page < 1: + raise ValueError(f"items_per_page must be >= 1, got {items_per_page}") + match pagination_type: + case PaginationType.CURSOR: + return await cls.cursor_paginate( + session, + cursor=cursor, + filters=filters, + joins=joins, + outer_join=outer_join, + load_options=load_options, + order_by=order_by, + items_per_page=items_per_page, + search=search, + search_fields=search_fields, + facet_fields=facet_fields, + filter_by=filter_by, + schema=schema, + ) + case PaginationType.OFFSET: + if page < 1: + raise ValueError(f"page must be >= 1, got {page}") + return await cls.offset_paginate( + session, + filters=filters, + joins=joins, + outer_join=outer_join, + load_options=load_options, + order_by=order_by, + page=page, + items_per_page=items_per_page, + search=search, + search_fields=search_fields, + facet_fields=facet_fields, + filter_by=filter_by, + schema=schema, + ) + case _: + raise ValueError(f"Unknown pagination_type: {pagination_type!r}") def CrudFactory( diff --git a/tests/test_crud.py b/tests/test_crud.py index 58e9a0f..809c2fd 100644 --- a/tests/test_crud.py +++ b/tests/test_crud.py @@ -6,7 +6,7 @@ import pytest from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload -from fastapi_toolsets.crud import CrudFactory +from fastapi_toolsets.crud import CrudFactory, PaginationType from fastapi_toolsets.crud.factory import AsyncCrud from fastapi_toolsets.exceptions import NotFoundError @@ -2384,3 +2384,72 @@ class TestCursorPaginateColumnTypes: page1_ids = {p.id for p in page1.data} page2_ids = {p.id for p in page2.data} assert page1_ids.isdisjoint(page2_ids) + + +class TestPaginate: + """Tests for the unified paginate() method.""" + + @pytest.mark.anyio + async def test_offset_pagination(self, db_session: AsyncSession): + """paginate() with OFFSET returns OffsetPaginatedResponse.""" + from fastapi_toolsets.schemas import OffsetPagination + + await RoleCrud.create(db_session, RoleCreate(name="admin")) + await RoleCrud.create(db_session, RoleCreate(name="user")) + + result = await RoleCrud.paginate( + db_session, + pagination_type=PaginationType.OFFSET, + schema=RoleRead, + ) + + assert isinstance(result.pagination, OffsetPagination) + assert result.pagination_type == PaginationType.OFFSET + + @pytest.mark.anyio + async def test_cursor_pagination(self, db_session: AsyncSession): + """paginate() with CURSOR returns CursorPaginatedResponse.""" + from fastapi_toolsets.schemas import CursorPagination + + await RoleCursorCrud.create(db_session, RoleCreate(name="admin")) + + result = await RoleCursorCrud.paginate( + db_session, + pagination_type=PaginationType.CURSOR, + schema=RoleRead, + ) + + assert isinstance(result.pagination, CursorPagination) + assert result.pagination_type == PaginationType.CURSOR + + @pytest.mark.anyio + async def test_invalid_items_per_page_raises(self, db_session: AsyncSession): + """paginate() raises ValueError when items_per_page < 1.""" + with pytest.raises(ValueError, match="items_per_page"): + await RoleCrud.paginate( + db_session, + pagination_type=PaginationType.OFFSET, + items_per_page=0, + schema=RoleRead, + ) + + @pytest.mark.anyio + async def test_invalid_page_raises(self, db_session: AsyncSession): + """paginate() raises ValueError when page < 1 for offset pagination.""" + with pytest.raises(ValueError, match="page"): + await RoleCrud.paginate( + db_session, + pagination_type=PaginationType.OFFSET, + page=0, + schema=RoleRead, + ) + + @pytest.mark.anyio + async def test_unknown_pagination_type_raises(self, db_session: AsyncSession): + """paginate() raises ValueError for unknown pagination_type.""" + with pytest.raises(ValueError, match="Unknown pagination_type"): + await RoleCrud.paginate( + db_session, + pagination_type="unknown", + schema=RoleRead, + ) # type: ignore[no-matching-overload]