mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-15 22:26:25 +02:00
fix: coerce string values to bool for Boolean facet field filtering (#219)
This commit is contained in:
@@ -278,6 +278,18 @@ _EQUALITY_TYPES = (String, Integer, Numeric, Date, DateTime, Time, Enum, Uuid)
|
|||||||
"""Column types that support equality / IN filtering in build_filter_by."""
|
"""Column types that support equality / IN filtering in build_filter_by."""
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_bool(value: Any) -> bool:
|
||||||
|
"""Coerce a string value to a Python bool for Boolean column filtering."""
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return value
|
||||||
|
if isinstance(value, str):
|
||||||
|
if value.lower() == "true":
|
||||||
|
return True
|
||||||
|
if value.lower() == "false":
|
||||||
|
return False
|
||||||
|
raise ValueError(f"Cannot coerce {value!r} to bool")
|
||||||
|
|
||||||
|
|
||||||
def build_filter_by(
|
def build_filter_by(
|
||||||
filter_by: dict[str, Any],
|
filter_by: dict[str, Any],
|
||||||
facet_fields: Sequence[FacetFieldType],
|
facet_fields: Sequence[FacetFieldType],
|
||||||
@@ -324,16 +336,17 @@ def build_filter_by(
|
|||||||
added_join_keys.add(rel_key)
|
added_join_keys.add(rel_key)
|
||||||
|
|
||||||
col_type = column.property.columns[0].type
|
col_type = column.property.columns[0].type
|
||||||
if isinstance(col_type, ARRAY):
|
if isinstance(col_type, Boolean):
|
||||||
|
coerce = _coerce_bool
|
||||||
|
if isinstance(value, list):
|
||||||
|
filters.append(column.in_([coerce(v) for v in value]))
|
||||||
|
else:
|
||||||
|
filters.append(column == coerce(value))
|
||||||
|
elif isinstance(col_type, ARRAY):
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
filters.append(column.overlap(value))
|
filters.append(column.overlap(value))
|
||||||
else:
|
else:
|
||||||
filters.append(column.any(value))
|
filters.append(column.any(value))
|
||||||
elif isinstance(col_type, Boolean):
|
|
||||||
if isinstance(value, list):
|
|
||||||
filters.append(column.in_(value))
|
|
||||||
else:
|
|
||||||
filters.append(column.is_(value))
|
|
||||||
elif isinstance(col_type, _EQUALITY_TYPES):
|
elif isinstance(col_type, _EQUALITY_TYPES):
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
filters.append(column.in_(value))
|
filters.append(column.in_(value))
|
||||||
|
|||||||
@@ -971,7 +971,7 @@ class TestFilterBy:
|
|||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_bool_filter_false(self, db_session: AsyncSession):
|
async def test_bool_filter_false(self, db_session: AsyncSession):
|
||||||
"""filter_by with a boolean False value correctly filters rows."""
|
"""filter_by with a string 'false' value correctly filters rows."""
|
||||||
UserBoolCrud = CrudFactory(User, facet_fields=[User.is_active])
|
UserBoolCrud = CrudFactory(User, facet_fields=[User.is_active])
|
||||||
await UserCrud.create(
|
await UserCrud.create(
|
||||||
db_session, UserCreate(username="alice", email="a@test.com", is_active=True)
|
db_session, UserCreate(username="alice", email="a@test.com", is_active=True)
|
||||||
@@ -982,7 +982,7 @@ class TestFilterBy:
|
|||||||
)
|
)
|
||||||
|
|
||||||
result = await UserBoolCrud.offset_paginate(
|
result = await UserBoolCrud.offset_paginate(
|
||||||
db_session, filter_by={"is_active": False}, schema=UserRead
|
db_session, filter_by={"is_active": "false"}, schema=UserRead
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(result.pagination, OffsetPagination)
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
@@ -991,7 +991,7 @@ class TestFilterBy:
|
|||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_bool_filter_true(self, db_session: AsyncSession):
|
async def test_bool_filter_true(self, db_session: AsyncSession):
|
||||||
"""filter_by with a boolean True value correctly filters rows."""
|
"""filter_by with a string 'true' value correctly filters rows."""
|
||||||
UserBoolCrud = CrudFactory(User, facet_fields=[User.is_active])
|
UserBoolCrud = CrudFactory(User, facet_fields=[User.is_active])
|
||||||
await UserCrud.create(
|
await UserCrud.create(
|
||||||
db_session, UserCreate(username="alice", email="a@test.com", is_active=True)
|
db_session, UserCreate(username="alice", email="a@test.com", is_active=True)
|
||||||
@@ -1002,7 +1002,7 @@ class TestFilterBy:
|
|||||||
)
|
)
|
||||||
|
|
||||||
result = await UserBoolCrud.offset_paginate(
|
result = await UserBoolCrud.offset_paginate(
|
||||||
db_session, filter_by={"is_active": True}, schema=UserRead
|
db_session, filter_by={"is_active": "true"}, schema=UserRead
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(result.pagination, OffsetPagination)
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
@@ -1011,7 +1011,7 @@ class TestFilterBy:
|
|||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_bool_filter_list(self, db_session: AsyncSession):
|
async def test_bool_filter_list(self, db_session: AsyncSession):
|
||||||
"""filter_by with a list of booleans produces an IN clause."""
|
"""filter_by with a list of string booleans produces an IN clause."""
|
||||||
UserBoolCrud = CrudFactory(User, facet_fields=[User.is_active])
|
UserBoolCrud = CrudFactory(User, facet_fields=[User.is_active])
|
||||||
await UserCrud.create(
|
await UserCrud.create(
|
||||||
db_session, UserCreate(username="alice", email="a@test.com", is_active=True)
|
db_session, UserCreate(username="alice", email="a@test.com", is_active=True)
|
||||||
@@ -1022,12 +1022,41 @@ class TestFilterBy:
|
|||||||
)
|
)
|
||||||
|
|
||||||
result = await UserBoolCrud.offset_paginate(
|
result = await UserBoolCrud.offset_paginate(
|
||||||
db_session, filter_by={"is_active": [True, False]}, schema=UserRead
|
db_session, filter_by={"is_active": ["true", "false"]}, schema=UserRead
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(result.pagination, OffsetPagination)
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
assert result.pagination.total_count == 2
|
assert result.pagination.total_count == 2
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_bool_filter_native_bool(self, db_session: AsyncSession):
|
||||||
|
"""filter_by with a native Python bool passes through coercion."""
|
||||||
|
UserBoolCrud = CrudFactory(User, facet_fields=[User.is_active])
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="alice", email="a@test.com", is_active=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserBoolCrud.offset_paginate(
|
||||||
|
db_session, filter_by={"is_active": True}, schema=UserRead
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 1
|
||||||
|
|
||||||
|
def test_bool_coerce_invalid_value(self):
|
||||||
|
"""_coerce_bool raises ValueError for non-bool, non-string values."""
|
||||||
|
from fastapi_toolsets.crud.search import _coerce_bool
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Cannot coerce"):
|
||||||
|
_coerce_bool(42)
|
||||||
|
|
||||||
|
def test_bool_coerce_invalid_string(self):
|
||||||
|
"""_coerce_bool raises ValueError for unrecognized string values."""
|
||||||
|
from fastapi_toolsets.crud.search import _coerce_bool
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Cannot coerce"):
|
||||||
|
_coerce_bool("maybe")
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_array_contains_single_value(self, db_session: AsyncSession):
|
async def test_array_contains_single_value(self, db_session: AsyncSession):
|
||||||
"""filter_by on an ARRAY column with a scalar checks containment."""
|
"""filter_by on an ARRAY column with a scalar checks containment."""
|
||||||
|
|||||||
Reference in New Issue
Block a user