diff --git a/src/fastapi_toolsets/fixtures/__init__.py b/src/fastapi_toolsets/fixtures/__init__.py index fce6fee..0157c09 100644 --- a/src/fastapi_toolsets/fixtures/__init__.py +++ b/src/fastapi_toolsets/fixtures/__init__.py @@ -1,11 +1,6 @@ -from .fixtures import ( - Context, - FixtureRegistry, - LoadStrategy, - load_fixtures, - load_fixtures_by_context, -) -from .utils import get_obj_by_attr +from .enum import LoadStrategy +from .registry import Context, FixtureRegistry +from .utils import get_obj_by_attr, load_fixtures, load_fixtures_by_context __all__ = [ "Context", @@ -16,12 +11,3 @@ __all__ = [ "load_fixtures_by_context", "register_fixtures", ] - - -# We lazy-load register_fixtures to avoid needing pytest when using fixtures CLI -def __getattr__(name: str): - if name == "register_fixtures": - from .pytest_plugin import register_fixtures - - return register_fixtures - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/fastapi_toolsets/fixtures/enum.py b/src/fastapi_toolsets/fixtures/enum.py new file mode 100644 index 0000000..6a0b53b --- /dev/null +++ b/src/fastapi_toolsets/fixtures/enum.py @@ -0,0 +1,30 @@ +from enum import Enum + + +class LoadStrategy(str, Enum): + """Strategy for loading fixtures into the database.""" + + INSERT = "insert" + """Insert new records. Fails if record already exists.""" + + MERGE = "merge" + """Insert or update based on primary key (SQLAlchemy merge).""" + + SKIP_EXISTING = "skip_existing" + """Insert only if record doesn't exist (based on primary key).""" + + +class Context(str, Enum): + """Predefined fixture contexts.""" + + BASE = "base" + """Base fixtures loaded in all environments.""" + + PRODUCTION = "production" + """Production-only fixtures.""" + + DEVELOPMENT = "development" + """Development fixtures.""" + + TESTING = "testing" + """Test fixtures.""" diff --git a/src/fastapi_toolsets/fixtures/fixtures.py b/src/fastapi_toolsets/fixtures/registry.py similarity index 54% rename from src/fastapi_toolsets/fixtures/fixtures.py rename to src/fastapi_toolsets/fixtures/registry.py index 11273c5..17ca750 100644 --- a/src/fastapi_toolsets/fixtures/fixtures.py +++ b/src/fastapi_toolsets/fixtures/registry.py @@ -3,46 +3,15 @@ import logging from collections.abc import Callable, Sequence from dataclasses import dataclass, field -from enum import Enum from typing import Any, cast -from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import DeclarativeBase -from ..db import get_transaction +from .enum import Context logger = logging.getLogger(__name__) -class LoadStrategy(str, Enum): - """Strategy for loading fixtures into the database.""" - - INSERT = "insert" - """Insert new records. Fails if record already exists.""" - - MERGE = "merge" - """Insert or update based on primary key (SQLAlchemy merge).""" - - SKIP_EXISTING = "skip_existing" - """Insert only if record doesn't exist (based on primary key).""" - - -class Context(str, Enum): - """Predefined fixture contexts.""" - - BASE = "base" - """Base fixtures loaded in all environments.""" - - PRODUCTION = "production" - """Production-only fixtures.""" - - DEVELOPMENT = "development" - """Development fixtures.""" - - TESTING = "testing" - """Test fixtures.""" - - @dataclass class Fixture: """A fixture definition with metadata.""" @@ -204,118 +173,3 @@ class FixtureRegistry: all_deps.update(deps) return self.resolve_dependencies(*all_deps) - - -async def load_fixtures( - session: AsyncSession, - registry: FixtureRegistry, - *names: str, - strategy: LoadStrategy = LoadStrategy.MERGE, -) -> dict[str, list[DeclarativeBase]]: - """Load specific fixtures by name with dependencies. - - Args: - session: Database session - registry: Fixture registry - *names: Fixture names to load (dependencies auto-resolved) - strategy: How to handle existing records - - Returns: - Dict mapping fixture names to loaded instances - - Example: - # Loads 'roles' first (dependency), then 'users' - result = await load_fixtures(session, fixtures, "users") - print(result["users"]) # [User(...), ...] - """ - ordered = registry.resolve_dependencies(*names) - return await _load_ordered(session, registry, ordered, strategy) - - -async def load_fixtures_by_context( - session: AsyncSession, - registry: FixtureRegistry, - *contexts: str | Context, - strategy: LoadStrategy = LoadStrategy.MERGE, -) -> dict[str, list[DeclarativeBase]]: - """Load all fixtures for specific contexts. - - Args: - session: Database session - registry: Fixture registry - *contexts: Contexts to load (e.g., Context.BASE, Context.TESTING) - strategy: How to handle existing records - - Returns: - Dict mapping fixture names to loaded instances - - Example: - # Load base + testing fixtures - await load_fixtures_by_context( - session, fixtures, - Context.BASE, Context.TESTING - ) - """ - ordered = registry.resolve_context_dependencies(*contexts) - return await _load_ordered(session, registry, ordered, strategy) - - -async def _load_ordered( - session: AsyncSession, - registry: FixtureRegistry, - ordered_names: list[str], - strategy: LoadStrategy, -) -> dict[str, list[DeclarativeBase]]: - """Load fixtures in order.""" - results: dict[str, list[DeclarativeBase]] = {} - - for name in ordered_names: - fixture = registry.get(name) - instances = list(fixture.func()) - - if not instances: - results[name] = [] - continue - - model_name = type(instances[0]).__name__ - loaded: list[DeclarativeBase] = [] - - async with get_transaction(session): - for instance in instances: - if strategy == LoadStrategy.INSERT: - session.add(instance) - loaded.append(instance) - - elif strategy == LoadStrategy.MERGE: - merged = await session.merge(instance) - loaded.append(merged) - - elif strategy == LoadStrategy.SKIP_EXISTING: - pk = _get_primary_key(instance) - if pk is not None: - existing = await session.get(type(instance), pk) - if existing is None: - session.add(instance) - loaded.append(instance) - else: - session.add(instance) - loaded.append(instance) - - results[name] = loaded - logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)") - - return results - - -def _get_primary_key(instance: DeclarativeBase) -> Any | None: - """Get the primary key value of a model instance.""" - mapper = instance.__class__.__mapper__ - pk_cols = mapper.primary_key - - if len(pk_cols) == 1: - return getattr(instance, pk_cols[0].name, None) - - pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols) - if all(v is not None for v in pk_values): - return pk_values - return None diff --git a/src/fastapi_toolsets/fixtures/utils.py b/src/fastapi_toolsets/fixtures/utils.py index 106c706..e1a88eb 100644 --- a/src/fastapi_toolsets/fixtures/utils.py +++ b/src/fastapi_toolsets/fixtures/utils.py @@ -1,8 +1,16 @@ +import logging from collections.abc import Callable, Sequence from typing import Any, TypeVar +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import DeclarativeBase +from ..db import get_transaction +from .enum import LoadStrategy +from .registry import Context, FixtureRegistry + +logger = logging.getLogger(__name__) + T = TypeVar("T", bound=DeclarativeBase) @@ -24,3 +32,118 @@ def get_obj_by_attr( StopIteration: If no matching object is found. """ return next(obj for obj in fixtures() if getattr(obj, attr_name) == value) + + +async def load_fixtures( + session: AsyncSession, + registry: FixtureRegistry, + *names: str, + strategy: LoadStrategy = LoadStrategy.MERGE, +) -> dict[str, list[DeclarativeBase]]: + """Load specific fixtures by name with dependencies. + + Args: + session: Database session + registry: Fixture registry + *names: Fixture names to load (dependencies auto-resolved) + strategy: How to handle existing records + + Returns: + Dict mapping fixture names to loaded instances + + Example: + # Loads 'roles' first (dependency), then 'users' + result = await load_fixtures(session, fixtures, "users") + print(result["users"]) # [User(...), ...] + """ + ordered = registry.resolve_dependencies(*names) + return await _load_ordered(session, registry, ordered, strategy) + + +async def load_fixtures_by_context( + session: AsyncSession, + registry: FixtureRegistry, + *contexts: str | Context, + strategy: LoadStrategy = LoadStrategy.MERGE, +) -> dict[str, list[DeclarativeBase]]: + """Load all fixtures for specific contexts. + + Args: + session: Database session + registry: Fixture registry + *contexts: Contexts to load (e.g., Context.BASE, Context.TESTING) + strategy: How to handle existing records + + Returns: + Dict mapping fixture names to loaded instances + + Example: + # Load base + testing fixtures + await load_fixtures_by_context( + session, fixtures, + Context.BASE, Context.TESTING + ) + """ + ordered = registry.resolve_context_dependencies(*contexts) + return await _load_ordered(session, registry, ordered, strategy) + + +async def _load_ordered( + session: AsyncSession, + registry: FixtureRegistry, + ordered_names: list[str], + strategy: LoadStrategy, +) -> dict[str, list[DeclarativeBase]]: + """Load fixtures in order.""" + results: dict[str, list[DeclarativeBase]] = {} + + for name in ordered_names: + fixture = registry.get(name) + instances = list(fixture.func()) + + if not instances: + results[name] = [] + continue + + model_name = type(instances[0]).__name__ + loaded: list[DeclarativeBase] = [] + + async with get_transaction(session): + for instance in instances: + if strategy == LoadStrategy.INSERT: + session.add(instance) + loaded.append(instance) + + elif strategy == LoadStrategy.MERGE: + merged = await session.merge(instance) + loaded.append(merged) + + elif strategy == LoadStrategy.SKIP_EXISTING: + pk = _get_primary_key(instance) + if pk is not None: + existing = await session.get(type(instance), pk) + if existing is None: + session.add(instance) + loaded.append(instance) + else: + session.add(instance) + loaded.append(instance) + + results[name] = loaded + logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)") + + return results + + +def _get_primary_key(instance: DeclarativeBase) -> Any | None: + """Get the primary key value of a model instance.""" + mapper = instance.__class__.__mapper__ + pk_cols = mapper.primary_key + + if len(pk_cols) == 1: + return getattr(instance, pk_cols[0].name, None) + + pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols) + if all(v is not None for v in pk_values): + return pk_values + return None diff --git a/src/fastapi_toolsets/pytest/__init__.py b/src/fastapi_toolsets/pytest/__init__.py new file mode 100644 index 0000000..20fa819 --- /dev/null +++ b/src/fastapi_toolsets/pytest/__init__.py @@ -0,0 +1,5 @@ +from .plugin import register_fixtures + +__all__ = [ + "register_fixtures", +] diff --git a/src/fastapi_toolsets/fixtures/pytest_plugin.py b/src/fastapi_toolsets/pytest/plugin.py similarity index 99% rename from src/fastapi_toolsets/fixtures/pytest_plugin.py rename to src/fastapi_toolsets/pytest/plugin.py index 29b3446..e10a770 100644 --- a/src/fastapi_toolsets/fixtures/pytest_plugin.py +++ b/src/fastapi_toolsets/pytest/plugin.py @@ -59,7 +59,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import DeclarativeBase from ..db import get_transaction -from .fixtures import FixtureRegistry, LoadStrategy +from ..fixtures import FixtureRegistry, LoadStrategy def register_fixtures( diff --git a/tests/test_pytest_plugin.py b/tests/test_pytest_plugin.py index 62aabf5..d88095e 100644 --- a/tests/test_pytest_plugin.py +++ b/tests/test_pytest_plugin.py @@ -4,7 +4,8 @@ import pytest from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload -from fastapi_toolsets.fixtures import Context, FixtureRegistry, register_fixtures +from fastapi_toolsets.fixtures import Context, FixtureRegistry +from fastapi_toolsets.pytest import register_fixtures from .conftest import Role, RoleCrud, User, UserCrud