mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-16 06:36:26 +02:00
refactor: CursorDirection enum, cursor value parsing and __class_getitem__ caching (#144)
This commit is contained in:
@@ -9,6 +9,7 @@ import uuid as uuid_module
|
|||||||
from collections.abc import Awaitable, Callable, Sequence
|
from collections.abc import Awaitable, Callable, Sequence
|
||||||
from datetime import date, datetime
|
from datetime import date, datetime
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
|
from enum import Enum
|
||||||
from typing import Any, ClassVar, Generic, Literal, Self, cast, overload
|
from typing import Any, ClassVar, Generic, Literal, Self, cast, overload
|
||||||
|
|
||||||
from fastapi import Query
|
from fastapi import Query
|
||||||
@@ -49,17 +50,43 @@ from .search import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _encode_cursor(value: Any, *, direction: str = "next") -> str:
|
class _CursorDirection(str, Enum):
|
||||||
|
NEXT = "next"
|
||||||
|
PREV = "prev"
|
||||||
|
|
||||||
|
|
||||||
|
def _encode_cursor(
|
||||||
|
value: Any, *, direction: _CursorDirection = _CursorDirection.NEXT
|
||||||
|
) -> str:
|
||||||
"""Encode a cursor column value and navigation direction as a base64 string."""
|
"""Encode a cursor column value and navigation direction as a base64 string."""
|
||||||
return base64.b64encode(
|
return base64.b64encode(
|
||||||
json.dumps({"val": str(value), "dir": direction}).encode()
|
json.dumps({"val": str(value), "dir": direction}).encode()
|
||||||
).decode()
|
).decode()
|
||||||
|
|
||||||
|
|
||||||
def _decode_cursor(cursor: str) -> tuple[str, str]:
|
def _decode_cursor(cursor: str) -> tuple[str, _CursorDirection]:
|
||||||
"""Decode a cursor base64 string into ``(raw_value, direction)``."""
|
"""Decode a cursor base64 string into ``(raw_value, direction)``."""
|
||||||
payload = json.loads(base64.b64decode(cursor.encode()).decode())
|
payload = json.loads(base64.b64decode(cursor.encode()).decode())
|
||||||
return payload["val"], payload["dir"]
|
return payload["val"], _CursorDirection(payload["dir"])
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_cursor_value(raw_val: str, col_type: Any) -> Any:
|
||||||
|
"""Parse a raw cursor string value back into the appropriate Python type."""
|
||||||
|
if isinstance(col_type, Integer):
|
||||||
|
return int(raw_val)
|
||||||
|
if isinstance(col_type, Uuid):
|
||||||
|
return uuid_module.UUID(raw_val)
|
||||||
|
if isinstance(col_type, DateTime):
|
||||||
|
return datetime.fromisoformat(raw_val)
|
||||||
|
if isinstance(col_type, Date):
|
||||||
|
return date.fromisoformat(raw_val)
|
||||||
|
if isinstance(col_type, (Float, Numeric)):
|
||||||
|
return Decimal(raw_val)
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported cursor column type: {type(col_type).__name__!r}. "
|
||||||
|
"Supported types: Integer, BigInteger, SmallInteger, Uuid, "
|
||||||
|
"DateTime, Date, Float, Numeric."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _apply_joins(q: Any, joins: JoinType | None, outer_join: bool) -> Any:
|
def _apply_joins(q: Any, joins: JoinType | None, outer_join: bool) -> Any:
|
||||||
@@ -108,8 +135,9 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
if raw_fields is None:
|
if raw_fields is None:
|
||||||
cls.searchable_fields = [pk_col]
|
cls.searchable_fields = [pk_col]
|
||||||
else:
|
else:
|
||||||
existing_keys = {f.key for f in raw_fields if not isinstance(f, tuple)}
|
if not any(
|
||||||
if pk_key not in existing_keys:
|
not isinstance(f, tuple) and f.key == pk_key for f in raw_fields
|
||||||
|
):
|
||||||
cls.searchable_fields = [pk_col, *raw_fields]
|
cls.searchable_fields = [pk_col, *raw_fields]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -1048,27 +1076,12 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
cursor_column: Any = cls.cursor_column
|
cursor_column: Any = cls.cursor_column
|
||||||
cursor_col_name: str = cursor_column.key
|
cursor_col_name: str = cursor_column.key
|
||||||
|
|
||||||
direction = "next"
|
direction = _CursorDirection.NEXT
|
||||||
if cursor is not None:
|
if cursor is not None:
|
||||||
raw_val, direction = _decode_cursor(cursor)
|
raw_val, direction = _decode_cursor(cursor)
|
||||||
col_type = cursor_column.property.columns[0].type
|
col_type = cursor_column.property.columns[0].type
|
||||||
if isinstance(col_type, Integer):
|
cursor_val: Any = _parse_cursor_value(raw_val, col_type)
|
||||||
cursor_val: Any = int(raw_val)
|
if direction is _CursorDirection.PREV:
|
||||||
elif isinstance(col_type, Uuid):
|
|
||||||
cursor_val = uuid_module.UUID(raw_val)
|
|
||||||
elif isinstance(col_type, DateTime):
|
|
||||||
cursor_val = datetime.fromisoformat(raw_val)
|
|
||||||
elif isinstance(col_type, Date):
|
|
||||||
cursor_val = date.fromisoformat(raw_val)
|
|
||||||
elif isinstance(col_type, (Float, Numeric)):
|
|
||||||
cursor_val = Decimal(raw_val)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported cursor column type: {type(col_type).__name__!r}. "
|
|
||||||
"Supported types: Integer, BigInteger, SmallInteger, Uuid, "
|
|
||||||
"DateTime, Date, Float, Numeric."
|
|
||||||
)
|
|
||||||
if direction == "prev":
|
|
||||||
filters.append(cursor_column < cursor_val)
|
filters.append(cursor_column < cursor_val)
|
||||||
else:
|
else:
|
||||||
filters.append(cursor_column > cursor_val)
|
filters.append(cursor_column > cursor_val)
|
||||||
@@ -1099,7 +1112,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
q = q.options(*resolved)
|
q = q.options(*resolved)
|
||||||
|
|
||||||
# Cursor column is always the primary sort; reverse direction for prev traversal
|
# Cursor column is always the primary sort; reverse direction for prev traversal
|
||||||
if direction == "prev":
|
if direction is _CursorDirection.PREV:
|
||||||
q = q.order_by(cursor_column.desc())
|
q = q.order_by(cursor_column.desc())
|
||||||
else:
|
else:
|
||||||
q = q.order_by(cursor_column)
|
q = q.order_by(cursor_column)
|
||||||
@@ -1115,32 +1128,34 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
items_page = raw_items[:items_per_page]
|
items_page = raw_items[:items_per_page]
|
||||||
|
|
||||||
# Restore ascending order when traversing backward
|
# Restore ascending order when traversing backward
|
||||||
if direction == "prev":
|
if direction is _CursorDirection.PREV:
|
||||||
items_page = list(reversed(items_page))
|
items_page = list(reversed(items_page))
|
||||||
|
|
||||||
# next_cursor: points past the last item in ascending order
|
# next_cursor: points past the last item in ascending order
|
||||||
next_cursor: str | None = None
|
next_cursor: str | None = None
|
||||||
if direction == "next":
|
if direction is _CursorDirection.NEXT:
|
||||||
if has_more and items_page:
|
if has_more and items_page:
|
||||||
next_cursor = _encode_cursor(
|
next_cursor = _encode_cursor(
|
||||||
getattr(items_page[-1], cursor_col_name), direction="next"
|
getattr(items_page[-1], cursor_col_name),
|
||||||
|
direction=_CursorDirection.NEXT,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Going backward: always provide a next_cursor to allow returning forward
|
# Going backward: always provide a next_cursor to allow returning forward
|
||||||
if items_page:
|
if items_page:
|
||||||
next_cursor = _encode_cursor(
|
next_cursor = _encode_cursor(
|
||||||
getattr(items_page[-1], cursor_col_name), direction="next"
|
getattr(items_page[-1], cursor_col_name),
|
||||||
|
direction=_CursorDirection.NEXT,
|
||||||
)
|
)
|
||||||
|
|
||||||
# prev_cursor: points before the first item in ascending order
|
# prev_cursor: points before the first item in ascending order
|
||||||
prev_cursor: str | None = None
|
prev_cursor: str | None = None
|
||||||
if direction == "next" and cursor is not None and items_page:
|
if direction is _CursorDirection.NEXT and cursor is not None and items_page:
|
||||||
prev_cursor = _encode_cursor(
|
prev_cursor = _encode_cursor(
|
||||||
getattr(items_page[0], cursor_col_name), direction="prev"
|
getattr(items_page[0], cursor_col_name), direction=_CursorDirection.PREV
|
||||||
)
|
)
|
||||||
elif direction == "prev" and has_more and items_page:
|
elif direction is _CursorDirection.PREV and has_more and items_page:
|
||||||
prev_cursor = _encode_cursor(
|
prev_cursor = _encode_cursor(
|
||||||
getattr(items_page[0], cursor_col_name), direction="prev"
|
getattr(items_page[0], cursor_col_name), direction=_CursorDirection.PREV
|
||||||
)
|
)
|
||||||
|
|
||||||
items: list[Any] = [schema.model_validate(item) for item in items_page]
|
items: list[Any] = [schema.model_validate(item) for item in items_page]
|
||||||
|
|||||||
@@ -150,14 +150,20 @@ class PaginatedResponse(BaseResponse, Generic[DataT]):
|
|||||||
pagination_type: PaginationType | None = None
|
pagination_type: PaginationType | None = None
|
||||||
filter_attributes: dict[str, list[Any]] | None = None
|
filter_attributes: dict[str, list[Any]] | None = None
|
||||||
|
|
||||||
|
_discriminated_union_cache: ClassVar[dict[Any, Any]] = {}
|
||||||
|
|
||||||
def __class_getitem__( # type: ignore[invalid-method-override]
|
def __class_getitem__( # type: ignore[invalid-method-override]
|
||||||
cls, item: type[Any] | tuple[type[Any], ...]
|
cls, item: type[Any] | tuple[type[Any], ...]
|
||||||
) -> type[Any]:
|
) -> type[Any]:
|
||||||
if cls is PaginatedResponse and not isinstance(item, TypeVar):
|
if cls is PaginatedResponse and not isinstance(item, TypeVar):
|
||||||
return Annotated[ # type: ignore[invalid-return-type]
|
cached = cls._discriminated_union_cache.get(item)
|
||||||
|
if cached is None:
|
||||||
|
cached = Annotated[
|
||||||
Union[CursorPaginatedResponse[item], OffsetPaginatedResponse[item]], # type: ignore[invalid-type-form]
|
Union[CursorPaginatedResponse[item], OffsetPaginatedResponse[item]], # type: ignore[invalid-type-form]
|
||||||
Field(discriminator="pagination_type"),
|
Field(discriminator="pagination_type"),
|
||||||
]
|
]
|
||||||
|
cls._discriminated_union_cache[item] = cached
|
||||||
|
return cached # type: ignore[invalid-return-type]
|
||||||
return super().__class_getitem__(item)
|
return super().__class_getitem__(item)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
from fastapi_toolsets.crud import CrudFactory, PaginationType
|
from fastapi_toolsets.crud import CrudFactory, PaginationType
|
||||||
from fastapi_toolsets.crud.factory import AsyncCrud
|
from fastapi_toolsets.crud.factory import AsyncCrud, _CursorDirection
|
||||||
from fastapi_toolsets.exceptions import NotFoundError
|
from fastapi_toolsets.exceptions import NotFoundError
|
||||||
|
|
||||||
from .conftest import (
|
from .conftest import (
|
||||||
@@ -2016,7 +2016,7 @@ class TestCursorPaginatePrevCursor:
|
|||||||
assert isinstance(page1.pagination, CursorPagination)
|
assert isinstance(page1.pagination, CursorPagination)
|
||||||
|
|
||||||
# Manually craft a backward cursor before any existing id
|
# Manually craft a backward cursor before any existing id
|
||||||
before_all = _encode_cursor(0, direction="prev")
|
before_all = _encode_cursor(0, direction=_CursorDirection.PREV)
|
||||||
empty = await IntRoleCursorCrud.cursor_paginate(
|
empty = await IntRoleCursorCrud.cursor_paginate(
|
||||||
db_session, cursor=before_all, items_per_page=5, schema=IntRoleRead
|
db_session, cursor=before_all, items_per_page=5, schema=IntRoleRead
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -324,6 +324,10 @@ class TestPaginatedResponse:
|
|||||||
assert CursorPaginatedResponse[dict] in union_args
|
assert CursorPaginatedResponse[dict] in union_args
|
||||||
assert OffsetPaginatedResponse[dict] in union_args
|
assert OffsetPaginatedResponse[dict] in union_args
|
||||||
|
|
||||||
|
def test_class_getitem_is_cached(self):
|
||||||
|
"""Repeated subscripting with the same type returns the identical cached object."""
|
||||||
|
assert PaginatedResponse[dict] is PaginatedResponse[dict]
|
||||||
|
|
||||||
def test_class_getitem_with_typevar_returns_generic(self):
|
def test_class_getitem_with_typevar_returns_generic(self):
|
||||||
"""PaginatedResponse[TypeVar] falls through to Pydantic generic parametrisation."""
|
"""PaginatedResponse[TypeVar] falls through to Pydantic generic parametrisation."""
|
||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
|
|||||||
Reference in New Issue
Block a user