mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-15 22:26:25 +02:00
feat: add M2M helpers (#247)
This commit is contained in:
@@ -4,11 +4,13 @@ import asyncio
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
||||
from enum import Enum
|
||||
from typing import Any, TypeVar
|
||||
from typing import Any, TypeVar, cast
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy import Table, delete, text, tuple_
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from sqlalchemy.orm import DeclarativeBase, QueryableAttribute
|
||||
from sqlalchemy.orm.relationships import RelationshipProperty
|
||||
|
||||
from .exceptions import NotFoundError
|
||||
|
||||
@@ -20,6 +22,9 @@ __all__ = [
|
||||
"create_db_dependency",
|
||||
"get_transaction",
|
||||
"lock_tables",
|
||||
"m2m_add",
|
||||
"m2m_remove",
|
||||
"m2m_set",
|
||||
"wait_for_row_change",
|
||||
]
|
||||
|
||||
@@ -339,3 +344,140 @@ async def wait_for_row_change(
|
||||
current = {col: getattr(instance, col) for col in watch_cols}
|
||||
if current != initial:
|
||||
return instance
|
||||
|
||||
|
||||
def _m2m_prop(rel_attr: QueryableAttribute) -> RelationshipProperty: # type: ignore[type-arg]
|
||||
"""Return the validated M2M RelationshipProperty for *rel_attr*.
|
||||
|
||||
Raises TypeError if *rel_attr* is not a Many-to-Many relationship.
|
||||
"""
|
||||
prop = rel_attr.property
|
||||
if not isinstance(prop, RelationshipProperty) or prop.secondary is None:
|
||||
raise TypeError(
|
||||
f"m2m helpers require a Many-to-Many relationship attribute, "
|
||||
f"got {rel_attr!r}. Use a relationship with a secondary table."
|
||||
)
|
||||
return prop
|
||||
|
||||
|
||||
async def m2m_add(
|
||||
session: AsyncSession,
|
||||
instance: DeclarativeBase,
|
||||
rel_attr: QueryableAttribute,
|
||||
*related: DeclarativeBase,
|
||||
ignore_conflicts: bool = False,
|
||||
) -> None:
|
||||
"""Insert rows into a Many-to-Many association table without loading the ORM collection.
|
||||
|
||||
Args:
|
||||
session: DB async session.
|
||||
instance: The "owner" side model instance (e.g. the ``A`` in ``A.b_list``).
|
||||
rel_attr: The M2M relationship attribute on the model class (e.g. ``A.b_list``).
|
||||
*related: One or more related instances to associate with ``instance``.
|
||||
ignore_conflicts: When ``True``, silently skip rows that already exist
|
||||
in the association table (``ON CONFLICT DO NOTHING``).
|
||||
|
||||
Raises:
|
||||
TypeError: If ``rel_attr`` is not a Many-to-Many relationship.
|
||||
"""
|
||||
prop = _m2m_prop(rel_attr)
|
||||
if not related:
|
||||
return
|
||||
|
||||
secondary = cast(Table, prop.secondary)
|
||||
assert secondary is not None # guaranteed by _m2m_prop
|
||||
sync_pairs = prop.secondary_synchronize_pairs
|
||||
assert sync_pairs is not None # set whenever secondary is set
|
||||
|
||||
# synchronize_pairs: [(parent_col, assoc_col), ...]
|
||||
# secondary_synchronize_pairs: [(related_col, assoc_col), ...]
|
||||
rows: list[dict[str, Any]] = []
|
||||
for rel_instance in related:
|
||||
row: dict[str, Any] = {}
|
||||
for parent_col, assoc_col in prop.synchronize_pairs:
|
||||
row[assoc_col.name] = getattr(instance, cast(str, parent_col.key))
|
||||
for related_col, assoc_col in sync_pairs:
|
||||
row[assoc_col.name] = getattr(rel_instance, cast(str, related_col.key))
|
||||
rows.append(row)
|
||||
|
||||
stmt = pg_insert(secondary).values(rows)
|
||||
if ignore_conflicts:
|
||||
stmt = stmt.on_conflict_do_nothing()
|
||||
await session.execute(stmt)
|
||||
|
||||
|
||||
async def m2m_remove(
|
||||
session: AsyncSession,
|
||||
instance: DeclarativeBase,
|
||||
rel_attr: QueryableAttribute,
|
||||
*related: DeclarativeBase,
|
||||
) -> None:
|
||||
"""Remove rows from a Many-to-Many association table without loading the ORM collection.
|
||||
|
||||
Args:
|
||||
session: DB async session.
|
||||
instance: The "owner" side model instance (e.g. the ``A`` in ``A.b_list``).
|
||||
rel_attr: The M2M relationship attribute on the model class (e.g. ``A.b_list``).
|
||||
*related: One or more related instances to disassociate from ``instance``.
|
||||
|
||||
Raises:
|
||||
TypeError: If ``rel_attr`` is not a Many-to-Many relationship.
|
||||
"""
|
||||
prop = _m2m_prop(rel_attr)
|
||||
if not related:
|
||||
return
|
||||
|
||||
secondary = cast(Table, prop.secondary)
|
||||
assert secondary is not None # guaranteed by _m2m_prop
|
||||
related_pairs = prop.secondary_synchronize_pairs
|
||||
assert related_pairs is not None # set whenever secondary is set
|
||||
|
||||
parent_where = [
|
||||
assoc_col == getattr(instance, cast(str, parent_col.key))
|
||||
for parent_col, assoc_col in prop.synchronize_pairs
|
||||
]
|
||||
|
||||
if len(related_pairs) == 1:
|
||||
related_col, assoc_col = related_pairs[0]
|
||||
related_values = [getattr(r, cast(str, related_col.key)) for r in related]
|
||||
related_where = assoc_col.in_(related_values)
|
||||
else:
|
||||
assoc_cols = [ac for _, ac in related_pairs]
|
||||
rel_cols = [rc for rc, _ in related_pairs]
|
||||
related_values_t = [
|
||||
tuple(getattr(r, cast(str, rc.key)) for rc in rel_cols) for r in related
|
||||
]
|
||||
related_where = tuple_(*assoc_cols).in_(related_values_t)
|
||||
|
||||
await session.execute(delete(secondary).where(*parent_where, related_where))
|
||||
|
||||
|
||||
async def m2m_set(
|
||||
session: AsyncSession,
|
||||
instance: DeclarativeBase,
|
||||
rel_attr: QueryableAttribute,
|
||||
*related: DeclarativeBase,
|
||||
) -> None:
|
||||
"""Replace the entire Many-to-Many association set atomically.
|
||||
|
||||
Args:
|
||||
session: DB async session.
|
||||
instance: The "owner" side model instance (e.g. the ``A`` in ``A.b_list``).
|
||||
rel_attr: The M2M relationship attribute on the model class (e.g. ``A.b_list``).
|
||||
*related: The new complete set of related instances.
|
||||
|
||||
Raises:
|
||||
TypeError: If ``rel_attr`` is not a Many-to-Many relationship.
|
||||
"""
|
||||
prop = _m2m_prop(rel_attr)
|
||||
secondary = cast(Table, prop.secondary)
|
||||
assert secondary is not None # guaranteed by _m2m_prop
|
||||
|
||||
parent_where = [
|
||||
assoc_col == getattr(instance, cast(str, parent_col.key))
|
||||
for parent_col, assoc_col in prop.synchronize_pairs
|
||||
]
|
||||
await session.execute(delete(secondary).where(*parent_where))
|
||||
|
||||
if related:
|
||||
await m2m_add(session, instance, rel_attr, *related)
|
||||
|
||||
Reference in New Issue
Block a user