mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-03-01 17:00:48 +01:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
d971261f98
|
|||
|
|
74a54b7396 | ||
|
|
19805ab376 | ||
|
|
d4498e2063 |
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "fastapi-toolsets"
|
||||
version = "0.8.1"
|
||||
version = "0.9.0"
|
||||
description = "Reusable tools for FastAPI: async CRUD, fixtures, CLI, and standardized responses for SQLAlchemy + PostgreSQL"
|
||||
readme = "README.md"
|
||||
license = "MIT"
|
||||
|
||||
@@ -21,4 +21,4 @@ Example usage:
|
||||
return Response(data={"user": user.username}, message="Success")
|
||||
"""
|
||||
|
||||
__version__ = "0.8.1"
|
||||
__version__ = "0.9.0"
|
||||
|
||||
@@ -183,7 +183,7 @@ def generate_error_responses(
|
||||
"content": {
|
||||
"application/json": {
|
||||
"example": {
|
||||
"data": None,
|
||||
"data": api_error.data,
|
||||
"status": ResponseStatus.FAIL.value,
|
||||
"message": api_error.msg,
|
||||
"description": api_error.desc,
|
||||
|
||||
@@ -7,7 +7,7 @@ from fastapi.exceptions import RequestValidationError, ResponseValidationError
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from ..schemas import ResponseStatus
|
||||
from ..schemas import ErrorResponse, ResponseStatus
|
||||
from .exceptions import ApiException
|
||||
|
||||
|
||||
@@ -54,16 +54,16 @@ def _register_exception_handlers(app: FastAPI) -> None:
|
||||
async def api_exception_handler(request: Request, exc: ApiException) -> Response:
|
||||
"""Handle custom API exceptions with structured response."""
|
||||
api_error = exc.api_error
|
||||
error_response = ErrorResponse(
|
||||
data=api_error.data,
|
||||
message=api_error.msg,
|
||||
description=api_error.desc,
|
||||
error_code=api_error.err_code,
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=api_error.code,
|
||||
content={
|
||||
"data": None,
|
||||
"status": ResponseStatus.FAIL.value,
|
||||
"message": api_error.msg,
|
||||
"description": api_error.desc,
|
||||
"error_code": api_error.err_code,
|
||||
},
|
||||
content=error_response.model_dump(),
|
||||
)
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
@@ -83,15 +83,15 @@ def _register_exception_handlers(app: FastAPI) -> None:
|
||||
@app.exception_handler(Exception)
|
||||
async def generic_exception_handler(request: Request, exc: Exception) -> Response:
|
||||
"""Handle all unhandled exceptions with a generic 500 response."""
|
||||
error_response = ErrorResponse(
|
||||
message="Internal Server Error",
|
||||
description="An unexpected error occurred. Please try again later.",
|
||||
error_code="SERVER-500",
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={
|
||||
"data": None,
|
||||
"status": ResponseStatus.FAIL.value,
|
||||
"message": "Internal Server Error",
|
||||
"description": "An unexpected error occurred. Please try again later.",
|
||||
"error_code": "SERVER-500",
|
||||
},
|
||||
content=error_response.model_dump(),
|
||||
)
|
||||
|
||||
|
||||
@@ -116,15 +116,16 @@ def _format_validation_error(
|
||||
}
|
||||
)
|
||||
|
||||
error_response = ErrorResponse(
|
||||
data={"errors": formatted_errors},
|
||||
message="Validation Error",
|
||||
description=f"{len(formatted_errors)} validation error(s) detected",
|
||||
error_code="VAL-422",
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
content={
|
||||
"data": {"errors": formatted_errors},
|
||||
"status": ResponseStatus.FAIL.value,
|
||||
"message": "Validation Error",
|
||||
"description": f"{len(formatted_errors)} validation error(s) detected",
|
||||
"error_code": "VAL-422",
|
||||
},
|
||||
content=error_response.model_dump(),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Pytest helper utilities for FastAPI testing."""
|
||||
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
@@ -22,12 +22,16 @@ from ..db import create_db_context
|
||||
async def create_async_client(
|
||||
app: Any,
|
||||
base_url: str = "http://test",
|
||||
dependency_overrides: dict[Callable[..., Any], Callable[..., Any]] | None = None,
|
||||
) -> AsyncGenerator[AsyncClient, None]:
|
||||
"""Create an async httpx client for testing FastAPI applications.
|
||||
|
||||
Args:
|
||||
app: FastAPI application instance.
|
||||
base_url: Base URL for requests. Defaults to "http://test".
|
||||
dependency_overrides: Optional mapping of original dependencies to
|
||||
their test replacements. Applied via ``app.dependency_overrides``
|
||||
before yielding and cleaned up after.
|
||||
|
||||
Yields:
|
||||
An AsyncClient configured for the app.
|
||||
@@ -46,10 +50,37 @@ async def create_async_client(
|
||||
async def test_endpoint(client: AsyncClient):
|
||||
response = await client.get("/health")
|
||||
assert response.status_code == 200
|
||||
|
||||
Example with dependency overrides:
|
||||
from fastapi_toolsets.pytest import create_async_client, create_db_session
|
||||
from app.db import get_db
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session():
|
||||
async with create_db_session(DATABASE_URL, Base, cleanup=True) as session:
|
||||
yield session
|
||||
|
||||
@pytest.fixture
|
||||
async def client(db_session):
|
||||
async def override():
|
||||
yield db_session
|
||||
|
||||
async with create_async_client(
|
||||
app, dependency_overrides={get_db: override}
|
||||
) as c:
|
||||
yield c
|
||||
"""
|
||||
if dependency_overrides:
|
||||
app.dependency_overrides.update(dependency_overrides)
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
try:
|
||||
async with AsyncClient(transport=transport, base_url=base_url) as client:
|
||||
yield client
|
||||
finally:
|
||||
if dependency_overrides:
|
||||
for key in dependency_overrides:
|
||||
app.dependency_overrides.pop(key, None)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@@ -60,6 +91,7 @@ async def create_db_session(
|
||||
echo: bool = False,
|
||||
expire_on_commit: bool = False,
|
||||
drop_tables: bool = True,
|
||||
cleanup: bool = False,
|
||||
) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Create a database session for testing.
|
||||
|
||||
@@ -72,6 +104,8 @@ async def create_db_session(
|
||||
echo: Enable SQLAlchemy query logging. Defaults to False.
|
||||
expire_on_commit: Expire objects after commit. Defaults to False.
|
||||
drop_tables: Drop tables after test. Defaults to True.
|
||||
cleanup: Truncate all tables after test using
|
||||
:func:`cleanup_tables`. Defaults to False.
|
||||
|
||||
Yields:
|
||||
An AsyncSession ready for database operations.
|
||||
@@ -84,7 +118,9 @@ async def create_db_session(
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session():
|
||||
async with create_db_session(DATABASE_URL, Base) as session:
|
||||
async with create_db_session(
|
||||
DATABASE_URL, Base, cleanup=True
|
||||
) as session:
|
||||
yield session
|
||||
|
||||
async def test_create_user(db_session: AsyncSession):
|
||||
@@ -106,6 +142,9 @@ async def create_db_session(
|
||||
async with get_session() as session:
|
||||
yield session
|
||||
|
||||
if cleanup:
|
||||
await cleanup_tables(session, base)
|
||||
|
||||
if drop_tables:
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(base.metadata.drop_all)
|
||||
@@ -193,7 +232,7 @@ async def create_worker_database(
|
||||
|
||||
Example:
|
||||
from fastapi_toolsets.pytest import (
|
||||
create_worker_database, create_db_session, cleanup_tables
|
||||
create_worker_database, create_db_session,
|
||||
)
|
||||
|
||||
DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost/test_db"
|
||||
@@ -205,9 +244,10 @@ async def create_worker_database(
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session(worker_db_url):
|
||||
async with create_db_session(worker_db_url, Base) as session:
|
||||
async with create_db_session(
|
||||
worker_db_url, Base, cleanup=True
|
||||
) as session:
|
||||
yield session
|
||||
await cleanup_tables(session, Base)
|
||||
"""
|
||||
worker_url = worker_database_url(
|
||||
database_url=database_url, default_test_db=default_test_db
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Base Pydantic schemas for API responses."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import ClassVar, Generic, TypeVar
|
||||
from typing import Any, ClassVar, Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
@@ -50,6 +50,7 @@ class ApiError(PydanticBase):
|
||||
msg: str
|
||||
desc: str
|
||||
err_code: str
|
||||
data: Any | None = None
|
||||
|
||||
|
||||
class BaseResponse(PydanticBase):
|
||||
@@ -84,7 +85,7 @@ class ErrorResponse(BaseResponse):
|
||||
|
||||
status: ResponseStatus = ResponseStatus.FAIL
|
||||
description: str | None = None
|
||||
data: None = None
|
||||
data: Any | None = None
|
||||
|
||||
|
||||
class Pagination(PydanticBase):
|
||||
|
||||
@@ -108,6 +108,24 @@ class TestGenerateErrorResponses:
|
||||
assert example["status"] == "FAIL"
|
||||
assert example["error_code"] == "RES-404"
|
||||
assert example["message"] == "Not Found"
|
||||
assert example["data"] is None
|
||||
|
||||
def test_response_example_with_data(self):
|
||||
"""Generated response includes data when set on ApiError."""
|
||||
|
||||
class ErrorWithData(ApiException):
|
||||
api_error = ApiError(
|
||||
code=400,
|
||||
msg="Bad Request",
|
||||
desc="Invalid input.",
|
||||
err_code="BAD-400",
|
||||
data={"details": "some context"},
|
||||
)
|
||||
|
||||
responses = generate_error_responses(ErrorWithData)
|
||||
example = responses[400]["content"]["application/json"]["example"]
|
||||
|
||||
assert example["data"] == {"details": "some context"}
|
||||
|
||||
|
||||
class TestInitExceptionsHandlers:
|
||||
@@ -137,6 +155,59 @@ class TestInitExceptionsHandlers:
|
||||
assert data["error_code"] == "RES-404"
|
||||
assert data["message"] == "Not Found"
|
||||
|
||||
def test_handles_api_exception_without_data(self):
|
||||
"""ApiException without data returns null data field."""
|
||||
app = FastAPI()
|
||||
init_exceptions_handlers(app)
|
||||
|
||||
@app.get("/error")
|
||||
async def raise_error():
|
||||
raise NotFoundError()
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/error")
|
||||
|
||||
assert response.status_code == 404
|
||||
assert response.json()["data"] is None
|
||||
|
||||
def test_handles_api_exception_with_data(self):
|
||||
"""ApiException with data returns the data payload."""
|
||||
app = FastAPI()
|
||||
init_exceptions_handlers(app)
|
||||
|
||||
class CustomValidationError(ApiException):
|
||||
api_error = ApiError(
|
||||
code=422,
|
||||
msg="Validation Error",
|
||||
desc="1 validation error(s) detected",
|
||||
err_code="CUSTOM-422",
|
||||
data={
|
||||
"errors": [
|
||||
{
|
||||
"field": "email",
|
||||
"message": "invalid format",
|
||||
"type": "value_error",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@app.get("/error")
|
||||
async def raise_error():
|
||||
raise CustomValidationError()
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/error")
|
||||
|
||||
assert response.status_code == 422
|
||||
data = response.json()
|
||||
assert data["data"] == {
|
||||
"errors": [
|
||||
{"field": "email", "message": "invalid format", "type": "value_error"}
|
||||
]
|
||||
}
|
||||
assert data["error_code"] == "CUSTOM-422"
|
||||
|
||||
def test_handles_validation_error(self):
|
||||
"""Handles validation errors with structured response."""
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi import Depends, FastAPI
|
||||
from httpx import AsyncClient
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.engine import make_url
|
||||
@@ -236,6 +236,30 @@ class TestCreateAsyncClient:
|
||||
|
||||
assert client_ref.is_closed
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_dependency_overrides_applied_and_cleaned(self):
|
||||
"""Dependency overrides are applied during the context and removed after."""
|
||||
app = FastAPI()
|
||||
|
||||
async def original_dep() -> str:
|
||||
return "original"
|
||||
|
||||
async def override_dep() -> str:
|
||||
return "overridden"
|
||||
|
||||
@app.get("/dep")
|
||||
async def dep_endpoint(value: str = Depends(original_dep)):
|
||||
return {"value": value}
|
||||
|
||||
async with create_async_client(
|
||||
app, dependency_overrides={original_dep: override_dep}
|
||||
) as client:
|
||||
response = await client.get("/dep")
|
||||
assert response.json() == {"value": "overridden"}
|
||||
|
||||
# Overrides should be cleaned up
|
||||
assert original_dep not in app.dependency_overrides
|
||||
|
||||
|
||||
class TestCreateDbSession:
|
||||
"""Tests for create_db_session helper."""
|
||||
@@ -297,6 +321,22 @@ class TestCreateDbSession:
|
||||
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as _:
|
||||
pass
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_cleanup_truncates_tables(self):
|
||||
"""Tables are truncated after session closes when cleanup=True."""
|
||||
role_id = uuid.uuid4()
|
||||
async with create_db_session(
|
||||
DATABASE_URL, Base, cleanup=True, drop_tables=False
|
||||
) as session:
|
||||
role = Role(id=role_id, name="will_be_cleaned")
|
||||
session.add(role)
|
||||
await session.commit()
|
||||
|
||||
# Data should have been truncated, but tables still exist
|
||||
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
|
||||
result = await session.execute(select(Role))
|
||||
assert result.all() == []
|
||||
|
||||
|
||||
class TestGetXdistWorker:
|
||||
"""Tests for _get_xdist_worker helper."""
|
||||
|
||||
@@ -46,6 +46,31 @@ class TestApiError:
|
||||
assert error.desc == "The resource was not found."
|
||||
assert error.err_code == "RES-404"
|
||||
|
||||
def test_data_defaults_to_none(self):
|
||||
"""ApiError data field defaults to None."""
|
||||
error = ApiError(
|
||||
code=404,
|
||||
msg="Not Found",
|
||||
desc="The resource was not found.",
|
||||
err_code="RES-404",
|
||||
)
|
||||
assert error.data is None
|
||||
|
||||
def test_create_with_data(self):
|
||||
"""ApiError can be created with a data payload."""
|
||||
error = ApiError(
|
||||
code=422,
|
||||
msg="Validation Error",
|
||||
desc="2 validation error(s) detected",
|
||||
err_code="VAL-422",
|
||||
data={
|
||||
"errors": [{"field": "name", "message": "required", "type": "missing"}]
|
||||
},
|
||||
)
|
||||
assert error.data == {
|
||||
"errors": [{"field": "name", "message": "required", "type": "missing"}]
|
||||
}
|
||||
|
||||
def test_requires_all_fields(self):
|
||||
"""ApiError requires all fields."""
|
||||
with pytest.raises(ValidationError):
|
||||
|
||||
Reference in New Issue
Block a user