test: add missing tests for fixtures/utils.py

This commit is contained in:
2026-03-01 08:05:20 -05:00
parent bcf0c3becb
commit 858cca576e
3 changed files with 129 additions and 78 deletions

View File

@@ -15,6 +15,67 @@ from .registry import Context, FixtureRegistry
logger = get_logger()
async def _load_ordered(
session: AsyncSession,
registry: FixtureRegistry,
ordered_names: list[str],
strategy: LoadStrategy,
) -> dict[str, list[DeclarativeBase]]:
"""Load fixtures in order."""
results: dict[str, list[DeclarativeBase]] = {}
for name in ordered_names:
fixture = registry.get(name)
instances = list(fixture.func())
if not instances:
results[name] = []
continue
model_name = type(instances[0]).__name__
loaded: list[DeclarativeBase] = []
async with get_transaction(session):
for instance in instances:
if strategy == LoadStrategy.INSERT:
session.add(instance)
loaded.append(instance)
elif strategy == LoadStrategy.MERGE:
merged = await session.merge(instance)
loaded.append(merged)
else: # LoadStrategy.SKIP_EXISTING
pk = _get_primary_key(instance)
if pk is not None:
existing = await session.get(type(instance), pk)
if existing is None:
session.add(instance)
loaded.append(instance)
else:
session.add(instance)
loaded.append(instance)
results[name] = loaded
logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")
return results
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
"""Get the primary key value of a model instance."""
mapper = instance.__class__.__mapper__
pk_cols = mapper.primary_key
if len(pk_cols) == 1:
return getattr(instance, pk_cols[0].name, None)
pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols)
if all(v is not None for v in pk_values):
return pk_values
return None
def get_obj_by_attr(
fixtures: Callable[[], Sequence[ModelType]], attr_name: str, value: Any
) -> ModelType:
@@ -56,13 +117,6 @@ async def load_fixtures(
Returns:
Dict mapping fixture names to loaded instances
Example:
```python
# Loads 'roles' first (dependency), then 'users'
result = await load_fixtures(session, fixtures, "users")
print(result["users"]) # [User(...), ...]
```
"""
ordered = registry.resolve_dependencies(*names)
return await _load_ordered(session, registry, ordered, strategy)
@@ -84,76 +138,6 @@ async def load_fixtures_by_context(
Returns:
Dict mapping fixture names to loaded instances
Example:
```python
# Load base + testing fixtures
await load_fixtures_by_context(
session, fixtures,
Context.BASE, Context.TESTING
)
```
"""
ordered = registry.resolve_context_dependencies(*contexts)
return await _load_ordered(session, registry, ordered, strategy)
async def _load_ordered(
session: AsyncSession,
registry: FixtureRegistry,
ordered_names: list[str],
strategy: LoadStrategy,
) -> dict[str, list[DeclarativeBase]]:
"""Load fixtures in order."""
results: dict[str, list[DeclarativeBase]] = {}
for name in ordered_names:
fixture = registry.get(name)
instances = list(fixture.func())
if not instances:
results[name] = []
continue
model_name = type(instances[0]).__name__
loaded: list[DeclarativeBase] = []
async with get_transaction(session):
for instance in instances:
if strategy == LoadStrategy.INSERT:
session.add(instance)
loaded.append(instance)
elif strategy == LoadStrategy.MERGE:
merged = await session.merge(instance)
loaded.append(merged)
elif strategy == LoadStrategy.SKIP_EXISTING:
pk = _get_primary_key(instance)
if pk is not None:
existing = await session.get(type(instance), pk)
if existing is None:
session.add(instance)
loaded.append(instance)
else:
session.add(instance)
loaded.append(instance)
results[name] = loaded
logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")
return results
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
"""Get the primary key value of a model instance."""
mapper = instance.__class__.__mapper__
pk_cols = mapper.primary_key
if len(pk_cols) == 1:
return getattr(instance, pk_cols[0].name, None)
pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols)
if all(v is not None for v in pk_values):
return pk_values
return None