Files
fastapi-toolsets/tests/test_db.py
2026-04-12 18:46:57 +02:00

931 lines
33 KiB
Python

"""Tests for fastapi_toolsets.db module."""
import asyncio
import uuid
import pytest
from sqlalchemy import (
Column,
ForeignKey,
ForeignKeyConstraint,
String,
Table,
Uuid,
select,
text,
)
from sqlalchemy.engine import make_url
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import (
DeclarativeBase,
Mapped,
mapped_column,
relationship,
selectinload,
)
from fastapi_toolsets.db import (
LockMode,
cleanup_tables,
create_database,
create_db_context,
create_db_dependency,
get_transaction,
lock_tables,
m2m_add,
m2m_remove,
m2m_set,
wait_for_row_change,
)
from fastapi_toolsets.exceptions import NotFoundError
from fastapi_toolsets.pytest import create_db_session
from .conftest import DATABASE_URL, Base, Post, Role, RoleCrud, Tag, User, UserCrud
class TestCreateDbDependency:
"""Tests for create_db_dependency."""
@pytest.mark.anyio
async def test_yields_session(self):
"""Dependency yields a valid session."""
engine = create_async_engine(DATABASE_URL, echo=False)
session_factory = async_sessionmaker(engine, expire_on_commit=False)
get_db = create_db_dependency(session_factory)
async for session in get_db():
assert isinstance(session, AsyncSession)
break
await engine.dispose()
@pytest.mark.anyio
async def test_auto_commits_transaction(self):
"""Dependency auto-commits if transaction is active."""
engine = create_async_engine(DATABASE_URL, echo=False)
session_factory = async_sessionmaker(engine, expire_on_commit=False)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
try:
get_db = create_db_dependency(session_factory)
async for session in get_db():
role = Role(name="test_role_dep")
session.add(role)
await session.flush()
async with session_factory() as verify_session:
result = await RoleCrud.first(
verify_session, [Role.name == "test_role_dep"]
)
assert result is not None
finally:
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await engine.dispose()
@pytest.mark.anyio
async def test_in_transaction_on_yield(self):
"""Session is already in a transaction when the endpoint body starts."""
engine = create_async_engine(DATABASE_URL, echo=False)
session_factory = async_sessionmaker(engine, expire_on_commit=False)
get_db = create_db_dependency(session_factory)
async for session in get_db():
assert session.in_transaction()
break
await engine.dispose()
@pytest.mark.anyio
async def test_no_commit_when_not_in_transaction(self):
"""Dependency skips commit if the session is no longer in a transaction on exit."""
engine = create_async_engine(DATABASE_URL, echo=False)
session_factory = async_sessionmaker(engine, expire_on_commit=False)
get_db = create_db_dependency(session_factory)
async for session in get_db():
# Manually commit — session exits the transaction
await session.commit()
assert not session.in_transaction()
# The dependency's post-yield path must not call commit again (no error)
await engine.dispose()
@pytest.mark.anyio
async def test_update_after_lock_tables_is_persisted(self):
"""Changes made after lock_tables exits (before endpoint returns) are committed.
Regression: without the auto-begin fix, lock_tables would start and commit a
real outer transaction, leaving the session idle. Any modifications after that
point were silently dropped.
"""
engine = create_async_engine(DATABASE_URL, echo=False)
session_factory = async_sessionmaker(engine, expire_on_commit=False)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
try:
get_db = create_db_dependency(session_factory)
async for session in get_db():
async with lock_tables(session, [Role]):
role = Role(name="lock_then_update")
session.add(role)
await session.flush()
# lock_tables has exited — outer transaction must still be open
assert session.in_transaction()
role.name = "updated_after_lock"
async with session_factory() as verify:
result = await RoleCrud.first(
verify, [Role.name == "updated_after_lock"]
)
assert result is not None
finally:
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await engine.dispose()
class TestCreateDbContext:
"""Tests for create_db_context."""
@pytest.mark.anyio
async def test_context_manager_yields_session(self):
"""Context manager yields a valid session."""
engine = create_async_engine(DATABASE_URL, echo=False)
session_factory = async_sessionmaker(engine, expire_on_commit=False)
get_db_context = create_db_context(session_factory)
async with get_db_context() as session:
assert isinstance(session, AsyncSession)
await engine.dispose()
@pytest.mark.anyio
async def test_context_manager_commits(self):
"""Context manager commits on exit."""
engine = create_async_engine(DATABASE_URL, echo=False)
session_factory = async_sessionmaker(engine, expire_on_commit=False)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
try:
get_db_context = create_db_context(session_factory)
async with get_db_context() as session:
role = Role(name="context_role")
session.add(role)
await session.flush()
async with session_factory() as verify_session:
result = await RoleCrud.first(
verify_session, [Role.name == "context_role"]
)
assert result is not None
finally:
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await engine.dispose()
class TestGetTransaction:
"""Tests for get_transaction context manager."""
@pytest.mark.anyio
async def test_starts_transaction(self, db_session: AsyncSession):
"""get_transaction starts a new transaction."""
async with get_transaction(db_session):
role = Role(name="tx_role")
db_session.add(role)
result = await RoleCrud.first(db_session, [Role.name == "tx_role"])
assert result is not None
@pytest.mark.anyio
async def test_nested_transaction_uses_savepoint(self, db_session: AsyncSession):
"""Nested transactions use savepoints."""
async with get_transaction(db_session):
role1 = Role(name="outer_role")
db_session.add(role1)
await db_session.flush()
async with get_transaction(db_session):
role2 = Role(name="inner_role")
db_session.add(role2)
results = await RoleCrud.get_multi(db_session)
names = {r.name for r in results}
assert "outer_role" in names
assert "inner_role" in names
@pytest.mark.anyio
async def test_rollback_on_exception(self, db_session: AsyncSession):
"""Transaction rolls back on exception."""
try:
async with get_transaction(db_session):
role = Role(name="rollback_role")
db_session.add(role)
await db_session.flush()
raise ValueError("Simulated error")
except ValueError:
pass
result = await RoleCrud.first(db_session, [Role.name == "rollback_role"])
assert result is None
@pytest.mark.anyio
async def test_nested_rollback_preserves_outer(self, db_session: AsyncSession):
"""Nested rollback preserves outer transaction."""
async with get_transaction(db_session):
role1 = Role(name="preserved_role")
db_session.add(role1)
await db_session.flush()
try:
async with get_transaction(db_session):
role2 = Role(name="rolled_back_role")
db_session.add(role2)
await db_session.flush()
raise ValueError("Inner error")
except ValueError:
pass
outer = await RoleCrud.first(db_session, [Role.name == "preserved_role"])
inner = await RoleCrud.first(db_session, [Role.name == "rolled_back_role"])
assert outer is not None
assert inner is None
class TestLockMode:
"""Tests for LockMode enum."""
def test_lock_modes_exist(self):
"""All expected lock modes are defined."""
assert LockMode.ACCESS_SHARE == "ACCESS SHARE"
assert LockMode.ROW_SHARE == "ROW SHARE"
assert LockMode.ROW_EXCLUSIVE == "ROW EXCLUSIVE"
assert LockMode.SHARE_UPDATE_EXCLUSIVE == "SHARE UPDATE EXCLUSIVE"
assert LockMode.SHARE == "SHARE"
assert LockMode.SHARE_ROW_EXCLUSIVE == "SHARE ROW EXCLUSIVE"
assert LockMode.EXCLUSIVE == "EXCLUSIVE"
assert LockMode.ACCESS_EXCLUSIVE == "ACCESS EXCLUSIVE"
def test_lock_mode_is_string(self):
"""Lock modes are string enums."""
assert isinstance(LockMode.EXCLUSIVE, str)
assert LockMode.EXCLUSIVE.value == "EXCLUSIVE"
class TestLockTables:
"""Tests for lock_tables context manager (PostgreSQL-specific)."""
@pytest.mark.anyio
async def test_lock_single_table(self, db_session: AsyncSession):
"""Lock a single table."""
async with lock_tables(db_session, [Role]):
# Inside the lock, we can still perform operations
role = Role(name="locked_role")
db_session.add(role)
await db_session.flush()
# After lock is released, verify the data was committed
result = await RoleCrud.first(db_session, [Role.name == "locked_role"])
assert result is not None
@pytest.mark.anyio
async def test_lock_multiple_tables(self, db_session: AsyncSession):
"""Lock multiple tables."""
async with lock_tables(db_session, [Role, User]):
role = Role(name="multi_lock_role")
db_session.add(role)
await db_session.flush()
result = await RoleCrud.first(db_session, [Role.name == "multi_lock_role"])
assert result is not None
@pytest.mark.anyio
async def test_lock_with_custom_mode(self, db_session: AsyncSession):
"""Lock with custom lock mode."""
async with lock_tables(db_session, [Role], mode=LockMode.EXCLUSIVE):
role = Role(name="exclusive_lock_role")
db_session.add(role)
await db_session.flush()
result = await RoleCrud.first(db_session, [Role.name == "exclusive_lock_role"])
assert result is not None
@pytest.mark.anyio
async def test_lock_rollback_on_exception(self, db_session: AsyncSession):
"""Lock context rolls back on exception."""
try:
async with lock_tables(db_session, [Role]):
role = Role(name="lock_rollback_role")
db_session.add(role)
await db_session.flush()
raise ValueError("Simulated error")
except ValueError:
pass
result = await RoleCrud.first(db_session, [Role.name == "lock_rollback_role"])
assert result is None
class TestWaitForRowChange:
"""Tests for wait_for_row_change polling function."""
@pytest.mark.anyio
async def test_detects_update(self, db_session: AsyncSession, engine):
"""Returns updated instance when a column value changes."""
role = Role(name="watch_role")
db_session.add(role)
await db_session.commit()
async def update_later():
await asyncio.sleep(0.15)
factory = async_sessionmaker(engine, expire_on_commit=False)
async with factory() as other:
r = await other.get(Role, role.id)
assert r is not None
r.name = "updated_role"
await other.commit()
update_task = asyncio.create_task(update_later())
result = await wait_for_row_change(db_session, Role, role.id, interval=0.05)
await update_task
assert result.name == "updated_role"
@pytest.mark.anyio
async def test_watches_specific_columns(self, db_session: AsyncSession, engine):
"""Only triggers on changes to specified columns."""
user = User(username="testuser", email="test@example.com")
db_session.add(user)
await db_session.commit()
async def update_later():
factory = async_sessionmaker(engine, expire_on_commit=False)
# First: change email (not watched) — should not trigger
await asyncio.sleep(0.15)
async with factory() as other:
u = await other.get(User, user.id)
assert u is not None
u.email = "new@example.com"
await other.commit()
# Second: change username (watched) — should trigger
await asyncio.sleep(0.15)
async with factory() as other:
u = await other.get(User, user.id)
assert u is not None
u.username = "newuser"
await other.commit()
update_task = asyncio.create_task(update_later())
result = await wait_for_row_change(
db_session, User, user.id, columns=["username"], interval=0.05
)
await update_task
assert result.username == "newuser"
assert result.email == "new@example.com"
@pytest.mark.anyio
async def test_nonexistent_row_raises(self, db_session: AsyncSession):
"""Raises NotFoundError when the row does not exist."""
fake_id = uuid.uuid4()
with pytest.raises(NotFoundError, match="not found"):
await wait_for_row_change(db_session, Role, fake_id, interval=0.05)
@pytest.mark.anyio
async def test_timeout_raises(self, db_session: AsyncSession):
"""Raises TimeoutError when no change is detected within timeout."""
role = Role(name="timeout_role")
db_session.add(role)
await db_session.commit()
with pytest.raises(TimeoutError):
await wait_for_row_change(
db_session, Role, role.id, interval=0.05, timeout=0.2
)
@pytest.mark.anyio
async def test_deleted_row_raises(self, db_session: AsyncSession, engine):
"""Raises NotFoundError when the row is deleted during polling."""
role = Role(name="delete_role")
db_session.add(role)
await db_session.commit()
async def delete_later():
await asyncio.sleep(0.15)
factory = async_sessionmaker(engine, expire_on_commit=False)
async with factory() as other:
r = await other.get(Role, role.id)
await other.delete(r)
await other.commit()
delete_task = asyncio.create_task(delete_later())
with pytest.raises(NotFoundError):
await wait_for_row_change(db_session, Role, role.id, interval=0.05)
await delete_task
class TestCreateDatabase:
"""Tests for create_database."""
@pytest.mark.anyio
async def test_creates_database(self):
"""Database is created by create_database."""
target_url = (
make_url(DATABASE_URL)
.set(database="test_create_db_general")
.render_as_string(hide_password=False)
)
expected_db = make_url(target_url).database
assert expected_db is not None
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
try:
async with engine.connect() as conn:
await conn.execute(text(f"DROP DATABASE IF EXISTS {expected_db}"))
await create_database(db_name=expected_db, server_url=DATABASE_URL)
async with engine.connect() as conn:
result = await conn.execute(
text("SELECT 1 FROM pg_database WHERE datname = :name"),
{"name": expected_db},
)
assert result.scalar() == 1
# Cleanup
async with engine.connect() as conn:
await conn.execute(text(f"DROP DATABASE IF EXISTS {expected_db}"))
finally:
await engine.dispose()
class TestCleanupTables:
"""Tests for cleanup_tables helper."""
@pytest.mark.anyio
async def test_truncates_all_tables(self):
"""All table rows are removed after cleanup_tables."""
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
role = Role(id=uuid.uuid4(), name="cleanup_role")
session.add(role)
await session.flush()
user = User(
id=uuid.uuid4(),
username="cleanup_user",
email="cleanup@test.com",
role_id=role.id,
)
session.add(user)
await session.commit()
# Verify rows exist
roles_count = await RoleCrud.count(session)
users_count = await UserCrud.count(session)
assert roles_count == 1
assert users_count == 1
await cleanup_tables(session, Base)
# Verify tables are empty
roles_count = await RoleCrud.count(session)
users_count = await UserCrud.count(session)
assert roles_count == 0
assert users_count == 0
@pytest.mark.anyio
async def test_noop_for_empty_metadata(self):
"""cleanup_tables does not raise when metadata has no tables."""
class EmptyBase(DeclarativeBase):
pass
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
# Should not raise
await cleanup_tables(session, EmptyBase)
class TestM2MAdd:
"""Tests for m2m_add helper."""
@pytest.mark.anyio
async def test_adds_single_related(self, db_session: AsyncSession):
"""Associates one related instance via the secondary table."""
user = User(username="m2m_author", email="m2m@test.com")
db_session.add(user)
await db_session.flush()
post = Post(title="Post A", author_id=user.id)
tag = Tag(name="python")
db_session.add_all([post, tag])
await db_session.flush()
async with get_transaction(db_session):
await m2m_add(db_session, post, Post.tags, tag)
result = await db_session.execute(
select(Post).where(Post.id == post.id).options(selectinload(Post.tags))
)
loaded = result.scalar_one()
assert len(loaded.tags) == 1
assert loaded.tags[0].id == tag.id
@pytest.mark.anyio
async def test_adds_multiple_related(self, db_session: AsyncSession):
"""Associates multiple related instances in a single call."""
user = User(username="m2m_author2", email="m2m2@test.com")
db_session.add(user)
await db_session.flush()
post = Post(title="Post B", author_id=user.id)
tag1 = Tag(name="web")
tag2 = Tag(name="api")
tag3 = Tag(name="async")
db_session.add_all([post, tag1, tag2, tag3])
await db_session.flush()
async with get_transaction(db_session):
await m2m_add(db_session, post, Post.tags, tag1, tag2, tag3)
result = await db_session.execute(
select(Post).where(Post.id == post.id).options(selectinload(Post.tags))
)
loaded = result.scalar_one()
assert {t.id for t in loaded.tags} == {tag1.id, tag2.id, tag3.id}
@pytest.mark.anyio
async def test_noop_for_empty_related(self, db_session: AsyncSession):
"""Calling with no related instances is a no-op."""
user = User(username="m2m_author3", email="m2m3@test.com")
db_session.add(user)
await db_session.flush()
post = Post(title="Post C", author_id=user.id)
db_session.add(post)
await db_session.flush()
async with get_transaction(db_session):
await m2m_add(db_session, post, Post.tags) # no related instances
result = await db_session.execute(
select(Post).where(Post.id == post.id).options(selectinload(Post.tags))
)
loaded = result.scalar_one()
assert loaded.tags == []
@pytest.mark.anyio
async def test_ignore_conflicts_true(self, db_session: AsyncSession):
"""Duplicate inserts are silently skipped when ignore_conflicts=True."""
user = User(username="m2m_author4", email="m2m4@test.com")
db_session.add(user)
await db_session.flush()
post = Post(title="Post D", author_id=user.id)
tag = Tag(name="duplicate_tag")
db_session.add_all([post, tag])
await db_session.flush()
async with get_transaction(db_session):
await m2m_add(db_session, post, Post.tags, tag)
# Second call with ignore_conflicts=True must not raise
async with get_transaction(db_session):
await m2m_add(db_session, post, Post.tags, tag, ignore_conflicts=True)
result = await db_session.execute(
select(Post).where(Post.id == post.id).options(selectinload(Post.tags))
)
loaded = result.scalar_one()
assert len(loaded.tags) == 1
@pytest.mark.anyio
async def test_ignore_conflicts_false_raises(self, db_session: AsyncSession):
"""Duplicate inserts raise IntegrityError when ignore_conflicts=False (default)."""
user = User(username="m2m_author5", email="m2m5@test.com")
db_session.add(user)
await db_session.flush()
post = Post(title="Post E", author_id=user.id)
tag = Tag(name="conflict_tag")
db_session.add_all([post, tag])
await db_session.flush()
async with get_transaction(db_session):
await m2m_add(db_session, post, Post.tags, tag)
with pytest.raises(IntegrityError):
async with get_transaction(db_session):
await m2m_add(db_session, post, Post.tags, tag)
@pytest.mark.anyio
async def test_non_m2m_raises_type_error(self, db_session: AsyncSession):
"""Passing a non-M2M relationship attribute raises TypeError."""
user = User(username="m2m_author6", email="m2m6@test.com")
db_session.add(user)
await db_session.flush()
role = Role(name="type_err_role")
db_session.add(role)
await db_session.flush()
with pytest.raises(TypeError, match="Many-to-Many"):
await m2m_add(db_session, user, User.role, role)
@pytest.mark.anyio
async def test_works_inside_lock_tables(self, db_session: AsyncSession):
"""m2m_add works correctly inside a lock_tables nested transaction."""
user = User(username="m2m_lock_author", email="m2m_lock@test.com")
db_session.add(user)
await db_session.flush()
async with lock_tables(db_session, [Tag]):
tag = Tag(name="locked_tag")
db_session.add(tag)
await db_session.flush()
post = Post(title="Post Lock", author_id=user.id)
db_session.add(post)
await db_session.flush()
await m2m_add(db_session, post, Post.tags, tag)
result = await db_session.execute(
select(Post).where(Post.id == post.id).options(selectinload(Post.tags))
)
loaded = result.scalar_one()
assert len(loaded.tags) == 1
assert loaded.tags[0].name == "locked_tag"
class _LocalBase(DeclarativeBase):
pass
_comp_assoc = Table(
"_comp_assoc",
_LocalBase.metadata,
Column("owner_id", Uuid, ForeignKey("_comp_owners.id"), primary_key=True),
Column("item_group", String(50), primary_key=True),
Column("item_code", String(50), primary_key=True),
ForeignKeyConstraint(
["item_group", "item_code"],
["_comp_items.group_id", "_comp_items.item_code"],
),
)
class _CompOwner(_LocalBase):
__tablename__ = "_comp_owners"
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
items: Mapped[list["_CompItem"]] = relationship(secondary=_comp_assoc)
class _CompItem(_LocalBase):
__tablename__ = "_comp_items"
group_id: Mapped[str] = mapped_column(String(50), primary_key=True)
item_code: Mapped[str] = mapped_column(String(50), primary_key=True)
class TestM2MRemove:
"""Tests for m2m_remove helper."""
async def _setup(
self, session: AsyncSession, username: str, email: str, *tag_names: str
):
"""Create a user, post, and tags; associate all tags with the post."""
user = User(username=username, email=email)
session.add(user)
await session.flush()
post = Post(title=f"Post {username}", author_id=user.id)
tags = [Tag(name=n) for n in tag_names]
session.add(post)
session.add_all(tags)
await session.flush()
async with get_transaction(session):
await m2m_add(session, post, Post.tags, *tags)
return post, tags
async def _load_tags(self, session: AsyncSession, post: Post) -> list[Tag]:
result = await session.execute(
select(Post).where(Post.id == post.id).options(selectinload(Post.tags))
)
return result.scalar_one().tags
@pytest.mark.anyio
async def test_removes_single(self, db_session: AsyncSession):
"""Removes one association, leaving others intact."""
post, (tag1, tag2) = await self._setup(
db_session, "rm_author1", "rm1@test.com", "tag_rm_a", "tag_rm_b"
)
async with get_transaction(db_session):
await m2m_remove(db_session, post, Post.tags, tag1)
remaining = await self._load_tags(db_session, post)
assert len(remaining) == 1
assert remaining[0].id == tag2.id
@pytest.mark.anyio
async def test_removes_multiple(self, db_session: AsyncSession):
"""Removes multiple associations in one call."""
post, (tag1, tag2, tag3) = await self._setup(
db_session, "rm_author2", "rm2@test.com", "tag_rm_c", "tag_rm_d", "tag_rm_e"
)
async with get_transaction(db_session):
await m2m_remove(db_session, post, Post.tags, tag1, tag3)
remaining = await self._load_tags(db_session, post)
assert len(remaining) == 1
assert remaining[0].id == tag2.id
@pytest.mark.anyio
async def test_noop_for_empty_related(self, db_session: AsyncSession):
"""Calling with no related instances is a no-op."""
post, (tag,) = await self._setup(
db_session, "rm_author3", "rm3@test.com", "tag_rm_f"
)
async with get_transaction(db_session):
await m2m_remove(db_session, post, Post.tags)
remaining = await self._load_tags(db_session, post)
assert len(remaining) == 1
@pytest.mark.anyio
async def test_idempotent_for_missing_association(self, db_session: AsyncSession):
"""Removing a non-existent association does not raise."""
post, (tag1,) = await self._setup(
db_session, "rm_author4", "rm4@test.com", "tag_rm_g"
)
tag2 = Tag(name="tag_rm_h")
db_session.add(tag2)
await db_session.flush()
# tag2 was never associated — should not raise
async with get_transaction(db_session):
await m2m_remove(db_session, post, Post.tags, tag2)
remaining = await self._load_tags(db_session, post)
assert len(remaining) == 1
@pytest.mark.anyio
async def test_non_m2m_raises_type_error(self, db_session: AsyncSession):
"""Passing a non-M2M relationship attribute raises TypeError."""
user = User(username="rm_author5", email="rm5@test.com")
db_session.add(user)
await db_session.flush()
role = Role(name="rm_type_err_role")
db_session.add(role)
await db_session.flush()
with pytest.raises(TypeError, match="Many-to-Many"):
await m2m_remove(db_session, user, User.role, role)
@pytest.mark.anyio
async def test_removes_composite_pk_related(self):
"""Composite-PK branch: DELETE uses tuple IN when related side has multi-col PK."""
engine = create_async_engine(DATABASE_URL, echo=False)
async with engine.begin() as conn:
await conn.run_sync(_LocalBase.metadata.create_all)
session_factory = async_sessionmaker(engine, expire_on_commit=False)
try:
async with session_factory() as session:
owner = _CompOwner()
item1 = _CompItem(group_id="g1", item_code="c1")
item2 = _CompItem(group_id="g1", item_code="c2")
session.add_all([owner, item1, item2])
await session.flush()
async with get_transaction(session):
await m2m_add(session, owner, _CompOwner.items, item1, item2)
async with get_transaction(session):
await m2m_remove(session, owner, _CompOwner.items, item1)
await session.commit()
async with session_factory() as verify:
from sqlalchemy import select
from sqlalchemy.orm import selectinload
result = await verify.execute(
select(_CompOwner)
.where(_CompOwner.id == owner.id)
.options(selectinload(_CompOwner.items))
)
loaded = result.scalar_one()
assert len(loaded.items) == 1
assert (loaded.items[0].group_id, loaded.items[0].item_code) == (
"g1",
"c2",
)
finally:
async with engine.begin() as conn:
await conn.run_sync(_LocalBase.metadata.drop_all)
await engine.dispose()
class TestM2MSet:
"""Tests for m2m_set helper."""
async def _load_tags(self, session: AsyncSession, post: Post) -> list[Tag]:
result = await session.execute(
select(Post).where(Post.id == post.id).options(selectinload(Post.tags))
)
return result.scalar_one().tags
@pytest.mark.anyio
async def test_replaces_existing_set(self, db_session: AsyncSession):
"""Replaces the full association set atomically."""
user = User(username="set_author1", email="set1@test.com")
db_session.add(user)
await db_session.flush()
post = Post(title="Post Set A", author_id=user.id)
tag1 = Tag(name="tag_set_a")
tag2 = Tag(name="tag_set_b")
tag3 = Tag(name="tag_set_c")
db_session.add_all([post, tag1, tag2, tag3])
await db_session.flush()
async with get_transaction(db_session):
await m2m_add(db_session, post, Post.tags, tag1, tag2)
async with get_transaction(db_session):
await m2m_set(db_session, post, Post.tags, tag3)
remaining = await self._load_tags(db_session, post)
assert len(remaining) == 1
assert remaining[0].id == tag3.id
@pytest.mark.anyio
async def test_clears_all_when_no_related(self, db_session: AsyncSession):
"""Passing no related instances clears all associations."""
user = User(username="set_author2", email="set2@test.com")
db_session.add(user)
await db_session.flush()
post = Post(title="Post Set B", author_id=user.id)
tag = Tag(name="tag_set_d")
db_session.add_all([post, tag])
await db_session.flush()
async with get_transaction(db_session):
await m2m_add(db_session, post, Post.tags, tag)
async with get_transaction(db_session):
await m2m_set(db_session, post, Post.tags)
remaining = await self._load_tags(db_session, post)
assert remaining == []
@pytest.mark.anyio
async def test_set_on_empty_then_populate(self, db_session: AsyncSession):
"""m2m_set works on a post with no existing associations."""
user = User(username="set_author3", email="set3@test.com")
db_session.add(user)
await db_session.flush()
post = Post(title="Post Set C", author_id=user.id)
tag1 = Tag(name="tag_set_e")
tag2 = Tag(name="tag_set_f")
db_session.add_all([post, tag1, tag2])
await db_session.flush()
async with get_transaction(db_session):
await m2m_set(db_session, post, Post.tags, tag1, tag2)
remaining = await self._load_tags(db_session, post)
assert {t.id for t in remaining} == {tag1.id, tag2.id}
@pytest.mark.anyio
async def test_non_m2m_raises_type_error(self, db_session: AsyncSession):
"""Passing a non-M2M relationship attribute raises TypeError."""
user = User(username="set_author4", email="set4@test.com")
db_session.add(user)
await db_session.flush()
role = Role(name="set_type_err_role")
db_session.add(role)
await db_session.flush()
with pytest.raises(TypeError, match="Many-to-Many"):
await m2m_set(db_session, user, User.role, role)