7 Commits

Author SHA1 Message Date
3299a439fe Version 0.10.0 2026-02-17 07:28:10 -05:00
d3vyce
d5b22a72fd feat: add a metrics module (#67) 2026-02-17 13:24:53 +01:00
d3vyce
c32f2e18be feat: add many to many support in CrudFactory (#65) 2026-02-15 15:57:15 +01:00
d971261f98 Version 0.9.0 2026-02-14 14:38:58 -05:00
d3vyce
74a54b7396 feat: add optional data field in ApiError (#63) 2026-02-14 20:37:50 +01:00
d3vyce
19805ab376 feat: add dependency_overrides parameter to create_async_client (#61) 2026-02-13 18:11:11 +01:00
d3vyce
d4498e2063 feat: add cleanup parameter to create_db_session (#60) 2026-02-13 18:03:28 +01:00
18 changed files with 1507 additions and 46 deletions

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "fastapi-toolsets" name = "fastapi-toolsets"
version = "0.8.1" version = "0.10.0"
description = "Reusable tools for FastAPI: async CRUD, fixtures, CLI, and standardized responses for SQLAlchemy + PostgreSQL" description = "Reusable tools for FastAPI: async CRUD, fixtures, CLI, and standardized responses for SQLAlchemy + PostgreSQL"
readme = "README.md" readme = "README.md"
license = "MIT" license = "MIT"
@@ -46,6 +46,9 @@ Repository = "https://github.com/d3vyce/fastapi-toolsets"
Issues = "https://github.com/d3vyce/fastapi-toolsets/issues" Issues = "https://github.com/d3vyce/fastapi-toolsets/issues"
[project.optional-dependencies] [project.optional-dependencies]
metrics = [
"prometheus_client>=0.20.0",
]
test = [ test = [
"pytest>=8.0.0", "pytest>=8.0.0",
"pytest-anyio>=0.0.0", "pytest-anyio>=0.0.0",
@@ -54,7 +57,7 @@ test = [
"pytest-cov>=4.0.0", "pytest-cov>=4.0.0",
] ]
dev = [ dev = [
"fastapi-toolsets[test]", "fastapi-toolsets[metrics,test]",
"ruff>=0.1.0", "ruff>=0.1.0",
"ty>=0.0.1a0", "ty>=0.0.1a0",
] ]

View File

@@ -21,4 +21,4 @@ Example usage:
return Response(data={"user": user.username}, message="Success") return Response(data={"user": user.username}, message="Success")
""" """
__version__ = "0.8.1" __version__ = "0.10.0"

View File

@@ -1,7 +1,7 @@
"""Generic async CRUD operations for SQLAlchemy models.""" """Generic async CRUD operations for SQLAlchemy models."""
from ..exceptions import NoSearchableFieldsError from ..exceptions import NoSearchableFieldsError
from .factory import CrudFactory from .factory import CrudFactory, JoinType, M2MFieldType
from .search import ( from .search import (
SearchConfig, SearchConfig,
get_searchable_fields, get_searchable_fields,
@@ -10,6 +10,8 @@ from .search import (
__all__ = [ __all__ = [
"CrudFactory", "CrudFactory",
"get_searchable_fields", "get_searchable_fields",
"JoinType",
"M2MFieldType",
"NoSearchableFieldsError", "NoSearchableFieldsError",
"SearchConfig", "SearchConfig",
] ]

View File

@@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from collections.abc import Mapping, Sequence
from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload
from pydantic import BaseModel from pydantic import BaseModel
@@ -11,7 +11,7 @@ from sqlalchemy import delete as sql_delete
from sqlalchemy.dialects.postgresql import insert from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.exc import NoResultFound from sqlalchemy.exc import NoResultFound
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import DeclarativeBase, QueryableAttribute, selectinload
from sqlalchemy.sql.roles import WhereHavingRole from sqlalchemy.sql.roles import WhereHavingRole
from ..db import get_transaction from ..db import get_transaction
@@ -21,6 +21,7 @@ from .search import SearchConfig, SearchFieldType, build_search_filters
ModelType = TypeVar("ModelType", bound=DeclarativeBase) ModelType = TypeVar("ModelType", bound=DeclarativeBase)
JoinType = list[tuple[type[DeclarativeBase], Any]] JoinType = list[tuple[type[DeclarativeBase], Any]]
M2MFieldType = Mapping[str, QueryableAttribute[Any]]
class AsyncCrud(Generic[ModelType]): class AsyncCrud(Generic[ModelType]):
@@ -31,6 +32,7 @@ class AsyncCrud(Generic[ModelType]):
model: ClassVar[type[DeclarativeBase]] model: ClassVar[type[DeclarativeBase]]
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
m2m_fields: ClassVar[M2MFieldType | None] = None
@overload @overload
@classmethod @classmethod
@@ -52,6 +54,62 @@ class AsyncCrud(Generic[ModelType]):
as_response: Literal[False] = ..., as_response: Literal[False] = ...,
) -> ModelType: ... ) -> ModelType: ...
@classmethod
async def _resolve_m2m(
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
*,
only_set: bool = False,
) -> dict[str, list[Any]]:
"""Resolve M2M fields from a Pydantic schema into related model instances.
Args:
session: DB async session
obj: Pydantic model containing M2M ID fields
only_set: If True, only process fields explicitly set on the schema
Returns:
Dict mapping relationship attr names to lists of related instances
"""
result: dict[str, list[Any]] = {}
if not cls.m2m_fields:
return result
for schema_field, rel in cls.m2m_fields.items():
rel_attr = rel.property.key
related_model = rel.property.mapper.class_
if only_set and schema_field not in obj.model_fields_set:
continue
ids = getattr(obj, schema_field, None)
if ids is not None:
related = (
(
await session.execute(
select(related_model).where(related_model.id.in_(ids))
)
)
.scalars()
.all()
)
if len(related) != len(ids):
found_ids = {r.id for r in related}
missing = set(ids) - found_ids
raise NotFoundError(
f"Related {related_model.__name__} not found for IDs: {missing}"
)
result[rel_attr] = list(related)
else:
result[rel_attr] = []
return result
@classmethod
def _m2m_schema_fields(cls: type[Self]) -> set[str]:
"""Return the set of schema field names that are M2M fields."""
if not cls.m2m_fields:
return set()
return set(cls.m2m_fields.keys())
@classmethod @classmethod
async def create( async def create(
cls: type[Self], cls: type[Self],
@@ -71,7 +129,17 @@ class AsyncCrud(Generic[ModelType]):
Created model instance or Response wrapping it Created model instance or Response wrapping it
""" """
async with get_transaction(session): async with get_transaction(session):
db_model = cls.model(**obj.model_dump()) m2m_exclude = cls._m2m_schema_fields()
data = (
obj.model_dump(exclude=m2m_exclude) if m2m_exclude else obj.model_dump()
)
db_model = cls.model(**data)
if m2m_exclude:
m2m_resolved = await cls._resolve_m2m(session, obj)
for rel_attr, related_instances in m2m_resolved.items():
setattr(db_model, rel_attr, related_instances)
session.add(db_model) session.add(db_model)
await session.refresh(db_model) await session.refresh(db_model)
result = cast(ModelType, db_model) result = cast(ModelType, db_model)
@@ -299,12 +367,33 @@ class AsyncCrud(Generic[ModelType]):
NotFoundError: If no record found NotFoundError: If no record found
""" """
async with get_transaction(session): async with get_transaction(session):
db_model = await cls.get(session=session, filters=filters) m2m_exclude = cls._m2m_schema_fields()
# Eagerly load M2M relationships that will be updated so that
# setattr does not trigger a lazy load (which fails in async).
m2m_load_options: list[Any] = []
if m2m_exclude and cls.m2m_fields:
for schema_field, rel in cls.m2m_fields.items():
if schema_field in obj.model_fields_set:
m2m_load_options.append(selectinload(rel))
db_model = await cls.get(
session=session,
filters=filters,
load_options=m2m_load_options or None,
)
values = obj.model_dump( values = obj.model_dump(
exclude_unset=exclude_unset, exclude_none=exclude_none exclude_unset=exclude_unset,
exclude_none=exclude_none,
exclude=m2m_exclude,
) )
for key, value in values.items(): for key, value in values.items():
setattr(db_model, key, value) setattr(db_model, key, value)
if m2m_exclude:
m2m_resolved = await cls._resolve_m2m(session, obj, only_set=True)
for rel_attr, related_instances in m2m_resolved.items():
setattr(db_model, rel_attr, related_instances)
await session.refresh(db_model) await session.refresh(db_model)
if as_response: if as_response:
return Response(data=db_model) return Response(data=db_model)
@@ -578,12 +667,16 @@ def CrudFactory(
model: type[ModelType], model: type[ModelType],
*, *,
searchable_fields: Sequence[SearchFieldType] | None = None, searchable_fields: Sequence[SearchFieldType] | None = None,
m2m_fields: M2MFieldType | None = None,
) -> type[AsyncCrud[ModelType]]: ) -> type[AsyncCrud[ModelType]]:
"""Create a CRUD class for a specific model. """Create a CRUD class for a specific model.
Args: Args:
model: SQLAlchemy model class model: SQLAlchemy model class
searchable_fields: Optional list of searchable fields searchable_fields: Optional list of searchable fields
m2m_fields: Optional mapping for many-to-many relationships.
Maps schema field names (containing lists of IDs) to
SQLAlchemy relationship attributes.
Returns: Returns:
AsyncCrud subclass bound to the model AsyncCrud subclass bound to the model
@@ -601,10 +694,20 @@ def CrudFactory(
searchable_fields=[User.username, User.email, (User.role, Role.name)] searchable_fields=[User.username, User.email, (User.role, Role.name)]
) )
# With many-to-many fields:
# Schema has `tag_ids: list[UUID]`, model has `tags` relationship to Tag
PostCrud = CrudFactory(
Post,
m2m_fields={"tag_ids": Post.tags},
)
# Usage # Usage
user = await UserCrud.get(session, [User.id == 1]) user = await UserCrud.get(session, [User.id == 1])
posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id]) posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id])
# Create with M2M - tag_ids are automatically resolved
post = await PostCrud.create(session, PostCreate(title="Hello", tag_ids=[id1, id2]))
# With search # With search
result = await UserCrud.paginate(session, search="john") result = await UserCrud.paginate(session, search="john")
@@ -628,6 +731,7 @@ def CrudFactory(
{ {
"model": model, "model": model,
"searchable_fields": searchable_fields, "searchable_fields": searchable_fields,
"m2m_fields": m2m_fields,
}, },
) )
return cast(type[AsyncCrud[ModelType]], cls) return cast(type[AsyncCrud[ModelType]], cls)

View File

@@ -183,7 +183,7 @@ def generate_error_responses(
"content": { "content": {
"application/json": { "application/json": {
"example": { "example": {
"data": None, "data": api_error.data,
"status": ResponseStatus.FAIL.value, "status": ResponseStatus.FAIL.value,
"message": api_error.msg, "message": api_error.msg,
"description": api_error.desc, "description": api_error.desc,

View File

@@ -7,7 +7,7 @@ from fastapi.exceptions import RequestValidationError, ResponseValidationError
from fastapi.openapi.utils import get_openapi from fastapi.openapi.utils import get_openapi
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from ..schemas import ResponseStatus from ..schemas import ErrorResponse, ResponseStatus
from .exceptions import ApiException from .exceptions import ApiException
@@ -54,16 +54,16 @@ def _register_exception_handlers(app: FastAPI) -> None:
async def api_exception_handler(request: Request, exc: ApiException) -> Response: async def api_exception_handler(request: Request, exc: ApiException) -> Response:
"""Handle custom API exceptions with structured response.""" """Handle custom API exceptions with structured response."""
api_error = exc.api_error api_error = exc.api_error
error_response = ErrorResponse(
data=api_error.data,
message=api_error.msg,
description=api_error.desc,
error_code=api_error.err_code,
)
return JSONResponse( return JSONResponse(
status_code=api_error.code, status_code=api_error.code,
content={ content=error_response.model_dump(),
"data": None,
"status": ResponseStatus.FAIL.value,
"message": api_error.msg,
"description": api_error.desc,
"error_code": api_error.err_code,
},
) )
@app.exception_handler(RequestValidationError) @app.exception_handler(RequestValidationError)
@@ -83,15 +83,15 @@ def _register_exception_handlers(app: FastAPI) -> None:
@app.exception_handler(Exception) @app.exception_handler(Exception)
async def generic_exception_handler(request: Request, exc: Exception) -> Response: async def generic_exception_handler(request: Request, exc: Exception) -> Response:
"""Handle all unhandled exceptions with a generic 500 response.""" """Handle all unhandled exceptions with a generic 500 response."""
error_response = ErrorResponse(
message="Internal Server Error",
description="An unexpected error occurred. Please try again later.",
error_code="SERVER-500",
)
return JSONResponse( return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={ content=error_response.model_dump(),
"data": None,
"status": ResponseStatus.FAIL.value,
"message": "Internal Server Error",
"description": "An unexpected error occurred. Please try again later.",
"error_code": "SERVER-500",
},
) )
@@ -116,15 +116,16 @@ def _format_validation_error(
} }
) )
error_response = ErrorResponse(
data={"errors": formatted_errors},
message="Validation Error",
description=f"{len(formatted_errors)} validation error(s) detected",
error_code="VAL-422",
)
return JSONResponse( return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content={ content=error_response.model_dump(),
"data": {"errors": formatted_errors},
"status": ResponseStatus.FAIL.value,
"message": "Validation Error",
"description": f"{len(formatted_errors)} validation error(s) detected",
"error_code": "VAL-422",
},
) )

View File

@@ -0,0 +1,10 @@
"""Prometheus metrics integration for FastAPI applications."""
from .handler import init_metrics
from .registry import Metric, MetricsRegistry
__all__ = [
"Metric",
"MetricsRegistry",
"init_metrics",
]

View File

@@ -0,0 +1,73 @@
"""Prometheus metrics endpoint for FastAPI applications."""
import asyncio
import os
from fastapi import FastAPI
from fastapi.responses import Response
from prometheus_client import (
CONTENT_TYPE_LATEST,
CollectorRegistry,
generate_latest,
multiprocess,
)
from ..logger import get_logger
from .registry import MetricsRegistry
logger = get_logger()
def _is_multiprocess() -> bool:
"""Check if prometheus multi-process mode is enabled."""
return "PROMETHEUS_MULTIPROC_DIR" in os.environ
def init_metrics(
app: FastAPI,
registry: MetricsRegistry,
*,
path: str = "/metrics",
) -> FastAPI:
"""Register a Prometheus ``/metrics`` endpoint on a FastAPI app.
Args:
app: FastAPI application instance.
registry: A :class:`MetricsRegistry` containing providers and collectors.
path: URL path for the metrics endpoint (default ``/metrics``).
Returns:
The same FastAPI instance (for chaining).
Example:
from fastapi import FastAPI
from fastapi_toolsets.metrics import MetricsRegistry, init_metrics
metrics = MetricsRegistry()
app = FastAPI()
init_metrics(app, registry=metrics)
"""
for provider in registry.get_providers():
logger.debug("Initialising metric provider '%s'", provider.name)
provider.func()
collectors = registry.get_collectors()
@app.get(path, include_in_schema=False)
async def metrics_endpoint() -> Response:
for collector in collectors:
if asyncio.iscoroutinefunction(collector.func):
await collector.func()
else:
collector.func()
if _is_multiprocess():
prom_registry = CollectorRegistry()
multiprocess.MultiProcessCollector(prom_registry)
output = generate_latest(prom_registry)
else:
output = generate_latest()
return Response(content=output, media_type=CONTENT_TYPE_LATEST)
return app

View File

@@ -0,0 +1,122 @@
"""Metrics registry with decorator-based registration."""
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any, cast
from ..logger import get_logger
logger = get_logger()
@dataclass
class Metric:
"""A metric definition with metadata."""
name: str
func: Callable[..., Any]
collect: bool = field(default=False)
class MetricsRegistry:
"""Registry for managing Prometheus metric providers and collectors.
Example:
from prometheus_client import Counter, Gauge
from fastapi_toolsets.metrics import MetricsRegistry
metrics = MetricsRegistry()
@metrics.register
def http_requests():
return Counter("http_requests_total", "Total HTTP requests", ["method", "status"])
@metrics.register(name="db_pool")
def database_pool_size():
return Gauge("db_pool_size", "Database connection pool size")
@metrics.register(collect=True)
def collect_queue_depth(gauge=Gauge("queue_depth", "Current queue depth")):
gauge.set(get_current_queue_depth())
"""
def __init__(self) -> None:
self._metrics: dict[str, Metric] = {}
def register(
self,
func: Callable[..., Any] | None = None,
*,
name: str | None = None,
collect: bool = False,
) -> Callable[..., Any]:
"""Register a metric provider or collector function.
Can be used as a decorator with or without arguments.
Args:
func: The metric function to register.
name: Metric name (defaults to function name).
collect: If ``True``, the function is called on every scrape.
If ``False`` (default), called once at init time.
Example:
@metrics.register
def my_counter():
return Counter("my_counter", "A counter")
@metrics.register(collect=True, name="queue")
def collect_queue_depth():
gauge.set(compute_depth())
"""
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
metric_name = name or cast(Any, fn).__name__
self._metrics[metric_name] = Metric(
name=metric_name,
func=fn,
collect=collect,
)
return fn
if func is not None:
return decorator(func)
return decorator
def include_registry(self, registry: "MetricsRegistry") -> None:
"""Include another :class:`MetricsRegistry` into this one.
Args:
registry: The registry to merge in.
Raises:
ValueError: If a metric name already exists in the current registry.
Example:
main = MetricsRegistry()
sub = MetricsRegistry()
@sub.register
def sub_metric():
return Counter("sub_total", "Sub counter")
main.include_registry(sub)
"""
for metric_name, definition in registry._metrics.items():
if metric_name in self._metrics:
raise ValueError(
f"Metric '{metric_name}' already exists in the current registry"
)
self._metrics[metric_name] = definition
def get_all(self) -> list[Metric]:
"""Get all registered metric definitions."""
return list(self._metrics.values())
def get_providers(self) -> list[Metric]:
"""Get metric providers (called once at init)."""
return [m for m in self._metrics.values() if not m.collect]
def get_collectors(self) -> list[Metric]:
"""Get collectors (called on each scrape)."""
return [m for m in self._metrics.values() if m.collect]

View File

@@ -1,7 +1,7 @@
"""Pytest helper utilities for FastAPI testing.""" """Pytest helper utilities for FastAPI testing."""
import os import os
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator, Callable
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Any from typing import Any
@@ -22,12 +22,16 @@ from ..db import create_db_context
async def create_async_client( async def create_async_client(
app: Any, app: Any,
base_url: str = "http://test", base_url: str = "http://test",
dependency_overrides: dict[Callable[..., Any], Callable[..., Any]] | None = None,
) -> AsyncGenerator[AsyncClient, None]: ) -> AsyncGenerator[AsyncClient, None]:
"""Create an async httpx client for testing FastAPI applications. """Create an async httpx client for testing FastAPI applications.
Args: Args:
app: FastAPI application instance. app: FastAPI application instance.
base_url: Base URL for requests. Defaults to "http://test". base_url: Base URL for requests. Defaults to "http://test".
dependency_overrides: Optional mapping of original dependencies to
their test replacements. Applied via ``app.dependency_overrides``
before yielding and cleaned up after.
Yields: Yields:
An AsyncClient configured for the app. An AsyncClient configured for the app.
@@ -46,10 +50,37 @@ async def create_async_client(
async def test_endpoint(client: AsyncClient): async def test_endpoint(client: AsyncClient):
response = await client.get("/health") response = await client.get("/health")
assert response.status_code == 200 assert response.status_code == 200
Example with dependency overrides:
from fastapi_toolsets.pytest import create_async_client, create_db_session
from app.db import get_db
@pytest.fixture
async def db_session():
async with create_db_session(DATABASE_URL, Base, cleanup=True) as session:
yield session
@pytest.fixture
async def client(db_session):
async def override():
yield db_session
async with create_async_client(
app, dependency_overrides={get_db: override}
) as c:
yield c
""" """
if dependency_overrides:
app.dependency_overrides.update(dependency_overrides)
transport = ASGITransport(app=app) transport = ASGITransport(app=app)
try:
async with AsyncClient(transport=transport, base_url=base_url) as client: async with AsyncClient(transport=transport, base_url=base_url) as client:
yield client yield client
finally:
if dependency_overrides:
for key in dependency_overrides:
app.dependency_overrides.pop(key, None)
@asynccontextmanager @asynccontextmanager
@@ -60,6 +91,7 @@ async def create_db_session(
echo: bool = False, echo: bool = False,
expire_on_commit: bool = False, expire_on_commit: bool = False,
drop_tables: bool = True, drop_tables: bool = True,
cleanup: bool = False,
) -> AsyncGenerator[AsyncSession, None]: ) -> AsyncGenerator[AsyncSession, None]:
"""Create a database session for testing. """Create a database session for testing.
@@ -72,6 +104,8 @@ async def create_db_session(
echo: Enable SQLAlchemy query logging. Defaults to False. echo: Enable SQLAlchemy query logging. Defaults to False.
expire_on_commit: Expire objects after commit. Defaults to False. expire_on_commit: Expire objects after commit. Defaults to False.
drop_tables: Drop tables after test. Defaults to True. drop_tables: Drop tables after test. Defaults to True.
cleanup: Truncate all tables after test using
:func:`cleanup_tables`. Defaults to False.
Yields: Yields:
An AsyncSession ready for database operations. An AsyncSession ready for database operations.
@@ -84,7 +118,9 @@ async def create_db_session(
@pytest.fixture @pytest.fixture
async def db_session(): async def db_session():
async with create_db_session(DATABASE_URL, Base) as session: async with create_db_session(
DATABASE_URL, Base, cleanup=True
) as session:
yield session yield session
async def test_create_user(db_session: AsyncSession): async def test_create_user(db_session: AsyncSession):
@@ -106,6 +142,9 @@ async def create_db_session(
async with get_session() as session: async with get_session() as session:
yield session yield session
if cleanup:
await cleanup_tables(session, base)
if drop_tables: if drop_tables:
async with engine.begin() as conn: async with engine.begin() as conn:
await conn.run_sync(base.metadata.drop_all) await conn.run_sync(base.metadata.drop_all)
@@ -193,7 +232,7 @@ async def create_worker_database(
Example: Example:
from fastapi_toolsets.pytest import ( from fastapi_toolsets.pytest import (
create_worker_database, create_db_session, cleanup_tables create_worker_database, create_db_session,
) )
DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost/test_db" DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost/test_db"
@@ -205,9 +244,10 @@ async def create_worker_database(
@pytest.fixture @pytest.fixture
async def db_session(worker_db_url): async def db_session(worker_db_url):
async with create_db_session(worker_db_url, Base) as session: async with create_db_session(
worker_db_url, Base, cleanup=True
) as session:
yield session yield session
await cleanup_tables(session, Base)
""" """
worker_url = worker_database_url( worker_url = worker_database_url(
database_url=database_url, default_test_db=default_test_db database_url=database_url, default_test_db=default_test_db

View File

@@ -1,7 +1,7 @@
"""Base Pydantic schemas for API responses.""" """Base Pydantic schemas for API responses."""
from enum import Enum from enum import Enum
from typing import ClassVar, Generic, TypeVar from typing import Any, ClassVar, Generic, TypeVar
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
@@ -50,6 +50,7 @@ class ApiError(PydanticBase):
msg: str msg: str
desc: str desc: str
err_code: str err_code: str
data: Any | None = None
class BaseResponse(PydanticBase): class BaseResponse(PydanticBase):
@@ -84,7 +85,7 @@ class ErrorResponse(BaseResponse):
status: ResponseStatus = ResponseStatus.FAIL status: ResponseStatus = ResponseStatus.FAIL
description: str | None = None description: str | None = None
data: None = None data: Any | None = None
class Pagination(PydanticBase): class Pagination(PydanticBase):

View File

@@ -5,7 +5,7 @@ import uuid
import pytest import pytest
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import ForeignKey, String, Uuid from sqlalchemy import Column, ForeignKey, String, Table, Uuid
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
@@ -56,6 +56,25 @@ class User(Base):
role: Mapped[Role | None] = relationship(back_populates="users") role: Mapped[Role | None] = relationship(back_populates="users")
class Tag(Base):
"""Test tag model."""
__tablename__ = "tags"
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
name: Mapped[str] = mapped_column(String(50), unique=True)
post_tags = Table(
"post_tags",
Base.metadata,
Column(
"post_id", Uuid, ForeignKey("posts.id", ondelete="CASCADE"), primary_key=True
),
Column("tag_id", Uuid, ForeignKey("tags.id", ondelete="CASCADE"), primary_key=True),
)
class Post(Base): class Post(Base):
"""Test post model.""" """Test post model."""
@@ -67,6 +86,8 @@ class Post(Base):
is_published: Mapped[bool] = mapped_column(default=False) is_published: Mapped[bool] = mapped_column(default=False)
author_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id")) author_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"))
tags: Mapped[list[Tag]] = relationship(secondary=post_tags)
# ============================================================================= # =============================================================================
# Test Schemas # Test Schemas
@@ -105,6 +126,13 @@ class UserUpdate(BaseModel):
role_id: uuid.UUID | None = None role_id: uuid.UUID | None = None
class TagCreate(BaseModel):
"""Schema for creating a tag."""
id: uuid.UUID | None = None
name: str
class PostCreate(BaseModel): class PostCreate(BaseModel):
"""Schema for creating a post.""" """Schema for creating a post."""
@@ -123,6 +151,26 @@ class PostUpdate(BaseModel):
is_published: bool | None = None is_published: bool | None = None
class PostM2MCreate(BaseModel):
"""Schema for creating a post with M2M tag IDs."""
id: uuid.UUID | None = None
title: str
content: str = ""
is_published: bool = False
author_id: uuid.UUID
tag_ids: list[uuid.UUID] = []
class PostM2MUpdate(BaseModel):
"""Schema for updating a post with M2M tag IDs."""
title: str | None = None
content: str | None = None
is_published: bool | None = None
tag_ids: list[uuid.UUID] | None = None
# ============================================================================= # =============================================================================
# CRUD Classes # CRUD Classes
# ============================================================================= # =============================================================================
@@ -130,6 +178,8 @@ class PostUpdate(BaseModel):
RoleCrud = CrudFactory(Role) RoleCrud = CrudFactory(Role)
UserCrud = CrudFactory(User) UserCrud = CrudFactory(User)
PostCrud = CrudFactory(Post) PostCrud = CrudFactory(Post)
TagCrud = CrudFactory(Tag)
PostM2MCrud = CrudFactory(Post, m2m_fields={"tag_ids": Post.tags})
# ============================================================================= # =============================================================================

View File

@@ -4,6 +4,7 @@ import uuid
import pytest import pytest
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from fastapi_toolsets.crud import CrudFactory from fastapi_toolsets.crud import CrudFactory
from fastapi_toolsets.crud.factory import AsyncCrud from fastapi_toolsets.crud.factory import AsyncCrud
@@ -13,10 +14,15 @@ from .conftest import (
Post, Post,
PostCreate, PostCreate,
PostCrud, PostCrud,
PostM2MCreate,
PostM2MCrud,
PostM2MUpdate,
Role, Role,
RoleCreate, RoleCreate,
RoleCrud, RoleCrud,
RoleUpdate, RoleUpdate,
TagCreate,
TagCrud,
User, User,
UserCreate, UserCreate,
UserCrud, UserCrud,
@@ -812,3 +818,383 @@ class TestAsResponse:
assert isinstance(result, Response) assert isinstance(result, Response)
assert result.data is None assert result.data is None
class TestCrudFactoryM2M:
"""Tests for CrudFactory with m2m_fields parameter."""
def test_creates_crud_with_m2m_fields(self):
"""CrudFactory configures m2m_fields on the class."""
crud = CrudFactory(Post, m2m_fields={"tag_ids": Post.tags})
assert crud.m2m_fields is not None
assert "tag_ids" in crud.m2m_fields
def test_creates_crud_without_m2m_fields(self):
"""CrudFactory without m2m_fields has None."""
crud = CrudFactory(Post)
assert crud.m2m_fields is None
def test_m2m_schema_fields(self):
"""_m2m_schema_fields returns correct field names."""
crud = CrudFactory(Post, m2m_fields={"tag_ids": Post.tags})
assert crud._m2m_schema_fields() == {"tag_ids"}
def test_m2m_schema_fields_empty_when_none(self):
"""_m2m_schema_fields returns empty set when no m2m_fields."""
crud = CrudFactory(Post)
assert crud._m2m_schema_fields() == set()
@pytest.mark.anyio
async def test_resolve_m2m_returns_empty_without_m2m_fields(
self, db_session: AsyncSession
):
"""_resolve_m2m returns empty dict when m2m_fields is not configured."""
from pydantic import BaseModel
class DummySchema(BaseModel):
name: str
result = await PostCrud._resolve_m2m(db_session, DummySchema(name="test"))
assert result == {}
class TestM2MResolveNone:
"""Tests for _resolve_m2m when IDs field is None."""
@pytest.mark.anyio
async def test_resolve_m2m_with_none_ids(self, db_session: AsyncSession):
"""_resolve_m2m sets empty list when ids value is None."""
from pydantic import BaseModel
class SchemaWithNullableTags(BaseModel):
tag_ids: list[uuid.UUID] | None = None
result = await PostM2MCrud._resolve_m2m(
db_session, SchemaWithNullableTags(tag_ids=None)
)
assert result == {"tags": []}
class TestM2MCreate:
"""Tests for create with M2M relationships."""
@pytest.mark.anyio
async def test_create_with_m2m_tags(self, db_session: AsyncSession):
"""Create a post with M2M tags resolves tag IDs."""
user = await UserCrud.create(
db_session, UserCreate(username="author", email="author@test.com")
)
tag1 = await TagCrud.create(db_session, TagCreate(name="python"))
tag2 = await TagCrud.create(db_session, TagCreate(name="fastapi"))
post = await PostM2MCrud.create(
db_session,
PostM2MCreate(
title="M2M Post",
author_id=user.id,
tag_ids=[tag1.id, tag2.id],
),
)
assert post.id is not None
assert post.title == "M2M Post"
# Reload with tags eagerly loaded
loaded = await PostM2MCrud.get(
db_session,
[Post.id == post.id],
load_options=[selectinload(Post.tags)],
)
tag_names = sorted(t.name for t in loaded.tags)
assert tag_names == ["fastapi", "python"]
@pytest.mark.anyio
async def test_create_with_empty_m2m(self, db_session: AsyncSession):
"""Create a post with empty tag_ids list works."""
user = await UserCrud.create(
db_session, UserCreate(username="author", email="author@test.com")
)
post = await PostM2MCrud.create(
db_session,
PostM2MCreate(
title="No Tags Post",
author_id=user.id,
tag_ids=[],
),
)
assert post.id is not None
loaded = await PostM2MCrud.get(
db_session,
[Post.id == post.id],
load_options=[selectinload(Post.tags)],
)
assert loaded.tags == []
@pytest.mark.anyio
async def test_create_with_default_m2m(self, db_session: AsyncSession):
"""Create a post using default tag_ids (empty list) works."""
user = await UserCrud.create(
db_session, UserCreate(username="author", email="author@test.com")
)
post = await PostM2MCrud.create(
db_session,
PostM2MCreate(title="Default Tags", author_id=user.id),
)
loaded = await PostM2MCrud.get(
db_session,
[Post.id == post.id],
load_options=[selectinload(Post.tags)],
)
assert loaded.tags == []
@pytest.mark.anyio
async def test_create_with_nonexistent_tag_id_raises(
self, db_session: AsyncSession
):
"""Create with a nonexistent tag ID raises NotFoundError."""
user = await UserCrud.create(
db_session, UserCreate(username="author", email="author@test.com")
)
tag = await TagCrud.create(db_session, TagCreate(name="valid"))
fake_id = uuid.uuid4()
with pytest.raises(NotFoundError):
await PostM2MCrud.create(
db_session,
PostM2MCreate(
title="Bad Tags",
author_id=user.id,
tag_ids=[tag.id, fake_id],
),
)
@pytest.mark.anyio
async def test_create_with_single_tag(self, db_session: AsyncSession):
"""Create with a single tag works correctly."""
user = await UserCrud.create(
db_session, UserCreate(username="author", email="author@test.com")
)
tag = await TagCrud.create(db_session, TagCreate(name="solo"))
post = await PostM2MCrud.create(
db_session,
PostM2MCreate(
title="Single Tag",
author_id=user.id,
tag_ids=[tag.id],
),
)
loaded = await PostM2MCrud.get(
db_session,
[Post.id == post.id],
load_options=[selectinload(Post.tags)],
)
assert len(loaded.tags) == 1
assert loaded.tags[0].name == "solo"
class TestM2MUpdate:
"""Tests for update with M2M relationships."""
@pytest.mark.anyio
async def test_update_m2m_tags(self, db_session: AsyncSession):
"""Update replaces M2M tags when tag_ids is set."""
user = await UserCrud.create(
db_session, UserCreate(username="author", email="author@test.com")
)
tag1 = await TagCrud.create(db_session, TagCreate(name="old_tag"))
tag2 = await TagCrud.create(db_session, TagCreate(name="new_tag"))
# Create with tag1
post = await PostM2MCrud.create(
db_session,
PostM2MCreate(
title="Update Test",
author_id=user.id,
tag_ids=[tag1.id],
),
)
# Update to tag2
updated = await PostM2MCrud.update(
db_session,
PostM2MUpdate(tag_ids=[tag2.id]),
[Post.id == post.id],
)
loaded = await PostM2MCrud.get(
db_session,
[Post.id == updated.id],
load_options=[selectinload(Post.tags)],
)
assert len(loaded.tags) == 1
assert loaded.tags[0].name == "new_tag"
@pytest.mark.anyio
async def test_update_without_m2m_preserves_tags(self, db_session: AsyncSession):
"""Update without setting tag_ids preserves existing tags."""
user = await UserCrud.create(
db_session, UserCreate(username="author", email="author@test.com")
)
tag = await TagCrud.create(db_session, TagCreate(name="keep_me"))
post = await PostM2MCrud.create(
db_session,
PostM2MCreate(
title="Keep Tags",
author_id=user.id,
tag_ids=[tag.id],
),
)
# Update only title, tag_ids not set
await PostM2MCrud.update(
db_session,
PostM2MUpdate(title="Updated Title"),
[Post.id == post.id],
)
loaded = await PostM2MCrud.get(
db_session,
[Post.id == post.id],
load_options=[selectinload(Post.tags)],
)
assert loaded.title == "Updated Title"
assert len(loaded.tags) == 1
assert loaded.tags[0].name == "keep_me"
@pytest.mark.anyio
async def test_update_clear_m2m_tags(self, db_session: AsyncSession):
"""Update with empty tag_ids clears all tags."""
user = await UserCrud.create(
db_session, UserCreate(username="author", email="author@test.com")
)
tag = await TagCrud.create(db_session, TagCreate(name="remove_me"))
post = await PostM2MCrud.create(
db_session,
PostM2MCreate(
title="Clear Tags",
author_id=user.id,
tag_ids=[tag.id],
),
)
# Explicitly set tag_ids to empty list
await PostM2MCrud.update(
db_session,
PostM2MUpdate(tag_ids=[]),
[Post.id == post.id],
)
loaded = await PostM2MCrud.get(
db_session,
[Post.id == post.id],
load_options=[selectinload(Post.tags)],
)
assert loaded.tags == []
@pytest.mark.anyio
async def test_update_m2m_with_nonexistent_id_raises(
self, db_session: AsyncSession
):
"""Update with nonexistent tag ID raises NotFoundError."""
user = await UserCrud.create(
db_session, UserCreate(username="author", email="author@test.com")
)
tag = await TagCrud.create(db_session, TagCreate(name="existing"))
post = await PostM2MCrud.create(
db_session,
PostM2MCreate(
title="Bad Update",
author_id=user.id,
tag_ids=[tag.id],
),
)
fake_id = uuid.uuid4()
with pytest.raises(NotFoundError):
await PostM2MCrud.update(
db_session,
PostM2MUpdate(tag_ids=[fake_id]),
[Post.id == post.id],
)
@pytest.mark.anyio
async def test_update_m2m_and_scalar_fields(self, db_session: AsyncSession):
"""Update both scalar fields and M2M tags together."""
user = await UserCrud.create(
db_session, UserCreate(username="author", email="author@test.com")
)
tag1 = await TagCrud.create(db_session, TagCreate(name="tag1"))
tag2 = await TagCrud.create(db_session, TagCreate(name="tag2"))
post = await PostM2MCrud.create(
db_session,
PostM2MCreate(
title="Original",
author_id=user.id,
tag_ids=[tag1.id],
),
)
# Update title and tags simultaneously
await PostM2MCrud.update(
db_session,
PostM2MUpdate(title="Updated", tag_ids=[tag1.id, tag2.id]),
[Post.id == post.id],
)
loaded = await PostM2MCrud.get(
db_session,
[Post.id == post.id],
load_options=[selectinload(Post.tags)],
)
assert loaded.title == "Updated"
tag_names = sorted(t.name for t in loaded.tags)
assert tag_names == ["tag1", "tag2"]
class TestM2MWithNonM2MCrud:
"""Tests that non-M2M CRUD classes are unaffected."""
@pytest.mark.anyio
async def test_create_without_m2m_unchanged(self, db_session: AsyncSession):
"""Regular PostCrud.create still works without M2M logic."""
from .conftest import PostCreate
user = await UserCrud.create(
db_session, UserCreate(username="author", email="author@test.com")
)
post = await PostCrud.create(
db_session,
PostCreate(title="Plain Post", author_id=user.id),
)
assert post.id is not None
assert post.title == "Plain Post"
@pytest.mark.anyio
async def test_update_without_m2m_unchanged(self, db_session: AsyncSession):
"""Regular PostCrud.update still works without M2M logic."""
from .conftest import PostCreate, PostUpdate
user = await UserCrud.create(
db_session, UserCreate(username="author", email="author@test.com")
)
post = await PostCrud.create(
db_session,
PostCreate(title="Plain Post", author_id=user.id),
)
updated = await PostCrud.update(
db_session,
PostUpdate(title="Updated Plain"),
[Post.id == post.id],
)
assert updated.title == "Updated Plain"

View File

@@ -108,6 +108,24 @@ class TestGenerateErrorResponses:
assert example["status"] == "FAIL" assert example["status"] == "FAIL"
assert example["error_code"] == "RES-404" assert example["error_code"] == "RES-404"
assert example["message"] == "Not Found" assert example["message"] == "Not Found"
assert example["data"] is None
def test_response_example_with_data(self):
"""Generated response includes data when set on ApiError."""
class ErrorWithData(ApiException):
api_error = ApiError(
code=400,
msg="Bad Request",
desc="Invalid input.",
err_code="BAD-400",
data={"details": "some context"},
)
responses = generate_error_responses(ErrorWithData)
example = responses[400]["content"]["application/json"]["example"]
assert example["data"] == {"details": "some context"}
class TestInitExceptionsHandlers: class TestInitExceptionsHandlers:
@@ -137,6 +155,59 @@ class TestInitExceptionsHandlers:
assert data["error_code"] == "RES-404" assert data["error_code"] == "RES-404"
assert data["message"] == "Not Found" assert data["message"] == "Not Found"
def test_handles_api_exception_without_data(self):
"""ApiException without data returns null data field."""
app = FastAPI()
init_exceptions_handlers(app)
@app.get("/error")
async def raise_error():
raise NotFoundError()
client = TestClient(app)
response = client.get("/error")
assert response.status_code == 404
assert response.json()["data"] is None
def test_handles_api_exception_with_data(self):
"""ApiException with data returns the data payload."""
app = FastAPI()
init_exceptions_handlers(app)
class CustomValidationError(ApiException):
api_error = ApiError(
code=422,
msg="Validation Error",
desc="1 validation error(s) detected",
err_code="CUSTOM-422",
data={
"errors": [
{
"field": "email",
"message": "invalid format",
"type": "value_error",
}
]
},
)
@app.get("/error")
async def raise_error():
raise CustomValidationError()
client = TestClient(app)
response = client.get("/error")
assert response.status_code == 422
data = response.json()
assert data["data"] == {
"errors": [
{"field": "email", "message": "invalid format", "type": "value_error"}
]
}
assert data["error_code"] == "CUSTOM-422"
def test_handles_validation_error(self): def test_handles_validation_error(self):
"""Handles validation errors with structured response.""" """Handles validation errors with structured response."""
from pydantic import BaseModel from pydantic import BaseModel

519
tests/test_metrics.py Normal file
View File

@@ -0,0 +1,519 @@
"""Tests for fastapi_toolsets.metrics module."""
import os
import tempfile
from unittest.mock import AsyncMock, MagicMock
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from prometheus_client import REGISTRY, CollectorRegistry, Counter, Gauge
from fastapi_toolsets.metrics import Metric, MetricsRegistry, init_metrics
@pytest.fixture(autouse=True)
def _clean_prometheus_registry():
"""Unregister test collectors from the global registry after each test."""
yield
collectors = list(REGISTRY._names_to_collectors.values())
for collector in collectors:
try:
REGISTRY.unregister(collector)
except Exception:
pass
class TestMetric:
"""Tests for Metric dataclass."""
def test_default_collect_is_false(self):
"""Default collect is False (provider mode)."""
definition = Metric(name="test", func=lambda: None)
assert definition.collect is False
def test_collect_true(self):
"""Collect can be set to True (collector mode)."""
definition = Metric(name="test", func=lambda: None, collect=True)
assert definition.collect is True
class TestMetricsRegistry:
"""Tests for MetricsRegistry class."""
def test_register_with_decorator(self):
"""Register metric with bare decorator."""
registry = MetricsRegistry()
@registry.register
def my_counter():
return Counter("test_counter", "A test counter")
names = [m.name for m in registry.get_all()]
assert "my_counter" in names
def test_register_with_custom_name(self):
"""Register metric with custom name."""
registry = MetricsRegistry()
@registry.register(name="custom_name")
def my_counter():
return Counter("test_counter_2", "A test counter")
definition = registry.get_all()[0]
assert definition.name == "custom_name"
def test_register_as_collector(self):
"""Register metric with collect=True."""
registry = MetricsRegistry()
@registry.register(collect=True)
def collect_something():
pass
definition = registry.get_all()[0]
assert definition.collect is True
def test_register_preserves_function(self):
"""Decorator returns the original function unchanged."""
registry = MetricsRegistry()
def my_func():
return "original"
result = registry.register(my_func)
assert result is my_func
assert result() == "original"
def test_register_parameterized_preserves_function(self):
"""Parameterized decorator returns the original function unchanged."""
registry = MetricsRegistry()
def my_func():
return "original"
result = registry.register(name="custom")(my_func)
assert result is my_func
assert result() == "original"
def test_get_all(self):
"""Get all registered metrics."""
registry = MetricsRegistry()
@registry.register
def metric_a():
pass
@registry.register
def metric_b():
pass
names = {m.name for m in registry.get_all()}
assert names == {"metric_a", "metric_b"}
def test_get_providers(self):
"""Get only provider metrics (collect=False)."""
registry = MetricsRegistry()
@registry.register
def provider():
pass
@registry.register(collect=True)
def collector():
pass
providers = registry.get_providers()
assert len(providers) == 1
assert providers[0].name == "provider"
def test_get_collectors(self):
"""Get only collector metrics (collect=True)."""
registry = MetricsRegistry()
@registry.register
def provider():
pass
@registry.register(collect=True)
def collector():
pass
collectors = registry.get_collectors()
assert len(collectors) == 1
assert collectors[0].name == "collector"
def test_register_overwrites_same_name(self):
"""Registering with the same name overwrites the previous entry."""
registry = MetricsRegistry()
@registry.register(name="metric")
def first():
pass
@registry.register(name="metric")
def second():
pass
assert len(registry.get_all()) == 1
assert registry.get_all()[0].func is second
class TestIncludeRegistry:
"""Tests for MetricsRegistry.include_registry method."""
def test_include_empty_registry(self):
"""Include an empty registry does nothing."""
main = MetricsRegistry()
other = MetricsRegistry()
@main.register
def metric_a():
pass
main.include_registry(other)
assert len(main.get_all()) == 1
def test_include_registry_adds_metrics(self):
"""Include registry adds all metrics from the other registry."""
main = MetricsRegistry()
other = MetricsRegistry()
@main.register
def metric_a():
pass
@other.register
def metric_b():
pass
@other.register
def metric_c():
pass
main.include_registry(other)
names = {m.name for m in main.get_all()}
assert names == {"metric_a", "metric_b", "metric_c"}
def test_include_registry_preserves_collect_flag(self):
"""Include registry preserves the collect flag."""
main = MetricsRegistry()
other = MetricsRegistry()
@other.register(collect=True)
def collector():
pass
main.include_registry(other)
assert main.get_all()[0].collect is True
def test_include_registry_raises_on_duplicate(self):
"""Include registry raises ValueError on duplicate metric names."""
main = MetricsRegistry()
other = MetricsRegistry()
@main.register(name="metric")
def metric_main():
pass
@other.register(name="metric")
def metric_other():
pass
with pytest.raises(ValueError, match="already exists"):
main.include_registry(other)
def test_include_multiple_registries(self):
"""Include multiple registries sequentially."""
main = MetricsRegistry()
sub1 = MetricsRegistry()
sub2 = MetricsRegistry()
@main.register
def base():
pass
@sub1.register
def sub1_metric():
pass
@sub2.register
def sub2_metric():
pass
main.include_registry(sub1)
main.include_registry(sub2)
names = {m.name for m in main.get_all()}
assert names == {"base", "sub1_metric", "sub2_metric"}
class TestInitMetrics:
"""Tests for init_metrics function."""
def test_returns_app(self):
"""Returns the FastAPI app."""
app = FastAPI()
registry = MetricsRegistry()
result = init_metrics(app, registry)
assert result is app
def test_metrics_endpoint_responds(self):
"""The /metrics endpoint returns 200."""
app = FastAPI()
registry = MetricsRegistry()
init_metrics(app, registry)
client = TestClient(app)
response = client.get("/metrics")
assert response.status_code == 200
def test_metrics_endpoint_content_type(self):
"""The /metrics endpoint returns prometheus content type."""
app = FastAPI()
registry = MetricsRegistry()
init_metrics(app, registry)
client = TestClient(app)
response = client.get("/metrics")
assert "text/plain" in response.headers["content-type"]
def test_custom_path(self):
"""Custom path is used for the metrics endpoint."""
app = FastAPI()
registry = MetricsRegistry()
init_metrics(app, registry, path="/custom-metrics")
client = TestClient(app)
assert client.get("/custom-metrics").status_code == 200
assert client.get("/metrics").status_code == 404
def test_providers_called_at_init(self):
"""Provider functions are called once at init time."""
app = FastAPI()
registry = MetricsRegistry()
mock = MagicMock()
@registry.register
def my_provider():
mock()
init_metrics(app, registry)
mock.assert_called_once()
def test_collectors_called_on_scrape(self):
"""Collector functions are called on each scrape."""
app = FastAPI()
registry = MetricsRegistry()
mock = MagicMock()
@registry.register(collect=True)
def my_collector():
mock()
init_metrics(app, registry)
client = TestClient(app)
client.get("/metrics")
client.get("/metrics")
assert mock.call_count == 2
def test_collectors_not_called_at_init(self):
"""Collector functions are not called at init time."""
app = FastAPI()
registry = MetricsRegistry()
mock = MagicMock()
@registry.register(collect=True)
def my_collector():
mock()
init_metrics(app, registry)
mock.assert_not_called()
def test_async_collectors_called_on_scrape(self):
"""Async collector functions are awaited on each scrape."""
app = FastAPI()
registry = MetricsRegistry()
mock = AsyncMock()
@registry.register(collect=True)
async def my_async_collector():
await mock()
init_metrics(app, registry)
client = TestClient(app)
client.get("/metrics")
client.get("/metrics")
assert mock.call_count == 2
def test_mixed_sync_and_async_collectors(self):
"""Both sync and async collectors are called on scrape."""
app = FastAPI()
registry = MetricsRegistry()
sync_mock = MagicMock()
async_mock = AsyncMock()
@registry.register(collect=True)
def sync_collector():
sync_mock()
@registry.register(collect=True)
async def async_collector():
await async_mock()
init_metrics(app, registry)
client = TestClient(app)
client.get("/metrics")
sync_mock.assert_called_once()
async_mock.assert_called_once()
def test_registered_metrics_appear_in_output(self):
"""Metrics created by providers appear in /metrics output."""
app = FastAPI()
registry = MetricsRegistry()
@registry.register
def my_gauge():
g = Gauge("test_gauge_value", "A test gauge")
g.set(42)
return g
init_metrics(app, registry)
client = TestClient(app)
response = client.get("/metrics")
assert b"test_gauge_value" in response.content
assert b"42.0" in response.content
def test_endpoint_not_in_openapi_schema(self):
"""The /metrics endpoint is not included in the OpenAPI schema."""
app = FastAPI()
registry = MetricsRegistry()
init_metrics(app, registry)
schema = app.openapi()
assert "/metrics" not in schema.get("paths", {})
class TestMultiProcessMode:
"""Tests for multi-process Prometheus mode."""
def test_multiprocess_with_env_var(self):
"""Multi-process mode works when PROMETHEUS_MULTIPROC_DIR is set."""
with tempfile.TemporaryDirectory() as tmpdir:
os.environ["PROMETHEUS_MULTIPROC_DIR"] = tmpdir
try:
# Use a separate registry to avoid conflicts with default
prom_registry = CollectorRegistry()
app = FastAPI()
registry = MetricsRegistry()
@registry.register
def mp_counter():
return Counter(
"mp_test_counter",
"A multiprocess counter",
registry=prom_registry,
)
init_metrics(app, registry)
client = TestClient(app)
response = client.get("/metrics")
assert response.status_code == 200
finally:
del os.environ["PROMETHEUS_MULTIPROC_DIR"]
def test_single_process_without_env_var(self):
"""Single-process mode when PROMETHEUS_MULTIPROC_DIR is not set."""
os.environ.pop("PROMETHEUS_MULTIPROC_DIR", None)
app = FastAPI()
registry = MetricsRegistry()
@registry.register
def sp_gauge():
g = Gauge("sp_test_gauge", "A single-process gauge")
g.set(99)
return g
init_metrics(app, registry)
client = TestClient(app)
response = client.get("/metrics")
assert response.status_code == 200
assert b"sp_test_gauge" in response.content
class TestMetricsIntegration:
"""Integration tests for the metrics module."""
def test_full_workflow(self):
"""Full workflow: registry, providers, collectors, endpoint."""
app = FastAPI()
registry = MetricsRegistry()
call_count = {"value": 0}
@registry.register
def request_counter():
return Counter(
"integration_requests_total",
"Total requests",
["method"],
)
@registry.register(collect=True)
def collect_uptime():
call_count["value"] += 1
init_metrics(app, registry)
client = TestClient(app)
response = client.get("/metrics")
assert response.status_code == 200
assert b"integration_requests_total" in response.content
assert call_count["value"] == 1
response = client.get("/metrics")
assert call_count["value"] == 2
def test_multiple_registries_merged(self):
"""Multiple registries can be merged and used together."""
app = FastAPI()
main = MetricsRegistry()
sub = MetricsRegistry()
@main.register
def main_gauge():
g = Gauge("main_gauge_val", "Main gauge")
g.set(1)
return g
@sub.register
def sub_gauge():
g = Gauge("sub_gauge_val", "Sub gauge")
g.set(2)
return g
main.include_registry(sub)
init_metrics(app, main)
client = TestClient(app)
response = client.get("/metrics")
assert b"main_gauge_val" in response.content
assert b"sub_gauge_val" in response.content

View File

@@ -3,7 +3,7 @@
import uuid import uuid
import pytest import pytest
from fastapi import FastAPI from fastapi import Depends, FastAPI
from httpx import AsyncClient from httpx import AsyncClient
from sqlalchemy import select, text from sqlalchemy import select, text
from sqlalchemy.engine import make_url from sqlalchemy.engine import make_url
@@ -236,6 +236,30 @@ class TestCreateAsyncClient:
assert client_ref.is_closed assert client_ref.is_closed
@pytest.mark.anyio
async def test_dependency_overrides_applied_and_cleaned(self):
"""Dependency overrides are applied during the context and removed after."""
app = FastAPI()
async def original_dep() -> str:
return "original"
async def override_dep() -> str:
return "overridden"
@app.get("/dep")
async def dep_endpoint(value: str = Depends(original_dep)):
return {"value": value}
async with create_async_client(
app, dependency_overrides={original_dep: override_dep}
) as client:
response = await client.get("/dep")
assert response.json() == {"value": "overridden"}
# Overrides should be cleaned up
assert original_dep not in app.dependency_overrides
class TestCreateDbSession: class TestCreateDbSession:
"""Tests for create_db_session helper.""" """Tests for create_db_session helper."""
@@ -297,6 +321,22 @@ class TestCreateDbSession:
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as _: async with create_db_session(DATABASE_URL, Base, drop_tables=True) as _:
pass pass
@pytest.mark.anyio
async def test_cleanup_truncates_tables(self):
"""Tables are truncated after session closes when cleanup=True."""
role_id = uuid.uuid4()
async with create_db_session(
DATABASE_URL, Base, cleanup=True, drop_tables=False
) as session:
role = Role(id=role_id, name="will_be_cleaned")
session.add(role)
await session.commit()
# Data should have been truncated, but tables still exist
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
result = await session.execute(select(Role))
assert result.all() == []
class TestGetXdistWorker: class TestGetXdistWorker:
"""Tests for _get_xdist_worker helper.""" """Tests for _get_xdist_worker helper."""

View File

@@ -46,6 +46,31 @@ class TestApiError:
assert error.desc == "The resource was not found." assert error.desc == "The resource was not found."
assert error.err_code == "RES-404" assert error.err_code == "RES-404"
def test_data_defaults_to_none(self):
"""ApiError data field defaults to None."""
error = ApiError(
code=404,
msg="Not Found",
desc="The resource was not found.",
err_code="RES-404",
)
assert error.data is None
def test_create_with_data(self):
"""ApiError can be created with a data payload."""
error = ApiError(
code=422,
msg="Validation Error",
desc="2 validation error(s) detected",
err_code="VAL-422",
data={
"errors": [{"field": "name", "message": "required", "type": "missing"}]
},
)
assert error.data == {
"errors": [{"field": "name", "message": "required", "type": "missing"}]
}
def test_requires_all_fields(self): def test_requires_all_fields(self):
"""ApiError requires all fields.""" """ApiError requires all fields."""
with pytest.raises(ValidationError): with pytest.raises(ValidationError):

20
uv.lock generated
View File

@@ -242,7 +242,7 @@ wheels = [
[[package]] [[package]]
name = "fastapi-toolsets" name = "fastapi-toolsets"
version = "0.8.1" version = "0.10.0"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "asyncpg" }, { name = "asyncpg" },
@@ -256,6 +256,7 @@ dependencies = [
[package.optional-dependencies] [package.optional-dependencies]
dev = [ dev = [
{ name = "coverage" }, { name = "coverage" },
{ name = "prometheus-client" },
{ name = "pytest" }, { name = "pytest" },
{ name = "pytest-anyio" }, { name = "pytest-anyio" },
{ name = "pytest-cov" }, { name = "pytest-cov" },
@@ -263,6 +264,9 @@ dev = [
{ name = "ruff" }, { name = "ruff" },
{ name = "ty" }, { name = "ty" },
] ]
metrics = [
{ name = "prometheus-client" },
]
test = [ test = [
{ name = "coverage" }, { name = "coverage" },
{ name = "pytest" }, { name = "pytest" },
@@ -276,8 +280,9 @@ requires-dist = [
{ name = "asyncpg", specifier = ">=0.29.0" }, { name = "asyncpg", specifier = ">=0.29.0" },
{ name = "coverage", marker = "extra == 'test'", specifier = ">=7.0.0" }, { name = "coverage", marker = "extra == 'test'", specifier = ">=7.0.0" },
{ name = "fastapi", specifier = ">=0.100.0" }, { name = "fastapi", specifier = ">=0.100.0" },
{ name = "fastapi-toolsets", extras = ["test"], marker = "extra == 'dev'" }, { name = "fastapi-toolsets", extras = ["metrics", "test"], marker = "extra == 'dev'" },
{ name = "httpx", specifier = ">=0.25.0" }, { name = "httpx", specifier = ">=0.25.0" },
{ name = "prometheus-client", marker = "extra == 'metrics'", specifier = ">=0.20.0" },
{ name = "pydantic", specifier = ">=2.0" }, { name = "pydantic", specifier = ">=2.0" },
{ name = "pytest", marker = "extra == 'test'", specifier = ">=8.0.0" }, { name = "pytest", marker = "extra == 'test'", specifier = ">=8.0.0" },
{ name = "pytest-anyio", marker = "extra == 'test'", specifier = ">=0.0.0" }, { name = "pytest-anyio", marker = "extra == 'test'", specifier = ">=0.0.0" },
@@ -288,7 +293,7 @@ requires-dist = [
{ name = "ty", marker = "extra == 'dev'", specifier = ">=0.0.1a0" }, { name = "ty", marker = "extra == 'dev'", specifier = ">=0.0.1a0" },
{ name = "typer", specifier = ">=0.9.0" }, { name = "typer", specifier = ">=0.9.0" },
] ]
provides-extras = ["test", "dev"] provides-extras = ["metrics", "test", "dev"]
[[package]] [[package]]
name = "greenlet" name = "greenlet"
@@ -436,6 +441,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
] ]
[[package]]
name = "prometheus-client"
version = "0.24.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/f0/58/a794d23feb6b00fc0c72787d7e87d872a6730dd9ed7c7b3e954637d8f280/prometheus_client-0.24.1.tar.gz", hash = "sha256:7e0ced7fbbd40f7b84962d5d2ab6f17ef88a72504dcf7c0b40737b43b2a461f9", size = 85616, upload-time = "2026-01-14T15:26:26.965Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/74/c3/24a2f845e3917201628ecaba4f18bab4d18a337834c1df2a159ee9d22a42/prometheus_client-0.24.1-py3-none-any.whl", hash = "sha256:150db128af71a5c2482b36e588fc8a6b95e498750da4b17065947c16070f4055", size = 64057, upload-time = "2026-01-14T15:26:24.42Z" },
]
[[package]] [[package]]
name = "pydantic" name = "pydantic"
version = "2.12.5" version = "2.12.5"