feat: add faceted search in CrudFactory (#97)

* feat: add faceted search in CrudFactory

* feat: add filter_params_schema in CrudFactory

* fix: add missing Raises in build_search_filters docstring

* fix: faceted search

* fix: cov

* fix: documentation/filter_params
This commit is contained in:
d3vyce
2026-02-26 15:23:07 +01:00
committed by GitHub
parent 433dc55fcd
commit 5a08ec2f57
11 changed files with 1026 additions and 9 deletions

View File

@@ -1,19 +1,23 @@
"""Search utilities for AsyncCrud."""
import asyncio
from collections import Counter
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal
from sqlalchemy import String, or_
from sqlalchemy import String, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm.attributes import InstrumentedAttribute
from ..exceptions import NoSearchableFieldsError
from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError
if TYPE_CHECKING:
from sqlalchemy.sql.elements import ColumnElement
SearchFieldType = InstrumentedAttribute[Any] | tuple[InstrumentedAttribute[Any], ...]
FacetFieldType = SearchFieldType
@dataclass
@@ -89,6 +93,9 @@ def build_search_filters(
Returns:
Tuple of (filter_conditions, joins_needed)
Raises:
NoSearchableFieldsError: If no searchable field has been configured
"""
# Normalize input
if isinstance(search, str):
@@ -136,7 +143,7 @@ def build_search_filters(
else:
filters.append(column_as_string.ilike(f"%{query}%"))
if not filters:
if not filters: # pragma: no cover
return [], []
# Combine based on match_mode
@@ -144,3 +151,145 @@ def build_search_filters(
return [or_(*filters)], joins
else:
return filters, joins
def facet_keys(facet_fields: Sequence[FacetFieldType]) -> list[str]:
"""Return a key for each facet field, disambiguating duplicate column keys.
Args:
facet_fields: Sequence of facet fields — either direct columns or
relationship tuples ``(rel, ..., column)``.
Returns:
A list of string keys, one per facet field, in the same order.
"""
raw: list[tuple[str, str | None]] = []
for field in facet_fields:
if isinstance(field, tuple):
rel = field[-2]
column = field[-1]
raw.append((column.key, rel.key))
else:
raw.append((field.key, None))
counts = Counter(col_key for col_key, _ in raw)
keys: list[str] = []
for col_key, rel_key in raw:
if counts[col_key] > 1 and rel_key is not None:
keys.append(f"{rel_key}__{col_key}")
else:
keys.append(col_key)
return keys
async def build_facets(
session: "AsyncSession",
model: type[DeclarativeBase],
facet_fields: Sequence[FacetFieldType],
*,
base_filters: "list[ColumnElement[bool]] | None" = None,
base_joins: list[InstrumentedAttribute[Any]] | None = None,
) -> dict[str, list[Any]]:
"""Return distinct values for each facet field, respecting current filters.
Args:
session: DB async session
model: SQLAlchemy model class
facet_fields: Columns or relationship tuples to facet on
base_filters: Filter conditions already applied to the main query (search + caller filters)
base_joins: Relationship joins already applied to the main query
Returns:
Dict mapping column key to sorted list of distinct non-None values
"""
existing_join_keys: set[str] = {str(j) for j in (base_joins or [])}
keys = facet_keys(facet_fields)
async def _query_facet(field: FacetFieldType, key: str) -> tuple[str, list[Any]]:
if isinstance(field, tuple):
# Relationship chain: (User.role, Role.name) — last element is the column
rels = field[:-1]
column = field[-1]
else:
rels = ()
column = field
q = select(column).select_from(model).distinct()
# Apply base joins (already done on main query, but needed here independently)
for rel in base_joins or []:
q = q.outerjoin(rel)
# Add any extra joins required by this facet field that aren't already in base_joins
for rel in rels:
if str(rel) not in existing_join_keys:
q = q.outerjoin(rel)
if base_filters:
from sqlalchemy import and_
q = q.where(and_(*base_filters))
q = q.order_by(column)
result = await session.execute(q)
values = [row[0] for row in result.all() if row[0] is not None]
return key, values
pairs = await asyncio.gather(
*[_query_facet(f, k) for f, k in zip(facet_fields, keys)]
)
return dict(pairs)
def build_filter_by(
filter_by: dict[str, Any],
facet_fields: Sequence[FacetFieldType],
) -> tuple["list[ColumnElement[bool]]", list[InstrumentedAttribute[Any]]]:
"""Translate a {column_key: value} dict into SQLAlchemy filter conditions.
Args:
filter_by: Mapping of column key to scalar value or list of values
facet_fields: Declared facet fields to validate keys against
Returns:
Tuple of (filter_conditions, joins_needed)
Raises:
InvalidFacetFilterError: If a key in filter_by is not a declared facet field
"""
index: dict[
str, tuple[InstrumentedAttribute[Any], list[InstrumentedAttribute[Any]]]
] = {}
for key, field in zip(facet_keys(facet_fields), facet_fields):
if isinstance(field, tuple):
rels = list(field[:-1])
column = field[-1]
else:
rels = []
column = field
index[key] = (column, rels)
valid_keys = set(index)
filters: list[ColumnElement[bool]] = []
joins: list[InstrumentedAttribute[Any]] = []
added_join_keys: set[str] = set()
for key, value in filter_by.items():
if key not in index:
raise InvalidFacetFilterError(key, valid_keys)
column, rels = index[key]
for rel in rels:
rel_key = str(rel)
if rel_key not in added_join_keys:
joins.append(rel)
added_join_keys.add(rel_key)
if isinstance(value, list):
filters.append(column.in_(value))
else:
filters.append(column == value)
return filters, joins