mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-03-01 17:00:48 +01:00
feat: add faceted search in CrudFactory (#97)
* feat: add faceted search in CrudFactory * feat: add filter_params_schema in CrudFactory * fix: add missing Raises in build_search_filters docstring * fix: faceted search * fix: cov * fix: documentation/filter_params
This commit is contained in:
@@ -168,7 +168,19 @@ PostCrud = CrudFactory(model=Post, cursor_column=Post.created_at)
|
||||
|
||||
## Search
|
||||
|
||||
Declare searchable fields on the CRUD class. Relationship traversal is supported via tuples:
|
||||
Two search strategies are available, both compatible with [`offset_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.offset_paginate) and [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate).
|
||||
|
||||
| | Full-text search | Filter attributes |
|
||||
|---|---|---|
|
||||
| Input | Free-text string | Exact column values |
|
||||
| Relationship support | Yes | Yes |
|
||||
| Use case | Search bars | Filter dropdowns |
|
||||
|
||||
!!! info "You can use both search strategies in the same endpoint!"
|
||||
|
||||
### Full-text search
|
||||
|
||||
Declare `searchable_fields` on the CRUD class. Relationship traversal is supported via tuples:
|
||||
|
||||
```python
|
||||
PostCrud = CrudFactory(
|
||||
@@ -181,6 +193,15 @@ PostCrud = CrudFactory(
|
||||
)
|
||||
```
|
||||
|
||||
You can override `searchable_fields` per call with `search_fields`:
|
||||
|
||||
```python
|
||||
result = await UserCrud.offset_paginate(
|
||||
session=session,
|
||||
search_fields=[User.country],
|
||||
)
|
||||
```
|
||||
|
||||
This allows searching with both [`offset_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.offset_paginate) and [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate):
|
||||
|
||||
```python
|
||||
@@ -221,6 +242,87 @@ async def get_users(
|
||||
)
|
||||
```
|
||||
|
||||
### Filter attributes
|
||||
|
||||
!!! info "Added in `v1.2`"
|
||||
|
||||
Declare `facet_fields` on the CRUD class to return distinct column values alongside paginated results. This is useful for populating filter dropdowns or building faceted search UIs.
|
||||
|
||||
Facet fields use the same syntax as `searchable_fields` — direct columns or relationship tuples:
|
||||
|
||||
```python
|
||||
UserCrud = CrudFactory(
|
||||
model=User,
|
||||
facet_fields=[
|
||||
User.status,
|
||||
User.country,
|
||||
(User.role, Role.name), # value from a related model
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
You can override `facet_fields` per call:
|
||||
|
||||
```python
|
||||
result = await UserCrud.offset_paginate(
|
||||
session=session,
|
||||
facet_fields=[User.country],
|
||||
)
|
||||
```
|
||||
|
||||
The distinct values are returned in the `filter_attributes` field of [`PaginatedResponse`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse):
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"data": ["..."],
|
||||
"pagination": { "..." },
|
||||
"filter_attributes": {
|
||||
"status": ["active", "inactive"],
|
||||
"country": ["DE", "FR", "US"],
|
||||
"name": ["admin", "editor", "viewer"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Use `filter_by` to pass the client's chosen filter values directly — no need to build SQLAlchemy conditions by hand. Any unknown key raises [`InvalidFacetFilterError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.InvalidFacetFilterError).
|
||||
|
||||
!!! info "The keys in `filter_by` are the same keys the client received in `filter_attributes`."
|
||||
Keys are normally the terminal `column.key` (e.g. `"name"` for `Role.name`). When two facet fields share the same column key (e.g. `(Build.project, Project.name)` and `(Build.os, Os.name)`), the relationship name is prepended automatically: `"project__name"` and `"os__name"`.
|
||||
|
||||
`filter_by` and `filters` can be combined — both are applied with AND logic.
|
||||
|
||||
Use [`filter_params()`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.filter_params) to generate a dict with the facet filter values from the query parameters:
|
||||
|
||||
```python
|
||||
from fastapi import Depends
|
||||
|
||||
UserCrud = CrudFactory(
|
||||
model=User,
|
||||
facet_fields=[User.status, User.country, (User.role, Role.name)],
|
||||
)
|
||||
|
||||
@router.get("", response_model_exclude_none=True)
|
||||
async def list_users(
|
||||
session: SessionDep,
|
||||
page: int = 1,
|
||||
filter_by: dict[str, list[str]] = Depends(UserCrud.filter_params()),
|
||||
) -> PaginatedResponse[UserRead]:
|
||||
return await UserCrud.offset_paginate(
|
||||
session=session,
|
||||
page=page,
|
||||
filter_by=filter_by,
|
||||
)
|
||||
```
|
||||
|
||||
Both single-value and multi-value query parameters work:
|
||||
|
||||
```
|
||||
GET /users?status=active → filter_by={"status": ["active"]}
|
||||
GET /users?status=active&country=FR → filter_by={"status": ["active"], "country": ["FR"]}
|
||||
GET /users?role=admin&role=editor → filter_by={"role": ["admin", "editor"]} (IN clause)
|
||||
```
|
||||
|
||||
## Relationship loading
|
||||
|
||||
!!! info "Added in `v1.1`"
|
||||
|
||||
@@ -34,6 +34,7 @@ This registers handlers for:
|
||||
| [`NotFoundError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.NotFoundError) | 404 | Not found |
|
||||
| [`ConflictError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.ConflictError) | 409 | Conflict |
|
||||
| [`NoSearchableFieldsError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.NoSearchableFieldsError) | 400 | No searchable fields |
|
||||
| [`InvalidFacetFilterError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.InvalidFacetFilterError) | 400 | Invalid facet filter |
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.exceptions import NotFoundError
|
||||
|
||||
@@ -22,7 +22,7 @@ async def get_user(user: User = UserDep) -> Response[UserSchema]:
|
||||
|
||||
### [`PaginatedResponse[T]`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse)
|
||||
|
||||
Wraps a list of items with pagination metadata.
|
||||
Wraps a list of items with pagination metadata and optional facet values.
|
||||
|
||||
```python
|
||||
from fastapi_toolsets.schemas import PaginatedResponse, Pagination
|
||||
@@ -40,6 +40,8 @@ async def list_users() -> PaginatedResponse[UserSchema]:
|
||||
)
|
||||
```
|
||||
|
||||
The optional `filter_attributes` field is populated when `facet_fields` are configured on the CRUD class (see [Filter attributes](crud.md#filter-attributes-facets)). It is `None` by default and can be hidden from API responses with `response_model_exclude_none=True`.
|
||||
|
||||
### [`ErrorResponse`](../reference/schemas.md#fastapi_toolsets.schemas.ErrorResponse)
|
||||
|
||||
Returned automatically by the exceptions handler.
|
||||
|
||||
@@ -12,6 +12,7 @@ from fastapi_toolsets.exceptions import (
|
||||
NotFoundError,
|
||||
ConflictError,
|
||||
NoSearchableFieldsError,
|
||||
InvalidFacetFilterError,
|
||||
generate_error_responses,
|
||||
init_exceptions_handlers,
|
||||
)
|
||||
@@ -29,6 +30,8 @@ from fastapi_toolsets.exceptions import (
|
||||
|
||||
## ::: fastapi_toolsets.exceptions.exceptions.NoSearchableFieldsError
|
||||
|
||||
## ::: fastapi_toolsets.exceptions.exceptions.InvalidFacetFilterError
|
||||
|
||||
## ::: fastapi_toolsets.exceptions.exceptions.generate_error_responses
|
||||
|
||||
## ::: fastapi_toolsets.exceptions.handler.init_exceptions_handlers
|
||||
|
||||
@@ -1,15 +1,18 @@
|
||||
"""Generic async CRUD operations for SQLAlchemy models."""
|
||||
|
||||
from ..exceptions import NoSearchableFieldsError
|
||||
from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError
|
||||
from .factory import CrudFactory, JoinType, M2MFieldType
|
||||
from .search import (
|
||||
FacetFieldType,
|
||||
SearchConfig,
|
||||
get_searchable_fields,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CrudFactory",
|
||||
"FacetFieldType",
|
||||
"get_searchable_fields",
|
||||
"InvalidFacetFilterError",
|
||||
"JoinType",
|
||||
"M2MFieldType",
|
||||
"NoSearchableFieldsError",
|
||||
|
||||
@@ -3,14 +3,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import inspect
|
||||
import json
|
||||
import uuid as uuid_module
|
||||
import warnings
|
||||
from collections.abc import Mapping, Sequence
|
||||
from collections.abc import Awaitable, Callable, Mapping, Sequence
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload
|
||||
|
||||
from fastapi import Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import Date, DateTime, Float, Integer, Numeric, Uuid, and_, func, select
|
||||
from sqlalchemy import delete as sql_delete
|
||||
@@ -24,7 +26,15 @@ from sqlalchemy.sql.roles import WhereHavingRole
|
||||
from ..db import get_transaction
|
||||
from ..exceptions import NotFoundError
|
||||
from ..schemas import CursorPagination, OffsetPagination, PaginatedResponse, Response
|
||||
from .search import SearchConfig, SearchFieldType, build_search_filters
|
||||
from .search import (
|
||||
FacetFieldType,
|
||||
SearchConfig,
|
||||
SearchFieldType,
|
||||
build_facets,
|
||||
build_filter_by,
|
||||
build_search_filters,
|
||||
facet_keys,
|
||||
)
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
||||
SchemaType = TypeVar("SchemaType", bound=BaseModel)
|
||||
@@ -50,6 +60,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
|
||||
model: ClassVar[type[DeclarativeBase]]
|
||||
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
|
||||
facet_fields: ClassVar[Sequence[FacetFieldType] | None] = None
|
||||
m2m_fields: ClassVar[M2MFieldType | None] = None
|
||||
default_load_options: ClassVar[list[ExecutableOption] | None] = None
|
||||
cursor_column: ClassVar[Any | None] = None
|
||||
@@ -119,6 +130,52 @@ class AsyncCrud(Generic[ModelType]):
|
||||
return set()
|
||||
return set(cls.m2m_fields.keys())
|
||||
|
||||
@classmethod
|
||||
def filter_params(
|
||||
cls: type[Self],
|
||||
*,
|
||||
facet_fields: Sequence[FacetFieldType] | None = None,
|
||||
) -> Callable[..., Awaitable[dict[str, list[str]]]]:
|
||||
"""Return a FastAPI dependency that collects facet filter values from query parameters.
|
||||
Args:
|
||||
facet_fields: Override the facet fields for this dependency. Falls back to the
|
||||
class-level ``facet_fields`` if not provided.
|
||||
|
||||
Returns:
|
||||
An async dependency function named ``{Model}FilterParams`` that resolves to a
|
||||
``dict[str, list[str]]`` containing only the keys that were supplied in the
|
||||
request (absent/``None`` parameters are excluded).
|
||||
|
||||
Raises:
|
||||
ValueError: If no facet fields are configured on this CRUD class and none are
|
||||
provided via ``facet_fields``.
|
||||
"""
|
||||
fields = facet_fields if facet_fields is not None else cls.facet_fields
|
||||
if not fields:
|
||||
raise ValueError(
|
||||
f"{cls.__name__} has no facet_fields configured. "
|
||||
"Pass facet_fields= or set them on CrudFactory."
|
||||
)
|
||||
keys = facet_keys(fields)
|
||||
|
||||
async def dependency(**kwargs: Any) -> dict[str, list[str]]:
|
||||
return {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
dependency.__name__ = f"{cls.model.__name__}FilterParams"
|
||||
dependency.__signature__ = inspect.Signature( # type: ignore[attr-defined]
|
||||
parameters=[
|
||||
inspect.Parameter(
|
||||
k,
|
||||
inspect.Parameter.KEYWORD_ONLY,
|
||||
annotation=list[str] | None,
|
||||
default=Query(default=None),
|
||||
)
|
||||
for k in keys
|
||||
]
|
||||
)
|
||||
|
||||
return dependency
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def create( # pragma: no cover
|
||||
@@ -693,6 +750,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
items_per_page: int = 20,
|
||||
search: str | SearchConfig | None = None,
|
||||
search_fields: Sequence[SearchFieldType] | None = None,
|
||||
facet_fields: Sequence[FacetFieldType] | None = None,
|
||||
filter_by: dict[str, Any] | BaseModel | None = None,
|
||||
schema: type[SchemaType],
|
||||
) -> PaginatedResponse[SchemaType]: ...
|
||||
|
||||
@@ -712,6 +771,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
items_per_page: int = 20,
|
||||
search: str | SearchConfig | None = None,
|
||||
search_fields: Sequence[SearchFieldType] | None = None,
|
||||
facet_fields: Sequence[FacetFieldType] | None = None,
|
||||
filter_by: dict[str, Any] | BaseModel | None = None,
|
||||
schema: None = ...,
|
||||
) -> PaginatedResponse[ModelType]: ...
|
||||
|
||||
@@ -729,6 +790,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
items_per_page: int = 20,
|
||||
search: str | SearchConfig | None = None,
|
||||
search_fields: Sequence[SearchFieldType] | None = None,
|
||||
facet_fields: Sequence[FacetFieldType] | None = None,
|
||||
filter_by: dict[str, Any] | BaseModel | None = None,
|
||||
schema: type[BaseModel] | None = None,
|
||||
) -> PaginatedResponse[ModelType] | PaginatedResponse[Any]:
|
||||
"""Get paginated results using offset-based pagination.
|
||||
@@ -744,6 +807,10 @@ class AsyncCrud(Generic[ModelType]):
|
||||
items_per_page: Number of items per page
|
||||
search: Search query string or SearchConfig object
|
||||
search_fields: Fields to search in (overrides class default)
|
||||
facet_fields: Columns to compute distinct values for (overrides class default)
|
||||
filter_by: Dict of {column_key: value} to filter by declared facet fields.
|
||||
Keys must match the column.key of a facet field. Scalar → equality,
|
||||
list → IN clause. Raises InvalidFacetFilterError for unknown keys.
|
||||
schema: Optional Pydantic schema to serialize each item into.
|
||||
|
||||
Returns:
|
||||
@@ -753,6 +820,20 @@ class AsyncCrud(Generic[ModelType]):
|
||||
offset = (page - 1) * items_per_page
|
||||
search_joins: list[Any] = []
|
||||
|
||||
if isinstance(filter_by, BaseModel):
|
||||
filter_by = filter_by.model_dump(exclude_none=True) or None
|
||||
|
||||
# Build filter_by conditions from declared facet fields
|
||||
if filter_by:
|
||||
resolved_facets_for_filter = (
|
||||
facet_fields if facet_fields is not None else cls.facet_fields
|
||||
)
|
||||
fb_filters, fb_joins = build_filter_by(
|
||||
filter_by, resolved_facets_for_filter or []
|
||||
)
|
||||
filters.extend(fb_filters)
|
||||
search_joins.extend(fb_joins)
|
||||
|
||||
# Build search filters
|
||||
if search:
|
||||
search_filters, search_joins = build_search_filters(
|
||||
@@ -817,6 +898,20 @@ class AsyncCrud(Generic[ModelType]):
|
||||
count_result = await session.execute(count_q)
|
||||
total_count = count_result.scalar_one()
|
||||
|
||||
# Build facets
|
||||
resolved_facet_fields = (
|
||||
facet_fields if facet_fields is not None else cls.facet_fields
|
||||
)
|
||||
filter_attributes: dict[str, list[Any]] | None = None
|
||||
if resolved_facet_fields:
|
||||
filter_attributes = await build_facets(
|
||||
session,
|
||||
cls.model,
|
||||
resolved_facet_fields,
|
||||
base_filters=filters or None,
|
||||
base_joins=search_joins or None,
|
||||
)
|
||||
|
||||
return PaginatedResponse(
|
||||
data=items,
|
||||
pagination=OffsetPagination(
|
||||
@@ -825,6 +920,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
page=page,
|
||||
has_more=page * items_per_page < total_count,
|
||||
),
|
||||
filter_attributes=filter_attributes,
|
||||
)
|
||||
|
||||
# Backward-compatible - will be removed in v2.0
|
||||
@@ -845,6 +941,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
items_per_page: int = 20,
|
||||
search: str | SearchConfig | None = None,
|
||||
search_fields: Sequence[SearchFieldType] | None = None,
|
||||
facet_fields: Sequence[FacetFieldType] | None = None,
|
||||
filter_by: dict[str, Any] | BaseModel | None = None,
|
||||
schema: type[SchemaType],
|
||||
) -> PaginatedResponse[SchemaType]: ...
|
||||
|
||||
@@ -864,6 +962,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
items_per_page: int = 20,
|
||||
search: str | SearchConfig | None = None,
|
||||
search_fields: Sequence[SearchFieldType] | None = None,
|
||||
facet_fields: Sequence[FacetFieldType] | None = None,
|
||||
filter_by: dict[str, Any] | BaseModel | None = None,
|
||||
schema: None = ...,
|
||||
) -> PaginatedResponse[ModelType]: ...
|
||||
|
||||
@@ -881,6 +981,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
items_per_page: int = 20,
|
||||
search: str | SearchConfig | None = None,
|
||||
search_fields: Sequence[SearchFieldType] | None = None,
|
||||
facet_fields: Sequence[FacetFieldType] | None = None,
|
||||
filter_by: dict[str, Any] | BaseModel | None = None,
|
||||
schema: type[BaseModel] | None = None,
|
||||
) -> PaginatedResponse[ModelType] | PaginatedResponse[Any]:
|
||||
"""Get paginated results using cursor-based pagination.
|
||||
@@ -899,6 +1001,10 @@ class AsyncCrud(Generic[ModelType]):
|
||||
items_per_page: Number of items per page (default 20).
|
||||
search: Search query string or SearchConfig object.
|
||||
search_fields: Fields to search in (overrides class default).
|
||||
facet_fields: Columns to compute distinct values for (overrides class default).
|
||||
filter_by: Dict of {column_key: value} to filter by declared facet fields.
|
||||
Keys must match the column.key of a facet field. Scalar → equality,
|
||||
list → IN clause. Raises InvalidFacetFilterError for unknown keys.
|
||||
schema: Optional Pydantic schema to serialize each item into.
|
||||
|
||||
Returns:
|
||||
@@ -907,6 +1013,20 @@ class AsyncCrud(Generic[ModelType]):
|
||||
filters = list(filters) if filters else []
|
||||
search_joins: list[Any] = []
|
||||
|
||||
if isinstance(filter_by, BaseModel):
|
||||
filter_by = filter_by.model_dump(exclude_none=True) or None
|
||||
|
||||
# Build filter_by conditions from declared facet fields
|
||||
if filter_by:
|
||||
resolved_facets_for_filter = (
|
||||
facet_fields if facet_fields is not None else cls.facet_fields
|
||||
)
|
||||
fb_filters, fb_joins = build_filter_by(
|
||||
filter_by, resolved_facets_for_filter or []
|
||||
)
|
||||
filters.extend(fb_filters)
|
||||
search_joins.extend(fb_joins)
|
||||
|
||||
if cls.cursor_column is None:
|
||||
raise ValueError(
|
||||
f"{cls.__name__}.cursor_column is not set. "
|
||||
@@ -996,6 +1116,20 @@ class AsyncCrud(Generic[ModelType]):
|
||||
else items_page
|
||||
)
|
||||
|
||||
# Build facets
|
||||
resolved_facet_fields = (
|
||||
facet_fields if facet_fields is not None else cls.facet_fields
|
||||
)
|
||||
filter_attributes: dict[str, list[Any]] | None = None
|
||||
if resolved_facet_fields:
|
||||
filter_attributes = await build_facets(
|
||||
session,
|
||||
cls.model,
|
||||
resolved_facet_fields,
|
||||
base_filters=filters or None,
|
||||
base_joins=search_joins or None,
|
||||
)
|
||||
|
||||
return PaginatedResponse(
|
||||
data=items,
|
||||
pagination=CursorPagination(
|
||||
@@ -1004,6 +1138,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
items_per_page=items_per_page,
|
||||
has_more=has_more,
|
||||
),
|
||||
filter_attributes=filter_attributes,
|
||||
)
|
||||
|
||||
|
||||
@@ -1011,6 +1146,7 @@ def CrudFactory(
|
||||
model: type[ModelType],
|
||||
*,
|
||||
searchable_fields: Sequence[SearchFieldType] | None = None,
|
||||
facet_fields: Sequence[FacetFieldType] | None = None,
|
||||
m2m_fields: M2MFieldType | None = None,
|
||||
default_load_options: list[ExecutableOption] | None = None,
|
||||
cursor_column: Any | None = None,
|
||||
@@ -1020,6 +1156,9 @@ def CrudFactory(
|
||||
Args:
|
||||
model: SQLAlchemy model class
|
||||
searchable_fields: Optional list of searchable fields
|
||||
facet_fields: Optional list of columns to compute distinct values for in paginated
|
||||
responses. Supports direct columns (``User.status``) and relationship tuples
|
||||
(``(User.role, Role.name)``). Can be overridden per call.
|
||||
m2m_fields: Optional mapping for many-to-many relationships.
|
||||
Maps schema field names (containing lists of IDs) to
|
||||
SQLAlchemy relationship attributes.
|
||||
@@ -1056,6 +1195,12 @@ def CrudFactory(
|
||||
m2m_fields={"tag_ids": Post.tags},
|
||||
)
|
||||
|
||||
# With facet fields for filter dropdowns / faceted search:
|
||||
UserCrud = CrudFactory(
|
||||
User,
|
||||
facet_fields=[User.status, User.country, (User.role, Role.name)],
|
||||
)
|
||||
|
||||
# With a fixed cursor column for cursor_paginate:
|
||||
PostCrud = CrudFactory(
|
||||
Post,
|
||||
@@ -1106,6 +1251,7 @@ def CrudFactory(
|
||||
{
|
||||
"model": model,
|
||||
"searchable_fields": searchable_fields,
|
||||
"facet_fields": facet_fields,
|
||||
"m2m_fields": m2m_fields,
|
||||
"default_load_options": default_load_options,
|
||||
"cursor_column": cursor_column,
|
||||
|
||||
@@ -1,19 +1,23 @@
|
||||
"""Search utilities for AsyncCrud."""
|
||||
|
||||
import asyncio
|
||||
from collections import Counter
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from sqlalchemy import String, or_
|
||||
from sqlalchemy import String, or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||
|
||||
from ..exceptions import NoSearchableFieldsError
|
||||
from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
|
||||
SearchFieldType = InstrumentedAttribute[Any] | tuple[InstrumentedAttribute[Any], ...]
|
||||
FacetFieldType = SearchFieldType
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -89,6 +93,9 @@ def build_search_filters(
|
||||
|
||||
Returns:
|
||||
Tuple of (filter_conditions, joins_needed)
|
||||
|
||||
Raises:
|
||||
NoSearchableFieldsError: If no searchable field has been configured
|
||||
"""
|
||||
# Normalize input
|
||||
if isinstance(search, str):
|
||||
@@ -136,7 +143,7 @@ def build_search_filters(
|
||||
else:
|
||||
filters.append(column_as_string.ilike(f"%{query}%"))
|
||||
|
||||
if not filters:
|
||||
if not filters: # pragma: no cover
|
||||
return [], []
|
||||
|
||||
# Combine based on match_mode
|
||||
@@ -144,3 +151,145 @@ def build_search_filters(
|
||||
return [or_(*filters)], joins
|
||||
else:
|
||||
return filters, joins
|
||||
|
||||
|
||||
def facet_keys(facet_fields: Sequence[FacetFieldType]) -> list[str]:
|
||||
"""Return a key for each facet field, disambiguating duplicate column keys.
|
||||
|
||||
Args:
|
||||
facet_fields: Sequence of facet fields — either direct columns or
|
||||
relationship tuples ``(rel, ..., column)``.
|
||||
|
||||
Returns:
|
||||
A list of string keys, one per facet field, in the same order.
|
||||
"""
|
||||
raw: list[tuple[str, str | None]] = []
|
||||
for field in facet_fields:
|
||||
if isinstance(field, tuple):
|
||||
rel = field[-2]
|
||||
column = field[-1]
|
||||
raw.append((column.key, rel.key))
|
||||
else:
|
||||
raw.append((field.key, None))
|
||||
|
||||
counts = Counter(col_key for col_key, _ in raw)
|
||||
keys: list[str] = []
|
||||
for col_key, rel_key in raw:
|
||||
if counts[col_key] > 1 and rel_key is not None:
|
||||
keys.append(f"{rel_key}__{col_key}")
|
||||
else:
|
||||
keys.append(col_key)
|
||||
return keys
|
||||
|
||||
|
||||
async def build_facets(
|
||||
session: "AsyncSession",
|
||||
model: type[DeclarativeBase],
|
||||
facet_fields: Sequence[FacetFieldType],
|
||||
*,
|
||||
base_filters: "list[ColumnElement[bool]] | None" = None,
|
||||
base_joins: list[InstrumentedAttribute[Any]] | None = None,
|
||||
) -> dict[str, list[Any]]:
|
||||
"""Return distinct values for each facet field, respecting current filters.
|
||||
|
||||
Args:
|
||||
session: DB async session
|
||||
model: SQLAlchemy model class
|
||||
facet_fields: Columns or relationship tuples to facet on
|
||||
base_filters: Filter conditions already applied to the main query (search + caller filters)
|
||||
base_joins: Relationship joins already applied to the main query
|
||||
|
||||
Returns:
|
||||
Dict mapping column key to sorted list of distinct non-None values
|
||||
"""
|
||||
existing_join_keys: set[str] = {str(j) for j in (base_joins or [])}
|
||||
|
||||
keys = facet_keys(facet_fields)
|
||||
|
||||
async def _query_facet(field: FacetFieldType, key: str) -> tuple[str, list[Any]]:
|
||||
if isinstance(field, tuple):
|
||||
# Relationship chain: (User.role, Role.name) — last element is the column
|
||||
rels = field[:-1]
|
||||
column = field[-1]
|
||||
else:
|
||||
rels = ()
|
||||
column = field
|
||||
|
||||
q = select(column).select_from(model).distinct()
|
||||
|
||||
# Apply base joins (already done on main query, but needed here independently)
|
||||
for rel in base_joins or []:
|
||||
q = q.outerjoin(rel)
|
||||
|
||||
# Add any extra joins required by this facet field that aren't already in base_joins
|
||||
for rel in rels:
|
||||
if str(rel) not in existing_join_keys:
|
||||
q = q.outerjoin(rel)
|
||||
|
||||
if base_filters:
|
||||
from sqlalchemy import and_
|
||||
|
||||
q = q.where(and_(*base_filters))
|
||||
|
||||
q = q.order_by(column)
|
||||
result = await session.execute(q)
|
||||
values = [row[0] for row in result.all() if row[0] is not None]
|
||||
return key, values
|
||||
|
||||
pairs = await asyncio.gather(
|
||||
*[_query_facet(f, k) for f, k in zip(facet_fields, keys)]
|
||||
)
|
||||
return dict(pairs)
|
||||
|
||||
|
||||
def build_filter_by(
|
||||
filter_by: dict[str, Any],
|
||||
facet_fields: Sequence[FacetFieldType],
|
||||
) -> tuple["list[ColumnElement[bool]]", list[InstrumentedAttribute[Any]]]:
|
||||
"""Translate a {column_key: value} dict into SQLAlchemy filter conditions.
|
||||
|
||||
Args:
|
||||
filter_by: Mapping of column key to scalar value or list of values
|
||||
facet_fields: Declared facet fields to validate keys against
|
||||
|
||||
Returns:
|
||||
Tuple of (filter_conditions, joins_needed)
|
||||
|
||||
Raises:
|
||||
InvalidFacetFilterError: If a key in filter_by is not a declared facet field
|
||||
"""
|
||||
index: dict[
|
||||
str, tuple[InstrumentedAttribute[Any], list[InstrumentedAttribute[Any]]]
|
||||
] = {}
|
||||
for key, field in zip(facet_keys(facet_fields), facet_fields):
|
||||
if isinstance(field, tuple):
|
||||
rels = list(field[:-1])
|
||||
column = field[-1]
|
||||
else:
|
||||
rels = []
|
||||
column = field
|
||||
index[key] = (column, rels)
|
||||
|
||||
valid_keys = set(index)
|
||||
filters: list[ColumnElement[bool]] = []
|
||||
joins: list[InstrumentedAttribute[Any]] = []
|
||||
added_join_keys: set[str] = set()
|
||||
|
||||
for key, value in filter_by.items():
|
||||
if key not in index:
|
||||
raise InvalidFacetFilterError(key, valid_keys)
|
||||
|
||||
column, rels = index[key]
|
||||
|
||||
for rel in rels:
|
||||
rel_key = str(rel)
|
||||
if rel_key not in added_join_keys:
|
||||
joins.append(rel)
|
||||
added_join_keys.add(rel_key)
|
||||
|
||||
if isinstance(value, list):
|
||||
filters.append(column.in_(value))
|
||||
else:
|
||||
filters.append(column == value)
|
||||
|
||||
return filters, joins
|
||||
|
||||
@@ -5,6 +5,7 @@ from .exceptions import (
|
||||
ApiException,
|
||||
ConflictError,
|
||||
ForbiddenError,
|
||||
InvalidFacetFilterError,
|
||||
NoSearchableFieldsError,
|
||||
NotFoundError,
|
||||
UnauthorizedError,
|
||||
@@ -19,6 +20,7 @@ __all__ = [
|
||||
"ForbiddenError",
|
||||
"generate_error_responses",
|
||||
"init_exceptions_handlers",
|
||||
"InvalidFacetFilterError",
|
||||
"NoSearchableFieldsError",
|
||||
"NotFoundError",
|
||||
"UnauthorizedError",
|
||||
|
||||
@@ -102,6 +102,32 @@ class NoSearchableFieldsError(ApiException):
|
||||
super().__init__(detail)
|
||||
|
||||
|
||||
class InvalidFacetFilterError(ApiException):
|
||||
"""Raised when filter_by contains a key not declared in facet_fields."""
|
||||
|
||||
api_error = ApiError(
|
||||
code=400,
|
||||
msg="Invalid Facet Filter",
|
||||
desc="One or more filter_by keys are not declared as facet fields.",
|
||||
err_code="FACET-400",
|
||||
)
|
||||
|
||||
def __init__(self, key: str, valid_keys: set[str]) -> None:
|
||||
"""Initialize the exception.
|
||||
|
||||
Args:
|
||||
key: The unknown filter key provided by the caller
|
||||
valid_keys: Set of valid keys derived from the declared facet_fields
|
||||
"""
|
||||
self.key = key
|
||||
self.valid_keys = valid_keys
|
||||
detail = (
|
||||
f"'{key}' is not a declared facet field. "
|
||||
f"Valid keys: {sorted(valid_keys) or 'none — set facet_fields on the CRUD class'}."
|
||||
)
|
||||
super().__init__(detail)
|
||||
|
||||
|
||||
def generate_error_responses(
|
||||
*errors: type[ApiException],
|
||||
) -> dict[int | str, dict[str, Any]]:
|
||||
|
||||
@@ -133,3 +133,4 @@ class PaginatedResponse(BaseResponse, Generic[DataT]):
|
||||
|
||||
data: list[DataT]
|
||||
pagination: OffsetPagination | CursorPagination
|
||||
filter_attributes: dict[str, list[Any]] | None = None
|
||||
|
||||
@@ -5,7 +5,12 @@ import uuid
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from fastapi_toolsets.crud import SearchConfig, get_searchable_fields
|
||||
from fastapi_toolsets.crud import (
|
||||
CrudFactory,
|
||||
InvalidFacetFilterError,
|
||||
SearchConfig,
|
||||
get_searchable_fields,
|
||||
)
|
||||
from fastapi_toolsets.schemas import OffsetPagination
|
||||
|
||||
from .conftest import (
|
||||
@@ -313,9 +318,35 @@ class TestPaginateSearch:
|
||||
assert result.data[0].id == user_id
|
||||
|
||||
|
||||
class TestBuildSearchFilters:
|
||||
"""Unit tests for build_search_filters."""
|
||||
|
||||
def test_deduplicates_relationship_join(self):
|
||||
"""Two tuple fields sharing the same relationship do not add the join twice."""
|
||||
from fastapi_toolsets.crud.search import build_search_filters
|
||||
|
||||
# Both fields traverse User.role — the second must not re-add the join.
|
||||
filters, joins = build_search_filters(
|
||||
User,
|
||||
"admin",
|
||||
search_fields=[(User.role, Role.name), (User.role, Role.id)],
|
||||
)
|
||||
|
||||
assert len(joins) == 1
|
||||
|
||||
|
||||
class TestSearchConfig:
|
||||
"""Tests for SearchConfig options."""
|
||||
|
||||
def test_search_config_empty_query_returns_empty(self):
|
||||
"""SearchConfig with an empty/blank query returns empty filters without hitting the DB."""
|
||||
from fastapi_toolsets.crud.search import build_search_filters
|
||||
|
||||
filters, joins = build_search_filters(User, SearchConfig(query=" "))
|
||||
|
||||
assert filters == []
|
||||
assert joins == []
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_match_mode_all(self, db_session: AsyncSession):
|
||||
"""match_mode='all' requires all fields to match (AND)."""
|
||||
@@ -432,3 +463,554 @@ class TestGetSearchableFields:
|
||||
# Role.users is a collection, should not be included
|
||||
field_strs = [str(f) for f in fields]
|
||||
assert not any("users" in f for f in field_strs)
|
||||
|
||||
|
||||
class TestFacetsNotSet:
|
||||
"""filter_attributes is None when no facet_fields are configured."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_offset_paginate_no_facets(self, db_session: AsyncSession):
|
||||
"""filter_attributes is None when facet_fields not set on factory or call."""
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="alice", email="a@test.com")
|
||||
)
|
||||
|
||||
result = await UserCrud.offset_paginate(db_session)
|
||||
|
||||
assert result.filter_attributes is None
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_cursor_paginate_no_facets(self, db_session: AsyncSession):
|
||||
"""filter_attributes is None for cursor_paginate when facet_fields not set."""
|
||||
UserCursorCrud = CrudFactory(User, cursor_column=User.id)
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="alice", email="a@test.com")
|
||||
)
|
||||
|
||||
result = await UserCursorCrud.cursor_paginate(db_session)
|
||||
|
||||
assert result.filter_attributes is None
|
||||
|
||||
|
||||
class TestFacetsDirectColumn:
|
||||
"""Facets on direct model columns."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_offset_paginate_direct_column(self, db_session: AsyncSession):
|
||||
"""Returns distinct values for a direct column via factory default."""
|
||||
UserFacetCrud = CrudFactory(User, facet_fields=[User.username])
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="alice", email="a@test.com")
|
||||
)
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="bob", email="b@test.com")
|
||||
)
|
||||
|
||||
result = await UserFacetCrud.offset_paginate(db_session)
|
||||
|
||||
assert result.filter_attributes is not None
|
||||
# Distinct usernames, sorted
|
||||
assert result.filter_attributes["username"] == ["alice", "bob"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_cursor_paginate_direct_column(self, db_session: AsyncSession):
|
||||
"""Returns distinct values for a direct column in cursor_paginate."""
|
||||
UserFacetCursorCrud = CrudFactory(
|
||||
User, cursor_column=User.id, facet_fields=[User.email]
|
||||
)
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="alice", email="a@test.com")
|
||||
)
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="bob", email="b@test.com")
|
||||
)
|
||||
|
||||
result = await UserFacetCursorCrud.cursor_paginate(db_session)
|
||||
|
||||
assert result.filter_attributes is not None
|
||||
assert set(result.filter_attributes["email"]) == {"a@test.com", "b@test.com"}
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_multiple_facet_columns(self, db_session: AsyncSession):
|
||||
"""Returns distinct values for multiple columns."""
|
||||
UserFacetCrud = CrudFactory(User, facet_fields=[User.username, User.email])
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="alice", email="a@test.com")
|
||||
)
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="bob", email="b@test.com")
|
||||
)
|
||||
|
||||
result = await UserFacetCrud.offset_paginate(db_session)
|
||||
|
||||
assert result.filter_attributes is not None
|
||||
assert "username" in result.filter_attributes
|
||||
assert "email" in result.filter_attributes
|
||||
assert set(result.filter_attributes["username"]) == {"alice", "bob"}
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_per_call_override(self, db_session: AsyncSession):
|
||||
"""Per-call facet_fields overrides the factory default."""
|
||||
UserFacetCrud = CrudFactory(User, facet_fields=[User.username])
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="alice", email="a@test.com")
|
||||
)
|
||||
|
||||
# Override: ask for email instead of username
|
||||
result = await UserFacetCrud.offset_paginate(
|
||||
db_session, facet_fields=[User.email]
|
||||
)
|
||||
|
||||
assert result.filter_attributes is not None
|
||||
assert "email" in result.filter_attributes
|
||||
assert "username" not in result.filter_attributes
|
||||
|
||||
|
||||
class TestFacetsRespectFilters:
|
||||
"""Facets reflect the active filter conditions."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_facets_respect_base_filters(self, db_session: AsyncSession):
|
||||
"""Facet values are scoped to the applied filters."""
|
||||
UserFacetCrud = CrudFactory(User, facet_fields=[User.username])
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="alice", email="a@test.com", is_active=True)
|
||||
)
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="bob", email="b@test.com", is_active=False)
|
||||
)
|
||||
|
||||
# Filter to active users only — facets should only see "alice"
|
||||
result = await UserFacetCrud.offset_paginate(
|
||||
db_session,
|
||||
filters=[User.is_active == True], # noqa: E712
|
||||
)
|
||||
|
||||
assert result.filter_attributes is not None
|
||||
assert result.filter_attributes["username"] == ["alice"]
|
||||
|
||||
|
||||
class TestFacetsRelationship:
|
||||
"""Facets on relationship columns via tuple syntax."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_relationship_facet(self, db_session: AsyncSession):
|
||||
"""Returns distinct values from a related model column."""
|
||||
UserRelFacetCrud = CrudFactory(User, facet_fields=[(User.role, Role.name)])
|
||||
|
||||
admin = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||
editor = await RoleCrud.create(db_session, RoleCreate(name="editor"))
|
||||
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="alice", email="a@test.com", role_id=admin.id),
|
||||
)
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="bob", email="b@test.com", role_id=editor.id),
|
||||
)
|
||||
# User without a role — their role.name should be excluded (None filtered out)
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="charlie", email="c@test.com")
|
||||
)
|
||||
|
||||
result = await UserRelFacetCrud.offset_paginate(db_session)
|
||||
|
||||
assert result.filter_attributes is not None
|
||||
assert set(result.filter_attributes["name"]) == {"admin", "editor"}
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_relationship_facet_none_excluded(self, db_session: AsyncSession):
|
||||
"""None values (e.g. NULL role) are excluded from facet results."""
|
||||
UserRelFacetCrud = CrudFactory(User, facet_fields=[(User.role, Role.name)])
|
||||
|
||||
# Only user with no role
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="norole", email="n@test.com")
|
||||
)
|
||||
|
||||
result = await UserRelFacetCrud.offset_paginate(db_session)
|
||||
|
||||
assert result.filter_attributes is not None
|
||||
assert result.filter_attributes["name"] == []
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_relationship_facet_deduplicates_join_with_search(
|
||||
self, db_session: AsyncSession
|
||||
):
|
||||
"""Facet join is not duplicated when search already added the same relationship join."""
|
||||
# Both search and facet use (User.role, Role.name) — join should not be doubled
|
||||
UserSearchFacetCrud = CrudFactory(
|
||||
User,
|
||||
searchable_fields=[(User.role, Role.name)],
|
||||
facet_fields=[(User.role, Role.name)],
|
||||
)
|
||||
|
||||
admin = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="alice", email="a@test.com", role_id=admin.id),
|
||||
)
|
||||
|
||||
result = await UserSearchFacetCrud.offset_paginate(
|
||||
db_session, search="admin", search_fields=[(User.role, Role.name)]
|
||||
)
|
||||
|
||||
assert result.filter_attributes is not None
|
||||
assert result.filter_attributes["name"] == ["admin"]
|
||||
|
||||
|
||||
class TestFilterBy:
|
||||
"""Tests for the filter_by parameter on offset_paginate and cursor_paginate."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_scalar_filter(self, db_session: AsyncSession):
|
||||
"""filter_by with a scalar value produces an equality filter."""
|
||||
UserFacetCrud = CrudFactory(User, facet_fields=[User.username])
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="alice", email="a@test.com")
|
||||
)
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="bob", email="b@test.com")
|
||||
)
|
||||
|
||||
result = await UserFacetCrud.offset_paginate(
|
||||
db_session, filter_by={"username": "alice"}
|
||||
)
|
||||
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0].username == "alice"
|
||||
# facet also scoped to the filter
|
||||
assert result.filter_attributes == {"username": ["alice"]}
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_filter_produces_in_clause(self, db_session: AsyncSession):
|
||||
"""filter_by with a list value produces an IN filter."""
|
||||
UserFacetCrud = CrudFactory(User, facet_fields=[User.username])
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="alice", email="a@test.com")
|
||||
)
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="bob", email="b@test.com")
|
||||
)
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="charlie", email="c@test.com")
|
||||
)
|
||||
|
||||
result = await UserFacetCrud.offset_paginate(
|
||||
db_session, filter_by={"username": ["alice", "bob"]}
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 2
|
||||
returned_names = {u.username for u in result.data}
|
||||
assert returned_names == {"alice", "bob"}
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_relationship_filter_by(self, db_session: AsyncSession):
|
||||
"""filter_by works with relationship tuple facet fields."""
|
||||
UserRelFacetCrud = CrudFactory(User, facet_fields=[(User.role, Role.name)])
|
||||
|
||||
admin = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||
editor = await RoleCrud.create(db_session, RoleCreate(name="editor"))
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="alice", email="a@test.com", role_id=admin.id),
|
||||
)
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="bob", email="b@test.com", role_id=editor.id),
|
||||
)
|
||||
|
||||
result = await UserRelFacetCrud.offset_paginate(
|
||||
db_session, filter_by={"name": "admin"}
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 1
|
||||
assert result.data[0].username == "alice"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_filter_by_combined_with_filters(self, db_session: AsyncSession):
|
||||
"""filter_by and filters= are combined (AND logic)."""
|
||||
UserFacetCrud = CrudFactory(User, facet_fields=[User.username])
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="alice", email="a@test.com", is_active=True)
|
||||
)
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="alice2", email="a2@test.com", is_active=False),
|
||||
)
|
||||
|
||||
result = await UserFacetCrud.offset_paginate(
|
||||
db_session,
|
||||
filters=[User.is_active == True], # noqa: E712
|
||||
filter_by={"username": ["alice", "alice2"]},
|
||||
)
|
||||
|
||||
# Only alice passes both: is_active=True AND username IN [alice, alice2]
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 1
|
||||
assert result.data[0].username == "alice"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_invalid_key_raises(self, db_session: AsyncSession):
|
||||
"""filter_by with an undeclared key raises InvalidFacetFilterError."""
|
||||
UserFacetCrud = CrudFactory(User, facet_fields=[User.username])
|
||||
|
||||
with pytest.raises(InvalidFacetFilterError) as exc_info:
|
||||
await UserFacetCrud.offset_paginate(
|
||||
db_session, filter_by={"nonexistent": "value"}
|
||||
)
|
||||
|
||||
assert exc_info.value.key == "nonexistent"
|
||||
assert "username" in exc_info.value.valid_keys
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_filter_by_deduplicates_relationship_join(
|
||||
self, db_session: AsyncSession
|
||||
):
|
||||
"""Two filter_by keys through the same relationship do not duplicate the join."""
|
||||
# Both (User.role, Role.name) and (User.role, Role.id) traverse User.role —
|
||||
# the second key must not re-add the join (exercises the dedup branch in build_filter_by).
|
||||
UserRoleFacetCrud = CrudFactory(
|
||||
User,
|
||||
facet_fields=[(User.role, Role.name), (User.role, Role.id)],
|
||||
)
|
||||
|
||||
admin = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||
editor = await RoleCrud.create(db_session, RoleCreate(name="editor"))
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="alice", email="a@test.com", role_id=admin.id),
|
||||
)
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="bob", email="b@test.com", role_id=editor.id),
|
||||
)
|
||||
|
||||
result = await UserRoleFacetCrud.offset_paginate(
|
||||
db_session,
|
||||
filter_by={"name": "admin", "id": str(admin.id)},
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 1
|
||||
assert result.data[0].username == "alice"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_cursor_paginate_filter_by(self, db_session: AsyncSession):
|
||||
"""filter_by works with cursor_paginate."""
|
||||
UserFacetCursorCrud = CrudFactory(
|
||||
User, cursor_column=User.id, facet_fields=[User.username]
|
||||
)
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="alice", email="a@test.com")
|
||||
)
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="bob", email="b@test.com")
|
||||
)
|
||||
|
||||
result = await UserFacetCursorCrud.cursor_paginate(
|
||||
db_session, filter_by={"username": "alice"}
|
||||
)
|
||||
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0].username == "alice"
|
||||
assert result.filter_attributes == {"username": ["alice"]}
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_basemodel_filter_by_offset_paginate(self, db_session: AsyncSession):
|
||||
"""filter_by accepts a BaseModel instance (model_dump path) in offset_paginate."""
|
||||
from pydantic import BaseModel as PydanticBaseModel
|
||||
|
||||
class UserFilter(PydanticBaseModel):
|
||||
username: str | None = None
|
||||
|
||||
UserFacetCrud = CrudFactory(User, facet_fields=[User.username])
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="alice", email="a@test.com")
|
||||
)
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="bob", email="b@test.com")
|
||||
)
|
||||
|
||||
result = await UserFacetCrud.offset_paginate(
|
||||
db_session, filter_by=UserFilter(username="alice")
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 1
|
||||
assert result.data[0].username == "alice"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_basemodel_filter_by_cursor_paginate(self, db_session: AsyncSession):
|
||||
"""filter_by accepts a BaseModel instance (model_dump path) in cursor_paginate."""
|
||||
from pydantic import BaseModel as PydanticBaseModel
|
||||
|
||||
class UserFilter(PydanticBaseModel):
|
||||
username: str | None = None
|
||||
|
||||
UserFacetCursorCrud = CrudFactory(
|
||||
User, cursor_column=User.id, facet_fields=[User.username]
|
||||
)
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="alice", email="a@test.com")
|
||||
)
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="bob", email="b@test.com")
|
||||
)
|
||||
|
||||
result = await UserFacetCursorCrud.cursor_paginate(
|
||||
db_session, filter_by=UserFilter(username="alice")
|
||||
)
|
||||
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0].username == "alice"
|
||||
|
||||
|
||||
class TestFilterParamsSchema:
|
||||
"""Tests for AsyncCrud.filter_params()."""
|
||||
|
||||
def test_generates_fields_from_facet_fields(self):
|
||||
"""Returned dependency has one keyword param per facet field."""
|
||||
import inspect
|
||||
|
||||
UserFacetCrud = CrudFactory(User, facet_fields=[User.username, User.email])
|
||||
dep = UserFacetCrud.filter_params()
|
||||
|
||||
param_names = set(inspect.signature(dep).parameters)
|
||||
assert param_names == {"username", "email"}
|
||||
|
||||
def test_relationship_facet_uses_column_key(self):
|
||||
"""Relationship tuple uses the terminal column's key."""
|
||||
import inspect
|
||||
|
||||
UserRoleCrud = CrudFactory(User, facet_fields=[(User.role, Role.name)])
|
||||
dep = UserRoleCrud.filter_params()
|
||||
|
||||
param_names = set(inspect.signature(dep).parameters)
|
||||
assert param_names == {"name"}
|
||||
|
||||
def test_raises_when_no_facet_fields(self):
|
||||
"""ValueError raised when no facet_fields are configured or provided."""
|
||||
with pytest.raises(ValueError, match="no facet_fields"):
|
||||
UserCrud.filter_params()
|
||||
|
||||
def test_facet_fields_override(self):
|
||||
"""facet_fields= parameter overrides the class-level default."""
|
||||
import inspect
|
||||
|
||||
UserFacetCrud = CrudFactory(User, facet_fields=[User.username, User.email])
|
||||
dep = UserFacetCrud.filter_params(facet_fields=[User.email])
|
||||
|
||||
param_names = set(inspect.signature(dep).parameters)
|
||||
assert param_names == {"email"}
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_awaiting_dep_returns_dict_with_values(self):
|
||||
"""Awaiting the dependency returns a dict with only the supplied keys."""
|
||||
UserFacetCrud = CrudFactory(User, facet_fields=[User.username, User.email])
|
||||
dep = UserFacetCrud.filter_params()
|
||||
|
||||
result = await dep(username=["alice"])
|
||||
assert result == {"username": ["alice"]}
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_multi_value_list_field(self):
|
||||
"""Multiple values are accepted as a list."""
|
||||
UserFacetCrud = CrudFactory(User, facet_fields=[User.username])
|
||||
dep = UserFacetCrud.filter_params()
|
||||
|
||||
result = await dep(username=["alice", "bob"])
|
||||
assert result == {"username": ["alice", "bob"]}
|
||||
|
||||
def test_disambiguates_duplicate_column_keys(self):
|
||||
"""Two relationship tuples sharing a terminal column key get prefixed names."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from fastapi_toolsets.crud.search import facet_keys
|
||||
|
||||
col_a = MagicMock()
|
||||
col_a.key = "name"
|
||||
rel_a = MagicMock()
|
||||
rel_a.key = "project"
|
||||
|
||||
col_b = MagicMock()
|
||||
col_b.key = "name"
|
||||
rel_b = MagicMock()
|
||||
rel_b.key = "os"
|
||||
|
||||
keys = facet_keys([(rel_a, col_a), (rel_b, col_b)])
|
||||
assert keys == ["project__name", "os__name"]
|
||||
|
||||
def test_unique_column_keys_kept_plain(self):
|
||||
"""Fields with unique column keys are not prefixed."""
|
||||
from fastapi_toolsets.crud.search import facet_keys
|
||||
|
||||
keys = facet_keys([User.username, User.email])
|
||||
assert keys == ["username", "email"]
|
||||
|
||||
def test_dependency_name_includes_model_name(self):
|
||||
"""Returned dependency is named {Model}FilterParams."""
|
||||
UserFacetCrud = CrudFactory(User, facet_fields=[User.username])
|
||||
dep = UserFacetCrud.filter_params()
|
||||
|
||||
assert dep.__name__ == "UserFilterParams" # type: ignore[union-attr]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_integration_with_offset_paginate(self, db_session: AsyncSession):
|
||||
"""Dependency result can be passed directly to offset_paginate via filter_by."""
|
||||
UserFacetCrud = CrudFactory(User, facet_fields=[User.username])
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="alice", email="a@test.com")
|
||||
)
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="bob", email="b@test.com")
|
||||
)
|
||||
|
||||
dep = UserFacetCrud.filter_params()
|
||||
f = await dep(username=["alice"])
|
||||
result = await UserFacetCrud.offset_paginate(db_session, filter_by=f)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 1
|
||||
assert result.data[0].username == "alice"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_dep_result_passed_to_cursor_paginate(self, db_session: AsyncSession):
|
||||
"""Dependency result can be passed directly to cursor_paginate via filter_by."""
|
||||
UserFacetCursorCrud = CrudFactory(
|
||||
User, cursor_column=User.id, facet_fields=[User.username]
|
||||
)
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="alice", email="a@test.com")
|
||||
)
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="bob", email="b@test.com")
|
||||
)
|
||||
|
||||
dep = UserFacetCursorCrud.filter_params()
|
||||
f = await dep(username=["alice"])
|
||||
result = await UserFacetCursorCrud.cursor_paginate(db_session, filter_by=f)
|
||||
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0].username == "alice"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_all_none_dep_result_passes_no_filter(self, db_session: AsyncSession):
|
||||
"""All-None dependency result results in no filter (returns all rows)."""
|
||||
UserFacetCrud = CrudFactory(User, facet_fields=[User.username])
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="alice", email="a@test.com")
|
||||
)
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="bob", email="b@test.com")
|
||||
)
|
||||
|
||||
dep = UserFacetCrud.filter_params()
|
||||
f = await dep() # all fields None
|
||||
result = await UserFacetCrud.offset_paginate(db_session, filter_by=f)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 2
|
||||
|
||||
Reference in New Issue
Block a user