From 29326ab532a321ca5851d90fde4540ece6ea9cb6 Mon Sep 17 00:00:00 2001 From: d3vyce <44915747+d3vyce@users.noreply.github.com> Date: Thu, 26 Mar 2026 20:29:25 +0100 Subject: [PATCH] perf: batch insert fixtures (#188) --- src/fastapi_toolsets/fixtures/utils.py | 141 +++++++++++++++++++++---- tests/test_fixtures.py | 56 +++++++++- 2 files changed, 176 insertions(+), 21 deletions(-) diff --git a/src/fastapi_toolsets/fixtures/utils.py b/src/fastapi_toolsets/fixtures/utils.py index fcde2b2..bc3ad3b 100644 --- a/src/fastapi_toolsets/fixtures/utils.py +++ b/src/fastapi_toolsets/fixtures/utils.py @@ -4,6 +4,8 @@ from collections.abc import Callable, Sequence from enum import Enum from typing import Any +from sqlalchemy import inspect as sa_inspect +from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import DeclarativeBase @@ -16,6 +18,113 @@ from .registry import FixtureRegistry, _normalize_contexts logger = get_logger() +def _instance_to_dict(instance: DeclarativeBase) -> dict[str, Any]: + """Extract column values from a model instance, skipping unset server-default columns.""" + state = sa_inspect(instance) + state_dict = state.dict + result: dict[str, Any] = {} + for prop in state.mapper.column_attrs: + if prop.key not in state_dict: + continue + val = state_dict[prop.key] + if val is None: + col = prop.columns[0] + + if col.server_default is not None or ( + col.default is not None and col.default.is_callable + ): + continue + result[prop.key] = val + return result + + +def _group_by_type( + instances: list[DeclarativeBase], +) -> list[tuple[type[DeclarativeBase], list[DeclarativeBase]]]: + """Group instances by their concrete model class, preserving insertion order.""" + groups: dict[type[DeclarativeBase], list[DeclarativeBase]] = {} + for instance in instances: + groups.setdefault(type(instance), []).append(instance) + return list(groups.items()) + + +async def _batch_insert( + session: AsyncSession, + model_cls: type[DeclarativeBase], + instances: list[DeclarativeBase], +) -> None: + """INSERT all instances — raises on conflict (no duplicate handling).""" + dicts = [_instance_to_dict(i) for i in instances] + await session.execute(pg_insert(model_cls).values(dicts)) + + +async def _batch_merge( + session: AsyncSession, + model_cls: type[DeclarativeBase], + instances: list[DeclarativeBase], +) -> None: + """UPSERT: insert new rows, update existing ones with the provided values.""" + mapper = model_cls.__mapper__ + pk_names = [col.name for col in mapper.primary_key] + pk_names_set = set(pk_names) + non_pk_cols = [ + prop.key + for prop in mapper.column_attrs + if not any(col.name in pk_names_set for col in prop.columns) + ] + + dicts = [_instance_to_dict(i) for i in instances] + stmt = pg_insert(model_cls).values(dicts) + + if non_pk_cols: + stmt = stmt.on_conflict_do_update( + index_elements=pk_names, + set_={col: stmt.excluded[col] for col in non_pk_cols}, + ) + else: + stmt = stmt.on_conflict_do_nothing(index_elements=pk_names) + + await session.execute(stmt) + + +async def _batch_skip_existing( + session: AsyncSession, + model_cls: type[DeclarativeBase], + instances: list[DeclarativeBase], +) -> list[DeclarativeBase]: + """INSERT only rows that do not already exist; return the inserted ones.""" + mapper = model_cls.__mapper__ + pk_names = [col.name for col in mapper.primary_key] + + no_pk: list[DeclarativeBase] = [] + with_pk_pairs: list[tuple[DeclarativeBase, Any]] = [] + for inst in instances: + pk = _get_primary_key(inst) + if pk is None: + no_pk.append(inst) + else: + with_pk_pairs.append((inst, pk)) + + loaded: list[DeclarativeBase] = list(no_pk) + if no_pk: + await session.execute( + pg_insert(model_cls).values([_instance_to_dict(i) for i in no_pk]) + ) + + if with_pk_pairs: + with_pk = [i for i, _ in with_pk_pairs] + stmt = ( + pg_insert(model_cls) + .values([_instance_to_dict(i) for i in with_pk]) + .on_conflict_do_nothing(index_elements=pk_names) + ) + result = await session.execute(stmt.returning(*mapper.primary_key)) + inserted_pks = {row[0] if len(pk_names) == 1 else tuple(row) for row in result} + loaded.extend(inst for inst, pk in with_pk_pairs if pk in inserted_pks) + + return loaded + + async def _load_ordered( session: AsyncSession, registry: FixtureRegistry, @@ -23,7 +132,7 @@ async def _load_ordered( strategy: LoadStrategy, contexts: tuple[str, ...] | None = None, ) -> dict[str, list[DeclarativeBase]]: - """Load fixtures in order. + """Load fixtures in order using batch Core INSERT statements. When *contexts* is provided only variants whose context set intersects with *contexts* are called for each name; their instances are concatenated. @@ -59,25 +168,17 @@ async def _load_ordered( loaded: list[DeclarativeBase] = [] async with get_transaction(session): - for instance in instances: - if strategy == LoadStrategy.INSERT: - session.add(instance) - loaded.append(instance) - - elif strategy == LoadStrategy.MERGE: - merged = await session.merge(instance) - loaded.append(merged) - - else: # LoadStrategy.SKIP_EXISTING - pk = _get_primary_key(instance) - if pk is not None: - existing = await session.get(type(instance), pk) - if existing is None: - session.add(instance) - loaded.append(instance) - else: - session.add(instance) - loaded.append(instance) + for model_cls, group in _group_by_type(instances): + match strategy: + case LoadStrategy.INSERT: + await _batch_insert(session, model_cls, group) + loaded.extend(group) + case LoadStrategy.MERGE: + await _batch_merge(session, model_cls, group) + loaded.extend(group) + case LoadStrategy.SKIP_EXISTING: + inserted = await _batch_skip_existing(session, model_cls, group) + loaded.extend(inserted) results[name] = loaded logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)") diff --git a/tests/test_fixtures.py b/tests/test_fixtures.py index 3fbe600..d2455ef 100644 --- a/tests/test_fixtures.py +++ b/tests/test_fixtures.py @@ -15,7 +15,7 @@ from fastapi_toolsets.fixtures import ( load_fixtures_by_context, ) -from fastapi_toolsets.fixtures.utils import _get_primary_key +from fastapi_toolsets.fixtures.utils import _get_primary_key, _instance_to_dict from .conftest import IntRole, Permission, Role, RoleCrud, User, UserCrud @@ -945,3 +945,57 @@ class TestRegistryGetVariants: registry = FixtureRegistry() with pytest.raises(KeyError, match="not found"): registry.get_variants("no_such_fixture") + + +class TestInstanceToDict: + """Unit tests for the _instance_to_dict helper.""" + + def test_explicit_values_included(self): + """All explicitly set column values appear in the result.""" + role_id = uuid.uuid4() + instance = Role(id=role_id, name="admin") + d = _instance_to_dict(instance) + assert d["id"] == role_id + assert d["name"] == "admin" + + def test_callable_default_none_excluded(self): + """A column whose value is None but has a callable Python-side default + (e.g. ``default=uuid.uuid4``) is excluded so the DB generates it.""" + instance = Role(id=None, name="admin") + d = _instance_to_dict(instance) + assert "id" not in d + assert d["name"] == "admin" + + def test_nullable_none_included(self): + """None on a nullable column with no default is kept (explicit NULL).""" + instance = User(id=uuid.uuid4(), username="u", email="e@e.com", role_id=None) + d = _instance_to_dict(instance) + assert "role_id" in d + assert d["role_id"] is None + + +class TestBatchMergeNonPkColumns: + """Batch MERGE on a model with no non-PK columns (PK-only table).""" + + @pytest.mark.anyio + async def test_merge_pk_only_model(self, db_session: AsyncSession): + """MERGE strategy on a PK-only model uses on_conflict_do_nothing.""" + registry = FixtureRegistry() + + @registry.register + def permissions(): + return [ + Permission(subject="post", action="read"), + Permission(subject="post", action="write"), + ] + + result = await load_fixtures( + db_session, registry, "permissions", strategy=LoadStrategy.MERGE + ) + assert len(result["permissions"]) == 2 + + # Run again — conflicts are silently ignored. + result2 = await load_fixtures( + db_session, registry, "permissions", strategy=LoadStrategy.MERGE + ) + assert len(result2["permissions"]) == 2