mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-16 06:36:26 +02:00
fix: widen JoinType to accept aliased and polymorphic targets
This commit is contained in:
@@ -15,7 +15,7 @@ ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
|||||||
SchemaType = TypeVar("SchemaType", bound=BaseModel)
|
SchemaType = TypeVar("SchemaType", bound=BaseModel)
|
||||||
|
|
||||||
# CRUD type aliases
|
# CRUD type aliases
|
||||||
JoinType = list[tuple[type[DeclarativeBase], Any]]
|
JoinType = list[tuple[type[DeclarativeBase] | Any, Any]]
|
||||||
M2MFieldType = Mapping[str, QueryableAttribute[Any]]
|
M2MFieldType = Mapping[str, QueryableAttribute[Any]]
|
||||||
OrderByClause = ColumnElement[Any] | QueryableAttribute[Any]
|
OrderByClause = ColumnElement[Any] | QueryableAttribute[Any]
|
||||||
|
|
||||||
|
|||||||
@@ -139,6 +139,17 @@ class Post(Base):
|
|||||||
tags: Mapped[list[Tag]] = relationship(secondary=post_tags)
|
tags: Mapped[list[Tag]] = relationship(secondary=post_tags)
|
||||||
|
|
||||||
|
|
||||||
|
class Transfer(Base):
|
||||||
|
"""Test model with two FKs to the same table (users)."""
|
||||||
|
|
||||||
|
__tablename__ = "transfers"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||||
|
amount: Mapped[str] = mapped_column(String(50))
|
||||||
|
sender_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"))
|
||||||
|
receiver_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"))
|
||||||
|
|
||||||
|
|
||||||
class Article(Base):
|
class Article(Base):
|
||||||
"""Test article model with ARRAY and JSON columns."""
|
"""Test article model with ARRAY and JSON columns."""
|
||||||
|
|
||||||
@@ -300,6 +311,23 @@ class ArticleRead(PydanticBase):
|
|||||||
labels: list[str]
|
labels: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class TransferCreate(BaseModel):
|
||||||
|
"""Schema for creating a transfer."""
|
||||||
|
|
||||||
|
id: uuid.UUID | None = None
|
||||||
|
amount: str
|
||||||
|
sender_id: uuid.UUID
|
||||||
|
receiver_id: uuid.UUID
|
||||||
|
|
||||||
|
|
||||||
|
class TransferRead(PydanticBase):
|
||||||
|
"""Schema for reading a transfer."""
|
||||||
|
|
||||||
|
id: uuid.UUID
|
||||||
|
amount: str
|
||||||
|
|
||||||
|
|
||||||
|
TransferCrud = CrudFactory(Transfer)
|
||||||
ArticleCrud = CrudFactory(Article)
|
ArticleCrud = CrudFactory(Article)
|
||||||
RoleCrud = CrudFactory(Role)
|
RoleCrud = CrudFactory(Role)
|
||||||
RoleCursorCrud = CrudFactory(Role, cursor_column=Role.id)
|
RoleCursorCrud = CrudFactory(Role, cursor_column=Role.id)
|
||||||
|
|||||||
@@ -38,6 +38,10 @@ from .conftest import (
|
|||||||
Tag,
|
Tag,
|
||||||
TagCreate,
|
TagCreate,
|
||||||
TagCrud,
|
TagCrud,
|
||||||
|
Transfer,
|
||||||
|
TransferCreate,
|
||||||
|
TransferCrud,
|
||||||
|
TransferRead,
|
||||||
User,
|
User,
|
||||||
UserCreate,
|
UserCreate,
|
||||||
UserCrud,
|
UserCrud,
|
||||||
@@ -1282,6 +1286,128 @@ class TestCrudJoins:
|
|||||||
assert users[0].username == "multi_join"
|
assert users[0].username == "multi_join"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCrudAliasedJoins:
|
||||||
|
"""Tests for CRUD operations with aliased joins (same table joined twice)."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_get_multi_with_aliased_joins(self, db_session: AsyncSession):
|
||||||
|
"""Aliased joins allow joining the same table twice."""
|
||||||
|
from sqlalchemy.orm import aliased
|
||||||
|
|
||||||
|
alice = await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="alice", email="alice@test.com")
|
||||||
|
)
|
||||||
|
bob = await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="bob", email="bob@test.com")
|
||||||
|
)
|
||||||
|
await TransferCrud.create(
|
||||||
|
db_session,
|
||||||
|
TransferCreate(amount="100", sender_id=alice.id, receiver_id=bob.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
Sender = aliased(User)
|
||||||
|
Receiver = aliased(User)
|
||||||
|
|
||||||
|
results = await TransferCrud.get_multi(
|
||||||
|
db_session,
|
||||||
|
joins=[
|
||||||
|
(Sender, Transfer.sender_id == Sender.id),
|
||||||
|
(Receiver, Transfer.receiver_id == Receiver.id),
|
||||||
|
],
|
||||||
|
filters=[Sender.username == "alice", Receiver.username == "bob"],
|
||||||
|
)
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].amount == "100"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_get_multi_aliased_no_match(self, db_session: AsyncSession):
|
||||||
|
"""Aliased joins correctly filter out non-matching rows."""
|
||||||
|
from sqlalchemy.orm import aliased
|
||||||
|
|
||||||
|
alice = await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="alice", email="alice@test.com")
|
||||||
|
)
|
||||||
|
bob = await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="bob", email="bob@test.com")
|
||||||
|
)
|
||||||
|
await TransferCrud.create(
|
||||||
|
db_session,
|
||||||
|
TransferCreate(amount="100", sender_id=alice.id, receiver_id=bob.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
Sender = aliased(User)
|
||||||
|
Receiver = aliased(User)
|
||||||
|
|
||||||
|
# bob is receiver, not sender — should return nothing
|
||||||
|
results = await TransferCrud.get_multi(
|
||||||
|
db_session,
|
||||||
|
joins=[
|
||||||
|
(Sender, Transfer.sender_id == Sender.id),
|
||||||
|
(Receiver, Transfer.receiver_id == Receiver.id),
|
||||||
|
],
|
||||||
|
filters=[Sender.username == "bob", Receiver.username == "alice"],
|
||||||
|
)
|
||||||
|
assert len(results) == 0
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_paginate_with_aliased_joins(self, db_session: AsyncSession):
|
||||||
|
"""Aliased joins work with offset_paginate."""
|
||||||
|
from sqlalchemy.orm import aliased
|
||||||
|
|
||||||
|
alice = await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="alice", email="alice@test.com")
|
||||||
|
)
|
||||||
|
bob = await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="bob", email="bob@test.com")
|
||||||
|
)
|
||||||
|
await TransferCrud.create(
|
||||||
|
db_session,
|
||||||
|
TransferCreate(amount="50", sender_id=alice.id, receiver_id=bob.id),
|
||||||
|
)
|
||||||
|
await TransferCrud.create(
|
||||||
|
db_session,
|
||||||
|
TransferCreate(amount="75", sender_id=bob.id, receiver_id=alice.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
Sender = aliased(User)
|
||||||
|
result = await TransferCrud.offset_paginate(
|
||||||
|
db_session,
|
||||||
|
joins=[(Sender, Transfer.sender_id == Sender.id)],
|
||||||
|
filters=[Sender.username == "alice"],
|
||||||
|
schema=TransferRead,
|
||||||
|
)
|
||||||
|
assert result.pagination.total_count == 1
|
||||||
|
assert result.data[0].amount == "50"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_count_with_aliased_join(self, db_session: AsyncSession):
|
||||||
|
"""Aliased joins work with count."""
|
||||||
|
from sqlalchemy.orm import aliased
|
||||||
|
|
||||||
|
alice = await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="alice", email="alice@test.com")
|
||||||
|
)
|
||||||
|
bob = await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="bob", email="bob@test.com")
|
||||||
|
)
|
||||||
|
await TransferCrud.create(
|
||||||
|
db_session,
|
||||||
|
TransferCreate(amount="10", sender_id=alice.id, receiver_id=bob.id),
|
||||||
|
)
|
||||||
|
await TransferCrud.create(
|
||||||
|
db_session,
|
||||||
|
TransferCreate(amount="20", sender_id=alice.id, receiver_id=bob.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
Sender = aliased(User)
|
||||||
|
count = await TransferCrud.count(
|
||||||
|
db_session,
|
||||||
|
joins=[(Sender, Transfer.sender_id == Sender.id)],
|
||||||
|
filters=[Sender.username == "alice"],
|
||||||
|
)
|
||||||
|
assert count == 2
|
||||||
|
|
||||||
|
|
||||||
class TestCrudFactoryM2M:
|
class TestCrudFactoryM2M:
|
||||||
"""Tests for CrudFactory with m2m_fields parameter."""
|
"""Tests for CrudFactory with m2m_fields parameter."""
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user