mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-03-01 17:00:48 +01:00
150 lines
4.6 KiB
Python
150 lines
4.6 KiB
Python
import logging
|
|
from collections.abc import Callable, Sequence
|
|
from typing import Any, TypeVar
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.orm import DeclarativeBase
|
|
|
|
from ..db import get_transaction
|
|
from .enum import LoadStrategy
|
|
from .registry import Context, FixtureRegistry
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
T = TypeVar("T", bound=DeclarativeBase)
|
|
|
|
|
|
def get_obj_by_attr(
|
|
fixtures: Callable[[], Sequence[T]], attr_name: str, value: Any
|
|
) -> T:
|
|
"""Get a SQLAlchemy model instance by matching an attribute value.
|
|
|
|
Args:
|
|
fixtures: A fixture function registered via ``@registry.register``
|
|
that returns a sequence of SQLAlchemy model instances.
|
|
attr_name: Name of the attribute to match against.
|
|
value: Value to match.
|
|
|
|
Returns:
|
|
The first model instance where the attribute matches the given value.
|
|
|
|
Raises:
|
|
StopIteration: If no matching object is found.
|
|
"""
|
|
return next(obj for obj in fixtures() if getattr(obj, attr_name) == value)
|
|
|
|
|
|
async def load_fixtures(
|
|
session: AsyncSession,
|
|
registry: FixtureRegistry,
|
|
*names: str,
|
|
strategy: LoadStrategy = LoadStrategy.MERGE,
|
|
) -> dict[str, list[DeclarativeBase]]:
|
|
"""Load specific fixtures by name with dependencies.
|
|
|
|
Args:
|
|
session: Database session
|
|
registry: Fixture registry
|
|
*names: Fixture names to load (dependencies auto-resolved)
|
|
strategy: How to handle existing records
|
|
|
|
Returns:
|
|
Dict mapping fixture names to loaded instances
|
|
|
|
Example:
|
|
# Loads 'roles' first (dependency), then 'users'
|
|
result = await load_fixtures(session, fixtures, "users")
|
|
print(result["users"]) # [User(...), ...]
|
|
"""
|
|
ordered = registry.resolve_dependencies(*names)
|
|
return await _load_ordered(session, registry, ordered, strategy)
|
|
|
|
|
|
async def load_fixtures_by_context(
|
|
session: AsyncSession,
|
|
registry: FixtureRegistry,
|
|
*contexts: str | Context,
|
|
strategy: LoadStrategy = LoadStrategy.MERGE,
|
|
) -> dict[str, list[DeclarativeBase]]:
|
|
"""Load all fixtures for specific contexts.
|
|
|
|
Args:
|
|
session: Database session
|
|
registry: Fixture registry
|
|
*contexts: Contexts to load (e.g., Context.BASE, Context.TESTING)
|
|
strategy: How to handle existing records
|
|
|
|
Returns:
|
|
Dict mapping fixture names to loaded instances
|
|
|
|
Example:
|
|
# Load base + testing fixtures
|
|
await load_fixtures_by_context(
|
|
session, fixtures,
|
|
Context.BASE, Context.TESTING
|
|
)
|
|
"""
|
|
ordered = registry.resolve_context_dependencies(*contexts)
|
|
return await _load_ordered(session, registry, ordered, strategy)
|
|
|
|
|
|
async def _load_ordered(
|
|
session: AsyncSession,
|
|
registry: FixtureRegistry,
|
|
ordered_names: list[str],
|
|
strategy: LoadStrategy,
|
|
) -> dict[str, list[DeclarativeBase]]:
|
|
"""Load fixtures in order."""
|
|
results: dict[str, list[DeclarativeBase]] = {}
|
|
|
|
for name in ordered_names:
|
|
fixture = registry.get(name)
|
|
instances = list(fixture.func())
|
|
|
|
if not instances:
|
|
results[name] = []
|
|
continue
|
|
|
|
model_name = type(instances[0]).__name__
|
|
loaded: list[DeclarativeBase] = []
|
|
|
|
async with get_transaction(session):
|
|
for instance in instances:
|
|
if strategy == LoadStrategy.INSERT:
|
|
session.add(instance)
|
|
loaded.append(instance)
|
|
|
|
elif strategy == LoadStrategy.MERGE:
|
|
merged = await session.merge(instance)
|
|
loaded.append(merged)
|
|
|
|
elif strategy == LoadStrategy.SKIP_EXISTING:
|
|
pk = _get_primary_key(instance)
|
|
if pk is not None:
|
|
existing = await session.get(type(instance), pk)
|
|
if existing is None:
|
|
session.add(instance)
|
|
loaded.append(instance)
|
|
else:
|
|
session.add(instance)
|
|
loaded.append(instance)
|
|
|
|
results[name] = loaded
|
|
logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")
|
|
|
|
return results
|
|
|
|
|
|
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
|
|
"""Get the primary key value of a model instance."""
|
|
mapper = instance.__class__.__mapper__
|
|
pk_cols = mapper.primary_key
|
|
|
|
if len(pk_cols) == 1:
|
|
return getattr(instance, pk_cols[0].name, None)
|
|
|
|
pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols)
|
|
if all(v is not None for v in pk_values):
|
|
return pk_values
|
|
return None
|