Files
fastapi-toolsets/src/fastapi_toolsets/fixtures/utils.py
d3vyce 432e0722e0 refactor: remove deprecated parameter and function/cleanup code (#101)
* refactor: remove deprecated parameter and function

* refactor: centralize type aliases in types.py and simplify crud layer

* test: add missing tests for fixtures/utils.py

* refactor: simplify and deduplicate across crud, metrics, cli, and
exceptions

* docs: fix old Paginate references

* docs: add migration + fix icon

* docs: update README/migration to v2
2026-03-04 17:19:38 +01:00

144 lines
4.5 KiB
Python

"""Fixture loading utilities for database seeding."""
from collections.abc import Callable, Sequence
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase
from ..db import get_transaction
from ..logger import get_logger
from ..types import ModelType
from .enum import LoadStrategy
from .registry import Context, FixtureRegistry
logger = get_logger()
async def _load_ordered(
session: AsyncSession,
registry: FixtureRegistry,
ordered_names: list[str],
strategy: LoadStrategy,
) -> dict[str, list[DeclarativeBase]]:
"""Load fixtures in order."""
results: dict[str, list[DeclarativeBase]] = {}
for name in ordered_names:
fixture = registry.get(name)
instances = list(fixture.func())
if not instances:
results[name] = []
continue
model_name = type(instances[0]).__name__
loaded: list[DeclarativeBase] = []
async with get_transaction(session):
for instance in instances:
if strategy == LoadStrategy.INSERT:
session.add(instance)
loaded.append(instance)
elif strategy == LoadStrategy.MERGE:
merged = await session.merge(instance)
loaded.append(merged)
else: # LoadStrategy.SKIP_EXISTING
pk = _get_primary_key(instance)
if pk is not None:
existing = await session.get(type(instance), pk)
if existing is None:
session.add(instance)
loaded.append(instance)
else:
session.add(instance)
loaded.append(instance)
results[name] = loaded
logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")
return results
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
"""Get the primary key value of a model instance."""
mapper = instance.__class__.__mapper__
pk_cols = mapper.primary_key
if len(pk_cols) == 1:
return getattr(instance, pk_cols[0].name, None)
pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols)
if all(v is not None for v in pk_values):
return pk_values
return None
def get_obj_by_attr(
fixtures: Callable[[], Sequence[ModelType]], attr_name: str, value: Any
) -> ModelType:
"""Get a SQLAlchemy model instance by matching an attribute value.
Args:
fixtures: A fixture function registered via ``@registry.register``
that returns a sequence of SQLAlchemy model instances.
attr_name: Name of the attribute to match against.
value: Value to match.
Returns:
The first model instance where the attribute matches the given value.
Raises:
StopIteration: If no matching object is found in the fixture group.
"""
try:
return next(obj for obj in fixtures() if getattr(obj, attr_name) == value)
except StopIteration:
raise StopIteration(
f"No object with {attr_name}={value} found in fixture '{getattr(fixtures, '__name__', repr(fixtures))}'"
) from None
async def load_fixtures(
session: AsyncSession,
registry: FixtureRegistry,
*names: str,
strategy: LoadStrategy = LoadStrategy.MERGE,
) -> dict[str, list[DeclarativeBase]]:
"""Load specific fixtures by name with dependencies.
Args:
session: Database session
registry: Fixture registry
*names: Fixture names to load (dependencies auto-resolved)
strategy: How to handle existing records
Returns:
Dict mapping fixture names to loaded instances
"""
ordered = registry.resolve_dependencies(*names)
return await _load_ordered(session, registry, ordered, strategy)
async def load_fixtures_by_context(
session: AsyncSession,
registry: FixtureRegistry,
*contexts: str | Context,
strategy: LoadStrategy = LoadStrategy.MERGE,
) -> dict[str, list[DeclarativeBase]]:
"""Load all fixtures for specific contexts.
Args:
session: Database session
registry: Fixture registry
*contexts: Contexts to load (e.g., Context.BASE, Context.TESTING)
strategy: How to handle existing records
Returns:
Dict mapping fixture names to loaded instances
"""
ordered = registry.resolve_context_dependencies(*contexts)
return await _load_ordered(session, registry, ordered, strategy)