refactor: CursorDirection enum, cursor value parsing and __class_getitem__ caching (#144)

This commit is contained in:
d3vyce
2026-03-15 17:48:50 +01:00
committed by GitHub
parent c863744012
commit fd7269a372
4 changed files with 63 additions and 38 deletions

View File

@@ -9,6 +9,7 @@ import uuid as uuid_module
from collections.abc import Awaitable, Callable, Sequence
from datetime import date, datetime
from decimal import Decimal
from enum import Enum
from typing import Any, ClassVar, Generic, Literal, Self, cast, overload
from fastapi import Query
@@ -49,17 +50,43 @@ from .search import (
)
def _encode_cursor(value: Any, *, direction: str = "next") -> str:
class _CursorDirection(str, Enum):
NEXT = "next"
PREV = "prev"
def _encode_cursor(
value: Any, *, direction: _CursorDirection = _CursorDirection.NEXT
) -> str:
"""Encode a cursor column value and navigation direction as a base64 string."""
return base64.b64encode(
json.dumps({"val": str(value), "dir": direction}).encode()
).decode()
def _decode_cursor(cursor: str) -> tuple[str, str]:
def _decode_cursor(cursor: str) -> tuple[str, _CursorDirection]:
"""Decode a cursor base64 string into ``(raw_value, direction)``."""
payload = json.loads(base64.b64decode(cursor.encode()).decode())
return payload["val"], payload["dir"]
return payload["val"], _CursorDirection(payload["dir"])
def _parse_cursor_value(raw_val: str, col_type: Any) -> Any:
"""Parse a raw cursor string value back into the appropriate Python type."""
if isinstance(col_type, Integer):
return int(raw_val)
if isinstance(col_type, Uuid):
return uuid_module.UUID(raw_val)
if isinstance(col_type, DateTime):
return datetime.fromisoformat(raw_val)
if isinstance(col_type, Date):
return date.fromisoformat(raw_val)
if isinstance(col_type, (Float, Numeric)):
return Decimal(raw_val)
raise ValueError(
f"Unsupported cursor column type: {type(col_type).__name__!r}. "
"Supported types: Integer, BigInteger, SmallInteger, Uuid, "
"DateTime, Date, Float, Numeric."
)
def _apply_joins(q: Any, joins: JoinType | None, outer_join: bool) -> Any:
@@ -108,8 +135,9 @@ class AsyncCrud(Generic[ModelType]):
if raw_fields is None:
cls.searchable_fields = [pk_col]
else:
existing_keys = {f.key for f in raw_fields if not isinstance(f, tuple)}
if pk_key not in existing_keys:
if not any(
not isinstance(f, tuple) and f.key == pk_key for f in raw_fields
):
cls.searchable_fields = [pk_col, *raw_fields]
@classmethod
@@ -1048,27 +1076,12 @@ class AsyncCrud(Generic[ModelType]):
cursor_column: Any = cls.cursor_column
cursor_col_name: str = cursor_column.key
direction = "next"
direction = _CursorDirection.NEXT
if cursor is not None:
raw_val, direction = _decode_cursor(cursor)
col_type = cursor_column.property.columns[0].type
if isinstance(col_type, Integer):
cursor_val: Any = int(raw_val)
elif isinstance(col_type, Uuid):
cursor_val = uuid_module.UUID(raw_val)
elif isinstance(col_type, DateTime):
cursor_val = datetime.fromisoformat(raw_val)
elif isinstance(col_type, Date):
cursor_val = date.fromisoformat(raw_val)
elif isinstance(col_type, (Float, Numeric)):
cursor_val = Decimal(raw_val)
else:
raise ValueError(
f"Unsupported cursor column type: {type(col_type).__name__!r}. "
"Supported types: Integer, BigInteger, SmallInteger, Uuid, "
"DateTime, Date, Float, Numeric."
)
if direction == "prev":
cursor_val: Any = _parse_cursor_value(raw_val, col_type)
if direction is _CursorDirection.PREV:
filters.append(cursor_column < cursor_val)
else:
filters.append(cursor_column > cursor_val)
@@ -1099,7 +1112,7 @@ class AsyncCrud(Generic[ModelType]):
q = q.options(*resolved)
# Cursor column is always the primary sort; reverse direction for prev traversal
if direction == "prev":
if direction is _CursorDirection.PREV:
q = q.order_by(cursor_column.desc())
else:
q = q.order_by(cursor_column)
@@ -1115,32 +1128,34 @@ class AsyncCrud(Generic[ModelType]):
items_page = raw_items[:items_per_page]
# Restore ascending order when traversing backward
if direction == "prev":
if direction is _CursorDirection.PREV:
items_page = list(reversed(items_page))
# next_cursor: points past the last item in ascending order
next_cursor: str | None = None
if direction == "next":
if direction is _CursorDirection.NEXT:
if has_more and items_page:
next_cursor = _encode_cursor(
getattr(items_page[-1], cursor_col_name), direction="next"
getattr(items_page[-1], cursor_col_name),
direction=_CursorDirection.NEXT,
)
else:
# Going backward: always provide a next_cursor to allow returning forward
if items_page:
next_cursor = _encode_cursor(
getattr(items_page[-1], cursor_col_name), direction="next"
getattr(items_page[-1], cursor_col_name),
direction=_CursorDirection.NEXT,
)
# prev_cursor: points before the first item in ascending order
prev_cursor: str | None = None
if direction == "next" and cursor is not None and items_page:
if direction is _CursorDirection.NEXT and cursor is not None and items_page:
prev_cursor = _encode_cursor(
getattr(items_page[0], cursor_col_name), direction="prev"
getattr(items_page[0], cursor_col_name), direction=_CursorDirection.PREV
)
elif direction == "prev" and has_more and items_page:
elif direction is _CursorDirection.PREV and has_more and items_page:
prev_cursor = _encode_cursor(
getattr(items_page[0], cursor_col_name), direction="prev"
getattr(items_page[0], cursor_col_name), direction=_CursorDirection.PREV
)
items: list[Any] = [schema.model_validate(item) for item in items_page]

View File

@@ -150,14 +150,20 @@ class PaginatedResponse(BaseResponse, Generic[DataT]):
pagination_type: PaginationType | None = None
filter_attributes: dict[str, list[Any]] | None = None
_discriminated_union_cache: ClassVar[dict[Any, Any]] = {}
def __class_getitem__( # type: ignore[invalid-method-override]
cls, item: type[Any] | tuple[type[Any], ...]
) -> type[Any]:
if cls is PaginatedResponse and not isinstance(item, TypeVar):
return Annotated[ # type: ignore[invalid-return-type]
Union[CursorPaginatedResponse[item], OffsetPaginatedResponse[item]], # type: ignore[invalid-type-form]
Field(discriminator="pagination_type"),
]
cached = cls._discriminated_union_cache.get(item)
if cached is None:
cached = Annotated[
Union[CursorPaginatedResponse[item], OffsetPaginatedResponse[item]], # type: ignore[invalid-type-form]
Field(discriminator="pagination_type"),
]
cls._discriminated_union_cache[item] = cached
return cached # type: ignore[invalid-return-type]
return super().__class_getitem__(item)

View File

@@ -7,7 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from fastapi_toolsets.crud import CrudFactory, PaginationType
from fastapi_toolsets.crud.factory import AsyncCrud
from fastapi_toolsets.crud.factory import AsyncCrud, _CursorDirection
from fastapi_toolsets.exceptions import NotFoundError
from .conftest import (
@@ -2016,7 +2016,7 @@ class TestCursorPaginatePrevCursor:
assert isinstance(page1.pagination, CursorPagination)
# Manually craft a backward cursor before any existing id
before_all = _encode_cursor(0, direction="prev")
before_all = _encode_cursor(0, direction=_CursorDirection.PREV)
empty = await IntRoleCursorCrud.cursor_paginate(
db_session, cursor=before_all, items_per_page=5, schema=IntRoleRead
)

View File

@@ -324,6 +324,10 @@ class TestPaginatedResponse:
assert CursorPaginatedResponse[dict] in union_args
assert OffsetPaginatedResponse[dict] in union_args
def test_class_getitem_is_cached(self):
"""Repeated subscripting with the same type returns the identical cached object."""
assert PaginatedResponse[dict] is PaginatedResponse[dict]
def test_class_getitem_with_typevar_returns_generic(self):
"""PaginatedResponse[TypeVar] falls through to Pydantic generic parametrisation."""
from typing import TypeVar