mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-15 22:26:25 +02:00
feat: normalize enum facet values to member names and accept name strings in filter_by (#235)
This commit is contained in:
@@ -265,7 +265,15 @@ async def build_facets(
|
|||||||
else:
|
else:
|
||||||
q = q.order_by(column)
|
q = q.order_by(column)
|
||||||
result = await session.execute(q)
|
result = await session.execute(q)
|
||||||
values = [row[0] for row in result.all() if row[0] is not None]
|
col_type = column.property.columns[0].type
|
||||||
|
enum_class = getattr(col_type, "enum_class", None)
|
||||||
|
values = [
|
||||||
|
row[0].name
|
||||||
|
if (enum_class is not None and isinstance(row[0], enum_class))
|
||||||
|
else row[0]
|
||||||
|
for row in result.all()
|
||||||
|
if row[0] is not None
|
||||||
|
]
|
||||||
return key, values
|
return key, values
|
||||||
|
|
||||||
pairs = await asyncio.gather(
|
pairs = await asyncio.gather(
|
||||||
@@ -349,18 +357,18 @@ def build_filter_by(
|
|||||||
filters.append(column.any(value))
|
filters.append(column.any(value))
|
||||||
elif isinstance(col_type, Enum):
|
elif isinstance(col_type, Enum):
|
||||||
enum_class = col_type.enum_class
|
enum_class = col_type.enum_class
|
||||||
if enum_class is not None and issubclass(enum_class, int):
|
if enum_class is not None:
|
||||||
|
|
||||||
def _coerce_int_enum(v: Any) -> Any:
|
def _coerce_enum(v: Any) -> Any:
|
||||||
if isinstance(v, enum_class):
|
if isinstance(v, enum_class):
|
||||||
return v
|
return v
|
||||||
return enum_class(int(v))
|
return enum_class[v] # lookup by name: "PENDING", "RED"
|
||||||
|
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
filters.append(column.in_([_coerce_int_enum(v) for v in value]))
|
filters.append(column.in_([_coerce_enum(v) for v in value]))
|
||||||
else:
|
else:
|
||||||
filters.append(column == _coerce_int_enum(value))
|
filters.append(column == _coerce_enum(value))
|
||||||
else:
|
else: # pragma: no cover
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
filters.append(column.in_(value))
|
filters.append(column.in_(value))
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1156,25 +1156,16 @@ class TestFilterByIntEnum:
|
|||||||
assert names == {"order-1", "order-3"}
|
assert names == {"order-1", "order-3"}
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_filter_by_plain_int_value(self, db_session: AsyncSession):
|
async def test_filter_by_plain_int_value_raises(self, db_session: AsyncSession):
|
||||||
"""filter_by with a plain int (not an enum member) filters correctly on an IntEnum column."""
|
"""filter_by with a plain int on an IntEnum column raises KeyError — use name or member."""
|
||||||
OrderFacetCrud = CrudFactory(Order, facet_fields=[Order.status])
|
OrderFacetCrud = CrudFactory(Order, facet_fields=[Order.status])
|
||||||
await OrderCrud.create(
|
|
||||||
db_session, OrderCreate(name="order-1", status=OrderStatus.PENDING)
|
|
||||||
)
|
|
||||||
await OrderCrud.create(
|
|
||||||
db_session, OrderCreate(name="order-2", status=OrderStatus.SHIPPED)
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await OrderFacetCrud.offset_paginate(
|
with pytest.raises(KeyError):
|
||||||
db_session,
|
await OrderFacetCrud.offset_paginate(
|
||||||
filter_by={"status": 1}, # plain int, not OrderStatus.PENDING
|
db_session,
|
||||||
schema=OrderRead,
|
filter_by={"status": 1},
|
||||||
)
|
schema=OrderRead,
|
||||||
|
)
|
||||||
assert isinstance(result.pagination, OffsetPagination)
|
|
||||||
assert result.pagination.total_count == 1
|
|
||||||
assert result.data[0].name == "order-1"
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_filter_by_intenum_list(self, db_session: AsyncSession):
|
async def test_filter_by_intenum_list(self, db_session: AsyncSession):
|
||||||
@@ -1202,8 +1193,43 @@ class TestFilterByIntEnum:
|
|||||||
assert names == {"order-1", "order-2"}
|
assert names == {"order-1", "order-2"}
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_filter_by_plain_int_list(self, db_session: AsyncSession):
|
async def test_filter_by_plain_int_list_raises(self, db_session: AsyncSession):
|
||||||
"""filter_by with a list of plain ints filters correctly on an IntEnum column."""
|
"""filter_by with a list of plain ints on an IntEnum column raises KeyError — use names or members."""
|
||||||
|
OrderFacetCrud = CrudFactory(Order, facet_fields=[Order.status])
|
||||||
|
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
await OrderFacetCrud.offset_paginate(
|
||||||
|
db_session,
|
||||||
|
filter_by={"status": [1, 3]},
|
||||||
|
schema=OrderRead,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_filter_by_intenum_name_string(self, db_session: AsyncSession):
|
||||||
|
"""filter_by with the enum member name as a string filters correctly."""
|
||||||
|
OrderFacetCrud = CrudFactory(Order, facet_fields=[Order.status])
|
||||||
|
await OrderCrud.create(
|
||||||
|
db_session, OrderCreate(name="order-1", status=OrderStatus.PENDING)
|
||||||
|
)
|
||||||
|
await OrderCrud.create(
|
||||||
|
db_session, OrderCreate(name="order-2", status=OrderStatus.SHIPPED)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await OrderFacetCrud.offset_paginate(
|
||||||
|
db_session,
|
||||||
|
filter_by={
|
||||||
|
"status": "PENDING"
|
||||||
|
}, # name as string, e.g. from HTTP query param
|
||||||
|
schema=OrderRead,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 1
|
||||||
|
assert result.data[0].name == "order-1"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_filter_by_intenum_name_string_list(self, db_session: AsyncSession):
|
||||||
|
"""filter_by with a list of enum name strings produces an IN filter."""
|
||||||
OrderFacetCrud = CrudFactory(Order, facet_fields=[Order.status])
|
OrderFacetCrud = CrudFactory(Order, facet_fields=[Order.status])
|
||||||
await OrderCrud.create(
|
await OrderCrud.create(
|
||||||
db_session, OrderCreate(name="order-1", status=OrderStatus.PENDING)
|
db_session, OrderCreate(name="order-1", status=OrderStatus.PENDING)
|
||||||
@@ -1217,7 +1243,7 @@ class TestFilterByIntEnum:
|
|||||||
|
|
||||||
result = await OrderFacetCrud.offset_paginate(
|
result = await OrderFacetCrud.offset_paginate(
|
||||||
db_session,
|
db_session,
|
||||||
filter_by={"status": [1, 3]}, # plain ints for PENDING and SHIPPED
|
filter_by={"status": ["PENDING", "SHIPPED"]},
|
||||||
schema=OrderRead,
|
schema=OrderRead,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user