Compare commits

..

2 Commits

5 changed files with 209 additions and 13 deletions

View File

@@ -278,6 +278,18 @@ _EQUALITY_TYPES = (String, Integer, Numeric, Date, DateTime, Time, Enum, Uuid)
"""Column types that support equality / IN filtering in build_filter_by.""" """Column types that support equality / IN filtering in build_filter_by."""
def _coerce_bool(value: Any) -> bool:
"""Coerce a string value to a Python bool for Boolean column filtering."""
if isinstance(value, bool):
return value
if isinstance(value, str):
if value.lower() == "true":
return True
if value.lower() == "false":
return False
raise ValueError(f"Cannot coerce {value!r} to bool")
def build_filter_by( def build_filter_by(
filter_by: dict[str, Any], filter_by: dict[str, Any],
facet_fields: Sequence[FacetFieldType], facet_fields: Sequence[FacetFieldType],
@@ -324,16 +336,17 @@ def build_filter_by(
added_join_keys.add(rel_key) added_join_keys.add(rel_key)
col_type = column.property.columns[0].type col_type = column.property.columns[0].type
if isinstance(col_type, ARRAY): if isinstance(col_type, Boolean):
coerce = _coerce_bool
if isinstance(value, list):
filters.append(column.in_([coerce(v) for v in value]))
else:
filters.append(column == coerce(value))
elif isinstance(col_type, ARRAY):
if isinstance(value, list): if isinstance(value, list):
filters.append(column.overlap(value)) filters.append(column.overlap(value))
else: else:
filters.append(column.any(value)) filters.append(column.any(value))
elif isinstance(col_type, Boolean):
if isinstance(value, list):
filters.append(column.in_(value))
else:
filters.append(column.is_(value))
elif isinstance(col_type, _EQUALITY_TYPES): elif isinstance(col_type, _EQUALITY_TYPES):
if isinstance(value, list): if isinstance(value, list):
filters.append(column.in_(value)) filters.append(column.in_(value))

View File

@@ -15,7 +15,7 @@ ModelType = TypeVar("ModelType", bound=DeclarativeBase)
SchemaType = TypeVar("SchemaType", bound=BaseModel) SchemaType = TypeVar("SchemaType", bound=BaseModel)
# CRUD type aliases # CRUD type aliases
JoinType = list[tuple[type[DeclarativeBase], Any]] JoinType = list[tuple[type[DeclarativeBase] | Any, Any]]
M2MFieldType = Mapping[str, QueryableAttribute[Any]] M2MFieldType = Mapping[str, QueryableAttribute[Any]]
OrderByClause = ColumnElement[Any] | QueryableAttribute[Any] OrderByClause = ColumnElement[Any] | QueryableAttribute[Any]

View File

@@ -139,6 +139,17 @@ class Post(Base):
tags: Mapped[list[Tag]] = relationship(secondary=post_tags) tags: Mapped[list[Tag]] = relationship(secondary=post_tags)
class Transfer(Base):
"""Test model with two FKs to the same table (users)."""
__tablename__ = "transfers"
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
amount: Mapped[str] = mapped_column(String(50))
sender_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"))
receiver_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"))
class Article(Base): class Article(Base):
"""Test article model with ARRAY and JSON columns.""" """Test article model with ARRAY and JSON columns."""
@@ -300,6 +311,23 @@ class ArticleRead(PydanticBase):
labels: list[str] labels: list[str]
class TransferCreate(BaseModel):
"""Schema for creating a transfer."""
id: uuid.UUID | None = None
amount: str
sender_id: uuid.UUID
receiver_id: uuid.UUID
class TransferRead(PydanticBase):
"""Schema for reading a transfer."""
id: uuid.UUID
amount: str
TransferCrud = CrudFactory(Transfer)
ArticleCrud = CrudFactory(Article) ArticleCrud = CrudFactory(Article)
RoleCrud = CrudFactory(Role) RoleCrud = CrudFactory(Role)
RoleCursorCrud = CrudFactory(Role, cursor_column=Role.id) RoleCursorCrud = CrudFactory(Role, cursor_column=Role.id)

View File

@@ -38,6 +38,10 @@ from .conftest import (
Tag, Tag,
TagCreate, TagCreate,
TagCrud, TagCrud,
Transfer,
TransferCreate,
TransferCrud,
TransferRead,
User, User,
UserCreate, UserCreate,
UserCrud, UserCrud,
@@ -1282,6 +1286,128 @@ class TestCrudJoins:
assert users[0].username == "multi_join" assert users[0].username == "multi_join"
class TestCrudAliasedJoins:
"""Tests for CRUD operations with aliased joins (same table joined twice)."""
@pytest.mark.anyio
async def test_get_multi_with_aliased_joins(self, db_session: AsyncSession):
"""Aliased joins allow joining the same table twice."""
from sqlalchemy.orm import aliased
alice = await UserCrud.create(
db_session, UserCreate(username="alice", email="alice@test.com")
)
bob = await UserCrud.create(
db_session, UserCreate(username="bob", email="bob@test.com")
)
await TransferCrud.create(
db_session,
TransferCreate(amount="100", sender_id=alice.id, receiver_id=bob.id),
)
Sender = aliased(User)
Receiver = aliased(User)
results = await TransferCrud.get_multi(
db_session,
joins=[
(Sender, Transfer.sender_id == Sender.id),
(Receiver, Transfer.receiver_id == Receiver.id),
],
filters=[Sender.username == "alice", Receiver.username == "bob"],
)
assert len(results) == 1
assert results[0].amount == "100"
@pytest.mark.anyio
async def test_get_multi_aliased_no_match(self, db_session: AsyncSession):
"""Aliased joins correctly filter out non-matching rows."""
from sqlalchemy.orm import aliased
alice = await UserCrud.create(
db_session, UserCreate(username="alice", email="alice@test.com")
)
bob = await UserCrud.create(
db_session, UserCreate(username="bob", email="bob@test.com")
)
await TransferCrud.create(
db_session,
TransferCreate(amount="100", sender_id=alice.id, receiver_id=bob.id),
)
Sender = aliased(User)
Receiver = aliased(User)
# bob is receiver, not sender — should return nothing
results = await TransferCrud.get_multi(
db_session,
joins=[
(Sender, Transfer.sender_id == Sender.id),
(Receiver, Transfer.receiver_id == Receiver.id),
],
filters=[Sender.username == "bob", Receiver.username == "alice"],
)
assert len(results) == 0
@pytest.mark.anyio
async def test_paginate_with_aliased_joins(self, db_session: AsyncSession):
"""Aliased joins work with offset_paginate."""
from sqlalchemy.orm import aliased
alice = await UserCrud.create(
db_session, UserCreate(username="alice", email="alice@test.com")
)
bob = await UserCrud.create(
db_session, UserCreate(username="bob", email="bob@test.com")
)
await TransferCrud.create(
db_session,
TransferCreate(amount="50", sender_id=alice.id, receiver_id=bob.id),
)
await TransferCrud.create(
db_session,
TransferCreate(amount="75", sender_id=bob.id, receiver_id=alice.id),
)
Sender = aliased(User)
result = await TransferCrud.offset_paginate(
db_session,
joins=[(Sender, Transfer.sender_id == Sender.id)],
filters=[Sender.username == "alice"],
schema=TransferRead,
)
assert result.pagination.total_count == 1
assert result.data[0].amount == "50"
@pytest.mark.anyio
async def test_count_with_aliased_join(self, db_session: AsyncSession):
"""Aliased joins work with count."""
from sqlalchemy.orm import aliased
alice = await UserCrud.create(
db_session, UserCreate(username="alice", email="alice@test.com")
)
bob = await UserCrud.create(
db_session, UserCreate(username="bob", email="bob@test.com")
)
await TransferCrud.create(
db_session,
TransferCreate(amount="10", sender_id=alice.id, receiver_id=bob.id),
)
await TransferCrud.create(
db_session,
TransferCreate(amount="20", sender_id=alice.id, receiver_id=bob.id),
)
Sender = aliased(User)
count = await TransferCrud.count(
db_session,
joins=[(Sender, Transfer.sender_id == Sender.id)],
filters=[Sender.username == "alice"],
)
assert count == 2
class TestCrudFactoryM2M: class TestCrudFactoryM2M:
"""Tests for CrudFactory with m2m_fields parameter.""" """Tests for CrudFactory with m2m_fields parameter."""

View File

@@ -971,7 +971,7 @@ class TestFilterBy:
@pytest.mark.anyio @pytest.mark.anyio
async def test_bool_filter_false(self, db_session: AsyncSession): async def test_bool_filter_false(self, db_session: AsyncSession):
"""filter_by with a boolean False value correctly filters rows.""" """filter_by with a string 'false' value correctly filters rows."""
UserBoolCrud = CrudFactory(User, facet_fields=[User.is_active]) UserBoolCrud = CrudFactory(User, facet_fields=[User.is_active])
await UserCrud.create( await UserCrud.create(
db_session, UserCreate(username="alice", email="a@test.com", is_active=True) db_session, UserCreate(username="alice", email="a@test.com", is_active=True)
@@ -982,7 +982,7 @@ class TestFilterBy:
) )
result = await UserBoolCrud.offset_paginate( result = await UserBoolCrud.offset_paginate(
db_session, filter_by={"is_active": False}, schema=UserRead db_session, filter_by={"is_active": "false"}, schema=UserRead
) )
assert isinstance(result.pagination, OffsetPagination) assert isinstance(result.pagination, OffsetPagination)
@@ -991,7 +991,7 @@ class TestFilterBy:
@pytest.mark.anyio @pytest.mark.anyio
async def test_bool_filter_true(self, db_session: AsyncSession): async def test_bool_filter_true(self, db_session: AsyncSession):
"""filter_by with a boolean True value correctly filters rows.""" """filter_by with a string 'true' value correctly filters rows."""
UserBoolCrud = CrudFactory(User, facet_fields=[User.is_active]) UserBoolCrud = CrudFactory(User, facet_fields=[User.is_active])
await UserCrud.create( await UserCrud.create(
db_session, UserCreate(username="alice", email="a@test.com", is_active=True) db_session, UserCreate(username="alice", email="a@test.com", is_active=True)
@@ -1002,7 +1002,7 @@ class TestFilterBy:
) )
result = await UserBoolCrud.offset_paginate( result = await UserBoolCrud.offset_paginate(
db_session, filter_by={"is_active": True}, schema=UserRead db_session, filter_by={"is_active": "true"}, schema=UserRead
) )
assert isinstance(result.pagination, OffsetPagination) assert isinstance(result.pagination, OffsetPagination)
@@ -1011,7 +1011,7 @@ class TestFilterBy:
@pytest.mark.anyio @pytest.mark.anyio
async def test_bool_filter_list(self, db_session: AsyncSession): async def test_bool_filter_list(self, db_session: AsyncSession):
"""filter_by with a list of booleans produces an IN clause.""" """filter_by with a list of string booleans produces an IN clause."""
UserBoolCrud = CrudFactory(User, facet_fields=[User.is_active]) UserBoolCrud = CrudFactory(User, facet_fields=[User.is_active])
await UserCrud.create( await UserCrud.create(
db_session, UserCreate(username="alice", email="a@test.com", is_active=True) db_session, UserCreate(username="alice", email="a@test.com", is_active=True)
@@ -1022,12 +1022,41 @@ class TestFilterBy:
) )
result = await UserBoolCrud.offset_paginate( result = await UserBoolCrud.offset_paginate(
db_session, filter_by={"is_active": [True, False]}, schema=UserRead db_session, filter_by={"is_active": ["true", "false"]}, schema=UserRead
) )
assert isinstance(result.pagination, OffsetPagination) assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 2 assert result.pagination.total_count == 2
@pytest.mark.anyio
async def test_bool_filter_native_bool(self, db_session: AsyncSession):
"""filter_by with a native Python bool passes through coercion."""
UserBoolCrud = CrudFactory(User, facet_fields=[User.is_active])
await UserCrud.create(
db_session, UserCreate(username="alice", email="a@test.com", is_active=True)
)
result = await UserBoolCrud.offset_paginate(
db_session, filter_by={"is_active": True}, schema=UserRead
)
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 1
def test_bool_coerce_invalid_value(self):
"""_coerce_bool raises ValueError for non-bool, non-string values."""
from fastapi_toolsets.crud.search import _coerce_bool
with pytest.raises(ValueError, match="Cannot coerce"):
_coerce_bool(42)
def test_bool_coerce_invalid_string(self):
"""_coerce_bool raises ValueError for unrecognized string values."""
from fastapi_toolsets.crud.search import _coerce_bool
with pytest.raises(ValueError, match="Cannot coerce"):
_coerce_bool("maybe")
@pytest.mark.anyio @pytest.mark.anyio
async def test_array_contains_single_value(self, db_session: AsyncSession): async def test_array_contains_single_value(self, db_session: AsyncSession):
"""filter_by on an ARRAY column with a scalar checks containment.""" """filter_by on an ARRAY column with a scalar checks containment."""