diff --git a/src/fastapi_toolsets/pytest/utils.py b/src/fastapi_toolsets/pytest/utils.py index 9e59c9a..e2cdaa3 100644 --- a/src/fastapi_toolsets/pytest/utils.py +++ b/src/fastapi_toolsets/pytest/utils.py @@ -1,7 +1,7 @@ """Pytest helper utilities for FastAPI testing.""" import os -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Callable from contextlib import asynccontextmanager from typing import Any @@ -22,12 +22,16 @@ from ..db import create_db_context async def create_async_client( app: Any, base_url: str = "http://test", + dependency_overrides: dict[Callable[..., Any], Callable[..., Any]] | None = None, ) -> AsyncGenerator[AsyncClient, None]: """Create an async httpx client for testing FastAPI applications. Args: app: FastAPI application instance. base_url: Base URL for requests. Defaults to "http://test". + dependency_overrides: Optional mapping of original dependencies to + their test replacements. Applied via ``app.dependency_overrides`` + before yielding and cleaned up after. Yields: An AsyncClient configured for the app. @@ -46,10 +50,37 @@ async def create_async_client( async def test_endpoint(client: AsyncClient): response = await client.get("/health") assert response.status_code == 200 + + Example with dependency overrides: + from fastapi_toolsets.pytest import create_async_client, create_db_session + from app.db import get_db + + @pytest.fixture + async def db_session(): + async with create_db_session(DATABASE_URL, Base, cleanup=True) as session: + yield session + + @pytest.fixture + async def client(db_session): + async def override(): + yield db_session + + async with create_async_client( + app, dependency_overrides={get_db: override} + ) as c: + yield c """ + if dependency_overrides: + app.dependency_overrides.update(dependency_overrides) + transport = ASGITransport(app=app) - async with AsyncClient(transport=transport, base_url=base_url) as client: - yield client + try: + async with AsyncClient(transport=transport, base_url=base_url) as client: + yield client + finally: + if dependency_overrides: + for key in dependency_overrides: + app.dependency_overrides.pop(key, None) @asynccontextmanager diff --git a/tests/test_pytest.py b/tests/test_pytest.py index 410a82b..d51947b 100644 --- a/tests/test_pytest.py +++ b/tests/test_pytest.py @@ -3,7 +3,7 @@ import uuid import pytest -from fastapi import FastAPI +from fastapi import Depends, FastAPI from httpx import AsyncClient from sqlalchemy import select, text from sqlalchemy.engine import make_url @@ -236,6 +236,30 @@ class TestCreateAsyncClient: assert client_ref.is_closed + @pytest.mark.anyio + async def test_dependency_overrides_applied_and_cleaned(self): + """Dependency overrides are applied during the context and removed after.""" + app = FastAPI() + + async def original_dep() -> str: + return "original" + + async def override_dep() -> str: + return "overridden" + + @app.get("/dep") + async def dep_endpoint(value: str = Depends(original_dep)): + return {"value": value} + + async with create_async_client( + app, dependency_overrides={original_dep: override_dep} + ) as client: + response = await client.get("/dep") + assert response.json() == {"value": "overridden"} + + # Overrides should be cleaned up + assert original_dep not in app.dependency_overrides + class TestCreateDbSession: """Tests for create_db_session helper."""