mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-15 22:26:25 +02:00
fix: normalize batch insert rows to prevent silent data loss for nullable columns (#192)
This commit is contained in:
@@ -38,6 +38,12 @@ def _instance_to_dict(instance: DeclarativeBase) -> dict[str, Any]:
|
||||
return result
|
||||
|
||||
|
||||
def _normalize_rows(dicts: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Ensure all row dicts share the same key set."""
|
||||
all_keys: set[str] = set().union(*dicts)
|
||||
return [{k: d.get(k) for k in all_keys} for d in dicts]
|
||||
|
||||
|
||||
def _group_by_type(
|
||||
instances: list[DeclarativeBase],
|
||||
) -> list[tuple[type[DeclarativeBase], list[DeclarativeBase]]]:
|
||||
@@ -54,7 +60,7 @@ async def _batch_insert(
|
||||
instances: list[DeclarativeBase],
|
||||
) -> None:
|
||||
"""INSERT all instances — raises on conflict (no duplicate handling)."""
|
||||
dicts = [_instance_to_dict(i) for i in instances]
|
||||
dicts = _normalize_rows([_instance_to_dict(i) for i in instances])
|
||||
await session.execute(pg_insert(model_cls).values(dicts))
|
||||
|
||||
|
||||
@@ -73,13 +79,16 @@ async def _batch_merge(
|
||||
if not any(col.name in pk_names_set for col in prop.columns)
|
||||
]
|
||||
|
||||
dicts = [_instance_to_dict(i) for i in instances]
|
||||
dicts = _normalize_rows([_instance_to_dict(i) for i in instances])
|
||||
stmt = pg_insert(model_cls).values(dicts)
|
||||
|
||||
if non_pk_cols:
|
||||
inserted_keys = set(dicts[0]) if dicts else set()
|
||||
update_cols = [col for col in non_pk_cols if col in inserted_keys]
|
||||
|
||||
if update_cols:
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=pk_names,
|
||||
set_={col: stmt.excluded[col] for col in non_pk_cols},
|
||||
set_={col: stmt.excluded[col] for col in update_cols},
|
||||
)
|
||||
else:
|
||||
stmt = stmt.on_conflict_do_nothing(index_elements=pk_names)
|
||||
@@ -108,14 +117,16 @@ async def _batch_skip_existing(
|
||||
loaded: list[DeclarativeBase] = list(no_pk)
|
||||
if no_pk:
|
||||
await session.execute(
|
||||
pg_insert(model_cls).values([_instance_to_dict(i) for i in no_pk])
|
||||
pg_insert(model_cls).values(
|
||||
_normalize_rows([_instance_to_dict(i) for i in no_pk])
|
||||
)
|
||||
)
|
||||
|
||||
if with_pk_pairs:
|
||||
with_pk = [i for i, _ in with_pk_pairs]
|
||||
stmt = (
|
||||
pg_insert(model_cls)
|
||||
.values([_instance_to_dict(i) for i in with_pk])
|
||||
.values(_normalize_rows([_instance_to_dict(i) for i in with_pk]))
|
||||
.on_conflict_do_nothing(index_elements=pk_names)
|
||||
)
|
||||
result = await session.execute(stmt.returning(*mapper.primary_key))
|
||||
|
||||
@@ -57,6 +57,7 @@ class User(Base):
|
||||
username: Mapped[str] = mapped_column(String(50), unique=True)
|
||||
email: Mapped[str] = mapped_column(String(100), unique=True)
|
||||
is_active: Mapped[bool] = mapped_column(default=True)
|
||||
notes: Mapped[str | None]
|
||||
role_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
ForeignKey("roles.id"), nullable=True
|
||||
)
|
||||
|
||||
@@ -14,10 +14,9 @@ from fastapi_toolsets.fixtures import (
|
||||
load_fixtures,
|
||||
load_fixtures_by_context,
|
||||
)
|
||||
|
||||
from fastapi_toolsets.fixtures.utils import _get_primary_key, _instance_to_dict
|
||||
|
||||
from .conftest import IntRole, Permission, Role, RoleCrud, User, UserCrud
|
||||
from .conftest import IntRole, Permission, Role, RoleCreate, RoleCrud, User, UserCrud
|
||||
|
||||
|
||||
class AppContext(str, Enum):
|
||||
@@ -621,6 +620,52 @@ class TestLoadFixtures:
|
||||
count = await RoleCrud.count(db_session)
|
||||
assert count == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_merge_does_not_overwrite_omitted_nullable_columns(
|
||||
self, db_session: AsyncSession
|
||||
):
|
||||
"""MERGE must not clear nullable columns that the fixture didn't set.
|
||||
|
||||
When a fixture omits a nullable column (e.g. role_id or notes), a re-merge
|
||||
must leave the existing DB value untouched — not overwrite it with NULL.
|
||||
"""
|
||||
registry = FixtureRegistry()
|
||||
admin = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||
uid = uuid.uuid4()
|
||||
|
||||
# First load: user has role_id and notes set
|
||||
@registry.register
|
||||
def users():
|
||||
return [
|
||||
User(
|
||||
id=uid,
|
||||
username="alice",
|
||||
email="a@test.com",
|
||||
role_id=admin.id,
|
||||
notes="original",
|
||||
)
|
||||
]
|
||||
|
||||
await load_fixtures(db_session, registry, "users", strategy=LoadStrategy.MERGE)
|
||||
|
||||
# Second load: fixture omits role_id and notes
|
||||
registry2 = FixtureRegistry()
|
||||
|
||||
@registry2.register
|
||||
def users(): # noqa: F811
|
||||
return [User(id=uid, username="alice-updated", email="a@test.com")]
|
||||
|
||||
await load_fixtures(db_session, registry2, "users", strategy=LoadStrategy.MERGE)
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
row = (
|
||||
await db_session.execute(select(User).where(User.id == uid))
|
||||
).scalar_one()
|
||||
assert row.username == "alice-updated" # updated column changed
|
||||
assert row.role_id == admin.id # omitted → preserved
|
||||
assert row.notes == "original" # omitted → preserved
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_load_with_skip_existing_strategy(self, db_session: AsyncSession):
|
||||
"""Load fixtures with SKIP_EXISTING strategy."""
|
||||
@@ -973,6 +1018,69 @@ class TestInstanceToDict:
|
||||
assert "role_id" in d
|
||||
assert d["role_id"] is None
|
||||
|
||||
def test_nullable_str_no_default_omitted_not_in_dict(self):
|
||||
"""Mapped[str | None] with no default, not provided in constructor, is absent from dict."""
|
||||
instance = User(id=uuid.uuid4(), username="u", email="e@e.com")
|
||||
d = _instance_to_dict(instance)
|
||||
assert "notes" not in d
|
||||
|
||||
def test_nullable_str_no_default_explicit_none_included(self):
|
||||
"""Mapped[str | None] with no default, explicitly set to None, is included as NULL."""
|
||||
instance = User(id=uuid.uuid4(), username="u", email="e@e.com", notes=None)
|
||||
d = _instance_to_dict(instance)
|
||||
assert "notes" in d
|
||||
assert d["notes"] is None
|
||||
|
||||
def test_nullable_str_no_default_with_value_included(self):
|
||||
"""Mapped[str | None] with no default and a value set is included normally."""
|
||||
instance = User(id=uuid.uuid4(), username="u", email="e@e.com", notes="hello")
|
||||
d = _instance_to_dict(instance)
|
||||
assert d["notes"] == "hello"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_nullable_str_no_default_insert_roundtrip(
|
||||
self, db_session: AsyncSession
|
||||
):
|
||||
"""Fixture loading works for models with Mapped[str | None] (no default).
|
||||
|
||||
Both the omitted-value (→ NULL) and explicit-None paths must insert without error.
|
||||
"""
|
||||
registry = FixtureRegistry()
|
||||
|
||||
uid_a = uuid.uuid4()
|
||||
uid_b = uuid.uuid4()
|
||||
uid_c = uuid.uuid4()
|
||||
|
||||
@registry.register
|
||||
def users():
|
||||
return [
|
||||
User(
|
||||
id=uid_a, username="no_notes", email="a@test.com"
|
||||
), # notes omitted
|
||||
User(
|
||||
id=uid_b, username="null_notes", email="b@test.com", notes=None
|
||||
), # explicit None
|
||||
User(
|
||||
id=uid_c, username="has_notes", email="c@test.com", notes="hi"
|
||||
), # value set
|
||||
]
|
||||
|
||||
result = await load_fixtures(db_session, registry, "users")
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
rows = (
|
||||
(await db_session.execute(select(User).order_by(User.username)))
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
by_username = {r.username: r for r in rows}
|
||||
|
||||
assert by_username["no_notes"].notes is None
|
||||
assert by_username["null_notes"].notes is None
|
||||
assert by_username["has_notes"].notes == "hi"
|
||||
assert len(result["users"]) == 3
|
||||
|
||||
|
||||
class TestBatchMergeNonPkColumns:
|
||||
"""Batch MERGE on a model with no non-PK columns (PK-only table)."""
|
||||
|
||||
Reference in New Issue
Block a user