diff --git a/src/fastapi_toolsets/crud/factory.py b/src/fastapi_toolsets/crud/factory.py index 57c0e99..b270f8b 100644 --- a/src/fastapi_toolsets/crud/factory.py +++ b/src/fastapi_toolsets/crud/factory.py @@ -9,6 +9,7 @@ import uuid as uuid_module from collections.abc import Awaitable, Callable, Sequence from datetime import date, datetime from decimal import Decimal +from enum import Enum from typing import Any, ClassVar, Generic, Literal, Self, cast, overload from fastapi import Query @@ -49,17 +50,43 @@ from .search import ( ) -def _encode_cursor(value: Any, *, direction: str = "next") -> str: +class _CursorDirection(str, Enum): + NEXT = "next" + PREV = "prev" + + +def _encode_cursor( + value: Any, *, direction: _CursorDirection = _CursorDirection.NEXT +) -> str: """Encode a cursor column value and navigation direction as a base64 string.""" return base64.b64encode( json.dumps({"val": str(value), "dir": direction}).encode() ).decode() -def _decode_cursor(cursor: str) -> tuple[str, str]: +def _decode_cursor(cursor: str) -> tuple[str, _CursorDirection]: """Decode a cursor base64 string into ``(raw_value, direction)``.""" payload = json.loads(base64.b64decode(cursor.encode()).decode()) - return payload["val"], payload["dir"] + return payload["val"], _CursorDirection(payload["dir"]) + + +def _parse_cursor_value(raw_val: str, col_type: Any) -> Any: + """Parse a raw cursor string value back into the appropriate Python type.""" + if isinstance(col_type, Integer): + return int(raw_val) + if isinstance(col_type, Uuid): + return uuid_module.UUID(raw_val) + if isinstance(col_type, DateTime): + return datetime.fromisoformat(raw_val) + if isinstance(col_type, Date): + return date.fromisoformat(raw_val) + if isinstance(col_type, (Float, Numeric)): + return Decimal(raw_val) + raise ValueError( + f"Unsupported cursor column type: {type(col_type).__name__!r}. " + "Supported types: Integer, BigInteger, SmallInteger, Uuid, " + "DateTime, Date, Float, Numeric." + ) def _apply_joins(q: Any, joins: JoinType | None, outer_join: bool) -> Any: @@ -108,8 +135,9 @@ class AsyncCrud(Generic[ModelType]): if raw_fields is None: cls.searchable_fields = [pk_col] else: - existing_keys = {f.key for f in raw_fields if not isinstance(f, tuple)} - if pk_key not in existing_keys: + if not any( + not isinstance(f, tuple) and f.key == pk_key for f in raw_fields + ): cls.searchable_fields = [pk_col, *raw_fields] @classmethod @@ -1048,27 +1076,12 @@ class AsyncCrud(Generic[ModelType]): cursor_column: Any = cls.cursor_column cursor_col_name: str = cursor_column.key - direction = "next" + direction = _CursorDirection.NEXT if cursor is not None: raw_val, direction = _decode_cursor(cursor) col_type = cursor_column.property.columns[0].type - if isinstance(col_type, Integer): - cursor_val: Any = int(raw_val) - elif isinstance(col_type, Uuid): - cursor_val = uuid_module.UUID(raw_val) - elif isinstance(col_type, DateTime): - cursor_val = datetime.fromisoformat(raw_val) - elif isinstance(col_type, Date): - cursor_val = date.fromisoformat(raw_val) - elif isinstance(col_type, (Float, Numeric)): - cursor_val = Decimal(raw_val) - else: - raise ValueError( - f"Unsupported cursor column type: {type(col_type).__name__!r}. " - "Supported types: Integer, BigInteger, SmallInteger, Uuid, " - "DateTime, Date, Float, Numeric." - ) - if direction == "prev": + cursor_val: Any = _parse_cursor_value(raw_val, col_type) + if direction is _CursorDirection.PREV: filters.append(cursor_column < cursor_val) else: filters.append(cursor_column > cursor_val) @@ -1099,7 +1112,7 @@ class AsyncCrud(Generic[ModelType]): q = q.options(*resolved) # Cursor column is always the primary sort; reverse direction for prev traversal - if direction == "prev": + if direction is _CursorDirection.PREV: q = q.order_by(cursor_column.desc()) else: q = q.order_by(cursor_column) @@ -1115,32 +1128,34 @@ class AsyncCrud(Generic[ModelType]): items_page = raw_items[:items_per_page] # Restore ascending order when traversing backward - if direction == "prev": + if direction is _CursorDirection.PREV: items_page = list(reversed(items_page)) # next_cursor: points past the last item in ascending order next_cursor: str | None = None - if direction == "next": + if direction is _CursorDirection.NEXT: if has_more and items_page: next_cursor = _encode_cursor( - getattr(items_page[-1], cursor_col_name), direction="next" + getattr(items_page[-1], cursor_col_name), + direction=_CursorDirection.NEXT, ) else: # Going backward: always provide a next_cursor to allow returning forward if items_page: next_cursor = _encode_cursor( - getattr(items_page[-1], cursor_col_name), direction="next" + getattr(items_page[-1], cursor_col_name), + direction=_CursorDirection.NEXT, ) # prev_cursor: points before the first item in ascending order prev_cursor: str | None = None - if direction == "next" and cursor is not None and items_page: + if direction is _CursorDirection.NEXT and cursor is not None and items_page: prev_cursor = _encode_cursor( - getattr(items_page[0], cursor_col_name), direction="prev" + getattr(items_page[0], cursor_col_name), direction=_CursorDirection.PREV ) - elif direction == "prev" and has_more and items_page: + elif direction is _CursorDirection.PREV and has_more and items_page: prev_cursor = _encode_cursor( - getattr(items_page[0], cursor_col_name), direction="prev" + getattr(items_page[0], cursor_col_name), direction=_CursorDirection.PREV ) items: list[Any] = [schema.model_validate(item) for item in items_page] diff --git a/src/fastapi_toolsets/schemas.py b/src/fastapi_toolsets/schemas.py index 504d0b5..a3dd176 100644 --- a/src/fastapi_toolsets/schemas.py +++ b/src/fastapi_toolsets/schemas.py @@ -150,14 +150,20 @@ class PaginatedResponse(BaseResponse, Generic[DataT]): pagination_type: PaginationType | None = None filter_attributes: dict[str, list[Any]] | None = None + _discriminated_union_cache: ClassVar[dict[Any, Any]] = {} + def __class_getitem__( # type: ignore[invalid-method-override] cls, item: type[Any] | tuple[type[Any], ...] ) -> type[Any]: if cls is PaginatedResponse and not isinstance(item, TypeVar): - return Annotated[ # type: ignore[invalid-return-type] - Union[CursorPaginatedResponse[item], OffsetPaginatedResponse[item]], # type: ignore[invalid-type-form] - Field(discriminator="pagination_type"), - ] + cached = cls._discriminated_union_cache.get(item) + if cached is None: + cached = Annotated[ + Union[CursorPaginatedResponse[item], OffsetPaginatedResponse[item]], # type: ignore[invalid-type-form] + Field(discriminator="pagination_type"), + ] + cls._discriminated_union_cache[item] = cached + return cached # type: ignore[invalid-return-type] return super().__class_getitem__(item) diff --git a/tests/test_crud.py b/tests/test_crud.py index 790ae72..306a61e 100644 --- a/tests/test_crud.py +++ b/tests/test_crud.py @@ -7,7 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from fastapi_toolsets.crud import CrudFactory, PaginationType -from fastapi_toolsets.crud.factory import AsyncCrud +from fastapi_toolsets.crud.factory import AsyncCrud, _CursorDirection from fastapi_toolsets.exceptions import NotFoundError from .conftest import ( @@ -2016,7 +2016,7 @@ class TestCursorPaginatePrevCursor: assert isinstance(page1.pagination, CursorPagination) # Manually craft a backward cursor before any existing id - before_all = _encode_cursor(0, direction="prev") + before_all = _encode_cursor(0, direction=_CursorDirection.PREV) empty = await IntRoleCursorCrud.cursor_paginate( db_session, cursor=before_all, items_per_page=5, schema=IntRoleRead ) diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 3433d52..c58f5a4 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -324,6 +324,10 @@ class TestPaginatedResponse: assert CursorPaginatedResponse[dict] in union_args assert OffsetPaginatedResponse[dict] in union_args + def test_class_getitem_is_cached(self): + """Repeated subscripting with the same type returns the identical cached object.""" + assert PaginatedResponse[dict] is PaginatedResponse[dict] + def test_class_getitem_with_typevar_returns_generic(self): """PaginatedResponse[TypeVar] falls through to Pydantic generic parametrisation.""" from typing import TypeVar