mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-03-01 17:00:48 +01:00
fix: handle Date, Float, Numeric cursor column types in cursor_paginate (#90)
This commit is contained in:
@@ -148,6 +148,16 @@ The cursor column is set once on [`CrudFactory`](../reference/crud.md#fastapi_to
|
|||||||
!!! note
|
!!! note
|
||||||
`cursor_column` is required. Calling [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate) on a CRUD class that has no `cursor_column` configured raises a `ValueError`.
|
`cursor_column` is required. Calling [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate) on a CRUD class that has no `cursor_column` configured raises a `ValueError`.
|
||||||
|
|
||||||
|
The cursor value is base64-encoded when returned to the client and decoded back to the correct Python type on the next request. The following SQLAlchemy column types are supported:
|
||||||
|
|
||||||
|
| SQLAlchemy type | Python type |
|
||||||
|
|---|---|
|
||||||
|
| `Integer`, `BigInteger`, `SmallInteger` | `int` |
|
||||||
|
| `Uuid` | `uuid.UUID` |
|
||||||
|
| `DateTime` | `datetime.datetime` |
|
||||||
|
| `Date` | `datetime.date` |
|
||||||
|
| `Float`, `Numeric` | `decimal.Decimal` |
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# Paginate by the primary key
|
# Paginate by the primary key
|
||||||
PostCrud = CrudFactory(model=Post, cursor_column=Post.id)
|
PostCrud = CrudFactory(model=Post, cursor_column=Post.id)
|
||||||
|
|||||||
@@ -7,10 +7,12 @@ import json
|
|||||||
import uuid as uuid_module
|
import uuid as uuid_module
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
|
from datetime import date, datetime
|
||||||
|
from decimal import Decimal
|
||||||
from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload
|
from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import Integer, Uuid, and_, func, select
|
from sqlalchemy import Date, DateTime, Float, Integer, Numeric, Uuid, and_, func, select
|
||||||
from sqlalchemy import delete as sql_delete
|
from sqlalchemy import delete as sql_delete
|
||||||
from sqlalchemy.dialects.postgresql import insert
|
from sqlalchemy.dialects.postgresql import insert
|
||||||
from sqlalchemy.exc import NoResultFound
|
from sqlalchemy.exc import NoResultFound
|
||||||
@@ -920,8 +922,18 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
cursor_val: Any = int(raw_val)
|
cursor_val: Any = int(raw_val)
|
||||||
elif isinstance(col_type, Uuid):
|
elif isinstance(col_type, Uuid):
|
||||||
cursor_val = uuid_module.UUID(raw_val)
|
cursor_val = uuid_module.UUID(raw_val)
|
||||||
|
elif isinstance(col_type, DateTime):
|
||||||
|
cursor_val = datetime.fromisoformat(raw_val)
|
||||||
|
elif isinstance(col_type, Date):
|
||||||
|
cursor_val = date.fromisoformat(raw_val)
|
||||||
|
elif isinstance(col_type, (Float, Numeric)):
|
||||||
|
cursor_val = Decimal(raw_val)
|
||||||
else:
|
else:
|
||||||
cursor_val = raw_val
|
raise ValueError(
|
||||||
|
f"Unsupported cursor column type: {type(col_type).__name__!r}. "
|
||||||
|
"Supported types: Integer, BigInteger, SmallInteger, Uuid, "
|
||||||
|
"DateTime, Date, Float, Numeric."
|
||||||
|
)
|
||||||
filters.append(cursor_column > cursor_val)
|
filters.append(cursor_column > cursor_val)
|
||||||
|
|
||||||
# Build search filters
|
# Build search filters
|
||||||
@@ -1016,8 +1028,9 @@ def CrudFactory(
|
|||||||
instead of ``lazy="selectin"`` on the model so that loading
|
instead of ``lazy="selectin"`` on the model so that loading
|
||||||
strategy is explicit and per-CRUD. Overridden entirely (not
|
strategy is explicit and per-CRUD. Overridden entirely (not
|
||||||
merged) when ``load_options`` is provided at call-site.
|
merged) when ``load_options`` is provided at call-site.
|
||||||
cursor_column: Required to call ``cursor_paginate``
|
cursor_column: Required to call ``cursor_paginate``.
|
||||||
Must be monotonically ordered (e.g. integer PK, UUID v7, timestamp).
|
Must be monotonically ordered (e.g. integer PK, UUID v7, timestamp).
|
||||||
|
See the cursor pagination docs for supported column types.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
AsyncCrud subclass bound to the model
|
AsyncCrud subclass bound to the model
|
||||||
|
|||||||
@@ -5,7 +5,20 @@ import uuid
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import Column, ForeignKey, Integer, String, Table, Uuid
|
import datetime
|
||||||
|
import decimal
|
||||||
|
|
||||||
|
from sqlalchemy import (
|
||||||
|
Column,
|
||||||
|
Date,
|
||||||
|
DateTime,
|
||||||
|
ForeignKey,
|
||||||
|
Integer,
|
||||||
|
Numeric,
|
||||||
|
String,
|
||||||
|
Table,
|
||||||
|
Uuid,
|
||||||
|
)
|
||||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||||
|
|
||||||
@@ -79,6 +92,27 @@ class IntRole(Base):
|
|||||||
name: Mapped[str] = mapped_column(String(50), unique=True)
|
name: Mapped[str] = mapped_column(String(50), unique=True)
|
||||||
|
|
||||||
|
|
||||||
|
class Event(Base):
|
||||||
|
"""Test model with DateTime and Date cursor columns."""
|
||||||
|
|
||||||
|
__tablename__ = "events"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||||
|
name: Mapped[str] = mapped_column(String(100))
|
||||||
|
occurred_at: Mapped[datetime.datetime] = mapped_column(DateTime)
|
||||||
|
scheduled_date: Mapped[datetime.date] = mapped_column(Date)
|
||||||
|
|
||||||
|
|
||||||
|
class Product(Base):
|
||||||
|
"""Test model with Numeric cursor column."""
|
||||||
|
|
||||||
|
__tablename__ = "products"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||||
|
name: Mapped[str] = mapped_column(String(100))
|
||||||
|
price: Mapped[decimal.Decimal] = mapped_column(Numeric(10, 2))
|
||||||
|
|
||||||
|
|
||||||
class Post(Base):
|
class Post(Base):
|
||||||
"""Test post model."""
|
"""Test post model."""
|
||||||
|
|
||||||
@@ -190,6 +224,21 @@ class IntRoleCreate(BaseModel):
|
|||||||
name: str
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class EventCreate(BaseModel):
|
||||||
|
"""Schema for creating an Event."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
occurred_at: datetime.datetime
|
||||||
|
scheduled_date: datetime.date
|
||||||
|
|
||||||
|
|
||||||
|
class ProductCreate(BaseModel):
|
||||||
|
"""Schema for creating a Product."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
price: decimal.Decimal
|
||||||
|
|
||||||
|
|
||||||
RoleCrud = CrudFactory(Role)
|
RoleCrud = CrudFactory(Role)
|
||||||
RoleCursorCrud = CrudFactory(Role, cursor_column=Role.id)
|
RoleCursorCrud = CrudFactory(Role, cursor_column=Role.id)
|
||||||
IntRoleCursorCrud = CrudFactory(IntRole, cursor_column=IntRole.id)
|
IntRoleCursorCrud = CrudFactory(IntRole, cursor_column=IntRole.id)
|
||||||
@@ -198,6 +247,11 @@ UserCursorCrud = CrudFactory(User, cursor_column=User.id)
|
|||||||
PostCrud = CrudFactory(Post)
|
PostCrud = CrudFactory(Post)
|
||||||
TagCrud = CrudFactory(Tag)
|
TagCrud = CrudFactory(Tag)
|
||||||
PostM2MCrud = CrudFactory(Post, m2m_fields={"tag_ids": Post.tags})
|
PostM2MCrud = CrudFactory(Post, m2m_fields={"tag_ids": Post.tags})
|
||||||
|
EventCrud = CrudFactory(Event)
|
||||||
|
EventDateTimeCursorCrud = CrudFactory(Event, cursor_column=Event.occurred_at)
|
||||||
|
EventDateCursorCrud = CrudFactory(Event, cursor_column=Event.scheduled_date)
|
||||||
|
ProductCrud = CrudFactory(Product)
|
||||||
|
ProductNumericCursorCrud = CrudFactory(Product, cursor_column=Product.price)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
@@ -11,6 +11,10 @@ from fastapi_toolsets.crud.factory import AsyncCrud
|
|||||||
from fastapi_toolsets.exceptions import NotFoundError
|
from fastapi_toolsets.exceptions import NotFoundError
|
||||||
|
|
||||||
from .conftest import (
|
from .conftest import (
|
||||||
|
EventCreate,
|
||||||
|
EventCrud,
|
||||||
|
EventDateCursorCrud,
|
||||||
|
EventDateTimeCursorCrud,
|
||||||
IntRoleCreate,
|
IntRoleCreate,
|
||||||
IntRoleCursorCrud,
|
IntRoleCursorCrud,
|
||||||
Post,
|
Post,
|
||||||
@@ -19,6 +23,9 @@ from .conftest import (
|
|||||||
PostM2MCreate,
|
PostM2MCreate,
|
||||||
PostM2MCrud,
|
PostM2MCrud,
|
||||||
PostM2MUpdate,
|
PostM2MUpdate,
|
||||||
|
ProductCreate,
|
||||||
|
ProductCrud,
|
||||||
|
ProductNumericCursorCrud,
|
||||||
Role,
|
Role,
|
||||||
RoleCreate,
|
RoleCreate,
|
||||||
RoleCrud,
|
RoleCrud,
|
||||||
@@ -1935,32 +1942,31 @@ class TestCursorPaginateExtraOptions:
|
|||||||
assert page2.pagination.has_more is False
|
assert page2.pagination.has_more is False
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_string_cursor_column(self, db_session: AsyncSession):
|
async def test_unsupported_cursor_column_type_raises(
|
||||||
"""cursor_paginate decodes non-UUID/non-Integer cursor values (string branch)."""
|
self, db_session: AsyncSession
|
||||||
|
):
|
||||||
|
"""cursor_paginate raises ValueError when cursor column type is not supported."""
|
||||||
from fastapi_toolsets.crud import CrudFactory
|
from fastapi_toolsets.crud import CrudFactory
|
||||||
from fastapi_toolsets.schemas import CursorPagination
|
from fastapi_toolsets.schemas import CursorPagination
|
||||||
|
|
||||||
RoleNameCursorCrud = CrudFactory(Role, cursor_column=Role.name)
|
RoleNameCursorCrud = CrudFactory(Role, cursor_column=Role.name)
|
||||||
|
|
||||||
for i in range(5):
|
await RoleCrud.create(db_session, RoleCreate(name="role00"))
|
||||||
await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
|
await RoleCrud.create(db_session, RoleCreate(name="role01"))
|
||||||
|
|
||||||
page1 = await RoleNameCursorCrud.cursor_paginate(db_session, items_per_page=3)
|
# First page succeeds (no cursor to decode)
|
||||||
|
page1 = await RoleNameCursorCrud.cursor_paginate(db_session, items_per_page=1)
|
||||||
assert isinstance(page1.pagination, CursorPagination)
|
|
||||||
assert len(page1.data) == 3
|
|
||||||
assert page1.pagination.has_more is True
|
assert page1.pagination.has_more is True
|
||||||
|
assert isinstance(page1.pagination, CursorPagination)
|
||||||
|
|
||||||
page2 = await RoleNameCursorCrud.cursor_paginate(
|
# Second page fails because String type is unsupported
|
||||||
|
with pytest.raises(ValueError, match="Unsupported cursor column type"):
|
||||||
|
await RoleNameCursorCrud.cursor_paginate(
|
||||||
db_session,
|
db_session,
|
||||||
cursor=page1.pagination.next_cursor,
|
cursor=page1.pagination.next_cursor,
|
||||||
items_per_page=3,
|
items_per_page=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(page2.pagination, CursorPagination)
|
|
||||||
assert len(page2.data) == 2
|
|
||||||
assert page2.pagination.has_more is False
|
|
||||||
|
|
||||||
|
|
||||||
class TestCursorPaginateSearchJoins:
|
class TestCursorPaginateSearchJoins:
|
||||||
"""Tests for cursor_paginate() search that traverses relationships (search_joins)."""
|
"""Tests for cursor_paginate() search that traverses relationships (search_joins)."""
|
||||||
@@ -2019,3 +2025,120 @@ class TestGetWithForUpdate:
|
|||||||
|
|
||||||
assert result.id == role.id
|
assert result.id == role.id
|
||||||
assert result.name == "locked"
|
assert result.name == "locked"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCursorPaginateColumnTypes:
|
||||||
|
"""Tests for cursor_paginate() covering DateTime, Date and Numeric column types."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_datetime_cursor_column(self, db_session: AsyncSession):
|
||||||
|
"""cursor_paginate decodes DateTime cursor values to datetime objects."""
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
from fastapi_toolsets.schemas import CursorPagination
|
||||||
|
|
||||||
|
base = datetime.datetime(2024, 1, 1, 0, 0, 0)
|
||||||
|
for i in range(5):
|
||||||
|
await EventCrud.create(
|
||||||
|
db_session,
|
||||||
|
EventCreate(
|
||||||
|
name=f"event{i}",
|
||||||
|
occurred_at=base + datetime.timedelta(hours=i),
|
||||||
|
scheduled_date=datetime.date(2024, 1, i + 1),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
page1 = await EventDateTimeCursorCrud.cursor_paginate(
|
||||||
|
db_session, items_per_page=3
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(page1.pagination, CursorPagination)
|
||||||
|
assert len(page1.data) == 3
|
||||||
|
assert page1.pagination.has_more is True
|
||||||
|
|
||||||
|
page2 = await EventDateTimeCursorCrud.cursor_paginate(
|
||||||
|
db_session,
|
||||||
|
cursor=page1.pagination.next_cursor,
|
||||||
|
items_per_page=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(page2.pagination, CursorPagination)
|
||||||
|
assert len(page2.data) == 2
|
||||||
|
assert page2.pagination.has_more is False
|
||||||
|
# No overlap between pages
|
||||||
|
page1_ids = {e.id for e in page1.data}
|
||||||
|
page2_ids = {e.id for e in page2.data}
|
||||||
|
assert page1_ids.isdisjoint(page2_ids)
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_date_cursor_column(self, db_session: AsyncSession):
|
||||||
|
"""cursor_paginate decodes Date cursor values to date objects."""
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
from fastapi_toolsets.schemas import CursorPagination
|
||||||
|
|
||||||
|
for i in range(5):
|
||||||
|
await EventCrud.create(
|
||||||
|
db_session,
|
||||||
|
EventCreate(
|
||||||
|
name=f"event{i}",
|
||||||
|
occurred_at=datetime.datetime(2024, 1, 1),
|
||||||
|
scheduled_date=datetime.date(2024, 1, i + 1),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
page1 = await EventDateCursorCrud.cursor_paginate(db_session, items_per_page=3)
|
||||||
|
|
||||||
|
assert isinstance(page1.pagination, CursorPagination)
|
||||||
|
assert len(page1.data) == 3
|
||||||
|
assert page1.pagination.has_more is True
|
||||||
|
|
||||||
|
page2 = await EventDateCursorCrud.cursor_paginate(
|
||||||
|
db_session,
|
||||||
|
cursor=page1.pagination.next_cursor,
|
||||||
|
items_per_page=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(page2.pagination, CursorPagination)
|
||||||
|
assert len(page2.data) == 2
|
||||||
|
assert page2.pagination.has_more is False
|
||||||
|
page1_ids = {e.id for e in page1.data}
|
||||||
|
page2_ids = {e.id for e in page2.data}
|
||||||
|
assert page1_ids.isdisjoint(page2_ids)
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_numeric_cursor_column(self, db_session: AsyncSession):
|
||||||
|
"""cursor_paginate decodes Numeric cursor values to Decimal objects."""
|
||||||
|
import decimal
|
||||||
|
|
||||||
|
from fastapi_toolsets.schemas import CursorPagination
|
||||||
|
|
||||||
|
for i in range(5):
|
||||||
|
await ProductCrud.create(
|
||||||
|
db_session,
|
||||||
|
ProductCreate(
|
||||||
|
name=f"product{i}",
|
||||||
|
price=decimal.Decimal(f"{i + 1}.99"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
page1 = await ProductNumericCursorCrud.cursor_paginate(
|
||||||
|
db_session, items_per_page=3
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(page1.pagination, CursorPagination)
|
||||||
|
assert len(page1.data) == 3
|
||||||
|
assert page1.pagination.has_more is True
|
||||||
|
|
||||||
|
page2 = await ProductNumericCursorCrud.cursor_paginate(
|
||||||
|
db_session,
|
||||||
|
cursor=page1.pagination.next_cursor,
|
||||||
|
items_per_page=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(page2.pagination, CursorPagination)
|
||||||
|
assert len(page2.data) == 2
|
||||||
|
assert page2.pagination.has_more is False
|
||||||
|
page1_ids = {p.id for p in page1.data}
|
||||||
|
page2_ids = {p.id for p in page2.data}
|
||||||
|
assert page1_ids.isdisjoint(page2_ids)
|
||||||
|
|||||||
Reference in New Issue
Block a user