diff --git a/docs/module/fixtures.md b/docs/module/fixtures.md index e3a1005..dd2df82 100644 --- a/docs/module/fixtures.md +++ b/docs/module/fixtures.md @@ -38,18 +38,20 @@ By context with [`load_fixtures_by_context`](../reference/fixtures.md#fastapi_to from fastapi_toolsets.fixtures import load_fixtures_by_context async with db_context() as session: - await load_fixtures_by_context(session=session, registry=fixtures, context=Context.TESTING) + await load_fixtures_by_context(session, fixtures, Context.TESTING) ``` -Directly with [`load_fixtures`](../reference/fixtures.md#fastapi_toolsets.fixtures.utils.load_fixtures): +Directly by name with [`load_fixtures`](../reference/fixtures.md#fastapi_toolsets.fixtures.utils.load_fixtures): ```python from fastapi_toolsets.fixtures import load_fixtures async with db_context() as session: - await load_fixtures(session=session, registry=fixtures) + await load_fixtures(session, fixtures, "roles", "test_users") ``` +Both functions return a `dict[str, list[...]]` mapping each fixture name to the list of loaded instances. + ## Contexts [`Context`](../reference/fixtures.md#fastapi_toolsets.fixtures.enum.Context) is an enum with predefined values: @@ -58,10 +60,60 @@ async with db_context() as session: |---------|-------------| | `Context.BASE` | Core data required in all environments | | `Context.TESTING` | Data only loaded during tests | +| `Context.DEVELOPMENT` | Data only loaded in development | | `Context.PRODUCTION` | Data only loaded in production | A fixture with no `contexts` defined takes `Context.BASE` by default. +### Custom contexts + +Plain strings and any `Enum` subclass are accepted wherever a `Context` enum is expected. + +```python +from enum import Enum + +class AppContext(str, Enum): + STAGING = "staging" + DEMO = "demo" + +@fixtures.register(contexts=[AppContext.STAGING]) +def staging_data(): + return [Config(key="feature_x", enabled=True)] + +await load_fixtures_by_context(session, fixtures, AppContext.STAGING) +``` + +### Default context for a registry + +Pass `contexts` to `FixtureRegistry` to set a default for all fixtures registered in it: + +```python +testing_registry = FixtureRegistry(contexts=[Context.TESTING]) + +@testing_registry.register # implicitly contexts=[Context.TESTING] +def test_orders(): + return [Order(id=1, total=99)] +``` + +### Same fixture name, multiple context variants + +The same fixture name may be registered under different (non-overlapping) context sets. When multiple contexts are loaded together, all matching variants are merged: + +```python +@fixtures.register(contexts=[Context.BASE]) +def users(): + return [User(id=1, username="admin")] + +@fixtures.register(contexts=[Context.TESTING]) +def users(): + return [User(id=2, username="tester")] + +# loads both admin and tester +await load_fixtures_by_context(session, fixtures, Context.BASE, Context.TESTING) +``` + +Registering two variants with overlapping context sets raises `ValueError`. + ## Load strategies [`LoadStrategy`](../reference/fixtures.md#fastapi_toolsets.fixtures.enum.LoadStrategy) controls how the fixture loader handles rows that already exist: @@ -69,20 +121,44 @@ A fixture with no `contexts` defined takes `Context.BASE` by default. | Strategy | Description | |----------|-------------| | `LoadStrategy.INSERT` | Insert only, fail on duplicates | -| `LoadStrategy.UPSERT` | Insert or update on conflict | -| `LoadStrategy.SKIP` | Skip rows that already exist | +| `LoadStrategy.MERGE` | Insert or update on conflict (default) | +| `LoadStrategy.SKIP_EXISTING` | Skip rows that already exist | + +```python +await load_fixtures_by_context( + session, fixtures, Context.BASE, strategy=LoadStrategy.SKIP_EXISTING +) +``` ## Merging registries -Split fixtures definitions across modules and merge them: +Split fixture definitions across modules and merge them: ```python from myapp.fixtures.dev import dev_fixtures from myapp.fixtures.prod import prod_fixtures -fixtures = fixturesRegistry() +fixtures = FixtureRegistry() fixtures.include_registry(registry=dev_fixtures) fixtures.include_registry(registry=prod_fixtures) +``` + +Fixtures with the same name are allowed as long as their context sets do not overlap. Conflicting contexts raise `ValueError`. + +## Looking up fixture instances + +[`get_obj_by_attr`](../reference/fixtures.md#fastapi_toolsets.fixtures.utils.get_obj_by_attr) retrieves a specific instance from a fixture function by attribute value — useful when building cross-fixture `depends_on` relationships: + +```python +from fastapi_toolsets.fixtures import get_obj_by_attr + +@fixtures.register(depends_on=["roles"]) +def users(): + admin_role = get_obj_by_attr(roles, "name", "admin") + return [User(id=1, username="alice", role_id=admin_role.id)] +``` + +Raises `StopIteration` if no matching instance is found. ## Pytest integration @@ -111,7 +187,6 @@ async def test_user_can_login(fixture_users: list[User], fixture_roles: list[Rol ... ``` - The load order is resolved automatically from the `depends_on` declarations in your registry. Each generated fixture receives `db_session` as a dependency and returns the list of loaded model instances. ## CLI integration diff --git a/src/fastapi_toolsets/fixtures/registry.py b/src/fastapi_toolsets/fixtures/registry.py index 6a4cebb..9dabe83 100644 --- a/src/fastapi_toolsets/fixtures/registry.py +++ b/src/fastapi_toolsets/fixtures/registry.py @@ -2,6 +2,7 @@ from collections.abc import Callable, Sequence from dataclasses import dataclass, field +from enum import Enum from typing import Any, cast from sqlalchemy.orm import DeclarativeBase @@ -12,6 +13,13 @@ from .enum import Context logger = get_logger() +def _normalize_contexts( + contexts: list[str | Enum] | tuple[str | Enum, ...], +) -> list[str]: + """Convert a sequence of any Enum subclass and/or plain strings to a list of strings.""" + return [c.value if isinstance(c, Enum) else c for c in contexts] + + @dataclass class Fixture: """A fixture definition with metadata.""" @@ -50,26 +58,51 @@ class FixtureRegistry: Post(id=1, title="Test", user_id=1), ] ``` + + Fixtures with the same name may be registered for **different** contexts. + When multiple contexts are loaded together, their instances are merged: + + ```python + @fixtures.register(contexts=[Context.BASE]) + def users(): + return [User(id=1, username="admin")] + + @fixtures.register(contexts=[Context.TESTING]) + def users(): + return [User(id=2, username="tester")] + # load_fixtures_by_context(..., Context.BASE, Context.TESTING) + # → loads both User(admin) and User(tester) under the "users" name + ``` """ def __init__( self, - contexts: list[str | Context] | None = None, + contexts: list[str | Enum] | None = None, ) -> None: - self._fixtures: dict[str, Fixture] = {} + self._fixtures: dict[str, list[Fixture]] = {} self._default_contexts: list[str] | None = ( - [c.value if isinstance(c, Context) else c for c in contexts] - if contexts - else None + _normalize_contexts(contexts) if contexts else None ) + def _validate_no_context_overlap(self, name: str, new_contexts: list[str]) -> None: + """Raise ``ValueError`` if any existing variant for *name* overlaps.""" + existing_variants = self._fixtures.get(name, []) + new_set = set(new_contexts) + for variant in existing_variants: + if set(variant.contexts) & new_set: + raise ValueError( + f"Fixture '{name}' already exists in the current registry " + f"with overlapping contexts. Use distinct context sets for " + f"each variant of the same fixture name." + ) + def register( self, func: Callable[[], Sequence[DeclarativeBase]] | None = None, *, name: str | None = None, depends_on: list[str] | None = None, - contexts: list[str | Context] | None = None, + contexts: list[str | Enum] | None = None, ) -> Callable[..., Any]: """Register a fixture function. @@ -79,7 +112,8 @@ class FixtureRegistry: func: Fixture function returning list of model instances name: Fixture name (defaults to function name) depends_on: List of fixture names this depends on - contexts: List of contexts this fixture belongs to + contexts: List of contexts this fixture belongs to. Both + :class:`Context` enum values and plain strings are accepted. Example: ```python @@ -90,7 +124,6 @@ class FixtureRegistry: @fixtures.register(depends_on=["roles"], contexts=[Context.TESTING]) def test_users(): return [User(id=1, username="test", role_id=1)] - ``` """ def decorator( @@ -98,19 +131,20 @@ class FixtureRegistry: ) -> Callable[[], Sequence[DeclarativeBase]]: fixture_name = name or cast(Any, fn).__name__ if contexts is not None: - fixture_contexts = [ - c.value if isinstance(c, Context) else c for c in contexts - ] + fixture_contexts = _normalize_contexts(contexts) elif self._default_contexts is not None: fixture_contexts = self._default_contexts else: fixture_contexts = [Context.BASE.value] - self._fixtures[fixture_name] = Fixture( - name=fixture_name, - func=fn, - depends_on=depends_on or [], - contexts=fixture_contexts, + self._validate_no_context_overlap(fixture_name, fixture_contexts) + self._fixtures.setdefault(fixture_name, []).append( + Fixture( + name=fixture_name, + func=fn, + depends_on=depends_on or [], + contexts=fixture_contexts, + ) ) return fn @@ -121,11 +155,14 @@ class FixtureRegistry: def include_registry(self, registry: "FixtureRegistry") -> None: """Include another `FixtureRegistry` in the same current `FixtureRegistry`. + Fixtures with the same name are allowed as long as their context sets + do not overlap. Conflicting contexts raise :class:`ValueError`. + Args: registry: The `FixtureRegistry` to include Raises: - ValueError: If a fixture name already exists in the current registry + ValueError: If a fixture name already exists with overlapping contexts Example: ```python @@ -139,31 +176,73 @@ class FixtureRegistry: registry.include_registry(registry=dev_registry) ``` """ - for name, fixture in registry._fixtures.items(): - if name in self._fixtures: - raise ValueError( - f"Fixture '{name}' already exists in the current registry" - ) - self._fixtures[name] = fixture + for name, variants in registry._fixtures.items(): + for fixture in variants: + self._validate_no_context_overlap(name, fixture.contexts) + self._fixtures.setdefault(name, []).append(fixture) def get(self, name: str) -> Fixture: - """Get a fixture by name.""" + """Get a fixture by name. + + Raises: + KeyError: If no fixture with *name* is registered. + ValueError: If the fixture has multiple context variants — use + :meth:`get_variants` in that case. + """ if name not in self._fixtures: raise KeyError(f"Fixture '{name}' not found") - return self._fixtures[name] + variants = self._fixtures[name] + if len(variants) > 1: + raise ValueError( + f"Fixture '{name}' has {len(variants)} context variants. " + f"Use get_variants('{name}') to retrieve them." + ) + return variants[0] + + def get_variants(self, name: str, *contexts: str | Enum) -> list[Fixture]: + """Return all registered variants for *name*, optionally filtered by context. + + Args: + name: Fixture name. + *contexts: If given, only return variants whose context set + intersects with these values. Both :class:`Context` enum + values and plain strings are accepted. + + Returns: + List of matching :class:`Fixture` objects (may be empty when a + context filter is applied and nothing matches). + + Raises: + KeyError: If no fixture with *name* is registered. + """ + if name not in self._fixtures: + raise KeyError(f"Fixture '{name}' not found") + variants = self._fixtures[name] + if not contexts: + return list(variants) + context_values = set(_normalize_contexts(contexts)) + return [v for v in variants if set(v.contexts) & context_values] def get_all(self) -> list[Fixture]: - """Get all registered fixtures.""" - return list(self._fixtures.values()) + """Get all registered fixtures (all variants of all names).""" + return [f for variants in self._fixtures.values() for f in variants] - def get_by_context(self, *contexts: str | Context) -> list[Fixture]: + def get_by_context(self, *contexts: str | Enum) -> list[Fixture]: """Get fixtures for specific contexts.""" - context_values = {c.value if isinstance(c, Context) else c for c in contexts} - return [f for f in self._fixtures.values() if set(f.contexts) & context_values] + context_values = set(_normalize_contexts(contexts)) + return [ + f + for variants in self._fixtures.values() + for f in variants + if set(f.contexts) & context_values + ] def resolve_dependencies(self, *names: str) -> list[str]: """Resolve fixture dependencies in topological order. + When a fixture name has multiple context variants, the union of all + variants' ``depends_on`` lists is used. + Args: *names: Fixture names to resolve @@ -185,9 +264,20 @@ class FixtureRegistry: raise ValueError(f"Circular dependency detected: {name}") visiting.add(name) - fixture = self.get(name) + variants = self._fixtures.get(name) + if variants is None: + raise KeyError(f"Fixture '{name}' not found") - for dep in fixture.depends_on: + # Union of depends_on across all variants, preserving first-seen order. + seen_deps: set[str] = set() + all_deps: list[str] = [] + for variant in variants: + for dep in variant.depends_on: + if dep not in seen_deps: + all_deps.append(dep) + seen_deps.add(dep) + + for dep in all_deps: visit(dep) visiting.remove(name) @@ -199,7 +289,7 @@ class FixtureRegistry: return resolved - def resolve_context_dependencies(self, *contexts: str | Context) -> list[str]: + def resolve_context_dependencies(self, *contexts: str | Enum) -> list[str]: """Resolve all fixtures for contexts with dependencies. Args: @@ -209,7 +299,9 @@ class FixtureRegistry: List of fixture names in load order """ context_fixtures = self.get_by_context(*contexts) - names = [f.name for f in context_fixtures] + # Deduplicate names while preserving first-seen order (a name can + # appear multiple times if it has variants in different contexts). + names = list(dict.fromkeys(f.name for f in context_fixtures)) all_deps: set[str] = set() for name in names: diff --git a/src/fastapi_toolsets/fixtures/utils.py b/src/fastapi_toolsets/fixtures/utils.py index 626c301..fcde2b2 100644 --- a/src/fastapi_toolsets/fixtures/utils.py +++ b/src/fastapi_toolsets/fixtures/utils.py @@ -1,6 +1,7 @@ """Fixture loading utilities for database seeding.""" from collections.abc import Callable, Sequence +from enum import Enum from typing import Any from sqlalchemy.ext.asyncio import AsyncSession @@ -10,7 +11,7 @@ from ..db import get_transaction from ..logger import get_logger from ..types import ModelType from .enum import LoadStrategy -from .registry import Context, FixtureRegistry +from .registry import FixtureRegistry, _normalize_contexts logger = get_logger() @@ -20,13 +21,35 @@ async def _load_ordered( registry: FixtureRegistry, ordered_names: list[str], strategy: LoadStrategy, + contexts: tuple[str, ...] | None = None, ) -> dict[str, list[DeclarativeBase]]: - """Load fixtures in order.""" + """Load fixtures in order. + + When *contexts* is provided only variants whose context set intersects with + *contexts* are called for each name; their instances are concatenated. + When *contexts* is ``None`` all variants of each name are loaded. + """ results: dict[str, list[DeclarativeBase]] = {} for name in ordered_names: - fixture = registry.get(name) - instances = list(fixture.func()) + variants = ( + registry.get_variants(name, *contexts) + if contexts is not None + else registry.get_variants(name) + ) + + # Cross-context dependency fallback: if we're loading by context but + # no variant matches (e.g. a "base"-only fixture required by a + # "testing" fixture), load all available variants so the dependency + # is satisfied. + if contexts is not None and not variants: + variants = registry.get_variants(name) + + if not variants: + results[name] = [] + continue + + instances = [inst for v in variants for inst in v.func()] if not instances: results[name] = [] @@ -109,6 +132,8 @@ async def load_fixtures( ) -> dict[str, list[DeclarativeBase]]: """Load specific fixtures by name with dependencies. + All context variants of each requested fixture are loaded and merged. + Args: session: Database session registry: Fixture registry @@ -125,19 +150,27 @@ async def load_fixtures( async def load_fixtures_by_context( session: AsyncSession, registry: FixtureRegistry, - *contexts: str | Context, + *contexts: str | Enum, strategy: LoadStrategy = LoadStrategy.MERGE, ) -> dict[str, list[DeclarativeBase]]: """Load all fixtures for specific contexts. + For each fixture name, only the variants whose context set intersects with + *contexts* are loaded. When a name has variants in multiple of the + requested contexts, their instances are merged before being inserted. + Args: session: Database session registry: Fixture registry - *contexts: Contexts to load (e.g., Context.BASE, Context.TESTING) + *contexts: Contexts to load (e.g., ``Context.BASE``, ``Context.TESTING``, + or plain strings for custom contexts) strategy: How to handle existing records Returns: Dict mapping fixture names to loaded instances """ + context_strings = tuple(_normalize_contexts(contexts)) ordered = registry.resolve_context_dependencies(*contexts) - return await _load_ordered(session, registry, ordered, strategy) + return await _load_ordered( + session, registry, ordered, strategy, contexts=context_strings + ) diff --git a/tests/test_fixtures.py b/tests/test_fixtures.py index 0ad3d4a..3fbe600 100644 --- a/tests/test_fixtures.py +++ b/tests/test_fixtures.py @@ -1,6 +1,7 @@ """Tests for fastapi_toolsets.fixtures module.""" import uuid +from enum import Enum import pytest from sqlalchemy.ext.asyncio import AsyncSession @@ -19,6 +20,19 @@ from fastapi_toolsets.fixtures.utils import _get_primary_key from .conftest import IntRole, Permission, Role, RoleCrud, User, UserCrud +class AppContext(str, Enum): + """Example user-defined str+Enum context.""" + + STAGING = "staging" + DEMO = "demo" + + +class PlainEnumContext(Enum): + """Example user-defined plain Enum context (no str mixin).""" + + STAGING = "staging" + + class TestContext: """Tests for Context enum.""" @@ -39,6 +53,86 @@ class TestContext: assert Context.TESTING.value == "testing" +class TestCustomEnumContext: + """Custom Enum types are accepted wherever Context/str are expected.""" + + def test_cannot_subclass_context_with_members(self): + """Python prohibits extending an Enum that already has members.""" + with pytest.raises(TypeError): + + class MyContext(Context): # noqa: F841 # ty: ignore[subclass-of-final-class] + STAGING = "staging" + + def test_custom_enum_values_interchangeable_with_context(self): + """A custom enum with the same .value as a built-in Context member is + treated as the same context — fixtures registered under one are found + by the other.""" + + class AppContextFull(str, Enum): + BASE = "base" + STAGING = "staging" + + registry = FixtureRegistry() + + @registry.register(contexts=[Context.BASE]) + def roles(): + return [] + + # AppContextFull.BASE has value "base" — same as Context.BASE + fixtures = registry.get_by_context(AppContextFull.BASE) + assert len(fixtures) == 1 + + def test_custom_enum_registry_default_contexts(self): + """FixtureRegistry(contexts=[...]) accepts a custom Enum.""" + registry = FixtureRegistry(contexts=[AppContext.STAGING]) + + @registry.register + def data(): + return [] + + fixture = registry.get("data") + assert fixture.contexts == ["staging"] + + def test_custom_enum_resolve_context_dependencies(self): + """resolve_context_dependencies accepts a custom Enum context.""" + registry = FixtureRegistry() + + @registry.register(contexts=[AppContext.STAGING]) + def staging_roles(): + return [] + + order = registry.resolve_context_dependencies(AppContext.STAGING) + assert "staging_roles" in order + + @pytest.mark.anyio + async def test_custom_enum_e2e(self, db_session: AsyncSession): + """End-to-end: register with custom Enum, load with the same Enum.""" + registry = FixtureRegistry() + + @registry.register(contexts=[AppContext.STAGING]) + def staging_roles(): + return [Role(id=uuid.uuid4(), name="staging-admin")] + + result = await load_fixtures_by_context( + db_session, registry, AppContext.STAGING + ) + assert len(result["staging_roles"]) == 1 + + @pytest.mark.anyio + async def test_plain_enum_e2e(self, db_session: AsyncSession): + """End-to-end: register with plain Enum, load with the same Enum.""" + registry = FixtureRegistry() + + @registry.register(contexts=[PlainEnumContext.STAGING]) + def staging_roles(): + return [Role(id=uuid.uuid4(), name="plain-staging-admin")] + + result = await load_fixtures_by_context( + db_session, registry, PlainEnumContext.STAGING + ) + assert len(result["staging_roles"]) == 1 + + class TestLoadStrategy: """Tests for LoadStrategy enum.""" @@ -407,6 +501,37 @@ class TestDependencyResolution: with pytest.raises(ValueError, match="Circular dependency"): registry.resolve_dependencies("a") + def test_resolve_raises_for_unknown_dependency(self): + """KeyError when depends_on references an unregistered fixture.""" + registry = FixtureRegistry() + + @registry.register(depends_on=["ghost"]) + def users(): + return [] + + with pytest.raises(KeyError, match="ghost"): + registry.resolve_dependencies("users") + + def test_resolve_deduplicates_shared_depends_on_across_variants(self): + """A dep shared by two same-name variants appears only once in the order.""" + registry = FixtureRegistry() + + @registry.register(contexts=[Context.BASE]) + def roles(): + return [] + + @registry.register(depends_on=["roles"], contexts=[Context.BASE]) + def items(): + return [] + + @registry.register(depends_on=["roles"], contexts=[Context.TESTING]) + def items(): # noqa: F811 + return [] + + order = registry.resolve_dependencies("items") + assert order.count("roles") == 1 + assert order.index("roles") < order.index("items") + def test_resolve_context_dependencies(self): """Resolve all fixtures for a context with dependencies.""" registry = FixtureRegistry() @@ -795,3 +920,28 @@ class TestGetPrimaryKey: instance = Permission(subject="post") # action is None pk = _get_primary_key(instance) assert pk is None + + +class TestRegistryGetVariants: + """Tests for FixtureRegistry.get and get_variants edge cases.""" + + def test_get_raises_value_error_for_multi_variant(self): + """get() raises ValueError when the fixture has multiple context variants.""" + registry = FixtureRegistry() + + @registry.register(contexts=[Context.BASE]) + def items(): + return [] + + @registry.register(contexts=[Context.TESTING]) + def items(): # noqa: F811 + return [] + + with pytest.raises(ValueError, match="get_variants"): + registry.get("items") + + def test_get_variants_raises_key_error_for_unknown(self): + """get_variants() raises KeyError for an unregistered name.""" + registry = FixtureRegistry() + with pytest.raises(KeyError, match="not found"): + registry.get_variants("no_such_fixture")