perf: batch insert fixtures (#188)

This commit is contained in:
d3vyce
2026-03-26 20:29:25 +01:00
committed by GitHub
parent 04afef7e33
commit 29326ab532
2 changed files with 176 additions and 21 deletions

View File

@@ -4,6 +4,8 @@ from collections.abc import Callable, Sequence
from enum import Enum
from typing import Any
from sqlalchemy import inspect as sa_inspect
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase
@@ -16,6 +18,113 @@ from .registry import FixtureRegistry, _normalize_contexts
logger = get_logger()
def _instance_to_dict(instance: DeclarativeBase) -> dict[str, Any]:
"""Extract column values from a model instance, skipping unset server-default columns."""
state = sa_inspect(instance)
state_dict = state.dict
result: dict[str, Any] = {}
for prop in state.mapper.column_attrs:
if prop.key not in state_dict:
continue
val = state_dict[prop.key]
if val is None:
col = prop.columns[0]
if col.server_default is not None or (
col.default is not None and col.default.is_callable
):
continue
result[prop.key] = val
return result
def _group_by_type(
instances: list[DeclarativeBase],
) -> list[tuple[type[DeclarativeBase], list[DeclarativeBase]]]:
"""Group instances by their concrete model class, preserving insertion order."""
groups: dict[type[DeclarativeBase], list[DeclarativeBase]] = {}
for instance in instances:
groups.setdefault(type(instance), []).append(instance)
return list(groups.items())
async def _batch_insert(
session: AsyncSession,
model_cls: type[DeclarativeBase],
instances: list[DeclarativeBase],
) -> None:
"""INSERT all instances — raises on conflict (no duplicate handling)."""
dicts = [_instance_to_dict(i) for i in instances]
await session.execute(pg_insert(model_cls).values(dicts))
async def _batch_merge(
session: AsyncSession,
model_cls: type[DeclarativeBase],
instances: list[DeclarativeBase],
) -> None:
"""UPSERT: insert new rows, update existing ones with the provided values."""
mapper = model_cls.__mapper__
pk_names = [col.name for col in mapper.primary_key]
pk_names_set = set(pk_names)
non_pk_cols = [
prop.key
for prop in mapper.column_attrs
if not any(col.name in pk_names_set for col in prop.columns)
]
dicts = [_instance_to_dict(i) for i in instances]
stmt = pg_insert(model_cls).values(dicts)
if non_pk_cols:
stmt = stmt.on_conflict_do_update(
index_elements=pk_names,
set_={col: stmt.excluded[col] for col in non_pk_cols},
)
else:
stmt = stmt.on_conflict_do_nothing(index_elements=pk_names)
await session.execute(stmt)
async def _batch_skip_existing(
session: AsyncSession,
model_cls: type[DeclarativeBase],
instances: list[DeclarativeBase],
) -> list[DeclarativeBase]:
"""INSERT only rows that do not already exist; return the inserted ones."""
mapper = model_cls.__mapper__
pk_names = [col.name for col in mapper.primary_key]
no_pk: list[DeclarativeBase] = []
with_pk_pairs: list[tuple[DeclarativeBase, Any]] = []
for inst in instances:
pk = _get_primary_key(inst)
if pk is None:
no_pk.append(inst)
else:
with_pk_pairs.append((inst, pk))
loaded: list[DeclarativeBase] = list(no_pk)
if no_pk:
await session.execute(
pg_insert(model_cls).values([_instance_to_dict(i) for i in no_pk])
)
if with_pk_pairs:
with_pk = [i for i, _ in with_pk_pairs]
stmt = (
pg_insert(model_cls)
.values([_instance_to_dict(i) for i in with_pk])
.on_conflict_do_nothing(index_elements=pk_names)
)
result = await session.execute(stmt.returning(*mapper.primary_key))
inserted_pks = {row[0] if len(pk_names) == 1 else tuple(row) for row in result}
loaded.extend(inst for inst, pk in with_pk_pairs if pk in inserted_pks)
return loaded
async def _load_ordered(
session: AsyncSession,
registry: FixtureRegistry,
@@ -23,7 +132,7 @@ async def _load_ordered(
strategy: LoadStrategy,
contexts: tuple[str, ...] | None = None,
) -> dict[str, list[DeclarativeBase]]:
"""Load fixtures in order.
"""Load fixtures in order using batch Core INSERT statements.
When *contexts* is provided only variants whose context set intersects with
*contexts* are called for each name; their instances are concatenated.
@@ -59,25 +168,17 @@ async def _load_ordered(
loaded: list[DeclarativeBase] = []
async with get_transaction(session):
for instance in instances:
if strategy == LoadStrategy.INSERT:
session.add(instance)
loaded.append(instance)
elif strategy == LoadStrategy.MERGE:
merged = await session.merge(instance)
loaded.append(merged)
else: # LoadStrategy.SKIP_EXISTING
pk = _get_primary_key(instance)
if pk is not None:
existing = await session.get(type(instance), pk)
if existing is None:
session.add(instance)
loaded.append(instance)
else:
session.add(instance)
loaded.append(instance)
for model_cls, group in _group_by_type(instances):
match strategy:
case LoadStrategy.INSERT:
await _batch_insert(session, model_cls, group)
loaded.extend(group)
case LoadStrategy.MERGE:
await _batch_merge(session, model_cls, group)
loaded.extend(group)
case LoadStrategy.SKIP_EXISTING:
inserted = await _batch_skip_existing(session, model_cls, group)
loaded.extend(inserted)
results[name] = loaded
logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")