mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-15 22:26:25 +02:00
refactor: CursorDirection enum, cursor value parsing and __class_getitem__ caching (#144)
This commit is contained in:
@@ -9,6 +9,7 @@ import uuid as uuid_module
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from typing import Any, ClassVar, Generic, Literal, Self, cast, overload
|
||||
|
||||
from fastapi import Query
|
||||
@@ -49,17 +50,43 @@ from .search import (
|
||||
)
|
||||
|
||||
|
||||
def _encode_cursor(value: Any, *, direction: str = "next") -> str:
|
||||
class _CursorDirection(str, Enum):
|
||||
NEXT = "next"
|
||||
PREV = "prev"
|
||||
|
||||
|
||||
def _encode_cursor(
|
||||
value: Any, *, direction: _CursorDirection = _CursorDirection.NEXT
|
||||
) -> str:
|
||||
"""Encode a cursor column value and navigation direction as a base64 string."""
|
||||
return base64.b64encode(
|
||||
json.dumps({"val": str(value), "dir": direction}).encode()
|
||||
).decode()
|
||||
|
||||
|
||||
def _decode_cursor(cursor: str) -> tuple[str, str]:
|
||||
def _decode_cursor(cursor: str) -> tuple[str, _CursorDirection]:
|
||||
"""Decode a cursor base64 string into ``(raw_value, direction)``."""
|
||||
payload = json.loads(base64.b64decode(cursor.encode()).decode())
|
||||
return payload["val"], payload["dir"]
|
||||
return payload["val"], _CursorDirection(payload["dir"])
|
||||
|
||||
|
||||
def _parse_cursor_value(raw_val: str, col_type: Any) -> Any:
|
||||
"""Parse a raw cursor string value back into the appropriate Python type."""
|
||||
if isinstance(col_type, Integer):
|
||||
return int(raw_val)
|
||||
if isinstance(col_type, Uuid):
|
||||
return uuid_module.UUID(raw_val)
|
||||
if isinstance(col_type, DateTime):
|
||||
return datetime.fromisoformat(raw_val)
|
||||
if isinstance(col_type, Date):
|
||||
return date.fromisoformat(raw_val)
|
||||
if isinstance(col_type, (Float, Numeric)):
|
||||
return Decimal(raw_val)
|
||||
raise ValueError(
|
||||
f"Unsupported cursor column type: {type(col_type).__name__!r}. "
|
||||
"Supported types: Integer, BigInteger, SmallInteger, Uuid, "
|
||||
"DateTime, Date, Float, Numeric."
|
||||
)
|
||||
|
||||
|
||||
def _apply_joins(q: Any, joins: JoinType | None, outer_join: bool) -> Any:
|
||||
@@ -108,8 +135,9 @@ class AsyncCrud(Generic[ModelType]):
|
||||
if raw_fields is None:
|
||||
cls.searchable_fields = [pk_col]
|
||||
else:
|
||||
existing_keys = {f.key for f in raw_fields if not isinstance(f, tuple)}
|
||||
if pk_key not in existing_keys:
|
||||
if not any(
|
||||
not isinstance(f, tuple) and f.key == pk_key for f in raw_fields
|
||||
):
|
||||
cls.searchable_fields = [pk_col, *raw_fields]
|
||||
|
||||
@classmethod
|
||||
@@ -1048,27 +1076,12 @@ class AsyncCrud(Generic[ModelType]):
|
||||
cursor_column: Any = cls.cursor_column
|
||||
cursor_col_name: str = cursor_column.key
|
||||
|
||||
direction = "next"
|
||||
direction = _CursorDirection.NEXT
|
||||
if cursor is not None:
|
||||
raw_val, direction = _decode_cursor(cursor)
|
||||
col_type = cursor_column.property.columns[0].type
|
||||
if isinstance(col_type, Integer):
|
||||
cursor_val: Any = int(raw_val)
|
||||
elif isinstance(col_type, Uuid):
|
||||
cursor_val = uuid_module.UUID(raw_val)
|
||||
elif isinstance(col_type, DateTime):
|
||||
cursor_val = datetime.fromisoformat(raw_val)
|
||||
elif isinstance(col_type, Date):
|
||||
cursor_val = date.fromisoformat(raw_val)
|
||||
elif isinstance(col_type, (Float, Numeric)):
|
||||
cursor_val = Decimal(raw_val)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported cursor column type: {type(col_type).__name__!r}. "
|
||||
"Supported types: Integer, BigInteger, SmallInteger, Uuid, "
|
||||
"DateTime, Date, Float, Numeric."
|
||||
)
|
||||
if direction == "prev":
|
||||
cursor_val: Any = _parse_cursor_value(raw_val, col_type)
|
||||
if direction is _CursorDirection.PREV:
|
||||
filters.append(cursor_column < cursor_val)
|
||||
else:
|
||||
filters.append(cursor_column > cursor_val)
|
||||
@@ -1099,7 +1112,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
q = q.options(*resolved)
|
||||
|
||||
# Cursor column is always the primary sort; reverse direction for prev traversal
|
||||
if direction == "prev":
|
||||
if direction is _CursorDirection.PREV:
|
||||
q = q.order_by(cursor_column.desc())
|
||||
else:
|
||||
q = q.order_by(cursor_column)
|
||||
@@ -1115,32 +1128,34 @@ class AsyncCrud(Generic[ModelType]):
|
||||
items_page = raw_items[:items_per_page]
|
||||
|
||||
# Restore ascending order when traversing backward
|
||||
if direction == "prev":
|
||||
if direction is _CursorDirection.PREV:
|
||||
items_page = list(reversed(items_page))
|
||||
|
||||
# next_cursor: points past the last item in ascending order
|
||||
next_cursor: str | None = None
|
||||
if direction == "next":
|
||||
if direction is _CursorDirection.NEXT:
|
||||
if has_more and items_page:
|
||||
next_cursor = _encode_cursor(
|
||||
getattr(items_page[-1], cursor_col_name), direction="next"
|
||||
getattr(items_page[-1], cursor_col_name),
|
||||
direction=_CursorDirection.NEXT,
|
||||
)
|
||||
else:
|
||||
# Going backward: always provide a next_cursor to allow returning forward
|
||||
if items_page:
|
||||
next_cursor = _encode_cursor(
|
||||
getattr(items_page[-1], cursor_col_name), direction="next"
|
||||
getattr(items_page[-1], cursor_col_name),
|
||||
direction=_CursorDirection.NEXT,
|
||||
)
|
||||
|
||||
# prev_cursor: points before the first item in ascending order
|
||||
prev_cursor: str | None = None
|
||||
if direction == "next" and cursor is not None and items_page:
|
||||
if direction is _CursorDirection.NEXT and cursor is not None and items_page:
|
||||
prev_cursor = _encode_cursor(
|
||||
getattr(items_page[0], cursor_col_name), direction="prev"
|
||||
getattr(items_page[0], cursor_col_name), direction=_CursorDirection.PREV
|
||||
)
|
||||
elif direction == "prev" and has_more and items_page:
|
||||
elif direction is _CursorDirection.PREV and has_more and items_page:
|
||||
prev_cursor = _encode_cursor(
|
||||
getattr(items_page[0], cursor_col_name), direction="prev"
|
||||
getattr(items_page[0], cursor_col_name), direction=_CursorDirection.PREV
|
||||
)
|
||||
|
||||
items: list[Any] = [schema.model_validate(item) for item in items_page]
|
||||
|
||||
@@ -150,14 +150,20 @@ class PaginatedResponse(BaseResponse, Generic[DataT]):
|
||||
pagination_type: PaginationType | None = None
|
||||
filter_attributes: dict[str, list[Any]] | None = None
|
||||
|
||||
_discriminated_union_cache: ClassVar[dict[Any, Any]] = {}
|
||||
|
||||
def __class_getitem__( # type: ignore[invalid-method-override]
|
||||
cls, item: type[Any] | tuple[type[Any], ...]
|
||||
) -> type[Any]:
|
||||
if cls is PaginatedResponse and not isinstance(item, TypeVar):
|
||||
return Annotated[ # type: ignore[invalid-return-type]
|
||||
cached = cls._discriminated_union_cache.get(item)
|
||||
if cached is None:
|
||||
cached = Annotated[
|
||||
Union[CursorPaginatedResponse[item], OffsetPaginatedResponse[item]], # type: ignore[invalid-type-form]
|
||||
Field(discriminator="pagination_type"),
|
||||
]
|
||||
cls._discriminated_union_cache[item] = cached
|
||||
return cached # type: ignore[invalid-return-type]
|
||||
return super().__class_getitem__(item)
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from fastapi_toolsets.crud import CrudFactory, PaginationType
|
||||
from fastapi_toolsets.crud.factory import AsyncCrud
|
||||
from fastapi_toolsets.crud.factory import AsyncCrud, _CursorDirection
|
||||
from fastapi_toolsets.exceptions import NotFoundError
|
||||
|
||||
from .conftest import (
|
||||
@@ -2016,7 +2016,7 @@ class TestCursorPaginatePrevCursor:
|
||||
assert isinstance(page1.pagination, CursorPagination)
|
||||
|
||||
# Manually craft a backward cursor before any existing id
|
||||
before_all = _encode_cursor(0, direction="prev")
|
||||
before_all = _encode_cursor(0, direction=_CursorDirection.PREV)
|
||||
empty = await IntRoleCursorCrud.cursor_paginate(
|
||||
db_session, cursor=before_all, items_per_page=5, schema=IntRoleRead
|
||||
)
|
||||
|
||||
@@ -324,6 +324,10 @@ class TestPaginatedResponse:
|
||||
assert CursorPaginatedResponse[dict] in union_args
|
||||
assert OffsetPaginatedResponse[dict] in union_args
|
||||
|
||||
def test_class_getitem_is_cached(self):
|
||||
"""Repeated subscripting with the same type returns the identical cached object."""
|
||||
assert PaginatedResponse[dict] is PaginatedResponse[dict]
|
||||
|
||||
def test_class_getitem_with_typevar_returns_generic(self):
|
||||
"""PaginatedResponse[TypeVar] falls through to Pydantic generic parametrisation."""
|
||||
from typing import TypeVar
|
||||
|
||||
Reference in New Issue
Block a user