fix: normalize batch insert rows to prevent silent data loss for nullable columns (#192)

This commit is contained in:
d3vyce
2026-03-27 19:20:41 +01:00
committed by GitHub
parent 5215b921ae
commit f4698bea8a
3 changed files with 128 additions and 8 deletions

View File

@@ -38,6 +38,12 @@ def _instance_to_dict(instance: DeclarativeBase) -> dict[str, Any]:
return result
def _normalize_rows(dicts: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Ensure all row dicts share the same key set."""
all_keys: set[str] = set().union(*dicts)
return [{k: d.get(k) for k in all_keys} for d in dicts]
def _group_by_type(
instances: list[DeclarativeBase],
) -> list[tuple[type[DeclarativeBase], list[DeclarativeBase]]]:
@@ -54,7 +60,7 @@ async def _batch_insert(
instances: list[DeclarativeBase],
) -> None:
"""INSERT all instances — raises on conflict (no duplicate handling)."""
dicts = [_instance_to_dict(i) for i in instances]
dicts = _normalize_rows([_instance_to_dict(i) for i in instances])
await session.execute(pg_insert(model_cls).values(dicts))
@@ -73,13 +79,16 @@ async def _batch_merge(
if not any(col.name in pk_names_set for col in prop.columns)
]
dicts = [_instance_to_dict(i) for i in instances]
dicts = _normalize_rows([_instance_to_dict(i) for i in instances])
stmt = pg_insert(model_cls).values(dicts)
if non_pk_cols:
inserted_keys = set(dicts[0]) if dicts else set()
update_cols = [col for col in non_pk_cols if col in inserted_keys]
if update_cols:
stmt = stmt.on_conflict_do_update(
index_elements=pk_names,
set_={col: stmt.excluded[col] for col in non_pk_cols},
set_={col: stmt.excluded[col] for col in update_cols},
)
else:
stmt = stmt.on_conflict_do_nothing(index_elements=pk_names)
@@ -108,14 +117,16 @@ async def _batch_skip_existing(
loaded: list[DeclarativeBase] = list(no_pk)
if no_pk:
await session.execute(
pg_insert(model_cls).values([_instance_to_dict(i) for i in no_pk])
pg_insert(model_cls).values(
_normalize_rows([_instance_to_dict(i) for i in no_pk])
)
)
if with_pk_pairs:
with_pk = [i for i, _ in with_pk_pairs]
stmt = (
pg_insert(model_cls)
.values([_instance_to_dict(i) for i in with_pk])
.values(_normalize_rows([_instance_to_dict(i) for i in with_pk]))
.on_conflict_do_nothing(index_elements=pk_names)
)
result = await session.execute(stmt.returning(*mapper.primary_key))