test: add missing tests for fixtures/utils.py

This commit is contained in:
2026-03-01 08:05:20 -05:00
parent e0828c7e71
commit 82ef96082e
3 changed files with 129 additions and 78 deletions

View File

@@ -15,6 +15,67 @@ from .registry import Context, FixtureRegistry
logger = get_logger() logger = get_logger()
async def _load_ordered(
session: AsyncSession,
registry: FixtureRegistry,
ordered_names: list[str],
strategy: LoadStrategy,
) -> dict[str, list[DeclarativeBase]]:
"""Load fixtures in order."""
results: dict[str, list[DeclarativeBase]] = {}
for name in ordered_names:
fixture = registry.get(name)
instances = list(fixture.func())
if not instances:
results[name] = []
continue
model_name = type(instances[0]).__name__
loaded: list[DeclarativeBase] = []
async with get_transaction(session):
for instance in instances:
if strategy == LoadStrategy.INSERT:
session.add(instance)
loaded.append(instance)
elif strategy == LoadStrategy.MERGE:
merged = await session.merge(instance)
loaded.append(merged)
else: # LoadStrategy.SKIP_EXISTING
pk = _get_primary_key(instance)
if pk is not None:
existing = await session.get(type(instance), pk)
if existing is None:
session.add(instance)
loaded.append(instance)
else:
session.add(instance)
loaded.append(instance)
results[name] = loaded
logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")
return results
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
"""Get the primary key value of a model instance."""
mapper = instance.__class__.__mapper__
pk_cols = mapper.primary_key
if len(pk_cols) == 1:
return getattr(instance, pk_cols[0].name, None)
pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols)
if all(v is not None for v in pk_values):
return pk_values
return None
def get_obj_by_attr( def get_obj_by_attr(
fixtures: Callable[[], Sequence[ModelType]], attr_name: str, value: Any fixtures: Callable[[], Sequence[ModelType]], attr_name: str, value: Any
) -> ModelType: ) -> ModelType:
@@ -56,13 +117,6 @@ async def load_fixtures(
Returns: Returns:
Dict mapping fixture names to loaded instances Dict mapping fixture names to loaded instances
Example:
```python
# Loads 'roles' first (dependency), then 'users'
result = await load_fixtures(session, fixtures, "users")
print(result["users"]) # [User(...), ...]
```
""" """
ordered = registry.resolve_dependencies(*names) ordered = registry.resolve_dependencies(*names)
return await _load_ordered(session, registry, ordered, strategy) return await _load_ordered(session, registry, ordered, strategy)
@@ -84,76 +138,6 @@ async def load_fixtures_by_context(
Returns: Returns:
Dict mapping fixture names to loaded instances Dict mapping fixture names to loaded instances
Example:
```python
# Load base + testing fixtures
await load_fixtures_by_context(
session, fixtures,
Context.BASE, Context.TESTING
)
```
""" """
ordered = registry.resolve_context_dependencies(*contexts) ordered = registry.resolve_context_dependencies(*contexts)
return await _load_ordered(session, registry, ordered, strategy) return await _load_ordered(session, registry, ordered, strategy)
async def _load_ordered(
session: AsyncSession,
registry: FixtureRegistry,
ordered_names: list[str],
strategy: LoadStrategy,
) -> dict[str, list[DeclarativeBase]]:
"""Load fixtures in order."""
results: dict[str, list[DeclarativeBase]] = {}
for name in ordered_names:
fixture = registry.get(name)
instances = list(fixture.func())
if not instances:
results[name] = []
continue
model_name = type(instances[0]).__name__
loaded: list[DeclarativeBase] = []
async with get_transaction(session):
for instance in instances:
if strategy == LoadStrategy.INSERT:
session.add(instance)
loaded.append(instance)
elif strategy == LoadStrategy.MERGE:
merged = await session.merge(instance)
loaded.append(merged)
elif strategy == LoadStrategy.SKIP_EXISTING:
pk = _get_primary_key(instance)
if pk is not None:
existing = await session.get(type(instance), pk)
if existing is None:
session.add(instance)
loaded.append(instance)
else:
session.add(instance)
loaded.append(instance)
results[name] = loaded
logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")
return results
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
"""Get the primary key value of a model instance."""
mapper = instance.__class__.__mapper__
pk_cols = mapper.primary_key
if len(pk_cols) == 1:
return getattr(instance, pk_cols[0].name, None)
pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols)
if all(v is not None for v in pk_values):
return pk_values
return None

View File

@@ -92,6 +92,15 @@ class IntRole(Base):
name: Mapped[str] = mapped_column(String(50), unique=True) name: Mapped[str] = mapped_column(String(50), unique=True)
class Permission(Base):
"""Test model with composite primary key."""
__tablename__ = "permissions"
subject: Mapped[str] = mapped_column(String(50), primary_key=True)
action: Mapped[str] = mapped_column(String(50), primary_key=True)
class Event(Base): class Event(Base):
"""Test model with DateTime and Date cursor columns.""" """Test model with DateTime and Date cursor columns."""

View File

@@ -14,7 +14,9 @@ from fastapi_toolsets.fixtures import (
load_fixtures_by_context, load_fixtures_by_context,
) )
from .conftest import Role, User from fastapi_toolsets.fixtures.utils import _get_primary_key
from .conftest import IntRole, Permission, Role, User
class TestContext: class TestContext:
@@ -597,6 +599,46 @@ class TestLoadFixtures:
count = await RoleCrud.count(db_session) count = await RoleCrud.count(db_session)
assert count == 2 assert count == 2
@pytest.mark.anyio
async def test_skip_existing_skips_if_record_exists(self, db_session: AsyncSession):
"""SKIP_EXISTING returns empty loaded list when the record already exists."""
registry = FixtureRegistry()
role_id = uuid.uuid4()
@registry.register
def roles():
return [Role(id=role_id, name="admin")]
# First load — inserts the record.
result1 = await load_fixtures(
db_session, registry, "roles", strategy=LoadStrategy.SKIP_EXISTING
)
assert len(result1["roles"]) == 1
# Remove from identity map so session.get() queries the DB in the second load.
db_session.expunge_all()
# Second load — record exists in DB, nothing should be added.
result2 = await load_fixtures(
db_session, registry, "roles", strategy=LoadStrategy.SKIP_EXISTING
)
assert result2["roles"] == []
@pytest.mark.anyio
async def test_skip_existing_null_pk_inserts(self, db_session: AsyncSession):
"""SKIP_EXISTING inserts when the instance has no PK set (auto-increment)."""
registry = FixtureRegistry()
@registry.register
def int_roles():
# No id provided — PK is None before INSERT (autoincrement).
return [IntRole(name="member")]
result = await load_fixtures(
db_session, registry, "int_roles", strategy=LoadStrategy.SKIP_EXISTING
)
assert len(result["int_roles"]) == 1
class TestLoadFixturesByContext: class TestLoadFixturesByContext:
"""Tests for load_fixtures_by_context function.""" """Tests for load_fixtures_by_context function."""
@@ -755,3 +797,19 @@ class TestGetObjByAttr:
"""Raises StopIteration when value type doesn't match.""" """Raises StopIteration when value type doesn't match."""
with pytest.raises(StopIteration): with pytest.raises(StopIteration):
get_obj_by_attr(self.roles, "id", "not-a-uuid") get_obj_by_attr(self.roles, "id", "not-a-uuid")
class TestGetPrimaryKey:
"""Unit tests for the _get_primary_key helper (composite PK paths)."""
def test_composite_pk_all_set(self):
"""Returns a tuple when all composite PK values are set."""
instance = Permission(subject="post", action="read")
pk = _get_primary_key(instance)
assert pk == ("post", "read")
def test_composite_pk_partial_none(self):
"""Returns None when any composite PK value is None."""
instance = Permission(subject="post") # action is None
pk = _get_primary_key(instance)
assert pk is None