mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-15 22:26:25 +02:00
fix: batch insert normalizes away omitted nullable columns (#198)
This commit is contained in:
@@ -30,20 +30,16 @@ def _instance_to_dict(instance: DeclarativeBase) -> dict[str, Any]:
|
||||
if val is None:
|
||||
col = prop.columns[0]
|
||||
|
||||
if col.server_default is not None or (
|
||||
col.default is not None and col.default.is_callable
|
||||
if (
|
||||
col.server_default is not None
|
||||
or (col.default is not None and col.default.is_callable)
|
||||
or col.autoincrement is True
|
||||
):
|
||||
continue
|
||||
result[prop.key] = val
|
||||
return result
|
||||
|
||||
|
||||
def _normalize_rows(dicts: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Ensure all row dicts share the same key set."""
|
||||
all_keys: set[str] = set().union(*dicts)
|
||||
return [{k: d.get(k) for k in all_keys} for d in dicts]
|
||||
|
||||
|
||||
def _group_by_type(
|
||||
instances: list[DeclarativeBase],
|
||||
) -> list[tuple[type[DeclarativeBase], list[DeclarativeBase]]]:
|
||||
@@ -54,14 +50,32 @@ def _group_by_type(
|
||||
return list(groups.items())
|
||||
|
||||
|
||||
def _group_by_column_set(
|
||||
dicts: list[dict[str, Any]],
|
||||
instances: list[DeclarativeBase],
|
||||
) -> list[tuple[list[dict[str, Any]], list[DeclarativeBase]]]:
|
||||
"""Group (dict, instance) pairs by their dict key sets."""
|
||||
groups: dict[
|
||||
frozenset[str], tuple[list[dict[str, Any]], list[DeclarativeBase]]
|
||||
] = {}
|
||||
for d, inst in zip(dicts, instances):
|
||||
key = frozenset(d)
|
||||
if key not in groups:
|
||||
groups[key] = ([], [])
|
||||
groups[key][0].append(d)
|
||||
groups[key][1].append(inst)
|
||||
return list(groups.values())
|
||||
|
||||
|
||||
async def _batch_insert(
|
||||
session: AsyncSession,
|
||||
model_cls: type[DeclarativeBase],
|
||||
instances: list[DeclarativeBase],
|
||||
) -> None:
|
||||
"""INSERT all instances — raises on conflict (no duplicate handling)."""
|
||||
dicts = _normalize_rows([_instance_to_dict(i) for i in instances])
|
||||
await session.execute(pg_insert(model_cls).values(dicts))
|
||||
dicts = [_instance_to_dict(i) for i in instances]
|
||||
for group_dicts, _ in _group_by_column_set(dicts, instances):
|
||||
await session.execute(pg_insert(model_cls).values(group_dicts))
|
||||
|
||||
|
||||
async def _batch_merge(
|
||||
@@ -79,21 +93,22 @@ async def _batch_merge(
|
||||
if not any(col.name in pk_names_set for col in prop.columns)
|
||||
]
|
||||
|
||||
dicts = _normalize_rows([_instance_to_dict(i) for i in instances])
|
||||
stmt = pg_insert(model_cls).values(dicts)
|
||||
dicts = [_instance_to_dict(i) for i in instances]
|
||||
for group_dicts, _ in _group_by_column_set(dicts, instances):
|
||||
stmt = pg_insert(model_cls).values(group_dicts)
|
||||
|
||||
inserted_keys = set(dicts[0]) if dicts else set()
|
||||
update_cols = [col for col in non_pk_cols if col in inserted_keys]
|
||||
inserted_keys = set(group_dicts[0])
|
||||
update_cols = [col for col in non_pk_cols if col in inserted_keys]
|
||||
|
||||
if update_cols:
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=pk_names,
|
||||
set_={col: stmt.excluded[col] for col in update_cols},
|
||||
)
|
||||
else:
|
||||
stmt = stmt.on_conflict_do_nothing(index_elements=pk_names)
|
||||
if update_cols:
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=pk_names,
|
||||
set_={col: stmt.excluded[col] for col in update_cols},
|
||||
)
|
||||
else:
|
||||
stmt = stmt.on_conflict_do_nothing(index_elements=pk_names)
|
||||
|
||||
await session.execute(stmt)
|
||||
await session.execute(stmt)
|
||||
|
||||
|
||||
async def _batch_skip_existing(
|
||||
@@ -116,22 +131,30 @@ async def _batch_skip_existing(
|
||||
|
||||
loaded: list[DeclarativeBase] = list(no_pk)
|
||||
if no_pk:
|
||||
await session.execute(
|
||||
pg_insert(model_cls).values(
|
||||
_normalize_rows([_instance_to_dict(i) for i in no_pk])
|
||||
)
|
||||
)
|
||||
no_pk_dicts = [_instance_to_dict(i) for i in no_pk]
|
||||
for group_dicts, _ in _group_by_column_set(no_pk_dicts, no_pk):
|
||||
await session.execute(pg_insert(model_cls).values(group_dicts))
|
||||
|
||||
if with_pk_pairs:
|
||||
with_pk = [i for i, _ in with_pk_pairs]
|
||||
stmt = (
|
||||
pg_insert(model_cls)
|
||||
.values(_normalize_rows([_instance_to_dict(i) for i in with_pk]))
|
||||
.on_conflict_do_nothing(index_elements=pk_names)
|
||||
)
|
||||
result = await session.execute(stmt.returning(*mapper.primary_key))
|
||||
inserted_pks = {row[0] if len(pk_names) == 1 else tuple(row) for row in result}
|
||||
loaded.extend(inst for inst, pk in with_pk_pairs if pk in inserted_pks)
|
||||
with_pk_dicts = [_instance_to_dict(i) for i in with_pk]
|
||||
for group_dicts, group_insts in _group_by_column_set(with_pk_dicts, with_pk):
|
||||
stmt = (
|
||||
pg_insert(model_cls)
|
||||
.values(group_dicts)
|
||||
.on_conflict_do_nothing(index_elements=pk_names)
|
||||
)
|
||||
result = await session.execute(stmt.returning(*mapper.primary_key))
|
||||
inserted_pks = {
|
||||
row[0] if len(pk_names) == 1 else tuple(row) for row in result
|
||||
}
|
||||
loaded.extend(
|
||||
inst
|
||||
for inst, pk in zip(
|
||||
group_insts, [_get_primary_key(i) for i in group_insts]
|
||||
)
|
||||
if pk in inserted_pks
|
||||
)
|
||||
|
||||
return loaded
|
||||
|
||||
@@ -143,12 +166,7 @@ async def _load_ordered(
|
||||
strategy: LoadStrategy,
|
||||
contexts: tuple[str, ...] | None = None,
|
||||
) -> dict[str, list[DeclarativeBase]]:
|
||||
"""Load fixtures in order using batch Core INSERT statements.
|
||||
|
||||
When *contexts* is provided only variants whose context set intersects with
|
||||
*contexts* are called for each name; their instances are concatenated.
|
||||
When *contexts* is ``None`` all variants of each name are loaded.
|
||||
"""
|
||||
"""Load fixtures in order using batch Core INSERT statements."""
|
||||
results: dict[str, list[DeclarativeBase]] = {}
|
||||
|
||||
for name in ordered_names:
|
||||
@@ -158,10 +176,6 @@ async def _load_ordered(
|
||||
else registry.get_variants(name)
|
||||
)
|
||||
|
||||
# Cross-context dependency fallback: if we're loading by context but
|
||||
# no variant matches (e.g. a "base"-only fixture required by a
|
||||
# "testing" fixture), load all available variants so the dependency
|
||||
# is satisfied.
|
||||
if contexts is not None and not variants:
|
||||
variants = registry.get_variants(name)
|
||||
|
||||
@@ -267,10 +281,6 @@ async def load_fixtures_by_context(
|
||||
) -> dict[str, list[DeclarativeBase]]:
|
||||
"""Load all fixtures for specific contexts.
|
||||
|
||||
For each fixture name, only the variants whose context set intersects with
|
||||
*contexts* are loaded. When a name has variants in multiple of the
|
||||
requested contexts, their instances are merged before being inserted.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
registry: Fixture registry
|
||||
|
||||
@@ -1011,6 +1011,14 @@ class TestInstanceToDict:
|
||||
assert "id" not in d
|
||||
assert d["name"] == "admin"
|
||||
|
||||
def test_autoincrement_none_excluded(self):
|
||||
"""A column whose value is None but has autoincrement=True is excluded
|
||||
so the DB generates the value via its sequence."""
|
||||
instance = IntRole(id=None, name="admin")
|
||||
d = _instance_to_dict(instance)
|
||||
assert "id" not in d
|
||||
assert d["name"] == "admin"
|
||||
|
||||
def test_nullable_none_included(self):
|
||||
"""None on a nullable column with no default is kept (explicit NULL)."""
|
||||
instance = User(id=uuid.uuid4(), username="u", email="e@e.com", role_id=None)
|
||||
@@ -1107,3 +1115,361 @@ class TestBatchMergeNonPkColumns:
|
||||
db_session, registry, "permissions", strategy=LoadStrategy.MERGE
|
||||
)
|
||||
assert len(result2["permissions"]) == 2
|
||||
|
||||
|
||||
class TestBatchNullableColumnEdgeCases:
|
||||
"""Deep tests for nullable column handling during batch import."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_insert_batch_mixed_nullable_fk(self, db_session: AsyncSession):
|
||||
"""INSERT batch where some rows set a nullable FK and others don't.
|
||||
|
||||
After normalization the omitted role_id becomes None. For INSERT this
|
||||
is acceptable — both rows should insert successfully with the correct
|
||||
values (one with FK, one with NULL).
|
||||
"""
|
||||
registry = FixtureRegistry()
|
||||
admin = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||
uid1 = uuid.uuid4()
|
||||
uid2 = uuid.uuid4()
|
||||
|
||||
@registry.register
|
||||
def users():
|
||||
return [
|
||||
User(
|
||||
id=uid1, username="with_role", email="a@test.com", role_id=admin.id
|
||||
),
|
||||
User(id=uid2, username="no_role", email="b@test.com"),
|
||||
]
|
||||
|
||||
await load_fixtures(db_session, registry, "users", strategy=LoadStrategy.INSERT)
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
rows = {
|
||||
r.username: r
|
||||
for r in (await db_session.execute(select(User))).scalars().all()
|
||||
}
|
||||
assert rows["with_role"].role_id == admin.id
|
||||
assert rows["no_role"].role_id is None
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_insert_batch_mixed_nullable_notes(self, db_session: AsyncSession):
|
||||
"""INSERT batch where some rows have notes and others don't.
|
||||
|
||||
Ensures normalization doesn't break the insert and that each row gets
|
||||
the intended value.
|
||||
"""
|
||||
registry = FixtureRegistry()
|
||||
uid1 = uuid.uuid4()
|
||||
uid2 = uuid.uuid4()
|
||||
uid3 = uuid.uuid4()
|
||||
|
||||
@registry.register
|
||||
def users():
|
||||
return [
|
||||
User(
|
||||
id=uid1,
|
||||
username="has_notes",
|
||||
email="a@test.com",
|
||||
notes="important",
|
||||
),
|
||||
User(id=uid2, username="no_notes", email="b@test.com"),
|
||||
User(id=uid3, username="null_notes", email="c@test.com", notes=None),
|
||||
]
|
||||
|
||||
await load_fixtures(db_session, registry, "users", strategy=LoadStrategy.INSERT)
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
rows = {
|
||||
r.username: r
|
||||
for r in (await db_session.execute(select(User))).scalars().all()
|
||||
}
|
||||
assert rows["has_notes"].notes == "important"
|
||||
assert rows["no_notes"].notes is None
|
||||
assert rows["null_notes"].notes is None
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_merge_batch_mixed_nullable_does_not_overwrite(
|
||||
self, db_session: AsyncSession
|
||||
):
|
||||
"""MERGE batch where one row sets a nullable column and another omits it.
|
||||
|
||||
If both rows already exist in DB, the row that omits the column must
|
||||
NOT have its existing value overwritten with NULL.
|
||||
|
||||
This is the core normalization bug: _normalize_rows fills missing keys
|
||||
with None, and then MERGE's SET clause includes that column for ALL rows.
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
|
||||
admin = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||
uid1 = uuid.uuid4()
|
||||
uid2 = uuid.uuid4()
|
||||
|
||||
# Pre-populate: both users have role_id and notes
|
||||
registry_initial = FixtureRegistry()
|
||||
|
||||
@registry_initial.register
|
||||
def users():
|
||||
return [
|
||||
User(
|
||||
id=uid1,
|
||||
username="alice",
|
||||
email="a@test.com",
|
||||
role_id=admin.id,
|
||||
notes="alice notes",
|
||||
),
|
||||
User(
|
||||
id=uid2,
|
||||
username="bob",
|
||||
email="b@test.com",
|
||||
role_id=admin.id,
|
||||
notes="bob notes",
|
||||
),
|
||||
]
|
||||
|
||||
await load_fixtures(
|
||||
db_session, registry_initial, "users", strategy=LoadStrategy.INSERT
|
||||
)
|
||||
|
||||
# Re-merge: alice updates notes, bob omits notes entirely
|
||||
registry_merge = FixtureRegistry()
|
||||
|
||||
@registry_merge.register
|
||||
def users(): # noqa: F811
|
||||
return [
|
||||
User(
|
||||
id=uid1,
|
||||
username="alice",
|
||||
email="a@test.com",
|
||||
role_id=admin.id,
|
||||
notes="updated",
|
||||
),
|
||||
User(
|
||||
id=uid2,
|
||||
username="bob",
|
||||
email="b@test.com",
|
||||
role_id=admin.id,
|
||||
), # notes omitted
|
||||
]
|
||||
|
||||
await load_fixtures(
|
||||
db_session, registry_merge, "users", strategy=LoadStrategy.MERGE
|
||||
)
|
||||
|
||||
rows = {
|
||||
r.username: r
|
||||
for r in (await db_session.execute(select(User))).scalars().all()
|
||||
}
|
||||
assert rows["alice"].notes == "updated"
|
||||
# Bob's notes must be preserved, NOT overwritten with NULL
|
||||
assert rows["bob"].notes == "bob notes"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_merge_batch_mixed_nullable_fk_preserves_existing(
|
||||
self, db_session: AsyncSession
|
||||
):
|
||||
"""MERGE batch where one row sets role_id and another omits it.
|
||||
|
||||
The row that omits role_id must keep its existing DB value.
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
|
||||
admin = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||
editor = await RoleCrud.create(db_session, RoleCreate(name="editor"))
|
||||
uid1 = uuid.uuid4()
|
||||
uid2 = uuid.uuid4()
|
||||
|
||||
# Pre-populate
|
||||
registry_initial = FixtureRegistry()
|
||||
|
||||
@registry_initial.register
|
||||
def users():
|
||||
return [
|
||||
User(
|
||||
id=uid1,
|
||||
username="alice",
|
||||
email="a@test.com",
|
||||
role_id=admin.id,
|
||||
),
|
||||
User(
|
||||
id=uid2,
|
||||
username="bob",
|
||||
email="b@test.com",
|
||||
role_id=editor.id,
|
||||
),
|
||||
]
|
||||
|
||||
await load_fixtures(
|
||||
db_session, registry_initial, "users", strategy=LoadStrategy.INSERT
|
||||
)
|
||||
|
||||
# Re-merge: alice changes role, bob omits role_id
|
||||
registry_merge = FixtureRegistry()
|
||||
|
||||
@registry_merge.register
|
||||
def users(): # noqa: F811
|
||||
return [
|
||||
User(
|
||||
id=uid1,
|
||||
username="alice",
|
||||
email="a@test.com",
|
||||
role_id=editor.id,
|
||||
),
|
||||
User(id=uid2, username="bob", email="b@test.com"), # role_id omitted
|
||||
]
|
||||
|
||||
await load_fixtures(
|
||||
db_session, registry_merge, "users", strategy=LoadStrategy.MERGE
|
||||
)
|
||||
|
||||
rows = {
|
||||
r.username: r
|
||||
for r in (await db_session.execute(select(User))).scalars().all()
|
||||
}
|
||||
assert rows["alice"].role_id == editor.id # updated
|
||||
assert rows["bob"].role_id == editor.id # must be preserved, NOT NULL
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_insert_batch_mixed_pk_presence(self, db_session: AsyncSession):
|
||||
"""INSERT batch where some rows have explicit PK and others rely on
|
||||
the callable default (uuid.uuid4).
|
||||
|
||||
Normalization adds the PK key with None to rows that omitted it,
|
||||
which can cause NOT NULL violations on the PK column.
|
||||
"""
|
||||
registry = FixtureRegistry()
|
||||
explicit_id = uuid.uuid4()
|
||||
|
||||
@registry.register
|
||||
def roles():
|
||||
return [
|
||||
Role(id=explicit_id, name="admin"),
|
||||
Role(name="user"), # PK omitted, relies on default=uuid.uuid4
|
||||
]
|
||||
|
||||
await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.INSERT)
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
rows = (await db_session.execute(select(Role))).scalars().all()
|
||||
assert len(rows) == 2
|
||||
names = {r.name for r in rows}
|
||||
assert names == {"admin", "user"}
|
||||
# The "admin" row must have the explicit ID
|
||||
admin = next(r for r in rows if r.name == "admin")
|
||||
assert admin.id == explicit_id
|
||||
# The "user" row must have a generated UUID (not None)
|
||||
user = next(r for r in rows if r.name == "user")
|
||||
assert user.id is not None
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_skip_existing_batch_mixed_nullable(self, db_session: AsyncSession):
|
||||
"""SKIP_EXISTING with mixed nullable columns inserts correctly.
|
||||
|
||||
Only new rows are inserted; existing rows are untouched regardless of
|
||||
which columns the fixture provides.
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
|
||||
admin = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||
uid1 = uuid.uuid4()
|
||||
uid2 = uuid.uuid4()
|
||||
|
||||
# Pre-populate uid1 with notes
|
||||
registry_initial = FixtureRegistry()
|
||||
|
||||
@registry_initial.register
|
||||
def users():
|
||||
return [
|
||||
User(
|
||||
id=uid1,
|
||||
username="alice",
|
||||
email="a@test.com",
|
||||
role_id=admin.id,
|
||||
notes="keep me",
|
||||
),
|
||||
]
|
||||
|
||||
await load_fixtures(
|
||||
db_session, registry_initial, "users", strategy=LoadStrategy.INSERT
|
||||
)
|
||||
|
||||
# Load again with SKIP_EXISTING: uid1 already exists, uid2 is new
|
||||
registry_skip = FixtureRegistry()
|
||||
|
||||
@registry_skip.register
|
||||
def users(): # noqa: F811
|
||||
return [
|
||||
User(id=uid1, username="alice-updated", email="a@test.com"), # exists
|
||||
User(
|
||||
id=uid2,
|
||||
username="bob",
|
||||
email="b@test.com",
|
||||
notes="new user",
|
||||
), # new
|
||||
]
|
||||
|
||||
result = await load_fixtures(
|
||||
db_session, registry_skip, "users", strategy=LoadStrategy.SKIP_EXISTING
|
||||
)
|
||||
assert len(result["users"]) == 1 # only bob inserted
|
||||
|
||||
rows = {
|
||||
r.username: r
|
||||
for r in (await db_session.execute(select(User))).scalars().all()
|
||||
}
|
||||
# alice untouched
|
||||
assert rows["alice"].role_id == admin.id
|
||||
assert rows["alice"].notes == "keep me"
|
||||
# bob inserted correctly
|
||||
assert rows["bob"].notes == "new user"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_insert_batch_every_row_different_nullable_columns(
|
||||
self, db_session: AsyncSession
|
||||
):
|
||||
"""Each row in the batch sets a different combination of nullable columns.
|
||||
|
||||
Tests that normalization produces valid SQL for all rows.
|
||||
"""
|
||||
registry = FixtureRegistry()
|
||||
admin = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||
uid1 = uuid.uuid4()
|
||||
uid2 = uuid.uuid4()
|
||||
uid3 = uuid.uuid4()
|
||||
|
||||
@registry.register
|
||||
def users():
|
||||
return [
|
||||
User(
|
||||
id=uid1,
|
||||
username="all_set",
|
||||
email="a@test.com",
|
||||
role_id=admin.id,
|
||||
notes="full",
|
||||
),
|
||||
User(
|
||||
id=uid2, username="only_role", email="b@test.com", role_id=admin.id
|
||||
),
|
||||
User(
|
||||
id=uid3, username="only_notes", email="c@test.com", notes="partial"
|
||||
),
|
||||
]
|
||||
|
||||
await load_fixtures(db_session, registry, "users", strategy=LoadStrategy.INSERT)
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
rows = {
|
||||
r.username: r
|
||||
for r in (await db_session.execute(select(User))).scalars().all()
|
||||
}
|
||||
assert rows["all_set"].role_id == admin.id
|
||||
assert rows["all_set"].notes == "full"
|
||||
assert rows["only_role"].role_id == admin.id
|
||||
assert rows["only_role"].notes is None
|
||||
assert rows["only_notes"].role_id is None
|
||||
assert rows["only_notes"].notes == "partial"
|
||||
|
||||
Reference in New Issue
Block a user