mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-03-01 17:00:48 +01:00
feat: add create_async_client and create_db_session pytest utils
function
This commit is contained in:
@@ -1,5 +1,8 @@
|
|||||||
from .plugin import register_fixtures
|
from .plugin import register_fixtures
|
||||||
|
from .utils import create_async_client, create_db_session
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"create_async_client",
|
||||||
|
"create_db_session",
|
||||||
"register_fixtures",
|
"register_fixtures",
|
||||||
]
|
]
|
||||||
|
|||||||
110
src/fastapi_toolsets/pytest/utils.py
Normal file
110
src/fastapi_toolsets/pytest/utils.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
"""Pytest helper utilities for FastAPI testing."""
|
||||||
|
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
from ..db import create_db_context
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def create_async_client(
|
||||||
|
app: Any,
|
||||||
|
base_url: str = "http://test",
|
||||||
|
) -> AsyncGenerator[AsyncClient, None]:
|
||||||
|
"""Create an async httpx client for testing FastAPI applications.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: FastAPI application instance.
|
||||||
|
base_url: Base URL for requests. Defaults to "http://test".
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
An AsyncClient configured for the app.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi_toolsets.pytest import create_async_client
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def client():
|
||||||
|
async with create_async_client(app) as c:
|
||||||
|
yield c
|
||||||
|
|
||||||
|
async def test_endpoint(client: AsyncClient):
|
||||||
|
response = await client.get("/health")
|
||||||
|
assert response.status_code == 200
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url=base_url) as client:
|
||||||
|
yield client
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def create_db_session(
|
||||||
|
database_url: str,
|
||||||
|
base: type[DeclarativeBase],
|
||||||
|
*,
|
||||||
|
echo: bool = False,
|
||||||
|
expire_on_commit: bool = False,
|
||||||
|
drop_tables: bool = True,
|
||||||
|
) -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
"""Create a database session for testing.
|
||||||
|
|
||||||
|
Creates tables before yielding the session and optionally drops them after.
|
||||||
|
Each call creates a fresh engine and session for test isolation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
database_url: Database connection URL (e.g., "postgresql+asyncpg://...").
|
||||||
|
base: SQLAlchemy DeclarativeBase class containing model metadata.
|
||||||
|
echo: Enable SQLAlchemy query logging. Defaults to False.
|
||||||
|
expire_on_commit: Expire objects after commit. Defaults to False.
|
||||||
|
drop_tables: Drop tables after test. Defaults to True.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
An AsyncSession ready for database operations.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.pytest import create_db_session
|
||||||
|
from app.models import Base
|
||||||
|
|
||||||
|
DATABASE_URL = "postgresql+asyncpg://user:pass@localhost/test_db"
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_session():
|
||||||
|
async with create_db_session(DATABASE_URL, Base) as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
async def test_create_user(db_session: AsyncSession):
|
||||||
|
user = User(name="test")
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
engine = create_async_engine(database_url, echo=echo)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create tables
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(base.metadata.create_all)
|
||||||
|
|
||||||
|
# Create session using existing db context utility
|
||||||
|
session_maker = async_sessionmaker(engine, expire_on_commit=expire_on_commit)
|
||||||
|
get_session = create_db_context(session_maker)
|
||||||
|
|
||||||
|
async with get_session() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
if drop_tables:
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(base.metadata.drop_all)
|
||||||
|
finally:
|
||||||
|
await engine.dispose()
|
||||||
@@ -1,13 +1,20 @@
|
|||||||
"""Tests for fastapi_toolsets.pytest_plugin module."""
|
"""Tests for fastapi_toolsets.pytest module."""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from httpx import AsyncClient
|
||||||
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
from fastapi_toolsets.fixtures import Context, FixtureRegistry
|
from fastapi_toolsets.fixtures import Context, FixtureRegistry
|
||||||
from fastapi_toolsets.pytest import register_fixtures
|
from fastapi_toolsets.pytest import (
|
||||||
|
create_async_client,
|
||||||
|
create_db_session,
|
||||||
|
register_fixtures,
|
||||||
|
)
|
||||||
|
|
||||||
from .conftest import Role, RoleCrud, User, UserCrud
|
from .conftest import DATABASE_URL, Base, Role, RoleCrud, User, UserCrud
|
||||||
|
|
||||||
test_registry = FixtureRegistry()
|
test_registry = FixtureRegistry()
|
||||||
|
|
||||||
@@ -159,3 +166,102 @@ class TestGeneratedFixtures:
|
|||||||
|
|
||||||
assert len(roles) == 2
|
assert len(roles) == 2
|
||||||
assert len(users) == 2
|
assert len(users) == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateAsyncClient:
|
||||||
|
"""Tests for create_async_client helper."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_creates_working_client(self):
|
||||||
|
"""Client can make requests to the app."""
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health():
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
async with create_async_client(app) as client:
|
||||||
|
assert isinstance(client, AsyncClient)
|
||||||
|
response = await client.get("/health")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"status": "ok"}
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_custom_base_url(self):
|
||||||
|
"""Client uses custom base URL."""
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.get("/test")
|
||||||
|
async def test_endpoint():
|
||||||
|
return {"url": "test"}
|
||||||
|
|
||||||
|
async with create_async_client(app, base_url="http://custom") as client:
|
||||||
|
assert str(client.base_url) == "http://custom"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_client_closes_properly(self):
|
||||||
|
"""Client is properly closed after context exit."""
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
async with create_async_client(app) as client:
|
||||||
|
client_ref = client
|
||||||
|
|
||||||
|
assert client_ref.is_closed
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateDbSession:
|
||||||
|
"""Tests for create_db_session helper."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_creates_working_session(self):
|
||||||
|
"""Session can perform database operations."""
|
||||||
|
async with create_db_session(DATABASE_URL, Base) as session:
|
||||||
|
assert isinstance(session, AsyncSession)
|
||||||
|
|
||||||
|
role = Role(id=9001, name="test_helper_role")
|
||||||
|
session.add(role)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
result = await session.execute(select(Role).where(Role.id == 9001))
|
||||||
|
fetched = result.scalar_one()
|
||||||
|
assert fetched.name == "test_helper_role"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_tables_created_before_session(self):
|
||||||
|
"""Tables exist when session is yielded."""
|
||||||
|
async with create_db_session(DATABASE_URL, Base) as session:
|
||||||
|
# Should not raise - tables exist
|
||||||
|
result = await session.execute(select(Role))
|
||||||
|
assert result.all() == []
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_tables_dropped_after_session(self):
|
||||||
|
"""Tables are dropped after session closes when drop_tables=True."""
|
||||||
|
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
|
||||||
|
role = Role(id=9002, name="will_be_dropped")
|
||||||
|
session.add(role)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# Verify tables were dropped by creating new session
|
||||||
|
async with create_db_session(DATABASE_URL, Base) as session:
|
||||||
|
result = await session.execute(select(Role))
|
||||||
|
assert result.all() == []
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_tables_preserved_when_drop_disabled(self):
|
||||||
|
"""Tables are preserved when drop_tables=False."""
|
||||||
|
async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session:
|
||||||
|
role = Role(id=9003, name="preserved_role")
|
||||||
|
session.add(role)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# Create another session without dropping
|
||||||
|
async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session:
|
||||||
|
result = await session.execute(select(Role).where(Role.id == 9003))
|
||||||
|
fetched = result.scalar_one_or_none()
|
||||||
|
assert fetched is not None
|
||||||
|
assert fetched.name == "preserved_role"
|
||||||
|
|
||||||
|
# Cleanup: drop tables manually
|
||||||
|
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as _:
|
||||||
|
pass
|
||||||
Reference in New Issue
Block a user