From 82ef96082efa68d37f3776f3789e322650b9b41e Mon Sep 17 00:00:00 2001 From: d3vyce Date: Sun, 1 Mar 2026 08:05:20 -0500 Subject: [PATCH] test: add missing tests for fixtures/utils.py --- src/fastapi_toolsets/fixtures/utils.py | 138 +++++++++++-------------- tests/conftest.py | 9 ++ tests/test_fixtures.py | 60 ++++++++++- 3 files changed, 129 insertions(+), 78 deletions(-) diff --git a/src/fastapi_toolsets/fixtures/utils.py b/src/fastapi_toolsets/fixtures/utils.py index 1263c90..626c301 100644 --- a/src/fastapi_toolsets/fixtures/utils.py +++ b/src/fastapi_toolsets/fixtures/utils.py @@ -15,6 +15,67 @@ from .registry import Context, FixtureRegistry logger = get_logger() +async def _load_ordered( + session: AsyncSession, + registry: FixtureRegistry, + ordered_names: list[str], + strategy: LoadStrategy, +) -> dict[str, list[DeclarativeBase]]: + """Load fixtures in order.""" + results: dict[str, list[DeclarativeBase]] = {} + + for name in ordered_names: + fixture = registry.get(name) + instances = list(fixture.func()) + + if not instances: + results[name] = [] + continue + + model_name = type(instances[0]).__name__ + loaded: list[DeclarativeBase] = [] + + async with get_transaction(session): + for instance in instances: + if strategy == LoadStrategy.INSERT: + session.add(instance) + loaded.append(instance) + + elif strategy == LoadStrategy.MERGE: + merged = await session.merge(instance) + loaded.append(merged) + + else: # LoadStrategy.SKIP_EXISTING + pk = _get_primary_key(instance) + if pk is not None: + existing = await session.get(type(instance), pk) + if existing is None: + session.add(instance) + loaded.append(instance) + else: + session.add(instance) + loaded.append(instance) + + results[name] = loaded + logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)") + + return results + + +def _get_primary_key(instance: DeclarativeBase) -> Any | None: + """Get the primary key value of a model instance.""" + mapper = instance.__class__.__mapper__ + pk_cols = mapper.primary_key + + if len(pk_cols) == 1: + return getattr(instance, pk_cols[0].name, None) + + pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols) + if all(v is not None for v in pk_values): + return pk_values + return None + + def get_obj_by_attr( fixtures: Callable[[], Sequence[ModelType]], attr_name: str, value: Any ) -> ModelType: @@ -56,13 +117,6 @@ async def load_fixtures( Returns: Dict mapping fixture names to loaded instances - - Example: - ```python - # Loads 'roles' first (dependency), then 'users' - result = await load_fixtures(session, fixtures, "users") - print(result["users"]) # [User(...), ...] - ``` """ ordered = registry.resolve_dependencies(*names) return await _load_ordered(session, registry, ordered, strategy) @@ -84,76 +138,6 @@ async def load_fixtures_by_context( Returns: Dict mapping fixture names to loaded instances - - Example: - ```python - # Load base + testing fixtures - await load_fixtures_by_context( - session, fixtures, - Context.BASE, Context.TESTING - ) - ``` """ ordered = registry.resolve_context_dependencies(*contexts) return await _load_ordered(session, registry, ordered, strategy) - - -async def _load_ordered( - session: AsyncSession, - registry: FixtureRegistry, - ordered_names: list[str], - strategy: LoadStrategy, -) -> dict[str, list[DeclarativeBase]]: - """Load fixtures in order.""" - results: dict[str, list[DeclarativeBase]] = {} - - for name in ordered_names: - fixture = registry.get(name) - instances = list(fixture.func()) - - if not instances: - results[name] = [] - continue - - model_name = type(instances[0]).__name__ - loaded: list[DeclarativeBase] = [] - - async with get_transaction(session): - for instance in instances: - if strategy == LoadStrategy.INSERT: - session.add(instance) - loaded.append(instance) - - elif strategy == LoadStrategy.MERGE: - merged = await session.merge(instance) - loaded.append(merged) - - elif strategy == LoadStrategy.SKIP_EXISTING: - pk = _get_primary_key(instance) - if pk is not None: - existing = await session.get(type(instance), pk) - if existing is None: - session.add(instance) - loaded.append(instance) - else: - session.add(instance) - loaded.append(instance) - - results[name] = loaded - logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)") - - return results - - -def _get_primary_key(instance: DeclarativeBase) -> Any | None: - """Get the primary key value of a model instance.""" - mapper = instance.__class__.__mapper__ - pk_cols = mapper.primary_key - - if len(pk_cols) == 1: - return getattr(instance, pk_cols[0].name, None) - - pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols) - if all(v is not None for v in pk_values): - return pk_values - return None diff --git a/tests/conftest.py b/tests/conftest.py index 5b3a190..68be228 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -92,6 +92,15 @@ class IntRole(Base): name: Mapped[str] = mapped_column(String(50), unique=True) +class Permission(Base): + """Test model with composite primary key.""" + + __tablename__ = "permissions" + + subject: Mapped[str] = mapped_column(String(50), primary_key=True) + action: Mapped[str] = mapped_column(String(50), primary_key=True) + + class Event(Base): """Test model with DateTime and Date cursor columns.""" diff --git a/tests/test_fixtures.py b/tests/test_fixtures.py index 8e465b1..169765b 100644 --- a/tests/test_fixtures.py +++ b/tests/test_fixtures.py @@ -14,7 +14,9 @@ from fastapi_toolsets.fixtures import ( load_fixtures_by_context, ) -from .conftest import Role, User +from fastapi_toolsets.fixtures.utils import _get_primary_key + +from .conftest import IntRole, Permission, Role, User class TestContext: @@ -597,6 +599,46 @@ class TestLoadFixtures: count = await RoleCrud.count(db_session) assert count == 2 + @pytest.mark.anyio + async def test_skip_existing_skips_if_record_exists(self, db_session: AsyncSession): + """SKIP_EXISTING returns empty loaded list when the record already exists.""" + registry = FixtureRegistry() + role_id = uuid.uuid4() + + @registry.register + def roles(): + return [Role(id=role_id, name="admin")] + + # First load — inserts the record. + result1 = await load_fixtures( + db_session, registry, "roles", strategy=LoadStrategy.SKIP_EXISTING + ) + assert len(result1["roles"]) == 1 + + # Remove from identity map so session.get() queries the DB in the second load. + db_session.expunge_all() + + # Second load — record exists in DB, nothing should be added. + result2 = await load_fixtures( + db_session, registry, "roles", strategy=LoadStrategy.SKIP_EXISTING + ) + assert result2["roles"] == [] + + @pytest.mark.anyio + async def test_skip_existing_null_pk_inserts(self, db_session: AsyncSession): + """SKIP_EXISTING inserts when the instance has no PK set (auto-increment).""" + registry = FixtureRegistry() + + @registry.register + def int_roles(): + # No id provided — PK is None before INSERT (autoincrement). + return [IntRole(name="member")] + + result = await load_fixtures( + db_session, registry, "int_roles", strategy=LoadStrategy.SKIP_EXISTING + ) + assert len(result["int_roles"]) == 1 + class TestLoadFixturesByContext: """Tests for load_fixtures_by_context function.""" @@ -755,3 +797,19 @@ class TestGetObjByAttr: """Raises StopIteration when value type doesn't match.""" with pytest.raises(StopIteration): get_obj_by_attr(self.roles, "id", "not-a-uuid") + + +class TestGetPrimaryKey: + """Unit tests for the _get_primary_key helper (composite PK paths).""" + + def test_composite_pk_all_set(self): + """Returns a tuple when all composite PK values are set.""" + instance = Permission(subject="post", action="read") + pk = _get_primary_key(instance) + assert pk == ("post", "read") + + def test_composite_pk_partial_none(self): + """Returns None when any composite PK value is None.""" + instance = Permission(subject="post") # action is None + pk = _get_primary_key(instance) + assert pk is None