feat: add faceted search in CrudFactory (#97)

* feat: add faceted search in CrudFactory

* feat: add filter_params_schema in CrudFactory

* fix: add missing Raises in build_search_filters docstring

* fix: faceted search

* fix: cov

* fix: documentation/filter_params
This commit is contained in:
d3vyce
2026-02-26 15:23:07 +01:00
committed by GitHub
parent 433dc55fcd
commit 5a08ec2f57
11 changed files with 1026 additions and 9 deletions

View File

@@ -3,14 +3,16 @@
from __future__ import annotations
import base64
import inspect
import json
import uuid as uuid_module
import warnings
from collections.abc import Mapping, Sequence
from collections.abc import Awaitable, Callable, Mapping, Sequence
from datetime import date, datetime
from decimal import Decimal
from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload
from fastapi import Query
from pydantic import BaseModel
from sqlalchemy import Date, DateTime, Float, Integer, Numeric, Uuid, and_, func, select
from sqlalchemy import delete as sql_delete
@@ -24,7 +26,15 @@ from sqlalchemy.sql.roles import WhereHavingRole
from ..db import get_transaction
from ..exceptions import NotFoundError
from ..schemas import CursorPagination, OffsetPagination, PaginatedResponse, Response
from .search import SearchConfig, SearchFieldType, build_search_filters
from .search import (
FacetFieldType,
SearchConfig,
SearchFieldType,
build_facets,
build_filter_by,
build_search_filters,
facet_keys,
)
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
SchemaType = TypeVar("SchemaType", bound=BaseModel)
@@ -50,6 +60,7 @@ class AsyncCrud(Generic[ModelType]):
model: ClassVar[type[DeclarativeBase]]
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
facet_fields: ClassVar[Sequence[FacetFieldType] | None] = None
m2m_fields: ClassVar[M2MFieldType | None] = None
default_load_options: ClassVar[list[ExecutableOption] | None] = None
cursor_column: ClassVar[Any | None] = None
@@ -119,6 +130,52 @@ class AsyncCrud(Generic[ModelType]):
return set()
return set(cls.m2m_fields.keys())
@classmethod
def filter_params(
cls: type[Self],
*,
facet_fields: Sequence[FacetFieldType] | None = None,
) -> Callable[..., Awaitable[dict[str, list[str]]]]:
"""Return a FastAPI dependency that collects facet filter values from query parameters.
Args:
facet_fields: Override the facet fields for this dependency. Falls back to the
class-level ``facet_fields`` if not provided.
Returns:
An async dependency function named ``{Model}FilterParams`` that resolves to a
``dict[str, list[str]]`` containing only the keys that were supplied in the
request (absent/``None`` parameters are excluded).
Raises:
ValueError: If no facet fields are configured on this CRUD class and none are
provided via ``facet_fields``.
"""
fields = facet_fields if facet_fields is not None else cls.facet_fields
if not fields:
raise ValueError(
f"{cls.__name__} has no facet_fields configured. "
"Pass facet_fields= or set them on CrudFactory."
)
keys = facet_keys(fields)
async def dependency(**kwargs: Any) -> dict[str, list[str]]:
return {k: v for k, v in kwargs.items() if v is not None}
dependency.__name__ = f"{cls.model.__name__}FilterParams"
dependency.__signature__ = inspect.Signature( # type: ignore[attr-defined]
parameters=[
inspect.Parameter(
k,
inspect.Parameter.KEYWORD_ONLY,
annotation=list[str] | None,
default=Query(default=None),
)
for k in keys
]
)
return dependency
@overload
@classmethod
async def create( # pragma: no cover
@@ -693,6 +750,8 @@ class AsyncCrud(Generic[ModelType]):
items_per_page: int = 20,
search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None,
facet_fields: Sequence[FacetFieldType] | None = None,
filter_by: dict[str, Any] | BaseModel | None = None,
schema: type[SchemaType],
) -> PaginatedResponse[SchemaType]: ...
@@ -712,6 +771,8 @@ class AsyncCrud(Generic[ModelType]):
items_per_page: int = 20,
search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None,
facet_fields: Sequence[FacetFieldType] | None = None,
filter_by: dict[str, Any] | BaseModel | None = None,
schema: None = ...,
) -> PaginatedResponse[ModelType]: ...
@@ -729,6 +790,8 @@ class AsyncCrud(Generic[ModelType]):
items_per_page: int = 20,
search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None,
facet_fields: Sequence[FacetFieldType] | None = None,
filter_by: dict[str, Any] | BaseModel | None = None,
schema: type[BaseModel] | None = None,
) -> PaginatedResponse[ModelType] | PaginatedResponse[Any]:
"""Get paginated results using offset-based pagination.
@@ -744,6 +807,10 @@ class AsyncCrud(Generic[ModelType]):
items_per_page: Number of items per page
search: Search query string or SearchConfig object
search_fields: Fields to search in (overrides class default)
facet_fields: Columns to compute distinct values for (overrides class default)
filter_by: Dict of {column_key: value} to filter by declared facet fields.
Keys must match the column.key of a facet field. Scalar → equality,
list → IN clause. Raises InvalidFacetFilterError for unknown keys.
schema: Optional Pydantic schema to serialize each item into.
Returns:
@@ -753,6 +820,20 @@ class AsyncCrud(Generic[ModelType]):
offset = (page - 1) * items_per_page
search_joins: list[Any] = []
if isinstance(filter_by, BaseModel):
filter_by = filter_by.model_dump(exclude_none=True) or None
# Build filter_by conditions from declared facet fields
if filter_by:
resolved_facets_for_filter = (
facet_fields if facet_fields is not None else cls.facet_fields
)
fb_filters, fb_joins = build_filter_by(
filter_by, resolved_facets_for_filter or []
)
filters.extend(fb_filters)
search_joins.extend(fb_joins)
# Build search filters
if search:
search_filters, search_joins = build_search_filters(
@@ -817,6 +898,20 @@ class AsyncCrud(Generic[ModelType]):
count_result = await session.execute(count_q)
total_count = count_result.scalar_one()
# Build facets
resolved_facet_fields = (
facet_fields if facet_fields is not None else cls.facet_fields
)
filter_attributes: dict[str, list[Any]] | None = None
if resolved_facet_fields:
filter_attributes = await build_facets(
session,
cls.model,
resolved_facet_fields,
base_filters=filters or None,
base_joins=search_joins or None,
)
return PaginatedResponse(
data=items,
pagination=OffsetPagination(
@@ -825,6 +920,7 @@ class AsyncCrud(Generic[ModelType]):
page=page,
has_more=page * items_per_page < total_count,
),
filter_attributes=filter_attributes,
)
# Backward-compatible - will be removed in v2.0
@@ -845,6 +941,8 @@ class AsyncCrud(Generic[ModelType]):
items_per_page: int = 20,
search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None,
facet_fields: Sequence[FacetFieldType] | None = None,
filter_by: dict[str, Any] | BaseModel | None = None,
schema: type[SchemaType],
) -> PaginatedResponse[SchemaType]: ...
@@ -864,6 +962,8 @@ class AsyncCrud(Generic[ModelType]):
items_per_page: int = 20,
search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None,
facet_fields: Sequence[FacetFieldType] | None = None,
filter_by: dict[str, Any] | BaseModel | None = None,
schema: None = ...,
) -> PaginatedResponse[ModelType]: ...
@@ -881,6 +981,8 @@ class AsyncCrud(Generic[ModelType]):
items_per_page: int = 20,
search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None,
facet_fields: Sequence[FacetFieldType] | None = None,
filter_by: dict[str, Any] | BaseModel | None = None,
schema: type[BaseModel] | None = None,
) -> PaginatedResponse[ModelType] | PaginatedResponse[Any]:
"""Get paginated results using cursor-based pagination.
@@ -899,6 +1001,10 @@ class AsyncCrud(Generic[ModelType]):
items_per_page: Number of items per page (default 20).
search: Search query string or SearchConfig object.
search_fields: Fields to search in (overrides class default).
facet_fields: Columns to compute distinct values for (overrides class default).
filter_by: Dict of {column_key: value} to filter by declared facet fields.
Keys must match the column.key of a facet field. Scalar → equality,
list → IN clause. Raises InvalidFacetFilterError for unknown keys.
schema: Optional Pydantic schema to serialize each item into.
Returns:
@@ -907,6 +1013,20 @@ class AsyncCrud(Generic[ModelType]):
filters = list(filters) if filters else []
search_joins: list[Any] = []
if isinstance(filter_by, BaseModel):
filter_by = filter_by.model_dump(exclude_none=True) or None
# Build filter_by conditions from declared facet fields
if filter_by:
resolved_facets_for_filter = (
facet_fields if facet_fields is not None else cls.facet_fields
)
fb_filters, fb_joins = build_filter_by(
filter_by, resolved_facets_for_filter or []
)
filters.extend(fb_filters)
search_joins.extend(fb_joins)
if cls.cursor_column is None:
raise ValueError(
f"{cls.__name__}.cursor_column is not set. "
@@ -996,6 +1116,20 @@ class AsyncCrud(Generic[ModelType]):
else items_page
)
# Build facets
resolved_facet_fields = (
facet_fields if facet_fields is not None else cls.facet_fields
)
filter_attributes: dict[str, list[Any]] | None = None
if resolved_facet_fields:
filter_attributes = await build_facets(
session,
cls.model,
resolved_facet_fields,
base_filters=filters or None,
base_joins=search_joins or None,
)
return PaginatedResponse(
data=items,
pagination=CursorPagination(
@@ -1004,6 +1138,7 @@ class AsyncCrud(Generic[ModelType]):
items_per_page=items_per_page,
has_more=has_more,
),
filter_attributes=filter_attributes,
)
@@ -1011,6 +1146,7 @@ def CrudFactory(
model: type[ModelType],
*,
searchable_fields: Sequence[SearchFieldType] | None = None,
facet_fields: Sequence[FacetFieldType] | None = None,
m2m_fields: M2MFieldType | None = None,
default_load_options: list[ExecutableOption] | None = None,
cursor_column: Any | None = None,
@@ -1020,6 +1156,9 @@ def CrudFactory(
Args:
model: SQLAlchemy model class
searchable_fields: Optional list of searchable fields
facet_fields: Optional list of columns to compute distinct values for in paginated
responses. Supports direct columns (``User.status``) and relationship tuples
(``(User.role, Role.name)``). Can be overridden per call.
m2m_fields: Optional mapping for many-to-many relationships.
Maps schema field names (containing lists of IDs) to
SQLAlchemy relationship attributes.
@@ -1056,6 +1195,12 @@ def CrudFactory(
m2m_fields={"tag_ids": Post.tags},
)
# With facet fields for filter dropdowns / faceted search:
UserCrud = CrudFactory(
User,
facet_fields=[User.status, User.country, (User.role, Role.name)],
)
# With a fixed cursor column for cursor_paginate:
PostCrud = CrudFactory(
Post,
@@ -1106,6 +1251,7 @@ def CrudFactory(
{
"model": model,
"searchable_fields": searchable_fields,
"facet_fields": facet_fields,
"m2m_fields": m2m_fields,
"default_load_options": default_load_options,
"cursor_column": cursor_column,