diff --git a/src/fastapi_toolsets/fixtures/utils.py b/src/fastapi_toolsets/fixtures/utils.py index bc3ad3b..4edb64e 100644 --- a/src/fastapi_toolsets/fixtures/utils.py +++ b/src/fastapi_toolsets/fixtures/utils.py @@ -38,6 +38,12 @@ def _instance_to_dict(instance: DeclarativeBase) -> dict[str, Any]: return result +def _normalize_rows(dicts: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Ensure all row dicts share the same key set.""" + all_keys: set[str] = set().union(*dicts) + return [{k: d.get(k) for k in all_keys} for d in dicts] + + def _group_by_type( instances: list[DeclarativeBase], ) -> list[tuple[type[DeclarativeBase], list[DeclarativeBase]]]: @@ -54,7 +60,7 @@ async def _batch_insert( instances: list[DeclarativeBase], ) -> None: """INSERT all instances — raises on conflict (no duplicate handling).""" - dicts = [_instance_to_dict(i) for i in instances] + dicts = _normalize_rows([_instance_to_dict(i) for i in instances]) await session.execute(pg_insert(model_cls).values(dicts)) @@ -73,13 +79,16 @@ async def _batch_merge( if not any(col.name in pk_names_set for col in prop.columns) ] - dicts = [_instance_to_dict(i) for i in instances] + dicts = _normalize_rows([_instance_to_dict(i) for i in instances]) stmt = pg_insert(model_cls).values(dicts) - if non_pk_cols: + inserted_keys = set(dicts[0]) if dicts else set() + update_cols = [col for col in non_pk_cols if col in inserted_keys] + + if update_cols: stmt = stmt.on_conflict_do_update( index_elements=pk_names, - set_={col: stmt.excluded[col] for col in non_pk_cols}, + set_={col: stmt.excluded[col] for col in update_cols}, ) else: stmt = stmt.on_conflict_do_nothing(index_elements=pk_names) @@ -108,14 +117,16 @@ async def _batch_skip_existing( loaded: list[DeclarativeBase] = list(no_pk) if no_pk: await session.execute( - pg_insert(model_cls).values([_instance_to_dict(i) for i in no_pk]) + pg_insert(model_cls).values( + _normalize_rows([_instance_to_dict(i) for i in no_pk]) + ) ) if with_pk_pairs: with_pk = [i for i, _ in with_pk_pairs] stmt = ( pg_insert(model_cls) - .values([_instance_to_dict(i) for i in with_pk]) + .values(_normalize_rows([_instance_to_dict(i) for i in with_pk])) .on_conflict_do_nothing(index_elements=pk_names) ) result = await session.execute(stmt.returning(*mapper.primary_key)) diff --git a/tests/conftest.py b/tests/conftest.py index 10ac078..18388b7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -57,6 +57,7 @@ class User(Base): username: Mapped[str] = mapped_column(String(50), unique=True) email: Mapped[str] = mapped_column(String(100), unique=True) is_active: Mapped[bool] = mapped_column(default=True) + notes: Mapped[str | None] role_id: Mapped[uuid.UUID | None] = mapped_column( ForeignKey("roles.id"), nullable=True ) diff --git a/tests/test_fixtures.py b/tests/test_fixtures.py index d2455ef..f3c0e0a 100644 --- a/tests/test_fixtures.py +++ b/tests/test_fixtures.py @@ -14,10 +14,9 @@ from fastapi_toolsets.fixtures import ( load_fixtures, load_fixtures_by_context, ) - from fastapi_toolsets.fixtures.utils import _get_primary_key, _instance_to_dict -from .conftest import IntRole, Permission, Role, RoleCrud, User, UserCrud +from .conftest import IntRole, Permission, Role, RoleCreate, RoleCrud, User, UserCrud class AppContext(str, Enum): @@ -621,6 +620,52 @@ class TestLoadFixtures: count = await RoleCrud.count(db_session) assert count == 1 + @pytest.mark.anyio + async def test_merge_does_not_overwrite_omitted_nullable_columns( + self, db_session: AsyncSession + ): + """MERGE must not clear nullable columns that the fixture didn't set. + + When a fixture omits a nullable column (e.g. role_id or notes), a re-merge + must leave the existing DB value untouched — not overwrite it with NULL. + """ + registry = FixtureRegistry() + admin = await RoleCrud.create(db_session, RoleCreate(name="admin")) + uid = uuid.uuid4() + + # First load: user has role_id and notes set + @registry.register + def users(): + return [ + User( + id=uid, + username="alice", + email="a@test.com", + role_id=admin.id, + notes="original", + ) + ] + + await load_fixtures(db_session, registry, "users", strategy=LoadStrategy.MERGE) + + # Second load: fixture omits role_id and notes + registry2 = FixtureRegistry() + + @registry2.register + def users(): # noqa: F811 + return [User(id=uid, username="alice-updated", email="a@test.com")] + + await load_fixtures(db_session, registry2, "users", strategy=LoadStrategy.MERGE) + + from sqlalchemy import select + + row = ( + await db_session.execute(select(User).where(User.id == uid)) + ).scalar_one() + assert row.username == "alice-updated" # updated column changed + assert row.role_id == admin.id # omitted → preserved + assert row.notes == "original" # omitted → preserved + @pytest.mark.anyio async def test_load_with_skip_existing_strategy(self, db_session: AsyncSession): """Load fixtures with SKIP_EXISTING strategy.""" @@ -973,6 +1018,69 @@ class TestInstanceToDict: assert "role_id" in d assert d["role_id"] is None + def test_nullable_str_no_default_omitted_not_in_dict(self): + """Mapped[str | None] with no default, not provided in constructor, is absent from dict.""" + instance = User(id=uuid.uuid4(), username="u", email="e@e.com") + d = _instance_to_dict(instance) + assert "notes" not in d + + def test_nullable_str_no_default_explicit_none_included(self): + """Mapped[str | None] with no default, explicitly set to None, is included as NULL.""" + instance = User(id=uuid.uuid4(), username="u", email="e@e.com", notes=None) + d = _instance_to_dict(instance) + assert "notes" in d + assert d["notes"] is None + + def test_nullable_str_no_default_with_value_included(self): + """Mapped[str | None] with no default and a value set is included normally.""" + instance = User(id=uuid.uuid4(), username="u", email="e@e.com", notes="hello") + d = _instance_to_dict(instance) + assert d["notes"] == "hello" + + @pytest.mark.anyio + async def test_nullable_str_no_default_insert_roundtrip( + self, db_session: AsyncSession + ): + """Fixture loading works for models with Mapped[str | None] (no default). + + Both the omitted-value (→ NULL) and explicit-None paths must insert without error. + """ + registry = FixtureRegistry() + + uid_a = uuid.uuid4() + uid_b = uuid.uuid4() + uid_c = uuid.uuid4() + + @registry.register + def users(): + return [ + User( + id=uid_a, username="no_notes", email="a@test.com" + ), # notes omitted + User( + id=uid_b, username="null_notes", email="b@test.com", notes=None + ), # explicit None + User( + id=uid_c, username="has_notes", email="c@test.com", notes="hi" + ), # value set + ] + + result = await load_fixtures(db_session, registry, "users") + + from sqlalchemy import select + + rows = ( + (await db_session.execute(select(User).order_by(User.username))) + .scalars() + .all() + ) + by_username = {r.username: r for r in rows} + + assert by_username["no_notes"].notes is None + assert by_username["null_notes"].notes is None + assert by_username["has_notes"].notes == "hi" + assert len(result["users"]) == 3 + class TestBatchMergeNonPkColumns: """Batch MERGE on a model with no non-PK columns (PK-only table)."""