mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-16 14:46:24 +02:00
fix: batch insert normalizes away omitted nullable columns (#198)
This commit is contained in:
@@ -30,20 +30,16 @@ def _instance_to_dict(instance: DeclarativeBase) -> dict[str, Any]:
|
||||
if val is None:
|
||||
col = prop.columns[0]
|
||||
|
||||
if col.server_default is not None or (
|
||||
col.default is not None and col.default.is_callable
|
||||
if (
|
||||
col.server_default is not None
|
||||
or (col.default is not None and col.default.is_callable)
|
||||
or col.autoincrement is True
|
||||
):
|
||||
continue
|
||||
result[prop.key] = val
|
||||
return result
|
||||
|
||||
|
||||
def _normalize_rows(dicts: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Ensure all row dicts share the same key set."""
|
||||
all_keys: set[str] = set().union(*dicts)
|
||||
return [{k: d.get(k) for k in all_keys} for d in dicts]
|
||||
|
||||
|
||||
def _group_by_type(
|
||||
instances: list[DeclarativeBase],
|
||||
) -> list[tuple[type[DeclarativeBase], list[DeclarativeBase]]]:
|
||||
@@ -54,14 +50,32 @@ def _group_by_type(
|
||||
return list(groups.items())
|
||||
|
||||
|
||||
def _group_by_column_set(
|
||||
dicts: list[dict[str, Any]],
|
||||
instances: list[DeclarativeBase],
|
||||
) -> list[tuple[list[dict[str, Any]], list[DeclarativeBase]]]:
|
||||
"""Group (dict, instance) pairs by their dict key sets."""
|
||||
groups: dict[
|
||||
frozenset[str], tuple[list[dict[str, Any]], list[DeclarativeBase]]
|
||||
] = {}
|
||||
for d, inst in zip(dicts, instances):
|
||||
key = frozenset(d)
|
||||
if key not in groups:
|
||||
groups[key] = ([], [])
|
||||
groups[key][0].append(d)
|
||||
groups[key][1].append(inst)
|
||||
return list(groups.values())
|
||||
|
||||
|
||||
async def _batch_insert(
|
||||
session: AsyncSession,
|
||||
model_cls: type[DeclarativeBase],
|
||||
instances: list[DeclarativeBase],
|
||||
) -> None:
|
||||
"""INSERT all instances — raises on conflict (no duplicate handling)."""
|
||||
dicts = _normalize_rows([_instance_to_dict(i) for i in instances])
|
||||
await session.execute(pg_insert(model_cls).values(dicts))
|
||||
dicts = [_instance_to_dict(i) for i in instances]
|
||||
for group_dicts, _ in _group_by_column_set(dicts, instances):
|
||||
await session.execute(pg_insert(model_cls).values(group_dicts))
|
||||
|
||||
|
||||
async def _batch_merge(
|
||||
@@ -79,21 +93,22 @@ async def _batch_merge(
|
||||
if not any(col.name in pk_names_set for col in prop.columns)
|
||||
]
|
||||
|
||||
dicts = _normalize_rows([_instance_to_dict(i) for i in instances])
|
||||
stmt = pg_insert(model_cls).values(dicts)
|
||||
dicts = [_instance_to_dict(i) for i in instances]
|
||||
for group_dicts, _ in _group_by_column_set(dicts, instances):
|
||||
stmt = pg_insert(model_cls).values(group_dicts)
|
||||
|
||||
inserted_keys = set(dicts[0]) if dicts else set()
|
||||
update_cols = [col for col in non_pk_cols if col in inserted_keys]
|
||||
inserted_keys = set(group_dicts[0])
|
||||
update_cols = [col for col in non_pk_cols if col in inserted_keys]
|
||||
|
||||
if update_cols:
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=pk_names,
|
||||
set_={col: stmt.excluded[col] for col in update_cols},
|
||||
)
|
||||
else:
|
||||
stmt = stmt.on_conflict_do_nothing(index_elements=pk_names)
|
||||
if update_cols:
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=pk_names,
|
||||
set_={col: stmt.excluded[col] for col in update_cols},
|
||||
)
|
||||
else:
|
||||
stmt = stmt.on_conflict_do_nothing(index_elements=pk_names)
|
||||
|
||||
await session.execute(stmt)
|
||||
await session.execute(stmt)
|
||||
|
||||
|
||||
async def _batch_skip_existing(
|
||||
@@ -116,22 +131,30 @@ async def _batch_skip_existing(
|
||||
|
||||
loaded: list[DeclarativeBase] = list(no_pk)
|
||||
if no_pk:
|
||||
await session.execute(
|
||||
pg_insert(model_cls).values(
|
||||
_normalize_rows([_instance_to_dict(i) for i in no_pk])
|
||||
)
|
||||
)
|
||||
no_pk_dicts = [_instance_to_dict(i) for i in no_pk]
|
||||
for group_dicts, _ in _group_by_column_set(no_pk_dicts, no_pk):
|
||||
await session.execute(pg_insert(model_cls).values(group_dicts))
|
||||
|
||||
if with_pk_pairs:
|
||||
with_pk = [i for i, _ in with_pk_pairs]
|
||||
stmt = (
|
||||
pg_insert(model_cls)
|
||||
.values(_normalize_rows([_instance_to_dict(i) for i in with_pk]))
|
||||
.on_conflict_do_nothing(index_elements=pk_names)
|
||||
)
|
||||
result = await session.execute(stmt.returning(*mapper.primary_key))
|
||||
inserted_pks = {row[0] if len(pk_names) == 1 else tuple(row) for row in result}
|
||||
loaded.extend(inst for inst, pk in with_pk_pairs if pk in inserted_pks)
|
||||
with_pk_dicts = [_instance_to_dict(i) for i in with_pk]
|
||||
for group_dicts, group_insts in _group_by_column_set(with_pk_dicts, with_pk):
|
||||
stmt = (
|
||||
pg_insert(model_cls)
|
||||
.values(group_dicts)
|
||||
.on_conflict_do_nothing(index_elements=pk_names)
|
||||
)
|
||||
result = await session.execute(stmt.returning(*mapper.primary_key))
|
||||
inserted_pks = {
|
||||
row[0] if len(pk_names) == 1 else tuple(row) for row in result
|
||||
}
|
||||
loaded.extend(
|
||||
inst
|
||||
for inst, pk in zip(
|
||||
group_insts, [_get_primary_key(i) for i in group_insts]
|
||||
)
|
||||
if pk in inserted_pks
|
||||
)
|
||||
|
||||
return loaded
|
||||
|
||||
@@ -143,12 +166,7 @@ async def _load_ordered(
|
||||
strategy: LoadStrategy,
|
||||
contexts: tuple[str, ...] | None = None,
|
||||
) -> dict[str, list[DeclarativeBase]]:
|
||||
"""Load fixtures in order using batch Core INSERT statements.
|
||||
|
||||
When *contexts* is provided only variants whose context set intersects with
|
||||
*contexts* are called for each name; their instances are concatenated.
|
||||
When *contexts* is ``None`` all variants of each name are loaded.
|
||||
"""
|
||||
"""Load fixtures in order using batch Core INSERT statements."""
|
||||
results: dict[str, list[DeclarativeBase]] = {}
|
||||
|
||||
for name in ordered_names:
|
||||
@@ -158,10 +176,6 @@ async def _load_ordered(
|
||||
else registry.get_variants(name)
|
||||
)
|
||||
|
||||
# Cross-context dependency fallback: if we're loading by context but
|
||||
# no variant matches (e.g. a "base"-only fixture required by a
|
||||
# "testing" fixture), load all available variants so the dependency
|
||||
# is satisfied.
|
||||
if contexts is not None and not variants:
|
||||
variants = registry.get_variants(name)
|
||||
|
||||
@@ -267,10 +281,6 @@ async def load_fixtures_by_context(
|
||||
) -> dict[str, list[DeclarativeBase]]:
|
||||
"""Load all fixtures for specific contexts.
|
||||
|
||||
For each fixture name, only the variants whose context set intersects with
|
||||
*contexts* are loaded. When a name has variants in multiple of the
|
||||
requested contexts, their instances are merged before being inserted.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
registry: Fixture registry
|
||||
|
||||
Reference in New Issue
Block a user