fix: batch insert normalizes away omitted nullable columns (#198)

This commit is contained in:
d3vyce
2026-03-30 18:50:05 +02:00
committed by GitHub
parent 1890d696bf
commit 0b3f097012
2 changed files with 425 additions and 49 deletions

View File

@@ -30,20 +30,16 @@ def _instance_to_dict(instance: DeclarativeBase) -> dict[str, Any]:
if val is None:
col = prop.columns[0]
if col.server_default is not None or (
col.default is not None and col.default.is_callable
if (
col.server_default is not None
or (col.default is not None and col.default.is_callable)
or col.autoincrement is True
):
continue
result[prop.key] = val
return result
def _normalize_rows(dicts: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Ensure all row dicts share the same key set."""
all_keys: set[str] = set().union(*dicts)
return [{k: d.get(k) for k in all_keys} for d in dicts]
def _group_by_type(
instances: list[DeclarativeBase],
) -> list[tuple[type[DeclarativeBase], list[DeclarativeBase]]]:
@@ -54,14 +50,32 @@ def _group_by_type(
return list(groups.items())
def _group_by_column_set(
dicts: list[dict[str, Any]],
instances: list[DeclarativeBase],
) -> list[tuple[list[dict[str, Any]], list[DeclarativeBase]]]:
"""Group (dict, instance) pairs by their dict key sets."""
groups: dict[
frozenset[str], tuple[list[dict[str, Any]], list[DeclarativeBase]]
] = {}
for d, inst in zip(dicts, instances):
key = frozenset(d)
if key not in groups:
groups[key] = ([], [])
groups[key][0].append(d)
groups[key][1].append(inst)
return list(groups.values())
async def _batch_insert(
session: AsyncSession,
model_cls: type[DeclarativeBase],
instances: list[DeclarativeBase],
) -> None:
"""INSERT all instances — raises on conflict (no duplicate handling)."""
dicts = _normalize_rows([_instance_to_dict(i) for i in instances])
await session.execute(pg_insert(model_cls).values(dicts))
dicts = [_instance_to_dict(i) for i in instances]
for group_dicts, _ in _group_by_column_set(dicts, instances):
await session.execute(pg_insert(model_cls).values(group_dicts))
async def _batch_merge(
@@ -79,21 +93,22 @@ async def _batch_merge(
if not any(col.name in pk_names_set for col in prop.columns)
]
dicts = _normalize_rows([_instance_to_dict(i) for i in instances])
stmt = pg_insert(model_cls).values(dicts)
dicts = [_instance_to_dict(i) for i in instances]
for group_dicts, _ in _group_by_column_set(dicts, instances):
stmt = pg_insert(model_cls).values(group_dicts)
inserted_keys = set(dicts[0]) if dicts else set()
update_cols = [col for col in non_pk_cols if col in inserted_keys]
inserted_keys = set(group_dicts[0])
update_cols = [col for col in non_pk_cols if col in inserted_keys]
if update_cols:
stmt = stmt.on_conflict_do_update(
index_elements=pk_names,
set_={col: stmt.excluded[col] for col in update_cols},
)
else:
stmt = stmt.on_conflict_do_nothing(index_elements=pk_names)
if update_cols:
stmt = stmt.on_conflict_do_update(
index_elements=pk_names,
set_={col: stmt.excluded[col] for col in update_cols},
)
else:
stmt = stmt.on_conflict_do_nothing(index_elements=pk_names)
await session.execute(stmt)
await session.execute(stmt)
async def _batch_skip_existing(
@@ -116,22 +131,30 @@ async def _batch_skip_existing(
loaded: list[DeclarativeBase] = list(no_pk)
if no_pk:
await session.execute(
pg_insert(model_cls).values(
_normalize_rows([_instance_to_dict(i) for i in no_pk])
)
)
no_pk_dicts = [_instance_to_dict(i) for i in no_pk]
for group_dicts, _ in _group_by_column_set(no_pk_dicts, no_pk):
await session.execute(pg_insert(model_cls).values(group_dicts))
if with_pk_pairs:
with_pk = [i for i, _ in with_pk_pairs]
stmt = (
pg_insert(model_cls)
.values(_normalize_rows([_instance_to_dict(i) for i in with_pk]))
.on_conflict_do_nothing(index_elements=pk_names)
)
result = await session.execute(stmt.returning(*mapper.primary_key))
inserted_pks = {row[0] if len(pk_names) == 1 else tuple(row) for row in result}
loaded.extend(inst for inst, pk in with_pk_pairs if pk in inserted_pks)
with_pk_dicts = [_instance_to_dict(i) for i in with_pk]
for group_dicts, group_insts in _group_by_column_set(with_pk_dicts, with_pk):
stmt = (
pg_insert(model_cls)
.values(group_dicts)
.on_conflict_do_nothing(index_elements=pk_names)
)
result = await session.execute(stmt.returning(*mapper.primary_key))
inserted_pks = {
row[0] if len(pk_names) == 1 else tuple(row) for row in result
}
loaded.extend(
inst
for inst, pk in zip(
group_insts, [_get_primary_key(i) for i in group_insts]
)
if pk in inserted_pks
)
return loaded
@@ -143,12 +166,7 @@ async def _load_ordered(
strategy: LoadStrategy,
contexts: tuple[str, ...] | None = None,
) -> dict[str, list[DeclarativeBase]]:
"""Load fixtures in order using batch Core INSERT statements.
When *contexts* is provided only variants whose context set intersects with
*contexts* are called for each name; their instances are concatenated.
When *contexts* is ``None`` all variants of each name are loaded.
"""
"""Load fixtures in order using batch Core INSERT statements."""
results: dict[str, list[DeclarativeBase]] = {}
for name in ordered_names:
@@ -158,10 +176,6 @@ async def _load_ordered(
else registry.get_variants(name)
)
# Cross-context dependency fallback: if we're loading by context but
# no variant matches (e.g. a "base"-only fixture required by a
# "testing" fixture), load all available variants so the dependency
# is satisfied.
if contexts is not None and not variants:
variants = registry.get_variants(name)
@@ -267,10 +281,6 @@ async def load_fixtures_by_context(
) -> dict[str, list[DeclarativeBase]]:
"""Load all fixtures for specific contexts.
For each fixture name, only the variants whose context set intersects with
*contexts* are loaded. When a name has variants in multiple of the
requested contexts, their instances are merged before being inserted.
Args:
session: Database session
registry: Fixture registry

View File

@@ -1011,6 +1011,14 @@ class TestInstanceToDict:
assert "id" not in d
assert d["name"] == "admin"
def test_autoincrement_none_excluded(self):
"""A column whose value is None but has autoincrement=True is excluded
so the DB generates the value via its sequence."""
instance = IntRole(id=None, name="admin")
d = _instance_to_dict(instance)
assert "id" not in d
assert d["name"] == "admin"
def test_nullable_none_included(self):
"""None on a nullable column with no default is kept (explicit NULL)."""
instance = User(id=uuid.uuid4(), username="u", email="e@e.com", role_id=None)
@@ -1107,3 +1115,361 @@ class TestBatchMergeNonPkColumns:
db_session, registry, "permissions", strategy=LoadStrategy.MERGE
)
assert len(result2["permissions"]) == 2
class TestBatchNullableColumnEdgeCases:
"""Deep tests for nullable column handling during batch import."""
@pytest.mark.anyio
async def test_insert_batch_mixed_nullable_fk(self, db_session: AsyncSession):
"""INSERT batch where some rows set a nullable FK and others don't.
After normalization the omitted role_id becomes None. For INSERT this
is acceptable — both rows should insert successfully with the correct
values (one with FK, one with NULL).
"""
registry = FixtureRegistry()
admin = await RoleCrud.create(db_session, RoleCreate(name="admin"))
uid1 = uuid.uuid4()
uid2 = uuid.uuid4()
@registry.register
def users():
return [
User(
id=uid1, username="with_role", email="a@test.com", role_id=admin.id
),
User(id=uid2, username="no_role", email="b@test.com"),
]
await load_fixtures(db_session, registry, "users", strategy=LoadStrategy.INSERT)
from sqlalchemy import select
rows = {
r.username: r
for r in (await db_session.execute(select(User))).scalars().all()
}
assert rows["with_role"].role_id == admin.id
assert rows["no_role"].role_id is None
@pytest.mark.anyio
async def test_insert_batch_mixed_nullable_notes(self, db_session: AsyncSession):
"""INSERT batch where some rows have notes and others don't.
Ensures normalization doesn't break the insert and that each row gets
the intended value.
"""
registry = FixtureRegistry()
uid1 = uuid.uuid4()
uid2 = uuid.uuid4()
uid3 = uuid.uuid4()
@registry.register
def users():
return [
User(
id=uid1,
username="has_notes",
email="a@test.com",
notes="important",
),
User(id=uid2, username="no_notes", email="b@test.com"),
User(id=uid3, username="null_notes", email="c@test.com", notes=None),
]
await load_fixtures(db_session, registry, "users", strategy=LoadStrategy.INSERT)
from sqlalchemy import select
rows = {
r.username: r
for r in (await db_session.execute(select(User))).scalars().all()
}
assert rows["has_notes"].notes == "important"
assert rows["no_notes"].notes is None
assert rows["null_notes"].notes is None
@pytest.mark.anyio
async def test_merge_batch_mixed_nullable_does_not_overwrite(
self, db_session: AsyncSession
):
"""MERGE batch where one row sets a nullable column and another omits it.
If both rows already exist in DB, the row that omits the column must
NOT have its existing value overwritten with NULL.
This is the core normalization bug: _normalize_rows fills missing keys
with None, and then MERGE's SET clause includes that column for ALL rows.
"""
from sqlalchemy import select
admin = await RoleCrud.create(db_session, RoleCreate(name="admin"))
uid1 = uuid.uuid4()
uid2 = uuid.uuid4()
# Pre-populate: both users have role_id and notes
registry_initial = FixtureRegistry()
@registry_initial.register
def users():
return [
User(
id=uid1,
username="alice",
email="a@test.com",
role_id=admin.id,
notes="alice notes",
),
User(
id=uid2,
username="bob",
email="b@test.com",
role_id=admin.id,
notes="bob notes",
),
]
await load_fixtures(
db_session, registry_initial, "users", strategy=LoadStrategy.INSERT
)
# Re-merge: alice updates notes, bob omits notes entirely
registry_merge = FixtureRegistry()
@registry_merge.register
def users(): # noqa: F811
return [
User(
id=uid1,
username="alice",
email="a@test.com",
role_id=admin.id,
notes="updated",
),
User(
id=uid2,
username="bob",
email="b@test.com",
role_id=admin.id,
), # notes omitted
]
await load_fixtures(
db_session, registry_merge, "users", strategy=LoadStrategy.MERGE
)
rows = {
r.username: r
for r in (await db_session.execute(select(User))).scalars().all()
}
assert rows["alice"].notes == "updated"
# Bob's notes must be preserved, NOT overwritten with NULL
assert rows["bob"].notes == "bob notes"
@pytest.mark.anyio
async def test_merge_batch_mixed_nullable_fk_preserves_existing(
self, db_session: AsyncSession
):
"""MERGE batch where one row sets role_id and another omits it.
The row that omits role_id must keep its existing DB value.
"""
from sqlalchemy import select
admin = await RoleCrud.create(db_session, RoleCreate(name="admin"))
editor = await RoleCrud.create(db_session, RoleCreate(name="editor"))
uid1 = uuid.uuid4()
uid2 = uuid.uuid4()
# Pre-populate
registry_initial = FixtureRegistry()
@registry_initial.register
def users():
return [
User(
id=uid1,
username="alice",
email="a@test.com",
role_id=admin.id,
),
User(
id=uid2,
username="bob",
email="b@test.com",
role_id=editor.id,
),
]
await load_fixtures(
db_session, registry_initial, "users", strategy=LoadStrategy.INSERT
)
# Re-merge: alice changes role, bob omits role_id
registry_merge = FixtureRegistry()
@registry_merge.register
def users(): # noqa: F811
return [
User(
id=uid1,
username="alice",
email="a@test.com",
role_id=editor.id,
),
User(id=uid2, username="bob", email="b@test.com"), # role_id omitted
]
await load_fixtures(
db_session, registry_merge, "users", strategy=LoadStrategy.MERGE
)
rows = {
r.username: r
for r in (await db_session.execute(select(User))).scalars().all()
}
assert rows["alice"].role_id == editor.id # updated
assert rows["bob"].role_id == editor.id # must be preserved, NOT NULL
@pytest.mark.anyio
async def test_insert_batch_mixed_pk_presence(self, db_session: AsyncSession):
"""INSERT batch where some rows have explicit PK and others rely on
the callable default (uuid.uuid4).
Normalization adds the PK key with None to rows that omitted it,
which can cause NOT NULL violations on the PK column.
"""
registry = FixtureRegistry()
explicit_id = uuid.uuid4()
@registry.register
def roles():
return [
Role(id=explicit_id, name="admin"),
Role(name="user"), # PK omitted, relies on default=uuid.uuid4
]
await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.INSERT)
from sqlalchemy import select
rows = (await db_session.execute(select(Role))).scalars().all()
assert len(rows) == 2
names = {r.name for r in rows}
assert names == {"admin", "user"}
# The "admin" row must have the explicit ID
admin = next(r for r in rows if r.name == "admin")
assert admin.id == explicit_id
# The "user" row must have a generated UUID (not None)
user = next(r for r in rows if r.name == "user")
assert user.id is not None
@pytest.mark.anyio
async def test_skip_existing_batch_mixed_nullable(self, db_session: AsyncSession):
"""SKIP_EXISTING with mixed nullable columns inserts correctly.
Only new rows are inserted; existing rows are untouched regardless of
which columns the fixture provides.
"""
from sqlalchemy import select
admin = await RoleCrud.create(db_session, RoleCreate(name="admin"))
uid1 = uuid.uuid4()
uid2 = uuid.uuid4()
# Pre-populate uid1 with notes
registry_initial = FixtureRegistry()
@registry_initial.register
def users():
return [
User(
id=uid1,
username="alice",
email="a@test.com",
role_id=admin.id,
notes="keep me",
),
]
await load_fixtures(
db_session, registry_initial, "users", strategy=LoadStrategy.INSERT
)
# Load again with SKIP_EXISTING: uid1 already exists, uid2 is new
registry_skip = FixtureRegistry()
@registry_skip.register
def users(): # noqa: F811
return [
User(id=uid1, username="alice-updated", email="a@test.com"), # exists
User(
id=uid2,
username="bob",
email="b@test.com",
notes="new user",
), # new
]
result = await load_fixtures(
db_session, registry_skip, "users", strategy=LoadStrategy.SKIP_EXISTING
)
assert len(result["users"]) == 1 # only bob inserted
rows = {
r.username: r
for r in (await db_session.execute(select(User))).scalars().all()
}
# alice untouched
assert rows["alice"].role_id == admin.id
assert rows["alice"].notes == "keep me"
# bob inserted correctly
assert rows["bob"].notes == "new user"
@pytest.mark.anyio
async def test_insert_batch_every_row_different_nullable_columns(
self, db_session: AsyncSession
):
"""Each row in the batch sets a different combination of nullable columns.
Tests that normalization produces valid SQL for all rows.
"""
registry = FixtureRegistry()
admin = await RoleCrud.create(db_session, RoleCreate(name="admin"))
uid1 = uuid.uuid4()
uid2 = uuid.uuid4()
uid3 = uuid.uuid4()
@registry.register
def users():
return [
User(
id=uid1,
username="all_set",
email="a@test.com",
role_id=admin.id,
notes="full",
),
User(
id=uid2, username="only_role", email="b@test.com", role_id=admin.id
),
User(
id=uid3, username="only_notes", email="c@test.com", notes="partial"
),
]
await load_fixtures(db_session, registry, "users", strategy=LoadStrategy.INSERT)
from sqlalchemy import select
rows = {
r.username: r
for r in (await db_session.execute(select(User))).scalars().all()
}
assert rows["all_set"].role_id == admin.id
assert rows["all_set"].notes == "full"
assert rows["only_role"].role_id == admin.id
assert rows["only_role"].notes is None
assert rows["only_notes"].role_id is None
assert rows["only_notes"].notes == "partial"