diff --git a/src/fastapi_toolsets/types.py b/src/fastapi_toolsets/types.py index a89eef8..ed43164 100644 --- a/src/fastapi_toolsets/types.py +++ b/src/fastapi_toolsets/types.py @@ -15,7 +15,7 @@ ModelType = TypeVar("ModelType", bound=DeclarativeBase) SchemaType = TypeVar("SchemaType", bound=BaseModel) # CRUD type aliases -JoinType = list[tuple[type[DeclarativeBase], Any]] +JoinType = list[tuple[type[DeclarativeBase] | Any, Any]] M2MFieldType = Mapping[str, QueryableAttribute[Any]] OrderByClause = ColumnElement[Any] | QueryableAttribute[Any] diff --git a/tests/conftest.py b/tests/conftest.py index 292a2d9..fbf30fa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -139,6 +139,17 @@ class Post(Base): tags: Mapped[list[Tag]] = relationship(secondary=post_tags) +class Transfer(Base): + """Test model with two FKs to the same table (users).""" + + __tablename__ = "transfers" + + id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4) + amount: Mapped[str] = mapped_column(String(50)) + sender_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id")) + receiver_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id")) + + class Article(Base): """Test article model with ARRAY and JSON columns.""" @@ -300,6 +311,23 @@ class ArticleRead(PydanticBase): labels: list[str] +class TransferCreate(BaseModel): + """Schema for creating a transfer.""" + + id: uuid.UUID | None = None + amount: str + sender_id: uuid.UUID + receiver_id: uuid.UUID + + +class TransferRead(PydanticBase): + """Schema for reading a transfer.""" + + id: uuid.UUID + amount: str + + +TransferCrud = CrudFactory(Transfer) ArticleCrud = CrudFactory(Article) RoleCrud = CrudFactory(Role) RoleCursorCrud = CrudFactory(Role, cursor_column=Role.id) diff --git a/tests/test_crud.py b/tests/test_crud.py index aa5708a..3c3ffe4 100644 --- a/tests/test_crud.py +++ b/tests/test_crud.py @@ -38,6 +38,10 @@ from .conftest import ( Tag, TagCreate, TagCrud, + Transfer, + TransferCreate, + TransferCrud, + TransferRead, User, UserCreate, UserCrud, @@ -1282,6 +1286,128 @@ class TestCrudJoins: assert users[0].username == "multi_join" +class TestCrudAliasedJoins: + """Tests for CRUD operations with aliased joins (same table joined twice).""" + + @pytest.mark.anyio + async def test_get_multi_with_aliased_joins(self, db_session: AsyncSession): + """Aliased joins allow joining the same table twice.""" + from sqlalchemy.orm import aliased + + alice = await UserCrud.create( + db_session, UserCreate(username="alice", email="alice@test.com") + ) + bob = await UserCrud.create( + db_session, UserCreate(username="bob", email="bob@test.com") + ) + await TransferCrud.create( + db_session, + TransferCreate(amount="100", sender_id=alice.id, receiver_id=bob.id), + ) + + Sender = aliased(User) + Receiver = aliased(User) + + results = await TransferCrud.get_multi( + db_session, + joins=[ + (Sender, Transfer.sender_id == Sender.id), + (Receiver, Transfer.receiver_id == Receiver.id), + ], + filters=[Sender.username == "alice", Receiver.username == "bob"], + ) + assert len(results) == 1 + assert results[0].amount == "100" + + @pytest.mark.anyio + async def test_get_multi_aliased_no_match(self, db_session: AsyncSession): + """Aliased joins correctly filter out non-matching rows.""" + from sqlalchemy.orm import aliased + + alice = await UserCrud.create( + db_session, UserCreate(username="alice", email="alice@test.com") + ) + bob = await UserCrud.create( + db_session, UserCreate(username="bob", email="bob@test.com") + ) + await TransferCrud.create( + db_session, + TransferCreate(amount="100", sender_id=alice.id, receiver_id=bob.id), + ) + + Sender = aliased(User) + Receiver = aliased(User) + + # bob is receiver, not sender — should return nothing + results = await TransferCrud.get_multi( + db_session, + joins=[ + (Sender, Transfer.sender_id == Sender.id), + (Receiver, Transfer.receiver_id == Receiver.id), + ], + filters=[Sender.username == "bob", Receiver.username == "alice"], + ) + assert len(results) == 0 + + @pytest.mark.anyio + async def test_paginate_with_aliased_joins(self, db_session: AsyncSession): + """Aliased joins work with offset_paginate.""" + from sqlalchemy.orm import aliased + + alice = await UserCrud.create( + db_session, UserCreate(username="alice", email="alice@test.com") + ) + bob = await UserCrud.create( + db_session, UserCreate(username="bob", email="bob@test.com") + ) + await TransferCrud.create( + db_session, + TransferCreate(amount="50", sender_id=alice.id, receiver_id=bob.id), + ) + await TransferCrud.create( + db_session, + TransferCreate(amount="75", sender_id=bob.id, receiver_id=alice.id), + ) + + Sender = aliased(User) + result = await TransferCrud.offset_paginate( + db_session, + joins=[(Sender, Transfer.sender_id == Sender.id)], + filters=[Sender.username == "alice"], + schema=TransferRead, + ) + assert result.pagination.total_count == 1 + assert result.data[0].amount == "50" + + @pytest.mark.anyio + async def test_count_with_aliased_join(self, db_session: AsyncSession): + """Aliased joins work with count.""" + from sqlalchemy.orm import aliased + + alice = await UserCrud.create( + db_session, UserCreate(username="alice", email="alice@test.com") + ) + bob = await UserCrud.create( + db_session, UserCreate(username="bob", email="bob@test.com") + ) + await TransferCrud.create( + db_session, + TransferCreate(amount="10", sender_id=alice.id, receiver_id=bob.id), + ) + await TransferCrud.create( + db_session, + TransferCreate(amount="20", sender_id=alice.id, receiver_id=bob.id), + ) + + Sender = aliased(User) + count = await TransferCrud.count( + db_session, + joins=[(Sender, Transfer.sender_id == Sender.id)], + filters=[Sender.username == "alice"], + ) + assert count == 2 + + class TestCrudFactoryM2M: """Tests for CrudFactory with m2m_fields parameter."""