feat: add M2M helpers (#247)

This commit is contained in:
d3vyce
2026-04-12 18:46:57 +02:00
committed by GitHub
parent 863e6ce6e9
commit 9268b576b4
5 changed files with 656 additions and 11 deletions

View File

@@ -4,11 +4,13 @@ import asyncio
from collections.abc import AsyncGenerator, Callable
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from enum import Enum
from typing import Any, TypeVar
from typing import Any, TypeVar, cast
from sqlalchemy import text
from sqlalchemy import Table, delete, text, tuple_
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import DeclarativeBase, QueryableAttribute
from sqlalchemy.orm.relationships import RelationshipProperty
from .exceptions import NotFoundError
@@ -20,6 +22,9 @@ __all__ = [
"create_db_dependency",
"get_transaction",
"lock_tables",
"m2m_add",
"m2m_remove",
"m2m_set",
"wait_for_row_change",
]
@@ -339,3 +344,140 @@ async def wait_for_row_change(
current = {col: getattr(instance, col) for col in watch_cols}
if current != initial:
return instance
def _m2m_prop(rel_attr: QueryableAttribute) -> RelationshipProperty: # type: ignore[type-arg]
"""Return the validated M2M RelationshipProperty for *rel_attr*.
Raises TypeError if *rel_attr* is not a Many-to-Many relationship.
"""
prop = rel_attr.property
if not isinstance(prop, RelationshipProperty) or prop.secondary is None:
raise TypeError(
f"m2m helpers require a Many-to-Many relationship attribute, "
f"got {rel_attr!r}. Use a relationship with a secondary table."
)
return prop
async def m2m_add(
session: AsyncSession,
instance: DeclarativeBase,
rel_attr: QueryableAttribute,
*related: DeclarativeBase,
ignore_conflicts: bool = False,
) -> None:
"""Insert rows into a Many-to-Many association table without loading the ORM collection.
Args:
session: DB async session.
instance: The "owner" side model instance (e.g. the ``A`` in ``A.b_list``).
rel_attr: The M2M relationship attribute on the model class (e.g. ``A.b_list``).
*related: One or more related instances to associate with ``instance``.
ignore_conflicts: When ``True``, silently skip rows that already exist
in the association table (``ON CONFLICT DO NOTHING``).
Raises:
TypeError: If ``rel_attr`` is not a Many-to-Many relationship.
"""
prop = _m2m_prop(rel_attr)
if not related:
return
secondary = cast(Table, prop.secondary)
assert secondary is not None # guaranteed by _m2m_prop
sync_pairs = prop.secondary_synchronize_pairs
assert sync_pairs is not None # set whenever secondary is set
# synchronize_pairs: [(parent_col, assoc_col), ...]
# secondary_synchronize_pairs: [(related_col, assoc_col), ...]
rows: list[dict[str, Any]] = []
for rel_instance in related:
row: dict[str, Any] = {}
for parent_col, assoc_col in prop.synchronize_pairs:
row[assoc_col.name] = getattr(instance, cast(str, parent_col.key))
for related_col, assoc_col in sync_pairs:
row[assoc_col.name] = getattr(rel_instance, cast(str, related_col.key))
rows.append(row)
stmt = pg_insert(secondary).values(rows)
if ignore_conflicts:
stmt = stmt.on_conflict_do_nothing()
await session.execute(stmt)
async def m2m_remove(
session: AsyncSession,
instance: DeclarativeBase,
rel_attr: QueryableAttribute,
*related: DeclarativeBase,
) -> None:
"""Remove rows from a Many-to-Many association table without loading the ORM collection.
Args:
session: DB async session.
instance: The "owner" side model instance (e.g. the ``A`` in ``A.b_list``).
rel_attr: The M2M relationship attribute on the model class (e.g. ``A.b_list``).
*related: One or more related instances to disassociate from ``instance``.
Raises:
TypeError: If ``rel_attr`` is not a Many-to-Many relationship.
"""
prop = _m2m_prop(rel_attr)
if not related:
return
secondary = cast(Table, prop.secondary)
assert secondary is not None # guaranteed by _m2m_prop
related_pairs = prop.secondary_synchronize_pairs
assert related_pairs is not None # set whenever secondary is set
parent_where = [
assoc_col == getattr(instance, cast(str, parent_col.key))
for parent_col, assoc_col in prop.synchronize_pairs
]
if len(related_pairs) == 1:
related_col, assoc_col = related_pairs[0]
related_values = [getattr(r, cast(str, related_col.key)) for r in related]
related_where = assoc_col.in_(related_values)
else:
assoc_cols = [ac for _, ac in related_pairs]
rel_cols = [rc for rc, _ in related_pairs]
related_values_t = [
tuple(getattr(r, cast(str, rc.key)) for rc in rel_cols) for r in related
]
related_where = tuple_(*assoc_cols).in_(related_values_t)
await session.execute(delete(secondary).where(*parent_where, related_where))
async def m2m_set(
session: AsyncSession,
instance: DeclarativeBase,
rel_attr: QueryableAttribute,
*related: DeclarativeBase,
) -> None:
"""Replace the entire Many-to-Many association set atomically.
Args:
session: DB async session.
instance: The "owner" side model instance (e.g. the ``A`` in ``A.b_list``).
rel_attr: The M2M relationship attribute on the model class (e.g. ``A.b_list``).
*related: The new complete set of related instances.
Raises:
TypeError: If ``rel_attr`` is not a Many-to-Many relationship.
"""
prop = _m2m_prop(rel_attr)
secondary = cast(Table, prop.secondary)
assert secondary is not None # guaranteed by _m2m_prop
parent_where = [
assoc_col == getattr(instance, cast(str, parent_col.key))
for parent_col, assoc_col in prop.synchronize_pairs
]
await session.execute(delete(secondary).where(*parent_where))
if related:
await m2m_add(session, instance, rel_attr, *related)