refactor: CursorDirection enum, cursor value parsing and __class_getitem__ caching (#144)

This commit is contained in:
d3vyce
2026-03-15 17:48:50 +01:00
committed by GitHub
parent c863744012
commit fd7269a372
4 changed files with 63 additions and 38 deletions

View File

@@ -9,6 +9,7 @@ import uuid as uuid_module
from collections.abc import Awaitable, Callable, Sequence from collections.abc import Awaitable, Callable, Sequence
from datetime import date, datetime from datetime import date, datetime
from decimal import Decimal from decimal import Decimal
from enum import Enum
from typing import Any, ClassVar, Generic, Literal, Self, cast, overload from typing import Any, ClassVar, Generic, Literal, Self, cast, overload
from fastapi import Query from fastapi import Query
@@ -49,17 +50,43 @@ from .search import (
) )
def _encode_cursor(value: Any, *, direction: str = "next") -> str: class _CursorDirection(str, Enum):
NEXT = "next"
PREV = "prev"
def _encode_cursor(
value: Any, *, direction: _CursorDirection = _CursorDirection.NEXT
) -> str:
"""Encode a cursor column value and navigation direction as a base64 string.""" """Encode a cursor column value and navigation direction as a base64 string."""
return base64.b64encode( return base64.b64encode(
json.dumps({"val": str(value), "dir": direction}).encode() json.dumps({"val": str(value), "dir": direction}).encode()
).decode() ).decode()
def _decode_cursor(cursor: str) -> tuple[str, str]: def _decode_cursor(cursor: str) -> tuple[str, _CursorDirection]:
"""Decode a cursor base64 string into ``(raw_value, direction)``.""" """Decode a cursor base64 string into ``(raw_value, direction)``."""
payload = json.loads(base64.b64decode(cursor.encode()).decode()) payload = json.loads(base64.b64decode(cursor.encode()).decode())
return payload["val"], payload["dir"] return payload["val"], _CursorDirection(payload["dir"])
def _parse_cursor_value(raw_val: str, col_type: Any) -> Any:
"""Parse a raw cursor string value back into the appropriate Python type."""
if isinstance(col_type, Integer):
return int(raw_val)
if isinstance(col_type, Uuid):
return uuid_module.UUID(raw_val)
if isinstance(col_type, DateTime):
return datetime.fromisoformat(raw_val)
if isinstance(col_type, Date):
return date.fromisoformat(raw_val)
if isinstance(col_type, (Float, Numeric)):
return Decimal(raw_val)
raise ValueError(
f"Unsupported cursor column type: {type(col_type).__name__!r}. "
"Supported types: Integer, BigInteger, SmallInteger, Uuid, "
"DateTime, Date, Float, Numeric."
)
def _apply_joins(q: Any, joins: JoinType | None, outer_join: bool) -> Any: def _apply_joins(q: Any, joins: JoinType | None, outer_join: bool) -> Any:
@@ -108,8 +135,9 @@ class AsyncCrud(Generic[ModelType]):
if raw_fields is None: if raw_fields is None:
cls.searchable_fields = [pk_col] cls.searchable_fields = [pk_col]
else: else:
existing_keys = {f.key for f in raw_fields if not isinstance(f, tuple)} if not any(
if pk_key not in existing_keys: not isinstance(f, tuple) and f.key == pk_key for f in raw_fields
):
cls.searchable_fields = [pk_col, *raw_fields] cls.searchable_fields = [pk_col, *raw_fields]
@classmethod @classmethod
@@ -1048,27 +1076,12 @@ class AsyncCrud(Generic[ModelType]):
cursor_column: Any = cls.cursor_column cursor_column: Any = cls.cursor_column
cursor_col_name: str = cursor_column.key cursor_col_name: str = cursor_column.key
direction = "next" direction = _CursorDirection.NEXT
if cursor is not None: if cursor is not None:
raw_val, direction = _decode_cursor(cursor) raw_val, direction = _decode_cursor(cursor)
col_type = cursor_column.property.columns[0].type col_type = cursor_column.property.columns[0].type
if isinstance(col_type, Integer): cursor_val: Any = _parse_cursor_value(raw_val, col_type)
cursor_val: Any = int(raw_val) if direction is _CursorDirection.PREV:
elif isinstance(col_type, Uuid):
cursor_val = uuid_module.UUID(raw_val)
elif isinstance(col_type, DateTime):
cursor_val = datetime.fromisoformat(raw_val)
elif isinstance(col_type, Date):
cursor_val = date.fromisoformat(raw_val)
elif isinstance(col_type, (Float, Numeric)):
cursor_val = Decimal(raw_val)
else:
raise ValueError(
f"Unsupported cursor column type: {type(col_type).__name__!r}. "
"Supported types: Integer, BigInteger, SmallInteger, Uuid, "
"DateTime, Date, Float, Numeric."
)
if direction == "prev":
filters.append(cursor_column < cursor_val) filters.append(cursor_column < cursor_val)
else: else:
filters.append(cursor_column > cursor_val) filters.append(cursor_column > cursor_val)
@@ -1099,7 +1112,7 @@ class AsyncCrud(Generic[ModelType]):
q = q.options(*resolved) q = q.options(*resolved)
# Cursor column is always the primary sort; reverse direction for prev traversal # Cursor column is always the primary sort; reverse direction for prev traversal
if direction == "prev": if direction is _CursorDirection.PREV:
q = q.order_by(cursor_column.desc()) q = q.order_by(cursor_column.desc())
else: else:
q = q.order_by(cursor_column) q = q.order_by(cursor_column)
@@ -1115,32 +1128,34 @@ class AsyncCrud(Generic[ModelType]):
items_page = raw_items[:items_per_page] items_page = raw_items[:items_per_page]
# Restore ascending order when traversing backward # Restore ascending order when traversing backward
if direction == "prev": if direction is _CursorDirection.PREV:
items_page = list(reversed(items_page)) items_page = list(reversed(items_page))
# next_cursor: points past the last item in ascending order # next_cursor: points past the last item in ascending order
next_cursor: str | None = None next_cursor: str | None = None
if direction == "next": if direction is _CursorDirection.NEXT:
if has_more and items_page: if has_more and items_page:
next_cursor = _encode_cursor( next_cursor = _encode_cursor(
getattr(items_page[-1], cursor_col_name), direction="next" getattr(items_page[-1], cursor_col_name),
direction=_CursorDirection.NEXT,
) )
else: else:
# Going backward: always provide a next_cursor to allow returning forward # Going backward: always provide a next_cursor to allow returning forward
if items_page: if items_page:
next_cursor = _encode_cursor( next_cursor = _encode_cursor(
getattr(items_page[-1], cursor_col_name), direction="next" getattr(items_page[-1], cursor_col_name),
direction=_CursorDirection.NEXT,
) )
# prev_cursor: points before the first item in ascending order # prev_cursor: points before the first item in ascending order
prev_cursor: str | None = None prev_cursor: str | None = None
if direction == "next" and cursor is not None and items_page: if direction is _CursorDirection.NEXT and cursor is not None and items_page:
prev_cursor = _encode_cursor( prev_cursor = _encode_cursor(
getattr(items_page[0], cursor_col_name), direction="prev" getattr(items_page[0], cursor_col_name), direction=_CursorDirection.PREV
) )
elif direction == "prev" and has_more and items_page: elif direction is _CursorDirection.PREV and has_more and items_page:
prev_cursor = _encode_cursor( prev_cursor = _encode_cursor(
getattr(items_page[0], cursor_col_name), direction="prev" getattr(items_page[0], cursor_col_name), direction=_CursorDirection.PREV
) )
items: list[Any] = [schema.model_validate(item) for item in items_page] items: list[Any] = [schema.model_validate(item) for item in items_page]

View File

@@ -150,14 +150,20 @@ class PaginatedResponse(BaseResponse, Generic[DataT]):
pagination_type: PaginationType | None = None pagination_type: PaginationType | None = None
filter_attributes: dict[str, list[Any]] | None = None filter_attributes: dict[str, list[Any]] | None = None
_discriminated_union_cache: ClassVar[dict[Any, Any]] = {}
def __class_getitem__( # type: ignore[invalid-method-override] def __class_getitem__( # type: ignore[invalid-method-override]
cls, item: type[Any] | tuple[type[Any], ...] cls, item: type[Any] | tuple[type[Any], ...]
) -> type[Any]: ) -> type[Any]:
if cls is PaginatedResponse and not isinstance(item, TypeVar): if cls is PaginatedResponse and not isinstance(item, TypeVar):
return Annotated[ # type: ignore[invalid-return-type] cached = cls._discriminated_union_cache.get(item)
if cached is None:
cached = Annotated[
Union[CursorPaginatedResponse[item], OffsetPaginatedResponse[item]], # type: ignore[invalid-type-form] Union[CursorPaginatedResponse[item], OffsetPaginatedResponse[item]], # type: ignore[invalid-type-form]
Field(discriminator="pagination_type"), Field(discriminator="pagination_type"),
] ]
cls._discriminated_union_cache[item] = cached
return cached # type: ignore[invalid-return-type]
return super().__class_getitem__(item) return super().__class_getitem__(item)

View File

@@ -7,7 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
from fastapi_toolsets.crud import CrudFactory, PaginationType from fastapi_toolsets.crud import CrudFactory, PaginationType
from fastapi_toolsets.crud.factory import AsyncCrud from fastapi_toolsets.crud.factory import AsyncCrud, _CursorDirection
from fastapi_toolsets.exceptions import NotFoundError from fastapi_toolsets.exceptions import NotFoundError
from .conftest import ( from .conftest import (
@@ -2016,7 +2016,7 @@ class TestCursorPaginatePrevCursor:
assert isinstance(page1.pagination, CursorPagination) assert isinstance(page1.pagination, CursorPagination)
# Manually craft a backward cursor before any existing id # Manually craft a backward cursor before any existing id
before_all = _encode_cursor(0, direction="prev") before_all = _encode_cursor(0, direction=_CursorDirection.PREV)
empty = await IntRoleCursorCrud.cursor_paginate( empty = await IntRoleCursorCrud.cursor_paginate(
db_session, cursor=before_all, items_per_page=5, schema=IntRoleRead db_session, cursor=before_all, items_per_page=5, schema=IntRoleRead
) )

View File

@@ -324,6 +324,10 @@ class TestPaginatedResponse:
assert CursorPaginatedResponse[dict] in union_args assert CursorPaginatedResponse[dict] in union_args
assert OffsetPaginatedResponse[dict] in union_args assert OffsetPaginatedResponse[dict] in union_args
def test_class_getitem_is_cached(self):
"""Repeated subscripting with the same type returns the identical cached object."""
assert PaginatedResponse[dict] is PaginatedResponse[dict]
def test_class_getitem_with_typevar_returns_generic(self): def test_class_getitem_with_typevar_returns_generic(self):
"""PaginatedResponse[TypeVar] falls through to Pydantic generic parametrisation.""" """PaginatedResponse[TypeVar] falls through to Pydantic generic parametrisation."""
from typing import TypeVar from typing import TypeVar