diff --git a/docs/module/crud.md b/docs/module/crud.md index d869b63..33eb88a 100644 --- a/docs/module/crud.md +++ b/docs/module/crud.md @@ -168,7 +168,19 @@ PostCrud = CrudFactory(model=Post, cursor_column=Post.created_at) ## Search -Declare searchable fields on the CRUD class. Relationship traversal is supported via tuples: +Two search strategies are available, both compatible with [`offset_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.offset_paginate) and [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate). + +| | Full-text search | Filter attributes | +|---|---|---| +| Input | Free-text string | Exact column values | +| Relationship support | Yes | Yes | +| Use case | Search bars | Filter dropdowns | + +!!! info "You can use both search strategies in the same endpoint!" + +### Full-text search + +Declare `searchable_fields` on the CRUD class. Relationship traversal is supported via tuples: ```python PostCrud = CrudFactory( @@ -181,6 +193,15 @@ PostCrud = CrudFactory( ) ``` +You can override `searchable_fields` per call with `search_fields`: + +```python +result = await UserCrud.offset_paginate( + session=session, + search_fields=[User.country], +) +``` + This allows searching with both [`offset_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.offset_paginate) and [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate): ```python @@ -221,6 +242,87 @@ async def get_users( ) ``` +### Filter attributes + +!!! info "Added in `v1.2`" + +Declare `facet_fields` on the CRUD class to return distinct column values alongside paginated results. This is useful for populating filter dropdowns or building faceted search UIs. + +Facet fields use the same syntax as `searchable_fields` — direct columns or relationship tuples: + +```python +UserCrud = CrudFactory( + model=User, + facet_fields=[ + User.status, + User.country, + (User.role, Role.name), # value from a related model + ], +) +``` + +You can override `facet_fields` per call: + +```python +result = await UserCrud.offset_paginate( + session=session, + facet_fields=[User.country], +) +``` + +The distinct values are returned in the `filter_attributes` field of [`PaginatedResponse`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse): + +```json +{ + "status": "SUCCESS", + "data": ["..."], + "pagination": { "..." }, + "filter_attributes": { + "status": ["active", "inactive"], + "country": ["DE", "FR", "US"], + "name": ["admin", "editor", "viewer"] + } +} +``` + +Use `filter_by` to pass the client's chosen filter values directly — no need to build SQLAlchemy conditions by hand. Any unknown key raises [`InvalidFacetFilterError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.InvalidFacetFilterError). + +!!! info "The keys in `filter_by` are the same keys the client received in `filter_attributes`." + Keys are normally the terminal `column.key` (e.g. `"name"` for `Role.name`). When two facet fields share the same column key (e.g. `(Build.project, Project.name)` and `(Build.os, Os.name)`), the relationship name is prepended automatically: `"project__name"` and `"os__name"`. + +`filter_by` and `filters` can be combined — both are applied with AND logic. + +Use [`filter_params()`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.filter_params) to generate a dict with the facet filter values from the query parameters: + +```python +from fastapi import Depends + +UserCrud = CrudFactory( + model=User, + facet_fields=[User.status, User.country, (User.role, Role.name)], +) + +@router.get("", response_model_exclude_none=True) +async def list_users( + session: SessionDep, + page: int = 1, + filter_by: dict[str, list[str]] = Depends(UserCrud.filter_params()), +) -> PaginatedResponse[UserRead]: + return await UserCrud.offset_paginate( + session=session, + page=page, + filter_by=filter_by, + ) +``` + +Both single-value and multi-value query parameters work: + +``` +GET /users?status=active → filter_by={"status": ["active"]} +GET /users?status=active&country=FR → filter_by={"status": ["active"], "country": ["FR"]} +GET /users?role=admin&role=editor → filter_by={"role": ["admin", "editor"]} (IN clause) +``` + ## Relationship loading !!! info "Added in `v1.1`" diff --git a/docs/module/exceptions.md b/docs/module/exceptions.md index 171d5df..4f318d1 100644 --- a/docs/module/exceptions.md +++ b/docs/module/exceptions.md @@ -34,6 +34,7 @@ This registers handlers for: | [`NotFoundError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.NotFoundError) | 404 | Not found | | [`ConflictError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.ConflictError) | 409 | Conflict | | [`NoSearchableFieldsError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.NoSearchableFieldsError) | 400 | No searchable fields | +| [`InvalidFacetFilterError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.InvalidFacetFilterError) | 400 | Invalid facet filter | ```python from fastapi_toolsets.exceptions import NotFoundError diff --git a/docs/module/schemas.md b/docs/module/schemas.md index 3e0b670..b92dc8c 100644 --- a/docs/module/schemas.md +++ b/docs/module/schemas.md @@ -22,7 +22,7 @@ async def get_user(user: User = UserDep) -> Response[UserSchema]: ### [`PaginatedResponse[T]`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse) -Wraps a list of items with pagination metadata. +Wraps a list of items with pagination metadata and optional facet values. ```python from fastapi_toolsets.schemas import PaginatedResponse, Pagination @@ -40,6 +40,8 @@ async def list_users() -> PaginatedResponse[UserSchema]: ) ``` +The optional `filter_attributes` field is populated when `facet_fields` are configured on the CRUD class (see [Filter attributes](crud.md#filter-attributes-facets)). It is `None` by default and can be hidden from API responses with `response_model_exclude_none=True`. + ### [`ErrorResponse`](../reference/schemas.md#fastapi_toolsets.schemas.ErrorResponse) Returned automatically by the exceptions handler. diff --git a/docs/reference/exceptions.md b/docs/reference/exceptions.md index d3add67..6df730e 100644 --- a/docs/reference/exceptions.md +++ b/docs/reference/exceptions.md @@ -12,6 +12,7 @@ from fastapi_toolsets.exceptions import ( NotFoundError, ConflictError, NoSearchableFieldsError, + InvalidFacetFilterError, generate_error_responses, init_exceptions_handlers, ) @@ -29,6 +30,8 @@ from fastapi_toolsets.exceptions import ( ## ::: fastapi_toolsets.exceptions.exceptions.NoSearchableFieldsError +## ::: fastapi_toolsets.exceptions.exceptions.InvalidFacetFilterError + ## ::: fastapi_toolsets.exceptions.exceptions.generate_error_responses ## ::: fastapi_toolsets.exceptions.handler.init_exceptions_handlers diff --git a/src/fastapi_toolsets/crud/__init__.py b/src/fastapi_toolsets/crud/__init__.py index 3aaa5bf..3e311d1 100644 --- a/src/fastapi_toolsets/crud/__init__.py +++ b/src/fastapi_toolsets/crud/__init__.py @@ -1,15 +1,18 @@ """Generic async CRUD operations for SQLAlchemy models.""" -from ..exceptions import NoSearchableFieldsError +from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError from .factory import CrudFactory, JoinType, M2MFieldType from .search import ( + FacetFieldType, SearchConfig, get_searchable_fields, ) __all__ = [ "CrudFactory", + "FacetFieldType", "get_searchable_fields", + "InvalidFacetFilterError", "JoinType", "M2MFieldType", "NoSearchableFieldsError", diff --git a/src/fastapi_toolsets/crud/factory.py b/src/fastapi_toolsets/crud/factory.py index 75b08ad..a2849ed 100644 --- a/src/fastapi_toolsets/crud/factory.py +++ b/src/fastapi_toolsets/crud/factory.py @@ -3,14 +3,16 @@ from __future__ import annotations import base64 +import inspect import json import uuid as uuid_module import warnings -from collections.abc import Mapping, Sequence +from collections.abc import Awaitable, Callable, Mapping, Sequence from datetime import date, datetime from decimal import Decimal from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload +from fastapi import Query from pydantic import BaseModel from sqlalchemy import Date, DateTime, Float, Integer, Numeric, Uuid, and_, func, select from sqlalchemy import delete as sql_delete @@ -24,7 +26,15 @@ from sqlalchemy.sql.roles import WhereHavingRole from ..db import get_transaction from ..exceptions import NotFoundError from ..schemas import CursorPagination, OffsetPagination, PaginatedResponse, Response -from .search import SearchConfig, SearchFieldType, build_search_filters +from .search import ( + FacetFieldType, + SearchConfig, + SearchFieldType, + build_facets, + build_filter_by, + build_search_filters, + facet_keys, +) ModelType = TypeVar("ModelType", bound=DeclarativeBase) SchemaType = TypeVar("SchemaType", bound=BaseModel) @@ -50,6 +60,7 @@ class AsyncCrud(Generic[ModelType]): model: ClassVar[type[DeclarativeBase]] searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None + facet_fields: ClassVar[Sequence[FacetFieldType] | None] = None m2m_fields: ClassVar[M2MFieldType | None] = None default_load_options: ClassVar[list[ExecutableOption] | None] = None cursor_column: ClassVar[Any | None] = None @@ -119,6 +130,52 @@ class AsyncCrud(Generic[ModelType]): return set() return set(cls.m2m_fields.keys()) + @classmethod + def filter_params( + cls: type[Self], + *, + facet_fields: Sequence[FacetFieldType] | None = None, + ) -> Callable[..., Awaitable[dict[str, list[str]]]]: + """Return a FastAPI dependency that collects facet filter values from query parameters. + Args: + facet_fields: Override the facet fields for this dependency. Falls back to the + class-level ``facet_fields`` if not provided. + + Returns: + An async dependency function named ``{Model}FilterParams`` that resolves to a + ``dict[str, list[str]]`` containing only the keys that were supplied in the + request (absent/``None`` parameters are excluded). + + Raises: + ValueError: If no facet fields are configured on this CRUD class and none are + provided via ``facet_fields``. + """ + fields = facet_fields if facet_fields is not None else cls.facet_fields + if not fields: + raise ValueError( + f"{cls.__name__} has no facet_fields configured. " + "Pass facet_fields= or set them on CrudFactory." + ) + keys = facet_keys(fields) + + async def dependency(**kwargs: Any) -> dict[str, list[str]]: + return {k: v for k, v in kwargs.items() if v is not None} + + dependency.__name__ = f"{cls.model.__name__}FilterParams" + dependency.__signature__ = inspect.Signature( # type: ignore[attr-defined] + parameters=[ + inspect.Parameter( + k, + inspect.Parameter.KEYWORD_ONLY, + annotation=list[str] | None, + default=Query(default=None), + ) + for k in keys + ] + ) + + return dependency + @overload @classmethod async def create( # pragma: no cover @@ -693,6 +750,8 @@ class AsyncCrud(Generic[ModelType]): items_per_page: int = 20, search: str | SearchConfig | None = None, search_fields: Sequence[SearchFieldType] | None = None, + facet_fields: Sequence[FacetFieldType] | None = None, + filter_by: dict[str, Any] | BaseModel | None = None, schema: type[SchemaType], ) -> PaginatedResponse[SchemaType]: ... @@ -712,6 +771,8 @@ class AsyncCrud(Generic[ModelType]): items_per_page: int = 20, search: str | SearchConfig | None = None, search_fields: Sequence[SearchFieldType] | None = None, + facet_fields: Sequence[FacetFieldType] | None = None, + filter_by: dict[str, Any] | BaseModel | None = None, schema: None = ..., ) -> PaginatedResponse[ModelType]: ... @@ -729,6 +790,8 @@ class AsyncCrud(Generic[ModelType]): items_per_page: int = 20, search: str | SearchConfig | None = None, search_fields: Sequence[SearchFieldType] | None = None, + facet_fields: Sequence[FacetFieldType] | None = None, + filter_by: dict[str, Any] | BaseModel | None = None, schema: type[BaseModel] | None = None, ) -> PaginatedResponse[ModelType] | PaginatedResponse[Any]: """Get paginated results using offset-based pagination. @@ -744,6 +807,10 @@ class AsyncCrud(Generic[ModelType]): items_per_page: Number of items per page search: Search query string or SearchConfig object search_fields: Fields to search in (overrides class default) + facet_fields: Columns to compute distinct values for (overrides class default) + filter_by: Dict of {column_key: value} to filter by declared facet fields. + Keys must match the column.key of a facet field. Scalar → equality, + list → IN clause. Raises InvalidFacetFilterError for unknown keys. schema: Optional Pydantic schema to serialize each item into. Returns: @@ -753,6 +820,20 @@ class AsyncCrud(Generic[ModelType]): offset = (page - 1) * items_per_page search_joins: list[Any] = [] + if isinstance(filter_by, BaseModel): + filter_by = filter_by.model_dump(exclude_none=True) or None + + # Build filter_by conditions from declared facet fields + if filter_by: + resolved_facets_for_filter = ( + facet_fields if facet_fields is not None else cls.facet_fields + ) + fb_filters, fb_joins = build_filter_by( + filter_by, resolved_facets_for_filter or [] + ) + filters.extend(fb_filters) + search_joins.extend(fb_joins) + # Build search filters if search: search_filters, search_joins = build_search_filters( @@ -817,6 +898,20 @@ class AsyncCrud(Generic[ModelType]): count_result = await session.execute(count_q) total_count = count_result.scalar_one() + # Build facets + resolved_facet_fields = ( + facet_fields if facet_fields is not None else cls.facet_fields + ) + filter_attributes: dict[str, list[Any]] | None = None + if resolved_facet_fields: + filter_attributes = await build_facets( + session, + cls.model, + resolved_facet_fields, + base_filters=filters or None, + base_joins=search_joins or None, + ) + return PaginatedResponse( data=items, pagination=OffsetPagination( @@ -825,6 +920,7 @@ class AsyncCrud(Generic[ModelType]): page=page, has_more=page * items_per_page < total_count, ), + filter_attributes=filter_attributes, ) # Backward-compatible - will be removed in v2.0 @@ -845,6 +941,8 @@ class AsyncCrud(Generic[ModelType]): items_per_page: int = 20, search: str | SearchConfig | None = None, search_fields: Sequence[SearchFieldType] | None = None, + facet_fields: Sequence[FacetFieldType] | None = None, + filter_by: dict[str, Any] | BaseModel | None = None, schema: type[SchemaType], ) -> PaginatedResponse[SchemaType]: ... @@ -864,6 +962,8 @@ class AsyncCrud(Generic[ModelType]): items_per_page: int = 20, search: str | SearchConfig | None = None, search_fields: Sequence[SearchFieldType] | None = None, + facet_fields: Sequence[FacetFieldType] | None = None, + filter_by: dict[str, Any] | BaseModel | None = None, schema: None = ..., ) -> PaginatedResponse[ModelType]: ... @@ -881,6 +981,8 @@ class AsyncCrud(Generic[ModelType]): items_per_page: int = 20, search: str | SearchConfig | None = None, search_fields: Sequence[SearchFieldType] | None = None, + facet_fields: Sequence[FacetFieldType] | None = None, + filter_by: dict[str, Any] | BaseModel | None = None, schema: type[BaseModel] | None = None, ) -> PaginatedResponse[ModelType] | PaginatedResponse[Any]: """Get paginated results using cursor-based pagination. @@ -899,6 +1001,10 @@ class AsyncCrud(Generic[ModelType]): items_per_page: Number of items per page (default 20). search: Search query string or SearchConfig object. search_fields: Fields to search in (overrides class default). + facet_fields: Columns to compute distinct values for (overrides class default). + filter_by: Dict of {column_key: value} to filter by declared facet fields. + Keys must match the column.key of a facet field. Scalar → equality, + list → IN clause. Raises InvalidFacetFilterError for unknown keys. schema: Optional Pydantic schema to serialize each item into. Returns: @@ -907,6 +1013,20 @@ class AsyncCrud(Generic[ModelType]): filters = list(filters) if filters else [] search_joins: list[Any] = [] + if isinstance(filter_by, BaseModel): + filter_by = filter_by.model_dump(exclude_none=True) or None + + # Build filter_by conditions from declared facet fields + if filter_by: + resolved_facets_for_filter = ( + facet_fields if facet_fields is not None else cls.facet_fields + ) + fb_filters, fb_joins = build_filter_by( + filter_by, resolved_facets_for_filter or [] + ) + filters.extend(fb_filters) + search_joins.extend(fb_joins) + if cls.cursor_column is None: raise ValueError( f"{cls.__name__}.cursor_column is not set. " @@ -996,6 +1116,20 @@ class AsyncCrud(Generic[ModelType]): else items_page ) + # Build facets + resolved_facet_fields = ( + facet_fields if facet_fields is not None else cls.facet_fields + ) + filter_attributes: dict[str, list[Any]] | None = None + if resolved_facet_fields: + filter_attributes = await build_facets( + session, + cls.model, + resolved_facet_fields, + base_filters=filters or None, + base_joins=search_joins or None, + ) + return PaginatedResponse( data=items, pagination=CursorPagination( @@ -1004,6 +1138,7 @@ class AsyncCrud(Generic[ModelType]): items_per_page=items_per_page, has_more=has_more, ), + filter_attributes=filter_attributes, ) @@ -1011,6 +1146,7 @@ def CrudFactory( model: type[ModelType], *, searchable_fields: Sequence[SearchFieldType] | None = None, + facet_fields: Sequence[FacetFieldType] | None = None, m2m_fields: M2MFieldType | None = None, default_load_options: list[ExecutableOption] | None = None, cursor_column: Any | None = None, @@ -1020,6 +1156,9 @@ def CrudFactory( Args: model: SQLAlchemy model class searchable_fields: Optional list of searchable fields + facet_fields: Optional list of columns to compute distinct values for in paginated + responses. Supports direct columns (``User.status``) and relationship tuples + (``(User.role, Role.name)``). Can be overridden per call. m2m_fields: Optional mapping for many-to-many relationships. Maps schema field names (containing lists of IDs) to SQLAlchemy relationship attributes. @@ -1056,6 +1195,12 @@ def CrudFactory( m2m_fields={"tag_ids": Post.tags}, ) + # With facet fields for filter dropdowns / faceted search: + UserCrud = CrudFactory( + User, + facet_fields=[User.status, User.country, (User.role, Role.name)], + ) + # With a fixed cursor column for cursor_paginate: PostCrud = CrudFactory( Post, @@ -1106,6 +1251,7 @@ def CrudFactory( { "model": model, "searchable_fields": searchable_fields, + "facet_fields": facet_fields, "m2m_fields": m2m_fields, "default_load_options": default_load_options, "cursor_column": cursor_column, diff --git a/src/fastapi_toolsets/crud/search.py b/src/fastapi_toolsets/crud/search.py index 72c5a58..6e01df7 100644 --- a/src/fastapi_toolsets/crud/search.py +++ b/src/fastapi_toolsets/crud/search.py @@ -1,19 +1,23 @@ """Search utilities for AsyncCrud.""" +import asyncio +from collections import Counter from collections.abc import Sequence from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Literal -from sqlalchemy import String, or_ +from sqlalchemy import String, or_, select +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm.attributes import InstrumentedAttribute -from ..exceptions import NoSearchableFieldsError +from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError if TYPE_CHECKING: from sqlalchemy.sql.elements import ColumnElement SearchFieldType = InstrumentedAttribute[Any] | tuple[InstrumentedAttribute[Any], ...] +FacetFieldType = SearchFieldType @dataclass @@ -89,6 +93,9 @@ def build_search_filters( Returns: Tuple of (filter_conditions, joins_needed) + + Raises: + NoSearchableFieldsError: If no searchable field has been configured """ # Normalize input if isinstance(search, str): @@ -136,7 +143,7 @@ def build_search_filters( else: filters.append(column_as_string.ilike(f"%{query}%")) - if not filters: + if not filters: # pragma: no cover return [], [] # Combine based on match_mode @@ -144,3 +151,145 @@ def build_search_filters( return [or_(*filters)], joins else: return filters, joins + + +def facet_keys(facet_fields: Sequence[FacetFieldType]) -> list[str]: + """Return a key for each facet field, disambiguating duplicate column keys. + + Args: + facet_fields: Sequence of facet fields — either direct columns or + relationship tuples ``(rel, ..., column)``. + + Returns: + A list of string keys, one per facet field, in the same order. + """ + raw: list[tuple[str, str | None]] = [] + for field in facet_fields: + if isinstance(field, tuple): + rel = field[-2] + column = field[-1] + raw.append((column.key, rel.key)) + else: + raw.append((field.key, None)) + + counts = Counter(col_key for col_key, _ in raw) + keys: list[str] = [] + for col_key, rel_key in raw: + if counts[col_key] > 1 and rel_key is not None: + keys.append(f"{rel_key}__{col_key}") + else: + keys.append(col_key) + return keys + + +async def build_facets( + session: "AsyncSession", + model: type[DeclarativeBase], + facet_fields: Sequence[FacetFieldType], + *, + base_filters: "list[ColumnElement[bool]] | None" = None, + base_joins: list[InstrumentedAttribute[Any]] | None = None, +) -> dict[str, list[Any]]: + """Return distinct values for each facet field, respecting current filters. + + Args: + session: DB async session + model: SQLAlchemy model class + facet_fields: Columns or relationship tuples to facet on + base_filters: Filter conditions already applied to the main query (search + caller filters) + base_joins: Relationship joins already applied to the main query + + Returns: + Dict mapping column key to sorted list of distinct non-None values + """ + existing_join_keys: set[str] = {str(j) for j in (base_joins or [])} + + keys = facet_keys(facet_fields) + + async def _query_facet(field: FacetFieldType, key: str) -> tuple[str, list[Any]]: + if isinstance(field, tuple): + # Relationship chain: (User.role, Role.name) — last element is the column + rels = field[:-1] + column = field[-1] + else: + rels = () + column = field + + q = select(column).select_from(model).distinct() + + # Apply base joins (already done on main query, but needed here independently) + for rel in base_joins or []: + q = q.outerjoin(rel) + + # Add any extra joins required by this facet field that aren't already in base_joins + for rel in rels: + if str(rel) not in existing_join_keys: + q = q.outerjoin(rel) + + if base_filters: + from sqlalchemy import and_ + + q = q.where(and_(*base_filters)) + + q = q.order_by(column) + result = await session.execute(q) + values = [row[0] for row in result.all() if row[0] is not None] + return key, values + + pairs = await asyncio.gather( + *[_query_facet(f, k) for f, k in zip(facet_fields, keys)] + ) + return dict(pairs) + + +def build_filter_by( + filter_by: dict[str, Any], + facet_fields: Sequence[FacetFieldType], +) -> tuple["list[ColumnElement[bool]]", list[InstrumentedAttribute[Any]]]: + """Translate a {column_key: value} dict into SQLAlchemy filter conditions. + + Args: + filter_by: Mapping of column key to scalar value or list of values + facet_fields: Declared facet fields to validate keys against + + Returns: + Tuple of (filter_conditions, joins_needed) + + Raises: + InvalidFacetFilterError: If a key in filter_by is not a declared facet field + """ + index: dict[ + str, tuple[InstrumentedAttribute[Any], list[InstrumentedAttribute[Any]]] + ] = {} + for key, field in zip(facet_keys(facet_fields), facet_fields): + if isinstance(field, tuple): + rels = list(field[:-1]) + column = field[-1] + else: + rels = [] + column = field + index[key] = (column, rels) + + valid_keys = set(index) + filters: list[ColumnElement[bool]] = [] + joins: list[InstrumentedAttribute[Any]] = [] + added_join_keys: set[str] = set() + + for key, value in filter_by.items(): + if key not in index: + raise InvalidFacetFilterError(key, valid_keys) + + column, rels = index[key] + + for rel in rels: + rel_key = str(rel) + if rel_key not in added_join_keys: + joins.append(rel) + added_join_keys.add(rel_key) + + if isinstance(value, list): + filters.append(column.in_(value)) + else: + filters.append(column == value) + + return filters, joins diff --git a/src/fastapi_toolsets/exceptions/__init__.py b/src/fastapi_toolsets/exceptions/__init__.py index 714b6cb..2bb2b65 100644 --- a/src/fastapi_toolsets/exceptions/__init__.py +++ b/src/fastapi_toolsets/exceptions/__init__.py @@ -5,6 +5,7 @@ from .exceptions import ( ApiException, ConflictError, ForbiddenError, + InvalidFacetFilterError, NoSearchableFieldsError, NotFoundError, UnauthorizedError, @@ -19,6 +20,7 @@ __all__ = [ "ForbiddenError", "generate_error_responses", "init_exceptions_handlers", + "InvalidFacetFilterError", "NoSearchableFieldsError", "NotFoundError", "UnauthorizedError", diff --git a/src/fastapi_toolsets/exceptions/exceptions.py b/src/fastapi_toolsets/exceptions/exceptions.py index 1755f40..87d34d9 100644 --- a/src/fastapi_toolsets/exceptions/exceptions.py +++ b/src/fastapi_toolsets/exceptions/exceptions.py @@ -102,6 +102,32 @@ class NoSearchableFieldsError(ApiException): super().__init__(detail) +class InvalidFacetFilterError(ApiException): + """Raised when filter_by contains a key not declared in facet_fields.""" + + api_error = ApiError( + code=400, + msg="Invalid Facet Filter", + desc="One or more filter_by keys are not declared as facet fields.", + err_code="FACET-400", + ) + + def __init__(self, key: str, valid_keys: set[str]) -> None: + """Initialize the exception. + + Args: + key: The unknown filter key provided by the caller + valid_keys: Set of valid keys derived from the declared facet_fields + """ + self.key = key + self.valid_keys = valid_keys + detail = ( + f"'{key}' is not a declared facet field. " + f"Valid keys: {sorted(valid_keys) or 'none — set facet_fields on the CRUD class'}." + ) + super().__init__(detail) + + def generate_error_responses( *errors: type[ApiException], ) -> dict[int | str, dict[str, Any]]: diff --git a/src/fastapi_toolsets/schemas.py b/src/fastapi_toolsets/schemas.py index 0a0070f..bcded4a 100644 --- a/src/fastapi_toolsets/schemas.py +++ b/src/fastapi_toolsets/schemas.py @@ -133,3 +133,4 @@ class PaginatedResponse(BaseResponse, Generic[DataT]): data: list[DataT] pagination: OffsetPagination | CursorPagination + filter_attributes: dict[str, list[Any]] | None = None diff --git a/tests/test_crud_search.py b/tests/test_crud_search.py index 79e6eaf..f87e5f9 100644 --- a/tests/test_crud_search.py +++ b/tests/test_crud_search.py @@ -5,7 +5,12 @@ import uuid import pytest from sqlalchemy.ext.asyncio import AsyncSession -from fastapi_toolsets.crud import SearchConfig, get_searchable_fields +from fastapi_toolsets.crud import ( + CrudFactory, + InvalidFacetFilterError, + SearchConfig, + get_searchable_fields, +) from fastapi_toolsets.schemas import OffsetPagination from .conftest import ( @@ -313,9 +318,35 @@ class TestPaginateSearch: assert result.data[0].id == user_id +class TestBuildSearchFilters: + """Unit tests for build_search_filters.""" + + def test_deduplicates_relationship_join(self): + """Two tuple fields sharing the same relationship do not add the join twice.""" + from fastapi_toolsets.crud.search import build_search_filters + + # Both fields traverse User.role — the second must not re-add the join. + filters, joins = build_search_filters( + User, + "admin", + search_fields=[(User.role, Role.name), (User.role, Role.id)], + ) + + assert len(joins) == 1 + + class TestSearchConfig: """Tests for SearchConfig options.""" + def test_search_config_empty_query_returns_empty(self): + """SearchConfig with an empty/blank query returns empty filters without hitting the DB.""" + from fastapi_toolsets.crud.search import build_search_filters + + filters, joins = build_search_filters(User, SearchConfig(query=" ")) + + assert filters == [] + assert joins == [] + @pytest.mark.anyio async def test_match_mode_all(self, db_session: AsyncSession): """match_mode='all' requires all fields to match (AND).""" @@ -432,3 +463,554 @@ class TestGetSearchableFields: # Role.users is a collection, should not be included field_strs = [str(f) for f in fields] assert not any("users" in f for f in field_strs) + + +class TestFacetsNotSet: + """filter_attributes is None when no facet_fields are configured.""" + + @pytest.mark.anyio + async def test_offset_paginate_no_facets(self, db_session: AsyncSession): + """filter_attributes is None when facet_fields not set on factory or call.""" + await UserCrud.create( + db_session, UserCreate(username="alice", email="a@test.com") + ) + + result = await UserCrud.offset_paginate(db_session) + + assert result.filter_attributes is None + + @pytest.mark.anyio + async def test_cursor_paginate_no_facets(self, db_session: AsyncSession): + """filter_attributes is None for cursor_paginate when facet_fields not set.""" + UserCursorCrud = CrudFactory(User, cursor_column=User.id) + await UserCrud.create( + db_session, UserCreate(username="alice", email="a@test.com") + ) + + result = await UserCursorCrud.cursor_paginate(db_session) + + assert result.filter_attributes is None + + +class TestFacetsDirectColumn: + """Facets on direct model columns.""" + + @pytest.mark.anyio + async def test_offset_paginate_direct_column(self, db_session: AsyncSession): + """Returns distinct values for a direct column via factory default.""" + UserFacetCrud = CrudFactory(User, facet_fields=[User.username]) + await UserCrud.create( + db_session, UserCreate(username="alice", email="a@test.com") + ) + await UserCrud.create( + db_session, UserCreate(username="bob", email="b@test.com") + ) + + result = await UserFacetCrud.offset_paginate(db_session) + + assert result.filter_attributes is not None + # Distinct usernames, sorted + assert result.filter_attributes["username"] == ["alice", "bob"] + + @pytest.mark.anyio + async def test_cursor_paginate_direct_column(self, db_session: AsyncSession): + """Returns distinct values for a direct column in cursor_paginate.""" + UserFacetCursorCrud = CrudFactory( + User, cursor_column=User.id, facet_fields=[User.email] + ) + await UserCrud.create( + db_session, UserCreate(username="alice", email="a@test.com") + ) + await UserCrud.create( + db_session, UserCreate(username="bob", email="b@test.com") + ) + + result = await UserFacetCursorCrud.cursor_paginate(db_session) + + assert result.filter_attributes is not None + assert set(result.filter_attributes["email"]) == {"a@test.com", "b@test.com"} + + @pytest.mark.anyio + async def test_multiple_facet_columns(self, db_session: AsyncSession): + """Returns distinct values for multiple columns.""" + UserFacetCrud = CrudFactory(User, facet_fields=[User.username, User.email]) + await UserCrud.create( + db_session, UserCreate(username="alice", email="a@test.com") + ) + await UserCrud.create( + db_session, UserCreate(username="bob", email="b@test.com") + ) + + result = await UserFacetCrud.offset_paginate(db_session) + + assert result.filter_attributes is not None + assert "username" in result.filter_attributes + assert "email" in result.filter_attributes + assert set(result.filter_attributes["username"]) == {"alice", "bob"} + + @pytest.mark.anyio + async def test_per_call_override(self, db_session: AsyncSession): + """Per-call facet_fields overrides the factory default.""" + UserFacetCrud = CrudFactory(User, facet_fields=[User.username]) + await UserCrud.create( + db_session, UserCreate(username="alice", email="a@test.com") + ) + + # Override: ask for email instead of username + result = await UserFacetCrud.offset_paginate( + db_session, facet_fields=[User.email] + ) + + assert result.filter_attributes is not None + assert "email" in result.filter_attributes + assert "username" not in result.filter_attributes + + +class TestFacetsRespectFilters: + """Facets reflect the active filter conditions.""" + + @pytest.mark.anyio + async def test_facets_respect_base_filters(self, db_session: AsyncSession): + """Facet values are scoped to the applied filters.""" + UserFacetCrud = CrudFactory(User, facet_fields=[User.username]) + await UserCrud.create( + db_session, UserCreate(username="alice", email="a@test.com", is_active=True) + ) + await UserCrud.create( + db_session, UserCreate(username="bob", email="b@test.com", is_active=False) + ) + + # Filter to active users only — facets should only see "alice" + result = await UserFacetCrud.offset_paginate( + db_session, + filters=[User.is_active == True], # noqa: E712 + ) + + assert result.filter_attributes is not None + assert result.filter_attributes["username"] == ["alice"] + + +class TestFacetsRelationship: + """Facets on relationship columns via tuple syntax.""" + + @pytest.mark.anyio + async def test_relationship_facet(self, db_session: AsyncSession): + """Returns distinct values from a related model column.""" + UserRelFacetCrud = CrudFactory(User, facet_fields=[(User.role, Role.name)]) + + admin = await RoleCrud.create(db_session, RoleCreate(name="admin")) + editor = await RoleCrud.create(db_session, RoleCreate(name="editor")) + + await UserCrud.create( + db_session, + UserCreate(username="alice", email="a@test.com", role_id=admin.id), + ) + await UserCrud.create( + db_session, + UserCreate(username="bob", email="b@test.com", role_id=editor.id), + ) + # User without a role — their role.name should be excluded (None filtered out) + await UserCrud.create( + db_session, UserCreate(username="charlie", email="c@test.com") + ) + + result = await UserRelFacetCrud.offset_paginate(db_session) + + assert result.filter_attributes is not None + assert set(result.filter_attributes["name"]) == {"admin", "editor"} + + @pytest.mark.anyio + async def test_relationship_facet_none_excluded(self, db_session: AsyncSession): + """None values (e.g. NULL role) are excluded from facet results.""" + UserRelFacetCrud = CrudFactory(User, facet_fields=[(User.role, Role.name)]) + + # Only user with no role + await UserCrud.create( + db_session, UserCreate(username="norole", email="n@test.com") + ) + + result = await UserRelFacetCrud.offset_paginate(db_session) + + assert result.filter_attributes is not None + assert result.filter_attributes["name"] == [] + + @pytest.mark.anyio + async def test_relationship_facet_deduplicates_join_with_search( + self, db_session: AsyncSession + ): + """Facet join is not duplicated when search already added the same relationship join.""" + # Both search and facet use (User.role, Role.name) — join should not be doubled + UserSearchFacetCrud = CrudFactory( + User, + searchable_fields=[(User.role, Role.name)], + facet_fields=[(User.role, Role.name)], + ) + + admin = await RoleCrud.create(db_session, RoleCreate(name="admin")) + await UserCrud.create( + db_session, + UserCreate(username="alice", email="a@test.com", role_id=admin.id), + ) + + result = await UserSearchFacetCrud.offset_paginate( + db_session, search="admin", search_fields=[(User.role, Role.name)] + ) + + assert result.filter_attributes is not None + assert result.filter_attributes["name"] == ["admin"] + + +class TestFilterBy: + """Tests for the filter_by parameter on offset_paginate and cursor_paginate.""" + + @pytest.mark.anyio + async def test_scalar_filter(self, db_session: AsyncSession): + """filter_by with a scalar value produces an equality filter.""" + UserFacetCrud = CrudFactory(User, facet_fields=[User.username]) + await UserCrud.create( + db_session, UserCreate(username="alice", email="a@test.com") + ) + await UserCrud.create( + db_session, UserCreate(username="bob", email="b@test.com") + ) + + result = await UserFacetCrud.offset_paginate( + db_session, filter_by={"username": "alice"} + ) + + assert len(result.data) == 1 + assert result.data[0].username == "alice" + # facet also scoped to the filter + assert result.filter_attributes == {"username": ["alice"]} + + @pytest.mark.anyio + async def test_list_filter_produces_in_clause(self, db_session: AsyncSession): + """filter_by with a list value produces an IN filter.""" + UserFacetCrud = CrudFactory(User, facet_fields=[User.username]) + await UserCrud.create( + db_session, UserCreate(username="alice", email="a@test.com") + ) + await UserCrud.create( + db_session, UserCreate(username="bob", email="b@test.com") + ) + await UserCrud.create( + db_session, UserCreate(username="charlie", email="c@test.com") + ) + + result = await UserFacetCrud.offset_paginate( + db_session, filter_by={"username": ["alice", "bob"]} + ) + + assert isinstance(result.pagination, OffsetPagination) + assert result.pagination.total_count == 2 + returned_names = {u.username for u in result.data} + assert returned_names == {"alice", "bob"} + + @pytest.mark.anyio + async def test_relationship_filter_by(self, db_session: AsyncSession): + """filter_by works with relationship tuple facet fields.""" + UserRelFacetCrud = CrudFactory(User, facet_fields=[(User.role, Role.name)]) + + admin = await RoleCrud.create(db_session, RoleCreate(name="admin")) + editor = await RoleCrud.create(db_session, RoleCreate(name="editor")) + await UserCrud.create( + db_session, + UserCreate(username="alice", email="a@test.com", role_id=admin.id), + ) + await UserCrud.create( + db_session, + UserCreate(username="bob", email="b@test.com", role_id=editor.id), + ) + + result = await UserRelFacetCrud.offset_paginate( + db_session, filter_by={"name": "admin"} + ) + + assert isinstance(result.pagination, OffsetPagination) + assert result.pagination.total_count == 1 + assert result.data[0].username == "alice" + + @pytest.mark.anyio + async def test_filter_by_combined_with_filters(self, db_session: AsyncSession): + """filter_by and filters= are combined (AND logic).""" + UserFacetCrud = CrudFactory(User, facet_fields=[User.username]) + await UserCrud.create( + db_session, UserCreate(username="alice", email="a@test.com", is_active=True) + ) + await UserCrud.create( + db_session, + UserCreate(username="alice2", email="a2@test.com", is_active=False), + ) + + result = await UserFacetCrud.offset_paginate( + db_session, + filters=[User.is_active == True], # noqa: E712 + filter_by={"username": ["alice", "alice2"]}, + ) + + # Only alice passes both: is_active=True AND username IN [alice, alice2] + assert isinstance(result.pagination, OffsetPagination) + assert result.pagination.total_count == 1 + assert result.data[0].username == "alice" + + @pytest.mark.anyio + async def test_invalid_key_raises(self, db_session: AsyncSession): + """filter_by with an undeclared key raises InvalidFacetFilterError.""" + UserFacetCrud = CrudFactory(User, facet_fields=[User.username]) + + with pytest.raises(InvalidFacetFilterError) as exc_info: + await UserFacetCrud.offset_paginate( + db_session, filter_by={"nonexistent": "value"} + ) + + assert exc_info.value.key == "nonexistent" + assert "username" in exc_info.value.valid_keys + + @pytest.mark.anyio + async def test_filter_by_deduplicates_relationship_join( + self, db_session: AsyncSession + ): + """Two filter_by keys through the same relationship do not duplicate the join.""" + # Both (User.role, Role.name) and (User.role, Role.id) traverse User.role — + # the second key must not re-add the join (exercises the dedup branch in build_filter_by). + UserRoleFacetCrud = CrudFactory( + User, + facet_fields=[(User.role, Role.name), (User.role, Role.id)], + ) + + admin = await RoleCrud.create(db_session, RoleCreate(name="admin")) + editor = await RoleCrud.create(db_session, RoleCreate(name="editor")) + await UserCrud.create( + db_session, + UserCreate(username="alice", email="a@test.com", role_id=admin.id), + ) + await UserCrud.create( + db_session, + UserCreate(username="bob", email="b@test.com", role_id=editor.id), + ) + + result = await UserRoleFacetCrud.offset_paginate( + db_session, + filter_by={"name": "admin", "id": str(admin.id)}, + ) + + assert isinstance(result.pagination, OffsetPagination) + assert result.pagination.total_count == 1 + assert result.data[0].username == "alice" + + @pytest.mark.anyio + async def test_cursor_paginate_filter_by(self, db_session: AsyncSession): + """filter_by works with cursor_paginate.""" + UserFacetCursorCrud = CrudFactory( + User, cursor_column=User.id, facet_fields=[User.username] + ) + await UserCrud.create( + db_session, UserCreate(username="alice", email="a@test.com") + ) + await UserCrud.create( + db_session, UserCreate(username="bob", email="b@test.com") + ) + + result = await UserFacetCursorCrud.cursor_paginate( + db_session, filter_by={"username": "alice"} + ) + + assert len(result.data) == 1 + assert result.data[0].username == "alice" + assert result.filter_attributes == {"username": ["alice"]} + + @pytest.mark.anyio + async def test_basemodel_filter_by_offset_paginate(self, db_session: AsyncSession): + """filter_by accepts a BaseModel instance (model_dump path) in offset_paginate.""" + from pydantic import BaseModel as PydanticBaseModel + + class UserFilter(PydanticBaseModel): + username: str | None = None + + UserFacetCrud = CrudFactory(User, facet_fields=[User.username]) + await UserCrud.create( + db_session, UserCreate(username="alice", email="a@test.com") + ) + await UserCrud.create( + db_session, UserCreate(username="bob", email="b@test.com") + ) + + result = await UserFacetCrud.offset_paginate( + db_session, filter_by=UserFilter(username="alice") + ) + + assert isinstance(result.pagination, OffsetPagination) + assert result.pagination.total_count == 1 + assert result.data[0].username == "alice" + + @pytest.mark.anyio + async def test_basemodel_filter_by_cursor_paginate(self, db_session: AsyncSession): + """filter_by accepts a BaseModel instance (model_dump path) in cursor_paginate.""" + from pydantic import BaseModel as PydanticBaseModel + + class UserFilter(PydanticBaseModel): + username: str | None = None + + UserFacetCursorCrud = CrudFactory( + User, cursor_column=User.id, facet_fields=[User.username] + ) + await UserCrud.create( + db_session, UserCreate(username="alice", email="a@test.com") + ) + await UserCrud.create( + db_session, UserCreate(username="bob", email="b@test.com") + ) + + result = await UserFacetCursorCrud.cursor_paginate( + db_session, filter_by=UserFilter(username="alice") + ) + + assert len(result.data) == 1 + assert result.data[0].username == "alice" + + +class TestFilterParamsSchema: + """Tests for AsyncCrud.filter_params().""" + + def test_generates_fields_from_facet_fields(self): + """Returned dependency has one keyword param per facet field.""" + import inspect + + UserFacetCrud = CrudFactory(User, facet_fields=[User.username, User.email]) + dep = UserFacetCrud.filter_params() + + param_names = set(inspect.signature(dep).parameters) + assert param_names == {"username", "email"} + + def test_relationship_facet_uses_column_key(self): + """Relationship tuple uses the terminal column's key.""" + import inspect + + UserRoleCrud = CrudFactory(User, facet_fields=[(User.role, Role.name)]) + dep = UserRoleCrud.filter_params() + + param_names = set(inspect.signature(dep).parameters) + assert param_names == {"name"} + + def test_raises_when_no_facet_fields(self): + """ValueError raised when no facet_fields are configured or provided.""" + with pytest.raises(ValueError, match="no facet_fields"): + UserCrud.filter_params() + + def test_facet_fields_override(self): + """facet_fields= parameter overrides the class-level default.""" + import inspect + + UserFacetCrud = CrudFactory(User, facet_fields=[User.username, User.email]) + dep = UserFacetCrud.filter_params(facet_fields=[User.email]) + + param_names = set(inspect.signature(dep).parameters) + assert param_names == {"email"} + + @pytest.mark.anyio + async def test_awaiting_dep_returns_dict_with_values(self): + """Awaiting the dependency returns a dict with only the supplied keys.""" + UserFacetCrud = CrudFactory(User, facet_fields=[User.username, User.email]) + dep = UserFacetCrud.filter_params() + + result = await dep(username=["alice"]) + assert result == {"username": ["alice"]} + + @pytest.mark.anyio + async def test_multi_value_list_field(self): + """Multiple values are accepted as a list.""" + UserFacetCrud = CrudFactory(User, facet_fields=[User.username]) + dep = UserFacetCrud.filter_params() + + result = await dep(username=["alice", "bob"]) + assert result == {"username": ["alice", "bob"]} + + def test_disambiguates_duplicate_column_keys(self): + """Two relationship tuples sharing a terminal column key get prefixed names.""" + from unittest.mock import MagicMock + + from fastapi_toolsets.crud.search import facet_keys + + col_a = MagicMock() + col_a.key = "name" + rel_a = MagicMock() + rel_a.key = "project" + + col_b = MagicMock() + col_b.key = "name" + rel_b = MagicMock() + rel_b.key = "os" + + keys = facet_keys([(rel_a, col_a), (rel_b, col_b)]) + assert keys == ["project__name", "os__name"] + + def test_unique_column_keys_kept_plain(self): + """Fields with unique column keys are not prefixed.""" + from fastapi_toolsets.crud.search import facet_keys + + keys = facet_keys([User.username, User.email]) + assert keys == ["username", "email"] + + def test_dependency_name_includes_model_name(self): + """Returned dependency is named {Model}FilterParams.""" + UserFacetCrud = CrudFactory(User, facet_fields=[User.username]) + dep = UserFacetCrud.filter_params() + + assert dep.__name__ == "UserFilterParams" # type: ignore[union-attr] + + @pytest.mark.anyio + async def test_integration_with_offset_paginate(self, db_session: AsyncSession): + """Dependency result can be passed directly to offset_paginate via filter_by.""" + UserFacetCrud = CrudFactory(User, facet_fields=[User.username]) + await UserCrud.create( + db_session, UserCreate(username="alice", email="a@test.com") + ) + await UserCrud.create( + db_session, UserCreate(username="bob", email="b@test.com") + ) + + dep = UserFacetCrud.filter_params() + f = await dep(username=["alice"]) + result = await UserFacetCrud.offset_paginate(db_session, filter_by=f) + + assert isinstance(result.pagination, OffsetPagination) + assert result.pagination.total_count == 1 + assert result.data[0].username == "alice" + + @pytest.mark.anyio + async def test_dep_result_passed_to_cursor_paginate(self, db_session: AsyncSession): + """Dependency result can be passed directly to cursor_paginate via filter_by.""" + UserFacetCursorCrud = CrudFactory( + User, cursor_column=User.id, facet_fields=[User.username] + ) + await UserCrud.create( + db_session, UserCreate(username="alice", email="a@test.com") + ) + await UserCrud.create( + db_session, UserCreate(username="bob", email="b@test.com") + ) + + dep = UserFacetCursorCrud.filter_params() + f = await dep(username=["alice"]) + result = await UserFacetCursorCrud.cursor_paginate(db_session, filter_by=f) + + assert len(result.data) == 1 + assert result.data[0].username == "alice" + + @pytest.mark.anyio + async def test_all_none_dep_result_passes_no_filter(self, db_session: AsyncSession): + """All-None dependency result results in no filter (returns all rows).""" + UserFacetCrud = CrudFactory(User, facet_fields=[User.username]) + await UserCrud.create( + db_session, UserCreate(username="alice", email="a@test.com") + ) + await UserCrud.create( + db_session, UserCreate(username="bob", email="b@test.com") + ) + + dep = UserFacetCrud.filter_params() + f = await dep() # all fields None + result = await UserFacetCrud.offset_paginate(db_session, filter_by=f) + + assert isinstance(result.pagination, OffsetPagination) + assert result.pagination.total_count == 2