mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-16 06:36:26 +02:00
299 lines
9.9 KiB
Python
299 lines
9.9 KiB
Python
"""Fixture loading utilities for database seeding."""
|
|
|
|
from collections.abc import Callable, Sequence
|
|
from enum import Enum
|
|
from typing import Any
|
|
|
|
from sqlalchemy import inspect as sa_inspect
|
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.orm import DeclarativeBase
|
|
|
|
from ..db import get_transaction
|
|
from ..logger import get_logger
|
|
from ..types import ModelType
|
|
from .enum import LoadStrategy
|
|
from .registry import FixtureRegistry, _normalize_contexts
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
def _instance_to_dict(instance: DeclarativeBase) -> dict[str, Any]:
|
|
"""Extract column values from a model instance, skipping unset server-default columns."""
|
|
state = sa_inspect(instance)
|
|
state_dict = state.dict
|
|
result: dict[str, Any] = {}
|
|
for prop in state.mapper.column_attrs:
|
|
if prop.key not in state_dict:
|
|
continue
|
|
val = state_dict[prop.key]
|
|
if val is None:
|
|
col = prop.columns[0]
|
|
|
|
if (
|
|
col.server_default is not None
|
|
or (col.default is not None and col.default.is_callable)
|
|
or col.autoincrement is True
|
|
):
|
|
continue
|
|
result[prop.key] = val
|
|
return result
|
|
|
|
|
|
def _group_by_type(
|
|
instances: list[DeclarativeBase],
|
|
) -> list[tuple[type[DeclarativeBase], list[DeclarativeBase]]]:
|
|
"""Group instances by their concrete model class, preserving insertion order."""
|
|
groups: dict[type[DeclarativeBase], list[DeclarativeBase]] = {}
|
|
for instance in instances:
|
|
groups.setdefault(type(instance), []).append(instance)
|
|
return list(groups.items())
|
|
|
|
|
|
def _group_by_column_set(
|
|
dicts: list[dict[str, Any]],
|
|
instances: list[DeclarativeBase],
|
|
) -> list[tuple[list[dict[str, Any]], list[DeclarativeBase]]]:
|
|
"""Group (dict, instance) pairs by their dict key sets."""
|
|
groups: dict[
|
|
frozenset[str], tuple[list[dict[str, Any]], list[DeclarativeBase]]
|
|
] = {}
|
|
for d, inst in zip(dicts, instances):
|
|
key = frozenset(d)
|
|
if key not in groups:
|
|
groups[key] = ([], [])
|
|
groups[key][0].append(d)
|
|
groups[key][1].append(inst)
|
|
return list(groups.values())
|
|
|
|
|
|
async def _batch_insert(
|
|
session: AsyncSession,
|
|
model_cls: type[DeclarativeBase],
|
|
instances: list[DeclarativeBase],
|
|
) -> None:
|
|
"""INSERT all instances — raises on conflict (no duplicate handling)."""
|
|
dicts = [_instance_to_dict(i) for i in instances]
|
|
for group_dicts, _ in _group_by_column_set(dicts, instances):
|
|
await session.execute(pg_insert(model_cls).values(group_dicts))
|
|
|
|
|
|
async def _batch_merge(
|
|
session: AsyncSession,
|
|
model_cls: type[DeclarativeBase],
|
|
instances: list[DeclarativeBase],
|
|
) -> None:
|
|
"""UPSERT: insert new rows, update existing ones with the provided values."""
|
|
mapper = model_cls.__mapper__
|
|
pk_names = [col.name for col in mapper.primary_key]
|
|
pk_names_set = set(pk_names)
|
|
non_pk_cols = [
|
|
prop.key
|
|
for prop in mapper.column_attrs
|
|
if not any(col.name in pk_names_set for col in prop.columns)
|
|
]
|
|
|
|
dicts = [_instance_to_dict(i) for i in instances]
|
|
for group_dicts, _ in _group_by_column_set(dicts, instances):
|
|
stmt = pg_insert(model_cls).values(group_dicts)
|
|
|
|
inserted_keys = set(group_dicts[0])
|
|
update_cols = [col for col in non_pk_cols if col in inserted_keys]
|
|
|
|
if update_cols:
|
|
stmt = stmt.on_conflict_do_update(
|
|
index_elements=pk_names,
|
|
set_={col: stmt.excluded[col] for col in update_cols},
|
|
)
|
|
else:
|
|
stmt = stmt.on_conflict_do_nothing(index_elements=pk_names)
|
|
|
|
await session.execute(stmt)
|
|
|
|
|
|
async def _batch_skip_existing(
|
|
session: AsyncSession,
|
|
model_cls: type[DeclarativeBase],
|
|
instances: list[DeclarativeBase],
|
|
) -> list[DeclarativeBase]:
|
|
"""INSERT only rows that do not already exist; return the inserted ones."""
|
|
mapper = model_cls.__mapper__
|
|
pk_names = [col.name for col in mapper.primary_key]
|
|
|
|
no_pk: list[DeclarativeBase] = []
|
|
with_pk_pairs: list[tuple[DeclarativeBase, Any]] = []
|
|
for inst in instances:
|
|
pk = _get_primary_key(inst)
|
|
if pk is None:
|
|
no_pk.append(inst)
|
|
else:
|
|
with_pk_pairs.append((inst, pk))
|
|
|
|
loaded: list[DeclarativeBase] = list(no_pk)
|
|
if no_pk:
|
|
no_pk_dicts = [_instance_to_dict(i) for i in no_pk]
|
|
for group_dicts, _ in _group_by_column_set(no_pk_dicts, no_pk):
|
|
await session.execute(pg_insert(model_cls).values(group_dicts))
|
|
|
|
if with_pk_pairs:
|
|
with_pk = [i for i, _ in with_pk_pairs]
|
|
with_pk_dicts = [_instance_to_dict(i) for i in with_pk]
|
|
for group_dicts, group_insts in _group_by_column_set(with_pk_dicts, with_pk):
|
|
stmt = (
|
|
pg_insert(model_cls)
|
|
.values(group_dicts)
|
|
.on_conflict_do_nothing(index_elements=pk_names)
|
|
)
|
|
result = await session.execute(stmt.returning(*mapper.primary_key))
|
|
inserted_pks = {
|
|
row[0] if len(pk_names) == 1 else tuple(row) for row in result
|
|
}
|
|
loaded.extend(
|
|
inst
|
|
for inst, pk in zip(
|
|
group_insts, [_get_primary_key(i) for i in group_insts]
|
|
)
|
|
if pk in inserted_pks
|
|
)
|
|
|
|
return loaded
|
|
|
|
|
|
async def _load_ordered(
|
|
session: AsyncSession,
|
|
registry: FixtureRegistry,
|
|
ordered_names: list[str],
|
|
strategy: LoadStrategy,
|
|
contexts: tuple[str, ...] | None = None,
|
|
) -> dict[str, list[DeclarativeBase]]:
|
|
"""Load fixtures in order using batch Core INSERT statements."""
|
|
results: dict[str, list[DeclarativeBase]] = {}
|
|
|
|
for name in ordered_names:
|
|
variants = (
|
|
registry.get_variants(name, *contexts)
|
|
if contexts is not None
|
|
else registry.get_variants(name)
|
|
)
|
|
|
|
if contexts is not None and not variants:
|
|
variants = registry.get_variants(name)
|
|
|
|
if not variants:
|
|
results[name] = []
|
|
continue
|
|
|
|
instances = [inst for v in variants for inst in v.func()]
|
|
|
|
if not instances:
|
|
results[name] = []
|
|
continue
|
|
|
|
model_name = type(instances[0]).__name__
|
|
loaded: list[DeclarativeBase] = []
|
|
|
|
async with get_transaction(session):
|
|
for model_cls, group in _group_by_type(instances):
|
|
match strategy:
|
|
case LoadStrategy.INSERT:
|
|
await _batch_insert(session, model_cls, group)
|
|
loaded.extend(group)
|
|
case LoadStrategy.MERGE:
|
|
await _batch_merge(session, model_cls, group)
|
|
loaded.extend(group)
|
|
case LoadStrategy.SKIP_EXISTING:
|
|
inserted = await _batch_skip_existing(session, model_cls, group)
|
|
loaded.extend(inserted)
|
|
|
|
results[name] = loaded
|
|
logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")
|
|
|
|
return results
|
|
|
|
|
|
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
|
|
"""Get the primary key value of a model instance."""
|
|
mapper = instance.__class__.__mapper__
|
|
pk_cols = mapper.primary_key
|
|
|
|
if len(pk_cols) == 1:
|
|
return getattr(instance, pk_cols[0].name, None)
|
|
|
|
pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols)
|
|
if all(v is not None for v in pk_values):
|
|
return pk_values
|
|
return None
|
|
|
|
|
|
def get_obj_by_attr(
|
|
fixtures: Callable[[], Sequence[ModelType]], attr_name: str, value: Any
|
|
) -> ModelType:
|
|
"""Get a SQLAlchemy model instance by matching an attribute value.
|
|
|
|
Args:
|
|
fixtures: A fixture function registered via ``@registry.register``
|
|
that returns a sequence of SQLAlchemy model instances.
|
|
attr_name: Name of the attribute to match against.
|
|
value: Value to match.
|
|
|
|
Returns:
|
|
The first model instance where the attribute matches the given value.
|
|
|
|
Raises:
|
|
StopIteration: If no matching object is found in the fixture group.
|
|
"""
|
|
try:
|
|
return next(obj for obj in fixtures() if getattr(obj, attr_name) == value)
|
|
except StopIteration:
|
|
raise StopIteration(
|
|
f"No object with {attr_name}={value} found in fixture '{getattr(fixtures, '__name__', repr(fixtures))}'"
|
|
) from None
|
|
|
|
|
|
async def load_fixtures(
|
|
session: AsyncSession,
|
|
registry: FixtureRegistry,
|
|
*names: str,
|
|
strategy: LoadStrategy = LoadStrategy.MERGE,
|
|
) -> dict[str, list[DeclarativeBase]]:
|
|
"""Load specific fixtures by name with dependencies.
|
|
|
|
All context variants of each requested fixture are loaded and merged.
|
|
|
|
Args:
|
|
session: Database session
|
|
registry: Fixture registry
|
|
*names: Fixture names to load (dependencies auto-resolved)
|
|
strategy: How to handle existing records
|
|
|
|
Returns:
|
|
Dict mapping fixture names to loaded instances
|
|
"""
|
|
ordered = registry.resolve_dependencies(*names)
|
|
return await _load_ordered(session, registry, ordered, strategy)
|
|
|
|
|
|
async def load_fixtures_by_context(
|
|
session: AsyncSession,
|
|
registry: FixtureRegistry,
|
|
*contexts: str | Enum,
|
|
strategy: LoadStrategy = LoadStrategy.MERGE,
|
|
) -> dict[str, list[DeclarativeBase]]:
|
|
"""Load all fixtures for specific contexts.
|
|
|
|
Args:
|
|
session: Database session
|
|
registry: Fixture registry
|
|
*contexts: Contexts to load (e.g., ``Context.BASE``, ``Context.TESTING``,
|
|
or plain strings for custom contexts)
|
|
strategy: How to handle existing records
|
|
|
|
Returns:
|
|
Dict mapping fixture names to loaded instances
|
|
"""
|
|
context_strings = tuple(_normalize_contexts(contexts))
|
|
ordered = registry.resolve_context_dependencies(*contexts)
|
|
return await _load_ordered(
|
|
session, registry, ordered, strategy, contexts=context_strings
|
|
)
|