mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-15 22:26:25 +02:00
931 lines
33 KiB
Python
931 lines
33 KiB
Python
"""Tests for fastapi_toolsets.db module."""
|
|
|
|
import asyncio
|
|
import uuid
|
|
|
|
import pytest
|
|
from sqlalchemy import (
|
|
Column,
|
|
ForeignKey,
|
|
ForeignKeyConstraint,
|
|
String,
|
|
Table,
|
|
Uuid,
|
|
select,
|
|
text,
|
|
)
|
|
from sqlalchemy.engine import make_url
|
|
from sqlalchemy.exc import IntegrityError
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
from sqlalchemy.orm import (
|
|
DeclarativeBase,
|
|
Mapped,
|
|
mapped_column,
|
|
relationship,
|
|
selectinload,
|
|
)
|
|
|
|
from fastapi_toolsets.db import (
|
|
LockMode,
|
|
cleanup_tables,
|
|
create_database,
|
|
create_db_context,
|
|
create_db_dependency,
|
|
get_transaction,
|
|
lock_tables,
|
|
m2m_add,
|
|
m2m_remove,
|
|
m2m_set,
|
|
wait_for_row_change,
|
|
)
|
|
from fastapi_toolsets.exceptions import NotFoundError
|
|
from fastapi_toolsets.pytest import create_db_session
|
|
|
|
from .conftest import DATABASE_URL, Base, Post, Role, RoleCrud, Tag, User, UserCrud
|
|
|
|
|
|
class TestCreateDbDependency:
|
|
"""Tests for create_db_dependency."""
|
|
|
|
@pytest.mark.anyio
|
|
async def test_yields_session(self):
|
|
"""Dependency yields a valid session."""
|
|
engine = create_async_engine(DATABASE_URL, echo=False)
|
|
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
|
get_db = create_db_dependency(session_factory)
|
|
|
|
async for session in get_db():
|
|
assert isinstance(session, AsyncSession)
|
|
break
|
|
|
|
await engine.dispose()
|
|
|
|
@pytest.mark.anyio
|
|
async def test_auto_commits_transaction(self):
|
|
"""Dependency auto-commits if transaction is active."""
|
|
engine = create_async_engine(DATABASE_URL, echo=False)
|
|
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
|
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
try:
|
|
get_db = create_db_dependency(session_factory)
|
|
|
|
async for session in get_db():
|
|
role = Role(name="test_role_dep")
|
|
session.add(role)
|
|
await session.flush()
|
|
|
|
async with session_factory() as verify_session:
|
|
result = await RoleCrud.first(
|
|
verify_session, [Role.name == "test_role_dep"]
|
|
)
|
|
assert result is not None
|
|
finally:
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.drop_all)
|
|
await engine.dispose()
|
|
|
|
@pytest.mark.anyio
|
|
async def test_in_transaction_on_yield(self):
|
|
"""Session is already in a transaction when the endpoint body starts."""
|
|
engine = create_async_engine(DATABASE_URL, echo=False)
|
|
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
|
get_db = create_db_dependency(session_factory)
|
|
|
|
async for session in get_db():
|
|
assert session.in_transaction()
|
|
break
|
|
|
|
await engine.dispose()
|
|
|
|
@pytest.mark.anyio
|
|
async def test_no_commit_when_not_in_transaction(self):
|
|
"""Dependency skips commit if the session is no longer in a transaction on exit."""
|
|
engine = create_async_engine(DATABASE_URL, echo=False)
|
|
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
|
get_db = create_db_dependency(session_factory)
|
|
|
|
async for session in get_db():
|
|
# Manually commit — session exits the transaction
|
|
await session.commit()
|
|
assert not session.in_transaction()
|
|
# The dependency's post-yield path must not call commit again (no error)
|
|
|
|
await engine.dispose()
|
|
|
|
@pytest.mark.anyio
|
|
async def test_update_after_lock_tables_is_persisted(self):
|
|
"""Changes made after lock_tables exits (before endpoint returns) are committed.
|
|
|
|
Regression: without the auto-begin fix, lock_tables would start and commit a
|
|
real outer transaction, leaving the session idle. Any modifications after that
|
|
point were silently dropped.
|
|
"""
|
|
engine = create_async_engine(DATABASE_URL, echo=False)
|
|
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
|
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
try:
|
|
get_db = create_db_dependency(session_factory)
|
|
|
|
async for session in get_db():
|
|
async with lock_tables(session, [Role]):
|
|
role = Role(name="lock_then_update")
|
|
session.add(role)
|
|
await session.flush()
|
|
# lock_tables has exited — outer transaction must still be open
|
|
assert session.in_transaction()
|
|
role.name = "updated_after_lock"
|
|
|
|
async with session_factory() as verify:
|
|
result = await RoleCrud.first(
|
|
verify, [Role.name == "updated_after_lock"]
|
|
)
|
|
assert result is not None
|
|
finally:
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.drop_all)
|
|
await engine.dispose()
|
|
|
|
|
|
class TestCreateDbContext:
|
|
"""Tests for create_db_context."""
|
|
|
|
@pytest.mark.anyio
|
|
async def test_context_manager_yields_session(self):
|
|
"""Context manager yields a valid session."""
|
|
engine = create_async_engine(DATABASE_URL, echo=False)
|
|
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
|
get_db_context = create_db_context(session_factory)
|
|
|
|
async with get_db_context() as session:
|
|
assert isinstance(session, AsyncSession)
|
|
|
|
await engine.dispose()
|
|
|
|
@pytest.mark.anyio
|
|
async def test_context_manager_commits(self):
|
|
"""Context manager commits on exit."""
|
|
engine = create_async_engine(DATABASE_URL, echo=False)
|
|
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
|
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
try:
|
|
get_db_context = create_db_context(session_factory)
|
|
|
|
async with get_db_context() as session:
|
|
role = Role(name="context_role")
|
|
session.add(role)
|
|
await session.flush()
|
|
|
|
async with session_factory() as verify_session:
|
|
result = await RoleCrud.first(
|
|
verify_session, [Role.name == "context_role"]
|
|
)
|
|
assert result is not None
|
|
finally:
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.drop_all)
|
|
await engine.dispose()
|
|
|
|
|
|
class TestGetTransaction:
|
|
"""Tests for get_transaction context manager."""
|
|
|
|
@pytest.mark.anyio
|
|
async def test_starts_transaction(self, db_session: AsyncSession):
|
|
"""get_transaction starts a new transaction."""
|
|
async with get_transaction(db_session):
|
|
role = Role(name="tx_role")
|
|
db_session.add(role)
|
|
|
|
result = await RoleCrud.first(db_session, [Role.name == "tx_role"])
|
|
assert result is not None
|
|
|
|
@pytest.mark.anyio
|
|
async def test_nested_transaction_uses_savepoint(self, db_session: AsyncSession):
|
|
"""Nested transactions use savepoints."""
|
|
async with get_transaction(db_session):
|
|
role1 = Role(name="outer_role")
|
|
db_session.add(role1)
|
|
await db_session.flush()
|
|
|
|
async with get_transaction(db_session):
|
|
role2 = Role(name="inner_role")
|
|
db_session.add(role2)
|
|
|
|
results = await RoleCrud.get_multi(db_session)
|
|
names = {r.name for r in results}
|
|
assert "outer_role" in names
|
|
assert "inner_role" in names
|
|
|
|
@pytest.mark.anyio
|
|
async def test_rollback_on_exception(self, db_session: AsyncSession):
|
|
"""Transaction rolls back on exception."""
|
|
try:
|
|
async with get_transaction(db_session):
|
|
role = Role(name="rollback_role")
|
|
db_session.add(role)
|
|
await db_session.flush()
|
|
raise ValueError("Simulated error")
|
|
except ValueError:
|
|
pass
|
|
|
|
result = await RoleCrud.first(db_session, [Role.name == "rollback_role"])
|
|
assert result is None
|
|
|
|
@pytest.mark.anyio
|
|
async def test_nested_rollback_preserves_outer(self, db_session: AsyncSession):
|
|
"""Nested rollback preserves outer transaction."""
|
|
async with get_transaction(db_session):
|
|
role1 = Role(name="preserved_role")
|
|
db_session.add(role1)
|
|
await db_session.flush()
|
|
|
|
try:
|
|
async with get_transaction(db_session):
|
|
role2 = Role(name="rolled_back_role")
|
|
db_session.add(role2)
|
|
await db_session.flush()
|
|
raise ValueError("Inner error")
|
|
except ValueError:
|
|
pass
|
|
|
|
outer = await RoleCrud.first(db_session, [Role.name == "preserved_role"])
|
|
inner = await RoleCrud.first(db_session, [Role.name == "rolled_back_role"])
|
|
assert outer is not None
|
|
assert inner is None
|
|
|
|
|
|
class TestLockMode:
|
|
"""Tests for LockMode enum."""
|
|
|
|
def test_lock_modes_exist(self):
|
|
"""All expected lock modes are defined."""
|
|
assert LockMode.ACCESS_SHARE == "ACCESS SHARE"
|
|
assert LockMode.ROW_SHARE == "ROW SHARE"
|
|
assert LockMode.ROW_EXCLUSIVE == "ROW EXCLUSIVE"
|
|
assert LockMode.SHARE_UPDATE_EXCLUSIVE == "SHARE UPDATE EXCLUSIVE"
|
|
assert LockMode.SHARE == "SHARE"
|
|
assert LockMode.SHARE_ROW_EXCLUSIVE == "SHARE ROW EXCLUSIVE"
|
|
assert LockMode.EXCLUSIVE == "EXCLUSIVE"
|
|
assert LockMode.ACCESS_EXCLUSIVE == "ACCESS EXCLUSIVE"
|
|
|
|
def test_lock_mode_is_string(self):
|
|
"""Lock modes are string enums."""
|
|
assert isinstance(LockMode.EXCLUSIVE, str)
|
|
assert LockMode.EXCLUSIVE.value == "EXCLUSIVE"
|
|
|
|
|
|
class TestLockTables:
|
|
"""Tests for lock_tables context manager (PostgreSQL-specific)."""
|
|
|
|
@pytest.mark.anyio
|
|
async def test_lock_single_table(self, db_session: AsyncSession):
|
|
"""Lock a single table."""
|
|
async with lock_tables(db_session, [Role]):
|
|
# Inside the lock, we can still perform operations
|
|
role = Role(name="locked_role")
|
|
db_session.add(role)
|
|
await db_session.flush()
|
|
|
|
# After lock is released, verify the data was committed
|
|
result = await RoleCrud.first(db_session, [Role.name == "locked_role"])
|
|
assert result is not None
|
|
|
|
@pytest.mark.anyio
|
|
async def test_lock_multiple_tables(self, db_session: AsyncSession):
|
|
"""Lock multiple tables."""
|
|
async with lock_tables(db_session, [Role, User]):
|
|
role = Role(name="multi_lock_role")
|
|
db_session.add(role)
|
|
await db_session.flush()
|
|
|
|
result = await RoleCrud.first(db_session, [Role.name == "multi_lock_role"])
|
|
assert result is not None
|
|
|
|
@pytest.mark.anyio
|
|
async def test_lock_with_custom_mode(self, db_session: AsyncSession):
|
|
"""Lock with custom lock mode."""
|
|
async with lock_tables(db_session, [Role], mode=LockMode.EXCLUSIVE):
|
|
role = Role(name="exclusive_lock_role")
|
|
db_session.add(role)
|
|
await db_session.flush()
|
|
|
|
result = await RoleCrud.first(db_session, [Role.name == "exclusive_lock_role"])
|
|
assert result is not None
|
|
|
|
@pytest.mark.anyio
|
|
async def test_lock_rollback_on_exception(self, db_session: AsyncSession):
|
|
"""Lock context rolls back on exception."""
|
|
try:
|
|
async with lock_tables(db_session, [Role]):
|
|
role = Role(name="lock_rollback_role")
|
|
db_session.add(role)
|
|
await db_session.flush()
|
|
raise ValueError("Simulated error")
|
|
except ValueError:
|
|
pass
|
|
|
|
result = await RoleCrud.first(db_session, [Role.name == "lock_rollback_role"])
|
|
assert result is None
|
|
|
|
|
|
class TestWaitForRowChange:
|
|
"""Tests for wait_for_row_change polling function."""
|
|
|
|
@pytest.mark.anyio
|
|
async def test_detects_update(self, db_session: AsyncSession, engine):
|
|
"""Returns updated instance when a column value changes."""
|
|
role = Role(name="watch_role")
|
|
db_session.add(role)
|
|
await db_session.commit()
|
|
|
|
async def update_later():
|
|
await asyncio.sleep(0.15)
|
|
factory = async_sessionmaker(engine, expire_on_commit=False)
|
|
async with factory() as other:
|
|
r = await other.get(Role, role.id)
|
|
assert r is not None
|
|
r.name = "updated_role"
|
|
await other.commit()
|
|
|
|
update_task = asyncio.create_task(update_later())
|
|
result = await wait_for_row_change(db_session, Role, role.id, interval=0.05)
|
|
await update_task
|
|
|
|
assert result.name == "updated_role"
|
|
|
|
@pytest.mark.anyio
|
|
async def test_watches_specific_columns(self, db_session: AsyncSession, engine):
|
|
"""Only triggers on changes to specified columns."""
|
|
user = User(username="testuser", email="test@example.com")
|
|
db_session.add(user)
|
|
await db_session.commit()
|
|
|
|
async def update_later():
|
|
factory = async_sessionmaker(engine, expire_on_commit=False)
|
|
# First: change email (not watched) — should not trigger
|
|
await asyncio.sleep(0.15)
|
|
async with factory() as other:
|
|
u = await other.get(User, user.id)
|
|
assert u is not None
|
|
u.email = "new@example.com"
|
|
await other.commit()
|
|
# Second: change username (watched) — should trigger
|
|
await asyncio.sleep(0.15)
|
|
async with factory() as other:
|
|
u = await other.get(User, user.id)
|
|
assert u is not None
|
|
u.username = "newuser"
|
|
await other.commit()
|
|
|
|
update_task = asyncio.create_task(update_later())
|
|
result = await wait_for_row_change(
|
|
db_session, User, user.id, columns=["username"], interval=0.05
|
|
)
|
|
await update_task
|
|
|
|
assert result.username == "newuser"
|
|
assert result.email == "new@example.com"
|
|
|
|
@pytest.mark.anyio
|
|
async def test_nonexistent_row_raises(self, db_session: AsyncSession):
|
|
"""Raises NotFoundError when the row does not exist."""
|
|
fake_id = uuid.uuid4()
|
|
with pytest.raises(NotFoundError, match="not found"):
|
|
await wait_for_row_change(db_session, Role, fake_id, interval=0.05)
|
|
|
|
@pytest.mark.anyio
|
|
async def test_timeout_raises(self, db_session: AsyncSession):
|
|
"""Raises TimeoutError when no change is detected within timeout."""
|
|
role = Role(name="timeout_role")
|
|
db_session.add(role)
|
|
await db_session.commit()
|
|
|
|
with pytest.raises(TimeoutError):
|
|
await wait_for_row_change(
|
|
db_session, Role, role.id, interval=0.05, timeout=0.2
|
|
)
|
|
|
|
@pytest.mark.anyio
|
|
async def test_deleted_row_raises(self, db_session: AsyncSession, engine):
|
|
"""Raises NotFoundError when the row is deleted during polling."""
|
|
role = Role(name="delete_role")
|
|
db_session.add(role)
|
|
await db_session.commit()
|
|
|
|
async def delete_later():
|
|
await asyncio.sleep(0.15)
|
|
factory = async_sessionmaker(engine, expire_on_commit=False)
|
|
async with factory() as other:
|
|
r = await other.get(Role, role.id)
|
|
await other.delete(r)
|
|
await other.commit()
|
|
|
|
delete_task = asyncio.create_task(delete_later())
|
|
with pytest.raises(NotFoundError):
|
|
await wait_for_row_change(db_session, Role, role.id, interval=0.05)
|
|
await delete_task
|
|
|
|
|
|
class TestCreateDatabase:
|
|
"""Tests for create_database."""
|
|
|
|
@pytest.mark.anyio
|
|
async def test_creates_database(self):
|
|
"""Database is created by create_database."""
|
|
target_url = (
|
|
make_url(DATABASE_URL)
|
|
.set(database="test_create_db_general")
|
|
.render_as_string(hide_password=False)
|
|
)
|
|
expected_db = make_url(target_url).database
|
|
assert expected_db is not None
|
|
|
|
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
|
|
try:
|
|
async with engine.connect() as conn:
|
|
await conn.execute(text(f"DROP DATABASE IF EXISTS {expected_db}"))
|
|
|
|
await create_database(db_name=expected_db, server_url=DATABASE_URL)
|
|
|
|
async with engine.connect() as conn:
|
|
result = await conn.execute(
|
|
text("SELECT 1 FROM pg_database WHERE datname = :name"),
|
|
{"name": expected_db},
|
|
)
|
|
assert result.scalar() == 1
|
|
|
|
# Cleanup
|
|
async with engine.connect() as conn:
|
|
await conn.execute(text(f"DROP DATABASE IF EXISTS {expected_db}"))
|
|
finally:
|
|
await engine.dispose()
|
|
|
|
|
|
class TestCleanupTables:
|
|
"""Tests for cleanup_tables helper."""
|
|
|
|
@pytest.mark.anyio
|
|
async def test_truncates_all_tables(self):
|
|
"""All table rows are removed after cleanup_tables."""
|
|
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
|
|
role = Role(id=uuid.uuid4(), name="cleanup_role")
|
|
session.add(role)
|
|
await session.flush()
|
|
|
|
user = User(
|
|
id=uuid.uuid4(),
|
|
username="cleanup_user",
|
|
email="cleanup@test.com",
|
|
role_id=role.id,
|
|
)
|
|
session.add(user)
|
|
await session.commit()
|
|
|
|
# Verify rows exist
|
|
roles_count = await RoleCrud.count(session)
|
|
users_count = await UserCrud.count(session)
|
|
assert roles_count == 1
|
|
assert users_count == 1
|
|
|
|
await cleanup_tables(session, Base)
|
|
|
|
# Verify tables are empty
|
|
roles_count = await RoleCrud.count(session)
|
|
users_count = await UserCrud.count(session)
|
|
assert roles_count == 0
|
|
assert users_count == 0
|
|
|
|
@pytest.mark.anyio
|
|
async def test_noop_for_empty_metadata(self):
|
|
"""cleanup_tables does not raise when metadata has no tables."""
|
|
|
|
class EmptyBase(DeclarativeBase):
|
|
pass
|
|
|
|
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
|
|
# Should not raise
|
|
await cleanup_tables(session, EmptyBase)
|
|
|
|
|
|
class TestM2MAdd:
|
|
"""Tests for m2m_add helper."""
|
|
|
|
@pytest.mark.anyio
|
|
async def test_adds_single_related(self, db_session: AsyncSession):
|
|
"""Associates one related instance via the secondary table."""
|
|
user = User(username="m2m_author", email="m2m@test.com")
|
|
db_session.add(user)
|
|
await db_session.flush()
|
|
|
|
post = Post(title="Post A", author_id=user.id)
|
|
tag = Tag(name="python")
|
|
db_session.add_all([post, tag])
|
|
await db_session.flush()
|
|
|
|
async with get_transaction(db_session):
|
|
await m2m_add(db_session, post, Post.tags, tag)
|
|
|
|
result = await db_session.execute(
|
|
select(Post).where(Post.id == post.id).options(selectinload(Post.tags))
|
|
)
|
|
loaded = result.scalar_one()
|
|
assert len(loaded.tags) == 1
|
|
assert loaded.tags[0].id == tag.id
|
|
|
|
@pytest.mark.anyio
|
|
async def test_adds_multiple_related(self, db_session: AsyncSession):
|
|
"""Associates multiple related instances in a single call."""
|
|
user = User(username="m2m_author2", email="m2m2@test.com")
|
|
db_session.add(user)
|
|
await db_session.flush()
|
|
|
|
post = Post(title="Post B", author_id=user.id)
|
|
tag1 = Tag(name="web")
|
|
tag2 = Tag(name="api")
|
|
tag3 = Tag(name="async")
|
|
db_session.add_all([post, tag1, tag2, tag3])
|
|
await db_session.flush()
|
|
|
|
async with get_transaction(db_session):
|
|
await m2m_add(db_session, post, Post.tags, tag1, tag2, tag3)
|
|
|
|
result = await db_session.execute(
|
|
select(Post).where(Post.id == post.id).options(selectinload(Post.tags))
|
|
)
|
|
loaded = result.scalar_one()
|
|
assert {t.id for t in loaded.tags} == {tag1.id, tag2.id, tag3.id}
|
|
|
|
@pytest.mark.anyio
|
|
async def test_noop_for_empty_related(self, db_session: AsyncSession):
|
|
"""Calling with no related instances is a no-op."""
|
|
user = User(username="m2m_author3", email="m2m3@test.com")
|
|
db_session.add(user)
|
|
await db_session.flush()
|
|
|
|
post = Post(title="Post C", author_id=user.id)
|
|
db_session.add(post)
|
|
await db_session.flush()
|
|
|
|
async with get_transaction(db_session):
|
|
await m2m_add(db_session, post, Post.tags) # no related instances
|
|
|
|
result = await db_session.execute(
|
|
select(Post).where(Post.id == post.id).options(selectinload(Post.tags))
|
|
)
|
|
loaded = result.scalar_one()
|
|
assert loaded.tags == []
|
|
|
|
@pytest.mark.anyio
|
|
async def test_ignore_conflicts_true(self, db_session: AsyncSession):
|
|
"""Duplicate inserts are silently skipped when ignore_conflicts=True."""
|
|
user = User(username="m2m_author4", email="m2m4@test.com")
|
|
db_session.add(user)
|
|
await db_session.flush()
|
|
|
|
post = Post(title="Post D", author_id=user.id)
|
|
tag = Tag(name="duplicate_tag")
|
|
db_session.add_all([post, tag])
|
|
await db_session.flush()
|
|
|
|
async with get_transaction(db_session):
|
|
await m2m_add(db_session, post, Post.tags, tag)
|
|
|
|
# Second call with ignore_conflicts=True must not raise
|
|
async with get_transaction(db_session):
|
|
await m2m_add(db_session, post, Post.tags, tag, ignore_conflicts=True)
|
|
|
|
result = await db_session.execute(
|
|
select(Post).where(Post.id == post.id).options(selectinload(Post.tags))
|
|
)
|
|
loaded = result.scalar_one()
|
|
assert len(loaded.tags) == 1
|
|
|
|
@pytest.mark.anyio
|
|
async def test_ignore_conflicts_false_raises(self, db_session: AsyncSession):
|
|
"""Duplicate inserts raise IntegrityError when ignore_conflicts=False (default)."""
|
|
user = User(username="m2m_author5", email="m2m5@test.com")
|
|
db_session.add(user)
|
|
await db_session.flush()
|
|
|
|
post = Post(title="Post E", author_id=user.id)
|
|
tag = Tag(name="conflict_tag")
|
|
db_session.add_all([post, tag])
|
|
await db_session.flush()
|
|
|
|
async with get_transaction(db_session):
|
|
await m2m_add(db_session, post, Post.tags, tag)
|
|
|
|
with pytest.raises(IntegrityError):
|
|
async with get_transaction(db_session):
|
|
await m2m_add(db_session, post, Post.tags, tag)
|
|
|
|
@pytest.mark.anyio
|
|
async def test_non_m2m_raises_type_error(self, db_session: AsyncSession):
|
|
"""Passing a non-M2M relationship attribute raises TypeError."""
|
|
user = User(username="m2m_author6", email="m2m6@test.com")
|
|
db_session.add(user)
|
|
await db_session.flush()
|
|
|
|
role = Role(name="type_err_role")
|
|
db_session.add(role)
|
|
await db_session.flush()
|
|
|
|
with pytest.raises(TypeError, match="Many-to-Many"):
|
|
await m2m_add(db_session, user, User.role, role)
|
|
|
|
@pytest.mark.anyio
|
|
async def test_works_inside_lock_tables(self, db_session: AsyncSession):
|
|
"""m2m_add works correctly inside a lock_tables nested transaction."""
|
|
user = User(username="m2m_lock_author", email="m2m_lock@test.com")
|
|
db_session.add(user)
|
|
await db_session.flush()
|
|
|
|
async with lock_tables(db_session, [Tag]):
|
|
tag = Tag(name="locked_tag")
|
|
db_session.add(tag)
|
|
await db_session.flush()
|
|
|
|
post = Post(title="Post Lock", author_id=user.id)
|
|
db_session.add(post)
|
|
await db_session.flush()
|
|
|
|
await m2m_add(db_session, post, Post.tags, tag)
|
|
|
|
result = await db_session.execute(
|
|
select(Post).where(Post.id == post.id).options(selectinload(Post.tags))
|
|
)
|
|
loaded = result.scalar_one()
|
|
assert len(loaded.tags) == 1
|
|
assert loaded.tags[0].name == "locked_tag"
|
|
|
|
|
|
class _LocalBase(DeclarativeBase):
|
|
pass
|
|
|
|
|
|
_comp_assoc = Table(
|
|
"_comp_assoc",
|
|
_LocalBase.metadata,
|
|
Column("owner_id", Uuid, ForeignKey("_comp_owners.id"), primary_key=True),
|
|
Column("item_group", String(50), primary_key=True),
|
|
Column("item_code", String(50), primary_key=True),
|
|
ForeignKeyConstraint(
|
|
["item_group", "item_code"],
|
|
["_comp_items.group_id", "_comp_items.item_code"],
|
|
),
|
|
)
|
|
|
|
|
|
class _CompOwner(_LocalBase):
|
|
__tablename__ = "_comp_owners"
|
|
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
|
items: Mapped[list["_CompItem"]] = relationship(secondary=_comp_assoc)
|
|
|
|
|
|
class _CompItem(_LocalBase):
|
|
__tablename__ = "_comp_items"
|
|
group_id: Mapped[str] = mapped_column(String(50), primary_key=True)
|
|
item_code: Mapped[str] = mapped_column(String(50), primary_key=True)
|
|
|
|
|
|
class TestM2MRemove:
|
|
"""Tests for m2m_remove helper."""
|
|
|
|
async def _setup(
|
|
self, session: AsyncSession, username: str, email: str, *tag_names: str
|
|
):
|
|
"""Create a user, post, and tags; associate all tags with the post."""
|
|
user = User(username=username, email=email)
|
|
session.add(user)
|
|
await session.flush()
|
|
|
|
post = Post(title=f"Post {username}", author_id=user.id)
|
|
tags = [Tag(name=n) for n in tag_names]
|
|
session.add(post)
|
|
session.add_all(tags)
|
|
await session.flush()
|
|
|
|
async with get_transaction(session):
|
|
await m2m_add(session, post, Post.tags, *tags)
|
|
|
|
return post, tags
|
|
|
|
async def _load_tags(self, session: AsyncSession, post: Post) -> list[Tag]:
|
|
result = await session.execute(
|
|
select(Post).where(Post.id == post.id).options(selectinload(Post.tags))
|
|
)
|
|
return result.scalar_one().tags
|
|
|
|
@pytest.mark.anyio
|
|
async def test_removes_single(self, db_session: AsyncSession):
|
|
"""Removes one association, leaving others intact."""
|
|
post, (tag1, tag2) = await self._setup(
|
|
db_session, "rm_author1", "rm1@test.com", "tag_rm_a", "tag_rm_b"
|
|
)
|
|
|
|
async with get_transaction(db_session):
|
|
await m2m_remove(db_session, post, Post.tags, tag1)
|
|
|
|
remaining = await self._load_tags(db_session, post)
|
|
assert len(remaining) == 1
|
|
assert remaining[0].id == tag2.id
|
|
|
|
@pytest.mark.anyio
|
|
async def test_removes_multiple(self, db_session: AsyncSession):
|
|
"""Removes multiple associations in one call."""
|
|
post, (tag1, tag2, tag3) = await self._setup(
|
|
db_session, "rm_author2", "rm2@test.com", "tag_rm_c", "tag_rm_d", "tag_rm_e"
|
|
)
|
|
|
|
async with get_transaction(db_session):
|
|
await m2m_remove(db_session, post, Post.tags, tag1, tag3)
|
|
|
|
remaining = await self._load_tags(db_session, post)
|
|
assert len(remaining) == 1
|
|
assert remaining[0].id == tag2.id
|
|
|
|
@pytest.mark.anyio
|
|
async def test_noop_for_empty_related(self, db_session: AsyncSession):
|
|
"""Calling with no related instances is a no-op."""
|
|
post, (tag,) = await self._setup(
|
|
db_session, "rm_author3", "rm3@test.com", "tag_rm_f"
|
|
)
|
|
|
|
async with get_transaction(db_session):
|
|
await m2m_remove(db_session, post, Post.tags)
|
|
|
|
remaining = await self._load_tags(db_session, post)
|
|
assert len(remaining) == 1
|
|
|
|
@pytest.mark.anyio
|
|
async def test_idempotent_for_missing_association(self, db_session: AsyncSession):
|
|
"""Removing a non-existent association does not raise."""
|
|
post, (tag1,) = await self._setup(
|
|
db_session, "rm_author4", "rm4@test.com", "tag_rm_g"
|
|
)
|
|
tag2 = Tag(name="tag_rm_h")
|
|
db_session.add(tag2)
|
|
await db_session.flush()
|
|
|
|
# tag2 was never associated — should not raise
|
|
async with get_transaction(db_session):
|
|
await m2m_remove(db_session, post, Post.tags, tag2)
|
|
|
|
remaining = await self._load_tags(db_session, post)
|
|
assert len(remaining) == 1
|
|
|
|
@pytest.mark.anyio
|
|
async def test_non_m2m_raises_type_error(self, db_session: AsyncSession):
|
|
"""Passing a non-M2M relationship attribute raises TypeError."""
|
|
user = User(username="rm_author5", email="rm5@test.com")
|
|
db_session.add(user)
|
|
await db_session.flush()
|
|
|
|
role = Role(name="rm_type_err_role")
|
|
db_session.add(role)
|
|
await db_session.flush()
|
|
|
|
with pytest.raises(TypeError, match="Many-to-Many"):
|
|
await m2m_remove(db_session, user, User.role, role)
|
|
|
|
@pytest.mark.anyio
|
|
async def test_removes_composite_pk_related(self):
|
|
"""Composite-PK branch: DELETE uses tuple IN when related side has multi-col PK."""
|
|
engine = create_async_engine(DATABASE_URL, echo=False)
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(_LocalBase.metadata.create_all)
|
|
|
|
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
|
try:
|
|
async with session_factory() as session:
|
|
owner = _CompOwner()
|
|
item1 = _CompItem(group_id="g1", item_code="c1")
|
|
item2 = _CompItem(group_id="g1", item_code="c2")
|
|
session.add_all([owner, item1, item2])
|
|
await session.flush()
|
|
|
|
async with get_transaction(session):
|
|
await m2m_add(session, owner, _CompOwner.items, item1, item2)
|
|
|
|
async with get_transaction(session):
|
|
await m2m_remove(session, owner, _CompOwner.items, item1)
|
|
|
|
await session.commit()
|
|
|
|
async with session_factory() as verify:
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import selectinload
|
|
|
|
result = await verify.execute(
|
|
select(_CompOwner)
|
|
.where(_CompOwner.id == owner.id)
|
|
.options(selectinload(_CompOwner.items))
|
|
)
|
|
loaded = result.scalar_one()
|
|
assert len(loaded.items) == 1
|
|
assert (loaded.items[0].group_id, loaded.items[0].item_code) == (
|
|
"g1",
|
|
"c2",
|
|
)
|
|
finally:
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(_LocalBase.metadata.drop_all)
|
|
await engine.dispose()
|
|
|
|
|
|
class TestM2MSet:
|
|
"""Tests for m2m_set helper."""
|
|
|
|
async def _load_tags(self, session: AsyncSession, post: Post) -> list[Tag]:
|
|
result = await session.execute(
|
|
select(Post).where(Post.id == post.id).options(selectinload(Post.tags))
|
|
)
|
|
return result.scalar_one().tags
|
|
|
|
@pytest.mark.anyio
|
|
async def test_replaces_existing_set(self, db_session: AsyncSession):
|
|
"""Replaces the full association set atomically."""
|
|
user = User(username="set_author1", email="set1@test.com")
|
|
db_session.add(user)
|
|
await db_session.flush()
|
|
|
|
post = Post(title="Post Set A", author_id=user.id)
|
|
tag1 = Tag(name="tag_set_a")
|
|
tag2 = Tag(name="tag_set_b")
|
|
tag3 = Tag(name="tag_set_c")
|
|
db_session.add_all([post, tag1, tag2, tag3])
|
|
await db_session.flush()
|
|
|
|
async with get_transaction(db_session):
|
|
await m2m_add(db_session, post, Post.tags, tag1, tag2)
|
|
|
|
async with get_transaction(db_session):
|
|
await m2m_set(db_session, post, Post.tags, tag3)
|
|
|
|
remaining = await self._load_tags(db_session, post)
|
|
assert len(remaining) == 1
|
|
assert remaining[0].id == tag3.id
|
|
|
|
@pytest.mark.anyio
|
|
async def test_clears_all_when_no_related(self, db_session: AsyncSession):
|
|
"""Passing no related instances clears all associations."""
|
|
user = User(username="set_author2", email="set2@test.com")
|
|
db_session.add(user)
|
|
await db_session.flush()
|
|
|
|
post = Post(title="Post Set B", author_id=user.id)
|
|
tag = Tag(name="tag_set_d")
|
|
db_session.add_all([post, tag])
|
|
await db_session.flush()
|
|
|
|
async with get_transaction(db_session):
|
|
await m2m_add(db_session, post, Post.tags, tag)
|
|
|
|
async with get_transaction(db_session):
|
|
await m2m_set(db_session, post, Post.tags)
|
|
|
|
remaining = await self._load_tags(db_session, post)
|
|
assert remaining == []
|
|
|
|
@pytest.mark.anyio
|
|
async def test_set_on_empty_then_populate(self, db_session: AsyncSession):
|
|
"""m2m_set works on a post with no existing associations."""
|
|
user = User(username="set_author3", email="set3@test.com")
|
|
db_session.add(user)
|
|
await db_session.flush()
|
|
|
|
post = Post(title="Post Set C", author_id=user.id)
|
|
tag1 = Tag(name="tag_set_e")
|
|
tag2 = Tag(name="tag_set_f")
|
|
db_session.add_all([post, tag1, tag2])
|
|
await db_session.flush()
|
|
|
|
async with get_transaction(db_session):
|
|
await m2m_set(db_session, post, Post.tags, tag1, tag2)
|
|
|
|
remaining = await self._load_tags(db_session, post)
|
|
assert {t.id for t in remaining} == {tag1.id, tag2.id}
|
|
|
|
@pytest.mark.anyio
|
|
async def test_non_m2m_raises_type_error(self, db_session: AsyncSession):
|
|
"""Passing a non-M2M relationship attribute raises TypeError."""
|
|
user = User(username="set_author4", email="set4@test.com")
|
|
db_session.add(user)
|
|
await db_session.flush()
|
|
|
|
role = Role(name="set_type_err_role")
|
|
db_session.add(role)
|
|
await db_session.flush()
|
|
|
|
with pytest.raises(TypeError, match="Many-to-Many"):
|
|
await m2m_set(db_session, user, User.role, role)
|