mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-03-02 17:30:48 +01:00
feat: rework Exception/ApiError (#107)
* feat: rework Exception/ApiError * docs: update exceptions module * fix: docstring
This commit is contained in:
@@ -21,30 +21,37 @@ init_exceptions_handlers(app=app)
|
|||||||
This registers handlers for:
|
This registers handlers for:
|
||||||
|
|
||||||
- [`ApiException`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.ApiException) — all custom exceptions below
|
- [`ApiException`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.ApiException) — all custom exceptions below
|
||||||
|
- `HTTPException` — Starlette/FastAPI HTTP errors
|
||||||
- `RequestValidationError` — Pydantic request validation (422)
|
- `RequestValidationError` — Pydantic request validation (422)
|
||||||
- `ResponseValidationError` — Pydantic response validation (422)
|
- `ResponseValidationError` — Pydantic response validation (422)
|
||||||
- `Exception` — unhandled errors (500)
|
- `Exception` — unhandled errors (500)
|
||||||
|
|
||||||
|
It also patches `app.openapi()` to replace the default Pydantic 422 schema with a structured example matching the `ErrorResponse` format.
|
||||||
|
|
||||||
## Built-in exceptions
|
## Built-in exceptions
|
||||||
|
|
||||||
| Exception | Status | Default message |
|
| Exception | Status | Default message |
|
||||||
|-----------|--------|-----------------|
|
|-----------|--------|-----------------|
|
||||||
| [`UnauthorizedError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.UnauthorizedError) | 401 | Unauthorized |
|
| [`UnauthorizedError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.UnauthorizedError) | 401 | Unauthorized |
|
||||||
| [`ForbiddenError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.ForbiddenError) | 403 | Forbidden |
|
| [`ForbiddenError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.ForbiddenError) | 403 | Forbidden |
|
||||||
| [`NotFoundError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.NotFoundError) | 404 | Not found |
|
| [`NotFoundError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.NotFoundError) | 404 | Not Found |
|
||||||
| [`ConflictError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.ConflictError) | 409 | Conflict |
|
| [`ConflictError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.ConflictError) | 409 | Conflict |
|
||||||
| [`NoSearchableFieldsError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.NoSearchableFieldsError) | 400 | No searchable fields |
|
| [`NoSearchableFieldsError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.NoSearchableFieldsError) | 400 | No Searchable Fields |
|
||||||
| [`InvalidFacetFilterError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.InvalidFacetFilterError) | 400 | Invalid facet filter |
|
| [`InvalidFacetFilterError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.InvalidFacetFilterError) | 400 | Invalid Facet Filter |
|
||||||
|
| [`InvalidOrderFieldError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.InvalidOrderFieldError) | 422 | Invalid Order Field |
|
||||||
|
|
||||||
|
### Per-instance overrides
|
||||||
|
|
||||||
|
All built-in exceptions accept optional keyword arguments to customise the response for a specific raise site without changing the class defaults:
|
||||||
|
|
||||||
|
| Argument | Effect |
|
||||||
|
|----------|--------|
|
||||||
|
| `detail` | Overrides both `str(exc)` (log output) and the `message` field in the response body |
|
||||||
|
| `desc` | Overrides the `description` field |
|
||||||
|
| `data` | Overrides the `data` field |
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from fastapi_toolsets.exceptions import NotFoundError
|
raise NotFoundError(detail="User 42 not found", desc="No user with that ID exists in the database.")
|
||||||
|
|
||||||
@router.get("/users/{id}")
|
|
||||||
async def get_user(id: int, session: AsyncSession = Depends(get_db)):
|
|
||||||
user = await UserCrud.first(session=session, filters=[User.id == id])
|
|
||||||
if not user:
|
|
||||||
raise NotFoundError
|
|
||||||
return user
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Custom exceptions
|
## Custom exceptions
|
||||||
@@ -58,12 +65,51 @@ from fastapi_toolsets.schemas import ApiError
|
|||||||
class PaymentRequiredError(ApiException):
|
class PaymentRequiredError(ApiException):
|
||||||
api_error = ApiError(
|
api_error = ApiError(
|
||||||
code=402,
|
code=402,
|
||||||
msg="Payment required",
|
msg="Payment Required",
|
||||||
desc="Your subscription has expired.",
|
desc="Your subscription has expired.",
|
||||||
err_code="PAYMENT_REQUIRED",
|
err_code="BILLING-402",
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
!!! warning
|
||||||
|
Subclasses that do not define `api_error` raise a `TypeError` at **class creation time**, not at raise time.
|
||||||
|
|
||||||
|
### Custom `__init__`
|
||||||
|
|
||||||
|
Override `__init__` to compute `detail`, `desc`, or `data` dynamically, then delegate to `super().__init__()`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class OrderValidationError(ApiException):
|
||||||
|
api_error = ApiError(
|
||||||
|
code=422,
|
||||||
|
msg="Order Validation Failed",
|
||||||
|
desc="One or more order fields are invalid.",
|
||||||
|
err_code="ORDER-422",
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, *field_errors: str) -> None:
|
||||||
|
super().__init__(
|
||||||
|
f"{len(field_errors)} validation error(s)",
|
||||||
|
desc=", ".join(field_errors),
|
||||||
|
data={"errors": [{"message": e} for e in field_errors]},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Intermediate base classes
|
||||||
|
|
||||||
|
Use `abstract=True` when creating a shared base that is not meant to be raised directly:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class BillingError(ApiException, abstract=True):
|
||||||
|
"""Base for all billing-related errors."""
|
||||||
|
|
||||||
|
class PaymentRequiredError(BillingError):
|
||||||
|
api_error = ApiError(code=402, msg="Payment Required", desc="...", err_code="BILLING-402")
|
||||||
|
|
||||||
|
class SubscriptionExpiredError(BillingError):
|
||||||
|
api_error = ApiError(code=402, msg="Subscription Expired", desc="...", err_code="BILLING-402-EXP")
|
||||||
|
```
|
||||||
|
|
||||||
## OpenAPI response documentation
|
## OpenAPI response documentation
|
||||||
|
|
||||||
Use [`generate_error_responses`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.generate_error_responses) to add error schemas to your endpoint's OpenAPI spec:
|
Use [`generate_error_responses`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.generate_error_responses) to add error schemas to your endpoint's OpenAPI spec:
|
||||||
@@ -78,8 +124,7 @@ from fastapi_toolsets.exceptions import generate_error_responses, NotFoundError,
|
|||||||
async def get_user(...): ...
|
async def get_user(...): ...
|
||||||
```
|
```
|
||||||
|
|
||||||
!!! info
|
Multiple exceptions sharing the same HTTP status code are grouped under one entry, each appearing as a named example keyed by its `err_code`. This keeps the OpenAPI UI readable when several error variants map to the same status.
|
||||||
The pydantic validation error is automatically added by FastAPI.
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
@@ -6,32 +6,46 @@ from ..schemas import ApiError, ErrorResponse, ResponseStatus
|
|||||||
|
|
||||||
|
|
||||||
class ApiException(Exception):
|
class ApiException(Exception):
|
||||||
"""Base exception for API errors with structured response.
|
"""Base exception for API errors with structured response."""
|
||||||
|
|
||||||
Subclass this to create custom API exceptions with consistent error format.
|
|
||||||
The exception handler will use api_error to generate the response.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```python
|
|
||||||
class CustomError(ApiException):
|
|
||||||
api_error = ApiError(
|
|
||||||
code=400,
|
|
||||||
msg="Bad Request",
|
|
||||||
desc="The request was invalid.",
|
|
||||||
err_code="CUSTOM-400",
|
|
||||||
)
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
api_error: ClassVar[ApiError]
|
api_error: ClassVar[ApiError]
|
||||||
|
|
||||||
def __init__(self, detail: str | None = None):
|
def __init_subclass__(cls, abstract: bool = False, **kwargs: Any) -> None:
|
||||||
|
super().__init_subclass__(**kwargs)
|
||||||
|
if not abstract and not hasattr(cls, "api_error"):
|
||||||
|
raise TypeError(
|
||||||
|
f"{cls.__name__} must define an 'api_error' class attribute. "
|
||||||
|
"Pass abstract=True when creating intermediate base classes."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
detail: str | None = None,
|
||||||
|
*,
|
||||||
|
desc: str | None = None,
|
||||||
|
data: Any = None,
|
||||||
|
) -> None:
|
||||||
"""Initialize the exception.
|
"""Initialize the exception.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
detail: Optional override for the error message
|
detail: Optional human-readable message
|
||||||
|
desc: Optional per-instance override for the ``description`` field
|
||||||
|
in the HTTP response body.
|
||||||
|
data: Optional per-instance override for the ``data`` field in the
|
||||||
|
HTTP response body.
|
||||||
"""
|
"""
|
||||||
super().__init__(detail or self.api_error.msg)
|
updates: dict[str, Any] = {}
|
||||||
|
if detail is not None:
|
||||||
|
updates["msg"] = detail
|
||||||
|
if desc is not None:
|
||||||
|
updates["desc"] = desc
|
||||||
|
if data is not None:
|
||||||
|
updates["data"] = data
|
||||||
|
if updates:
|
||||||
|
object.__setattr__(
|
||||||
|
self, "api_error", self.__class__.api_error.model_copy(update=updates)
|
||||||
|
)
|
||||||
|
super().__init__(self.api_error.msg)
|
||||||
|
|
||||||
|
|
||||||
class UnauthorizedError(ApiException):
|
class UnauthorizedError(ApiException):
|
||||||
@@ -92,14 +106,15 @@ class NoSearchableFieldsError(ApiException):
|
|||||||
"""Initialize the exception.
|
"""Initialize the exception.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: The SQLAlchemy model class that has no searchable fields
|
model: The model class that has no searchable fields configured.
|
||||||
"""
|
"""
|
||||||
self.model = model
|
self.model = model
|
||||||
detail = (
|
super().__init__(
|
||||||
|
desc=(
|
||||||
f"No searchable fields found for model '{model.__name__}'. "
|
f"No searchable fields found for model '{model.__name__}'. "
|
||||||
"Provide 'search_fields' parameter or set 'searchable_fields' on the CRUD class."
|
"Provide 'search_fields' parameter or set 'searchable_fields' on the CRUD class."
|
||||||
)
|
)
|
||||||
super().__init__(detail)
|
)
|
||||||
|
|
||||||
|
|
||||||
class InvalidFacetFilterError(ApiException):
|
class InvalidFacetFilterError(ApiException):
|
||||||
@@ -116,16 +131,17 @@ class InvalidFacetFilterError(ApiException):
|
|||||||
"""Initialize the exception.
|
"""Initialize the exception.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
key: The unknown filter key provided by the caller
|
key: The unknown filter key provided by the caller.
|
||||||
valid_keys: Set of valid keys derived from the declared facet_fields
|
valid_keys: Set of valid keys derived from the declared facet_fields.
|
||||||
"""
|
"""
|
||||||
self.key = key
|
self.key = key
|
||||||
self.valid_keys = valid_keys
|
self.valid_keys = valid_keys
|
||||||
detail = (
|
super().__init__(
|
||||||
|
desc=(
|
||||||
f"'{key}' is not a declared facet field. "
|
f"'{key}' is not a declared facet field. "
|
||||||
f"Valid keys: {sorted(valid_keys) or 'none — set facet_fields on the CRUD class'}."
|
f"Valid keys: {sorted(valid_keys) or 'none — set facet_fields on the CRUD class'}."
|
||||||
)
|
)
|
||||||
super().__init__(detail)
|
)
|
||||||
|
|
||||||
|
|
||||||
class InvalidOrderFieldError(ApiException):
|
class InvalidOrderFieldError(ApiException):
|
||||||
@@ -142,15 +158,14 @@ class InvalidOrderFieldError(ApiException):
|
|||||||
"""Initialize the exception.
|
"""Initialize the exception.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
field: The unknown order field provided by the caller
|
field: The unknown order field provided by the caller.
|
||||||
valid_fields: List of valid field names
|
valid_fields: List of valid field names.
|
||||||
"""
|
"""
|
||||||
self.field = field
|
self.field = field
|
||||||
self.valid_fields = valid_fields
|
self.valid_fields = valid_fields
|
||||||
detail = (
|
super().__init__(
|
||||||
f"'{field}' is not an allowed order field. Valid fields: {valid_fields}."
|
desc=f"'{field}' is not an allowed order field. Valid fields: {valid_fields}."
|
||||||
)
|
)
|
||||||
super().__init__(detail)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_error_responses(
|
def generate_error_responses(
|
||||||
@@ -158,44 +173,39 @@ def generate_error_responses(
|
|||||||
) -> dict[int | str, dict[str, Any]]:
|
) -> dict[int | str, dict[str, Any]]:
|
||||||
"""Generate OpenAPI response documentation for exceptions.
|
"""Generate OpenAPI response documentation for exceptions.
|
||||||
|
|
||||||
Use this to document possible error responses for an endpoint.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
*errors: Exception classes that inherit from ApiException
|
*errors: Exception classes that inherit from ApiException.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict suitable for FastAPI's responses parameter
|
Dict suitable for FastAPI's ``responses`` parameter.
|
||||||
|
|
||||||
Example:
|
|
||||||
```python
|
|
||||||
from fastapi_toolsets.exceptions import generate_error_responses, UnauthorizedError, ForbiddenError
|
|
||||||
|
|
||||||
@app.get(
|
|
||||||
"/admin",
|
|
||||||
responses=generate_error_responses(UnauthorizedError, ForbiddenError)
|
|
||||||
)
|
|
||||||
async def admin_endpoint():
|
|
||||||
...
|
|
||||||
```
|
|
||||||
"""
|
"""
|
||||||
responses: dict[int | str, dict[str, Any]] = {}
|
responses: dict[int | str, dict[str, Any]] = {}
|
||||||
|
|
||||||
for error in errors:
|
for error in errors:
|
||||||
api_error = error.api_error
|
api_error = error.api_error
|
||||||
|
code = api_error.code
|
||||||
|
|
||||||
responses[api_error.code] = {
|
if code not in responses:
|
||||||
|
responses[code] = {
|
||||||
"model": ErrorResponse,
|
"model": ErrorResponse,
|
||||||
"description": api_error.msg,
|
"description": api_error.msg,
|
||||||
"content": {
|
"content": {
|
||||||
"application/json": {
|
"application/json": {
|
||||||
"example": {
|
"examples": {},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
responses[code]["content"]["application/json"]["examples"][
|
||||||
|
api_error.err_code
|
||||||
|
] = {
|
||||||
|
"summary": api_error.msg,
|
||||||
|
"value": {
|
||||||
"data": api_error.data,
|
"data": api_error.data,
|
||||||
"status": ResponseStatus.FAIL.value,
|
"status": ResponseStatus.FAIL.value,
|
||||||
"message": api_error.msg,
|
"message": api_error.msg,
|
||||||
"description": api_error.desc,
|
"description": api_error.desc,
|
||||||
"error_code": api_error.err_code,
|
"error_code": api_error.err_code,
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,14 @@
|
|||||||
"""Exception handlers for FastAPI applications."""
|
"""Exception handlers for FastAPI applications."""
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import FastAPI, Request, Response, status
|
from fastapi import FastAPI, Request, Response, status
|
||||||
from fastapi.exceptions import RequestValidationError, ResponseValidationError
|
from fastapi.exceptions import (
|
||||||
from fastapi.openapi.utils import get_openapi
|
HTTPException,
|
||||||
|
RequestValidationError,
|
||||||
|
ResponseValidationError,
|
||||||
|
)
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
from ..schemas import ErrorResponse, ResponseStatus
|
from ..schemas import ErrorResponse, ResponseStatus
|
||||||
@@ -14,43 +18,20 @@ from .exceptions import ApiException
|
|||||||
def init_exceptions_handlers(app: FastAPI) -> FastAPI:
|
def init_exceptions_handlers(app: FastAPI) -> FastAPI:
|
||||||
"""Register exception handlers and custom OpenAPI schema on a FastAPI app.
|
"""Register exception handlers and custom OpenAPI schema on a FastAPI app.
|
||||||
|
|
||||||
Installs handlers for :class:`ApiException`, validation errors, and
|
|
||||||
unhandled exceptions, and replaces the default 422 schema with a
|
|
||||||
consistent error format.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
app: FastAPI application instance
|
app: FastAPI application instance.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The same FastAPI instance (for chaining)
|
The same FastAPI instance (for chaining).
|
||||||
|
|
||||||
Example:
|
|
||||||
```python
|
|
||||||
from fastapi import FastAPI
|
|
||||||
from fastapi_toolsets.exceptions import init_exceptions_handlers
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
init_exceptions_handlers(app)
|
|
||||||
```
|
|
||||||
"""
|
"""
|
||||||
_register_exception_handlers(app)
|
_register_exception_handlers(app)
|
||||||
app.openapi = lambda: _custom_openapi(app) # type: ignore[method-assign]
|
_original_openapi = app.openapi
|
||||||
|
app.openapi = lambda: _patched_openapi(app, _original_openapi) # type: ignore[method-assign]
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
def _register_exception_handlers(app: FastAPI) -> None:
|
def _register_exception_handlers(app: FastAPI) -> None:
|
||||||
"""Register all exception handlers on a FastAPI application.
|
"""Register all exception handlers on a FastAPI application."""
|
||||||
|
|
||||||
Args:
|
|
||||||
app: FastAPI application instance
|
|
||||||
|
|
||||||
Example:
|
|
||||||
from fastapi import FastAPI
|
|
||||||
from fastapi_toolsets.exceptions import init_exceptions_handlers
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
init_exceptions_handlers(app)
|
|
||||||
"""
|
|
||||||
|
|
||||||
@app.exception_handler(ApiException)
|
@app.exception_handler(ApiException)
|
||||||
async def api_exception_handler(request: Request, exc: ApiException) -> Response:
|
async def api_exception_handler(request: Request, exc: ApiException) -> Response:
|
||||||
@@ -62,12 +43,25 @@ def _register_exception_handlers(app: FastAPI) -> None:
|
|||||||
description=api_error.desc,
|
description=api_error.desc,
|
||||||
error_code=api_error.err_code,
|
error_code=api_error.err_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=api_error.code,
|
status_code=api_error.code,
|
||||||
content=error_response.model_dump(),
|
content=error_response.model_dump(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@app.exception_handler(HTTPException)
|
||||||
|
async def http_exception_handler(request: Request, exc: HTTPException) -> Response:
|
||||||
|
"""Handle Starlette/FastAPI HTTPException with a consistent error format."""
|
||||||
|
detail = exc.detail if isinstance(exc.detail, str) else "HTTP Error"
|
||||||
|
error_response = ErrorResponse(
|
||||||
|
message=detail,
|
||||||
|
error_code=f"HTTP-{exc.status_code}",
|
||||||
|
)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=exc.status_code,
|
||||||
|
content=error_response.model_dump(),
|
||||||
|
headers=getattr(exc, "headers", None),
|
||||||
|
)
|
||||||
|
|
||||||
@app.exception_handler(RequestValidationError)
|
@app.exception_handler(RequestValidationError)
|
||||||
async def request_validation_handler(
|
async def request_validation_handler(
|
||||||
request: Request, exc: RequestValidationError
|
request: Request, exc: RequestValidationError
|
||||||
@@ -90,7 +84,6 @@ def _register_exception_handlers(app: FastAPI) -> None:
|
|||||||
description="An unexpected error occurred. Please try again later.",
|
description="An unexpected error occurred. Please try again later.",
|
||||||
error_code="SERVER-500",
|
error_code="SERVER-500",
|
||||||
)
|
)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
content=error_response.model_dump(),
|
content=error_response.model_dump(),
|
||||||
@@ -105,11 +98,10 @@ def _format_validation_error(
|
|||||||
formatted_errors = []
|
formatted_errors = []
|
||||||
|
|
||||||
for error in errors:
|
for error in errors:
|
||||||
field_path = ".".join(
|
locs = error["loc"]
|
||||||
str(loc)
|
if locs and locs[0] in ("body", "query", "path", "header", "cookie"):
|
||||||
for loc in error["loc"]
|
locs = locs[1:]
|
||||||
if loc not in ("body", "query", "path", "header", "cookie")
|
field_path = ".".join(str(loc) for loc in locs)
|
||||||
)
|
|
||||||
formatted_errors.append(
|
formatted_errors.append(
|
||||||
{
|
{
|
||||||
"field": field_path or "root",
|
"field": field_path or "root",
|
||||||
@@ -131,34 +123,22 @@ def _format_validation_error(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _custom_openapi(app: FastAPI) -> dict[str, Any]:
|
def _patched_openapi(
|
||||||
"""Generate custom OpenAPI schema with standardized error format.
|
app: FastAPI, original_openapi: Callable[[], dict[str, Any]]
|
||||||
|
) -> dict[str, Any]:
|
||||||
Replaces default 422 validation error responses with the custom format.
|
"""Generate the OpenAPI schema and replace default 422 responses.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
app: FastAPI application instance
|
app: FastAPI application instance.
|
||||||
|
original_openapi: The previous ``app.openapi`` callable to delegate to.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
OpenAPI schema dict
|
Patched OpenAPI schema dict.
|
||||||
|
|
||||||
Example:
|
|
||||||
from fastapi import FastAPI
|
|
||||||
from fastapi_toolsets.exceptions import init_exceptions_handlers
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
init_exceptions_handlers(app) # Automatically sets custom OpenAPI
|
|
||||||
"""
|
"""
|
||||||
if app.openapi_schema:
|
if app.openapi_schema:
|
||||||
return app.openapi_schema
|
return app.openapi_schema
|
||||||
|
|
||||||
openapi_schema = get_openapi(
|
openapi_schema = original_openapi()
|
||||||
title=app.title,
|
|
||||||
version=app.version,
|
|
||||||
openapi_version=app.openapi_version,
|
|
||||||
description=app.description,
|
|
||||||
routes=app.routes,
|
|
||||||
)
|
|
||||||
|
|
||||||
for path_data in openapi_schema.get("paths", {}).values():
|
for path_data in openapi_schema.get("paths", {}).values():
|
||||||
for operation in path_data.values():
|
for operation in path_data.values():
|
||||||
@@ -168,7 +148,10 @@ def _custom_openapi(app: FastAPI) -> dict[str, Any]:
|
|||||||
"description": "Validation Error",
|
"description": "Validation Error",
|
||||||
"content": {
|
"content": {
|
||||||
"application/json": {
|
"application/json": {
|
||||||
"example": {
|
"examples": {
|
||||||
|
"VAL-422": {
|
||||||
|
"summary": "Validation Error",
|
||||||
|
"value": {
|
||||||
"data": {
|
"data": {
|
||||||
"errors": [
|
"errors": [
|
||||||
{
|
{
|
||||||
@@ -182,6 +165,8 @@ def _custom_openapi(app: FastAPI) -> dict[str, Any]:
|
|||||||
"message": "Validation Error",
|
"message": "Validation Error",
|
||||||
"description": "1 validation error(s) detected",
|
"description": "1 validation error(s) detected",
|
||||||
"error_code": "VAL-422",
|
"error_code": "VAL-422",
|
||||||
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -410,7 +410,7 @@ class TestNoSearchableFieldsError:
|
|||||||
from fastapi_toolsets.exceptions import NoSearchableFieldsError
|
from fastapi_toolsets.exceptions import NoSearchableFieldsError
|
||||||
|
|
||||||
error = NoSearchableFieldsError(User)
|
error = NoSearchableFieldsError(User)
|
||||||
assert "User" in str(error)
|
assert "User" in error.api_error.desc
|
||||||
assert error.model is User
|
assert error.model is User
|
||||||
|
|
||||||
def test_error_raised_when_no_fields(self):
|
def test_error_raised_when_no_fields(self):
|
||||||
@@ -434,7 +434,7 @@ class TestNoSearchableFieldsError:
|
|||||||
build_search_filters(NoStringModel, "test")
|
build_search_filters(NoStringModel, "test")
|
||||||
|
|
||||||
assert exc_info.value.model is NoStringModel
|
assert exc_info.value.model is NoStringModel
|
||||||
assert "NoStringModel" in str(exc_info.value)
|
assert "NoStringModel" in exc_info.value.api_error.desc
|
||||||
|
|
||||||
|
|
||||||
class TestGetSearchableFields:
|
class TestGetSearchableFields:
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
from fastapi.exceptions import HTTPException
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
from fastapi_toolsets.exceptions import (
|
from fastapi_toolsets.exceptions import (
|
||||||
@@ -36,8 +37,8 @@ class TestApiException:
|
|||||||
assert error.api_error.msg == "I'm a teapot"
|
assert error.api_error.msg == "I'm a teapot"
|
||||||
assert str(error) == "I'm a teapot"
|
assert str(error) == "I'm a teapot"
|
||||||
|
|
||||||
def test_custom_detail_message(self):
|
def test_detail_overrides_msg_and_str(self):
|
||||||
"""Custom detail overrides default message."""
|
"""detail sets both str(exc) and api_error.msg; class-level msg is unchanged."""
|
||||||
|
|
||||||
class CustomError(ApiException):
|
class CustomError(ApiException):
|
||||||
api_error = ApiError(
|
api_error = ApiError(
|
||||||
@@ -47,8 +48,172 @@ class TestApiException:
|
|||||||
err_code="BAD-400",
|
err_code="BAD-400",
|
||||||
)
|
)
|
||||||
|
|
||||||
error = CustomError("Custom message")
|
error = CustomError("Widget not found")
|
||||||
assert str(error) == "Custom message"
|
assert str(error) == "Widget not found"
|
||||||
|
assert error.api_error.msg == "Widget not found"
|
||||||
|
assert CustomError.api_error.msg == "Bad Request" # class unchanged
|
||||||
|
|
||||||
|
def test_desc_override(self):
|
||||||
|
"""desc kwarg overrides api_error.desc on the instance only."""
|
||||||
|
|
||||||
|
class MyError(ApiException):
|
||||||
|
api_error = ApiError(
|
||||||
|
code=400, msg="Error", desc="Default.", err_code="ERR-400"
|
||||||
|
)
|
||||||
|
|
||||||
|
err = MyError(desc="Custom desc.")
|
||||||
|
assert err.api_error.desc == "Custom desc."
|
||||||
|
assert MyError.api_error.desc == "Default." # class unchanged
|
||||||
|
|
||||||
|
def test_data_override(self):
|
||||||
|
"""data kwarg sets api_error.data on the instance only."""
|
||||||
|
|
||||||
|
class MyError(ApiException):
|
||||||
|
api_error = ApiError(
|
||||||
|
code=400, msg="Error", desc="Default.", err_code="ERR-400"
|
||||||
|
)
|
||||||
|
|
||||||
|
err = MyError(data={"key": "value"})
|
||||||
|
assert err.api_error.data == {"key": "value"}
|
||||||
|
assert MyError.api_error.data is None # class unchanged
|
||||||
|
|
||||||
|
def test_desc_and_data_override(self):
|
||||||
|
"""detail, desc and data can all be overridden together."""
|
||||||
|
|
||||||
|
class MyError(ApiException):
|
||||||
|
api_error = ApiError(
|
||||||
|
code=400, msg="Error", desc="Default.", err_code="ERR-400"
|
||||||
|
)
|
||||||
|
|
||||||
|
err = MyError("custom msg", desc="New desc.", data={"x": 1})
|
||||||
|
assert str(err) == "custom msg"
|
||||||
|
assert err.api_error.msg == "custom msg" # detail also updates msg
|
||||||
|
assert err.api_error.desc == "New desc."
|
||||||
|
assert err.api_error.data == {"x": 1}
|
||||||
|
assert err.api_error.code == 400 # other fields unchanged
|
||||||
|
|
||||||
|
def test_class_api_error_not_mutated_after_instance_override(self):
|
||||||
|
"""Raising with desc/data does not mutate the class-level api_error."""
|
||||||
|
|
||||||
|
class MyError(ApiException):
|
||||||
|
api_error = ApiError(
|
||||||
|
code=400, msg="Error", desc="Default.", err_code="ERR-400"
|
||||||
|
)
|
||||||
|
|
||||||
|
MyError(desc="Changed", data={"x": 1})
|
||||||
|
assert MyError.api_error.desc == "Default."
|
||||||
|
assert MyError.api_error.data is None
|
||||||
|
|
||||||
|
def test_subclass_uses_super_with_desc_and_data(self):
|
||||||
|
"""Subclasses can delegate detail/desc/data to super().__init__()."""
|
||||||
|
|
||||||
|
class BuildValidationError(ApiException):
|
||||||
|
api_error = ApiError(
|
||||||
|
code=422,
|
||||||
|
msg="Build Validation Error",
|
||||||
|
desc="The build configuration is invalid.",
|
||||||
|
err_code="BUILD-422",
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, *errors: str) -> None:
|
||||||
|
super().__init__(
|
||||||
|
f"{len(errors)} validation error(s)",
|
||||||
|
desc=", ".join(errors),
|
||||||
|
data={"errors": [{"message": e} for e in errors]},
|
||||||
|
)
|
||||||
|
|
||||||
|
err = BuildValidationError("Field A is required", "Field B is invalid")
|
||||||
|
assert str(err) == "2 validation error(s)"
|
||||||
|
assert err.api_error.msg == "2 validation error(s)" # detail set msg
|
||||||
|
assert err.api_error.desc == "Field A is required, Field B is invalid"
|
||||||
|
assert err.api_error.data == {
|
||||||
|
"errors": [
|
||||||
|
{"message": "Field A is required"},
|
||||||
|
{"message": "Field B is invalid"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
assert err.api_error.code == 422 # other fields unchanged
|
||||||
|
|
||||||
|
def test_detail_desc_data_in_http_response(self):
|
||||||
|
"""detail/desc/data overrides all appear correctly in the FastAPI HTTP response."""
|
||||||
|
|
||||||
|
class DynamicError(ApiException):
|
||||||
|
api_error = ApiError(
|
||||||
|
code=400, msg="Error", desc="Default.", err_code="ERR-400"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, message: str) -> None:
|
||||||
|
super().__init__(
|
||||||
|
message,
|
||||||
|
desc=f"Detail: {message}",
|
||||||
|
data={"reason": message},
|
||||||
|
)
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app)
|
||||||
|
|
||||||
|
@app.get("/error")
|
||||||
|
async def raise_error():
|
||||||
|
raise DynamicError("something went wrong")
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/error")
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
body = response.json()
|
||||||
|
assert body["message"] == "something went wrong"
|
||||||
|
assert body["description"] == "Detail: something went wrong"
|
||||||
|
assert body["data"] == {"reason": "something went wrong"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestApiExceptionGuard:
|
||||||
|
"""Tests for the __init_subclass__ api_error guard."""
|
||||||
|
|
||||||
|
def test_missing_api_error_raises_type_error(self):
|
||||||
|
"""Defining a subclass without api_error raises TypeError at class creation time."""
|
||||||
|
with pytest.raises(
|
||||||
|
TypeError, match="must define an 'api_error' class attribute"
|
||||||
|
):
|
||||||
|
|
||||||
|
class BrokenError(ApiException):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_abstract_subclass_skips_guard(self):
|
||||||
|
"""abstract=True allows intermediate base classes without api_error."""
|
||||||
|
|
||||||
|
class BaseGroupError(ApiException, abstract=True):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Concrete child must still define it
|
||||||
|
class ConcreteError(BaseGroupError):
|
||||||
|
api_error = ApiError(
|
||||||
|
code=400, msg="Error", desc="Desc.", err_code="ERR-400"
|
||||||
|
)
|
||||||
|
|
||||||
|
err = ConcreteError()
|
||||||
|
assert err.api_error.code == 400
|
||||||
|
|
||||||
|
def test_abstract_child_still_requires_api_error_on_concrete(self):
|
||||||
|
"""Concrete subclass of an abstract class must define api_error."""
|
||||||
|
|
||||||
|
class Base(ApiException, abstract=True):
|
||||||
|
pass
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
TypeError, match="must define an 'api_error' class attribute"
|
||||||
|
):
|
||||||
|
|
||||||
|
class Concrete(Base):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_inherited_api_error_satisfies_guard(self):
|
||||||
|
"""Subclass that inherits api_error from a parent does not need its own."""
|
||||||
|
|
||||||
|
class ConcreteError(NotFoundError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
err = ConcreteError()
|
||||||
|
assert err.api_error.code == 404
|
||||||
|
|
||||||
|
|
||||||
class TestBuiltInExceptions:
|
class TestBuiltInExceptions:
|
||||||
@@ -90,7 +255,7 @@ class TestGenerateErrorResponses:
|
|||||||
assert responses[404]["description"] == "Not Found"
|
assert responses[404]["description"] == "Not Found"
|
||||||
|
|
||||||
def test_generates_multiple_responses(self):
|
def test_generates_multiple_responses(self):
|
||||||
"""Generates responses for multiple exceptions."""
|
"""Generates responses for multiple exceptions with distinct status codes."""
|
||||||
responses = generate_error_responses(
|
responses = generate_error_responses(
|
||||||
UnauthorizedError,
|
UnauthorizedError,
|
||||||
ForbiddenError,
|
ForbiddenError,
|
||||||
@@ -101,15 +266,24 @@ class TestGenerateErrorResponses:
|
|||||||
assert 403 in responses
|
assert 403 in responses
|
||||||
assert 404 in responses
|
assert 404 in responses
|
||||||
|
|
||||||
def test_response_has_example(self):
|
def test_response_has_named_example(self):
|
||||||
"""Generated response includes example."""
|
"""Generated response uses named examples keyed by err_code."""
|
||||||
responses = generate_error_responses(NotFoundError)
|
responses = generate_error_responses(NotFoundError)
|
||||||
example = responses[404]["content"]["application/json"]["example"]
|
examples = responses[404]["content"]["application/json"]["examples"]
|
||||||
|
|
||||||
assert example["status"] == "FAIL"
|
assert "RES-404" in examples
|
||||||
assert example["error_code"] == "RES-404"
|
value = examples["RES-404"]["value"]
|
||||||
assert example["message"] == "Not Found"
|
assert value["status"] == "FAIL"
|
||||||
assert example["data"] is None
|
assert value["error_code"] == "RES-404"
|
||||||
|
assert value["message"] == "Not Found"
|
||||||
|
assert value["data"] is None
|
||||||
|
|
||||||
|
def test_response_example_has_summary(self):
|
||||||
|
"""Each named example carries a summary equal to api_error.msg."""
|
||||||
|
responses = generate_error_responses(NotFoundError)
|
||||||
|
example = responses[404]["content"]["application/json"]["examples"]["RES-404"]
|
||||||
|
|
||||||
|
assert example["summary"] == "Not Found"
|
||||||
|
|
||||||
def test_response_example_with_data(self):
|
def test_response_example_with_data(self):
|
||||||
"""Generated response includes data when set on ApiError."""
|
"""Generated response includes data when set on ApiError."""
|
||||||
@@ -124,9 +298,49 @@ class TestGenerateErrorResponses:
|
|||||||
)
|
)
|
||||||
|
|
||||||
responses = generate_error_responses(ErrorWithData)
|
responses = generate_error_responses(ErrorWithData)
|
||||||
example = responses[400]["content"]["application/json"]["example"]
|
value = responses[400]["content"]["application/json"]["examples"]["BAD-400"][
|
||||||
|
"value"
|
||||||
|
]
|
||||||
|
|
||||||
assert example["data"] == {"details": "some context"}
|
assert value["data"] == {"details": "some context"}
|
||||||
|
|
||||||
|
def test_two_errors_same_code_both_present(self):
|
||||||
|
"""Two exceptions with the same HTTP code produce two named examples."""
|
||||||
|
|
||||||
|
class BadRequestA(ApiException):
|
||||||
|
api_error = ApiError(
|
||||||
|
code=400, msg="Bad A", desc="Reason A.", err_code="ERR-A"
|
||||||
|
)
|
||||||
|
|
||||||
|
class BadRequestB(ApiException):
|
||||||
|
api_error = ApiError(
|
||||||
|
code=400, msg="Bad B", desc="Reason B.", err_code="ERR-B"
|
||||||
|
)
|
||||||
|
|
||||||
|
responses = generate_error_responses(BadRequestA, BadRequestB)
|
||||||
|
|
||||||
|
assert 400 in responses
|
||||||
|
examples = responses[400]["content"]["application/json"]["examples"]
|
||||||
|
assert "ERR-A" in examples
|
||||||
|
assert "ERR-B" in examples
|
||||||
|
assert examples["ERR-A"]["value"]["message"] == "Bad A"
|
||||||
|
assert examples["ERR-B"]["value"]["message"] == "Bad B"
|
||||||
|
|
||||||
|
def test_two_errors_same_code_single_top_level_entry(self):
|
||||||
|
"""Two exceptions with the same HTTP code produce exactly one top-level entry."""
|
||||||
|
|
||||||
|
class BadRequestA(ApiException):
|
||||||
|
api_error = ApiError(
|
||||||
|
code=400, msg="Bad A", desc="Reason A.", err_code="ERR-A"
|
||||||
|
)
|
||||||
|
|
||||||
|
class BadRequestB(ApiException):
|
||||||
|
api_error = ApiError(
|
||||||
|
code=400, msg="Bad B", desc="Reason B.", err_code="ERR-B"
|
||||||
|
)
|
||||||
|
|
||||||
|
responses = generate_error_responses(BadRequestA, BadRequestB)
|
||||||
|
assert len([k for k in responses if k == 400]) == 1
|
||||||
|
|
||||||
|
|
||||||
class TestInitExceptionsHandlers:
|
class TestInitExceptionsHandlers:
|
||||||
@@ -250,13 +464,68 @@ class TestInitExceptionsHandlers:
|
|||||||
assert data["status"] == "FAIL"
|
assert data["status"] == "FAIL"
|
||||||
assert data["error_code"] == "SERVER-500"
|
assert data["error_code"] == "SERVER-500"
|
||||||
|
|
||||||
def test_custom_openapi_schema(self):
|
def test_handles_http_exception(self):
|
||||||
"""Customizes OpenAPI schema for 422 responses."""
|
"""Handles starlette HTTPException with consistent ErrorResponse envelope."""
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
init_exceptions_handlers(app)
|
init_exceptions_handlers(app)
|
||||||
|
|
||||||
|
@app.get("/protected")
|
||||||
|
async def protected():
|
||||||
|
raise HTTPException(status_code=403, detail="Forbidden")
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/protected")
|
||||||
|
|
||||||
|
assert response.status_code == 403
|
||||||
|
data = response.json()
|
||||||
|
assert data["status"] == "FAIL"
|
||||||
|
assert data["error_code"] == "HTTP-403"
|
||||||
|
assert data["message"] == "Forbidden"
|
||||||
|
|
||||||
|
def test_handles_http_exception_404_from_route(self):
|
||||||
|
"""HTTPException(404) raised inside a route uses the consistent ErrorResponse envelope."""
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app)
|
||||||
|
|
||||||
|
@app.get("/items/{item_id}")
|
||||||
|
async def get_item(item_id: int):
|
||||||
|
raise HTTPException(status_code=404, detail="Item not found")
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/items/99")
|
||||||
|
|
||||||
|
assert response.status_code == 404
|
||||||
|
data = response.json()
|
||||||
|
assert data["status"] == "FAIL"
|
||||||
|
assert data["error_code"] == "HTTP-404"
|
||||||
|
assert data["message"] == "Item not found"
|
||||||
|
|
||||||
|
def test_handles_http_exception_forwards_headers(self):
|
||||||
|
"""HTTPException with WWW-Authenticate header forwards it in the response."""
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app)
|
||||||
|
|
||||||
|
@app.get("/secure")
|
||||||
|
async def secure():
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail="Not authenticated",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/secure")
|
||||||
|
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert response.headers.get("www-authenticate") == "Bearer"
|
||||||
|
|
||||||
|
def test_custom_openapi_schema(self):
|
||||||
|
"""Customises OpenAPI schema for 422 responses using named examples."""
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app)
|
||||||
|
|
||||||
class Item(BaseModel):
|
class Item(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
@@ -269,8 +538,128 @@ class TestInitExceptionsHandlers:
|
|||||||
post_op = openapi["paths"]["/items"]["post"]
|
post_op = openapi["paths"]["/items"]["post"]
|
||||||
assert "422" in post_op["responses"]
|
assert "422" in post_op["responses"]
|
||||||
resp_422 = post_op["responses"]["422"]
|
resp_422 = post_op["responses"]["422"]
|
||||||
example = resp_422["content"]["application/json"]["example"]
|
examples = resp_422["content"]["application/json"]["examples"]
|
||||||
assert example["error_code"] == "VAL-422"
|
assert "VAL-422" in examples
|
||||||
|
assert examples["VAL-422"]["value"]["error_code"] == "VAL-422"
|
||||||
|
|
||||||
|
def test_custom_openapi_preserves_app_metadata(self):
|
||||||
|
"""_patched_openapi preserves custom FastAPI app-level metadata."""
|
||||||
|
app = FastAPI(
|
||||||
|
title="My API",
|
||||||
|
version="2.0.0",
|
||||||
|
description="Custom description",
|
||||||
|
)
|
||||||
|
init_exceptions_handlers(app)
|
||||||
|
|
||||||
|
schema = app.openapi()
|
||||||
|
assert schema["info"]["title"] == "My API"
|
||||||
|
assert schema["info"]["version"] == "2.0.0"
|
||||||
|
|
||||||
|
def test_handles_response_validation_error(self):
|
||||||
|
"""Handles ResponseValidationError with a structured 422 response."""
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
class CountResponse(BaseModel):
|
||||||
|
count: int
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app)
|
||||||
|
|
||||||
|
@app.get("/broken", response_model=CountResponse)
|
||||||
|
async def broken():
|
||||||
|
return {"count": "not-a-number"} # triggers ResponseValidationError
|
||||||
|
|
||||||
|
client = TestClient(app, raise_server_exceptions=False)
|
||||||
|
response = client.get("/broken")
|
||||||
|
|
||||||
|
assert response.status_code == 422
|
||||||
|
data = response.json()
|
||||||
|
assert data["status"] == "FAIL"
|
||||||
|
assert data["error_code"] == "VAL-422"
|
||||||
|
assert "errors" in data["data"]
|
||||||
|
|
||||||
|
def test_handles_validation_error_with_non_standard_loc(self):
|
||||||
|
"""Validation error with empty loc tuple maps the field to 'root'."""
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app)
|
||||||
|
|
||||||
|
@app.get("/root-error")
|
||||||
|
async def root_error():
|
||||||
|
raise RequestValidationError(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"type": "custom",
|
||||||
|
"loc": (),
|
||||||
|
"msg": "root level error",
|
||||||
|
"input": None,
|
||||||
|
"url": "",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/root-error")
|
||||||
|
|
||||||
|
assert response.status_code == 422
|
||||||
|
data = response.json()
|
||||||
|
assert data["data"]["errors"][0]["field"] == "root"
|
||||||
|
|
||||||
|
def test_openapi_schema_cached_after_first_call(self):
|
||||||
|
"""app.openapi() returns the cached schema on subsequent calls."""
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app)
|
||||||
|
|
||||||
|
class Item(BaseModel):
|
||||||
|
name: str
|
||||||
|
|
||||||
|
@app.post("/items")
|
||||||
|
async def create_item(item: Item):
|
||||||
|
return item
|
||||||
|
|
||||||
|
schema_first = app.openapi()
|
||||||
|
schema_second = app.openapi()
|
||||||
|
assert schema_first is schema_second
|
||||||
|
|
||||||
|
def test_openapi_skips_operations_without_422(self):
|
||||||
|
"""_patched_openapi leaves operations that have no 422 response unchanged."""
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app)
|
||||||
|
|
||||||
|
@app.get("/ping")
|
||||||
|
async def ping():
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
schema = app.openapi()
|
||||||
|
get_op = schema["paths"]["/ping"]["get"]
|
||||||
|
assert "422" not in get_op["responses"]
|
||||||
|
assert "200" in get_op["responses"]
|
||||||
|
|
||||||
|
def test_openapi_skips_non_dict_path_item_values(self):
|
||||||
|
"""_patched_openapi ignores non-dict values in path items (e.g. path-level parameters)."""
|
||||||
|
from fastapi_toolsets.exceptions.handler import _patched_openapi
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
def fake_openapi() -> dict:
|
||||||
|
return {
|
||||||
|
"paths": {
|
||||||
|
"/items": {
|
||||||
|
"parameters": [
|
||||||
|
{"name": "q", "in": "query"}
|
||||||
|
], # list, not a dict
|
||||||
|
"get": {"responses": {"200": {"description": "OK"}}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
schema = _patched_openapi(app, fake_openapi)
|
||||||
|
# The list value was skipped without error; the GET operation is intact
|
||||||
|
assert schema["paths"]["/items"]["parameters"] == [{"name": "q", "in": "query"}]
|
||||||
|
assert "422" not in schema["paths"]["/items"]["get"]["responses"]
|
||||||
|
|
||||||
|
|
||||||
class TestExceptionIntegration:
|
class TestExceptionIntegration:
|
||||||
@@ -352,12 +741,12 @@ class TestInvalidOrderFieldError:
|
|||||||
assert error.field == "unknown"
|
assert error.field == "unknown"
|
||||||
assert error.valid_fields == ["name", "created_at"]
|
assert error.valid_fields == ["name", "created_at"]
|
||||||
|
|
||||||
def test_message_contains_field_and_valid_fields(self):
|
def test_description_contains_field_and_valid_fields(self):
|
||||||
"""Exception message mentions the bad field and valid options."""
|
"""api_error.desc mentions the bad field and valid options."""
|
||||||
error = InvalidOrderFieldError("bad_field", ["name", "email"])
|
error = InvalidOrderFieldError("bad_field", ["name", "email"])
|
||||||
assert "bad_field" in str(error)
|
assert "bad_field" in error.api_error.desc
|
||||||
assert "name" in str(error)
|
assert "name" in error.api_error.desc
|
||||||
assert "email" in str(error)
|
assert "email" in error.api_error.desc
|
||||||
|
|
||||||
def test_handled_as_422_by_exception_handler(self):
|
def test_handled_as_422_by_exception_handler(self):
|
||||||
"""init_exceptions_handlers turns InvalidOrderFieldError into a 422 response."""
|
"""init_exceptions_handlers turns InvalidOrderFieldError into a 422 response."""
|
||||||
|
|||||||
Reference in New Issue
Block a user