diff --git a/src/fastapi_toolsets/fixtures/__init__.py b/src/fastapi_toolsets/fixtures/__init__.py index 4aaee2b..06f3748 100644 --- a/src/fastapi_toolsets/fixtures/__init__.py +++ b/src/fastapi_toolsets/fixtures/__init__.py @@ -6,11 +6,13 @@ from .fixtures import ( load_fixtures_by_context, ) from .pytest_plugin import register_fixtures +from .utils import get_obj_by_attr __all__ = [ "Context", "FixtureRegistry", "LoadStrategy", + "get_obj_by_attr", "load_fixtures", "load_fixtures_by_context", "register_fixtures", diff --git a/src/fastapi_toolsets/fixtures/utils.py b/src/fastapi_toolsets/fixtures/utils.py new file mode 100644 index 0000000..106c706 --- /dev/null +++ b/src/fastapi_toolsets/fixtures/utils.py @@ -0,0 +1,26 @@ +from collections.abc import Callable, Sequence +from typing import Any, TypeVar + +from sqlalchemy.orm import DeclarativeBase + +T = TypeVar("T", bound=DeclarativeBase) + + +def get_obj_by_attr( + fixtures: Callable[[], Sequence[T]], attr_name: str, value: Any +) -> T: + """Get a SQLAlchemy model instance by matching an attribute value. + + Args: + fixtures: A fixture function registered via ``@registry.register`` + that returns a sequence of SQLAlchemy model instances. + attr_name: Name of the attribute to match against. + value: Value to match. + + Returns: + The first model instance where the attribute matches the given value. + + Raises: + StopIteration: If no matching object is found. + """ + return next(obj for obj in fixtures() if getattr(obj, attr_name) == value) diff --git a/tests/test_fixtures_utils.py b/tests/test_fixtures_utils.py new file mode 100644 index 0000000..b21bbe8 --- /dev/null +++ b/tests/test_fixtures_utils.py @@ -0,0 +1,57 @@ +"""Tests for fastapi_toolsets.fixtures.utils.""" + +import pytest + +from fastapi_toolsets.fixtures import FixtureRegistry +from fastapi_toolsets.fixtures.utils import get_obj_by_attr + +from .conftest import Role, User + +registry = FixtureRegistry() + + +@registry.register +def roles() -> list[Role]: + return [ + Role(id=1, name="admin"), + Role(id=2, name="user"), + Role(id=3, name="moderator"), + ] + + +@registry.register(depends_on=["roles"]) +def users() -> list[User]: + return [ + User(id=1, username="alice", email="alice@example.com", role_id=1), + User(id=2, username="bob", email="bob@example.com", role_id=1), + ] + + +class TestGetObjByAttr: + """Tests for get_obj_by_attr.""" + + def test_get_by_id(self): + """Get an object by its id attribute.""" + role = get_obj_by_attr(roles, "id", 1) + assert role.name == "admin" + + def test_get_user_by_username(self): + """Get a user by username.""" + user = get_obj_by_attr(users, "username", "bob") + assert user.id == 2 + assert user.email == "bob@example.com" + + def test_returns_first_match(self): + """Returns the first matching object when multiple could match.""" + user = get_obj_by_attr(users, "role_id", 1) + assert user.username == "alice" + + def test_no_match_raises_stop_iteration(self): + """Raises StopIteration when no object matches.""" + with pytest.raises(StopIteration): + get_obj_by_attr(roles, "name", "nonexistent") + + def test_no_match_on_wrong_value_type(self): + """Raises StopIteration when value type doesn't match.""" + with pytest.raises(StopIteration): + get_obj_by_attr(roles, "id", "1")