From 0b3f097012fa9a63d66b532880dd3afb7e4fc3af Mon Sep 17 00:00:00 2001 From: d3vyce <44915747+d3vyce@users.noreply.github.com> Date: Mon, 30 Mar 2026 18:50:05 +0200 Subject: [PATCH] fix: batch insert normalizes away omitted nullable columns (#198) --- src/fastapi_toolsets/fixtures/utils.py | 108 ++++---- tests/test_fixtures.py | 366 +++++++++++++++++++++++++ 2 files changed, 425 insertions(+), 49 deletions(-) diff --git a/src/fastapi_toolsets/fixtures/utils.py b/src/fastapi_toolsets/fixtures/utils.py index 4edb64e..ac26d3d 100644 --- a/src/fastapi_toolsets/fixtures/utils.py +++ b/src/fastapi_toolsets/fixtures/utils.py @@ -30,20 +30,16 @@ def _instance_to_dict(instance: DeclarativeBase) -> dict[str, Any]: if val is None: col = prop.columns[0] - if col.server_default is not None or ( - col.default is not None and col.default.is_callable + if ( + col.server_default is not None + or (col.default is not None and col.default.is_callable) + or col.autoincrement is True ): continue result[prop.key] = val return result -def _normalize_rows(dicts: list[dict[str, Any]]) -> list[dict[str, Any]]: - """Ensure all row dicts share the same key set.""" - all_keys: set[str] = set().union(*dicts) - return [{k: d.get(k) for k in all_keys} for d in dicts] - - def _group_by_type( instances: list[DeclarativeBase], ) -> list[tuple[type[DeclarativeBase], list[DeclarativeBase]]]: @@ -54,14 +50,32 @@ def _group_by_type( return list(groups.items()) +def _group_by_column_set( + dicts: list[dict[str, Any]], + instances: list[DeclarativeBase], +) -> list[tuple[list[dict[str, Any]], list[DeclarativeBase]]]: + """Group (dict, instance) pairs by their dict key sets.""" + groups: dict[ + frozenset[str], tuple[list[dict[str, Any]], list[DeclarativeBase]] + ] = {} + for d, inst in zip(dicts, instances): + key = frozenset(d) + if key not in groups: + groups[key] = ([], []) + groups[key][0].append(d) + groups[key][1].append(inst) + return list(groups.values()) + + async def _batch_insert( session: AsyncSession, model_cls: type[DeclarativeBase], instances: list[DeclarativeBase], ) -> None: """INSERT all instances — raises on conflict (no duplicate handling).""" - dicts = _normalize_rows([_instance_to_dict(i) for i in instances]) - await session.execute(pg_insert(model_cls).values(dicts)) + dicts = [_instance_to_dict(i) for i in instances] + for group_dicts, _ in _group_by_column_set(dicts, instances): + await session.execute(pg_insert(model_cls).values(group_dicts)) async def _batch_merge( @@ -79,21 +93,22 @@ async def _batch_merge( if not any(col.name in pk_names_set for col in prop.columns) ] - dicts = _normalize_rows([_instance_to_dict(i) for i in instances]) - stmt = pg_insert(model_cls).values(dicts) + dicts = [_instance_to_dict(i) for i in instances] + for group_dicts, _ in _group_by_column_set(dicts, instances): + stmt = pg_insert(model_cls).values(group_dicts) - inserted_keys = set(dicts[0]) if dicts else set() - update_cols = [col for col in non_pk_cols if col in inserted_keys] + inserted_keys = set(group_dicts[0]) + update_cols = [col for col in non_pk_cols if col in inserted_keys] - if update_cols: - stmt = stmt.on_conflict_do_update( - index_elements=pk_names, - set_={col: stmt.excluded[col] for col in update_cols}, - ) - else: - stmt = stmt.on_conflict_do_nothing(index_elements=pk_names) + if update_cols: + stmt = stmt.on_conflict_do_update( + index_elements=pk_names, + set_={col: stmt.excluded[col] for col in update_cols}, + ) + else: + stmt = stmt.on_conflict_do_nothing(index_elements=pk_names) - await session.execute(stmt) + await session.execute(stmt) async def _batch_skip_existing( @@ -116,22 +131,30 @@ async def _batch_skip_existing( loaded: list[DeclarativeBase] = list(no_pk) if no_pk: - await session.execute( - pg_insert(model_cls).values( - _normalize_rows([_instance_to_dict(i) for i in no_pk]) - ) - ) + no_pk_dicts = [_instance_to_dict(i) for i in no_pk] + for group_dicts, _ in _group_by_column_set(no_pk_dicts, no_pk): + await session.execute(pg_insert(model_cls).values(group_dicts)) if with_pk_pairs: with_pk = [i for i, _ in with_pk_pairs] - stmt = ( - pg_insert(model_cls) - .values(_normalize_rows([_instance_to_dict(i) for i in with_pk])) - .on_conflict_do_nothing(index_elements=pk_names) - ) - result = await session.execute(stmt.returning(*mapper.primary_key)) - inserted_pks = {row[0] if len(pk_names) == 1 else tuple(row) for row in result} - loaded.extend(inst for inst, pk in with_pk_pairs if pk in inserted_pks) + with_pk_dicts = [_instance_to_dict(i) for i in with_pk] + for group_dicts, group_insts in _group_by_column_set(with_pk_dicts, with_pk): + stmt = ( + pg_insert(model_cls) + .values(group_dicts) + .on_conflict_do_nothing(index_elements=pk_names) + ) + result = await session.execute(stmt.returning(*mapper.primary_key)) + inserted_pks = { + row[0] if len(pk_names) == 1 else tuple(row) for row in result + } + loaded.extend( + inst + for inst, pk in zip( + group_insts, [_get_primary_key(i) for i in group_insts] + ) + if pk in inserted_pks + ) return loaded @@ -143,12 +166,7 @@ async def _load_ordered( strategy: LoadStrategy, contexts: tuple[str, ...] | None = None, ) -> dict[str, list[DeclarativeBase]]: - """Load fixtures in order using batch Core INSERT statements. - - When *contexts* is provided only variants whose context set intersects with - *contexts* are called for each name; their instances are concatenated. - When *contexts* is ``None`` all variants of each name are loaded. - """ + """Load fixtures in order using batch Core INSERT statements.""" results: dict[str, list[DeclarativeBase]] = {} for name in ordered_names: @@ -158,10 +176,6 @@ async def _load_ordered( else registry.get_variants(name) ) - # Cross-context dependency fallback: if we're loading by context but - # no variant matches (e.g. a "base"-only fixture required by a - # "testing" fixture), load all available variants so the dependency - # is satisfied. if contexts is not None and not variants: variants = registry.get_variants(name) @@ -267,10 +281,6 @@ async def load_fixtures_by_context( ) -> dict[str, list[DeclarativeBase]]: """Load all fixtures for specific contexts. - For each fixture name, only the variants whose context set intersects with - *contexts* are loaded. When a name has variants in multiple of the - requested contexts, their instances are merged before being inserted. - Args: session: Database session registry: Fixture registry diff --git a/tests/test_fixtures.py b/tests/test_fixtures.py index f3c0e0a..81f2c2d 100644 --- a/tests/test_fixtures.py +++ b/tests/test_fixtures.py @@ -1011,6 +1011,14 @@ class TestInstanceToDict: assert "id" not in d assert d["name"] == "admin" + def test_autoincrement_none_excluded(self): + """A column whose value is None but has autoincrement=True is excluded + so the DB generates the value via its sequence.""" + instance = IntRole(id=None, name="admin") + d = _instance_to_dict(instance) + assert "id" not in d + assert d["name"] == "admin" + def test_nullable_none_included(self): """None on a nullable column with no default is kept (explicit NULL).""" instance = User(id=uuid.uuid4(), username="u", email="e@e.com", role_id=None) @@ -1107,3 +1115,361 @@ class TestBatchMergeNonPkColumns: db_session, registry, "permissions", strategy=LoadStrategy.MERGE ) assert len(result2["permissions"]) == 2 + + +class TestBatchNullableColumnEdgeCases: + """Deep tests for nullable column handling during batch import.""" + + @pytest.mark.anyio + async def test_insert_batch_mixed_nullable_fk(self, db_session: AsyncSession): + """INSERT batch where some rows set a nullable FK and others don't. + + After normalization the omitted role_id becomes None. For INSERT this + is acceptable — both rows should insert successfully with the correct + values (one with FK, one with NULL). + """ + registry = FixtureRegistry() + admin = await RoleCrud.create(db_session, RoleCreate(name="admin")) + uid1 = uuid.uuid4() + uid2 = uuid.uuid4() + + @registry.register + def users(): + return [ + User( + id=uid1, username="with_role", email="a@test.com", role_id=admin.id + ), + User(id=uid2, username="no_role", email="b@test.com"), + ] + + await load_fixtures(db_session, registry, "users", strategy=LoadStrategy.INSERT) + + from sqlalchemy import select + + rows = { + r.username: r + for r in (await db_session.execute(select(User))).scalars().all() + } + assert rows["with_role"].role_id == admin.id + assert rows["no_role"].role_id is None + + @pytest.mark.anyio + async def test_insert_batch_mixed_nullable_notes(self, db_session: AsyncSession): + """INSERT batch where some rows have notes and others don't. + + Ensures normalization doesn't break the insert and that each row gets + the intended value. + """ + registry = FixtureRegistry() + uid1 = uuid.uuid4() + uid2 = uuid.uuid4() + uid3 = uuid.uuid4() + + @registry.register + def users(): + return [ + User( + id=uid1, + username="has_notes", + email="a@test.com", + notes="important", + ), + User(id=uid2, username="no_notes", email="b@test.com"), + User(id=uid3, username="null_notes", email="c@test.com", notes=None), + ] + + await load_fixtures(db_session, registry, "users", strategy=LoadStrategy.INSERT) + + from sqlalchemy import select + + rows = { + r.username: r + for r in (await db_session.execute(select(User))).scalars().all() + } + assert rows["has_notes"].notes == "important" + assert rows["no_notes"].notes is None + assert rows["null_notes"].notes is None + + @pytest.mark.anyio + async def test_merge_batch_mixed_nullable_does_not_overwrite( + self, db_session: AsyncSession + ): + """MERGE batch where one row sets a nullable column and another omits it. + + If both rows already exist in DB, the row that omits the column must + NOT have its existing value overwritten with NULL. + + This is the core normalization bug: _normalize_rows fills missing keys + with None, and then MERGE's SET clause includes that column for ALL rows. + """ + from sqlalchemy import select + + admin = await RoleCrud.create(db_session, RoleCreate(name="admin")) + uid1 = uuid.uuid4() + uid2 = uuid.uuid4() + + # Pre-populate: both users have role_id and notes + registry_initial = FixtureRegistry() + + @registry_initial.register + def users(): + return [ + User( + id=uid1, + username="alice", + email="a@test.com", + role_id=admin.id, + notes="alice notes", + ), + User( + id=uid2, + username="bob", + email="b@test.com", + role_id=admin.id, + notes="bob notes", + ), + ] + + await load_fixtures( + db_session, registry_initial, "users", strategy=LoadStrategy.INSERT + ) + + # Re-merge: alice updates notes, bob omits notes entirely + registry_merge = FixtureRegistry() + + @registry_merge.register + def users(): # noqa: F811 + return [ + User( + id=uid1, + username="alice", + email="a@test.com", + role_id=admin.id, + notes="updated", + ), + User( + id=uid2, + username="bob", + email="b@test.com", + role_id=admin.id, + ), # notes omitted + ] + + await load_fixtures( + db_session, registry_merge, "users", strategy=LoadStrategy.MERGE + ) + + rows = { + r.username: r + for r in (await db_session.execute(select(User))).scalars().all() + } + assert rows["alice"].notes == "updated" + # Bob's notes must be preserved, NOT overwritten with NULL + assert rows["bob"].notes == "bob notes" + + @pytest.mark.anyio + async def test_merge_batch_mixed_nullable_fk_preserves_existing( + self, db_session: AsyncSession + ): + """MERGE batch where one row sets role_id and another omits it. + + The row that omits role_id must keep its existing DB value. + """ + from sqlalchemy import select + + admin = await RoleCrud.create(db_session, RoleCreate(name="admin")) + editor = await RoleCrud.create(db_session, RoleCreate(name="editor")) + uid1 = uuid.uuid4() + uid2 = uuid.uuid4() + + # Pre-populate + registry_initial = FixtureRegistry() + + @registry_initial.register + def users(): + return [ + User( + id=uid1, + username="alice", + email="a@test.com", + role_id=admin.id, + ), + User( + id=uid2, + username="bob", + email="b@test.com", + role_id=editor.id, + ), + ] + + await load_fixtures( + db_session, registry_initial, "users", strategy=LoadStrategy.INSERT + ) + + # Re-merge: alice changes role, bob omits role_id + registry_merge = FixtureRegistry() + + @registry_merge.register + def users(): # noqa: F811 + return [ + User( + id=uid1, + username="alice", + email="a@test.com", + role_id=editor.id, + ), + User(id=uid2, username="bob", email="b@test.com"), # role_id omitted + ] + + await load_fixtures( + db_session, registry_merge, "users", strategy=LoadStrategy.MERGE + ) + + rows = { + r.username: r + for r in (await db_session.execute(select(User))).scalars().all() + } + assert rows["alice"].role_id == editor.id # updated + assert rows["bob"].role_id == editor.id # must be preserved, NOT NULL + + @pytest.mark.anyio + async def test_insert_batch_mixed_pk_presence(self, db_session: AsyncSession): + """INSERT batch where some rows have explicit PK and others rely on + the callable default (uuid.uuid4). + + Normalization adds the PK key with None to rows that omitted it, + which can cause NOT NULL violations on the PK column. + """ + registry = FixtureRegistry() + explicit_id = uuid.uuid4() + + @registry.register + def roles(): + return [ + Role(id=explicit_id, name="admin"), + Role(name="user"), # PK omitted, relies on default=uuid.uuid4 + ] + + await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.INSERT) + + from sqlalchemy import select + + rows = (await db_session.execute(select(Role))).scalars().all() + assert len(rows) == 2 + names = {r.name for r in rows} + assert names == {"admin", "user"} + # The "admin" row must have the explicit ID + admin = next(r for r in rows if r.name == "admin") + assert admin.id == explicit_id + # The "user" row must have a generated UUID (not None) + user = next(r for r in rows if r.name == "user") + assert user.id is not None + + @pytest.mark.anyio + async def test_skip_existing_batch_mixed_nullable(self, db_session: AsyncSession): + """SKIP_EXISTING with mixed nullable columns inserts correctly. + + Only new rows are inserted; existing rows are untouched regardless of + which columns the fixture provides. + """ + from sqlalchemy import select + + admin = await RoleCrud.create(db_session, RoleCreate(name="admin")) + uid1 = uuid.uuid4() + uid2 = uuid.uuid4() + + # Pre-populate uid1 with notes + registry_initial = FixtureRegistry() + + @registry_initial.register + def users(): + return [ + User( + id=uid1, + username="alice", + email="a@test.com", + role_id=admin.id, + notes="keep me", + ), + ] + + await load_fixtures( + db_session, registry_initial, "users", strategy=LoadStrategy.INSERT + ) + + # Load again with SKIP_EXISTING: uid1 already exists, uid2 is new + registry_skip = FixtureRegistry() + + @registry_skip.register + def users(): # noqa: F811 + return [ + User(id=uid1, username="alice-updated", email="a@test.com"), # exists + User( + id=uid2, + username="bob", + email="b@test.com", + notes="new user", + ), # new + ] + + result = await load_fixtures( + db_session, registry_skip, "users", strategy=LoadStrategy.SKIP_EXISTING + ) + assert len(result["users"]) == 1 # only bob inserted + + rows = { + r.username: r + for r in (await db_session.execute(select(User))).scalars().all() + } + # alice untouched + assert rows["alice"].role_id == admin.id + assert rows["alice"].notes == "keep me" + # bob inserted correctly + assert rows["bob"].notes == "new user" + + @pytest.mark.anyio + async def test_insert_batch_every_row_different_nullable_columns( + self, db_session: AsyncSession + ): + """Each row in the batch sets a different combination of nullable columns. + + Tests that normalization produces valid SQL for all rows. + """ + registry = FixtureRegistry() + admin = await RoleCrud.create(db_session, RoleCreate(name="admin")) + uid1 = uuid.uuid4() + uid2 = uuid.uuid4() + uid3 = uuid.uuid4() + + @registry.register + def users(): + return [ + User( + id=uid1, + username="all_set", + email="a@test.com", + role_id=admin.id, + notes="full", + ), + User( + id=uid2, username="only_role", email="b@test.com", role_id=admin.id + ), + User( + id=uid3, username="only_notes", email="c@test.com", notes="partial" + ), + ] + + await load_fixtures(db_session, registry, "users", strategy=LoadStrategy.INSERT) + + from sqlalchemy import select + + rows = { + r.username: r + for r in (await db_session.execute(select(User))).scalars().all() + } + assert rows["all_set"].role_id == admin.id + assert rows["all_set"].notes == "full" + assert rows["only_role"].role_id == admin.id + assert rows["only_role"].notes is None + assert rows["only_notes"].role_id is None + assert rows["only_notes"].notes == "partial"