From d5b22a72fd1583b1bd1bb9a32f9b4275ef74f072 Mon Sep 17 00:00:00 2001 From: d3vyce <44915747+d3vyce@users.noreply.github.com> Date: Tue, 17 Feb 2026 13:24:53 +0100 Subject: [PATCH] feat: add a metrics module (#67) --- pyproject.toml | 5 +- src/fastapi_toolsets/metrics/__init__.py | 10 + src/fastapi_toolsets/metrics/handler.py | 73 ++++ src/fastapi_toolsets/metrics/registry.py | 122 ++++++ tests/test_metrics.py | 519 +++++++++++++++++++++++ uv.lock | 18 +- 6 files changed, 744 insertions(+), 3 deletions(-) create mode 100644 src/fastapi_toolsets/metrics/__init__.py create mode 100644 src/fastapi_toolsets/metrics/handler.py create mode 100644 src/fastapi_toolsets/metrics/registry.py create mode 100644 tests/test_metrics.py diff --git a/pyproject.toml b/pyproject.toml index 1bd416f..9811f7b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,9 @@ Repository = "https://github.com/d3vyce/fastapi-toolsets" Issues = "https://github.com/d3vyce/fastapi-toolsets/issues" [project.optional-dependencies] +metrics = [ + "prometheus_client>=0.20.0", +] test = [ "pytest>=8.0.0", "pytest-anyio>=0.0.0", @@ -54,7 +57,7 @@ test = [ "pytest-cov>=4.0.0", ] dev = [ - "fastapi-toolsets[test]", + "fastapi-toolsets[metrics,test]", "ruff>=0.1.0", "ty>=0.0.1a0", ] diff --git a/src/fastapi_toolsets/metrics/__init__.py b/src/fastapi_toolsets/metrics/__init__.py new file mode 100644 index 0000000..c96ecf6 --- /dev/null +++ b/src/fastapi_toolsets/metrics/__init__.py @@ -0,0 +1,10 @@ +"""Prometheus metrics integration for FastAPI applications.""" + +from .handler import init_metrics +from .registry import Metric, MetricsRegistry + +__all__ = [ + "Metric", + "MetricsRegistry", + "init_metrics", +] diff --git a/src/fastapi_toolsets/metrics/handler.py b/src/fastapi_toolsets/metrics/handler.py new file mode 100644 index 0000000..d451f82 --- /dev/null +++ b/src/fastapi_toolsets/metrics/handler.py @@ -0,0 +1,73 @@ +"""Prometheus metrics endpoint for FastAPI applications.""" + +import asyncio +import os + +from fastapi import FastAPI +from fastapi.responses import Response +from prometheus_client import ( + CONTENT_TYPE_LATEST, + CollectorRegistry, + generate_latest, + multiprocess, +) + +from ..logger import get_logger +from .registry import MetricsRegistry + +logger = get_logger() + + +def _is_multiprocess() -> bool: + """Check if prometheus multi-process mode is enabled.""" + return "PROMETHEUS_MULTIPROC_DIR" in os.environ + + +def init_metrics( + app: FastAPI, + registry: MetricsRegistry, + *, + path: str = "/metrics", +) -> FastAPI: + """Register a Prometheus ``/metrics`` endpoint on a FastAPI app. + + Args: + app: FastAPI application instance. + registry: A :class:`MetricsRegistry` containing providers and collectors. + path: URL path for the metrics endpoint (default ``/metrics``). + + Returns: + The same FastAPI instance (for chaining). + + Example: + from fastapi import FastAPI + from fastapi_toolsets.metrics import MetricsRegistry, init_metrics + + metrics = MetricsRegistry() + app = FastAPI() + init_metrics(app, registry=metrics) + """ + for provider in registry.get_providers(): + logger.debug("Initialising metric provider '%s'", provider.name) + provider.func() + + collectors = registry.get_collectors() + + @app.get(path, include_in_schema=False) + async def metrics_endpoint() -> Response: + for collector in collectors: + if asyncio.iscoroutinefunction(collector.func): + await collector.func() + else: + collector.func() + + if _is_multiprocess(): + prom_registry = CollectorRegistry() + multiprocess.MultiProcessCollector(prom_registry) + output = generate_latest(prom_registry) + else: + output = generate_latest() + + return Response(content=output, media_type=CONTENT_TYPE_LATEST) + + return app diff --git a/src/fastapi_toolsets/metrics/registry.py b/src/fastapi_toolsets/metrics/registry.py new file mode 100644 index 0000000..54f76f5 --- /dev/null +++ b/src/fastapi_toolsets/metrics/registry.py @@ -0,0 +1,122 @@ +"""Metrics registry with decorator-based registration.""" + +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any, cast + +from ..logger import get_logger + +logger = get_logger() + + +@dataclass +class Metric: + """A metric definition with metadata.""" + + name: str + func: Callable[..., Any] + collect: bool = field(default=False) + + +class MetricsRegistry: + """Registry for managing Prometheus metric providers and collectors. + + Example: + from prometheus_client import Counter, Gauge + from fastapi_toolsets.metrics import MetricsRegistry + + metrics = MetricsRegistry() + + @metrics.register + def http_requests(): + return Counter("http_requests_total", "Total HTTP requests", ["method", "status"]) + + @metrics.register(name="db_pool") + def database_pool_size(): + return Gauge("db_pool_size", "Database connection pool size") + + @metrics.register(collect=True) + def collect_queue_depth(gauge=Gauge("queue_depth", "Current queue depth")): + gauge.set(get_current_queue_depth()) + """ + + def __init__(self) -> None: + self._metrics: dict[str, Metric] = {} + + def register( + self, + func: Callable[..., Any] | None = None, + *, + name: str | None = None, + collect: bool = False, + ) -> Callable[..., Any]: + """Register a metric provider or collector function. + + Can be used as a decorator with or without arguments. + + Args: + func: The metric function to register. + name: Metric name (defaults to function name). + collect: If ``True``, the function is called on every scrape. + If ``False`` (default), called once at init time. + + Example: + @metrics.register + def my_counter(): + return Counter("my_counter", "A counter") + + @metrics.register(collect=True, name="queue") + def collect_queue_depth(): + gauge.set(compute_depth()) + """ + + def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: + metric_name = name or cast(Any, fn).__name__ + self._metrics[metric_name] = Metric( + name=metric_name, + func=fn, + collect=collect, + ) + return fn + + if func is not None: + return decorator(func) + return decorator + + def include_registry(self, registry: "MetricsRegistry") -> None: + """Include another :class:`MetricsRegistry` into this one. + + Args: + registry: The registry to merge in. + + Raises: + ValueError: If a metric name already exists in the current registry. + + Example: + main = MetricsRegistry() + sub = MetricsRegistry() + + @sub.register + def sub_metric(): + return Counter("sub_total", "Sub counter") + + main.include_registry(sub) + """ + for metric_name, definition in registry._metrics.items(): + if metric_name in self._metrics: + raise ValueError( + f"Metric '{metric_name}' already exists in the current registry" + ) + self._metrics[metric_name] = definition + + def get_all(self) -> list[Metric]: + """Get all registered metric definitions.""" + return list(self._metrics.values()) + + def get_providers(self) -> list[Metric]: + """Get metric providers (called once at init).""" + return [m for m in self._metrics.values() if not m.collect] + + def get_collectors(self) -> list[Metric]: + """Get collectors (called on each scrape).""" + return [m for m in self._metrics.values() if m.collect] diff --git a/tests/test_metrics.py b/tests/test_metrics.py new file mode 100644 index 0000000..380c5db --- /dev/null +++ b/tests/test_metrics.py @@ -0,0 +1,519 @@ +"""Tests for fastapi_toolsets.metrics module.""" + +import os +import tempfile +from unittest.mock import AsyncMock, MagicMock + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from prometheus_client import REGISTRY, CollectorRegistry, Counter, Gauge + +from fastapi_toolsets.metrics import Metric, MetricsRegistry, init_metrics + + +@pytest.fixture(autouse=True) +def _clean_prometheus_registry(): + """Unregister test collectors from the global registry after each test.""" + yield + collectors = list(REGISTRY._names_to_collectors.values()) + for collector in collectors: + try: + REGISTRY.unregister(collector) + except Exception: + pass + + +class TestMetric: + """Tests for Metric dataclass.""" + + def test_default_collect_is_false(self): + """Default collect is False (provider mode).""" + definition = Metric(name="test", func=lambda: None) + assert definition.collect is False + + def test_collect_true(self): + """Collect can be set to True (collector mode).""" + definition = Metric(name="test", func=lambda: None, collect=True) + assert definition.collect is True + + +class TestMetricsRegistry: + """Tests for MetricsRegistry class.""" + + def test_register_with_decorator(self): + """Register metric with bare decorator.""" + registry = MetricsRegistry() + + @registry.register + def my_counter(): + return Counter("test_counter", "A test counter") + + names = [m.name for m in registry.get_all()] + assert "my_counter" in names + + def test_register_with_custom_name(self): + """Register metric with custom name.""" + registry = MetricsRegistry() + + @registry.register(name="custom_name") + def my_counter(): + return Counter("test_counter_2", "A test counter") + + definition = registry.get_all()[0] + assert definition.name == "custom_name" + + def test_register_as_collector(self): + """Register metric with collect=True.""" + registry = MetricsRegistry() + + @registry.register(collect=True) + def collect_something(): + pass + + definition = registry.get_all()[0] + assert definition.collect is True + + def test_register_preserves_function(self): + """Decorator returns the original function unchanged.""" + registry = MetricsRegistry() + + def my_func(): + return "original" + + result = registry.register(my_func) + assert result is my_func + assert result() == "original" + + def test_register_parameterized_preserves_function(self): + """Parameterized decorator returns the original function unchanged.""" + registry = MetricsRegistry() + + def my_func(): + return "original" + + result = registry.register(name="custom")(my_func) + assert result is my_func + assert result() == "original" + + def test_get_all(self): + """Get all registered metrics.""" + registry = MetricsRegistry() + + @registry.register + def metric_a(): + pass + + @registry.register + def metric_b(): + pass + + names = {m.name for m in registry.get_all()} + assert names == {"metric_a", "metric_b"} + + def test_get_providers(self): + """Get only provider metrics (collect=False).""" + registry = MetricsRegistry() + + @registry.register + def provider(): + pass + + @registry.register(collect=True) + def collector(): + pass + + providers = registry.get_providers() + assert len(providers) == 1 + assert providers[0].name == "provider" + + def test_get_collectors(self): + """Get only collector metrics (collect=True).""" + registry = MetricsRegistry() + + @registry.register + def provider(): + pass + + @registry.register(collect=True) + def collector(): + pass + + collectors = registry.get_collectors() + assert len(collectors) == 1 + assert collectors[0].name == "collector" + + def test_register_overwrites_same_name(self): + """Registering with the same name overwrites the previous entry.""" + registry = MetricsRegistry() + + @registry.register(name="metric") + def first(): + pass + + @registry.register(name="metric") + def second(): + pass + + assert len(registry.get_all()) == 1 + assert registry.get_all()[0].func is second + + +class TestIncludeRegistry: + """Tests for MetricsRegistry.include_registry method.""" + + def test_include_empty_registry(self): + """Include an empty registry does nothing.""" + main = MetricsRegistry() + other = MetricsRegistry() + + @main.register + def metric_a(): + pass + + main.include_registry(other) + assert len(main.get_all()) == 1 + + def test_include_registry_adds_metrics(self): + """Include registry adds all metrics from the other registry.""" + main = MetricsRegistry() + other = MetricsRegistry() + + @main.register + def metric_a(): + pass + + @other.register + def metric_b(): + pass + + @other.register + def metric_c(): + pass + + main.include_registry(other) + names = {m.name for m in main.get_all()} + assert names == {"metric_a", "metric_b", "metric_c"} + + def test_include_registry_preserves_collect_flag(self): + """Include registry preserves the collect flag.""" + main = MetricsRegistry() + other = MetricsRegistry() + + @other.register(collect=True) + def collector(): + pass + + main.include_registry(other) + assert main.get_all()[0].collect is True + + def test_include_registry_raises_on_duplicate(self): + """Include registry raises ValueError on duplicate metric names.""" + main = MetricsRegistry() + other = MetricsRegistry() + + @main.register(name="metric") + def metric_main(): + pass + + @other.register(name="metric") + def metric_other(): + pass + + with pytest.raises(ValueError, match="already exists"): + main.include_registry(other) + + def test_include_multiple_registries(self): + """Include multiple registries sequentially.""" + main = MetricsRegistry() + sub1 = MetricsRegistry() + sub2 = MetricsRegistry() + + @main.register + def base(): + pass + + @sub1.register + def sub1_metric(): + pass + + @sub2.register + def sub2_metric(): + pass + + main.include_registry(sub1) + main.include_registry(sub2) + + names = {m.name for m in main.get_all()} + assert names == {"base", "sub1_metric", "sub2_metric"} + + +class TestInitMetrics: + """Tests for init_metrics function.""" + + def test_returns_app(self): + """Returns the FastAPI app.""" + app = FastAPI() + registry = MetricsRegistry() + result = init_metrics(app, registry) + assert result is app + + def test_metrics_endpoint_responds(self): + """The /metrics endpoint returns 200.""" + app = FastAPI() + registry = MetricsRegistry() + init_metrics(app, registry) + + client = TestClient(app) + response = client.get("/metrics") + + assert response.status_code == 200 + + def test_metrics_endpoint_content_type(self): + """The /metrics endpoint returns prometheus content type.""" + app = FastAPI() + registry = MetricsRegistry() + init_metrics(app, registry) + + client = TestClient(app) + response = client.get("/metrics") + + assert "text/plain" in response.headers["content-type"] + + def test_custom_path(self): + """Custom path is used for the metrics endpoint.""" + app = FastAPI() + registry = MetricsRegistry() + init_metrics(app, registry, path="/custom-metrics") + + client = TestClient(app) + assert client.get("/custom-metrics").status_code == 200 + assert client.get("/metrics").status_code == 404 + + def test_providers_called_at_init(self): + """Provider functions are called once at init time.""" + app = FastAPI() + registry = MetricsRegistry() + mock = MagicMock() + + @registry.register + def my_provider(): + mock() + + init_metrics(app, registry) + + mock.assert_called_once() + + def test_collectors_called_on_scrape(self): + """Collector functions are called on each scrape.""" + app = FastAPI() + registry = MetricsRegistry() + mock = MagicMock() + + @registry.register(collect=True) + def my_collector(): + mock() + + init_metrics(app, registry) + + client = TestClient(app) + client.get("/metrics") + client.get("/metrics") + + assert mock.call_count == 2 + + def test_collectors_not_called_at_init(self): + """Collector functions are not called at init time.""" + app = FastAPI() + registry = MetricsRegistry() + mock = MagicMock() + + @registry.register(collect=True) + def my_collector(): + mock() + + init_metrics(app, registry) + + mock.assert_not_called() + + def test_async_collectors_called_on_scrape(self): + """Async collector functions are awaited on each scrape.""" + app = FastAPI() + registry = MetricsRegistry() + mock = AsyncMock() + + @registry.register(collect=True) + async def my_async_collector(): + await mock() + + init_metrics(app, registry) + + client = TestClient(app) + client.get("/metrics") + client.get("/metrics") + + assert mock.call_count == 2 + + def test_mixed_sync_and_async_collectors(self): + """Both sync and async collectors are called on scrape.""" + app = FastAPI() + registry = MetricsRegistry() + sync_mock = MagicMock() + async_mock = AsyncMock() + + @registry.register(collect=True) + def sync_collector(): + sync_mock() + + @registry.register(collect=True) + async def async_collector(): + await async_mock() + + init_metrics(app, registry) + + client = TestClient(app) + client.get("/metrics") + + sync_mock.assert_called_once() + async_mock.assert_called_once() + + def test_registered_metrics_appear_in_output(self): + """Metrics created by providers appear in /metrics output.""" + app = FastAPI() + registry = MetricsRegistry() + + @registry.register + def my_gauge(): + g = Gauge("test_gauge_value", "A test gauge") + g.set(42) + return g + + init_metrics(app, registry) + + client = TestClient(app) + response = client.get("/metrics") + + assert b"test_gauge_value" in response.content + assert b"42.0" in response.content + + def test_endpoint_not_in_openapi_schema(self): + """The /metrics endpoint is not included in the OpenAPI schema.""" + app = FastAPI() + registry = MetricsRegistry() + init_metrics(app, registry) + + schema = app.openapi() + assert "/metrics" not in schema.get("paths", {}) + + +class TestMultiProcessMode: + """Tests for multi-process Prometheus mode.""" + + def test_multiprocess_with_env_var(self): + """Multi-process mode works when PROMETHEUS_MULTIPROC_DIR is set.""" + with tempfile.TemporaryDirectory() as tmpdir: + os.environ["PROMETHEUS_MULTIPROC_DIR"] = tmpdir + try: + # Use a separate registry to avoid conflicts with default + prom_registry = CollectorRegistry() + app = FastAPI() + registry = MetricsRegistry() + + @registry.register + def mp_counter(): + return Counter( + "mp_test_counter", + "A multiprocess counter", + registry=prom_registry, + ) + + init_metrics(app, registry) + + client = TestClient(app) + response = client.get("/metrics") + + assert response.status_code == 200 + finally: + del os.environ["PROMETHEUS_MULTIPROC_DIR"] + + def test_single_process_without_env_var(self): + """Single-process mode when PROMETHEUS_MULTIPROC_DIR is not set.""" + os.environ.pop("PROMETHEUS_MULTIPROC_DIR", None) + + app = FastAPI() + registry = MetricsRegistry() + + @registry.register + def sp_gauge(): + g = Gauge("sp_test_gauge", "A single-process gauge") + g.set(99) + return g + + init_metrics(app, registry) + + client = TestClient(app) + response = client.get("/metrics") + + assert response.status_code == 200 + assert b"sp_test_gauge" in response.content + + +class TestMetricsIntegration: + """Integration tests for the metrics module.""" + + def test_full_workflow(self): + """Full workflow: registry, providers, collectors, endpoint.""" + app = FastAPI() + registry = MetricsRegistry() + call_count = {"value": 0} + + @registry.register + def request_counter(): + return Counter( + "integration_requests_total", + "Total requests", + ["method"], + ) + + @registry.register(collect=True) + def collect_uptime(): + call_count["value"] += 1 + + init_metrics(app, registry) + + client = TestClient(app) + + response = client.get("/metrics") + assert response.status_code == 200 + assert b"integration_requests_total" in response.content + assert call_count["value"] == 1 + + response = client.get("/metrics") + assert call_count["value"] == 2 + + def test_multiple_registries_merged(self): + """Multiple registries can be merged and used together.""" + app = FastAPI() + main = MetricsRegistry() + sub = MetricsRegistry() + + @main.register + def main_gauge(): + g = Gauge("main_gauge_val", "Main gauge") + g.set(1) + return g + + @sub.register + def sub_gauge(): + g = Gauge("sub_gauge_val", "Sub gauge") + g.set(2) + return g + + main.include_registry(sub) + init_metrics(app, main) + + client = TestClient(app) + response = client.get("/metrics") + + assert b"main_gauge_val" in response.content + assert b"sub_gauge_val" in response.content diff --git a/uv.lock b/uv.lock index ba56ce9..b50e812 100644 --- a/uv.lock +++ b/uv.lock @@ -256,6 +256,7 @@ dependencies = [ [package.optional-dependencies] dev = [ { name = "coverage" }, + { name = "prometheus-client" }, { name = "pytest" }, { name = "pytest-anyio" }, { name = "pytest-cov" }, @@ -263,6 +264,9 @@ dev = [ { name = "ruff" }, { name = "ty" }, ] +metrics = [ + { name = "prometheus-client" }, +] test = [ { name = "coverage" }, { name = "pytest" }, @@ -276,8 +280,9 @@ requires-dist = [ { name = "asyncpg", specifier = ">=0.29.0" }, { name = "coverage", marker = "extra == 'test'", specifier = ">=7.0.0" }, { name = "fastapi", specifier = ">=0.100.0" }, - { name = "fastapi-toolsets", extras = ["test"], marker = "extra == 'dev'" }, + { name = "fastapi-toolsets", extras = ["metrics", "test"], marker = "extra == 'dev'" }, { name = "httpx", specifier = ">=0.25.0" }, + { name = "prometheus-client", marker = "extra == 'metrics'", specifier = ">=0.20.0" }, { name = "pydantic", specifier = ">=2.0" }, { name = "pytest", marker = "extra == 'test'", specifier = ">=8.0.0" }, { name = "pytest-anyio", marker = "extra == 'test'", specifier = ">=0.0.0" }, @@ -288,7 +293,7 @@ requires-dist = [ { name = "ty", marker = "extra == 'dev'", specifier = ">=0.0.1a0" }, { name = "typer", specifier = ">=0.9.0" }, ] -provides-extras = ["test", "dev"] +provides-extras = ["metrics", "test", "dev"] [[package]] name = "greenlet" @@ -436,6 +441,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] +[[package]] +name = "prometheus-client" +version = "0.24.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f0/58/a794d23feb6b00fc0c72787d7e87d872a6730dd9ed7c7b3e954637d8f280/prometheus_client-0.24.1.tar.gz", hash = "sha256:7e0ced7fbbd40f7b84962d5d2ab6f17ef88a72504dcf7c0b40737b43b2a461f9", size = 85616, upload-time = "2026-01-14T15:26:26.965Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/74/c3/24a2f845e3917201628ecaba4f18bab4d18a337834c1df2a159ee9d22a42/prometheus_client-0.24.1-py3-none-any.whl", hash = "sha256:150db128af71a5c2482b36e588fc8a6b95e498750da4b17065947c16070f4055", size = 64057, upload-time = "2026-01-14T15:26:24.42Z" }, +] + [[package]] name = "pydantic" version = "2.12.5"