From 05b5a2c876d5f2422625fc14000ac2538725f8d1 Mon Sep 17 00:00:00 2001 From: d3vyce <44915747+d3vyce@users.noreply.github.com> Date: Mon, 2 Mar 2026 16:34:29 +0100 Subject: [PATCH] feat: rework Exception/ApiError (#107) * feat: rework Exception/ApiError * docs: update exceptions module * fix: docstring --- docs/module/exceptions.md | 75 ++- src/fastapi_toolsets/exceptions/exceptions.py | 136 +++--- src/fastapi_toolsets/exceptions/handler.py | 127 +++-- tests/test_crud_search.py | 4 +- tests/test_exceptions.py | 435 +++++++++++++++++- 5 files changed, 603 insertions(+), 174 deletions(-) diff --git a/docs/module/exceptions.md b/docs/module/exceptions.md index 4f318d1..f3b22fe 100644 --- a/docs/module/exceptions.md +++ b/docs/module/exceptions.md @@ -21,30 +21,37 @@ init_exceptions_handlers(app=app) This registers handlers for: - [`ApiException`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.ApiException) — all custom exceptions below +- `HTTPException` — Starlette/FastAPI HTTP errors - `RequestValidationError` — Pydantic request validation (422) - `ResponseValidationError` — Pydantic response validation (422) - `Exception` — unhandled errors (500) +It also patches `app.openapi()` to replace the default Pydantic 422 schema with a structured example matching the `ErrorResponse` format. + ## Built-in exceptions | Exception | Status | Default message | |-----------|--------|-----------------| | [`UnauthorizedError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.UnauthorizedError) | 401 | Unauthorized | | [`ForbiddenError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.ForbiddenError) | 403 | Forbidden | -| [`NotFoundError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.NotFoundError) | 404 | Not found | +| [`NotFoundError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.NotFoundError) | 404 | Not Found | | [`ConflictError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.ConflictError) | 409 | Conflict | -| [`NoSearchableFieldsError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.NoSearchableFieldsError) | 400 | No searchable fields | -| [`InvalidFacetFilterError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.InvalidFacetFilterError) | 400 | Invalid facet filter | +| [`NoSearchableFieldsError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.NoSearchableFieldsError) | 400 | No Searchable Fields | +| [`InvalidFacetFilterError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.InvalidFacetFilterError) | 400 | Invalid Facet Filter | +| [`InvalidOrderFieldError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.InvalidOrderFieldError) | 422 | Invalid Order Field | + +### Per-instance overrides + +All built-in exceptions accept optional keyword arguments to customise the response for a specific raise site without changing the class defaults: + +| Argument | Effect | +|----------|--------| +| `detail` | Overrides both `str(exc)` (log output) and the `message` field in the response body | +| `desc` | Overrides the `description` field | +| `data` | Overrides the `data` field | ```python -from fastapi_toolsets.exceptions import NotFoundError - -@router.get("/users/{id}") -async def get_user(id: int, session: AsyncSession = Depends(get_db)): - user = await UserCrud.first(session=session, filters=[User.id == id]) - if not user: - raise NotFoundError - return user +raise NotFoundError(detail="User 42 not found", desc="No user with that ID exists in the database.") ``` ## Custom exceptions @@ -58,12 +65,51 @@ from fastapi_toolsets.schemas import ApiError class PaymentRequiredError(ApiException): api_error = ApiError( code=402, - msg="Payment required", + msg="Payment Required", desc="Your subscription has expired.", - err_code="PAYMENT_REQUIRED", + err_code="BILLING-402", ) ``` +!!! warning + Subclasses that do not define `api_error` raise a `TypeError` at **class creation time**, not at raise time. + +### Custom `__init__` + +Override `__init__` to compute `detail`, `desc`, or `data` dynamically, then delegate to `super().__init__()`: + +```python +class OrderValidationError(ApiException): + api_error = ApiError( + code=422, + msg="Order Validation Failed", + desc="One or more order fields are invalid.", + err_code="ORDER-422", + ) + + def __init__(self, *field_errors: str) -> None: + super().__init__( + f"{len(field_errors)} validation error(s)", + desc=", ".join(field_errors), + data={"errors": [{"message": e} for e in field_errors]}, + ) +``` + +### Intermediate base classes + +Use `abstract=True` when creating a shared base that is not meant to be raised directly: + +```python +class BillingError(ApiException, abstract=True): + """Base for all billing-related errors.""" + +class PaymentRequiredError(BillingError): + api_error = ApiError(code=402, msg="Payment Required", desc="...", err_code="BILLING-402") + +class SubscriptionExpiredError(BillingError): + api_error = ApiError(code=402, msg="Subscription Expired", desc="...", err_code="BILLING-402-EXP") +``` + ## OpenAPI response documentation Use [`generate_error_responses`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.generate_error_responses) to add error schemas to your endpoint's OpenAPI spec: @@ -78,8 +124,7 @@ from fastapi_toolsets.exceptions import generate_error_responses, NotFoundError, async def get_user(...): ... ``` -!!! info - The pydantic validation error is automatically added by FastAPI. +Multiple exceptions sharing the same HTTP status code are grouped under one entry, each appearing as a named example keyed by its `err_code`. This keeps the OpenAPI UI readable when several error variants map to the same status. --- diff --git a/src/fastapi_toolsets/exceptions/exceptions.py b/src/fastapi_toolsets/exceptions/exceptions.py index 16b9d4a..be5a762 100644 --- a/src/fastapi_toolsets/exceptions/exceptions.py +++ b/src/fastapi_toolsets/exceptions/exceptions.py @@ -6,32 +6,46 @@ from ..schemas import ApiError, ErrorResponse, ResponseStatus class ApiException(Exception): - """Base exception for API errors with structured response. - - Subclass this to create custom API exceptions with consistent error format. - The exception handler will use api_error to generate the response. - - Example: - ```python - class CustomError(ApiException): - api_error = ApiError( - code=400, - msg="Bad Request", - desc="The request was invalid.", - err_code="CUSTOM-400", - ) - ``` - """ + """Base exception for API errors with structured response.""" api_error: ClassVar[ApiError] - def __init__(self, detail: str | None = None): + def __init_subclass__(cls, abstract: bool = False, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + if not abstract and not hasattr(cls, "api_error"): + raise TypeError( + f"{cls.__name__} must define an 'api_error' class attribute. " + "Pass abstract=True when creating intermediate base classes." + ) + + def __init__( + self, + detail: str | None = None, + *, + desc: str | None = None, + data: Any = None, + ) -> None: """Initialize the exception. Args: - detail: Optional override for the error message + detail: Optional human-readable message + desc: Optional per-instance override for the ``description`` field + in the HTTP response body. + data: Optional per-instance override for the ``data`` field in the + HTTP response body. """ - super().__init__(detail or self.api_error.msg) + updates: dict[str, Any] = {} + if detail is not None: + updates["msg"] = detail + if desc is not None: + updates["desc"] = desc + if data is not None: + updates["data"] = data + if updates: + object.__setattr__( + self, "api_error", self.__class__.api_error.model_copy(update=updates) + ) + super().__init__(self.api_error.msg) class UnauthorizedError(ApiException): @@ -92,14 +106,15 @@ class NoSearchableFieldsError(ApiException): """Initialize the exception. Args: - model: The SQLAlchemy model class that has no searchable fields + model: The model class that has no searchable fields configured. """ self.model = model - detail = ( - f"No searchable fields found for model '{model.__name__}'. " - "Provide 'search_fields' parameter or set 'searchable_fields' on the CRUD class." + super().__init__( + desc=( + f"No searchable fields found for model '{model.__name__}'. " + "Provide 'search_fields' parameter or set 'searchable_fields' on the CRUD class." + ) ) - super().__init__(detail) class InvalidFacetFilterError(ApiException): @@ -116,16 +131,17 @@ class InvalidFacetFilterError(ApiException): """Initialize the exception. Args: - key: The unknown filter key provided by the caller - valid_keys: Set of valid keys derived from the declared facet_fields + key: The unknown filter key provided by the caller. + valid_keys: Set of valid keys derived from the declared facet_fields. """ self.key = key self.valid_keys = valid_keys - detail = ( - f"'{key}' is not a declared facet field. " - f"Valid keys: {sorted(valid_keys) or 'none — set facet_fields on the CRUD class'}." + super().__init__( + desc=( + f"'{key}' is not a declared facet field. " + f"Valid keys: {sorted(valid_keys) or 'none — set facet_fields on the CRUD class'}." + ) ) - super().__init__(detail) class InvalidOrderFieldError(ApiException): @@ -142,15 +158,14 @@ class InvalidOrderFieldError(ApiException): """Initialize the exception. Args: - field: The unknown order field provided by the caller - valid_fields: List of valid field names + field: The unknown order field provided by the caller. + valid_fields: List of valid field names. """ self.field = field self.valid_fields = valid_fields - detail = ( - f"'{field}' is not an allowed order field. Valid fields: {valid_fields}." + super().__init__( + desc=f"'{field}' is not an allowed order field. Valid fields: {valid_fields}." ) - super().__init__(detail) def generate_error_responses( @@ -158,44 +173,39 @@ def generate_error_responses( ) -> dict[int | str, dict[str, Any]]: """Generate OpenAPI response documentation for exceptions. - Use this to document possible error responses for an endpoint. - Args: - *errors: Exception classes that inherit from ApiException + *errors: Exception classes that inherit from ApiException. Returns: - Dict suitable for FastAPI's responses parameter - - Example: - ```python - from fastapi_toolsets.exceptions import generate_error_responses, UnauthorizedError, ForbiddenError - - @app.get( - "/admin", - responses=generate_error_responses(UnauthorizedError, ForbiddenError) - ) - async def admin_endpoint(): - ... - ``` + Dict suitable for FastAPI's ``responses`` parameter. """ responses: dict[int | str, dict[str, Any]] = {} for error in errors: api_error = error.api_error + code = api_error.code - responses[api_error.code] = { - "model": ErrorResponse, - "description": api_error.msg, - "content": { - "application/json": { - "example": { - "data": api_error.data, - "status": ResponseStatus.FAIL.value, - "message": api_error.msg, - "description": api_error.desc, - "error_code": api_error.err_code, + if code not in responses: + responses[code] = { + "model": ErrorResponse, + "description": api_error.msg, + "content": { + "application/json": { + "examples": {}, } - } + }, + } + + responses[code]["content"]["application/json"]["examples"][ + api_error.err_code + ] = { + "summary": api_error.msg, + "value": { + "data": api_error.data, + "status": ResponseStatus.FAIL.value, + "message": api_error.msg, + "description": api_error.desc, + "error_code": api_error.err_code, }, } diff --git a/src/fastapi_toolsets/exceptions/handler.py b/src/fastapi_toolsets/exceptions/handler.py index d27f2ca..859de5a 100644 --- a/src/fastapi_toolsets/exceptions/handler.py +++ b/src/fastapi_toolsets/exceptions/handler.py @@ -1,10 +1,14 @@ """Exception handlers for FastAPI applications.""" +from collections.abc import Callable from typing import Any from fastapi import FastAPI, Request, Response, status -from fastapi.exceptions import RequestValidationError, ResponseValidationError -from fastapi.openapi.utils import get_openapi +from fastapi.exceptions import ( + HTTPException, + RequestValidationError, + ResponseValidationError, +) from fastapi.responses import JSONResponse from ..schemas import ErrorResponse, ResponseStatus @@ -14,43 +18,20 @@ from .exceptions import ApiException def init_exceptions_handlers(app: FastAPI) -> FastAPI: """Register exception handlers and custom OpenAPI schema on a FastAPI app. - Installs handlers for :class:`ApiException`, validation errors, and - unhandled exceptions, and replaces the default 422 schema with a - consistent error format. - Args: - app: FastAPI application instance + app: FastAPI application instance. Returns: - The same FastAPI instance (for chaining) - - Example: - ```python - from fastapi import FastAPI - from fastapi_toolsets.exceptions import init_exceptions_handlers - - app = FastAPI() - init_exceptions_handlers(app) - ``` + The same FastAPI instance (for chaining). """ _register_exception_handlers(app) - app.openapi = lambda: _custom_openapi(app) # type: ignore[method-assign] + _original_openapi = app.openapi + app.openapi = lambda: _patched_openapi(app, _original_openapi) # type: ignore[method-assign] return app def _register_exception_handlers(app: FastAPI) -> None: - """Register all exception handlers on a FastAPI application. - - Args: - app: FastAPI application instance - - Example: - from fastapi import FastAPI - from fastapi_toolsets.exceptions import init_exceptions_handlers - - app = FastAPI() - init_exceptions_handlers(app) - """ + """Register all exception handlers on a FastAPI application.""" @app.exception_handler(ApiException) async def api_exception_handler(request: Request, exc: ApiException) -> Response: @@ -62,12 +43,25 @@ def _register_exception_handlers(app: FastAPI) -> None: description=api_error.desc, error_code=api_error.err_code, ) - return JSONResponse( status_code=api_error.code, content=error_response.model_dump(), ) + @app.exception_handler(HTTPException) + async def http_exception_handler(request: Request, exc: HTTPException) -> Response: + """Handle Starlette/FastAPI HTTPException with a consistent error format.""" + detail = exc.detail if isinstance(exc.detail, str) else "HTTP Error" + error_response = ErrorResponse( + message=detail, + error_code=f"HTTP-{exc.status_code}", + ) + return JSONResponse( + status_code=exc.status_code, + content=error_response.model_dump(), + headers=getattr(exc, "headers", None), + ) + @app.exception_handler(RequestValidationError) async def request_validation_handler( request: Request, exc: RequestValidationError @@ -90,7 +84,6 @@ def _register_exception_handlers(app: FastAPI) -> None: description="An unexpected error occurred. Please try again later.", error_code="SERVER-500", ) - return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=error_response.model_dump(), @@ -105,11 +98,10 @@ def _format_validation_error( formatted_errors = [] for error in errors: - field_path = ".".join( - str(loc) - for loc in error["loc"] - if loc not in ("body", "query", "path", "header", "cookie") - ) + locs = error["loc"] + if locs and locs[0] in ("body", "query", "path", "header", "cookie"): + locs = locs[1:] + field_path = ".".join(str(loc) for loc in locs) formatted_errors.append( { "field": field_path or "root", @@ -131,34 +123,22 @@ def _format_validation_error( ) -def _custom_openapi(app: FastAPI) -> dict[str, Any]: - """Generate custom OpenAPI schema with standardized error format. - - Replaces default 422 validation error responses with the custom format. +def _patched_openapi( + app: FastAPI, original_openapi: Callable[[], dict[str, Any]] +) -> dict[str, Any]: + """Generate the OpenAPI schema and replace default 422 responses. Args: - app: FastAPI application instance + app: FastAPI application instance. + original_openapi: The previous ``app.openapi`` callable to delegate to. Returns: - OpenAPI schema dict - - Example: - from fastapi import FastAPI - from fastapi_toolsets.exceptions import init_exceptions_handlers - - app = FastAPI() - init_exceptions_handlers(app) # Automatically sets custom OpenAPI + Patched OpenAPI schema dict. """ if app.openapi_schema: return app.openapi_schema - openapi_schema = get_openapi( - title=app.title, - version=app.version, - openapi_version=app.openapi_version, - description=app.description, - routes=app.routes, - ) + openapi_schema = original_openapi() for path_data in openapi_schema.get("paths", {}).values(): for operation in path_data.values(): @@ -168,20 +148,25 @@ def _custom_openapi(app: FastAPI) -> dict[str, Any]: "description": "Validation Error", "content": { "application/json": { - "example": { - "data": { - "errors": [ - { - "field": "field_name", - "message": "value is not valid", - "type": "value_error", - } - ] - }, - "status": ResponseStatus.FAIL.value, - "message": "Validation Error", - "description": "1 validation error(s) detected", - "error_code": "VAL-422", + "examples": { + "VAL-422": { + "summary": "Validation Error", + "value": { + "data": { + "errors": [ + { + "field": "field_name", + "message": "value is not valid", + "type": "value_error", + } + ] + }, + "status": ResponseStatus.FAIL.value, + "message": "Validation Error", + "description": "1 validation error(s) detected", + "error_code": "VAL-422", + }, + } } } }, diff --git a/tests/test_crud_search.py b/tests/test_crud_search.py index 5fbe1e9..49e94aa 100644 --- a/tests/test_crud_search.py +++ b/tests/test_crud_search.py @@ -410,7 +410,7 @@ class TestNoSearchableFieldsError: from fastapi_toolsets.exceptions import NoSearchableFieldsError error = NoSearchableFieldsError(User) - assert "User" in str(error) + assert "User" in error.api_error.desc assert error.model is User def test_error_raised_when_no_fields(self): @@ -434,7 +434,7 @@ class TestNoSearchableFieldsError: build_search_filters(NoStringModel, "test") assert exc_info.value.model is NoStringModel - assert "NoStringModel" in str(exc_info.value) + assert "NoStringModel" in exc_info.value.api_error.desc class TestGetSearchableFields: diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index d088ed5..700792a 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -2,6 +2,7 @@ import pytest from fastapi import FastAPI +from fastapi.exceptions import HTTPException from fastapi.testclient import TestClient from fastapi_toolsets.exceptions import ( @@ -36,8 +37,8 @@ class TestApiException: assert error.api_error.msg == "I'm a teapot" assert str(error) == "I'm a teapot" - def test_custom_detail_message(self): - """Custom detail overrides default message.""" + def test_detail_overrides_msg_and_str(self): + """detail sets both str(exc) and api_error.msg; class-level msg is unchanged.""" class CustomError(ApiException): api_error = ApiError( @@ -47,8 +48,172 @@ class TestApiException: err_code="BAD-400", ) - error = CustomError("Custom message") - assert str(error) == "Custom message" + error = CustomError("Widget not found") + assert str(error) == "Widget not found" + assert error.api_error.msg == "Widget not found" + assert CustomError.api_error.msg == "Bad Request" # class unchanged + + def test_desc_override(self): + """desc kwarg overrides api_error.desc on the instance only.""" + + class MyError(ApiException): + api_error = ApiError( + code=400, msg="Error", desc="Default.", err_code="ERR-400" + ) + + err = MyError(desc="Custom desc.") + assert err.api_error.desc == "Custom desc." + assert MyError.api_error.desc == "Default." # class unchanged + + def test_data_override(self): + """data kwarg sets api_error.data on the instance only.""" + + class MyError(ApiException): + api_error = ApiError( + code=400, msg="Error", desc="Default.", err_code="ERR-400" + ) + + err = MyError(data={"key": "value"}) + assert err.api_error.data == {"key": "value"} + assert MyError.api_error.data is None # class unchanged + + def test_desc_and_data_override(self): + """detail, desc and data can all be overridden together.""" + + class MyError(ApiException): + api_error = ApiError( + code=400, msg="Error", desc="Default.", err_code="ERR-400" + ) + + err = MyError("custom msg", desc="New desc.", data={"x": 1}) + assert str(err) == "custom msg" + assert err.api_error.msg == "custom msg" # detail also updates msg + assert err.api_error.desc == "New desc." + assert err.api_error.data == {"x": 1} + assert err.api_error.code == 400 # other fields unchanged + + def test_class_api_error_not_mutated_after_instance_override(self): + """Raising with desc/data does not mutate the class-level api_error.""" + + class MyError(ApiException): + api_error = ApiError( + code=400, msg="Error", desc="Default.", err_code="ERR-400" + ) + + MyError(desc="Changed", data={"x": 1}) + assert MyError.api_error.desc == "Default." + assert MyError.api_error.data is None + + def test_subclass_uses_super_with_desc_and_data(self): + """Subclasses can delegate detail/desc/data to super().__init__().""" + + class BuildValidationError(ApiException): + api_error = ApiError( + code=422, + msg="Build Validation Error", + desc="The build configuration is invalid.", + err_code="BUILD-422", + ) + + def __init__(self, *errors: str) -> None: + super().__init__( + f"{len(errors)} validation error(s)", + desc=", ".join(errors), + data={"errors": [{"message": e} for e in errors]}, + ) + + err = BuildValidationError("Field A is required", "Field B is invalid") + assert str(err) == "2 validation error(s)" + assert err.api_error.msg == "2 validation error(s)" # detail set msg + assert err.api_error.desc == "Field A is required, Field B is invalid" + assert err.api_error.data == { + "errors": [ + {"message": "Field A is required"}, + {"message": "Field B is invalid"}, + ] + } + assert err.api_error.code == 422 # other fields unchanged + + def test_detail_desc_data_in_http_response(self): + """detail/desc/data overrides all appear correctly in the FastAPI HTTP response.""" + + class DynamicError(ApiException): + api_error = ApiError( + code=400, msg="Error", desc="Default.", err_code="ERR-400" + ) + + def __init__(self, message: str) -> None: + super().__init__( + message, + desc=f"Detail: {message}", + data={"reason": message}, + ) + + app = FastAPI() + init_exceptions_handlers(app) + + @app.get("/error") + async def raise_error(): + raise DynamicError("something went wrong") + + client = TestClient(app) + response = client.get("/error") + + assert response.status_code == 400 + body = response.json() + assert body["message"] == "something went wrong" + assert body["description"] == "Detail: something went wrong" + assert body["data"] == {"reason": "something went wrong"} + + +class TestApiExceptionGuard: + """Tests for the __init_subclass__ api_error guard.""" + + def test_missing_api_error_raises_type_error(self): + """Defining a subclass without api_error raises TypeError at class creation time.""" + with pytest.raises( + TypeError, match="must define an 'api_error' class attribute" + ): + + class BrokenError(ApiException): + pass + + def test_abstract_subclass_skips_guard(self): + """abstract=True allows intermediate base classes without api_error.""" + + class BaseGroupError(ApiException, abstract=True): + pass + + # Concrete child must still define it + class ConcreteError(BaseGroupError): + api_error = ApiError( + code=400, msg="Error", desc="Desc.", err_code="ERR-400" + ) + + err = ConcreteError() + assert err.api_error.code == 400 + + def test_abstract_child_still_requires_api_error_on_concrete(self): + """Concrete subclass of an abstract class must define api_error.""" + + class Base(ApiException, abstract=True): + pass + + with pytest.raises( + TypeError, match="must define an 'api_error' class attribute" + ): + + class Concrete(Base): + pass + + def test_inherited_api_error_satisfies_guard(self): + """Subclass that inherits api_error from a parent does not need its own.""" + + class ConcreteError(NotFoundError): + pass + + err = ConcreteError() + assert err.api_error.code == 404 class TestBuiltInExceptions: @@ -90,7 +255,7 @@ class TestGenerateErrorResponses: assert responses[404]["description"] == "Not Found" def test_generates_multiple_responses(self): - """Generates responses for multiple exceptions.""" + """Generates responses for multiple exceptions with distinct status codes.""" responses = generate_error_responses( UnauthorizedError, ForbiddenError, @@ -101,15 +266,24 @@ class TestGenerateErrorResponses: assert 403 in responses assert 404 in responses - def test_response_has_example(self): - """Generated response includes example.""" + def test_response_has_named_example(self): + """Generated response uses named examples keyed by err_code.""" responses = generate_error_responses(NotFoundError) - example = responses[404]["content"]["application/json"]["example"] + examples = responses[404]["content"]["application/json"]["examples"] - assert example["status"] == "FAIL" - assert example["error_code"] == "RES-404" - assert example["message"] == "Not Found" - assert example["data"] is None + assert "RES-404" in examples + value = examples["RES-404"]["value"] + assert value["status"] == "FAIL" + assert value["error_code"] == "RES-404" + assert value["message"] == "Not Found" + assert value["data"] is None + + def test_response_example_has_summary(self): + """Each named example carries a summary equal to api_error.msg.""" + responses = generate_error_responses(NotFoundError) + example = responses[404]["content"]["application/json"]["examples"]["RES-404"] + + assert example["summary"] == "Not Found" def test_response_example_with_data(self): """Generated response includes data when set on ApiError.""" @@ -124,9 +298,49 @@ class TestGenerateErrorResponses: ) responses = generate_error_responses(ErrorWithData) - example = responses[400]["content"]["application/json"]["example"] + value = responses[400]["content"]["application/json"]["examples"]["BAD-400"][ + "value" + ] - assert example["data"] == {"details": "some context"} + assert value["data"] == {"details": "some context"} + + def test_two_errors_same_code_both_present(self): + """Two exceptions with the same HTTP code produce two named examples.""" + + class BadRequestA(ApiException): + api_error = ApiError( + code=400, msg="Bad A", desc="Reason A.", err_code="ERR-A" + ) + + class BadRequestB(ApiException): + api_error = ApiError( + code=400, msg="Bad B", desc="Reason B.", err_code="ERR-B" + ) + + responses = generate_error_responses(BadRequestA, BadRequestB) + + assert 400 in responses + examples = responses[400]["content"]["application/json"]["examples"] + assert "ERR-A" in examples + assert "ERR-B" in examples + assert examples["ERR-A"]["value"]["message"] == "Bad A" + assert examples["ERR-B"]["value"]["message"] == "Bad B" + + def test_two_errors_same_code_single_top_level_entry(self): + """Two exceptions with the same HTTP code produce exactly one top-level entry.""" + + class BadRequestA(ApiException): + api_error = ApiError( + code=400, msg="Bad A", desc="Reason A.", err_code="ERR-A" + ) + + class BadRequestB(ApiException): + api_error = ApiError( + code=400, msg="Bad B", desc="Reason B.", err_code="ERR-B" + ) + + responses = generate_error_responses(BadRequestA, BadRequestB) + assert len([k for k in responses if k == 400]) == 1 class TestInitExceptionsHandlers: @@ -250,13 +464,68 @@ class TestInitExceptionsHandlers: assert data["status"] == "FAIL" assert data["error_code"] == "SERVER-500" - def test_custom_openapi_schema(self): - """Customizes OpenAPI schema for 422 responses.""" + def test_handles_http_exception(self): + """Handles starlette HTTPException with consistent ErrorResponse envelope.""" app = FastAPI() init_exceptions_handlers(app) + @app.get("/protected") + async def protected(): + raise HTTPException(status_code=403, detail="Forbidden") + + client = TestClient(app) + response = client.get("/protected") + + assert response.status_code == 403 + data = response.json() + assert data["status"] == "FAIL" + assert data["error_code"] == "HTTP-403" + assert data["message"] == "Forbidden" + + def test_handles_http_exception_404_from_route(self): + """HTTPException(404) raised inside a route uses the consistent ErrorResponse envelope.""" + app = FastAPI() + init_exceptions_handlers(app) + + @app.get("/items/{item_id}") + async def get_item(item_id: int): + raise HTTPException(status_code=404, detail="Item not found") + + client = TestClient(app) + response = client.get("/items/99") + + assert response.status_code == 404 + data = response.json() + assert data["status"] == "FAIL" + assert data["error_code"] == "HTTP-404" + assert data["message"] == "Item not found" + + def test_handles_http_exception_forwards_headers(self): + """HTTPException with WWW-Authenticate header forwards it in the response.""" + app = FastAPI() + init_exceptions_handlers(app) + + @app.get("/secure") + async def secure(): + raise HTTPException( + status_code=401, + detail="Not authenticated", + headers={"WWW-Authenticate": "Bearer"}, + ) + + client = TestClient(app) + response = client.get("/secure") + + assert response.status_code == 401 + assert response.headers.get("www-authenticate") == "Bearer" + + def test_custom_openapi_schema(self): + """Customises OpenAPI schema for 422 responses using named examples.""" from pydantic import BaseModel + app = FastAPI() + init_exceptions_handlers(app) + class Item(BaseModel): name: str @@ -269,8 +538,128 @@ class TestInitExceptionsHandlers: post_op = openapi["paths"]["/items"]["post"] assert "422" in post_op["responses"] resp_422 = post_op["responses"]["422"] - example = resp_422["content"]["application/json"]["example"] - assert example["error_code"] == "VAL-422" + examples = resp_422["content"]["application/json"]["examples"] + assert "VAL-422" in examples + assert examples["VAL-422"]["value"]["error_code"] == "VAL-422" + + def test_custom_openapi_preserves_app_metadata(self): + """_patched_openapi preserves custom FastAPI app-level metadata.""" + app = FastAPI( + title="My API", + version="2.0.0", + description="Custom description", + ) + init_exceptions_handlers(app) + + schema = app.openapi() + assert schema["info"]["title"] == "My API" + assert schema["info"]["version"] == "2.0.0" + + def test_handles_response_validation_error(self): + """Handles ResponseValidationError with a structured 422 response.""" + from pydantic import BaseModel + + class CountResponse(BaseModel): + count: int + + app = FastAPI() + init_exceptions_handlers(app) + + @app.get("/broken", response_model=CountResponse) + async def broken(): + return {"count": "not-a-number"} # triggers ResponseValidationError + + client = TestClient(app, raise_server_exceptions=False) + response = client.get("/broken") + + assert response.status_code == 422 + data = response.json() + assert data["status"] == "FAIL" + assert data["error_code"] == "VAL-422" + assert "errors" in data["data"] + + def test_handles_validation_error_with_non_standard_loc(self): + """Validation error with empty loc tuple maps the field to 'root'.""" + from fastapi.exceptions import RequestValidationError + + app = FastAPI() + init_exceptions_handlers(app) + + @app.get("/root-error") + async def root_error(): + raise RequestValidationError( + [ + { + "type": "custom", + "loc": (), + "msg": "root level error", + "input": None, + "url": "", + } + ] + ) + + client = TestClient(app) + response = client.get("/root-error") + + assert response.status_code == 422 + data = response.json() + assert data["data"]["errors"][0]["field"] == "root" + + def test_openapi_schema_cached_after_first_call(self): + """app.openapi() returns the cached schema on subsequent calls.""" + from pydantic import BaseModel + + app = FastAPI() + init_exceptions_handlers(app) + + class Item(BaseModel): + name: str + + @app.post("/items") + async def create_item(item: Item): + return item + + schema_first = app.openapi() + schema_second = app.openapi() + assert schema_first is schema_second + + def test_openapi_skips_operations_without_422(self): + """_patched_openapi leaves operations that have no 422 response unchanged.""" + app = FastAPI() + init_exceptions_handlers(app) + + @app.get("/ping") + async def ping(): + return {"ok": True} + + schema = app.openapi() + get_op = schema["paths"]["/ping"]["get"] + assert "422" not in get_op["responses"] + assert "200" in get_op["responses"] + + def test_openapi_skips_non_dict_path_item_values(self): + """_patched_openapi ignores non-dict values in path items (e.g. path-level parameters).""" + from fastapi_toolsets.exceptions.handler import _patched_openapi + + app = FastAPI() + + def fake_openapi() -> dict: + return { + "paths": { + "/items": { + "parameters": [ + {"name": "q", "in": "query"} + ], # list, not a dict + "get": {"responses": {"200": {"description": "OK"}}}, + } + } + } + + schema = _patched_openapi(app, fake_openapi) + # The list value was skipped without error; the GET operation is intact + assert schema["paths"]["/items"]["parameters"] == [{"name": "q", "in": "query"}] + assert "422" not in schema["paths"]["/items"]["get"]["responses"] class TestExceptionIntegration: @@ -352,12 +741,12 @@ class TestInvalidOrderFieldError: assert error.field == "unknown" assert error.valid_fields == ["name", "created_at"] - def test_message_contains_field_and_valid_fields(self): - """Exception message mentions the bad field and valid options.""" + def test_description_contains_field_and_valid_fields(self): + """api_error.desc mentions the bad field and valid options.""" error = InvalidOrderFieldError("bad_field", ["name", "email"]) - assert "bad_field" in str(error) - assert "name" in str(error) - assert "email" in str(error) + assert "bad_field" in error.api_error.desc + assert "name" in error.api_error.desc + assert "email" in error.api_error.desc def test_handled_as_422_by_exception_handler(self): """init_exceptions_handlers turns InvalidOrderFieldError into a 422 response."""