2 Commits

Author SHA1 Message Date
d3vyce
19805ab376 feat: add dependency_overrides parameter to create_async_client (#61) 2026-02-13 18:11:11 +01:00
d3vyce
d4498e2063 feat: add cleanup parameter to create_db_session (#60) 2026-02-13 18:03:28 +01:00
2 changed files with 88 additions and 8 deletions

View File

@@ -1,7 +1,7 @@
"""Pytest helper utilities for FastAPI testing."""
import os
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Callable
from contextlib import asynccontextmanager
from typing import Any
@@ -22,12 +22,16 @@ from ..db import create_db_context
async def create_async_client(
app: Any,
base_url: str = "http://test",
dependency_overrides: dict[Callable[..., Any], Callable[..., Any]] | None = None,
) -> AsyncGenerator[AsyncClient, None]:
"""Create an async httpx client for testing FastAPI applications.
Args:
app: FastAPI application instance.
base_url: Base URL for requests. Defaults to "http://test".
dependency_overrides: Optional mapping of original dependencies to
their test replacements. Applied via ``app.dependency_overrides``
before yielding and cleaned up after.
Yields:
An AsyncClient configured for the app.
@@ -46,10 +50,37 @@ async def create_async_client(
async def test_endpoint(client: AsyncClient):
response = await client.get("/health")
assert response.status_code == 200
Example with dependency overrides:
from fastapi_toolsets.pytest import create_async_client, create_db_session
from app.db import get_db
@pytest.fixture
async def db_session():
async with create_db_session(DATABASE_URL, Base, cleanup=True) as session:
yield session
@pytest.fixture
async def client(db_session):
async def override():
yield db_session
async with create_async_client(
app, dependency_overrides={get_db: override}
) as c:
yield c
"""
if dependency_overrides:
app.dependency_overrides.update(dependency_overrides)
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url=base_url) as client:
yield client
try:
async with AsyncClient(transport=transport, base_url=base_url) as client:
yield client
finally:
if dependency_overrides:
for key in dependency_overrides:
app.dependency_overrides.pop(key, None)
@asynccontextmanager
@@ -60,6 +91,7 @@ async def create_db_session(
echo: bool = False,
expire_on_commit: bool = False,
drop_tables: bool = True,
cleanup: bool = False,
) -> AsyncGenerator[AsyncSession, None]:
"""Create a database session for testing.
@@ -72,6 +104,8 @@ async def create_db_session(
echo: Enable SQLAlchemy query logging. Defaults to False.
expire_on_commit: Expire objects after commit. Defaults to False.
drop_tables: Drop tables after test. Defaults to True.
cleanup: Truncate all tables after test using
:func:`cleanup_tables`. Defaults to False.
Yields:
An AsyncSession ready for database operations.
@@ -84,7 +118,9 @@ async def create_db_session(
@pytest.fixture
async def db_session():
async with create_db_session(DATABASE_URL, Base) as session:
async with create_db_session(
DATABASE_URL, Base, cleanup=True
) as session:
yield session
async def test_create_user(db_session: AsyncSession):
@@ -106,6 +142,9 @@ async def create_db_session(
async with get_session() as session:
yield session
if cleanup:
await cleanup_tables(session, base)
if drop_tables:
async with engine.begin() as conn:
await conn.run_sync(base.metadata.drop_all)
@@ -193,7 +232,7 @@ async def create_worker_database(
Example:
from fastapi_toolsets.pytest import (
create_worker_database, create_db_session, cleanup_tables
create_worker_database, create_db_session,
)
DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost/test_db"
@@ -205,9 +244,10 @@ async def create_worker_database(
@pytest.fixture
async def db_session(worker_db_url):
async with create_db_session(worker_db_url, Base) as session:
async with create_db_session(
worker_db_url, Base, cleanup=True
) as session:
yield session
await cleanup_tables(session, Base)
"""
worker_url = worker_database_url(
database_url=database_url, default_test_db=default_test_db

View File

@@ -3,7 +3,7 @@
import uuid
import pytest
from fastapi import FastAPI
from fastapi import Depends, FastAPI
from httpx import AsyncClient
from sqlalchemy import select, text
from sqlalchemy.engine import make_url
@@ -236,6 +236,30 @@ class TestCreateAsyncClient:
assert client_ref.is_closed
@pytest.mark.anyio
async def test_dependency_overrides_applied_and_cleaned(self):
"""Dependency overrides are applied during the context and removed after."""
app = FastAPI()
async def original_dep() -> str:
return "original"
async def override_dep() -> str:
return "overridden"
@app.get("/dep")
async def dep_endpoint(value: str = Depends(original_dep)):
return {"value": value}
async with create_async_client(
app, dependency_overrides={original_dep: override_dep}
) as client:
response = await client.get("/dep")
assert response.json() == {"value": "overridden"}
# Overrides should be cleaned up
assert original_dep not in app.dependency_overrides
class TestCreateDbSession:
"""Tests for create_db_session helper."""
@@ -297,6 +321,22 @@ class TestCreateDbSession:
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as _:
pass
@pytest.mark.anyio
async def test_cleanup_truncates_tables(self):
"""Tables are truncated after session closes when cleanup=True."""
role_id = uuid.uuid4()
async with create_db_session(
DATABASE_URL, Base, cleanup=True, drop_tables=False
) as session:
role = Role(id=role_id, name="will_be_cleaned")
session.add(role)
await session.commit()
# Data should have been truncated, but tables still exist
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
result = await session.execute(select(Role))
assert result.all() == []
class TestGetXdistWorker:
"""Tests for _get_xdist_worker helper."""