feat: add create_async_client and create_db_session pytest utils

function
This commit is contained in:
2026-01-28 08:48:22 -05:00
parent a9f486d905
commit ba5180a73b
3 changed files with 222 additions and 3 deletions

View File

@@ -1,5 +1,8 @@
from .plugin import register_fixtures from .plugin import register_fixtures
from .utils import create_async_client, create_db_session
__all__ = [ __all__ = [
"create_async_client",
"create_db_session",
"register_fixtures", "register_fixtures",
] ]

View File

@@ -0,0 +1,110 @@
"""Pytest helper utilities for FastAPI testing."""
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from typing import Any
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase
from ..db import create_db_context
@asynccontextmanager
async def create_async_client(
app: Any,
base_url: str = "http://test",
) -> AsyncGenerator[AsyncClient, None]:
"""Create an async httpx client for testing FastAPI applications.
Args:
app: FastAPI application instance.
base_url: Base URL for requests. Defaults to "http://test".
Yields:
An AsyncClient configured for the app.
Example:
```python
from fastapi import FastAPI
from fastapi_toolsets.pytest import create_async_client
app = FastAPI()
@pytest.fixture
async def client():
async with create_async_client(app) as c:
yield c
async def test_endpoint(client: AsyncClient):
response = await client.get("/health")
assert response.status_code == 200
```
"""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url=base_url) as client:
yield client
@asynccontextmanager
async def create_db_session(
database_url: str,
base: type[DeclarativeBase],
*,
echo: bool = False,
expire_on_commit: bool = False,
drop_tables: bool = True,
) -> AsyncGenerator[AsyncSession, None]:
"""Create a database session for testing.
Creates tables before yielding the session and optionally drops them after.
Each call creates a fresh engine and session for test isolation.
Args:
database_url: Database connection URL (e.g., "postgresql+asyncpg://...").
base: SQLAlchemy DeclarativeBase class containing model metadata.
echo: Enable SQLAlchemy query logging. Defaults to False.
expire_on_commit: Expire objects after commit. Defaults to False.
drop_tables: Drop tables after test. Defaults to True.
Yields:
An AsyncSession ready for database operations.
Example:
```python
from fastapi_toolsets.pytest import create_db_session
from app.models import Base
DATABASE_URL = "postgresql+asyncpg://user:pass@localhost/test_db"
@pytest.fixture
async def db_session():
async with create_db_session(DATABASE_URL, Base) as session:
yield session
async def test_create_user(db_session: AsyncSession):
user = User(name="test")
db_session.add(user)
await db_session.commit()
```
"""
engine = create_async_engine(database_url, echo=echo)
try:
# Create tables
async with engine.begin() as conn:
await conn.run_sync(base.metadata.create_all)
# Create session using existing db context utility
session_maker = async_sessionmaker(engine, expire_on_commit=expire_on_commit)
get_session = create_db_context(session_maker)
async with get_session() as session:
yield session
if drop_tables:
async with engine.begin() as conn:
await conn.run_sync(base.metadata.drop_all)
finally:
await engine.dispose()

View File

@@ -1,13 +1,20 @@
"""Tests for fastapi_toolsets.pytest_plugin module.""" """Tests for fastapi_toolsets.pytest module."""
import pytest import pytest
from fastapi import FastAPI
from httpx import AsyncClient
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
from fastapi_toolsets.fixtures import Context, FixtureRegistry from fastapi_toolsets.fixtures import Context, FixtureRegistry
from fastapi_toolsets.pytest import register_fixtures from fastapi_toolsets.pytest import (
create_async_client,
create_db_session,
register_fixtures,
)
from .conftest import Role, RoleCrud, User, UserCrud from .conftest import DATABASE_URL, Base, Role, RoleCrud, User, UserCrud
test_registry = FixtureRegistry() test_registry = FixtureRegistry()
@@ -159,3 +166,102 @@ class TestGeneratedFixtures:
assert len(roles) == 2 assert len(roles) == 2
assert len(users) == 2 assert len(users) == 2
class TestCreateAsyncClient:
"""Tests for create_async_client helper."""
@pytest.mark.anyio
async def test_creates_working_client(self):
"""Client can make requests to the app."""
app = FastAPI()
@app.get("/health")
async def health():
return {"status": "ok"}
async with create_async_client(app) as client:
assert isinstance(client, AsyncClient)
response = await client.get("/health")
assert response.status_code == 200
assert response.json() == {"status": "ok"}
@pytest.mark.anyio
async def test_custom_base_url(self):
"""Client uses custom base URL."""
app = FastAPI()
@app.get("/test")
async def test_endpoint():
return {"url": "test"}
async with create_async_client(app, base_url="http://custom") as client:
assert str(client.base_url) == "http://custom"
@pytest.mark.anyio
async def test_client_closes_properly(self):
"""Client is properly closed after context exit."""
app = FastAPI()
async with create_async_client(app) as client:
client_ref = client
assert client_ref.is_closed
class TestCreateDbSession:
"""Tests for create_db_session helper."""
@pytest.mark.anyio
async def test_creates_working_session(self):
"""Session can perform database operations."""
async with create_db_session(DATABASE_URL, Base) as session:
assert isinstance(session, AsyncSession)
role = Role(id=9001, name="test_helper_role")
session.add(role)
await session.commit()
result = await session.execute(select(Role).where(Role.id == 9001))
fetched = result.scalar_one()
assert fetched.name == "test_helper_role"
@pytest.mark.anyio
async def test_tables_created_before_session(self):
"""Tables exist when session is yielded."""
async with create_db_session(DATABASE_URL, Base) as session:
# Should not raise - tables exist
result = await session.execute(select(Role))
assert result.all() == []
@pytest.mark.anyio
async def test_tables_dropped_after_session(self):
"""Tables are dropped after session closes when drop_tables=True."""
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
role = Role(id=9002, name="will_be_dropped")
session.add(role)
await session.commit()
# Verify tables were dropped by creating new session
async with create_db_session(DATABASE_URL, Base) as session:
result = await session.execute(select(Role))
assert result.all() == []
@pytest.mark.anyio
async def test_tables_preserved_when_drop_disabled(self):
"""Tables are preserved when drop_tables=False."""
async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session:
role = Role(id=9003, name="preserved_role")
session.add(role)
await session.commit()
# Create another session without dropping
async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session:
result = await session.execute(select(Role).where(Role.id == 9003))
fetched = result.scalar_one_or_none()
assert fetched is not None
assert fetched.name == "preserved_role"
# Cleanup: drop tables manually
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as _:
pass