From 2a498148182c53aa3d469883a52db24c91b0e79f Mon Sep 17 00:00:00 2001 From: d3vyce <44915747+d3vyce@users.noreply.github.com> Date: Tue, 7 Apr 2026 11:42:24 +0200 Subject: [PATCH] feat: normalize enum facet values to member names and accept name strings in filter_by (#235) --- src/fastapi_toolsets/crud/search.py | 22 +++++++--- tests/test_crud_search.py | 66 ++++++++++++++++++++--------- 2 files changed, 61 insertions(+), 27 deletions(-) diff --git a/src/fastapi_toolsets/crud/search.py b/src/fastapi_toolsets/crud/search.py index cd7bae1..e2368d8 100644 --- a/src/fastapi_toolsets/crud/search.py +++ b/src/fastapi_toolsets/crud/search.py @@ -265,7 +265,15 @@ async def build_facets( else: q = q.order_by(column) result = await session.execute(q) - values = [row[0] for row in result.all() if row[0] is not None] + col_type = column.property.columns[0].type + enum_class = getattr(col_type, "enum_class", None) + values = [ + row[0].name + if (enum_class is not None and isinstance(row[0], enum_class)) + else row[0] + for row in result.all() + if row[0] is not None + ] return key, values pairs = await asyncio.gather( @@ -349,18 +357,18 @@ def build_filter_by( filters.append(column.any(value)) elif isinstance(col_type, Enum): enum_class = col_type.enum_class - if enum_class is not None and issubclass(enum_class, int): + if enum_class is not None: - def _coerce_int_enum(v: Any) -> Any: + def _coerce_enum(v: Any) -> Any: if isinstance(v, enum_class): return v - return enum_class(int(v)) + return enum_class[v] # lookup by name: "PENDING", "RED" if isinstance(value, list): - filters.append(column.in_([_coerce_int_enum(v) for v in value])) + filters.append(column.in_([_coerce_enum(v) for v in value])) else: - filters.append(column == _coerce_int_enum(value)) - else: + filters.append(column == _coerce_enum(value)) + else: # pragma: no cover if isinstance(value, list): filters.append(column.in_(value)) else: diff --git a/tests/test_crud_search.py b/tests/test_crud_search.py index 426bf4b..1ea6b4d 100644 --- a/tests/test_crud_search.py +++ b/tests/test_crud_search.py @@ -1156,25 +1156,16 @@ class TestFilterByIntEnum: assert names == {"order-1", "order-3"} @pytest.mark.anyio - async def test_filter_by_plain_int_value(self, db_session: AsyncSession): - """filter_by with a plain int (not an enum member) filters correctly on an IntEnum column.""" + async def test_filter_by_plain_int_value_raises(self, db_session: AsyncSession): + """filter_by with a plain int on an IntEnum column raises KeyError — use name or member.""" OrderFacetCrud = CrudFactory(Order, facet_fields=[Order.status]) - await OrderCrud.create( - db_session, OrderCreate(name="order-1", status=OrderStatus.PENDING) - ) - await OrderCrud.create( - db_session, OrderCreate(name="order-2", status=OrderStatus.SHIPPED) - ) - result = await OrderFacetCrud.offset_paginate( - db_session, - filter_by={"status": 1}, # plain int, not OrderStatus.PENDING - schema=OrderRead, - ) - - assert isinstance(result.pagination, OffsetPagination) - assert result.pagination.total_count == 1 - assert result.data[0].name == "order-1" + with pytest.raises(KeyError): + await OrderFacetCrud.offset_paginate( + db_session, + filter_by={"status": 1}, + schema=OrderRead, + ) @pytest.mark.anyio async def test_filter_by_intenum_list(self, db_session: AsyncSession): @@ -1202,8 +1193,43 @@ class TestFilterByIntEnum: assert names == {"order-1", "order-2"} @pytest.mark.anyio - async def test_filter_by_plain_int_list(self, db_session: AsyncSession): - """filter_by with a list of plain ints filters correctly on an IntEnum column.""" + async def test_filter_by_plain_int_list_raises(self, db_session: AsyncSession): + """filter_by with a list of plain ints on an IntEnum column raises KeyError — use names or members.""" + OrderFacetCrud = CrudFactory(Order, facet_fields=[Order.status]) + + with pytest.raises(KeyError): + await OrderFacetCrud.offset_paginate( + db_session, + filter_by={"status": [1, 3]}, + schema=OrderRead, + ) + + @pytest.mark.anyio + async def test_filter_by_intenum_name_string(self, db_session: AsyncSession): + """filter_by with the enum member name as a string filters correctly.""" + OrderFacetCrud = CrudFactory(Order, facet_fields=[Order.status]) + await OrderCrud.create( + db_session, OrderCreate(name="order-1", status=OrderStatus.PENDING) + ) + await OrderCrud.create( + db_session, OrderCreate(name="order-2", status=OrderStatus.SHIPPED) + ) + + result = await OrderFacetCrud.offset_paginate( + db_session, + filter_by={ + "status": "PENDING" + }, # name as string, e.g. from HTTP query param + schema=OrderRead, + ) + + assert isinstance(result.pagination, OffsetPagination) + assert result.pagination.total_count == 1 + assert result.data[0].name == "order-1" + + @pytest.mark.anyio + async def test_filter_by_intenum_name_string_list(self, db_session: AsyncSession): + """filter_by with a list of enum name strings produces an IN filter.""" OrderFacetCrud = CrudFactory(Order, facet_fields=[Order.status]) await OrderCrud.create( db_session, OrderCreate(name="order-1", status=OrderStatus.PENDING) @@ -1217,7 +1243,7 @@ class TestFilterByIntEnum: result = await OrderFacetCrud.offset_paginate( db_session, - filter_by={"status": [1, 3]}, # plain ints for PENDING and SHIPPED + filter_by={"status": ["PENDING", "SHIPPED"]}, schema=OrderRead, )