mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-15 22:26:25 +02:00
feat: add M2M helpers (#247)
This commit is contained in:
454
tests/test_db.py
454
tests/test_db.py
@@ -4,10 +4,26 @@ import asyncio
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
ForeignKey,
|
||||
ForeignKeyConstraint,
|
||||
String,
|
||||
Table,
|
||||
Uuid,
|
||||
select,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy.engine import make_url
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from sqlalchemy.orm import (
|
||||
DeclarativeBase,
|
||||
Mapped,
|
||||
mapped_column,
|
||||
relationship,
|
||||
selectinload,
|
||||
)
|
||||
|
||||
from fastapi_toolsets.db import (
|
||||
LockMode,
|
||||
@@ -17,12 +33,15 @@ from fastapi_toolsets.db import (
|
||||
create_db_dependency,
|
||||
get_transaction,
|
||||
lock_tables,
|
||||
m2m_add,
|
||||
m2m_remove,
|
||||
m2m_set,
|
||||
wait_for_row_change,
|
||||
)
|
||||
from fastapi_toolsets.exceptions import NotFoundError
|
||||
from fastapi_toolsets.pytest import create_db_session
|
||||
|
||||
from .conftest import DATABASE_URL, Base, Role, RoleCrud, User, UserCrud
|
||||
from .conftest import DATABASE_URL, Base, Post, Role, RoleCrud, Tag, User, UserCrud
|
||||
|
||||
|
||||
class TestCreateDbDependency:
|
||||
@@ -81,6 +100,21 @@ class TestCreateDbDependency:
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_no_commit_when_not_in_transaction(self):
|
||||
"""Dependency skips commit if the session is no longer in a transaction on exit."""
|
||||
engine = create_async_engine(DATABASE_URL, echo=False)
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
get_db = create_db_dependency(session_factory)
|
||||
|
||||
async for session in get_db():
|
||||
# Manually commit — session exits the transaction
|
||||
await session.commit()
|
||||
assert not session.in_transaction()
|
||||
# The dependency's post-yield path must not call commit again (no error)
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_after_lock_tables_is_persisted(self):
|
||||
"""Changes made after lock_tables exits (before endpoint returns) are committed.
|
||||
@@ -480,3 +514,417 @@ class TestCleanupTables:
|
||||
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
|
||||
# Should not raise
|
||||
await cleanup_tables(session, EmptyBase)
|
||||
|
||||
|
||||
class TestM2MAdd:
|
||||
"""Tests for m2m_add helper."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_adds_single_related(self, db_session: AsyncSession):
|
||||
"""Associates one related instance via the secondary table."""
|
||||
user = User(username="m2m_author", email="m2m@test.com")
|
||||
db_session.add(user)
|
||||
await db_session.flush()
|
||||
|
||||
post = Post(title="Post A", author_id=user.id)
|
||||
tag = Tag(name="python")
|
||||
db_session.add_all([post, tag])
|
||||
await db_session.flush()
|
||||
|
||||
async with get_transaction(db_session):
|
||||
await m2m_add(db_session, post, Post.tags, tag)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Post).where(Post.id == post.id).options(selectinload(Post.tags))
|
||||
)
|
||||
loaded = result.scalar_one()
|
||||
assert len(loaded.tags) == 1
|
||||
assert loaded.tags[0].id == tag.id
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_adds_multiple_related(self, db_session: AsyncSession):
|
||||
"""Associates multiple related instances in a single call."""
|
||||
user = User(username="m2m_author2", email="m2m2@test.com")
|
||||
db_session.add(user)
|
||||
await db_session.flush()
|
||||
|
||||
post = Post(title="Post B", author_id=user.id)
|
||||
tag1 = Tag(name="web")
|
||||
tag2 = Tag(name="api")
|
||||
tag3 = Tag(name="async")
|
||||
db_session.add_all([post, tag1, tag2, tag3])
|
||||
await db_session.flush()
|
||||
|
||||
async with get_transaction(db_session):
|
||||
await m2m_add(db_session, post, Post.tags, tag1, tag2, tag3)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Post).where(Post.id == post.id).options(selectinload(Post.tags))
|
||||
)
|
||||
loaded = result.scalar_one()
|
||||
assert {t.id for t in loaded.tags} == {tag1.id, tag2.id, tag3.id}
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_noop_for_empty_related(self, db_session: AsyncSession):
|
||||
"""Calling with no related instances is a no-op."""
|
||||
user = User(username="m2m_author3", email="m2m3@test.com")
|
||||
db_session.add(user)
|
||||
await db_session.flush()
|
||||
|
||||
post = Post(title="Post C", author_id=user.id)
|
||||
db_session.add(post)
|
||||
await db_session.flush()
|
||||
|
||||
async with get_transaction(db_session):
|
||||
await m2m_add(db_session, post, Post.tags) # no related instances
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Post).where(Post.id == post.id).options(selectinload(Post.tags))
|
||||
)
|
||||
loaded = result.scalar_one()
|
||||
assert loaded.tags == []
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_ignore_conflicts_true(self, db_session: AsyncSession):
|
||||
"""Duplicate inserts are silently skipped when ignore_conflicts=True."""
|
||||
user = User(username="m2m_author4", email="m2m4@test.com")
|
||||
db_session.add(user)
|
||||
await db_session.flush()
|
||||
|
||||
post = Post(title="Post D", author_id=user.id)
|
||||
tag = Tag(name="duplicate_tag")
|
||||
db_session.add_all([post, tag])
|
||||
await db_session.flush()
|
||||
|
||||
async with get_transaction(db_session):
|
||||
await m2m_add(db_session, post, Post.tags, tag)
|
||||
|
||||
# Second call with ignore_conflicts=True must not raise
|
||||
async with get_transaction(db_session):
|
||||
await m2m_add(db_session, post, Post.tags, tag, ignore_conflicts=True)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Post).where(Post.id == post.id).options(selectinload(Post.tags))
|
||||
)
|
||||
loaded = result.scalar_one()
|
||||
assert len(loaded.tags) == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_ignore_conflicts_false_raises(self, db_session: AsyncSession):
|
||||
"""Duplicate inserts raise IntegrityError when ignore_conflicts=False (default)."""
|
||||
user = User(username="m2m_author5", email="m2m5@test.com")
|
||||
db_session.add(user)
|
||||
await db_session.flush()
|
||||
|
||||
post = Post(title="Post E", author_id=user.id)
|
||||
tag = Tag(name="conflict_tag")
|
||||
db_session.add_all([post, tag])
|
||||
await db_session.flush()
|
||||
|
||||
async with get_transaction(db_session):
|
||||
await m2m_add(db_session, post, Post.tags, tag)
|
||||
|
||||
with pytest.raises(IntegrityError):
|
||||
async with get_transaction(db_session):
|
||||
await m2m_add(db_session, post, Post.tags, tag)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_non_m2m_raises_type_error(self, db_session: AsyncSession):
|
||||
"""Passing a non-M2M relationship attribute raises TypeError."""
|
||||
user = User(username="m2m_author6", email="m2m6@test.com")
|
||||
db_session.add(user)
|
||||
await db_session.flush()
|
||||
|
||||
role = Role(name="type_err_role")
|
||||
db_session.add(role)
|
||||
await db_session.flush()
|
||||
|
||||
with pytest.raises(TypeError, match="Many-to-Many"):
|
||||
await m2m_add(db_session, user, User.role, role)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_works_inside_lock_tables(self, db_session: AsyncSession):
|
||||
"""m2m_add works correctly inside a lock_tables nested transaction."""
|
||||
user = User(username="m2m_lock_author", email="m2m_lock@test.com")
|
||||
db_session.add(user)
|
||||
await db_session.flush()
|
||||
|
||||
async with lock_tables(db_session, [Tag]):
|
||||
tag = Tag(name="locked_tag")
|
||||
db_session.add(tag)
|
||||
await db_session.flush()
|
||||
|
||||
post = Post(title="Post Lock", author_id=user.id)
|
||||
db_session.add(post)
|
||||
await db_session.flush()
|
||||
|
||||
await m2m_add(db_session, post, Post.tags, tag)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Post).where(Post.id == post.id).options(selectinload(Post.tags))
|
||||
)
|
||||
loaded = result.scalar_one()
|
||||
assert len(loaded.tags) == 1
|
||||
assert loaded.tags[0].name == "locked_tag"
|
||||
|
||||
|
||||
class _LocalBase(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
_comp_assoc = Table(
|
||||
"_comp_assoc",
|
||||
_LocalBase.metadata,
|
||||
Column("owner_id", Uuid, ForeignKey("_comp_owners.id"), primary_key=True),
|
||||
Column("item_group", String(50), primary_key=True),
|
||||
Column("item_code", String(50), primary_key=True),
|
||||
ForeignKeyConstraint(
|
||||
["item_group", "item_code"],
|
||||
["_comp_items.group_id", "_comp_items.item_code"],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class _CompOwner(_LocalBase):
|
||||
__tablename__ = "_comp_owners"
|
||||
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||
items: Mapped[list["_CompItem"]] = relationship(secondary=_comp_assoc)
|
||||
|
||||
|
||||
class _CompItem(_LocalBase):
|
||||
__tablename__ = "_comp_items"
|
||||
group_id: Mapped[str] = mapped_column(String(50), primary_key=True)
|
||||
item_code: Mapped[str] = mapped_column(String(50), primary_key=True)
|
||||
|
||||
|
||||
class TestM2MRemove:
|
||||
"""Tests for m2m_remove helper."""
|
||||
|
||||
async def _setup(
|
||||
self, session: AsyncSession, username: str, email: str, *tag_names: str
|
||||
):
|
||||
"""Create a user, post, and tags; associate all tags with the post."""
|
||||
user = User(username=username, email=email)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
post = Post(title=f"Post {username}", author_id=user.id)
|
||||
tags = [Tag(name=n) for n in tag_names]
|
||||
session.add(post)
|
||||
session.add_all(tags)
|
||||
await session.flush()
|
||||
|
||||
async with get_transaction(session):
|
||||
await m2m_add(session, post, Post.tags, *tags)
|
||||
|
||||
return post, tags
|
||||
|
||||
async def _load_tags(self, session: AsyncSession, post: Post) -> list[Tag]:
|
||||
result = await session.execute(
|
||||
select(Post).where(Post.id == post.id).options(selectinload(Post.tags))
|
||||
)
|
||||
return result.scalar_one().tags
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_removes_single(self, db_session: AsyncSession):
|
||||
"""Removes one association, leaving others intact."""
|
||||
post, (tag1, tag2) = await self._setup(
|
||||
db_session, "rm_author1", "rm1@test.com", "tag_rm_a", "tag_rm_b"
|
||||
)
|
||||
|
||||
async with get_transaction(db_session):
|
||||
await m2m_remove(db_session, post, Post.tags, tag1)
|
||||
|
||||
remaining = await self._load_tags(db_session, post)
|
||||
assert len(remaining) == 1
|
||||
assert remaining[0].id == tag2.id
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_removes_multiple(self, db_session: AsyncSession):
|
||||
"""Removes multiple associations in one call."""
|
||||
post, (tag1, tag2, tag3) = await self._setup(
|
||||
db_session, "rm_author2", "rm2@test.com", "tag_rm_c", "tag_rm_d", "tag_rm_e"
|
||||
)
|
||||
|
||||
async with get_transaction(db_session):
|
||||
await m2m_remove(db_session, post, Post.tags, tag1, tag3)
|
||||
|
||||
remaining = await self._load_tags(db_session, post)
|
||||
assert len(remaining) == 1
|
||||
assert remaining[0].id == tag2.id
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_noop_for_empty_related(self, db_session: AsyncSession):
|
||||
"""Calling with no related instances is a no-op."""
|
||||
post, (tag,) = await self._setup(
|
||||
db_session, "rm_author3", "rm3@test.com", "tag_rm_f"
|
||||
)
|
||||
|
||||
async with get_transaction(db_session):
|
||||
await m2m_remove(db_session, post, Post.tags)
|
||||
|
||||
remaining = await self._load_tags(db_session, post)
|
||||
assert len(remaining) == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_idempotent_for_missing_association(self, db_session: AsyncSession):
|
||||
"""Removing a non-existent association does not raise."""
|
||||
post, (tag1,) = await self._setup(
|
||||
db_session, "rm_author4", "rm4@test.com", "tag_rm_g"
|
||||
)
|
||||
tag2 = Tag(name="tag_rm_h")
|
||||
db_session.add(tag2)
|
||||
await db_session.flush()
|
||||
|
||||
# tag2 was never associated — should not raise
|
||||
async with get_transaction(db_session):
|
||||
await m2m_remove(db_session, post, Post.tags, tag2)
|
||||
|
||||
remaining = await self._load_tags(db_session, post)
|
||||
assert len(remaining) == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_non_m2m_raises_type_error(self, db_session: AsyncSession):
|
||||
"""Passing a non-M2M relationship attribute raises TypeError."""
|
||||
user = User(username="rm_author5", email="rm5@test.com")
|
||||
db_session.add(user)
|
||||
await db_session.flush()
|
||||
|
||||
role = Role(name="rm_type_err_role")
|
||||
db_session.add(role)
|
||||
await db_session.flush()
|
||||
|
||||
with pytest.raises(TypeError, match="Many-to-Many"):
|
||||
await m2m_remove(db_session, user, User.role, role)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_removes_composite_pk_related(self):
|
||||
"""Composite-PK branch: DELETE uses tuple IN when related side has multi-col PK."""
|
||||
engine = create_async_engine(DATABASE_URL, echo=False)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(_LocalBase.metadata.create_all)
|
||||
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
try:
|
||||
async with session_factory() as session:
|
||||
owner = _CompOwner()
|
||||
item1 = _CompItem(group_id="g1", item_code="c1")
|
||||
item2 = _CompItem(group_id="g1", item_code="c2")
|
||||
session.add_all([owner, item1, item2])
|
||||
await session.flush()
|
||||
|
||||
async with get_transaction(session):
|
||||
await m2m_add(session, owner, _CompOwner.items, item1, item2)
|
||||
|
||||
async with get_transaction(session):
|
||||
await m2m_remove(session, owner, _CompOwner.items, item1)
|
||||
|
||||
await session.commit()
|
||||
|
||||
async with session_factory() as verify:
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
result = await verify.execute(
|
||||
select(_CompOwner)
|
||||
.where(_CompOwner.id == owner.id)
|
||||
.options(selectinload(_CompOwner.items))
|
||||
)
|
||||
loaded = result.scalar_one()
|
||||
assert len(loaded.items) == 1
|
||||
assert (loaded.items[0].group_id, loaded.items[0].item_code) == (
|
||||
"g1",
|
||||
"c2",
|
||||
)
|
||||
finally:
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(_LocalBase.metadata.drop_all)
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
class TestM2MSet:
|
||||
"""Tests for m2m_set helper."""
|
||||
|
||||
async def _load_tags(self, session: AsyncSession, post: Post) -> list[Tag]:
|
||||
result = await session.execute(
|
||||
select(Post).where(Post.id == post.id).options(selectinload(Post.tags))
|
||||
)
|
||||
return result.scalar_one().tags
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_replaces_existing_set(self, db_session: AsyncSession):
|
||||
"""Replaces the full association set atomically."""
|
||||
user = User(username="set_author1", email="set1@test.com")
|
||||
db_session.add(user)
|
||||
await db_session.flush()
|
||||
|
||||
post = Post(title="Post Set A", author_id=user.id)
|
||||
tag1 = Tag(name="tag_set_a")
|
||||
tag2 = Tag(name="tag_set_b")
|
||||
tag3 = Tag(name="tag_set_c")
|
||||
db_session.add_all([post, tag1, tag2, tag3])
|
||||
await db_session.flush()
|
||||
|
||||
async with get_transaction(db_session):
|
||||
await m2m_add(db_session, post, Post.tags, tag1, tag2)
|
||||
|
||||
async with get_transaction(db_session):
|
||||
await m2m_set(db_session, post, Post.tags, tag3)
|
||||
|
||||
remaining = await self._load_tags(db_session, post)
|
||||
assert len(remaining) == 1
|
||||
assert remaining[0].id == tag3.id
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_clears_all_when_no_related(self, db_session: AsyncSession):
|
||||
"""Passing no related instances clears all associations."""
|
||||
user = User(username="set_author2", email="set2@test.com")
|
||||
db_session.add(user)
|
||||
await db_session.flush()
|
||||
|
||||
post = Post(title="Post Set B", author_id=user.id)
|
||||
tag = Tag(name="tag_set_d")
|
||||
db_session.add_all([post, tag])
|
||||
await db_session.flush()
|
||||
|
||||
async with get_transaction(db_session):
|
||||
await m2m_add(db_session, post, Post.tags, tag)
|
||||
|
||||
async with get_transaction(db_session):
|
||||
await m2m_set(db_session, post, Post.tags)
|
||||
|
||||
remaining = await self._load_tags(db_session, post)
|
||||
assert remaining == []
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_set_on_empty_then_populate(self, db_session: AsyncSession):
|
||||
"""m2m_set works on a post with no existing associations."""
|
||||
user = User(username="set_author3", email="set3@test.com")
|
||||
db_session.add(user)
|
||||
await db_session.flush()
|
||||
|
||||
post = Post(title="Post Set C", author_id=user.id)
|
||||
tag1 = Tag(name="tag_set_e")
|
||||
tag2 = Tag(name="tag_set_f")
|
||||
db_session.add_all([post, tag1, tag2])
|
||||
await db_session.flush()
|
||||
|
||||
async with get_transaction(db_session):
|
||||
await m2m_set(db_session, post, Post.tags, tag1, tag2)
|
||||
|
||||
remaining = await self._load_tags(db_session, post)
|
||||
assert {t.id for t in remaining} == {tag1.id, tag2.id}
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_non_m2m_raises_type_error(self, db_session: AsyncSession):
|
||||
"""Passing a non-M2M relationship attribute raises TypeError."""
|
||||
user = User(username="set_author4", email="set4@test.com")
|
||||
db_session.add(user)
|
||||
await db_session.flush()
|
||||
|
||||
role = Role(name="set_type_err_role")
|
||||
db_session.add(role)
|
||||
await db_session.flush()
|
||||
|
||||
with pytest.raises(TypeError, match="Many-to-Many"):
|
||||
await m2m_set(db_session, user, User.role, role)
|
||||
|
||||
Reference in New Issue
Block a user