refactor: simplify and deduplicate across crud, metrics, cli, and

exceptions
This commit is contained in:
2026-03-01 09:45:49 -05:00
parent 858cca576e
commit f45e4f5f93
4 changed files with 33 additions and 15 deletions

View File

@@ -72,7 +72,7 @@ async def load(
registry = get_fixtures_registry() registry = get_fixtures_registry()
db_context = get_db_context() db_context = get_db_context()
context_list = [c.value for c in contexts] if contexts else [Context.BASE] context_list = list(contexts) if contexts else [Context.BASE]
ordered = registry.resolve_context_dependencies(*context_list) ordered = registry.resolve_context_dependencies(*context_list)

View File

@@ -148,6 +148,14 @@ class AsyncCrud(Generic[ModelType]):
return set() return set()
return set(cls.m2m_fields.keys()) return set(cls.m2m_fields.keys())
@classmethod
def _resolve_facet_fields(
cls: type[Self],
facet_fields: Sequence[FacetFieldType] | None,
) -> Sequence[FacetFieldType] | None:
"""Return facet_fields if given, otherwise fall back to the class-level default."""
return facet_fields if facet_fields is not None else cls.facet_fields
@classmethod @classmethod
def _prepare_filter_by( def _prepare_filter_by(
cls: type[Self], cls: type[Self],
@@ -156,10 +164,10 @@ class AsyncCrud(Generic[ModelType]):
) -> tuple[list[Any], list[Any]]: ) -> tuple[list[Any], list[Any]]:
"""Normalize filter_by and return (filters, joins) to apply to the query.""" """Normalize filter_by and return (filters, joins) to apply to the query."""
if isinstance(filter_by, BaseModel): if isinstance(filter_by, BaseModel):
filter_by = filter_by.model_dump(exclude_none=True) or None filter_by = filter_by.model_dump(exclude_none=True)
if not filter_by: if not filter_by:
return [], [] return [], []
resolved = facet_fields if facet_fields is not None else cls.facet_fields resolved = cls._resolve_facet_fields(facet_fields)
return build_filter_by(filter_by, resolved or []) return build_filter_by(filter_by, resolved or [])
@classmethod @classmethod
@@ -171,15 +179,15 @@ class AsyncCrud(Generic[ModelType]):
search_joins: list[Any], search_joins: list[Any],
) -> dict[str, list[Any]] | None: ) -> dict[str, list[Any]] | None:
"""Build facet filter_attributes, or return None if no facet fields configured.""" """Build facet filter_attributes, or return None if no facet fields configured."""
resolved = facet_fields if facet_fields is not None else cls.facet_fields resolved = cls._resolve_facet_fields(facet_fields)
if not resolved: if not resolved:
return None return None
return await build_facets( return await build_facets(
session, session,
cls.model, cls.model,
resolved, resolved,
base_filters=filters or None, base_filters=filters,
base_joins=search_joins or None, base_joins=search_joins,
) )
@classmethod @classmethod
@@ -202,7 +210,7 @@ class AsyncCrud(Generic[ModelType]):
ValueError: If no facet fields are configured on this CRUD class and none are ValueError: If no facet fields are configured on this CRUD class and none are
provided via ``facet_fields``. provided via ``facet_fields``.
""" """
fields = facet_fields if facet_fields is not None else cls.facet_fields fields = cls._resolve_facet_fields(facet_fields)
if not fields: if not fields:
raise ValueError( raise ValueError(
f"{cls.__name__} has no facet_fields configured. " f"{cls.__name__} has no facet_fields configured. "

View File

@@ -14,6 +14,10 @@ from fastapi.responses import JSONResponse
from ..schemas import ErrorResponse, ResponseStatus from ..schemas import ErrorResponse, ResponseStatus
from .exceptions import ApiException from .exceptions import ApiException
_VALIDATION_LOCATION_PARAMS: frozenset[str] = frozenset(
{"body", "query", "path", "header", "cookie"}
)
def init_exceptions_handlers(app: FastAPI) -> FastAPI: def init_exceptions_handlers(app: FastAPI) -> FastAPI:
"""Register exception handlers and custom OpenAPI schema on a FastAPI app. """Register exception handlers and custom OpenAPI schema on a FastAPI app.
@@ -99,7 +103,7 @@ def _format_validation_error(
for error in errors: for error in errors:
locs = error["loc"] locs = error["loc"]
if locs and locs[0] in ("body", "query", "path", "header", "cookie"): if locs and locs[0] in _VALIDATION_LOCATION_PARAMS:
locs = locs[1:] locs = locs[1:]
field_path = ".".join(str(loc) for loc in locs) field_path = ".".join(str(loc) for loc in locs)
formatted_errors.append( formatted_errors.append(

View File

@@ -53,17 +53,23 @@ def init_metrics(
logger.debug("Initialising metric provider '%s'", provider.name) logger.debug("Initialising metric provider '%s'", provider.name)
provider.func() provider.func()
collectors = registry.get_collectors() # Partition collectors and cache env check at startup — both are stable for the app lifetime.
async_collectors = [
c for c in registry.get_collectors() if asyncio.iscoroutinefunction(c.func)
]
sync_collectors = [
c for c in registry.get_collectors() if not asyncio.iscoroutinefunction(c.func)
]
multiprocess_mode = _is_multiprocess()
@app.get(path, include_in_schema=False) @app.get(path, include_in_schema=False)
async def metrics_endpoint() -> Response: async def metrics_endpoint() -> Response:
for collector in collectors: for collector in sync_collectors:
if asyncio.iscoroutinefunction(collector.func):
await collector.func()
else:
collector.func() collector.func()
for collector in async_collectors:
await collector.func()
if _is_multiprocess(): if multiprocess_mode:
prom_registry = CollectorRegistry() prom_registry = CollectorRegistry()
multiprocess.MultiProcessCollector(prom_registry) multiprocess.MultiProcessCollector(prom_registry)
output = generate_latest(prom_registry) output = generate_latest(prom_registry)