diff --git a/src/fastapi_toolsets/fixtures/__init__.py b/src/fastapi_toolsets/fixtures/__init__.py index fce6fee..0157c09 100644 --- a/src/fastapi_toolsets/fixtures/__init__.py +++ b/src/fastapi_toolsets/fixtures/__init__.py @@ -1,11 +1,6 @@ -from .fixtures import ( - Context, - FixtureRegistry, - LoadStrategy, - load_fixtures, - load_fixtures_by_context, -) -from .utils import get_obj_by_attr +from .enum import LoadStrategy +from .registry import Context, FixtureRegistry +from .utils import get_obj_by_attr, load_fixtures, load_fixtures_by_context __all__ = [ "Context", @@ -16,12 +11,3 @@ __all__ = [ "load_fixtures_by_context", "register_fixtures", ] - - -# We lazy-load register_fixtures to avoid needing pytest when using fixtures CLI -def __getattr__(name: str): - if name == "register_fixtures": - from .pytest_plugin import register_fixtures - - return register_fixtures - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/fastapi_toolsets/fixtures/enum.py b/src/fastapi_toolsets/fixtures/enum.py new file mode 100644 index 0000000..6a0b53b --- /dev/null +++ b/src/fastapi_toolsets/fixtures/enum.py @@ -0,0 +1,30 @@ +from enum import Enum + + +class LoadStrategy(str, Enum): + """Strategy for loading fixtures into the database.""" + + INSERT = "insert" + """Insert new records. Fails if record already exists.""" + + MERGE = "merge" + """Insert or update based on primary key (SQLAlchemy merge).""" + + SKIP_EXISTING = "skip_existing" + """Insert only if record doesn't exist (based on primary key).""" + + +class Context(str, Enum): + """Predefined fixture contexts.""" + + BASE = "base" + """Base fixtures loaded in all environments.""" + + PRODUCTION = "production" + """Production-only fixtures.""" + + DEVELOPMENT = "development" + """Development fixtures.""" + + TESTING = "testing" + """Test fixtures.""" diff --git a/src/fastapi_toolsets/fixtures/fixtures.py b/src/fastapi_toolsets/fixtures/registry.py similarity index 54% rename from src/fastapi_toolsets/fixtures/fixtures.py rename to src/fastapi_toolsets/fixtures/registry.py index 11273c5..17ca750 100644 --- a/src/fastapi_toolsets/fixtures/fixtures.py +++ b/src/fastapi_toolsets/fixtures/registry.py @@ -3,46 +3,15 @@ import logging from collections.abc import Callable, Sequence from dataclasses import dataclass, field -from enum import Enum from typing import Any, cast -from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import DeclarativeBase -from ..db import get_transaction +from .enum import Context logger = logging.getLogger(__name__) -class LoadStrategy(str, Enum): - """Strategy for loading fixtures into the database.""" - - INSERT = "insert" - """Insert new records. Fails if record already exists.""" - - MERGE = "merge" - """Insert or update based on primary key (SQLAlchemy merge).""" - - SKIP_EXISTING = "skip_existing" - """Insert only if record doesn't exist (based on primary key).""" - - -class Context(str, Enum): - """Predefined fixture contexts.""" - - BASE = "base" - """Base fixtures loaded in all environments.""" - - PRODUCTION = "production" - """Production-only fixtures.""" - - DEVELOPMENT = "development" - """Development fixtures.""" - - TESTING = "testing" - """Test fixtures.""" - - @dataclass class Fixture: """A fixture definition with metadata.""" @@ -204,118 +173,3 @@ class FixtureRegistry: all_deps.update(deps) return self.resolve_dependencies(*all_deps) - - -async def load_fixtures( - session: AsyncSession, - registry: FixtureRegistry, - *names: str, - strategy: LoadStrategy = LoadStrategy.MERGE, -) -> dict[str, list[DeclarativeBase]]: - """Load specific fixtures by name with dependencies. - - Args: - session: Database session - registry: Fixture registry - *names: Fixture names to load (dependencies auto-resolved) - strategy: How to handle existing records - - Returns: - Dict mapping fixture names to loaded instances - - Example: - # Loads 'roles' first (dependency), then 'users' - result = await load_fixtures(session, fixtures, "users") - print(result["users"]) # [User(...), ...] - """ - ordered = registry.resolve_dependencies(*names) - return await _load_ordered(session, registry, ordered, strategy) - - -async def load_fixtures_by_context( - session: AsyncSession, - registry: FixtureRegistry, - *contexts: str | Context, - strategy: LoadStrategy = LoadStrategy.MERGE, -) -> dict[str, list[DeclarativeBase]]: - """Load all fixtures for specific contexts. - - Args: - session: Database session - registry: Fixture registry - *contexts: Contexts to load (e.g., Context.BASE, Context.TESTING) - strategy: How to handle existing records - - Returns: - Dict mapping fixture names to loaded instances - - Example: - # Load base + testing fixtures - await load_fixtures_by_context( - session, fixtures, - Context.BASE, Context.TESTING - ) - """ - ordered = registry.resolve_context_dependencies(*contexts) - return await _load_ordered(session, registry, ordered, strategy) - - -async def _load_ordered( - session: AsyncSession, - registry: FixtureRegistry, - ordered_names: list[str], - strategy: LoadStrategy, -) -> dict[str, list[DeclarativeBase]]: - """Load fixtures in order.""" - results: dict[str, list[DeclarativeBase]] = {} - - for name in ordered_names: - fixture = registry.get(name) - instances = list(fixture.func()) - - if not instances: - results[name] = [] - continue - - model_name = type(instances[0]).__name__ - loaded: list[DeclarativeBase] = [] - - async with get_transaction(session): - for instance in instances: - if strategy == LoadStrategy.INSERT: - session.add(instance) - loaded.append(instance) - - elif strategy == LoadStrategy.MERGE: - merged = await session.merge(instance) - loaded.append(merged) - - elif strategy == LoadStrategy.SKIP_EXISTING: - pk = _get_primary_key(instance) - if pk is not None: - existing = await session.get(type(instance), pk) - if existing is None: - session.add(instance) - loaded.append(instance) - else: - session.add(instance) - loaded.append(instance) - - results[name] = loaded - logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)") - - return results - - -def _get_primary_key(instance: DeclarativeBase) -> Any | None: - """Get the primary key value of a model instance.""" - mapper = instance.__class__.__mapper__ - pk_cols = mapper.primary_key - - if len(pk_cols) == 1: - return getattr(instance, pk_cols[0].name, None) - - pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols) - if all(v is not None for v in pk_values): - return pk_values - return None diff --git a/src/fastapi_toolsets/fixtures/utils.py b/src/fastapi_toolsets/fixtures/utils.py index 106c706..e1a88eb 100644 --- a/src/fastapi_toolsets/fixtures/utils.py +++ b/src/fastapi_toolsets/fixtures/utils.py @@ -1,8 +1,16 @@ +import logging from collections.abc import Callable, Sequence from typing import Any, TypeVar +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import DeclarativeBase +from ..db import get_transaction +from .enum import LoadStrategy +from .registry import Context, FixtureRegistry + +logger = logging.getLogger(__name__) + T = TypeVar("T", bound=DeclarativeBase) @@ -24,3 +32,118 @@ def get_obj_by_attr( StopIteration: If no matching object is found. """ return next(obj for obj in fixtures() if getattr(obj, attr_name) == value) + + +async def load_fixtures( + session: AsyncSession, + registry: FixtureRegistry, + *names: str, + strategy: LoadStrategy = LoadStrategy.MERGE, +) -> dict[str, list[DeclarativeBase]]: + """Load specific fixtures by name with dependencies. + + Args: + session: Database session + registry: Fixture registry + *names: Fixture names to load (dependencies auto-resolved) + strategy: How to handle existing records + + Returns: + Dict mapping fixture names to loaded instances + + Example: + # Loads 'roles' first (dependency), then 'users' + result = await load_fixtures(session, fixtures, "users") + print(result["users"]) # [User(...), ...] + """ + ordered = registry.resolve_dependencies(*names) + return await _load_ordered(session, registry, ordered, strategy) + + +async def load_fixtures_by_context( + session: AsyncSession, + registry: FixtureRegistry, + *contexts: str | Context, + strategy: LoadStrategy = LoadStrategy.MERGE, +) -> dict[str, list[DeclarativeBase]]: + """Load all fixtures for specific contexts. + + Args: + session: Database session + registry: Fixture registry + *contexts: Contexts to load (e.g., Context.BASE, Context.TESTING) + strategy: How to handle existing records + + Returns: + Dict mapping fixture names to loaded instances + + Example: + # Load base + testing fixtures + await load_fixtures_by_context( + session, fixtures, + Context.BASE, Context.TESTING + ) + """ + ordered = registry.resolve_context_dependencies(*contexts) + return await _load_ordered(session, registry, ordered, strategy) + + +async def _load_ordered( + session: AsyncSession, + registry: FixtureRegistry, + ordered_names: list[str], + strategy: LoadStrategy, +) -> dict[str, list[DeclarativeBase]]: + """Load fixtures in order.""" + results: dict[str, list[DeclarativeBase]] = {} + + for name in ordered_names: + fixture = registry.get(name) + instances = list(fixture.func()) + + if not instances: + results[name] = [] + continue + + model_name = type(instances[0]).__name__ + loaded: list[DeclarativeBase] = [] + + async with get_transaction(session): + for instance in instances: + if strategy == LoadStrategy.INSERT: + session.add(instance) + loaded.append(instance) + + elif strategy == LoadStrategy.MERGE: + merged = await session.merge(instance) + loaded.append(merged) + + elif strategy == LoadStrategy.SKIP_EXISTING: + pk = _get_primary_key(instance) + if pk is not None: + existing = await session.get(type(instance), pk) + if existing is None: + session.add(instance) + loaded.append(instance) + else: + session.add(instance) + loaded.append(instance) + + results[name] = loaded + logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)") + + return results + + +def _get_primary_key(instance: DeclarativeBase) -> Any | None: + """Get the primary key value of a model instance.""" + mapper = instance.__class__.__mapper__ + pk_cols = mapper.primary_key + + if len(pk_cols) == 1: + return getattr(instance, pk_cols[0].name, None) + + pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols) + if all(v is not None for v in pk_values): + return pk_values + return None diff --git a/src/fastapi_toolsets/pytest/__init__.py b/src/fastapi_toolsets/pytest/__init__.py new file mode 100644 index 0000000..7040c89 --- /dev/null +++ b/src/fastapi_toolsets/pytest/__init__.py @@ -0,0 +1,8 @@ +from .plugin import register_fixtures +from .utils import create_async_client, create_db_session + +__all__ = [ + "create_async_client", + "create_db_session", + "register_fixtures", +] diff --git a/src/fastapi_toolsets/fixtures/pytest_plugin.py b/src/fastapi_toolsets/pytest/plugin.py similarity index 99% rename from src/fastapi_toolsets/fixtures/pytest_plugin.py rename to src/fastapi_toolsets/pytest/plugin.py index 29b3446..e10a770 100644 --- a/src/fastapi_toolsets/fixtures/pytest_plugin.py +++ b/src/fastapi_toolsets/pytest/plugin.py @@ -59,7 +59,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import DeclarativeBase from ..db import get_transaction -from .fixtures import FixtureRegistry, LoadStrategy +from ..fixtures import FixtureRegistry, LoadStrategy def register_fixtures( diff --git a/src/fastapi_toolsets/pytest/utils.py b/src/fastapi_toolsets/pytest/utils.py new file mode 100644 index 0000000..c327738 --- /dev/null +++ b/src/fastapi_toolsets/pytest/utils.py @@ -0,0 +1,110 @@ +"""Pytest helper utilities for FastAPI testing.""" + +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from typing import Any + +from httpx import ASGITransport, AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.orm import DeclarativeBase + +from ..db import create_db_context + + +@asynccontextmanager +async def create_async_client( + app: Any, + base_url: str = "http://test", +) -> AsyncGenerator[AsyncClient, None]: + """Create an async httpx client for testing FastAPI applications. + + Args: + app: FastAPI application instance. + base_url: Base URL for requests. Defaults to "http://test". + + Yields: + An AsyncClient configured for the app. + + Example: + ```python + from fastapi import FastAPI + from fastapi_toolsets.pytest import create_async_client + + app = FastAPI() + + @pytest.fixture + async def client(): + async with create_async_client(app) as c: + yield c + + async def test_endpoint(client: AsyncClient): + response = await client.get("/health") + assert response.status_code == 200 + ``` + """ + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url=base_url) as client: + yield client + + +@asynccontextmanager +async def create_db_session( + database_url: str, + base: type[DeclarativeBase], + *, + echo: bool = False, + expire_on_commit: bool = False, + drop_tables: bool = True, +) -> AsyncGenerator[AsyncSession, None]: + """Create a database session for testing. + + Creates tables before yielding the session and optionally drops them after. + Each call creates a fresh engine and session for test isolation. + + Args: + database_url: Database connection URL (e.g., "postgresql+asyncpg://..."). + base: SQLAlchemy DeclarativeBase class containing model metadata. + echo: Enable SQLAlchemy query logging. Defaults to False. + expire_on_commit: Expire objects after commit. Defaults to False. + drop_tables: Drop tables after test. Defaults to True. + + Yields: + An AsyncSession ready for database operations. + + Example: + ```python + from fastapi_toolsets.pytest import create_db_session + from app.models import Base + + DATABASE_URL = "postgresql+asyncpg://user:pass@localhost/test_db" + + @pytest.fixture + async def db_session(): + async with create_db_session(DATABASE_URL, Base) as session: + yield session + + async def test_create_user(db_session: AsyncSession): + user = User(name="test") + db_session.add(user) + await db_session.commit() + ``` + """ + engine = create_async_engine(database_url, echo=echo) + + try: + # Create tables + async with engine.begin() as conn: + await conn.run_sync(base.metadata.create_all) + + # Create session using existing db context utility + session_maker = async_sessionmaker(engine, expire_on_commit=expire_on_commit) + get_session = create_db_context(session_maker) + + async with get_session() as session: + yield session + + if drop_tables: + async with engine.begin() as conn: + await conn.run_sync(base.metadata.drop_all) + finally: + await engine.dispose() diff --git a/tests/test_fixtures.py b/tests/test_fixtures.py index 79b5f0a..e80156b 100644 --- a/tests/test_fixtures.py +++ b/tests/test_fixtures.py @@ -7,6 +7,7 @@ from fastapi_toolsets.fixtures import ( Context, FixtureRegistry, LoadStrategy, + get_obj_by_attr, load_fixtures, load_fixtures_by_context, ) @@ -330,6 +331,69 @@ class TestLoadFixtures: assert role is not None assert role.name == "original" + @pytest.mark.anyio + async def test_load_with_insert_strategy(self, db_session: AsyncSession): + """Load fixtures with INSERT strategy.""" + registry = FixtureRegistry() + + @registry.register + def roles(): + return [ + Role(id=1, name="admin"), + Role(id=2, name="user"), + ] + + result = await load_fixtures( + db_session, registry, "roles", strategy=LoadStrategy.INSERT + ) + + assert "roles" in result + assert len(result["roles"]) == 2 + + from .conftest import RoleCrud + + count = await RoleCrud.count(db_session) + assert count == 2 + + @pytest.mark.anyio + async def test_load_empty_fixture(self, db_session: AsyncSession): + """Load a fixture that returns an empty list.""" + registry = FixtureRegistry() + + @registry.register + def empty_roles(): + return [] + + result = await load_fixtures(db_session, registry, "empty_roles") + + assert "empty_roles" in result + assert result["empty_roles"] == [] + + @pytest.mark.anyio + async def test_load_multiple_fixtures_without_dependencies( + self, db_session: AsyncSession + ): + """Load multiple independent fixtures.""" + registry = FixtureRegistry() + + @registry.register + def roles(): + return [Role(id=1, name="admin")] + + @registry.register + def other_roles(): + return [Role(id=2, name="user")] + + result = await load_fixtures(db_session, registry, "roles", "other_roles") + + assert "roles" in result + assert "other_roles" in result + + from .conftest import RoleCrud + + count = await RoleCrud.count(db_session) + assert count == 2 + class TestLoadFixturesByContext: """Tests for load_fixtures_by_context function.""" @@ -399,3 +463,55 @@ class TestLoadFixturesByContext: assert await RoleCrud.count(db_session) == 1 assert await UserCrud.count(db_session) == 1 + + +class TestGetObjByAttr: + """Tests for get_obj_by_attr helper function.""" + + def setup_method(self): + """Set up test fixtures for each test.""" + self.registry = FixtureRegistry() + + @self.registry.register + def roles() -> list[Role]: + return [ + Role(id=1, name="admin"), + Role(id=2, name="user"), + Role(id=3, name="moderator"), + ] + + @self.registry.register(depends_on=["roles"]) + def users() -> list[User]: + return [ + User(id=1, username="alice", email="alice@example.com", role_id=1), + User(id=2, username="bob", email="bob@example.com", role_id=1), + ] + + self.roles = roles + self.users = users + + def test_get_by_id(self): + """Get an object by its id attribute.""" + role = get_obj_by_attr(self.roles, "id", 1) + assert role.name == "admin" + + def test_get_user_by_username(self): + """Get a user by username.""" + user = get_obj_by_attr(self.users, "username", "bob") + assert user.id == 2 + assert user.email == "bob@example.com" + + def test_returns_first_match(self): + """Returns the first matching object when multiple could match.""" + user = get_obj_by_attr(self.users, "role_id", 1) + assert user.username == "alice" + + def test_no_match_raises_stop_iteration(self): + """Raises StopIteration when no object matches.""" + with pytest.raises(StopIteration): + get_obj_by_attr(self.roles, "name", "nonexistent") + + def test_no_match_on_wrong_value_type(self): + """Raises StopIteration when value type doesn't match.""" + with pytest.raises(StopIteration): + get_obj_by_attr(self.roles, "id", "1") diff --git a/tests/test_fixtures_utils.py b/tests/test_fixtures_utils.py deleted file mode 100644 index b21bbe8..0000000 --- a/tests/test_fixtures_utils.py +++ /dev/null @@ -1,57 +0,0 @@ -"""Tests for fastapi_toolsets.fixtures.utils.""" - -import pytest - -from fastapi_toolsets.fixtures import FixtureRegistry -from fastapi_toolsets.fixtures.utils import get_obj_by_attr - -from .conftest import Role, User - -registry = FixtureRegistry() - - -@registry.register -def roles() -> list[Role]: - return [ - Role(id=1, name="admin"), - Role(id=2, name="user"), - Role(id=3, name="moderator"), - ] - - -@registry.register(depends_on=["roles"]) -def users() -> list[User]: - return [ - User(id=1, username="alice", email="alice@example.com", role_id=1), - User(id=2, username="bob", email="bob@example.com", role_id=1), - ] - - -class TestGetObjByAttr: - """Tests for get_obj_by_attr.""" - - def test_get_by_id(self): - """Get an object by its id attribute.""" - role = get_obj_by_attr(roles, "id", 1) - assert role.name == "admin" - - def test_get_user_by_username(self): - """Get a user by username.""" - user = get_obj_by_attr(users, "username", "bob") - assert user.id == 2 - assert user.email == "bob@example.com" - - def test_returns_first_match(self): - """Returns the first matching object when multiple could match.""" - user = get_obj_by_attr(users, "role_id", 1) - assert user.username == "alice" - - def test_no_match_raises_stop_iteration(self): - """Raises StopIteration when no object matches.""" - with pytest.raises(StopIteration): - get_obj_by_attr(roles, "name", "nonexistent") - - def test_no_match_on_wrong_value_type(self): - """Raises StopIteration when value type doesn't match.""" - with pytest.raises(StopIteration): - get_obj_by_attr(roles, "id", "1") diff --git a/tests/test_pytest_plugin.py b/tests/test_pytest.py similarity index 55% rename from tests/test_pytest_plugin.py rename to tests/test_pytest.py index 62aabf5..ef903bb 100644 --- a/tests/test_pytest_plugin.py +++ b/tests/test_pytest.py @@ -1,12 +1,20 @@ -"""Tests for fastapi_toolsets.pytest_plugin module.""" +"""Tests for fastapi_toolsets.pytest module.""" import pytest +from fastapi import FastAPI +from httpx import AsyncClient +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload -from fastapi_toolsets.fixtures import Context, FixtureRegistry, register_fixtures +from fastapi_toolsets.fixtures import Context, FixtureRegistry +from fastapi_toolsets.pytest import ( + create_async_client, + create_db_session, + register_fixtures, +) -from .conftest import Role, RoleCrud, User, UserCrud +from .conftest import DATABASE_URL, Base, Role, RoleCrud, User, UserCrud test_registry = FixtureRegistry() @@ -158,3 +166,102 @@ class TestGeneratedFixtures: assert len(roles) == 2 assert len(users) == 2 + + +class TestCreateAsyncClient: + """Tests for create_async_client helper.""" + + @pytest.mark.anyio + async def test_creates_working_client(self): + """Client can make requests to the app.""" + app = FastAPI() + + @app.get("/health") + async def health(): + return {"status": "ok"} + + async with create_async_client(app) as client: + assert isinstance(client, AsyncClient) + response = await client.get("/health") + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + @pytest.mark.anyio + async def test_custom_base_url(self): + """Client uses custom base URL.""" + app = FastAPI() + + @app.get("/test") + async def test_endpoint(): + return {"url": "test"} + + async with create_async_client(app, base_url="http://custom") as client: + assert str(client.base_url) == "http://custom" + + @pytest.mark.anyio + async def test_client_closes_properly(self): + """Client is properly closed after context exit.""" + app = FastAPI() + + async with create_async_client(app) as client: + client_ref = client + + assert client_ref.is_closed + + +class TestCreateDbSession: + """Tests for create_db_session helper.""" + + @pytest.mark.anyio + async def test_creates_working_session(self): + """Session can perform database operations.""" + async with create_db_session(DATABASE_URL, Base) as session: + assert isinstance(session, AsyncSession) + + role = Role(id=9001, name="test_helper_role") + session.add(role) + await session.commit() + + result = await session.execute(select(Role).where(Role.id == 9001)) + fetched = result.scalar_one() + assert fetched.name == "test_helper_role" + + @pytest.mark.anyio + async def test_tables_created_before_session(self): + """Tables exist when session is yielded.""" + async with create_db_session(DATABASE_URL, Base) as session: + # Should not raise - tables exist + result = await session.execute(select(Role)) + assert result.all() == [] + + @pytest.mark.anyio + async def test_tables_dropped_after_session(self): + """Tables are dropped after session closes when drop_tables=True.""" + async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session: + role = Role(id=9002, name="will_be_dropped") + session.add(role) + await session.commit() + + # Verify tables were dropped by creating new session + async with create_db_session(DATABASE_URL, Base) as session: + result = await session.execute(select(Role)) + assert result.all() == [] + + @pytest.mark.anyio + async def test_tables_preserved_when_drop_disabled(self): + """Tables are preserved when drop_tables=False.""" + async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session: + role = Role(id=9003, name="preserved_role") + session.add(role) + await session.commit() + + # Create another session without dropping + async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session: + result = await session.execute(select(Role).where(Role.id == 9003)) + fetched = result.scalar_one_or_none() + assert fetched is not None + assert fetched.name == "preserved_role" + + # Cleanup: drop tables manually + async with create_db_session(DATABASE_URL, Base, drop_tables=True) as _: + pass