mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-15 22:26:25 +02:00
perf: batch insert fixtures (#188)
This commit is contained in:
@@ -4,6 +4,8 @@ from collections.abc import Callable, Sequence
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import inspect as sa_inspect
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
@@ -16,6 +18,113 @@ from .registry import FixtureRegistry, _normalize_contexts
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def _instance_to_dict(instance: DeclarativeBase) -> dict[str, Any]:
|
||||
"""Extract column values from a model instance, skipping unset server-default columns."""
|
||||
state = sa_inspect(instance)
|
||||
state_dict = state.dict
|
||||
result: dict[str, Any] = {}
|
||||
for prop in state.mapper.column_attrs:
|
||||
if prop.key not in state_dict:
|
||||
continue
|
||||
val = state_dict[prop.key]
|
||||
if val is None:
|
||||
col = prop.columns[0]
|
||||
|
||||
if col.server_default is not None or (
|
||||
col.default is not None and col.default.is_callable
|
||||
):
|
||||
continue
|
||||
result[prop.key] = val
|
||||
return result
|
||||
|
||||
|
||||
def _group_by_type(
|
||||
instances: list[DeclarativeBase],
|
||||
) -> list[tuple[type[DeclarativeBase], list[DeclarativeBase]]]:
|
||||
"""Group instances by their concrete model class, preserving insertion order."""
|
||||
groups: dict[type[DeclarativeBase], list[DeclarativeBase]] = {}
|
||||
for instance in instances:
|
||||
groups.setdefault(type(instance), []).append(instance)
|
||||
return list(groups.items())
|
||||
|
||||
|
||||
async def _batch_insert(
|
||||
session: AsyncSession,
|
||||
model_cls: type[DeclarativeBase],
|
||||
instances: list[DeclarativeBase],
|
||||
) -> None:
|
||||
"""INSERT all instances — raises on conflict (no duplicate handling)."""
|
||||
dicts = [_instance_to_dict(i) for i in instances]
|
||||
await session.execute(pg_insert(model_cls).values(dicts))
|
||||
|
||||
|
||||
async def _batch_merge(
|
||||
session: AsyncSession,
|
||||
model_cls: type[DeclarativeBase],
|
||||
instances: list[DeclarativeBase],
|
||||
) -> None:
|
||||
"""UPSERT: insert new rows, update existing ones with the provided values."""
|
||||
mapper = model_cls.__mapper__
|
||||
pk_names = [col.name for col in mapper.primary_key]
|
||||
pk_names_set = set(pk_names)
|
||||
non_pk_cols = [
|
||||
prop.key
|
||||
for prop in mapper.column_attrs
|
||||
if not any(col.name in pk_names_set for col in prop.columns)
|
||||
]
|
||||
|
||||
dicts = [_instance_to_dict(i) for i in instances]
|
||||
stmt = pg_insert(model_cls).values(dicts)
|
||||
|
||||
if non_pk_cols:
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=pk_names,
|
||||
set_={col: stmt.excluded[col] for col in non_pk_cols},
|
||||
)
|
||||
else:
|
||||
stmt = stmt.on_conflict_do_nothing(index_elements=pk_names)
|
||||
|
||||
await session.execute(stmt)
|
||||
|
||||
|
||||
async def _batch_skip_existing(
|
||||
session: AsyncSession,
|
||||
model_cls: type[DeclarativeBase],
|
||||
instances: list[DeclarativeBase],
|
||||
) -> list[DeclarativeBase]:
|
||||
"""INSERT only rows that do not already exist; return the inserted ones."""
|
||||
mapper = model_cls.__mapper__
|
||||
pk_names = [col.name for col in mapper.primary_key]
|
||||
|
||||
no_pk: list[DeclarativeBase] = []
|
||||
with_pk_pairs: list[tuple[DeclarativeBase, Any]] = []
|
||||
for inst in instances:
|
||||
pk = _get_primary_key(inst)
|
||||
if pk is None:
|
||||
no_pk.append(inst)
|
||||
else:
|
||||
with_pk_pairs.append((inst, pk))
|
||||
|
||||
loaded: list[DeclarativeBase] = list(no_pk)
|
||||
if no_pk:
|
||||
await session.execute(
|
||||
pg_insert(model_cls).values([_instance_to_dict(i) for i in no_pk])
|
||||
)
|
||||
|
||||
if with_pk_pairs:
|
||||
with_pk = [i for i, _ in with_pk_pairs]
|
||||
stmt = (
|
||||
pg_insert(model_cls)
|
||||
.values([_instance_to_dict(i) for i in with_pk])
|
||||
.on_conflict_do_nothing(index_elements=pk_names)
|
||||
)
|
||||
result = await session.execute(stmt.returning(*mapper.primary_key))
|
||||
inserted_pks = {row[0] if len(pk_names) == 1 else tuple(row) for row in result}
|
||||
loaded.extend(inst for inst, pk in with_pk_pairs if pk in inserted_pks)
|
||||
|
||||
return loaded
|
||||
|
||||
|
||||
async def _load_ordered(
|
||||
session: AsyncSession,
|
||||
registry: FixtureRegistry,
|
||||
@@ -23,7 +132,7 @@ async def _load_ordered(
|
||||
strategy: LoadStrategy,
|
||||
contexts: tuple[str, ...] | None = None,
|
||||
) -> dict[str, list[DeclarativeBase]]:
|
||||
"""Load fixtures in order.
|
||||
"""Load fixtures in order using batch Core INSERT statements.
|
||||
|
||||
When *contexts* is provided only variants whose context set intersects with
|
||||
*contexts* are called for each name; their instances are concatenated.
|
||||
@@ -59,25 +168,17 @@ async def _load_ordered(
|
||||
loaded: list[DeclarativeBase] = []
|
||||
|
||||
async with get_transaction(session):
|
||||
for instance in instances:
|
||||
if strategy == LoadStrategy.INSERT:
|
||||
session.add(instance)
|
||||
loaded.append(instance)
|
||||
|
||||
elif strategy == LoadStrategy.MERGE:
|
||||
merged = await session.merge(instance)
|
||||
loaded.append(merged)
|
||||
|
||||
else: # LoadStrategy.SKIP_EXISTING
|
||||
pk = _get_primary_key(instance)
|
||||
if pk is not None:
|
||||
existing = await session.get(type(instance), pk)
|
||||
if existing is None:
|
||||
session.add(instance)
|
||||
loaded.append(instance)
|
||||
else:
|
||||
session.add(instance)
|
||||
loaded.append(instance)
|
||||
for model_cls, group in _group_by_type(instances):
|
||||
match strategy:
|
||||
case LoadStrategy.INSERT:
|
||||
await _batch_insert(session, model_cls, group)
|
||||
loaded.extend(group)
|
||||
case LoadStrategy.MERGE:
|
||||
await _batch_merge(session, model_cls, group)
|
||||
loaded.extend(group)
|
||||
case LoadStrategy.SKIP_EXISTING:
|
||||
inserted = await _batch_skip_existing(session, model_cls, group)
|
||||
loaded.extend(inserted)
|
||||
|
||||
results[name] = loaded
|
||||
logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")
|
||||
|
||||
@@ -15,7 +15,7 @@ from fastapi_toolsets.fixtures import (
|
||||
load_fixtures_by_context,
|
||||
)
|
||||
|
||||
from fastapi_toolsets.fixtures.utils import _get_primary_key
|
||||
from fastapi_toolsets.fixtures.utils import _get_primary_key, _instance_to_dict
|
||||
|
||||
from .conftest import IntRole, Permission, Role, RoleCrud, User, UserCrud
|
||||
|
||||
@@ -945,3 +945,57 @@ class TestRegistryGetVariants:
|
||||
registry = FixtureRegistry()
|
||||
with pytest.raises(KeyError, match="not found"):
|
||||
registry.get_variants("no_such_fixture")
|
||||
|
||||
|
||||
class TestInstanceToDict:
|
||||
"""Unit tests for the _instance_to_dict helper."""
|
||||
|
||||
def test_explicit_values_included(self):
|
||||
"""All explicitly set column values appear in the result."""
|
||||
role_id = uuid.uuid4()
|
||||
instance = Role(id=role_id, name="admin")
|
||||
d = _instance_to_dict(instance)
|
||||
assert d["id"] == role_id
|
||||
assert d["name"] == "admin"
|
||||
|
||||
def test_callable_default_none_excluded(self):
|
||||
"""A column whose value is None but has a callable Python-side default
|
||||
(e.g. ``default=uuid.uuid4``) is excluded so the DB generates it."""
|
||||
instance = Role(id=None, name="admin")
|
||||
d = _instance_to_dict(instance)
|
||||
assert "id" not in d
|
||||
assert d["name"] == "admin"
|
||||
|
||||
def test_nullable_none_included(self):
|
||||
"""None on a nullable column with no default is kept (explicit NULL)."""
|
||||
instance = User(id=uuid.uuid4(), username="u", email="e@e.com", role_id=None)
|
||||
d = _instance_to_dict(instance)
|
||||
assert "role_id" in d
|
||||
assert d["role_id"] is None
|
||||
|
||||
|
||||
class TestBatchMergeNonPkColumns:
|
||||
"""Batch MERGE on a model with no non-PK columns (PK-only table)."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_merge_pk_only_model(self, db_session: AsyncSession):
|
||||
"""MERGE strategy on a PK-only model uses on_conflict_do_nothing."""
|
||||
registry = FixtureRegistry()
|
||||
|
||||
@registry.register
|
||||
def permissions():
|
||||
return [
|
||||
Permission(subject="post", action="read"),
|
||||
Permission(subject="post", action="write"),
|
||||
]
|
||||
|
||||
result = await load_fixtures(
|
||||
db_session, registry, "permissions", strategy=LoadStrategy.MERGE
|
||||
)
|
||||
assert len(result["permissions"]) == 2
|
||||
|
||||
# Run again — conflicts are silently ignored.
|
||||
result2 = await load_fixtures(
|
||||
db_session, registry, "permissions", strategy=LoadStrategy.MERGE
|
||||
)
|
||||
assert len(result2["permissions"]) == 2
|
||||
|
||||
Reference in New Issue
Block a user