From 691fb78fda592673ade46fd80f9fcb665019ebd2 Mon Sep 17 00:00:00 2001 From: d3vyce <44915747+d3vyce@users.noreply.github.com> Date: Tue, 3 Feb 2026 14:59:36 +0100 Subject: [PATCH] feat: add include_registry to FixtureRegistry + add context default to the registry (#25) --- src/fastapi_toolsets/fixtures/registry.py | 48 +++++- tests/test_fixtures.py | 172 ++++++++++++++++++++++ 2 files changed, 215 insertions(+), 5 deletions(-) diff --git a/src/fastapi_toolsets/fixtures/registry.py b/src/fastapi_toolsets/fixtures/registry.py index 17ca750..5df178f 100644 --- a/src/fastapi_toolsets/fixtures/registry.py +++ b/src/fastapi_toolsets/fixtures/registry.py @@ -50,8 +50,16 @@ class FixtureRegistry: ] """ - def __init__(self) -> None: + def __init__( + self, + contexts: list[str | Context] | None = None, + ) -> None: self._fixtures: dict[str, Fixture] = {} + self._default_contexts: list[str] | None = ( + [c.value if isinstance(c, Context) else c for c in contexts] + if contexts + else None + ) def register( self, @@ -85,10 +93,14 @@ class FixtureRegistry: fn: Callable[[], Sequence[DeclarativeBase]], ) -> Callable[[], Sequence[DeclarativeBase]]: fixture_name = name or cast(Any, fn).__name__ - fixture_contexts = [ - c.value if isinstance(c, Context) else c - for c in (contexts or [Context.BASE]) - ] + if contexts is not None: + fixture_contexts = [ + c.value if isinstance(c, Context) else c for c in contexts + ] + elif self._default_contexts is not None: + fixture_contexts = self._default_contexts + else: + fixture_contexts = [Context.BASE.value] self._fixtures[fixture_name] = Fixture( name=fixture_name, @@ -102,6 +114,32 @@ class FixtureRegistry: return decorator(func) return decorator + def include_registry(self, registry: "FixtureRegistry") -> None: + """Include another `FixtureRegistry` in the same current `FixtureRegistry`. + + Args: + registry: The `FixtureRegistry` to include + + Raises: + ValueError: If a fixture name already exists in the current registry + + Example: + registry = FixtureRegistry() + dev_registry = FixtureRegistry() + + @dev_registry.register + def dev_data(): + return [...] + + registry.include_registry(registry=dev_registry) + """ + for name, fixture in registry._fixtures.items(): + if name in self._fixtures: + raise ValueError( + f"Fixture '{name}' already exists in the current registry" + ) + self._fixtures[name] = fixture + def get(self, name: str) -> Fixture: """Get a fixture by name.""" if name not in self._fixtures: diff --git a/tests/test_fixtures.py b/tests/test_fixtures.py index 0d8ffcf..6e0dcea 100644 --- a/tests/test_fixtures.py +++ b/tests/test_fixtures.py @@ -159,6 +159,178 @@ class TestFixtureRegistry: assert names == {"test_data"} +class TestIncludeRegistry: + """Tests for FixtureRegistry.include_registry method.""" + + def test_include_empty_registry(self): + """Include an empty registry does nothing.""" + main_registry = FixtureRegistry() + other_registry = FixtureRegistry() + + @main_registry.register + def roles(): + return [] + + main_registry.include_registry(other_registry) + + assert len(main_registry.get_all()) == 1 + + def test_include_registry_adds_fixtures(self): + """Include registry adds all fixtures from the other registry.""" + main_registry = FixtureRegistry() + other_registry = FixtureRegistry() + + @main_registry.register + def roles(): + return [] + + @other_registry.register + def users(): + return [] + + @other_registry.register + def posts(): + return [] + + main_registry.include_registry(other_registry) + + names = {f.name for f in main_registry.get_all()} + assert names == {"roles", "users", "posts"} + + def test_include_registry_preserves_dependencies(self): + """Include registry preserves fixture dependencies.""" + main_registry = FixtureRegistry() + other_registry = FixtureRegistry() + + @main_registry.register + def roles(): + return [] + + @other_registry.register(depends_on=["roles"]) + def users(): + return [] + + main_registry.include_registry(other_registry) + + fixture = main_registry.get("users") + assert fixture.depends_on == ["roles"] + + def test_include_registry_preserves_contexts(self): + """Include registry preserves fixture contexts.""" + main_registry = FixtureRegistry() + other_registry = FixtureRegistry() + + @other_registry.register(contexts=[Context.TESTING, Context.DEVELOPMENT]) + def test_data(): + return [] + + main_registry.include_registry(other_registry) + + fixture = main_registry.get("test_data") + assert Context.TESTING.value in fixture.contexts + assert Context.DEVELOPMENT.value in fixture.contexts + + def test_include_registry_raises_on_duplicate(self): + """Include registry raises ValueError on duplicate fixture names.""" + main_registry = FixtureRegistry() + other_registry = FixtureRegistry() + + @main_registry.register(name="roles") + def roles_main(): + return [] + + @other_registry.register(name="roles") + def roles_other(): + return [] + + with pytest.raises(ValueError, match="already exists"): + main_registry.include_registry(other_registry) + + def test_include_multiple_registries(self): + """Include multiple registries sequentially.""" + main_registry = FixtureRegistry() + dev_registry = FixtureRegistry() + test_registry = FixtureRegistry() + + @main_registry.register + def base(): + return [] + + @dev_registry.register + def dev_data(): + return [] + + @test_registry.register + def test_data(): + return [] + + main_registry.include_registry(dev_registry) + main_registry.include_registry(test_registry) + + names = {f.name for f in main_registry.get_all()} + assert names == {"base", "dev_data", "test_data"} + + +class TestDefaultContexts: + """Tests for FixtureRegistry default contexts.""" + + def test_default_contexts_applied_to_fixtures(self): + """Default contexts are applied when no contexts specified.""" + registry = FixtureRegistry(contexts=[Context.TESTING]) + + @registry.register + def test_data(): + return [] + + fixture = registry.get("test_data") + assert fixture.contexts == [Context.TESTING.value] + + def test_explicit_contexts_override_default(self): + """Explicit contexts override default contexts.""" + registry = FixtureRegistry(contexts=[Context.TESTING]) + + @registry.register(contexts=[Context.PRODUCTION]) + def prod_data(): + return [] + + fixture = registry.get("prod_data") + assert fixture.contexts == [Context.PRODUCTION.value] + + def test_no_default_contexts_uses_base(self): + """Without default contexts, BASE is used.""" + registry = FixtureRegistry() + + @registry.register + def data(): + return [] + + fixture = registry.get("data") + assert fixture.contexts == [Context.BASE.value] + + def test_multiple_default_contexts(self): + """Multiple default contexts are applied.""" + registry = FixtureRegistry(contexts=[Context.DEVELOPMENT, Context.TESTING]) + + @registry.register + def dev_test_data(): + return [] + + fixture = registry.get("dev_test_data") + assert Context.DEVELOPMENT.value in fixture.contexts + assert Context.TESTING.value in fixture.contexts + + def test_default_contexts_with_string_values(self): + """Default contexts work with string values.""" + registry = FixtureRegistry(contexts=["custom_context"]) + + @registry.register + def custom_data(): + return [] + + fixture = registry.get("custom_data") + assert fixture.contexts == ["custom_context"] + + class TestDependencyResolution: """Tests for fixture dependency resolution."""