4 Commits

Author SHA1 Message Date
d971261f98 Version 0.9.0 2026-02-14 14:38:58 -05:00
d3vyce
74a54b7396 feat: add optional data field in ApiError (#63) 2026-02-14 20:37:50 +01:00
d3vyce
19805ab376 feat: add dependency_overrides parameter to create_async_client (#61) 2026-02-13 18:11:11 +01:00
d3vyce
d4498e2063 feat: add cleanup parameter to create_db_session (#60) 2026-02-13 18:03:28 +01:00
10 changed files with 214 additions and 36 deletions

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "fastapi-toolsets" name = "fastapi-toolsets"
version = "0.8.1" version = "0.9.0"
description = "Reusable tools for FastAPI: async CRUD, fixtures, CLI, and standardized responses for SQLAlchemy + PostgreSQL" description = "Reusable tools for FastAPI: async CRUD, fixtures, CLI, and standardized responses for SQLAlchemy + PostgreSQL"
readme = "README.md" readme = "README.md"
license = "MIT" license = "MIT"

View File

@@ -21,4 +21,4 @@ Example usage:
return Response(data={"user": user.username}, message="Success") return Response(data={"user": user.username}, message="Success")
""" """
__version__ = "0.8.1" __version__ = "0.9.0"

View File

@@ -183,7 +183,7 @@ def generate_error_responses(
"content": { "content": {
"application/json": { "application/json": {
"example": { "example": {
"data": None, "data": api_error.data,
"status": ResponseStatus.FAIL.value, "status": ResponseStatus.FAIL.value,
"message": api_error.msg, "message": api_error.msg,
"description": api_error.desc, "description": api_error.desc,

View File

@@ -7,7 +7,7 @@ from fastapi.exceptions import RequestValidationError, ResponseValidationError
from fastapi.openapi.utils import get_openapi from fastapi.openapi.utils import get_openapi
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from ..schemas import ResponseStatus from ..schemas import ErrorResponse, ResponseStatus
from .exceptions import ApiException from .exceptions import ApiException
@@ -54,16 +54,16 @@ def _register_exception_handlers(app: FastAPI) -> None:
async def api_exception_handler(request: Request, exc: ApiException) -> Response: async def api_exception_handler(request: Request, exc: ApiException) -> Response:
"""Handle custom API exceptions with structured response.""" """Handle custom API exceptions with structured response."""
api_error = exc.api_error api_error = exc.api_error
error_response = ErrorResponse(
data=api_error.data,
message=api_error.msg,
description=api_error.desc,
error_code=api_error.err_code,
)
return JSONResponse( return JSONResponse(
status_code=api_error.code, status_code=api_error.code,
content={ content=error_response.model_dump(),
"data": None,
"status": ResponseStatus.FAIL.value,
"message": api_error.msg,
"description": api_error.desc,
"error_code": api_error.err_code,
},
) )
@app.exception_handler(RequestValidationError) @app.exception_handler(RequestValidationError)
@@ -83,15 +83,15 @@ def _register_exception_handlers(app: FastAPI) -> None:
@app.exception_handler(Exception) @app.exception_handler(Exception)
async def generic_exception_handler(request: Request, exc: Exception) -> Response: async def generic_exception_handler(request: Request, exc: Exception) -> Response:
"""Handle all unhandled exceptions with a generic 500 response.""" """Handle all unhandled exceptions with a generic 500 response."""
error_response = ErrorResponse(
message="Internal Server Error",
description="An unexpected error occurred. Please try again later.",
error_code="SERVER-500",
)
return JSONResponse( return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={ content=error_response.model_dump(),
"data": None,
"status": ResponseStatus.FAIL.value,
"message": "Internal Server Error",
"description": "An unexpected error occurred. Please try again later.",
"error_code": "SERVER-500",
},
) )
@@ -116,15 +116,16 @@ def _format_validation_error(
} }
) )
error_response = ErrorResponse(
data={"errors": formatted_errors},
message="Validation Error",
description=f"{len(formatted_errors)} validation error(s) detected",
error_code="VAL-422",
)
return JSONResponse( return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content={ content=error_response.model_dump(),
"data": {"errors": formatted_errors},
"status": ResponseStatus.FAIL.value,
"message": "Validation Error",
"description": f"{len(formatted_errors)} validation error(s) detected",
"error_code": "VAL-422",
},
) )

View File

@@ -1,7 +1,7 @@
"""Pytest helper utilities for FastAPI testing.""" """Pytest helper utilities for FastAPI testing."""
import os import os
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator, Callable
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Any from typing import Any
@@ -22,12 +22,16 @@ from ..db import create_db_context
async def create_async_client( async def create_async_client(
app: Any, app: Any,
base_url: str = "http://test", base_url: str = "http://test",
dependency_overrides: dict[Callable[..., Any], Callable[..., Any]] | None = None,
) -> AsyncGenerator[AsyncClient, None]: ) -> AsyncGenerator[AsyncClient, None]:
"""Create an async httpx client for testing FastAPI applications. """Create an async httpx client for testing FastAPI applications.
Args: Args:
app: FastAPI application instance. app: FastAPI application instance.
base_url: Base URL for requests. Defaults to "http://test". base_url: Base URL for requests. Defaults to "http://test".
dependency_overrides: Optional mapping of original dependencies to
their test replacements. Applied via ``app.dependency_overrides``
before yielding and cleaned up after.
Yields: Yields:
An AsyncClient configured for the app. An AsyncClient configured for the app.
@@ -46,10 +50,37 @@ async def create_async_client(
async def test_endpoint(client: AsyncClient): async def test_endpoint(client: AsyncClient):
response = await client.get("/health") response = await client.get("/health")
assert response.status_code == 200 assert response.status_code == 200
Example with dependency overrides:
from fastapi_toolsets.pytest import create_async_client, create_db_session
from app.db import get_db
@pytest.fixture
async def db_session():
async with create_db_session(DATABASE_URL, Base, cleanup=True) as session:
yield session
@pytest.fixture
async def client(db_session):
async def override():
yield db_session
async with create_async_client(
app, dependency_overrides={get_db: override}
) as c:
yield c
""" """
if dependency_overrides:
app.dependency_overrides.update(dependency_overrides)
transport = ASGITransport(app=app) transport = ASGITransport(app=app)
try:
async with AsyncClient(transport=transport, base_url=base_url) as client: async with AsyncClient(transport=transport, base_url=base_url) as client:
yield client yield client
finally:
if dependency_overrides:
for key in dependency_overrides:
app.dependency_overrides.pop(key, None)
@asynccontextmanager @asynccontextmanager
@@ -60,6 +91,7 @@ async def create_db_session(
echo: bool = False, echo: bool = False,
expire_on_commit: bool = False, expire_on_commit: bool = False,
drop_tables: bool = True, drop_tables: bool = True,
cleanup: bool = False,
) -> AsyncGenerator[AsyncSession, None]: ) -> AsyncGenerator[AsyncSession, None]:
"""Create a database session for testing. """Create a database session for testing.
@@ -72,6 +104,8 @@ async def create_db_session(
echo: Enable SQLAlchemy query logging. Defaults to False. echo: Enable SQLAlchemy query logging. Defaults to False.
expire_on_commit: Expire objects after commit. Defaults to False. expire_on_commit: Expire objects after commit. Defaults to False.
drop_tables: Drop tables after test. Defaults to True. drop_tables: Drop tables after test. Defaults to True.
cleanup: Truncate all tables after test using
:func:`cleanup_tables`. Defaults to False.
Yields: Yields:
An AsyncSession ready for database operations. An AsyncSession ready for database operations.
@@ -84,7 +118,9 @@ async def create_db_session(
@pytest.fixture @pytest.fixture
async def db_session(): async def db_session():
async with create_db_session(DATABASE_URL, Base) as session: async with create_db_session(
DATABASE_URL, Base, cleanup=True
) as session:
yield session yield session
async def test_create_user(db_session: AsyncSession): async def test_create_user(db_session: AsyncSession):
@@ -106,6 +142,9 @@ async def create_db_session(
async with get_session() as session: async with get_session() as session:
yield session yield session
if cleanup:
await cleanup_tables(session, base)
if drop_tables: if drop_tables:
async with engine.begin() as conn: async with engine.begin() as conn:
await conn.run_sync(base.metadata.drop_all) await conn.run_sync(base.metadata.drop_all)
@@ -193,7 +232,7 @@ async def create_worker_database(
Example: Example:
from fastapi_toolsets.pytest import ( from fastapi_toolsets.pytest import (
create_worker_database, create_db_session, cleanup_tables create_worker_database, create_db_session,
) )
DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost/test_db" DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost/test_db"
@@ -205,9 +244,10 @@ async def create_worker_database(
@pytest.fixture @pytest.fixture
async def db_session(worker_db_url): async def db_session(worker_db_url):
async with create_db_session(worker_db_url, Base) as session: async with create_db_session(
worker_db_url, Base, cleanup=True
) as session:
yield session yield session
await cleanup_tables(session, Base)
""" """
worker_url = worker_database_url( worker_url = worker_database_url(
database_url=database_url, default_test_db=default_test_db database_url=database_url, default_test_db=default_test_db

View File

@@ -1,7 +1,7 @@
"""Base Pydantic schemas for API responses.""" """Base Pydantic schemas for API responses."""
from enum import Enum from enum import Enum
from typing import ClassVar, Generic, TypeVar from typing import Any, ClassVar, Generic, TypeVar
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
@@ -50,6 +50,7 @@ class ApiError(PydanticBase):
msg: str msg: str
desc: str desc: str
err_code: str err_code: str
data: Any | None = None
class BaseResponse(PydanticBase): class BaseResponse(PydanticBase):
@@ -84,7 +85,7 @@ class ErrorResponse(BaseResponse):
status: ResponseStatus = ResponseStatus.FAIL status: ResponseStatus = ResponseStatus.FAIL
description: str | None = None description: str | None = None
data: None = None data: Any | None = None
class Pagination(PydanticBase): class Pagination(PydanticBase):

View File

@@ -108,6 +108,24 @@ class TestGenerateErrorResponses:
assert example["status"] == "FAIL" assert example["status"] == "FAIL"
assert example["error_code"] == "RES-404" assert example["error_code"] == "RES-404"
assert example["message"] == "Not Found" assert example["message"] == "Not Found"
assert example["data"] is None
def test_response_example_with_data(self):
"""Generated response includes data when set on ApiError."""
class ErrorWithData(ApiException):
api_error = ApiError(
code=400,
msg="Bad Request",
desc="Invalid input.",
err_code="BAD-400",
data={"details": "some context"},
)
responses = generate_error_responses(ErrorWithData)
example = responses[400]["content"]["application/json"]["example"]
assert example["data"] == {"details": "some context"}
class TestInitExceptionsHandlers: class TestInitExceptionsHandlers:
@@ -137,6 +155,59 @@ class TestInitExceptionsHandlers:
assert data["error_code"] == "RES-404" assert data["error_code"] == "RES-404"
assert data["message"] == "Not Found" assert data["message"] == "Not Found"
def test_handles_api_exception_without_data(self):
"""ApiException without data returns null data field."""
app = FastAPI()
init_exceptions_handlers(app)
@app.get("/error")
async def raise_error():
raise NotFoundError()
client = TestClient(app)
response = client.get("/error")
assert response.status_code == 404
assert response.json()["data"] is None
def test_handles_api_exception_with_data(self):
"""ApiException with data returns the data payload."""
app = FastAPI()
init_exceptions_handlers(app)
class CustomValidationError(ApiException):
api_error = ApiError(
code=422,
msg="Validation Error",
desc="1 validation error(s) detected",
err_code="CUSTOM-422",
data={
"errors": [
{
"field": "email",
"message": "invalid format",
"type": "value_error",
}
]
},
)
@app.get("/error")
async def raise_error():
raise CustomValidationError()
client = TestClient(app)
response = client.get("/error")
assert response.status_code == 422
data = response.json()
assert data["data"] == {
"errors": [
{"field": "email", "message": "invalid format", "type": "value_error"}
]
}
assert data["error_code"] == "CUSTOM-422"
def test_handles_validation_error(self): def test_handles_validation_error(self):
"""Handles validation errors with structured response.""" """Handles validation errors with structured response."""
from pydantic import BaseModel from pydantic import BaseModel

View File

@@ -3,7 +3,7 @@
import uuid import uuid
import pytest import pytest
from fastapi import FastAPI from fastapi import Depends, FastAPI
from httpx import AsyncClient from httpx import AsyncClient
from sqlalchemy import select, text from sqlalchemy import select, text
from sqlalchemy.engine import make_url from sqlalchemy.engine import make_url
@@ -236,6 +236,30 @@ class TestCreateAsyncClient:
assert client_ref.is_closed assert client_ref.is_closed
@pytest.mark.anyio
async def test_dependency_overrides_applied_and_cleaned(self):
"""Dependency overrides are applied during the context and removed after."""
app = FastAPI()
async def original_dep() -> str:
return "original"
async def override_dep() -> str:
return "overridden"
@app.get("/dep")
async def dep_endpoint(value: str = Depends(original_dep)):
return {"value": value}
async with create_async_client(
app, dependency_overrides={original_dep: override_dep}
) as client:
response = await client.get("/dep")
assert response.json() == {"value": "overridden"}
# Overrides should be cleaned up
assert original_dep not in app.dependency_overrides
class TestCreateDbSession: class TestCreateDbSession:
"""Tests for create_db_session helper.""" """Tests for create_db_session helper."""
@@ -297,6 +321,22 @@ class TestCreateDbSession:
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as _: async with create_db_session(DATABASE_URL, Base, drop_tables=True) as _:
pass pass
@pytest.mark.anyio
async def test_cleanup_truncates_tables(self):
"""Tables are truncated after session closes when cleanup=True."""
role_id = uuid.uuid4()
async with create_db_session(
DATABASE_URL, Base, cleanup=True, drop_tables=False
) as session:
role = Role(id=role_id, name="will_be_cleaned")
session.add(role)
await session.commit()
# Data should have been truncated, but tables still exist
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
result = await session.execute(select(Role))
assert result.all() == []
class TestGetXdistWorker: class TestGetXdistWorker:
"""Tests for _get_xdist_worker helper.""" """Tests for _get_xdist_worker helper."""

View File

@@ -46,6 +46,31 @@ class TestApiError:
assert error.desc == "The resource was not found." assert error.desc == "The resource was not found."
assert error.err_code == "RES-404" assert error.err_code == "RES-404"
def test_data_defaults_to_none(self):
"""ApiError data field defaults to None."""
error = ApiError(
code=404,
msg="Not Found",
desc="The resource was not found.",
err_code="RES-404",
)
assert error.data is None
def test_create_with_data(self):
"""ApiError can be created with a data payload."""
error = ApiError(
code=422,
msg="Validation Error",
desc="2 validation error(s) detected",
err_code="VAL-422",
data={
"errors": [{"field": "name", "message": "required", "type": "missing"}]
},
)
assert error.data == {
"errors": [{"field": "name", "message": "required", "type": "missing"}]
}
def test_requires_all_fields(self): def test_requires_all_fields(self):
"""ApiError requires all fields.""" """ApiError requires all fields."""
with pytest.raises(ValidationError): with pytest.raises(ValidationError):

2
uv.lock generated
View File

@@ -242,7 +242,7 @@ wheels = [
[[package]] [[package]]
name = "fastapi-toolsets" name = "fastapi-toolsets"
version = "0.8.1" version = "0.9.0"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "asyncpg" }, { name = "asyncpg" },