feat: add dependency_overrides parameter to create_async_client (#61)

This commit is contained in:
d3vyce
2026-02-13 18:11:11 +01:00
committed by GitHub
parent d4498e2063
commit 19805ab376
2 changed files with 59 additions and 4 deletions

View File

@@ -1,7 +1,7 @@
"""Pytest helper utilities for FastAPI testing.""" """Pytest helper utilities for FastAPI testing."""
import os import os
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator, Callable
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Any from typing import Any
@@ -22,12 +22,16 @@ from ..db import create_db_context
async def create_async_client( async def create_async_client(
app: Any, app: Any,
base_url: str = "http://test", base_url: str = "http://test",
dependency_overrides: dict[Callable[..., Any], Callable[..., Any]] | None = None,
) -> AsyncGenerator[AsyncClient, None]: ) -> AsyncGenerator[AsyncClient, None]:
"""Create an async httpx client for testing FastAPI applications. """Create an async httpx client for testing FastAPI applications.
Args: Args:
app: FastAPI application instance. app: FastAPI application instance.
base_url: Base URL for requests. Defaults to "http://test". base_url: Base URL for requests. Defaults to "http://test".
dependency_overrides: Optional mapping of original dependencies to
their test replacements. Applied via ``app.dependency_overrides``
before yielding and cleaned up after.
Yields: Yields:
An AsyncClient configured for the app. An AsyncClient configured for the app.
@@ -46,10 +50,37 @@ async def create_async_client(
async def test_endpoint(client: AsyncClient): async def test_endpoint(client: AsyncClient):
response = await client.get("/health") response = await client.get("/health")
assert response.status_code == 200 assert response.status_code == 200
Example with dependency overrides:
from fastapi_toolsets.pytest import create_async_client, create_db_session
from app.db import get_db
@pytest.fixture
async def db_session():
async with create_db_session(DATABASE_URL, Base, cleanup=True) as session:
yield session
@pytest.fixture
async def client(db_session):
async def override():
yield db_session
async with create_async_client(
app, dependency_overrides={get_db: override}
) as c:
yield c
""" """
if dependency_overrides:
app.dependency_overrides.update(dependency_overrides)
transport = ASGITransport(app=app) transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url=base_url) as client: try:
yield client async with AsyncClient(transport=transport, base_url=base_url) as client:
yield client
finally:
if dependency_overrides:
for key in dependency_overrides:
app.dependency_overrides.pop(key, None)
@asynccontextmanager @asynccontextmanager

View File

@@ -3,7 +3,7 @@
import uuid import uuid
import pytest import pytest
from fastapi import FastAPI from fastapi import Depends, FastAPI
from httpx import AsyncClient from httpx import AsyncClient
from sqlalchemy import select, text from sqlalchemy import select, text
from sqlalchemy.engine import make_url from sqlalchemy.engine import make_url
@@ -236,6 +236,30 @@ class TestCreateAsyncClient:
assert client_ref.is_closed assert client_ref.is_closed
@pytest.mark.anyio
async def test_dependency_overrides_applied_and_cleaned(self):
"""Dependency overrides are applied during the context and removed after."""
app = FastAPI()
async def original_dep() -> str:
return "original"
async def override_dep() -> str:
return "overridden"
@app.get("/dep")
async def dep_endpoint(value: str = Depends(original_dep)):
return {"value": value}
async with create_async_client(
app, dependency_overrides={original_dep: override_dep}
) as client:
response = await client.get("/dep")
assert response.json() == {"value": "overridden"}
# Overrides should be cleaned up
assert original_dep not in app.dependency_overrides
class TestCreateDbSession: class TestCreateDbSession:
"""Tests for create_db_session helper.""" """Tests for create_db_session helper."""