diff --git a/src/fastapi_toolsets/pytest/__init__.py b/src/fastapi_toolsets/pytest/__init__.py index 20fa819..7040c89 100644 --- a/src/fastapi_toolsets/pytest/__init__.py +++ b/src/fastapi_toolsets/pytest/__init__.py @@ -1,5 +1,8 @@ from .plugin import register_fixtures +from .utils import create_async_client, create_db_session __all__ = [ + "create_async_client", + "create_db_session", "register_fixtures", ] diff --git a/src/fastapi_toolsets/pytest/utils.py b/src/fastapi_toolsets/pytest/utils.py new file mode 100644 index 0000000..c327738 --- /dev/null +++ b/src/fastapi_toolsets/pytest/utils.py @@ -0,0 +1,110 @@ +"""Pytest helper utilities for FastAPI testing.""" + +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from typing import Any + +from httpx import ASGITransport, AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.orm import DeclarativeBase + +from ..db import create_db_context + + +@asynccontextmanager +async def create_async_client( + app: Any, + base_url: str = "http://test", +) -> AsyncGenerator[AsyncClient, None]: + """Create an async httpx client for testing FastAPI applications. + + Args: + app: FastAPI application instance. + base_url: Base URL for requests. Defaults to "http://test". + + Yields: + An AsyncClient configured for the app. + + Example: + ```python + from fastapi import FastAPI + from fastapi_toolsets.pytest import create_async_client + + app = FastAPI() + + @pytest.fixture + async def client(): + async with create_async_client(app) as c: + yield c + + async def test_endpoint(client: AsyncClient): + response = await client.get("/health") + assert response.status_code == 200 + ``` + """ + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url=base_url) as client: + yield client + + +@asynccontextmanager +async def create_db_session( + database_url: str, + base: type[DeclarativeBase], + *, + echo: bool = False, + expire_on_commit: bool = False, + drop_tables: bool = True, +) -> AsyncGenerator[AsyncSession, None]: + """Create a database session for testing. + + Creates tables before yielding the session and optionally drops them after. + Each call creates a fresh engine and session for test isolation. + + Args: + database_url: Database connection URL (e.g., "postgresql+asyncpg://..."). + base: SQLAlchemy DeclarativeBase class containing model metadata. + echo: Enable SQLAlchemy query logging. Defaults to False. + expire_on_commit: Expire objects after commit. Defaults to False. + drop_tables: Drop tables after test. Defaults to True. + + Yields: + An AsyncSession ready for database operations. + + Example: + ```python + from fastapi_toolsets.pytest import create_db_session + from app.models import Base + + DATABASE_URL = "postgresql+asyncpg://user:pass@localhost/test_db" + + @pytest.fixture + async def db_session(): + async with create_db_session(DATABASE_URL, Base) as session: + yield session + + async def test_create_user(db_session: AsyncSession): + user = User(name="test") + db_session.add(user) + await db_session.commit() + ``` + """ + engine = create_async_engine(database_url, echo=echo) + + try: + # Create tables + async with engine.begin() as conn: + await conn.run_sync(base.metadata.create_all) + + # Create session using existing db context utility + session_maker = async_sessionmaker(engine, expire_on_commit=expire_on_commit) + get_session = create_db_context(session_maker) + + async with get_session() as session: + yield session + + if drop_tables: + async with engine.begin() as conn: + await conn.run_sync(base.metadata.drop_all) + finally: + await engine.dispose() diff --git a/tests/test_pytest_plugin.py b/tests/test_pytest.py similarity index 56% rename from tests/test_pytest_plugin.py rename to tests/test_pytest.py index d88095e..ef903bb 100644 --- a/tests/test_pytest_plugin.py +++ b/tests/test_pytest.py @@ -1,13 +1,20 @@ -"""Tests for fastapi_toolsets.pytest_plugin module.""" +"""Tests for fastapi_toolsets.pytest module.""" import pytest +from fastapi import FastAPI +from httpx import AsyncClient +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from fastapi_toolsets.fixtures import Context, FixtureRegistry -from fastapi_toolsets.pytest import register_fixtures +from fastapi_toolsets.pytest import ( + create_async_client, + create_db_session, + register_fixtures, +) -from .conftest import Role, RoleCrud, User, UserCrud +from .conftest import DATABASE_URL, Base, Role, RoleCrud, User, UserCrud test_registry = FixtureRegistry() @@ -159,3 +166,102 @@ class TestGeneratedFixtures: assert len(roles) == 2 assert len(users) == 2 + + +class TestCreateAsyncClient: + """Tests for create_async_client helper.""" + + @pytest.mark.anyio + async def test_creates_working_client(self): + """Client can make requests to the app.""" + app = FastAPI() + + @app.get("/health") + async def health(): + return {"status": "ok"} + + async with create_async_client(app) as client: + assert isinstance(client, AsyncClient) + response = await client.get("/health") + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + @pytest.mark.anyio + async def test_custom_base_url(self): + """Client uses custom base URL.""" + app = FastAPI() + + @app.get("/test") + async def test_endpoint(): + return {"url": "test"} + + async with create_async_client(app, base_url="http://custom") as client: + assert str(client.base_url) == "http://custom" + + @pytest.mark.anyio + async def test_client_closes_properly(self): + """Client is properly closed after context exit.""" + app = FastAPI() + + async with create_async_client(app) as client: + client_ref = client + + assert client_ref.is_closed + + +class TestCreateDbSession: + """Tests for create_db_session helper.""" + + @pytest.mark.anyio + async def test_creates_working_session(self): + """Session can perform database operations.""" + async with create_db_session(DATABASE_URL, Base) as session: + assert isinstance(session, AsyncSession) + + role = Role(id=9001, name="test_helper_role") + session.add(role) + await session.commit() + + result = await session.execute(select(Role).where(Role.id == 9001)) + fetched = result.scalar_one() + assert fetched.name == "test_helper_role" + + @pytest.mark.anyio + async def test_tables_created_before_session(self): + """Tables exist when session is yielded.""" + async with create_db_session(DATABASE_URL, Base) as session: + # Should not raise - tables exist + result = await session.execute(select(Role)) + assert result.all() == [] + + @pytest.mark.anyio + async def test_tables_dropped_after_session(self): + """Tables are dropped after session closes when drop_tables=True.""" + async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session: + role = Role(id=9002, name="will_be_dropped") + session.add(role) + await session.commit() + + # Verify tables were dropped by creating new session + async with create_db_session(DATABASE_URL, Base) as session: + result = await session.execute(select(Role)) + assert result.all() == [] + + @pytest.mark.anyio + async def test_tables_preserved_when_drop_disabled(self): + """Tables are preserved when drop_tables=False.""" + async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session: + role = Role(id=9003, name="preserved_role") + session.add(role) + await session.commit() + + # Create another session without dropping + async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session: + result = await session.execute(select(Role).where(Role.id == 9003)) + fetched = result.scalar_one_or_none() + assert fetched is not None + assert fetched.name == "preserved_role" + + # Cleanup: drop tables manually + async with create_db_session(DATABASE_URL, Base, drop_tables=True) as _: + pass