feat: add faceted search in CrudFactory (#97)

* feat: add faceted search in CrudFactory

* feat: add filter_params_schema in CrudFactory

* fix: add missing Raises in build_search_filters docstring

* fix: faceted search

* fix: cov

* fix: documentation/filter_params
This commit is contained in:
d3vyce
2026-02-26 15:23:07 +01:00
committed by GitHub
parent 433dc55fcd
commit 5a08ec2f57
11 changed files with 1026 additions and 9 deletions

View File

@@ -1,15 +1,18 @@
"""Generic async CRUD operations for SQLAlchemy models."""
from ..exceptions import NoSearchableFieldsError
from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError
from .factory import CrudFactory, JoinType, M2MFieldType
from .search import (
FacetFieldType,
SearchConfig,
get_searchable_fields,
)
__all__ = [
"CrudFactory",
"FacetFieldType",
"get_searchable_fields",
"InvalidFacetFilterError",
"JoinType",
"M2MFieldType",
"NoSearchableFieldsError",

View File

@@ -3,14 +3,16 @@
from __future__ import annotations
import base64
import inspect
import json
import uuid as uuid_module
import warnings
from collections.abc import Mapping, Sequence
from collections.abc import Awaitable, Callable, Mapping, Sequence
from datetime import date, datetime
from decimal import Decimal
from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload
from fastapi import Query
from pydantic import BaseModel
from sqlalchemy import Date, DateTime, Float, Integer, Numeric, Uuid, and_, func, select
from sqlalchemy import delete as sql_delete
@@ -24,7 +26,15 @@ from sqlalchemy.sql.roles import WhereHavingRole
from ..db import get_transaction
from ..exceptions import NotFoundError
from ..schemas import CursorPagination, OffsetPagination, PaginatedResponse, Response
from .search import SearchConfig, SearchFieldType, build_search_filters
from .search import (
FacetFieldType,
SearchConfig,
SearchFieldType,
build_facets,
build_filter_by,
build_search_filters,
facet_keys,
)
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
SchemaType = TypeVar("SchemaType", bound=BaseModel)
@@ -50,6 +60,7 @@ class AsyncCrud(Generic[ModelType]):
model: ClassVar[type[DeclarativeBase]]
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
facet_fields: ClassVar[Sequence[FacetFieldType] | None] = None
m2m_fields: ClassVar[M2MFieldType | None] = None
default_load_options: ClassVar[list[ExecutableOption] | None] = None
cursor_column: ClassVar[Any | None] = None
@@ -119,6 +130,52 @@ class AsyncCrud(Generic[ModelType]):
return set()
return set(cls.m2m_fields.keys())
@classmethod
def filter_params(
cls: type[Self],
*,
facet_fields: Sequence[FacetFieldType] | None = None,
) -> Callable[..., Awaitable[dict[str, list[str]]]]:
"""Return a FastAPI dependency that collects facet filter values from query parameters.
Args:
facet_fields: Override the facet fields for this dependency. Falls back to the
class-level ``facet_fields`` if not provided.
Returns:
An async dependency function named ``{Model}FilterParams`` that resolves to a
``dict[str, list[str]]`` containing only the keys that were supplied in the
request (absent/``None`` parameters are excluded).
Raises:
ValueError: If no facet fields are configured on this CRUD class and none are
provided via ``facet_fields``.
"""
fields = facet_fields if facet_fields is not None else cls.facet_fields
if not fields:
raise ValueError(
f"{cls.__name__} has no facet_fields configured. "
"Pass facet_fields= or set them on CrudFactory."
)
keys = facet_keys(fields)
async def dependency(**kwargs: Any) -> dict[str, list[str]]:
return {k: v for k, v in kwargs.items() if v is not None}
dependency.__name__ = f"{cls.model.__name__}FilterParams"
dependency.__signature__ = inspect.Signature( # type: ignore[attr-defined]
parameters=[
inspect.Parameter(
k,
inspect.Parameter.KEYWORD_ONLY,
annotation=list[str] | None,
default=Query(default=None),
)
for k in keys
]
)
return dependency
@overload
@classmethod
async def create( # pragma: no cover
@@ -693,6 +750,8 @@ class AsyncCrud(Generic[ModelType]):
items_per_page: int = 20,
search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None,
facet_fields: Sequence[FacetFieldType] | None = None,
filter_by: dict[str, Any] | BaseModel | None = None,
schema: type[SchemaType],
) -> PaginatedResponse[SchemaType]: ...
@@ -712,6 +771,8 @@ class AsyncCrud(Generic[ModelType]):
items_per_page: int = 20,
search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None,
facet_fields: Sequence[FacetFieldType] | None = None,
filter_by: dict[str, Any] | BaseModel | None = None,
schema: None = ...,
) -> PaginatedResponse[ModelType]: ...
@@ -729,6 +790,8 @@ class AsyncCrud(Generic[ModelType]):
items_per_page: int = 20,
search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None,
facet_fields: Sequence[FacetFieldType] | None = None,
filter_by: dict[str, Any] | BaseModel | None = None,
schema: type[BaseModel] | None = None,
) -> PaginatedResponse[ModelType] | PaginatedResponse[Any]:
"""Get paginated results using offset-based pagination.
@@ -744,6 +807,10 @@ class AsyncCrud(Generic[ModelType]):
items_per_page: Number of items per page
search: Search query string or SearchConfig object
search_fields: Fields to search in (overrides class default)
facet_fields: Columns to compute distinct values for (overrides class default)
filter_by: Dict of {column_key: value} to filter by declared facet fields.
Keys must match the column.key of a facet field. Scalar → equality,
list → IN clause. Raises InvalidFacetFilterError for unknown keys.
schema: Optional Pydantic schema to serialize each item into.
Returns:
@@ -753,6 +820,20 @@ class AsyncCrud(Generic[ModelType]):
offset = (page - 1) * items_per_page
search_joins: list[Any] = []
if isinstance(filter_by, BaseModel):
filter_by = filter_by.model_dump(exclude_none=True) or None
# Build filter_by conditions from declared facet fields
if filter_by:
resolved_facets_for_filter = (
facet_fields if facet_fields is not None else cls.facet_fields
)
fb_filters, fb_joins = build_filter_by(
filter_by, resolved_facets_for_filter or []
)
filters.extend(fb_filters)
search_joins.extend(fb_joins)
# Build search filters
if search:
search_filters, search_joins = build_search_filters(
@@ -817,6 +898,20 @@ class AsyncCrud(Generic[ModelType]):
count_result = await session.execute(count_q)
total_count = count_result.scalar_one()
# Build facets
resolved_facet_fields = (
facet_fields if facet_fields is not None else cls.facet_fields
)
filter_attributes: dict[str, list[Any]] | None = None
if resolved_facet_fields:
filter_attributes = await build_facets(
session,
cls.model,
resolved_facet_fields,
base_filters=filters or None,
base_joins=search_joins or None,
)
return PaginatedResponse(
data=items,
pagination=OffsetPagination(
@@ -825,6 +920,7 @@ class AsyncCrud(Generic[ModelType]):
page=page,
has_more=page * items_per_page < total_count,
),
filter_attributes=filter_attributes,
)
# Backward-compatible - will be removed in v2.0
@@ -845,6 +941,8 @@ class AsyncCrud(Generic[ModelType]):
items_per_page: int = 20,
search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None,
facet_fields: Sequence[FacetFieldType] | None = None,
filter_by: dict[str, Any] | BaseModel | None = None,
schema: type[SchemaType],
) -> PaginatedResponse[SchemaType]: ...
@@ -864,6 +962,8 @@ class AsyncCrud(Generic[ModelType]):
items_per_page: int = 20,
search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None,
facet_fields: Sequence[FacetFieldType] | None = None,
filter_by: dict[str, Any] | BaseModel | None = None,
schema: None = ...,
) -> PaginatedResponse[ModelType]: ...
@@ -881,6 +981,8 @@ class AsyncCrud(Generic[ModelType]):
items_per_page: int = 20,
search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None,
facet_fields: Sequence[FacetFieldType] | None = None,
filter_by: dict[str, Any] | BaseModel | None = None,
schema: type[BaseModel] | None = None,
) -> PaginatedResponse[ModelType] | PaginatedResponse[Any]:
"""Get paginated results using cursor-based pagination.
@@ -899,6 +1001,10 @@ class AsyncCrud(Generic[ModelType]):
items_per_page: Number of items per page (default 20).
search: Search query string or SearchConfig object.
search_fields: Fields to search in (overrides class default).
facet_fields: Columns to compute distinct values for (overrides class default).
filter_by: Dict of {column_key: value} to filter by declared facet fields.
Keys must match the column.key of a facet field. Scalar → equality,
list → IN clause. Raises InvalidFacetFilterError for unknown keys.
schema: Optional Pydantic schema to serialize each item into.
Returns:
@@ -907,6 +1013,20 @@ class AsyncCrud(Generic[ModelType]):
filters = list(filters) if filters else []
search_joins: list[Any] = []
if isinstance(filter_by, BaseModel):
filter_by = filter_by.model_dump(exclude_none=True) or None
# Build filter_by conditions from declared facet fields
if filter_by:
resolved_facets_for_filter = (
facet_fields if facet_fields is not None else cls.facet_fields
)
fb_filters, fb_joins = build_filter_by(
filter_by, resolved_facets_for_filter or []
)
filters.extend(fb_filters)
search_joins.extend(fb_joins)
if cls.cursor_column is None:
raise ValueError(
f"{cls.__name__}.cursor_column is not set. "
@@ -996,6 +1116,20 @@ class AsyncCrud(Generic[ModelType]):
else items_page
)
# Build facets
resolved_facet_fields = (
facet_fields if facet_fields is not None else cls.facet_fields
)
filter_attributes: dict[str, list[Any]] | None = None
if resolved_facet_fields:
filter_attributes = await build_facets(
session,
cls.model,
resolved_facet_fields,
base_filters=filters or None,
base_joins=search_joins or None,
)
return PaginatedResponse(
data=items,
pagination=CursorPagination(
@@ -1004,6 +1138,7 @@ class AsyncCrud(Generic[ModelType]):
items_per_page=items_per_page,
has_more=has_more,
),
filter_attributes=filter_attributes,
)
@@ -1011,6 +1146,7 @@ def CrudFactory(
model: type[ModelType],
*,
searchable_fields: Sequence[SearchFieldType] | None = None,
facet_fields: Sequence[FacetFieldType] | None = None,
m2m_fields: M2MFieldType | None = None,
default_load_options: list[ExecutableOption] | None = None,
cursor_column: Any | None = None,
@@ -1020,6 +1156,9 @@ def CrudFactory(
Args:
model: SQLAlchemy model class
searchable_fields: Optional list of searchable fields
facet_fields: Optional list of columns to compute distinct values for in paginated
responses. Supports direct columns (``User.status``) and relationship tuples
(``(User.role, Role.name)``). Can be overridden per call.
m2m_fields: Optional mapping for many-to-many relationships.
Maps schema field names (containing lists of IDs) to
SQLAlchemy relationship attributes.
@@ -1056,6 +1195,12 @@ def CrudFactory(
m2m_fields={"tag_ids": Post.tags},
)
# With facet fields for filter dropdowns / faceted search:
UserCrud = CrudFactory(
User,
facet_fields=[User.status, User.country, (User.role, Role.name)],
)
# With a fixed cursor column for cursor_paginate:
PostCrud = CrudFactory(
Post,
@@ -1106,6 +1251,7 @@ def CrudFactory(
{
"model": model,
"searchable_fields": searchable_fields,
"facet_fields": facet_fields,
"m2m_fields": m2m_fields,
"default_load_options": default_load_options,
"cursor_column": cursor_column,

View File

@@ -1,19 +1,23 @@
"""Search utilities for AsyncCrud."""
import asyncio
from collections import Counter
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal
from sqlalchemy import String, or_
from sqlalchemy import String, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm.attributes import InstrumentedAttribute
from ..exceptions import NoSearchableFieldsError
from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError
if TYPE_CHECKING:
from sqlalchemy.sql.elements import ColumnElement
SearchFieldType = InstrumentedAttribute[Any] | tuple[InstrumentedAttribute[Any], ...]
FacetFieldType = SearchFieldType
@dataclass
@@ -89,6 +93,9 @@ def build_search_filters(
Returns:
Tuple of (filter_conditions, joins_needed)
Raises:
NoSearchableFieldsError: If no searchable field has been configured
"""
# Normalize input
if isinstance(search, str):
@@ -136,7 +143,7 @@ def build_search_filters(
else:
filters.append(column_as_string.ilike(f"%{query}%"))
if not filters:
if not filters: # pragma: no cover
return [], []
# Combine based on match_mode
@@ -144,3 +151,145 @@ def build_search_filters(
return [or_(*filters)], joins
else:
return filters, joins
def facet_keys(facet_fields: Sequence[FacetFieldType]) -> list[str]:
"""Return a key for each facet field, disambiguating duplicate column keys.
Args:
facet_fields: Sequence of facet fields — either direct columns or
relationship tuples ``(rel, ..., column)``.
Returns:
A list of string keys, one per facet field, in the same order.
"""
raw: list[tuple[str, str | None]] = []
for field in facet_fields:
if isinstance(field, tuple):
rel = field[-2]
column = field[-1]
raw.append((column.key, rel.key))
else:
raw.append((field.key, None))
counts = Counter(col_key for col_key, _ in raw)
keys: list[str] = []
for col_key, rel_key in raw:
if counts[col_key] > 1 and rel_key is not None:
keys.append(f"{rel_key}__{col_key}")
else:
keys.append(col_key)
return keys
async def build_facets(
session: "AsyncSession",
model: type[DeclarativeBase],
facet_fields: Sequence[FacetFieldType],
*,
base_filters: "list[ColumnElement[bool]] | None" = None,
base_joins: list[InstrumentedAttribute[Any]] | None = None,
) -> dict[str, list[Any]]:
"""Return distinct values for each facet field, respecting current filters.
Args:
session: DB async session
model: SQLAlchemy model class
facet_fields: Columns or relationship tuples to facet on
base_filters: Filter conditions already applied to the main query (search + caller filters)
base_joins: Relationship joins already applied to the main query
Returns:
Dict mapping column key to sorted list of distinct non-None values
"""
existing_join_keys: set[str] = {str(j) for j in (base_joins or [])}
keys = facet_keys(facet_fields)
async def _query_facet(field: FacetFieldType, key: str) -> tuple[str, list[Any]]:
if isinstance(field, tuple):
# Relationship chain: (User.role, Role.name) — last element is the column
rels = field[:-1]
column = field[-1]
else:
rels = ()
column = field
q = select(column).select_from(model).distinct()
# Apply base joins (already done on main query, but needed here independently)
for rel in base_joins or []:
q = q.outerjoin(rel)
# Add any extra joins required by this facet field that aren't already in base_joins
for rel in rels:
if str(rel) not in existing_join_keys:
q = q.outerjoin(rel)
if base_filters:
from sqlalchemy import and_
q = q.where(and_(*base_filters))
q = q.order_by(column)
result = await session.execute(q)
values = [row[0] for row in result.all() if row[0] is not None]
return key, values
pairs = await asyncio.gather(
*[_query_facet(f, k) for f, k in zip(facet_fields, keys)]
)
return dict(pairs)
def build_filter_by(
filter_by: dict[str, Any],
facet_fields: Sequence[FacetFieldType],
) -> tuple["list[ColumnElement[bool]]", list[InstrumentedAttribute[Any]]]:
"""Translate a {column_key: value} dict into SQLAlchemy filter conditions.
Args:
filter_by: Mapping of column key to scalar value or list of values
facet_fields: Declared facet fields to validate keys against
Returns:
Tuple of (filter_conditions, joins_needed)
Raises:
InvalidFacetFilterError: If a key in filter_by is not a declared facet field
"""
index: dict[
str, tuple[InstrumentedAttribute[Any], list[InstrumentedAttribute[Any]]]
] = {}
for key, field in zip(facet_keys(facet_fields), facet_fields):
if isinstance(field, tuple):
rels = list(field[:-1])
column = field[-1]
else:
rels = []
column = field
index[key] = (column, rels)
valid_keys = set(index)
filters: list[ColumnElement[bool]] = []
joins: list[InstrumentedAttribute[Any]] = []
added_join_keys: set[str] = set()
for key, value in filter_by.items():
if key not in index:
raise InvalidFacetFilterError(key, valid_keys)
column, rels = index[key]
for rel in rels:
rel_key = str(rel)
if rel_key not in added_join_keys:
joins.append(rel)
added_join_keys.add(rel_key)
if isinstance(value, list):
filters.append(column.in_(value))
else:
filters.append(column == value)
return filters, joins

View File

@@ -5,6 +5,7 @@ from .exceptions import (
ApiException,
ConflictError,
ForbiddenError,
InvalidFacetFilterError,
NoSearchableFieldsError,
NotFoundError,
UnauthorizedError,
@@ -19,6 +20,7 @@ __all__ = [
"ForbiddenError",
"generate_error_responses",
"init_exceptions_handlers",
"InvalidFacetFilterError",
"NoSearchableFieldsError",
"NotFoundError",
"UnauthorizedError",

View File

@@ -102,6 +102,32 @@ class NoSearchableFieldsError(ApiException):
super().__init__(detail)
class InvalidFacetFilterError(ApiException):
"""Raised when filter_by contains a key not declared in facet_fields."""
api_error = ApiError(
code=400,
msg="Invalid Facet Filter",
desc="One or more filter_by keys are not declared as facet fields.",
err_code="FACET-400",
)
def __init__(self, key: str, valid_keys: set[str]) -> None:
"""Initialize the exception.
Args:
key: The unknown filter key provided by the caller
valid_keys: Set of valid keys derived from the declared facet_fields
"""
self.key = key
self.valid_keys = valid_keys
detail = (
f"'{key}' is not a declared facet field. "
f"Valid keys: {sorted(valid_keys) or 'none — set facet_fields on the CRUD class'}."
)
super().__init__(detail)
def generate_error_responses(
*errors: type[ApiException],
) -> dict[int | str, dict[str, Any]]:

View File

@@ -133,3 +133,4 @@ class PaginatedResponse(BaseResponse, Generic[DataT]):
data: list[DataT]
pagination: OffsetPagination | CursorPagination
filter_attributes: dict[str, list[Any]] | None = None