mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-15 22:26:25 +02:00
feat: add M2M helpers (#247)
This commit is contained in:
@@ -118,6 +118,57 @@ async def clean(db_session):
|
||||
await cleanup_tables(session=db_session, base=Base)
|
||||
```
|
||||
|
||||
## Many-to-Many helpers
|
||||
|
||||
SQLAlchemy's ORM collection API triggers lazy-loads when you append to a relationship inside a savepoint (e.g. inside `lock_tables` or a nested `get_transaction`). The three `m2m_*` helpers bypass the ORM collection entirely and issue direct SQL against the association table.
|
||||
|
||||
### `m2m_add` — insert associations
|
||||
|
||||
[`m2m_add`](../reference/db.md#fastapi_toolsets.db.m2m_add) inserts one or more rows into a secondary table without touching the ORM collection:
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.db import lock_tables, m2m_add
|
||||
|
||||
async with lock_tables(session, [Tag]):
|
||||
tag = await TagCrud.create(session, TagCreate(name="python"))
|
||||
await m2m_add(session, post, Post.tags, tag)
|
||||
```
|
||||
|
||||
Pass `ignore_conflicts=True` to silently skip associations that already exist:
|
||||
|
||||
```python
|
||||
await m2m_add(session, post, Post.tags, tag, ignore_conflicts=True)
|
||||
```
|
||||
|
||||
### `m2m_remove` — delete associations
|
||||
|
||||
[`m2m_remove`](../reference/db.md#fastapi_toolsets.db.m2m_remove) deletes specific association rows. Removing a non-existent association is a no-op:
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.db import get_transaction, m2m_remove
|
||||
|
||||
async with get_transaction(session):
|
||||
await m2m_remove(session, post, Post.tags, tag1, tag2)
|
||||
```
|
||||
|
||||
### `m2m_set` — replace the full set
|
||||
|
||||
[`m2m_set`](../reference/db.md#fastapi_toolsets.db.m2m_set) atomically replaces all associations: it deletes every existing row for the owner instance then inserts the new set. Passing no related instances clears the association entirely:
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.db import get_transaction, m2m_set
|
||||
|
||||
# Replace all tags
|
||||
async with get_transaction(session):
|
||||
await m2m_set(session, post, Post.tags, tag_a, tag_b)
|
||||
|
||||
# Clear all tags
|
||||
async with get_transaction(session):
|
||||
await m2m_set(session, post, Post.tags)
|
||||
```
|
||||
|
||||
All three helpers raise `TypeError` if the relationship attribute is not a Many-to-Many (i.e. has no secondary table).
|
||||
|
||||
---
|
||||
|
||||
[:material-api: API Reference](../reference/db.md)
|
||||
|
||||
@@ -13,6 +13,9 @@ from fastapi_toolsets.db import (
|
||||
create_db_context,
|
||||
get_transaction,
|
||||
lock_tables,
|
||||
m2m_add,
|
||||
m2m_remove,
|
||||
m2m_set,
|
||||
wait_for_row_change,
|
||||
)
|
||||
```
|
||||
@@ -32,3 +35,9 @@ from fastapi_toolsets.db import (
|
||||
## ::: fastapi_toolsets.db.create_database
|
||||
|
||||
## ::: fastapi_toolsets.db.cleanup_tables
|
||||
|
||||
## ::: fastapi_toolsets.db.m2m_add
|
||||
|
||||
## ::: fastapi_toolsets.db.m2m_remove
|
||||
|
||||
## ::: fastapi_toolsets.db.m2m_set
|
||||
|
||||
@@ -4,11 +4,13 @@ import asyncio
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
||||
from enum import Enum
|
||||
from typing import Any, TypeVar
|
||||
from typing import Any, TypeVar, cast
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy import Table, delete, text, tuple_
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from sqlalchemy.orm import DeclarativeBase, QueryableAttribute
|
||||
from sqlalchemy.orm.relationships import RelationshipProperty
|
||||
|
||||
from .exceptions import NotFoundError
|
||||
|
||||
@@ -20,6 +22,9 @@ __all__ = [
|
||||
"create_db_dependency",
|
||||
"get_transaction",
|
||||
"lock_tables",
|
||||
"m2m_add",
|
||||
"m2m_remove",
|
||||
"m2m_set",
|
||||
"wait_for_row_change",
|
||||
]
|
||||
|
||||
@@ -339,3 +344,140 @@ async def wait_for_row_change(
|
||||
current = {col: getattr(instance, col) for col in watch_cols}
|
||||
if current != initial:
|
||||
return instance
|
||||
|
||||
|
||||
def _m2m_prop(rel_attr: QueryableAttribute) -> RelationshipProperty: # type: ignore[type-arg]
|
||||
"""Return the validated M2M RelationshipProperty for *rel_attr*.
|
||||
|
||||
Raises TypeError if *rel_attr* is not a Many-to-Many relationship.
|
||||
"""
|
||||
prop = rel_attr.property
|
||||
if not isinstance(prop, RelationshipProperty) or prop.secondary is None:
|
||||
raise TypeError(
|
||||
f"m2m helpers require a Many-to-Many relationship attribute, "
|
||||
f"got {rel_attr!r}. Use a relationship with a secondary table."
|
||||
)
|
||||
return prop
|
||||
|
||||
|
||||
async def m2m_add(
|
||||
session: AsyncSession,
|
||||
instance: DeclarativeBase,
|
||||
rel_attr: QueryableAttribute,
|
||||
*related: DeclarativeBase,
|
||||
ignore_conflicts: bool = False,
|
||||
) -> None:
|
||||
"""Insert rows into a Many-to-Many association table without loading the ORM collection.
|
||||
|
||||
Args:
|
||||
session: DB async session.
|
||||
instance: The "owner" side model instance (e.g. the ``A`` in ``A.b_list``).
|
||||
rel_attr: The M2M relationship attribute on the model class (e.g. ``A.b_list``).
|
||||
*related: One or more related instances to associate with ``instance``.
|
||||
ignore_conflicts: When ``True``, silently skip rows that already exist
|
||||
in the association table (``ON CONFLICT DO NOTHING``).
|
||||
|
||||
Raises:
|
||||
TypeError: If ``rel_attr`` is not a Many-to-Many relationship.
|
||||
"""
|
||||
prop = _m2m_prop(rel_attr)
|
||||
if not related:
|
||||
return
|
||||
|
||||
secondary = cast(Table, prop.secondary)
|
||||
assert secondary is not None # guaranteed by _m2m_prop
|
||||
sync_pairs = prop.secondary_synchronize_pairs
|
||||
assert sync_pairs is not None # set whenever secondary is set
|
||||
|
||||
# synchronize_pairs: [(parent_col, assoc_col), ...]
|
||||
# secondary_synchronize_pairs: [(related_col, assoc_col), ...]
|
||||
rows: list[dict[str, Any]] = []
|
||||
for rel_instance in related:
|
||||
row: dict[str, Any] = {}
|
||||
for parent_col, assoc_col in prop.synchronize_pairs:
|
||||
row[assoc_col.name] = getattr(instance, cast(str, parent_col.key))
|
||||
for related_col, assoc_col in sync_pairs:
|
||||
row[assoc_col.name] = getattr(rel_instance, cast(str, related_col.key))
|
||||
rows.append(row)
|
||||
|
||||
stmt = pg_insert(secondary).values(rows)
|
||||
if ignore_conflicts:
|
||||
stmt = stmt.on_conflict_do_nothing()
|
||||
await session.execute(stmt)
|
||||
|
||||
|
||||
async def m2m_remove(
|
||||
session: AsyncSession,
|
||||
instance: DeclarativeBase,
|
||||
rel_attr: QueryableAttribute,
|
||||
*related: DeclarativeBase,
|
||||
) -> None:
|
||||
"""Remove rows from a Many-to-Many association table without loading the ORM collection.
|
||||
|
||||
Args:
|
||||
session: DB async session.
|
||||
instance: The "owner" side model instance (e.g. the ``A`` in ``A.b_list``).
|
||||
rel_attr: The M2M relationship attribute on the model class (e.g. ``A.b_list``).
|
||||
*related: One or more related instances to disassociate from ``instance``.
|
||||
|
||||
Raises:
|
||||
TypeError: If ``rel_attr`` is not a Many-to-Many relationship.
|
||||
"""
|
||||
prop = _m2m_prop(rel_attr)
|
||||
if not related:
|
||||
return
|
||||
|
||||
secondary = cast(Table, prop.secondary)
|
||||
assert secondary is not None # guaranteed by _m2m_prop
|
||||
related_pairs = prop.secondary_synchronize_pairs
|
||||
assert related_pairs is not None # set whenever secondary is set
|
||||
|
||||
parent_where = [
|
||||
assoc_col == getattr(instance, cast(str, parent_col.key))
|
||||
for parent_col, assoc_col in prop.synchronize_pairs
|
||||
]
|
||||
|
||||
if len(related_pairs) == 1:
|
||||
related_col, assoc_col = related_pairs[0]
|
||||
related_values = [getattr(r, cast(str, related_col.key)) for r in related]
|
||||
related_where = assoc_col.in_(related_values)
|
||||
else:
|
||||
assoc_cols = [ac for _, ac in related_pairs]
|
||||
rel_cols = [rc for rc, _ in related_pairs]
|
||||
related_values_t = [
|
||||
tuple(getattr(r, cast(str, rc.key)) for rc in rel_cols) for r in related
|
||||
]
|
||||
related_where = tuple_(*assoc_cols).in_(related_values_t)
|
||||
|
||||
await session.execute(delete(secondary).where(*parent_where, related_where))
|
||||
|
||||
|
||||
async def m2m_set(
|
||||
session: AsyncSession,
|
||||
instance: DeclarativeBase,
|
||||
rel_attr: QueryableAttribute,
|
||||
*related: DeclarativeBase,
|
||||
) -> None:
|
||||
"""Replace the entire Many-to-Many association set atomically.
|
||||
|
||||
Args:
|
||||
session: DB async session.
|
||||
instance: The "owner" side model instance (e.g. the ``A`` in ``A.b_list``).
|
||||
rel_attr: The M2M relationship attribute on the model class (e.g. ``A.b_list``).
|
||||
*related: The new complete set of related instances.
|
||||
|
||||
Raises:
|
||||
TypeError: If ``rel_attr`` is not a Many-to-Many relationship.
|
||||
"""
|
||||
prop = _m2m_prop(rel_attr)
|
||||
secondary = cast(Table, prop.secondary)
|
||||
assert secondary is not None # guaranteed by _m2m_prop
|
||||
|
||||
parent_where = [
|
||||
assoc_col == getattr(instance, cast(str, parent_col.key))
|
||||
for parent_col, assoc_col in prop.synchronize_pairs
|
||||
]
|
||||
await session.execute(delete(secondary).where(*parent_where))
|
||||
|
||||
if related:
|
||||
await m2m_add(session, instance, rel_attr, *related)
|
||||
|
||||
454
tests/test_db.py
454
tests/test_db.py
@@ -4,10 +4,26 @@ import asyncio
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
ForeignKey,
|
||||
ForeignKeyConstraint,
|
||||
String,
|
||||
Table,
|
||||
Uuid,
|
||||
select,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy.engine import make_url
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from sqlalchemy.orm import (
|
||||
DeclarativeBase,
|
||||
Mapped,
|
||||
mapped_column,
|
||||
relationship,
|
||||
selectinload,
|
||||
)
|
||||
|
||||
from fastapi_toolsets.db import (
|
||||
LockMode,
|
||||
@@ -17,12 +33,15 @@ from fastapi_toolsets.db import (
|
||||
create_db_dependency,
|
||||
get_transaction,
|
||||
lock_tables,
|
||||
m2m_add,
|
||||
m2m_remove,
|
||||
m2m_set,
|
||||
wait_for_row_change,
|
||||
)
|
||||
from fastapi_toolsets.exceptions import NotFoundError
|
||||
from fastapi_toolsets.pytest import create_db_session
|
||||
|
||||
from .conftest import DATABASE_URL, Base, Role, RoleCrud, User, UserCrud
|
||||
from .conftest import DATABASE_URL, Base, Post, Role, RoleCrud, Tag, User, UserCrud
|
||||
|
||||
|
||||
class TestCreateDbDependency:
|
||||
@@ -81,6 +100,21 @@ class TestCreateDbDependency:
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_no_commit_when_not_in_transaction(self):
|
||||
"""Dependency skips commit if the session is no longer in a transaction on exit."""
|
||||
engine = create_async_engine(DATABASE_URL, echo=False)
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
get_db = create_db_dependency(session_factory)
|
||||
|
||||
async for session in get_db():
|
||||
# Manually commit — session exits the transaction
|
||||
await session.commit()
|
||||
assert not session.in_transaction()
|
||||
# The dependency's post-yield path must not call commit again (no error)
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_after_lock_tables_is_persisted(self):
|
||||
"""Changes made after lock_tables exits (before endpoint returns) are committed.
|
||||
@@ -480,3 +514,417 @@ class TestCleanupTables:
|
||||
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
|
||||
# Should not raise
|
||||
await cleanup_tables(session, EmptyBase)
|
||||
|
||||
|
||||
class TestM2MAdd:
|
||||
"""Tests for m2m_add helper."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_adds_single_related(self, db_session: AsyncSession):
|
||||
"""Associates one related instance via the secondary table."""
|
||||
user = User(username="m2m_author", email="m2m@test.com")
|
||||
db_session.add(user)
|
||||
await db_session.flush()
|
||||
|
||||
post = Post(title="Post A", author_id=user.id)
|
||||
tag = Tag(name="python")
|
||||
db_session.add_all([post, tag])
|
||||
await db_session.flush()
|
||||
|
||||
async with get_transaction(db_session):
|
||||
await m2m_add(db_session, post, Post.tags, tag)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Post).where(Post.id == post.id).options(selectinload(Post.tags))
|
||||
)
|
||||
loaded = result.scalar_one()
|
||||
assert len(loaded.tags) == 1
|
||||
assert loaded.tags[0].id == tag.id
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_adds_multiple_related(self, db_session: AsyncSession):
|
||||
"""Associates multiple related instances in a single call."""
|
||||
user = User(username="m2m_author2", email="m2m2@test.com")
|
||||
db_session.add(user)
|
||||
await db_session.flush()
|
||||
|
||||
post = Post(title="Post B", author_id=user.id)
|
||||
tag1 = Tag(name="web")
|
||||
tag2 = Tag(name="api")
|
||||
tag3 = Tag(name="async")
|
||||
db_session.add_all([post, tag1, tag2, tag3])
|
||||
await db_session.flush()
|
||||
|
||||
async with get_transaction(db_session):
|
||||
await m2m_add(db_session, post, Post.tags, tag1, tag2, tag3)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Post).where(Post.id == post.id).options(selectinload(Post.tags))
|
||||
)
|
||||
loaded = result.scalar_one()
|
||||
assert {t.id for t in loaded.tags} == {tag1.id, tag2.id, tag3.id}
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_noop_for_empty_related(self, db_session: AsyncSession):
|
||||
"""Calling with no related instances is a no-op."""
|
||||
user = User(username="m2m_author3", email="m2m3@test.com")
|
||||
db_session.add(user)
|
||||
await db_session.flush()
|
||||
|
||||
post = Post(title="Post C", author_id=user.id)
|
||||
db_session.add(post)
|
||||
await db_session.flush()
|
||||
|
||||
async with get_transaction(db_session):
|
||||
await m2m_add(db_session, post, Post.tags) # no related instances
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Post).where(Post.id == post.id).options(selectinload(Post.tags))
|
||||
)
|
||||
loaded = result.scalar_one()
|
||||
assert loaded.tags == []
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_ignore_conflicts_true(self, db_session: AsyncSession):
|
||||
"""Duplicate inserts are silently skipped when ignore_conflicts=True."""
|
||||
user = User(username="m2m_author4", email="m2m4@test.com")
|
||||
db_session.add(user)
|
||||
await db_session.flush()
|
||||
|
||||
post = Post(title="Post D", author_id=user.id)
|
||||
tag = Tag(name="duplicate_tag")
|
||||
db_session.add_all([post, tag])
|
||||
await db_session.flush()
|
||||
|
||||
async with get_transaction(db_session):
|
||||
await m2m_add(db_session, post, Post.tags, tag)
|
||||
|
||||
# Second call with ignore_conflicts=True must not raise
|
||||
async with get_transaction(db_session):
|
||||
await m2m_add(db_session, post, Post.tags, tag, ignore_conflicts=True)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Post).where(Post.id == post.id).options(selectinload(Post.tags))
|
||||
)
|
||||
loaded = result.scalar_one()
|
||||
assert len(loaded.tags) == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_ignore_conflicts_false_raises(self, db_session: AsyncSession):
|
||||
"""Duplicate inserts raise IntegrityError when ignore_conflicts=False (default)."""
|
||||
user = User(username="m2m_author5", email="m2m5@test.com")
|
||||
db_session.add(user)
|
||||
await db_session.flush()
|
||||
|
||||
post = Post(title="Post E", author_id=user.id)
|
||||
tag = Tag(name="conflict_tag")
|
||||
db_session.add_all([post, tag])
|
||||
await db_session.flush()
|
||||
|
||||
async with get_transaction(db_session):
|
||||
await m2m_add(db_session, post, Post.tags, tag)
|
||||
|
||||
with pytest.raises(IntegrityError):
|
||||
async with get_transaction(db_session):
|
||||
await m2m_add(db_session, post, Post.tags, tag)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_non_m2m_raises_type_error(self, db_session: AsyncSession):
|
||||
"""Passing a non-M2M relationship attribute raises TypeError."""
|
||||
user = User(username="m2m_author6", email="m2m6@test.com")
|
||||
db_session.add(user)
|
||||
await db_session.flush()
|
||||
|
||||
role = Role(name="type_err_role")
|
||||
db_session.add(role)
|
||||
await db_session.flush()
|
||||
|
||||
with pytest.raises(TypeError, match="Many-to-Many"):
|
||||
await m2m_add(db_session, user, User.role, role)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_works_inside_lock_tables(self, db_session: AsyncSession):
|
||||
"""m2m_add works correctly inside a lock_tables nested transaction."""
|
||||
user = User(username="m2m_lock_author", email="m2m_lock@test.com")
|
||||
db_session.add(user)
|
||||
await db_session.flush()
|
||||
|
||||
async with lock_tables(db_session, [Tag]):
|
||||
tag = Tag(name="locked_tag")
|
||||
db_session.add(tag)
|
||||
await db_session.flush()
|
||||
|
||||
post = Post(title="Post Lock", author_id=user.id)
|
||||
db_session.add(post)
|
||||
await db_session.flush()
|
||||
|
||||
await m2m_add(db_session, post, Post.tags, tag)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Post).where(Post.id == post.id).options(selectinload(Post.tags))
|
||||
)
|
||||
loaded = result.scalar_one()
|
||||
assert len(loaded.tags) == 1
|
||||
assert loaded.tags[0].name == "locked_tag"
|
||||
|
||||
|
||||
class _LocalBase(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
_comp_assoc = Table(
|
||||
"_comp_assoc",
|
||||
_LocalBase.metadata,
|
||||
Column("owner_id", Uuid, ForeignKey("_comp_owners.id"), primary_key=True),
|
||||
Column("item_group", String(50), primary_key=True),
|
||||
Column("item_code", String(50), primary_key=True),
|
||||
ForeignKeyConstraint(
|
||||
["item_group", "item_code"],
|
||||
["_comp_items.group_id", "_comp_items.item_code"],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class _CompOwner(_LocalBase):
|
||||
__tablename__ = "_comp_owners"
|
||||
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||
items: Mapped[list["_CompItem"]] = relationship(secondary=_comp_assoc)
|
||||
|
||||
|
||||
class _CompItem(_LocalBase):
|
||||
__tablename__ = "_comp_items"
|
||||
group_id: Mapped[str] = mapped_column(String(50), primary_key=True)
|
||||
item_code: Mapped[str] = mapped_column(String(50), primary_key=True)
|
||||
|
||||
|
||||
class TestM2MRemove:
|
||||
"""Tests for m2m_remove helper."""
|
||||
|
||||
async def _setup(
|
||||
self, session: AsyncSession, username: str, email: str, *tag_names: str
|
||||
):
|
||||
"""Create a user, post, and tags; associate all tags with the post."""
|
||||
user = User(username=username, email=email)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
post = Post(title=f"Post {username}", author_id=user.id)
|
||||
tags = [Tag(name=n) for n in tag_names]
|
||||
session.add(post)
|
||||
session.add_all(tags)
|
||||
await session.flush()
|
||||
|
||||
async with get_transaction(session):
|
||||
await m2m_add(session, post, Post.tags, *tags)
|
||||
|
||||
return post, tags
|
||||
|
||||
async def _load_tags(self, session: AsyncSession, post: Post) -> list[Tag]:
|
||||
result = await session.execute(
|
||||
select(Post).where(Post.id == post.id).options(selectinload(Post.tags))
|
||||
)
|
||||
return result.scalar_one().tags
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_removes_single(self, db_session: AsyncSession):
|
||||
"""Removes one association, leaving others intact."""
|
||||
post, (tag1, tag2) = await self._setup(
|
||||
db_session, "rm_author1", "rm1@test.com", "tag_rm_a", "tag_rm_b"
|
||||
)
|
||||
|
||||
async with get_transaction(db_session):
|
||||
await m2m_remove(db_session, post, Post.tags, tag1)
|
||||
|
||||
remaining = await self._load_tags(db_session, post)
|
||||
assert len(remaining) == 1
|
||||
assert remaining[0].id == tag2.id
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_removes_multiple(self, db_session: AsyncSession):
|
||||
"""Removes multiple associations in one call."""
|
||||
post, (tag1, tag2, tag3) = await self._setup(
|
||||
db_session, "rm_author2", "rm2@test.com", "tag_rm_c", "tag_rm_d", "tag_rm_e"
|
||||
)
|
||||
|
||||
async with get_transaction(db_session):
|
||||
await m2m_remove(db_session, post, Post.tags, tag1, tag3)
|
||||
|
||||
remaining = await self._load_tags(db_session, post)
|
||||
assert len(remaining) == 1
|
||||
assert remaining[0].id == tag2.id
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_noop_for_empty_related(self, db_session: AsyncSession):
|
||||
"""Calling with no related instances is a no-op."""
|
||||
post, (tag,) = await self._setup(
|
||||
db_session, "rm_author3", "rm3@test.com", "tag_rm_f"
|
||||
)
|
||||
|
||||
async with get_transaction(db_session):
|
||||
await m2m_remove(db_session, post, Post.tags)
|
||||
|
||||
remaining = await self._load_tags(db_session, post)
|
||||
assert len(remaining) == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_idempotent_for_missing_association(self, db_session: AsyncSession):
|
||||
"""Removing a non-existent association does not raise."""
|
||||
post, (tag1,) = await self._setup(
|
||||
db_session, "rm_author4", "rm4@test.com", "tag_rm_g"
|
||||
)
|
||||
tag2 = Tag(name="tag_rm_h")
|
||||
db_session.add(tag2)
|
||||
await db_session.flush()
|
||||
|
||||
# tag2 was never associated — should not raise
|
||||
async with get_transaction(db_session):
|
||||
await m2m_remove(db_session, post, Post.tags, tag2)
|
||||
|
||||
remaining = await self._load_tags(db_session, post)
|
||||
assert len(remaining) == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_non_m2m_raises_type_error(self, db_session: AsyncSession):
|
||||
"""Passing a non-M2M relationship attribute raises TypeError."""
|
||||
user = User(username="rm_author5", email="rm5@test.com")
|
||||
db_session.add(user)
|
||||
await db_session.flush()
|
||||
|
||||
role = Role(name="rm_type_err_role")
|
||||
db_session.add(role)
|
||||
await db_session.flush()
|
||||
|
||||
with pytest.raises(TypeError, match="Many-to-Many"):
|
||||
await m2m_remove(db_session, user, User.role, role)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_removes_composite_pk_related(self):
|
||||
"""Composite-PK branch: DELETE uses tuple IN when related side has multi-col PK."""
|
||||
engine = create_async_engine(DATABASE_URL, echo=False)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(_LocalBase.metadata.create_all)
|
||||
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
try:
|
||||
async with session_factory() as session:
|
||||
owner = _CompOwner()
|
||||
item1 = _CompItem(group_id="g1", item_code="c1")
|
||||
item2 = _CompItem(group_id="g1", item_code="c2")
|
||||
session.add_all([owner, item1, item2])
|
||||
await session.flush()
|
||||
|
||||
async with get_transaction(session):
|
||||
await m2m_add(session, owner, _CompOwner.items, item1, item2)
|
||||
|
||||
async with get_transaction(session):
|
||||
await m2m_remove(session, owner, _CompOwner.items, item1)
|
||||
|
||||
await session.commit()
|
||||
|
||||
async with session_factory() as verify:
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
result = await verify.execute(
|
||||
select(_CompOwner)
|
||||
.where(_CompOwner.id == owner.id)
|
||||
.options(selectinload(_CompOwner.items))
|
||||
)
|
||||
loaded = result.scalar_one()
|
||||
assert len(loaded.items) == 1
|
||||
assert (loaded.items[0].group_id, loaded.items[0].item_code) == (
|
||||
"g1",
|
||||
"c2",
|
||||
)
|
||||
finally:
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(_LocalBase.metadata.drop_all)
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
class TestM2MSet:
|
||||
"""Tests for m2m_set helper."""
|
||||
|
||||
async def _load_tags(self, session: AsyncSession, post: Post) -> list[Tag]:
|
||||
result = await session.execute(
|
||||
select(Post).where(Post.id == post.id).options(selectinload(Post.tags))
|
||||
)
|
||||
return result.scalar_one().tags
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_replaces_existing_set(self, db_session: AsyncSession):
|
||||
"""Replaces the full association set atomically."""
|
||||
user = User(username="set_author1", email="set1@test.com")
|
||||
db_session.add(user)
|
||||
await db_session.flush()
|
||||
|
||||
post = Post(title="Post Set A", author_id=user.id)
|
||||
tag1 = Tag(name="tag_set_a")
|
||||
tag2 = Tag(name="tag_set_b")
|
||||
tag3 = Tag(name="tag_set_c")
|
||||
db_session.add_all([post, tag1, tag2, tag3])
|
||||
await db_session.flush()
|
||||
|
||||
async with get_transaction(db_session):
|
||||
await m2m_add(db_session, post, Post.tags, tag1, tag2)
|
||||
|
||||
async with get_transaction(db_session):
|
||||
await m2m_set(db_session, post, Post.tags, tag3)
|
||||
|
||||
remaining = await self._load_tags(db_session, post)
|
||||
assert len(remaining) == 1
|
||||
assert remaining[0].id == tag3.id
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_clears_all_when_no_related(self, db_session: AsyncSession):
|
||||
"""Passing no related instances clears all associations."""
|
||||
user = User(username="set_author2", email="set2@test.com")
|
||||
db_session.add(user)
|
||||
await db_session.flush()
|
||||
|
||||
post = Post(title="Post Set B", author_id=user.id)
|
||||
tag = Tag(name="tag_set_d")
|
||||
db_session.add_all([post, tag])
|
||||
await db_session.flush()
|
||||
|
||||
async with get_transaction(db_session):
|
||||
await m2m_add(db_session, post, Post.tags, tag)
|
||||
|
||||
async with get_transaction(db_session):
|
||||
await m2m_set(db_session, post, Post.tags)
|
||||
|
||||
remaining = await self._load_tags(db_session, post)
|
||||
assert remaining == []
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_set_on_empty_then_populate(self, db_session: AsyncSession):
|
||||
"""m2m_set works on a post with no existing associations."""
|
||||
user = User(username="set_author3", email="set3@test.com")
|
||||
db_session.add(user)
|
||||
await db_session.flush()
|
||||
|
||||
post = Post(title="Post Set C", author_id=user.id)
|
||||
tag1 = Tag(name="tag_set_e")
|
||||
tag2 = Tag(name="tag_set_f")
|
||||
db_session.add_all([post, tag1, tag2])
|
||||
await db_session.flush()
|
||||
|
||||
async with get_transaction(db_session):
|
||||
await m2m_set(db_session, post, Post.tags, tag1, tag2)
|
||||
|
||||
remaining = await self._load_tags(db_session, post)
|
||||
assert {t.id for t in remaining} == {tag1.id, tag2.id}
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_non_m2m_raises_type_error(self, db_session: AsyncSession):
|
||||
"""Passing a non-M2M relationship attribute raises TypeError."""
|
||||
user = User(username="set_author4", email="set4@test.com")
|
||||
db_session.add(user)
|
||||
await db_session.flush()
|
||||
|
||||
role = Role(name="set_type_err_role")
|
||||
db_session.add(role)
|
||||
await db_session.flush()
|
||||
|
||||
with pytest.raises(TypeError, match="Many-to-Many"):
|
||||
await m2m_set(db_session, user, User.role, role)
|
||||
|
||||
@@ -536,11 +536,6 @@ class TestCreateWorkerDatabase:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Local models for composite-PK coverage (own Base → own tables, isolated)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _LocalBase(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
Reference in New Issue
Block a user