refactor: CursorDirection enum, cursor value parsing and __class_getitem__ caching (#144)

This commit is contained in:
d3vyce
2026-03-15 17:48:50 +01:00
committed by GitHub
parent c863744012
commit fd7269a372
4 changed files with 63 additions and 38 deletions

View File

@@ -9,6 +9,7 @@ import uuid as uuid_module
from collections.abc import Awaitable, Callable, Sequence
from datetime import date, datetime
from decimal import Decimal
from enum import Enum
from typing import Any, ClassVar, Generic, Literal, Self, cast, overload
from fastapi import Query
@@ -49,17 +50,43 @@ from .search import (
)
def _encode_cursor(value: Any, *, direction: str = "next") -> str:
class _CursorDirection(str, Enum):
NEXT = "next"
PREV = "prev"
def _encode_cursor(
value: Any, *, direction: _CursorDirection = _CursorDirection.NEXT
) -> str:
"""Encode a cursor column value and navigation direction as a base64 string."""
return base64.b64encode(
json.dumps({"val": str(value), "dir": direction}).encode()
).decode()
def _decode_cursor(cursor: str) -> tuple[str, str]:
def _decode_cursor(cursor: str) -> tuple[str, _CursorDirection]:
"""Decode a cursor base64 string into ``(raw_value, direction)``."""
payload = json.loads(base64.b64decode(cursor.encode()).decode())
return payload["val"], payload["dir"]
return payload["val"], _CursorDirection(payload["dir"])
def _parse_cursor_value(raw_val: str, col_type: Any) -> Any:
"""Parse a raw cursor string value back into the appropriate Python type."""
if isinstance(col_type, Integer):
return int(raw_val)
if isinstance(col_type, Uuid):
return uuid_module.UUID(raw_val)
if isinstance(col_type, DateTime):
return datetime.fromisoformat(raw_val)
if isinstance(col_type, Date):
return date.fromisoformat(raw_val)
if isinstance(col_type, (Float, Numeric)):
return Decimal(raw_val)
raise ValueError(
f"Unsupported cursor column type: {type(col_type).__name__!r}. "
"Supported types: Integer, BigInteger, SmallInteger, Uuid, "
"DateTime, Date, Float, Numeric."
)
def _apply_joins(q: Any, joins: JoinType | None, outer_join: bool) -> Any:
@@ -108,8 +135,9 @@ class AsyncCrud(Generic[ModelType]):
if raw_fields is None:
cls.searchable_fields = [pk_col]
else:
existing_keys = {f.key for f in raw_fields if not isinstance(f, tuple)}
if pk_key not in existing_keys:
if not any(
not isinstance(f, tuple) and f.key == pk_key for f in raw_fields
):
cls.searchable_fields = [pk_col, *raw_fields]
@classmethod
@@ -1048,27 +1076,12 @@ class AsyncCrud(Generic[ModelType]):
cursor_column: Any = cls.cursor_column
cursor_col_name: str = cursor_column.key
direction = "next"
direction = _CursorDirection.NEXT
if cursor is not None:
raw_val, direction = _decode_cursor(cursor)
col_type = cursor_column.property.columns[0].type
if isinstance(col_type, Integer):
cursor_val: Any = int(raw_val)
elif isinstance(col_type, Uuid):
cursor_val = uuid_module.UUID(raw_val)
elif isinstance(col_type, DateTime):
cursor_val = datetime.fromisoformat(raw_val)
elif isinstance(col_type, Date):
cursor_val = date.fromisoformat(raw_val)
elif isinstance(col_type, (Float, Numeric)):
cursor_val = Decimal(raw_val)
else:
raise ValueError(
f"Unsupported cursor column type: {type(col_type).__name__!r}. "
"Supported types: Integer, BigInteger, SmallInteger, Uuid, "
"DateTime, Date, Float, Numeric."
)
if direction == "prev":
cursor_val: Any = _parse_cursor_value(raw_val, col_type)
if direction is _CursorDirection.PREV:
filters.append(cursor_column < cursor_val)
else:
filters.append(cursor_column > cursor_val)
@@ -1099,7 +1112,7 @@ class AsyncCrud(Generic[ModelType]):
q = q.options(*resolved)
# Cursor column is always the primary sort; reverse direction for prev traversal
if direction == "prev":
if direction is _CursorDirection.PREV:
q = q.order_by(cursor_column.desc())
else:
q = q.order_by(cursor_column)
@@ -1115,32 +1128,34 @@ class AsyncCrud(Generic[ModelType]):
items_page = raw_items[:items_per_page]
# Restore ascending order when traversing backward
if direction == "prev":
if direction is _CursorDirection.PREV:
items_page = list(reversed(items_page))
# next_cursor: points past the last item in ascending order
next_cursor: str | None = None
if direction == "next":
if direction is _CursorDirection.NEXT:
if has_more and items_page:
next_cursor = _encode_cursor(
getattr(items_page[-1], cursor_col_name), direction="next"
getattr(items_page[-1], cursor_col_name),
direction=_CursorDirection.NEXT,
)
else:
# Going backward: always provide a next_cursor to allow returning forward
if items_page:
next_cursor = _encode_cursor(
getattr(items_page[-1], cursor_col_name), direction="next"
getattr(items_page[-1], cursor_col_name),
direction=_CursorDirection.NEXT,
)
# prev_cursor: points before the first item in ascending order
prev_cursor: str | None = None
if direction == "next" and cursor is not None and items_page:
if direction is _CursorDirection.NEXT and cursor is not None and items_page:
prev_cursor = _encode_cursor(
getattr(items_page[0], cursor_col_name), direction="prev"
getattr(items_page[0], cursor_col_name), direction=_CursorDirection.PREV
)
elif direction == "prev" and has_more and items_page:
elif direction is _CursorDirection.PREV and has_more and items_page:
prev_cursor = _encode_cursor(
getattr(items_page[0], cursor_col_name), direction="prev"
getattr(items_page[0], cursor_col_name), direction=_CursorDirection.PREV
)
items: list[Any] = [schema.model_validate(item) for item in items_page]

View File

@@ -150,14 +150,20 @@ class PaginatedResponse(BaseResponse, Generic[DataT]):
pagination_type: PaginationType | None = None
filter_attributes: dict[str, list[Any]] | None = None
_discriminated_union_cache: ClassVar[dict[Any, Any]] = {}
def __class_getitem__( # type: ignore[invalid-method-override]
cls, item: type[Any] | tuple[type[Any], ...]
) -> type[Any]:
if cls is PaginatedResponse and not isinstance(item, TypeVar):
return Annotated[ # type: ignore[invalid-return-type]
Union[CursorPaginatedResponse[item], OffsetPaginatedResponse[item]], # type: ignore[invalid-type-form]
Field(discriminator="pagination_type"),
]
cached = cls._discriminated_union_cache.get(item)
if cached is None:
cached = Annotated[
Union[CursorPaginatedResponse[item], OffsetPaginatedResponse[item]], # type: ignore[invalid-type-form]
Field(discriminator="pagination_type"),
]
cls._discriminated_union_cache[item] = cached
return cached # type: ignore[invalid-return-type]
return super().__class_getitem__(item)