diff --git a/src/fastapi_toolsets/crud/search.py b/src/fastapi_toolsets/crud/search.py index 7f9b665..cd7bae1 100644 --- a/src/fastapi_toolsets/crud/search.py +++ b/src/fastapi_toolsets/crud/search.py @@ -347,6 +347,24 @@ def build_filter_by( filters.append(column.overlap(value)) else: filters.append(column.any(value)) + elif isinstance(col_type, Enum): + enum_class = col_type.enum_class + if enum_class is not None and issubclass(enum_class, int): + + def _coerce_int_enum(v: Any) -> Any: + if isinstance(v, enum_class): + return v + return enum_class(int(v)) + + if isinstance(value, list): + filters.append(column.in_([_coerce_int_enum(v) for v in value])) + else: + filters.append(column == _coerce_int_enum(value)) + else: + if isinstance(value, list): + filters.append(column.in_(value)) + else: + filters.append(column == value) elif isinstance(col_type, _EQUALITY_TYPES): if isinstance(value, list): filters.append(column.in_(value)) diff --git a/tests/conftest.py b/tests/conftest.py index fbf30fa..c71287a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ import os import uuid +from enum import Enum import pytest from pydantic import BaseModel @@ -12,6 +13,7 @@ from sqlalchemy import ( Column, Date, DateTime, + Enum as SAEnum, ForeignKey, Integer, JSON, @@ -139,6 +141,35 @@ class Post(Base): tags: Mapped[list[Tag]] = relationship(secondary=post_tags) +class OrderStatus(int, Enum): + """Integer-backed enum for order status.""" + + PENDING = 1 + PROCESSING = 2 + SHIPPED = 3 + CANCELLED = 4 + + +class Color(str, Enum): + """String-backed enum for color.""" + + RED = "red" + GREEN = "green" + BLUE = "blue" + + +class Order(Base): + """Test model with an IntEnum column (Enum(int, Enum)) and a raw Integer column.""" + + __tablename__ = "orders" + + id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4) + name: Mapped[str] = mapped_column(String(100)) + status: Mapped[OrderStatus] = mapped_column(SAEnum(OrderStatus)) + priority: Mapped[int] = mapped_column(Integer) + color: Mapped[Color] = mapped_column(SAEnum(Color)) + + class Transfer(Base): """Test model with two FKs to the same table (users).""" @@ -311,6 +342,26 @@ class ArticleRead(PydanticBase): labels: list[str] +class OrderCreate(BaseModel): + """Schema for creating an order.""" + + id: uuid.UUID | None = None + name: str + status: OrderStatus + priority: int = 0 + color: Color = Color.RED + + +class OrderRead(PydanticBase): + """Schema for reading an order.""" + + id: uuid.UUID + name: str + status: OrderStatus + priority: int + color: Color + + class TransferCreate(BaseModel): """Schema for creating a transfer.""" @@ -327,6 +378,7 @@ class TransferRead(PydanticBase): amount: str +OrderCrud = CrudFactory(Order) TransferCrud = CrudFactory(Transfer) ArticleCrud = CrudFactory(Article) RoleCrud = CrudFactory(Role) diff --git a/tests/test_crud_search.py b/tests/test_crud_search.py index 7760037..426bf4b 100644 --- a/tests/test_crud_search.py +++ b/tests/test_crud_search.py @@ -23,6 +23,12 @@ from .conftest import ( ArticleCreate, ArticleCrud, ArticleRead, + Color, + Order, + OrderCreate, + OrderCrud, + OrderRead, + OrderStatus, Role, RoleCreate, RoleCrud, @@ -1121,6 +1127,227 @@ class TestFilterBy: assert "JSON" in exc_info.value.col_type +class TestFilterByIntEnum: + """Tests for filter_by on columns typed as (int, Enum) / IntEnum.""" + + @pytest.mark.anyio + async def test_filter_by_intenum_member(self, db_session: AsyncSession): + """filter_by with an IntEnum member value filters correctly.""" + OrderFacetCrud = CrudFactory(Order, facet_fields=[Order.status]) + await OrderCrud.create( + db_session, OrderCreate(name="order-1", status=OrderStatus.PENDING) + ) + await OrderCrud.create( + db_session, OrderCreate(name="order-2", status=OrderStatus.SHIPPED) + ) + await OrderCrud.create( + db_session, OrderCreate(name="order-3", status=OrderStatus.PENDING) + ) + + result = await OrderFacetCrud.offset_paginate( + db_session, + filter_by={"status": OrderStatus.PENDING}, + schema=OrderRead, + ) + + assert isinstance(result.pagination, OffsetPagination) + assert result.pagination.total_count == 2 + names = {o.name for o in result.data} + assert names == {"order-1", "order-3"} + + @pytest.mark.anyio + async def test_filter_by_plain_int_value(self, db_session: AsyncSession): + """filter_by with a plain int (not an enum member) filters correctly on an IntEnum column.""" + OrderFacetCrud = CrudFactory(Order, facet_fields=[Order.status]) + await OrderCrud.create( + db_session, OrderCreate(name="order-1", status=OrderStatus.PENDING) + ) + await OrderCrud.create( + db_session, OrderCreate(name="order-2", status=OrderStatus.SHIPPED) + ) + + result = await OrderFacetCrud.offset_paginate( + db_session, + filter_by={"status": 1}, # plain int, not OrderStatus.PENDING + schema=OrderRead, + ) + + assert isinstance(result.pagination, OffsetPagination) + assert result.pagination.total_count == 1 + assert result.data[0].name == "order-1" + + @pytest.mark.anyio + async def test_filter_by_intenum_list(self, db_session: AsyncSession): + """filter_by with a list of IntEnum members produces an IN filter.""" + OrderFacetCrud = CrudFactory(Order, facet_fields=[Order.status]) + await OrderCrud.create( + db_session, OrderCreate(name="order-1", status=OrderStatus.PENDING) + ) + await OrderCrud.create( + db_session, OrderCreate(name="order-2", status=OrderStatus.SHIPPED) + ) + await OrderCrud.create( + db_session, OrderCreate(name="order-3", status=OrderStatus.CANCELLED) + ) + + result = await OrderFacetCrud.offset_paginate( + db_session, + filter_by={"status": [OrderStatus.PENDING, OrderStatus.SHIPPED]}, + schema=OrderRead, + ) + + assert isinstance(result.pagination, OffsetPagination) + assert result.pagination.total_count == 2 + names = {o.name for o in result.data} + assert names == {"order-1", "order-2"} + + @pytest.mark.anyio + async def test_filter_by_plain_int_list(self, db_session: AsyncSession): + """filter_by with a list of plain ints filters correctly on an IntEnum column.""" + OrderFacetCrud = CrudFactory(Order, facet_fields=[Order.status]) + await OrderCrud.create( + db_session, OrderCreate(name="order-1", status=OrderStatus.PENDING) + ) + await OrderCrud.create( + db_session, OrderCreate(name="order-2", status=OrderStatus.SHIPPED) + ) + await OrderCrud.create( + db_session, OrderCreate(name="order-3", status=OrderStatus.CANCELLED) + ) + + result = await OrderFacetCrud.offset_paginate( + db_session, + filter_by={"status": [1, 3]}, # plain ints for PENDING and SHIPPED + schema=OrderRead, + ) + + assert isinstance(result.pagination, OffsetPagination) + assert result.pagination.total_count == 2 + names = {o.name for o in result.data} + assert names == {"order-1", "order-2"} + + +class TestFilterByStrEnum: + """Tests for filter_by on columns typed as (str, Enum) / StrEnum (lines 364-367).""" + + @pytest.mark.anyio + async def test_filter_by_strenum_member(self, db_session: AsyncSession): + """filter_by with a StrEnum member on a string Enum column filters correctly.""" + OrderColorCrud = CrudFactory(Order, facet_fields=[Order.color]) + await OrderCrud.create( + db_session, + OrderCreate(name="red-order", status=OrderStatus.PENDING, color=Color.RED), + ) + await OrderCrud.create( + db_session, + OrderCreate( + name="blue-order", status=OrderStatus.PENDING, color=Color.BLUE + ), + ) + + result = await OrderColorCrud.offset_paginate( + db_session, + filter_by={"color": Color.RED}, + schema=OrderRead, + ) + + assert isinstance(result.pagination, OffsetPagination) + assert result.pagination.total_count == 1 + assert result.data[0].name == "red-order" + + @pytest.mark.anyio + async def test_filter_by_strenum_list(self, db_session: AsyncSession): + """filter_by with a list of StrEnum members produces an IN filter.""" + OrderColorCrud = CrudFactory(Order, facet_fields=[Order.color]) + await OrderCrud.create( + db_session, + OrderCreate(name="red-order", status=OrderStatus.PENDING, color=Color.RED), + ) + await OrderCrud.create( + db_session, + OrderCreate( + name="green-order", status=OrderStatus.PENDING, color=Color.GREEN + ), + ) + await OrderCrud.create( + db_session, + OrderCreate( + name="blue-order", status=OrderStatus.PENDING, color=Color.BLUE + ), + ) + + result = await OrderColorCrud.offset_paginate( + db_session, + filter_by={"color": [Color.RED, Color.BLUE]}, + schema=OrderRead, + ) + + assert isinstance(result.pagination, OffsetPagination) + assert result.pagination.total_count == 2 + names = {o.name for o in result.data} + assert names == {"red-order", "blue-order"} + + +class TestFilterByIntegerColumn: + """Tests for filter_by on plain Integer columns with IntEnum values.""" + + @pytest.mark.anyio + async def test_filter_by_integer_column_with_intenum_member( + self, db_session: AsyncSession + ): + """filter_by with an IntEnum member on an Integer column works correctly.""" + OrderPriorityCrud = CrudFactory(Order, facet_fields=[Order.priority]) + await OrderCrud.create( + db_session, + OrderCreate( + name="order-1", status=OrderStatus.PENDING, priority=OrderStatus.PENDING + ), + ) + await OrderCrud.create( + db_session, + OrderCreate( + name="order-2", status=OrderStatus.SHIPPED, priority=OrderStatus.SHIPPED + ), + ) + + result = await OrderPriorityCrud.offset_paginate( + db_session, + filter_by={ + "priority": OrderStatus.PENDING + }, # IntEnum member on Integer col + schema=OrderRead, + ) + + assert isinstance(result.pagination, OffsetPagination) + assert result.pagination.total_count == 1 + assert result.data[0].name == "order-1" + + @pytest.mark.anyio + async def test_filter_by_integer_column_with_plain_int( + self, db_session: AsyncSession + ): + """filter_by with a plain int on an Integer column works correctly.""" + OrderPriorityCrud = CrudFactory(Order, facet_fields=[Order.priority]) + await OrderCrud.create( + db_session, + OrderCreate(name="order-1", status=OrderStatus.PENDING, priority=1), + ) + await OrderCrud.create( + db_session, + OrderCreate(name="order-2", status=OrderStatus.SHIPPED, priority=3), + ) + + result = await OrderPriorityCrud.offset_paginate( + db_session, + filter_by={"priority": 1}, + schema=OrderRead, + ) + + assert isinstance(result.pagination, OffsetPagination) + assert result.pagination.total_count == 1 + assert result.data[0].name == "order-1" + + class TestFilterParamsViaConsolidated: """Tests for filter params via consolidated offset_paginate_params()."""