mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-03-01 17:00:48 +01:00
chore: move pytest fixture plugin + update fixture module structure
This commit is contained in:
@@ -1,11 +1,6 @@
|
||||
from .fixtures import (
|
||||
Context,
|
||||
FixtureRegistry,
|
||||
LoadStrategy,
|
||||
load_fixtures,
|
||||
load_fixtures_by_context,
|
||||
)
|
||||
from .utils import get_obj_by_attr
|
||||
from .enum import LoadStrategy
|
||||
from .registry import Context, FixtureRegistry
|
||||
from .utils import get_obj_by_attr, load_fixtures, load_fixtures_by_context
|
||||
|
||||
__all__ = [
|
||||
"Context",
|
||||
@@ -16,12 +11,3 @@ __all__ = [
|
||||
"load_fixtures_by_context",
|
||||
"register_fixtures",
|
||||
]
|
||||
|
||||
|
||||
# We lazy-load register_fixtures to avoid needing pytest when using fixtures CLI
|
||||
def __getattr__(name: str):
|
||||
if name == "register_fixtures":
|
||||
from .pytest_plugin import register_fixtures
|
||||
|
||||
return register_fixtures
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
30
src/fastapi_toolsets/fixtures/enum.py
Normal file
30
src/fastapi_toolsets/fixtures/enum.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class LoadStrategy(str, Enum):
|
||||
"""Strategy for loading fixtures into the database."""
|
||||
|
||||
INSERT = "insert"
|
||||
"""Insert new records. Fails if record already exists."""
|
||||
|
||||
MERGE = "merge"
|
||||
"""Insert or update based on primary key (SQLAlchemy merge)."""
|
||||
|
||||
SKIP_EXISTING = "skip_existing"
|
||||
"""Insert only if record doesn't exist (based on primary key)."""
|
||||
|
||||
|
||||
class Context(str, Enum):
|
||||
"""Predefined fixture contexts."""
|
||||
|
||||
BASE = "base"
|
||||
"""Base fixtures loaded in all environments."""
|
||||
|
||||
PRODUCTION = "production"
|
||||
"""Production-only fixtures."""
|
||||
|
||||
DEVELOPMENT = "development"
|
||||
"""Development fixtures."""
|
||||
|
||||
TESTING = "testing"
|
||||
"""Test fixtures."""
|
||||
@@ -1,204 +0,0 @@
|
||||
"""Pytest plugin for using FixtureRegistry fixtures in tests.
|
||||
|
||||
This module provides utilities to automatically generate pytest fixtures
|
||||
from your FixtureRegistry, with proper dependency resolution.
|
||||
|
||||
Example:
|
||||
# conftest.py
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
from app.fixtures import fixtures # Your FixtureRegistry
|
||||
from app.models import Base
|
||||
from fastapi_toolsets.pytest_plugin import register_fixtures
|
||||
|
||||
DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost:5432/test_db"
|
||||
|
||||
@pytest.fixture
|
||||
async def engine():
|
||||
engine = create_async_engine(DATABASE_URL)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session(engine):
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
session = session_factory()
|
||||
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
|
||||
# Automatically generate pytest fixtures from registry
|
||||
# Creates: fixture_roles, fixture_users, fixture_posts, etc.
|
||||
register_fixtures(fixtures, globals())
|
||||
|
||||
Usage in tests:
|
||||
# test_users.py
|
||||
async def test_user_count(db_session, fixture_users):
|
||||
# fixture_users automatically loads fixture_roles first (if dependency)
|
||||
# and returns the list of User models
|
||||
assert len(fixture_users) > 0
|
||||
|
||||
async def test_user_role(db_session, fixture_users):
|
||||
user = fixture_users[0]
|
||||
assert user.role_id is not None
|
||||
"""
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from ..db import get_transaction
|
||||
from .fixtures import FixtureRegistry, LoadStrategy
|
||||
|
||||
|
||||
def register_fixtures(
|
||||
registry: FixtureRegistry,
|
||||
namespace: dict[str, Any],
|
||||
*,
|
||||
prefix: str = "fixture_",
|
||||
session_fixture: str = "db_session",
|
||||
strategy: LoadStrategy = LoadStrategy.MERGE,
|
||||
) -> list[str]:
|
||||
"""Register pytest fixtures from a FixtureRegistry.
|
||||
|
||||
Automatically creates pytest fixtures for each fixture in the registry.
|
||||
Dependencies are resolved via pytest fixture dependencies.
|
||||
|
||||
Args:
|
||||
registry: The FixtureRegistry containing fixtures
|
||||
namespace: The module's globals() dict to add fixtures to
|
||||
prefix: Prefix for generated fixture names (default: "fixture_")
|
||||
session_fixture: Name of the db session fixture (default: "db_session")
|
||||
strategy: Loading strategy for fixtures (default: MERGE)
|
||||
|
||||
Returns:
|
||||
List of created fixture names
|
||||
|
||||
Example:
|
||||
# conftest.py
|
||||
from app.fixtures import fixtures
|
||||
from fastapi_toolsets.pytest_plugin import register_fixtures
|
||||
|
||||
register_fixtures(fixtures, globals())
|
||||
|
||||
# Creates fixtures like:
|
||||
# - fixture_roles
|
||||
# - fixture_users (depends on fixture_roles if users depends on roles)
|
||||
# - fixture_posts (depends on fixture_users if posts depends on users)
|
||||
"""
|
||||
created_fixtures: list[str] = []
|
||||
|
||||
for fixture in registry.get_all():
|
||||
fixture_name = f"{prefix}{fixture.name}"
|
||||
|
||||
# Build list of pytest fixture dependencies
|
||||
pytest_deps = [session_fixture]
|
||||
for dep in fixture.depends_on:
|
||||
pytest_deps.append(f"{prefix}{dep}")
|
||||
|
||||
# Create the fixture function
|
||||
fixture_func = _create_fixture_function(
|
||||
registry=registry,
|
||||
fixture_name=fixture.name,
|
||||
dependencies=pytest_deps,
|
||||
strategy=strategy,
|
||||
)
|
||||
|
||||
# Apply pytest.fixture decorator
|
||||
decorated = pytest.fixture(fixture_func)
|
||||
|
||||
# Add to namespace
|
||||
namespace[fixture_name] = decorated
|
||||
created_fixtures.append(fixture_name)
|
||||
|
||||
return created_fixtures
|
||||
|
||||
|
||||
def _create_fixture_function(
|
||||
registry: FixtureRegistry,
|
||||
fixture_name: str,
|
||||
dependencies: list[str],
|
||||
strategy: LoadStrategy,
|
||||
) -> Callable[..., Any]:
|
||||
"""Create a fixture function with the correct signature.
|
||||
|
||||
The function signature must include all dependencies as parameters
|
||||
for pytest to resolve them correctly.
|
||||
"""
|
||||
# Get the fixture definition
|
||||
fixture_def = registry.get(fixture_name)
|
||||
|
||||
# Build the function dynamically with correct parameters
|
||||
# We need the session as first param, then all dependencies
|
||||
async def fixture_func(**kwargs: Any) -> Sequence[DeclarativeBase]:
|
||||
# Get session from kwargs (first dependency)
|
||||
session: AsyncSession = kwargs[dependencies[0]]
|
||||
|
||||
# Load the fixture data
|
||||
instances = list(fixture_def.func())
|
||||
|
||||
if not instances:
|
||||
return []
|
||||
|
||||
loaded: list[DeclarativeBase] = []
|
||||
|
||||
async with get_transaction(session):
|
||||
for instance in instances:
|
||||
if strategy == LoadStrategy.INSERT:
|
||||
session.add(instance)
|
||||
loaded.append(instance)
|
||||
elif strategy == LoadStrategy.MERGE:
|
||||
merged = await session.merge(instance)
|
||||
loaded.append(merged)
|
||||
elif strategy == LoadStrategy.SKIP_EXISTING:
|
||||
pk = _get_primary_key(instance)
|
||||
if pk is not None:
|
||||
existing = await session.get(type(instance), pk)
|
||||
if existing is None:
|
||||
session.add(instance)
|
||||
loaded.append(instance)
|
||||
else:
|
||||
loaded.append(existing)
|
||||
else:
|
||||
session.add(instance)
|
||||
loaded.append(instance)
|
||||
|
||||
return loaded
|
||||
|
||||
# Update function signature to include dependencies
|
||||
# This is needed for pytest to inject the right fixtures
|
||||
params = ", ".join(dependencies)
|
||||
code = f"async def {fixture_name}_fixture({params}):\n return await _impl({', '.join(f'{d}={d}' for d in dependencies)})"
|
||||
|
||||
local_ns: dict[str, Any] = {"_impl": fixture_func}
|
||||
exec(code, local_ns) # noqa: S102
|
||||
|
||||
created_func = local_ns[f"{fixture_name}_fixture"]
|
||||
created_func.__doc__ = f"Load {fixture_name} fixture data."
|
||||
|
||||
return created_func
|
||||
|
||||
|
||||
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
|
||||
"""Get the primary key value of a model instance."""
|
||||
mapper = instance.__class__.__mapper__
|
||||
pk_cols = mapper.primary_key
|
||||
|
||||
if len(pk_cols) == 1:
|
||||
return getattr(instance, pk_cols[0].name, None)
|
||||
|
||||
pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols)
|
||||
if all(v is not None for v in pk_values):
|
||||
return pk_values
|
||||
return None
|
||||
@@ -3,46 +3,15 @@
|
||||
import logging
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from ..db import get_transaction
|
||||
from .enum import Context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoadStrategy(str, Enum):
|
||||
"""Strategy for loading fixtures into the database."""
|
||||
|
||||
INSERT = "insert"
|
||||
"""Insert new records. Fails if record already exists."""
|
||||
|
||||
MERGE = "merge"
|
||||
"""Insert or update based on primary key (SQLAlchemy merge)."""
|
||||
|
||||
SKIP_EXISTING = "skip_existing"
|
||||
"""Insert only if record doesn't exist (based on primary key)."""
|
||||
|
||||
|
||||
class Context(str, Enum):
|
||||
"""Predefined fixture contexts."""
|
||||
|
||||
BASE = "base"
|
||||
"""Base fixtures loaded in all environments."""
|
||||
|
||||
PRODUCTION = "production"
|
||||
"""Production-only fixtures."""
|
||||
|
||||
DEVELOPMENT = "development"
|
||||
"""Development fixtures."""
|
||||
|
||||
TESTING = "testing"
|
||||
"""Test fixtures."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class Fixture:
|
||||
"""A fixture definition with metadata."""
|
||||
@@ -204,118 +173,3 @@ class FixtureRegistry:
|
||||
all_deps.update(deps)
|
||||
|
||||
return self.resolve_dependencies(*all_deps)
|
||||
|
||||
|
||||
async def load_fixtures(
|
||||
session: AsyncSession,
|
||||
registry: FixtureRegistry,
|
||||
*names: str,
|
||||
strategy: LoadStrategy = LoadStrategy.MERGE,
|
||||
) -> dict[str, list[DeclarativeBase]]:
|
||||
"""Load specific fixtures by name with dependencies.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
registry: Fixture registry
|
||||
*names: Fixture names to load (dependencies auto-resolved)
|
||||
strategy: How to handle existing records
|
||||
|
||||
Returns:
|
||||
Dict mapping fixture names to loaded instances
|
||||
|
||||
Example:
|
||||
# Loads 'roles' first (dependency), then 'users'
|
||||
result = await load_fixtures(session, fixtures, "users")
|
||||
print(result["users"]) # [User(...), ...]
|
||||
"""
|
||||
ordered = registry.resolve_dependencies(*names)
|
||||
return await _load_ordered(session, registry, ordered, strategy)
|
||||
|
||||
|
||||
async def load_fixtures_by_context(
|
||||
session: AsyncSession,
|
||||
registry: FixtureRegistry,
|
||||
*contexts: str | Context,
|
||||
strategy: LoadStrategy = LoadStrategy.MERGE,
|
||||
) -> dict[str, list[DeclarativeBase]]:
|
||||
"""Load all fixtures for specific contexts.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
registry: Fixture registry
|
||||
*contexts: Contexts to load (e.g., Context.BASE, Context.TESTING)
|
||||
strategy: How to handle existing records
|
||||
|
||||
Returns:
|
||||
Dict mapping fixture names to loaded instances
|
||||
|
||||
Example:
|
||||
# Load base + testing fixtures
|
||||
await load_fixtures_by_context(
|
||||
session, fixtures,
|
||||
Context.BASE, Context.TESTING
|
||||
)
|
||||
"""
|
||||
ordered = registry.resolve_context_dependencies(*contexts)
|
||||
return await _load_ordered(session, registry, ordered, strategy)
|
||||
|
||||
|
||||
async def _load_ordered(
|
||||
session: AsyncSession,
|
||||
registry: FixtureRegistry,
|
||||
ordered_names: list[str],
|
||||
strategy: LoadStrategy,
|
||||
) -> dict[str, list[DeclarativeBase]]:
|
||||
"""Load fixtures in order."""
|
||||
results: dict[str, list[DeclarativeBase]] = {}
|
||||
|
||||
for name in ordered_names:
|
||||
fixture = registry.get(name)
|
||||
instances = list(fixture.func())
|
||||
|
||||
if not instances:
|
||||
results[name] = []
|
||||
continue
|
||||
|
||||
model_name = type(instances[0]).__name__
|
||||
loaded: list[DeclarativeBase] = []
|
||||
|
||||
async with get_transaction(session):
|
||||
for instance in instances:
|
||||
if strategy == LoadStrategy.INSERT:
|
||||
session.add(instance)
|
||||
loaded.append(instance)
|
||||
|
||||
elif strategy == LoadStrategy.MERGE:
|
||||
merged = await session.merge(instance)
|
||||
loaded.append(merged)
|
||||
|
||||
elif strategy == LoadStrategy.SKIP_EXISTING:
|
||||
pk = _get_primary_key(instance)
|
||||
if pk is not None:
|
||||
existing = await session.get(type(instance), pk)
|
||||
if existing is None:
|
||||
session.add(instance)
|
||||
loaded.append(instance)
|
||||
else:
|
||||
session.add(instance)
|
||||
loaded.append(instance)
|
||||
|
||||
results[name] = loaded
|
||||
logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
|
||||
"""Get the primary key value of a model instance."""
|
||||
mapper = instance.__class__.__mapper__
|
||||
pk_cols = mapper.primary_key
|
||||
|
||||
if len(pk_cols) == 1:
|
||||
return getattr(instance, pk_cols[0].name, None)
|
||||
|
||||
pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols)
|
||||
if all(v is not None for v in pk_values):
|
||||
return pk_values
|
||||
return None
|
||||
@@ -1,8 +1,16 @@
|
||||
import logging
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from ..db import get_transaction
|
||||
from .enum import LoadStrategy
|
||||
from .registry import Context, FixtureRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T", bound=DeclarativeBase)
|
||||
|
||||
|
||||
@@ -24,3 +32,118 @@ def get_obj_by_attr(
|
||||
StopIteration: If no matching object is found.
|
||||
"""
|
||||
return next(obj for obj in fixtures() if getattr(obj, attr_name) == value)
|
||||
|
||||
|
||||
async def load_fixtures(
|
||||
session: AsyncSession,
|
||||
registry: FixtureRegistry,
|
||||
*names: str,
|
||||
strategy: LoadStrategy = LoadStrategy.MERGE,
|
||||
) -> dict[str, list[DeclarativeBase]]:
|
||||
"""Load specific fixtures by name with dependencies.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
registry: Fixture registry
|
||||
*names: Fixture names to load (dependencies auto-resolved)
|
||||
strategy: How to handle existing records
|
||||
|
||||
Returns:
|
||||
Dict mapping fixture names to loaded instances
|
||||
|
||||
Example:
|
||||
# Loads 'roles' first (dependency), then 'users'
|
||||
result = await load_fixtures(session, fixtures, "users")
|
||||
print(result["users"]) # [User(...), ...]
|
||||
"""
|
||||
ordered = registry.resolve_dependencies(*names)
|
||||
return await _load_ordered(session, registry, ordered, strategy)
|
||||
|
||||
|
||||
async def load_fixtures_by_context(
|
||||
session: AsyncSession,
|
||||
registry: FixtureRegistry,
|
||||
*contexts: str | Context,
|
||||
strategy: LoadStrategy = LoadStrategy.MERGE,
|
||||
) -> dict[str, list[DeclarativeBase]]:
|
||||
"""Load all fixtures for specific contexts.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
registry: Fixture registry
|
||||
*contexts: Contexts to load (e.g., Context.BASE, Context.TESTING)
|
||||
strategy: How to handle existing records
|
||||
|
||||
Returns:
|
||||
Dict mapping fixture names to loaded instances
|
||||
|
||||
Example:
|
||||
# Load base + testing fixtures
|
||||
await load_fixtures_by_context(
|
||||
session, fixtures,
|
||||
Context.BASE, Context.TESTING
|
||||
)
|
||||
"""
|
||||
ordered = registry.resolve_context_dependencies(*contexts)
|
||||
return await _load_ordered(session, registry, ordered, strategy)
|
||||
|
||||
|
||||
async def _load_ordered(
|
||||
session: AsyncSession,
|
||||
registry: FixtureRegistry,
|
||||
ordered_names: list[str],
|
||||
strategy: LoadStrategy,
|
||||
) -> dict[str, list[DeclarativeBase]]:
|
||||
"""Load fixtures in order."""
|
||||
results: dict[str, list[DeclarativeBase]] = {}
|
||||
|
||||
for name in ordered_names:
|
||||
fixture = registry.get(name)
|
||||
instances = list(fixture.func())
|
||||
|
||||
if not instances:
|
||||
results[name] = []
|
||||
continue
|
||||
|
||||
model_name = type(instances[0]).__name__
|
||||
loaded: list[DeclarativeBase] = []
|
||||
|
||||
async with get_transaction(session):
|
||||
for instance in instances:
|
||||
if strategy == LoadStrategy.INSERT:
|
||||
session.add(instance)
|
||||
loaded.append(instance)
|
||||
|
||||
elif strategy == LoadStrategy.MERGE:
|
||||
merged = await session.merge(instance)
|
||||
loaded.append(merged)
|
||||
|
||||
elif strategy == LoadStrategy.SKIP_EXISTING:
|
||||
pk = _get_primary_key(instance)
|
||||
if pk is not None:
|
||||
existing = await session.get(type(instance), pk)
|
||||
if existing is None:
|
||||
session.add(instance)
|
||||
loaded.append(instance)
|
||||
else:
|
||||
session.add(instance)
|
||||
loaded.append(instance)
|
||||
|
||||
results[name] = loaded
|
||||
logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
|
||||
"""Get the primary key value of a model instance."""
|
||||
mapper = instance.__class__.__mapper__
|
||||
pk_cols = mapper.primary_key
|
||||
|
||||
if len(pk_cols) == 1:
|
||||
return getattr(instance, pk_cols[0].name, None)
|
||||
|
||||
pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols)
|
||||
if all(v is not None for v in pk_values):
|
||||
return pk_values
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user