chore: move pytest fixture plugin + update fixture module structure

This commit is contained in:
2026-01-28 07:29:34 -05:00
parent 45001767aa
commit 53e80cd0d5
7 changed files with 165 additions and 166 deletions

View File

@@ -1,11 +1,6 @@
from .fixtures import ( from .enum import LoadStrategy
Context, from .registry import Context, FixtureRegistry
FixtureRegistry, from .utils import get_obj_by_attr, load_fixtures, load_fixtures_by_context
LoadStrategy,
load_fixtures,
load_fixtures_by_context,
)
from .utils import get_obj_by_attr
__all__ = [ __all__ = [
"Context", "Context",
@@ -16,12 +11,3 @@ __all__ = [
"load_fixtures_by_context", "load_fixtures_by_context",
"register_fixtures", "register_fixtures",
] ]
# We lazy-load register_fixtures to avoid needing pytest when using fixtures CLI
def __getattr__(name: str):
if name == "register_fixtures":
from .pytest_plugin import register_fixtures
return register_fixtures
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@@ -0,0 +1,30 @@
from enum import Enum
class LoadStrategy(str, Enum):
"""Strategy for loading fixtures into the database."""
INSERT = "insert"
"""Insert new records. Fails if record already exists."""
MERGE = "merge"
"""Insert or update based on primary key (SQLAlchemy merge)."""
SKIP_EXISTING = "skip_existing"
"""Insert only if record doesn't exist (based on primary key)."""
class Context(str, Enum):
"""Predefined fixture contexts."""
BASE = "base"
"""Base fixtures loaded in all environments."""
PRODUCTION = "production"
"""Production-only fixtures."""
DEVELOPMENT = "development"
"""Development fixtures."""
TESTING = "testing"
"""Test fixtures."""

View File

@@ -3,46 +3,15 @@
import logging import logging
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum
from typing import Any, cast from typing import Any, cast
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import DeclarativeBase
from ..db import get_transaction from .enum import Context
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class LoadStrategy(str, Enum):
"""Strategy for loading fixtures into the database."""
INSERT = "insert"
"""Insert new records. Fails if record already exists."""
MERGE = "merge"
"""Insert or update based on primary key (SQLAlchemy merge)."""
SKIP_EXISTING = "skip_existing"
"""Insert only if record doesn't exist (based on primary key)."""
class Context(str, Enum):
"""Predefined fixture contexts."""
BASE = "base"
"""Base fixtures loaded in all environments."""
PRODUCTION = "production"
"""Production-only fixtures."""
DEVELOPMENT = "development"
"""Development fixtures."""
TESTING = "testing"
"""Test fixtures."""
@dataclass @dataclass
class Fixture: class Fixture:
"""A fixture definition with metadata.""" """A fixture definition with metadata."""
@@ -204,118 +173,3 @@ class FixtureRegistry:
all_deps.update(deps) all_deps.update(deps)
return self.resolve_dependencies(*all_deps) return self.resolve_dependencies(*all_deps)
async def load_fixtures(
session: AsyncSession,
registry: FixtureRegistry,
*names: str,
strategy: LoadStrategy = LoadStrategy.MERGE,
) -> dict[str, list[DeclarativeBase]]:
"""Load specific fixtures by name with dependencies.
Args:
session: Database session
registry: Fixture registry
*names: Fixture names to load (dependencies auto-resolved)
strategy: How to handle existing records
Returns:
Dict mapping fixture names to loaded instances
Example:
# Loads 'roles' first (dependency), then 'users'
result = await load_fixtures(session, fixtures, "users")
print(result["users"]) # [User(...), ...]
"""
ordered = registry.resolve_dependencies(*names)
return await _load_ordered(session, registry, ordered, strategy)
async def load_fixtures_by_context(
session: AsyncSession,
registry: FixtureRegistry,
*contexts: str | Context,
strategy: LoadStrategy = LoadStrategy.MERGE,
) -> dict[str, list[DeclarativeBase]]:
"""Load all fixtures for specific contexts.
Args:
session: Database session
registry: Fixture registry
*contexts: Contexts to load (e.g., Context.BASE, Context.TESTING)
strategy: How to handle existing records
Returns:
Dict mapping fixture names to loaded instances
Example:
# Load base + testing fixtures
await load_fixtures_by_context(
session, fixtures,
Context.BASE, Context.TESTING
)
"""
ordered = registry.resolve_context_dependencies(*contexts)
return await _load_ordered(session, registry, ordered, strategy)
async def _load_ordered(
session: AsyncSession,
registry: FixtureRegistry,
ordered_names: list[str],
strategy: LoadStrategy,
) -> dict[str, list[DeclarativeBase]]:
"""Load fixtures in order."""
results: dict[str, list[DeclarativeBase]] = {}
for name in ordered_names:
fixture = registry.get(name)
instances = list(fixture.func())
if not instances:
results[name] = []
continue
model_name = type(instances[0]).__name__
loaded: list[DeclarativeBase] = []
async with get_transaction(session):
for instance in instances:
if strategy == LoadStrategy.INSERT:
session.add(instance)
loaded.append(instance)
elif strategy == LoadStrategy.MERGE:
merged = await session.merge(instance)
loaded.append(merged)
elif strategy == LoadStrategy.SKIP_EXISTING:
pk = _get_primary_key(instance)
if pk is not None:
existing = await session.get(type(instance), pk)
if existing is None:
session.add(instance)
loaded.append(instance)
else:
session.add(instance)
loaded.append(instance)
results[name] = loaded
logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")
return results
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
"""Get the primary key value of a model instance."""
mapper = instance.__class__.__mapper__
pk_cols = mapper.primary_key
if len(pk_cols) == 1:
return getattr(instance, pk_cols[0].name, None)
pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols)
if all(v is not None for v in pk_values):
return pk_values
return None

View File

@@ -1,8 +1,16 @@
import logging
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from typing import Any, TypeVar from typing import Any, TypeVar
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import DeclarativeBase
from ..db import get_transaction
from .enum import LoadStrategy
from .registry import Context, FixtureRegistry
logger = logging.getLogger(__name__)
T = TypeVar("T", bound=DeclarativeBase) T = TypeVar("T", bound=DeclarativeBase)
@@ -24,3 +32,118 @@ def get_obj_by_attr(
StopIteration: If no matching object is found. StopIteration: If no matching object is found.
""" """
return next(obj for obj in fixtures() if getattr(obj, attr_name) == value) return next(obj for obj in fixtures() if getattr(obj, attr_name) == value)
async def load_fixtures(
session: AsyncSession,
registry: FixtureRegistry,
*names: str,
strategy: LoadStrategy = LoadStrategy.MERGE,
) -> dict[str, list[DeclarativeBase]]:
"""Load specific fixtures by name with dependencies.
Args:
session: Database session
registry: Fixture registry
*names: Fixture names to load (dependencies auto-resolved)
strategy: How to handle existing records
Returns:
Dict mapping fixture names to loaded instances
Example:
# Loads 'roles' first (dependency), then 'users'
result = await load_fixtures(session, fixtures, "users")
print(result["users"]) # [User(...), ...]
"""
ordered = registry.resolve_dependencies(*names)
return await _load_ordered(session, registry, ordered, strategy)
async def load_fixtures_by_context(
session: AsyncSession,
registry: FixtureRegistry,
*contexts: str | Context,
strategy: LoadStrategy = LoadStrategy.MERGE,
) -> dict[str, list[DeclarativeBase]]:
"""Load all fixtures for specific contexts.
Args:
session: Database session
registry: Fixture registry
*contexts: Contexts to load (e.g., Context.BASE, Context.TESTING)
strategy: How to handle existing records
Returns:
Dict mapping fixture names to loaded instances
Example:
# Load base + testing fixtures
await load_fixtures_by_context(
session, fixtures,
Context.BASE, Context.TESTING
)
"""
ordered = registry.resolve_context_dependencies(*contexts)
return await _load_ordered(session, registry, ordered, strategy)
async def _load_ordered(
session: AsyncSession,
registry: FixtureRegistry,
ordered_names: list[str],
strategy: LoadStrategy,
) -> dict[str, list[DeclarativeBase]]:
"""Load fixtures in order."""
results: dict[str, list[DeclarativeBase]] = {}
for name in ordered_names:
fixture = registry.get(name)
instances = list(fixture.func())
if not instances:
results[name] = []
continue
model_name = type(instances[0]).__name__
loaded: list[DeclarativeBase] = []
async with get_transaction(session):
for instance in instances:
if strategy == LoadStrategy.INSERT:
session.add(instance)
loaded.append(instance)
elif strategy == LoadStrategy.MERGE:
merged = await session.merge(instance)
loaded.append(merged)
elif strategy == LoadStrategy.SKIP_EXISTING:
pk = _get_primary_key(instance)
if pk is not None:
existing = await session.get(type(instance), pk)
if existing is None:
session.add(instance)
loaded.append(instance)
else:
session.add(instance)
loaded.append(instance)
results[name] = loaded
logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")
return results
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
"""Get the primary key value of a model instance."""
mapper = instance.__class__.__mapper__
pk_cols = mapper.primary_key
if len(pk_cols) == 1:
return getattr(instance, pk_cols[0].name, None)
pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols)
if all(v is not None for v in pk_values):
return pk_values
return None

View File

@@ -0,0 +1,5 @@
from .plugin import register_fixtures
__all__ = [
"register_fixtures",
]

View File

@@ -59,7 +59,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import DeclarativeBase
from ..db import get_transaction from ..db import get_transaction
from .fixtures import FixtureRegistry, LoadStrategy from ..fixtures import FixtureRegistry, LoadStrategy
def register_fixtures( def register_fixtures(

View File

@@ -4,7 +4,8 @@ import pytest
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
from fastapi_toolsets.fixtures import Context, FixtureRegistry, register_fixtures from fastapi_toolsets.fixtures import Context, FixtureRegistry
from fastapi_toolsets.pytest import register_fixtures
from .conftest import Role, RoleCrud, User, UserCrud from .conftest import Role, RoleCrud, User, UserCrud