mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-15 22:26:25 +02:00
fix: coerce plain int values to enum member when filtering on IntEnum column (#231)
This commit is contained in:
@@ -347,6 +347,24 @@ def build_filter_by(
|
||||
filters.append(column.overlap(value))
|
||||
else:
|
||||
filters.append(column.any(value))
|
||||
elif isinstance(col_type, Enum):
|
||||
enum_class = col_type.enum_class
|
||||
if enum_class is not None and issubclass(enum_class, int):
|
||||
|
||||
def _coerce_int_enum(v: Any) -> Any:
|
||||
if isinstance(v, enum_class):
|
||||
return v
|
||||
return enum_class(int(v))
|
||||
|
||||
if isinstance(value, list):
|
||||
filters.append(column.in_([_coerce_int_enum(v) for v in value]))
|
||||
else:
|
||||
filters.append(column == _coerce_int_enum(value))
|
||||
else:
|
||||
if isinstance(value, list):
|
||||
filters.append(column.in_(value))
|
||||
else:
|
||||
filters.append(column == value)
|
||||
elif isinstance(col_type, _EQUALITY_TYPES):
|
||||
if isinstance(value, list):
|
||||
filters.append(column.in_(value))
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from enum import Enum
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
@@ -12,6 +13,7 @@ from sqlalchemy import (
|
||||
Column,
|
||||
Date,
|
||||
DateTime,
|
||||
Enum as SAEnum,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
JSON,
|
||||
@@ -139,6 +141,35 @@ class Post(Base):
|
||||
tags: Mapped[list[Tag]] = relationship(secondary=post_tags)
|
||||
|
||||
|
||||
class OrderStatus(int, Enum):
|
||||
"""Integer-backed enum for order status."""
|
||||
|
||||
PENDING = 1
|
||||
PROCESSING = 2
|
||||
SHIPPED = 3
|
||||
CANCELLED = 4
|
||||
|
||||
|
||||
class Color(str, Enum):
|
||||
"""String-backed enum for color."""
|
||||
|
||||
RED = "red"
|
||||
GREEN = "green"
|
||||
BLUE = "blue"
|
||||
|
||||
|
||||
class Order(Base):
|
||||
"""Test model with an IntEnum column (Enum(int, Enum)) and a raw Integer column."""
|
||||
|
||||
__tablename__ = "orders"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||
name: Mapped[str] = mapped_column(String(100))
|
||||
status: Mapped[OrderStatus] = mapped_column(SAEnum(OrderStatus))
|
||||
priority: Mapped[int] = mapped_column(Integer)
|
||||
color: Mapped[Color] = mapped_column(SAEnum(Color))
|
||||
|
||||
|
||||
class Transfer(Base):
|
||||
"""Test model with two FKs to the same table (users)."""
|
||||
|
||||
@@ -311,6 +342,26 @@ class ArticleRead(PydanticBase):
|
||||
labels: list[str]
|
||||
|
||||
|
||||
class OrderCreate(BaseModel):
|
||||
"""Schema for creating an order."""
|
||||
|
||||
id: uuid.UUID | None = None
|
||||
name: str
|
||||
status: OrderStatus
|
||||
priority: int = 0
|
||||
color: Color = Color.RED
|
||||
|
||||
|
||||
class OrderRead(PydanticBase):
|
||||
"""Schema for reading an order."""
|
||||
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
status: OrderStatus
|
||||
priority: int
|
||||
color: Color
|
||||
|
||||
|
||||
class TransferCreate(BaseModel):
|
||||
"""Schema for creating a transfer."""
|
||||
|
||||
@@ -327,6 +378,7 @@ class TransferRead(PydanticBase):
|
||||
amount: str
|
||||
|
||||
|
||||
OrderCrud = CrudFactory(Order)
|
||||
TransferCrud = CrudFactory(Transfer)
|
||||
ArticleCrud = CrudFactory(Article)
|
||||
RoleCrud = CrudFactory(Role)
|
||||
|
||||
@@ -23,6 +23,12 @@ from .conftest import (
|
||||
ArticleCreate,
|
||||
ArticleCrud,
|
||||
ArticleRead,
|
||||
Color,
|
||||
Order,
|
||||
OrderCreate,
|
||||
OrderCrud,
|
||||
OrderRead,
|
||||
OrderStatus,
|
||||
Role,
|
||||
RoleCreate,
|
||||
RoleCrud,
|
||||
@@ -1121,6 +1127,227 @@ class TestFilterBy:
|
||||
assert "JSON" in exc_info.value.col_type
|
||||
|
||||
|
||||
class TestFilterByIntEnum:
|
||||
"""Tests for filter_by on columns typed as (int, Enum) / IntEnum."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_filter_by_intenum_member(self, db_session: AsyncSession):
|
||||
"""filter_by with an IntEnum member value filters correctly."""
|
||||
OrderFacetCrud = CrudFactory(Order, facet_fields=[Order.status])
|
||||
await OrderCrud.create(
|
||||
db_session, OrderCreate(name="order-1", status=OrderStatus.PENDING)
|
||||
)
|
||||
await OrderCrud.create(
|
||||
db_session, OrderCreate(name="order-2", status=OrderStatus.SHIPPED)
|
||||
)
|
||||
await OrderCrud.create(
|
||||
db_session, OrderCreate(name="order-3", status=OrderStatus.PENDING)
|
||||
)
|
||||
|
||||
result = await OrderFacetCrud.offset_paginate(
|
||||
db_session,
|
||||
filter_by={"status": OrderStatus.PENDING},
|
||||
schema=OrderRead,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 2
|
||||
names = {o.name for o in result.data}
|
||||
assert names == {"order-1", "order-3"}
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_filter_by_plain_int_value(self, db_session: AsyncSession):
|
||||
"""filter_by with a plain int (not an enum member) filters correctly on an IntEnum column."""
|
||||
OrderFacetCrud = CrudFactory(Order, facet_fields=[Order.status])
|
||||
await OrderCrud.create(
|
||||
db_session, OrderCreate(name="order-1", status=OrderStatus.PENDING)
|
||||
)
|
||||
await OrderCrud.create(
|
||||
db_session, OrderCreate(name="order-2", status=OrderStatus.SHIPPED)
|
||||
)
|
||||
|
||||
result = await OrderFacetCrud.offset_paginate(
|
||||
db_session,
|
||||
filter_by={"status": 1}, # plain int, not OrderStatus.PENDING
|
||||
schema=OrderRead,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 1
|
||||
assert result.data[0].name == "order-1"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_filter_by_intenum_list(self, db_session: AsyncSession):
|
||||
"""filter_by with a list of IntEnum members produces an IN filter."""
|
||||
OrderFacetCrud = CrudFactory(Order, facet_fields=[Order.status])
|
||||
await OrderCrud.create(
|
||||
db_session, OrderCreate(name="order-1", status=OrderStatus.PENDING)
|
||||
)
|
||||
await OrderCrud.create(
|
||||
db_session, OrderCreate(name="order-2", status=OrderStatus.SHIPPED)
|
||||
)
|
||||
await OrderCrud.create(
|
||||
db_session, OrderCreate(name="order-3", status=OrderStatus.CANCELLED)
|
||||
)
|
||||
|
||||
result = await OrderFacetCrud.offset_paginate(
|
||||
db_session,
|
||||
filter_by={"status": [OrderStatus.PENDING, OrderStatus.SHIPPED]},
|
||||
schema=OrderRead,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 2
|
||||
names = {o.name for o in result.data}
|
||||
assert names == {"order-1", "order-2"}
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_filter_by_plain_int_list(self, db_session: AsyncSession):
|
||||
"""filter_by with a list of plain ints filters correctly on an IntEnum column."""
|
||||
OrderFacetCrud = CrudFactory(Order, facet_fields=[Order.status])
|
||||
await OrderCrud.create(
|
||||
db_session, OrderCreate(name="order-1", status=OrderStatus.PENDING)
|
||||
)
|
||||
await OrderCrud.create(
|
||||
db_session, OrderCreate(name="order-2", status=OrderStatus.SHIPPED)
|
||||
)
|
||||
await OrderCrud.create(
|
||||
db_session, OrderCreate(name="order-3", status=OrderStatus.CANCELLED)
|
||||
)
|
||||
|
||||
result = await OrderFacetCrud.offset_paginate(
|
||||
db_session,
|
||||
filter_by={"status": [1, 3]}, # plain ints for PENDING and SHIPPED
|
||||
schema=OrderRead,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 2
|
||||
names = {o.name for o in result.data}
|
||||
assert names == {"order-1", "order-2"}
|
||||
|
||||
|
||||
class TestFilterByStrEnum:
|
||||
"""Tests for filter_by on columns typed as (str, Enum) / StrEnum (lines 364-367)."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_filter_by_strenum_member(self, db_session: AsyncSession):
|
||||
"""filter_by with a StrEnum member on a string Enum column filters correctly."""
|
||||
OrderColorCrud = CrudFactory(Order, facet_fields=[Order.color])
|
||||
await OrderCrud.create(
|
||||
db_session,
|
||||
OrderCreate(name="red-order", status=OrderStatus.PENDING, color=Color.RED),
|
||||
)
|
||||
await OrderCrud.create(
|
||||
db_session,
|
||||
OrderCreate(
|
||||
name="blue-order", status=OrderStatus.PENDING, color=Color.BLUE
|
||||
),
|
||||
)
|
||||
|
||||
result = await OrderColorCrud.offset_paginate(
|
||||
db_session,
|
||||
filter_by={"color": Color.RED},
|
||||
schema=OrderRead,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 1
|
||||
assert result.data[0].name == "red-order"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_filter_by_strenum_list(self, db_session: AsyncSession):
|
||||
"""filter_by with a list of StrEnum members produces an IN filter."""
|
||||
OrderColorCrud = CrudFactory(Order, facet_fields=[Order.color])
|
||||
await OrderCrud.create(
|
||||
db_session,
|
||||
OrderCreate(name="red-order", status=OrderStatus.PENDING, color=Color.RED),
|
||||
)
|
||||
await OrderCrud.create(
|
||||
db_session,
|
||||
OrderCreate(
|
||||
name="green-order", status=OrderStatus.PENDING, color=Color.GREEN
|
||||
),
|
||||
)
|
||||
await OrderCrud.create(
|
||||
db_session,
|
||||
OrderCreate(
|
||||
name="blue-order", status=OrderStatus.PENDING, color=Color.BLUE
|
||||
),
|
||||
)
|
||||
|
||||
result = await OrderColorCrud.offset_paginate(
|
||||
db_session,
|
||||
filter_by={"color": [Color.RED, Color.BLUE]},
|
||||
schema=OrderRead,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 2
|
||||
names = {o.name for o in result.data}
|
||||
assert names == {"red-order", "blue-order"}
|
||||
|
||||
|
||||
class TestFilterByIntegerColumn:
|
||||
"""Tests for filter_by on plain Integer columns with IntEnum values."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_filter_by_integer_column_with_intenum_member(
|
||||
self, db_session: AsyncSession
|
||||
):
|
||||
"""filter_by with an IntEnum member on an Integer column works correctly."""
|
||||
OrderPriorityCrud = CrudFactory(Order, facet_fields=[Order.priority])
|
||||
await OrderCrud.create(
|
||||
db_session,
|
||||
OrderCreate(
|
||||
name="order-1", status=OrderStatus.PENDING, priority=OrderStatus.PENDING
|
||||
),
|
||||
)
|
||||
await OrderCrud.create(
|
||||
db_session,
|
||||
OrderCreate(
|
||||
name="order-2", status=OrderStatus.SHIPPED, priority=OrderStatus.SHIPPED
|
||||
),
|
||||
)
|
||||
|
||||
result = await OrderPriorityCrud.offset_paginate(
|
||||
db_session,
|
||||
filter_by={
|
||||
"priority": OrderStatus.PENDING
|
||||
}, # IntEnum member on Integer col
|
||||
schema=OrderRead,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 1
|
||||
assert result.data[0].name == "order-1"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_filter_by_integer_column_with_plain_int(
|
||||
self, db_session: AsyncSession
|
||||
):
|
||||
"""filter_by with a plain int on an Integer column works correctly."""
|
||||
OrderPriorityCrud = CrudFactory(Order, facet_fields=[Order.priority])
|
||||
await OrderCrud.create(
|
||||
db_session,
|
||||
OrderCreate(name="order-1", status=OrderStatus.PENDING, priority=1),
|
||||
)
|
||||
await OrderCrud.create(
|
||||
db_session,
|
||||
OrderCreate(name="order-2", status=OrderStatus.SHIPPED, priority=3),
|
||||
)
|
||||
|
||||
result = await OrderPriorityCrud.offset_paginate(
|
||||
db_session,
|
||||
filter_by={"priority": 1},
|
||||
schema=OrderRead,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 1
|
||||
assert result.data[0].name == "order-1"
|
||||
|
||||
|
||||
class TestFilterParamsViaConsolidated:
|
||||
"""Tests for filter params via consolidated offset_paginate_params()."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user