From 7da34f33a20b619d8c41a9cfb76972f01d71a6a8 Mon Sep 17 00:00:00 2001 From: d3vyce <44915747+d3vyce@users.noreply.github.com> Date: Mon, 23 Feb 2026 20:58:43 +0100 Subject: [PATCH] fix: handle Date, Float, Numeric cursor column types in cursor_paginate (#90) --- docs/module/crud.md | 10 ++ src/fastapi_toolsets/crud/factory.py | 19 +++- tests/conftest.py | 56 +++++++++- tests/test_crud.py | 157 ++++++++++++++++++++++++--- 4 files changed, 221 insertions(+), 21 deletions(-) diff --git a/docs/module/crud.md b/docs/module/crud.md index 9f34ecf..d869b63 100644 --- a/docs/module/crud.md +++ b/docs/module/crud.md @@ -148,6 +148,16 @@ The cursor column is set once on [`CrudFactory`](../reference/crud.md#fastapi_to !!! note `cursor_column` is required. Calling [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate) on a CRUD class that has no `cursor_column` configured raises a `ValueError`. +The cursor value is base64-encoded when returned to the client and decoded back to the correct Python type on the next request. The following SQLAlchemy column types are supported: + +| SQLAlchemy type | Python type | +|---|---| +| `Integer`, `BigInteger`, `SmallInteger` | `int` | +| `Uuid` | `uuid.UUID` | +| `DateTime` | `datetime.datetime` | +| `Date` | `datetime.date` | +| `Float`, `Numeric` | `decimal.Decimal` | + ```python # Paginate by the primary key PostCrud = CrudFactory(model=Post, cursor_column=Post.id) diff --git a/src/fastapi_toolsets/crud/factory.py b/src/fastapi_toolsets/crud/factory.py index 6323af6..75b08ad 100644 --- a/src/fastapi_toolsets/crud/factory.py +++ b/src/fastapi_toolsets/crud/factory.py @@ -7,10 +7,12 @@ import json import uuid as uuid_module import warnings from collections.abc import Mapping, Sequence +from datetime import date, datetime +from decimal import Decimal from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload from pydantic import BaseModel -from sqlalchemy import Integer, Uuid, and_, func, select +from sqlalchemy import Date, DateTime, Float, Integer, Numeric, Uuid, and_, func, select from sqlalchemy import delete as sql_delete from sqlalchemy.dialects.postgresql import insert from sqlalchemy.exc import NoResultFound @@ -920,8 +922,18 @@ class AsyncCrud(Generic[ModelType]): cursor_val: Any = int(raw_val) elif isinstance(col_type, Uuid): cursor_val = uuid_module.UUID(raw_val) + elif isinstance(col_type, DateTime): + cursor_val = datetime.fromisoformat(raw_val) + elif isinstance(col_type, Date): + cursor_val = date.fromisoformat(raw_val) + elif isinstance(col_type, (Float, Numeric)): + cursor_val = Decimal(raw_val) else: - cursor_val = raw_val + raise ValueError( + f"Unsupported cursor column type: {type(col_type).__name__!r}. " + "Supported types: Integer, BigInteger, SmallInteger, Uuid, " + "DateTime, Date, Float, Numeric." + ) filters.append(cursor_column > cursor_val) # Build search filters @@ -1016,8 +1028,9 @@ def CrudFactory( instead of ``lazy="selectin"`` on the model so that loading strategy is explicit and per-CRUD. Overridden entirely (not merged) when ``load_options`` is provided at call-site. - cursor_column: Required to call ``cursor_paginate`` + cursor_column: Required to call ``cursor_paginate``. Must be monotonically ordered (e.g. integer PK, UUID v7, timestamp). + See the cursor pagination docs for supported column types. Returns: AsyncCrud subclass bound to the model diff --git a/tests/conftest.py b/tests/conftest.py index 0c96388..5885118 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,20 @@ import uuid import pytest from pydantic import BaseModel -from sqlalchemy import Column, ForeignKey, Integer, String, Table, Uuid +import datetime +import decimal + +from sqlalchemy import ( + Column, + Date, + DateTime, + ForeignKey, + Integer, + Numeric, + String, + Table, + Uuid, +) from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship @@ -79,6 +92,27 @@ class IntRole(Base): name: Mapped[str] = mapped_column(String(50), unique=True) +class Event(Base): + """Test model with DateTime and Date cursor columns.""" + + __tablename__ = "events" + + id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4) + name: Mapped[str] = mapped_column(String(100)) + occurred_at: Mapped[datetime.datetime] = mapped_column(DateTime) + scheduled_date: Mapped[datetime.date] = mapped_column(Date) + + +class Product(Base): + """Test model with Numeric cursor column.""" + + __tablename__ = "products" + + id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4) + name: Mapped[str] = mapped_column(String(100)) + price: Mapped[decimal.Decimal] = mapped_column(Numeric(10, 2)) + + class Post(Base): """Test post model.""" @@ -190,6 +224,21 @@ class IntRoleCreate(BaseModel): name: str +class EventCreate(BaseModel): + """Schema for creating an Event.""" + + name: str + occurred_at: datetime.datetime + scheduled_date: datetime.date + + +class ProductCreate(BaseModel): + """Schema for creating a Product.""" + + name: str + price: decimal.Decimal + + RoleCrud = CrudFactory(Role) RoleCursorCrud = CrudFactory(Role, cursor_column=Role.id) IntRoleCursorCrud = CrudFactory(IntRole, cursor_column=IntRole.id) @@ -198,6 +247,11 @@ UserCursorCrud = CrudFactory(User, cursor_column=User.id) PostCrud = CrudFactory(Post) TagCrud = CrudFactory(Tag) PostM2MCrud = CrudFactory(Post, m2m_fields={"tag_ids": Post.tags}) +EventCrud = CrudFactory(Event) +EventDateTimeCursorCrud = CrudFactory(Event, cursor_column=Event.occurred_at) +EventDateCursorCrud = CrudFactory(Event, cursor_column=Event.scheduled_date) +ProductCrud = CrudFactory(Product) +ProductNumericCursorCrud = CrudFactory(Product, cursor_column=Product.price) @pytest.fixture diff --git a/tests/test_crud.py b/tests/test_crud.py index b15d011..ee94c71 100644 --- a/tests/test_crud.py +++ b/tests/test_crud.py @@ -11,6 +11,10 @@ from fastapi_toolsets.crud.factory import AsyncCrud from fastapi_toolsets.exceptions import NotFoundError from .conftest import ( + EventCreate, + EventCrud, + EventDateCursorCrud, + EventDateTimeCursorCrud, IntRoleCreate, IntRoleCursorCrud, Post, @@ -19,6 +23,9 @@ from .conftest import ( PostM2MCreate, PostM2MCrud, PostM2MUpdate, + ProductCreate, + ProductCrud, + ProductNumericCursorCrud, Role, RoleCreate, RoleCrud, @@ -1935,31 +1942,30 @@ class TestCursorPaginateExtraOptions: assert page2.pagination.has_more is False @pytest.mark.anyio - async def test_string_cursor_column(self, db_session: AsyncSession): - """cursor_paginate decodes non-UUID/non-Integer cursor values (string branch).""" + async def test_unsupported_cursor_column_type_raises( + self, db_session: AsyncSession + ): + """cursor_paginate raises ValueError when cursor column type is not supported.""" from fastapi_toolsets.crud import CrudFactory from fastapi_toolsets.schemas import CursorPagination RoleNameCursorCrud = CrudFactory(Role, cursor_column=Role.name) - for i in range(5): - await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}")) + await RoleCrud.create(db_session, RoleCreate(name="role00")) + await RoleCrud.create(db_session, RoleCreate(name="role01")) - page1 = await RoleNameCursorCrud.cursor_paginate(db_session, items_per_page=3) - - assert isinstance(page1.pagination, CursorPagination) - assert len(page1.data) == 3 + # First page succeeds (no cursor to decode) + page1 = await RoleNameCursorCrud.cursor_paginate(db_session, items_per_page=1) assert page1.pagination.has_more is True + assert isinstance(page1.pagination, CursorPagination) - page2 = await RoleNameCursorCrud.cursor_paginate( - db_session, - cursor=page1.pagination.next_cursor, - items_per_page=3, - ) - - assert isinstance(page2.pagination, CursorPagination) - assert len(page2.data) == 2 - assert page2.pagination.has_more is False + # Second page fails because String type is unsupported + with pytest.raises(ValueError, match="Unsupported cursor column type"): + await RoleNameCursorCrud.cursor_paginate( + db_session, + cursor=page1.pagination.next_cursor, + items_per_page=1, + ) class TestCursorPaginateSearchJoins: @@ -2019,3 +2025,120 @@ class TestGetWithForUpdate: assert result.id == role.id assert result.name == "locked" + + +class TestCursorPaginateColumnTypes: + """Tests for cursor_paginate() covering DateTime, Date and Numeric column types.""" + + @pytest.mark.anyio + async def test_datetime_cursor_column(self, db_session: AsyncSession): + """cursor_paginate decodes DateTime cursor values to datetime objects.""" + import datetime + + from fastapi_toolsets.schemas import CursorPagination + + base = datetime.datetime(2024, 1, 1, 0, 0, 0) + for i in range(5): + await EventCrud.create( + db_session, + EventCreate( + name=f"event{i}", + occurred_at=base + datetime.timedelta(hours=i), + scheduled_date=datetime.date(2024, 1, i + 1), + ), + ) + + page1 = await EventDateTimeCursorCrud.cursor_paginate( + db_session, items_per_page=3 + ) + + assert isinstance(page1.pagination, CursorPagination) + assert len(page1.data) == 3 + assert page1.pagination.has_more is True + + page2 = await EventDateTimeCursorCrud.cursor_paginate( + db_session, + cursor=page1.pagination.next_cursor, + items_per_page=3, + ) + + assert isinstance(page2.pagination, CursorPagination) + assert len(page2.data) == 2 + assert page2.pagination.has_more is False + # No overlap between pages + page1_ids = {e.id for e in page1.data} + page2_ids = {e.id for e in page2.data} + assert page1_ids.isdisjoint(page2_ids) + + @pytest.mark.anyio + async def test_date_cursor_column(self, db_session: AsyncSession): + """cursor_paginate decodes Date cursor values to date objects.""" + import datetime + + from fastapi_toolsets.schemas import CursorPagination + + for i in range(5): + await EventCrud.create( + db_session, + EventCreate( + name=f"event{i}", + occurred_at=datetime.datetime(2024, 1, 1), + scheduled_date=datetime.date(2024, 1, i + 1), + ), + ) + + page1 = await EventDateCursorCrud.cursor_paginate(db_session, items_per_page=3) + + assert isinstance(page1.pagination, CursorPagination) + assert len(page1.data) == 3 + assert page1.pagination.has_more is True + + page2 = await EventDateCursorCrud.cursor_paginate( + db_session, + cursor=page1.pagination.next_cursor, + items_per_page=3, + ) + + assert isinstance(page2.pagination, CursorPagination) + assert len(page2.data) == 2 + assert page2.pagination.has_more is False + page1_ids = {e.id for e in page1.data} + page2_ids = {e.id for e in page2.data} + assert page1_ids.isdisjoint(page2_ids) + + @pytest.mark.anyio + async def test_numeric_cursor_column(self, db_session: AsyncSession): + """cursor_paginate decodes Numeric cursor values to Decimal objects.""" + import decimal + + from fastapi_toolsets.schemas import CursorPagination + + for i in range(5): + await ProductCrud.create( + db_session, + ProductCreate( + name=f"product{i}", + price=decimal.Decimal(f"{i + 1}.99"), + ), + ) + + page1 = await ProductNumericCursorCrud.cursor_paginate( + db_session, items_per_page=3 + ) + + assert isinstance(page1.pagination, CursorPagination) + assert len(page1.data) == 3 + assert page1.pagination.has_more is True + + page2 = await ProductNumericCursorCrud.cursor_paginate( + db_session, + cursor=page1.pagination.next_cursor, + items_per_page=3, + ) + + assert isinstance(page2.pagination, CursorPagination) + assert len(page2.data) == 2 + assert page2.pagination.has_more is False + page1_ids = {p.id for p in page1.data} + page2_ids = {p.id for p in page2.data} + assert page1_ids.isdisjoint(page2_ids)