mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-15 22:26:25 +02:00
feat: add search_column parameter and search_columns response field for targeted search (#207)
This commit is contained in:
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from ..exceptions import (
|
from ..exceptions import (
|
||||||
InvalidFacetFilterError,
|
InvalidFacetFilterError,
|
||||||
|
InvalidSearchColumnError,
|
||||||
NoSearchableFieldsError,
|
NoSearchableFieldsError,
|
||||||
UnsupportedFacetTypeError,
|
UnsupportedFacetTypeError,
|
||||||
)
|
)
|
||||||
@@ -22,6 +23,7 @@ __all__ = [
|
|||||||
"FacetFieldType",
|
"FacetFieldType",
|
||||||
"get_searchable_fields",
|
"get_searchable_fields",
|
||||||
"InvalidFacetFilterError",
|
"InvalidFacetFilterError",
|
||||||
|
"InvalidSearchColumnError",
|
||||||
"JoinType",
|
"JoinType",
|
||||||
"M2MFieldType",
|
"M2MFieldType",
|
||||||
"NoSearchableFieldsError",
|
"NoSearchableFieldsError",
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ from .search import (
|
|||||||
build_filter_by,
|
build_filter_by,
|
||||||
build_search_filters,
|
build_search_filters,
|
||||||
facet_keys,
|
facet_keys,
|
||||||
|
search_field_keys,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -309,6 +310,69 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
|
|
||||||
return dependency
|
return dependency
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def search_params(
|
||||||
|
cls: type[Self],
|
||||||
|
*,
|
||||||
|
search_fields: Sequence[SearchFieldType] | None = None,
|
||||||
|
) -> Callable[..., Awaitable[dict[str, Any]]]:
|
||||||
|
"""Return a FastAPI dependency that collects search params from query parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
search_fields: Override search fields for this dependency.
|
||||||
|
Falls back to the class-level ``searchable_fields``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An async dependency function named ``{Model}SearchParams`` that
|
||||||
|
resolves to a ``dict`` with ``search`` and ``search_column`` keys
|
||||||
|
(absent keys are excluded).
|
||||||
|
"""
|
||||||
|
fields = search_fields if search_fields is not None else cls.searchable_fields
|
||||||
|
if not fields:
|
||||||
|
raise ValueError(
|
||||||
|
f"{cls.__name__} has no searchable_fields configured. "
|
||||||
|
"Pass search_fields= or set them on CrudFactory."
|
||||||
|
)
|
||||||
|
keys = search_field_keys(fields)
|
||||||
|
|
||||||
|
async def dependency(**kwargs: Any) -> dict[str, Any]:
|
||||||
|
return {k: v for k, v in kwargs.items() if v is not None}
|
||||||
|
|
||||||
|
dependency.__name__ = f"{cls.model.__name__}SearchParams"
|
||||||
|
dependency.__signature__ = inspect.Signature( # type: ignore[attr-defined] # ty:ignore[unresolved-attribute]
|
||||||
|
parameters=[
|
||||||
|
inspect.Parameter(
|
||||||
|
"search",
|
||||||
|
inspect.Parameter.KEYWORD_ONLY,
|
||||||
|
annotation=str | None,
|
||||||
|
default=Query(default=None, description="Search query string"),
|
||||||
|
),
|
||||||
|
inspect.Parameter(
|
||||||
|
"search_column",
|
||||||
|
inspect.Parameter.KEYWORD_ONLY,
|
||||||
|
annotation=str | None,
|
||||||
|
default=Query(
|
||||||
|
default=None,
|
||||||
|
description="Restrict search to a single column",
|
||||||
|
enum=keys,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return dependency
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _resolve_search_columns(
|
||||||
|
cls: type[Self],
|
||||||
|
search_fields: Sequence[SearchFieldType] | None,
|
||||||
|
) -> list[str] | None:
|
||||||
|
"""Return search column keys, or None if no searchable fields configured."""
|
||||||
|
fields = search_fields if search_fields is not None else cls.searchable_fields
|
||||||
|
if not fields:
|
||||||
|
return None
|
||||||
|
return search_field_keys(fields)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def offset_params(
|
def offset_params(
|
||||||
cls: type[Self],
|
cls: type[Self],
|
||||||
@@ -1056,6 +1120,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
include_total: bool = True,
|
include_total: bool = True,
|
||||||
search: str | SearchConfig | None = None,
|
search: str | SearchConfig | None = None,
|
||||||
search_fields: Sequence[SearchFieldType] | None = None,
|
search_fields: Sequence[SearchFieldType] | None = None,
|
||||||
|
search_column: str | None = None,
|
||||||
facet_fields: Sequence[FacetFieldType] | None = None,
|
facet_fields: Sequence[FacetFieldType] | None = None,
|
||||||
filter_by: dict[str, Any] | BaseModel | None = None,
|
filter_by: dict[str, Any] | BaseModel | None = None,
|
||||||
schema: type[BaseModel],
|
schema: type[BaseModel],
|
||||||
@@ -1075,6 +1140,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
``pagination.total_count`` will be ``None``.
|
``pagination.total_count`` will be ``None``.
|
||||||
search: Search query string or SearchConfig object
|
search: Search query string or SearchConfig object
|
||||||
search_fields: Fields to search in (overrides class default)
|
search_fields: Fields to search in (overrides class default)
|
||||||
|
search_column: Restrict search to a single column key.
|
||||||
facet_fields: Columns to compute distinct values for (overrides class default)
|
facet_fields: Columns to compute distinct values for (overrides class default)
|
||||||
filter_by: Dict of {column_key: value} to filter by declared facet fields.
|
filter_by: Dict of {column_key: value} to filter by declared facet fields.
|
||||||
Keys must match the column.key of a facet field. Scalar → equality,
|
Keys must match the column.key of a facet field. Scalar → equality,
|
||||||
@@ -1097,6 +1163,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
search,
|
search,
|
||||||
search_fields=search_fields,
|
search_fields=search_fields,
|
||||||
default_fields=cls.searchable_fields,
|
default_fields=cls.searchable_fields,
|
||||||
|
search_column=search_column,
|
||||||
)
|
)
|
||||||
filters.extend(search_filters)
|
filters.extend(search_filters)
|
||||||
search_joins.extend(new_search_joins)
|
search_joins.extend(new_search_joins)
|
||||||
@@ -1153,6 +1220,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
filter_attributes = await cls._build_filter_attributes(
|
filter_attributes = await cls._build_filter_attributes(
|
||||||
session, facet_fields, filters, search_joins
|
session, facet_fields, filters, search_joins
|
||||||
)
|
)
|
||||||
|
search_columns = cls._resolve_search_columns(search_fields)
|
||||||
|
|
||||||
return OffsetPaginatedResponse(
|
return OffsetPaginatedResponse(
|
||||||
data=items,
|
data=items,
|
||||||
@@ -1163,6 +1231,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
has_more=has_more,
|
has_more=has_more,
|
||||||
),
|
),
|
||||||
filter_attributes=filter_attributes,
|
filter_attributes=filter_attributes,
|
||||||
|
search_columns=search_columns,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -1179,6 +1248,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
items_per_page: int = 20,
|
items_per_page: int = 20,
|
||||||
search: str | SearchConfig | None = None,
|
search: str | SearchConfig | None = None,
|
||||||
search_fields: Sequence[SearchFieldType] | None = None,
|
search_fields: Sequence[SearchFieldType] | None = None,
|
||||||
|
search_column: str | None = None,
|
||||||
facet_fields: Sequence[FacetFieldType] | None = None,
|
facet_fields: Sequence[FacetFieldType] | None = None,
|
||||||
filter_by: dict[str, Any] | BaseModel | None = None,
|
filter_by: dict[str, Any] | BaseModel | None = None,
|
||||||
schema: type[BaseModel],
|
schema: type[BaseModel],
|
||||||
@@ -1199,6 +1269,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
items_per_page: Number of items per page (default 20).
|
items_per_page: Number of items per page (default 20).
|
||||||
search: Search query string or SearchConfig object.
|
search: Search query string or SearchConfig object.
|
||||||
search_fields: Fields to search in (overrides class default).
|
search_fields: Fields to search in (overrides class default).
|
||||||
|
search_column: Restrict search to a single column key.
|
||||||
facet_fields: Columns to compute distinct values for (overrides class default).
|
facet_fields: Columns to compute distinct values for (overrides class default).
|
||||||
filter_by: Dict of {column_key: value} to filter by declared facet fields.
|
filter_by: Dict of {column_key: value} to filter by declared facet fields.
|
||||||
Keys must match the column.key of a facet field. Scalar → equality,
|
Keys must match the column.key of a facet field. Scalar → equality,
|
||||||
@@ -1238,6 +1309,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
search,
|
search,
|
||||||
search_fields=search_fields,
|
search_fields=search_fields,
|
||||||
default_fields=cls.searchable_fields,
|
default_fields=cls.searchable_fields,
|
||||||
|
search_column=search_column,
|
||||||
)
|
)
|
||||||
filters.extend(search_filters)
|
filters.extend(search_filters)
|
||||||
search_joins.extend(new_search_joins)
|
search_joins.extend(new_search_joins)
|
||||||
@@ -1308,6 +1380,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
filter_attributes = await cls._build_filter_attributes(
|
filter_attributes = await cls._build_filter_attributes(
|
||||||
session, facet_fields, filters, search_joins
|
session, facet_fields, filters, search_joins
|
||||||
)
|
)
|
||||||
|
search_columns = cls._resolve_search_columns(search_fields)
|
||||||
|
|
||||||
return CursorPaginatedResponse(
|
return CursorPaginatedResponse(
|
||||||
data=items,
|
data=items,
|
||||||
@@ -1318,6 +1391,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
has_more=has_more,
|
has_more=has_more,
|
||||||
),
|
),
|
||||||
filter_attributes=filter_attributes,
|
filter_attributes=filter_attributes,
|
||||||
|
search_columns=search_columns,
|
||||||
)
|
)
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
@@ -1338,6 +1412,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
include_total: bool = ...,
|
include_total: bool = ...,
|
||||||
search: str | SearchConfig | None = ...,
|
search: str | SearchConfig | None = ...,
|
||||||
search_fields: Sequence[SearchFieldType] | None = ...,
|
search_fields: Sequence[SearchFieldType] | None = ...,
|
||||||
|
search_column: str | None = ...,
|
||||||
facet_fields: Sequence[FacetFieldType] | None = ...,
|
facet_fields: Sequence[FacetFieldType] | None = ...,
|
||||||
filter_by: dict[str, Any] | BaseModel | None = ...,
|
filter_by: dict[str, Any] | BaseModel | None = ...,
|
||||||
schema: type[BaseModel],
|
schema: type[BaseModel],
|
||||||
@@ -1361,6 +1436,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
include_total: bool = ...,
|
include_total: bool = ...,
|
||||||
search: str | SearchConfig | None = ...,
|
search: str | SearchConfig | None = ...,
|
||||||
search_fields: Sequence[SearchFieldType] | None = ...,
|
search_fields: Sequence[SearchFieldType] | None = ...,
|
||||||
|
search_column: str | None = ...,
|
||||||
facet_fields: Sequence[FacetFieldType] | None = ...,
|
facet_fields: Sequence[FacetFieldType] | None = ...,
|
||||||
filter_by: dict[str, Any] | BaseModel | None = ...,
|
filter_by: dict[str, Any] | BaseModel | None = ...,
|
||||||
schema: type[BaseModel],
|
schema: type[BaseModel],
|
||||||
@@ -1383,6 +1459,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
include_total: bool = True,
|
include_total: bool = True,
|
||||||
search: str | SearchConfig | None = None,
|
search: str | SearchConfig | None = None,
|
||||||
search_fields: Sequence[SearchFieldType] | None = None,
|
search_fields: Sequence[SearchFieldType] | None = None,
|
||||||
|
search_column: str | None = None,
|
||||||
facet_fields: Sequence[FacetFieldType] | None = None,
|
facet_fields: Sequence[FacetFieldType] | None = None,
|
||||||
filter_by: dict[str, Any] | BaseModel | None = None,
|
filter_by: dict[str, Any] | BaseModel | None = None,
|
||||||
schema: type[BaseModel],
|
schema: type[BaseModel],
|
||||||
@@ -1410,6 +1487,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
only applies when ``pagination_type`` is ``OFFSET``.
|
only applies when ``pagination_type`` is ``OFFSET``.
|
||||||
search: Search query string or :class:`.SearchConfig` object.
|
search: Search query string or :class:`.SearchConfig` object.
|
||||||
search_fields: Fields to search in (overrides class default).
|
search_fields: Fields to search in (overrides class default).
|
||||||
|
search_column: Restrict search to a single column key.
|
||||||
facet_fields: Columns to compute distinct values for (overrides
|
facet_fields: Columns to compute distinct values for (overrides
|
||||||
class default).
|
class default).
|
||||||
filter_by: Dict of ``{column_key: value}`` to filter by declared
|
filter_by: Dict of ``{column_key: value}`` to filter by declared
|
||||||
@@ -1438,6 +1516,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
items_per_page=items_per_page,
|
items_per_page=items_per_page,
|
||||||
search=search,
|
search=search,
|
||||||
search_fields=search_fields,
|
search_fields=search_fields,
|
||||||
|
search_column=search_column,
|
||||||
facet_fields=facet_fields,
|
facet_fields=facet_fields,
|
||||||
filter_by=filter_by,
|
filter_by=filter_by,
|
||||||
schema=schema,
|
schema=schema,
|
||||||
@@ -1457,6 +1536,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
include_total=include_total,
|
include_total=include_total,
|
||||||
search=search,
|
search=search,
|
||||||
search_fields=search_fields,
|
search_fields=search_fields,
|
||||||
|
search_column=search_column,
|
||||||
facet_fields=facet_fields,
|
facet_fields=facet_fields,
|
||||||
filter_by=filter_by,
|
filter_by=filter_by,
|
||||||
schema=schema,
|
schema=schema,
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from sqlalchemy.types import (
|
|||||||
|
|
||||||
from ..exceptions import (
|
from ..exceptions import (
|
||||||
InvalidFacetFilterError,
|
InvalidFacetFilterError,
|
||||||
|
InvalidSearchColumnError,
|
||||||
NoSearchableFieldsError,
|
NoSearchableFieldsError,
|
||||||
UnsupportedFacetTypeError,
|
UnsupportedFacetTypeError,
|
||||||
)
|
)
|
||||||
@@ -96,6 +97,7 @@ def build_search_filters(
|
|||||||
search: str | SearchConfig,
|
search: str | SearchConfig,
|
||||||
search_fields: Sequence[SearchFieldType] | None = None,
|
search_fields: Sequence[SearchFieldType] | None = None,
|
||||||
default_fields: Sequence[SearchFieldType] | None = None,
|
default_fields: Sequence[SearchFieldType] | None = None,
|
||||||
|
search_column: str | None = None,
|
||||||
) -> tuple[list["ColumnElement[bool]"], list[InstrumentedAttribute[Any]]]:
|
) -> tuple[list["ColumnElement[bool]"], list[InstrumentedAttribute[Any]]]:
|
||||||
"""Build SQLAlchemy filter conditions for search.
|
"""Build SQLAlchemy filter conditions for search.
|
||||||
|
|
||||||
@@ -104,6 +106,8 @@ def build_search_filters(
|
|||||||
search: Search string or SearchConfig
|
search: Search string or SearchConfig
|
||||||
search_fields: Fields specified per-call (takes priority)
|
search_fields: Fields specified per-call (takes priority)
|
||||||
default_fields: Default fields (from ClassVar)
|
default_fields: Default fields (from ClassVar)
|
||||||
|
search_column: Optional key to narrow search to a single field.
|
||||||
|
Must match one of the resolved search field keys.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (filter_conditions, joins_needed)
|
Tuple of (filter_conditions, joins_needed)
|
||||||
@@ -130,6 +134,14 @@ def build_search_filters(
|
|||||||
if not fields:
|
if not fields:
|
||||||
raise NoSearchableFieldsError(model)
|
raise NoSearchableFieldsError(model)
|
||||||
|
|
||||||
|
# Narrow to a single column when search_column is specified
|
||||||
|
if search_column is not None:
|
||||||
|
keys = search_field_keys(fields)
|
||||||
|
index = {k: f for k, f in zip(keys, fields)}
|
||||||
|
if search_column not in index:
|
||||||
|
raise InvalidSearchColumnError(search_column, sorted(index))
|
||||||
|
fields = [index[search_column]]
|
||||||
|
|
||||||
query = config.query.strip()
|
query = config.query.strip()
|
||||||
filters: list[ColumnElement[bool]] = []
|
filters: list[ColumnElement[bool]] = []
|
||||||
joins: list[InstrumentedAttribute[Any]] = []
|
joins: list[InstrumentedAttribute[Any]] = []
|
||||||
@@ -164,6 +176,11 @@ def build_search_filters(
|
|||||||
return filters, joins
|
return filters, joins
|
||||||
|
|
||||||
|
|
||||||
|
def search_field_keys(fields: Sequence[SearchFieldType]) -> list[str]:
|
||||||
|
"""Return a human-readable key for each search field."""
|
||||||
|
return facet_keys(fields)
|
||||||
|
|
||||||
|
|
||||||
def facet_keys(facet_fields: Sequence[FacetFieldType]) -> list[str]:
|
def facet_keys(facet_fields: Sequence[FacetFieldType]) -> list[str]:
|
||||||
"""Return a key for each facet field.
|
"""Return a key for each facet field.
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from .exceptions import (
|
|||||||
ForbiddenError,
|
ForbiddenError,
|
||||||
InvalidFacetFilterError,
|
InvalidFacetFilterError,
|
||||||
InvalidOrderFieldError,
|
InvalidOrderFieldError,
|
||||||
|
InvalidSearchColumnError,
|
||||||
NoSearchableFieldsError,
|
NoSearchableFieldsError,
|
||||||
NotFoundError,
|
NotFoundError,
|
||||||
UnauthorizedError,
|
UnauthorizedError,
|
||||||
@@ -24,6 +25,7 @@ __all__ = [
|
|||||||
"init_exceptions_handlers",
|
"init_exceptions_handlers",
|
||||||
"InvalidFacetFilterError",
|
"InvalidFacetFilterError",
|
||||||
"InvalidOrderFieldError",
|
"InvalidOrderFieldError",
|
||||||
|
"InvalidSearchColumnError",
|
||||||
"NoSearchableFieldsError",
|
"NoSearchableFieldsError",
|
||||||
"NotFoundError",
|
"NotFoundError",
|
||||||
"UnauthorizedError",
|
"UnauthorizedError",
|
||||||
|
|||||||
@@ -172,6 +172,33 @@ class UnsupportedFacetTypeError(ApiException):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidSearchColumnError(ApiException):
|
||||||
|
"""Raised when search_column is not one of the configured searchable fields."""
|
||||||
|
|
||||||
|
api_error = ApiError(
|
||||||
|
code=400,
|
||||||
|
msg="Invalid Search Column",
|
||||||
|
desc="The requested search column is not a configured searchable field.",
|
||||||
|
err_code="SEARCH-COL-400",
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, column: str, valid_columns: list[str]) -> None:
|
||||||
|
"""Initialize the exception.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
column: The unknown search column provided by the caller.
|
||||||
|
valid_columns: List of valid search column keys.
|
||||||
|
"""
|
||||||
|
self.column = column
|
||||||
|
self.valid_columns = valid_columns
|
||||||
|
super().__init__(
|
||||||
|
desc=(
|
||||||
|
f"'{column}' is not a searchable column. "
|
||||||
|
f"Valid columns: {valid_columns}."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class InvalidOrderFieldError(ApiException):
|
class InvalidOrderFieldError(ApiException):
|
||||||
"""Raised when order_by contains a field not in the allowed order fields."""
|
"""Raised when order_by contains a field not in the allowed order fields."""
|
||||||
|
|
||||||
|
|||||||
@@ -162,6 +162,7 @@ class PaginatedResponse(BaseResponse, Generic[DataT]):
|
|||||||
pagination: OffsetPagination | CursorPagination
|
pagination: OffsetPagination | CursorPagination
|
||||||
pagination_type: PaginationType | None = None
|
pagination_type: PaginationType | None = None
|
||||||
filter_attributes: dict[str, list[Any]] | None = None
|
filter_attributes: dict[str, list[Any]] | None = None
|
||||||
|
search_columns: list[str] | None = None
|
||||||
|
|
||||||
_discriminated_union_cache: ClassVar[dict[Any, Any]] = {}
|
_discriminated_union_cache: ClassVar[dict[Any, Any]] = {}
|
||||||
|
|
||||||
|
|||||||
@@ -211,6 +211,38 @@ class TestResolveLoadOptions:
|
|||||||
assert crud._resolve_load_options([]) == []
|
assert crud._resolve_load_options([]) == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestResolveSearchColumns:
|
||||||
|
"""Tests for _resolve_search_columns logic."""
|
||||||
|
|
||||||
|
def test_returns_none_when_no_searchable_fields(self):
|
||||||
|
"""Returns None when cls.searchable_fields is None and no search_fields passed."""
|
||||||
|
|
||||||
|
class AbstractCrud(AsyncCrud[User]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert AbstractCrud._resolve_search_columns(None) is None
|
||||||
|
|
||||||
|
def test_returns_none_when_empty_search_fields_passed(self):
|
||||||
|
"""Returns None when an empty list is passed explicitly."""
|
||||||
|
crud = CrudFactory(User)
|
||||||
|
assert crud._resolve_search_columns([]) is None
|
||||||
|
|
||||||
|
def test_returns_keys_from_class_searchable_fields(self):
|
||||||
|
"""Returns column keys from cls.searchable_fields when no override passed."""
|
||||||
|
crud = CrudFactory(User, searchable_fields=[User.username])
|
||||||
|
result = crud._resolve_search_columns(None)
|
||||||
|
assert result is not None
|
||||||
|
assert "username" in result
|
||||||
|
|
||||||
|
def test_search_fields_override_takes_priority(self):
|
||||||
|
"""Explicit search_fields override cls.searchable_fields."""
|
||||||
|
crud = CrudFactory(User, searchable_fields=[User.username])
|
||||||
|
result = crud._resolve_search_columns([User.email])
|
||||||
|
assert result is not None
|
||||||
|
assert "email" in result
|
||||||
|
assert "username" not in result
|
||||||
|
|
||||||
|
|
||||||
class TestDefaultLoadOptionsIntegration:
|
class TestDefaultLoadOptionsIntegration:
|
||||||
"""Integration tests for default_load_options with real DB queries."""
|
"""Integration tests for default_load_options with real DB queries."""
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from sqlalchemy.sql.elements import ColumnElement, UnaryExpression
|
|||||||
from fastapi_toolsets.crud import (
|
from fastapi_toolsets.crud import (
|
||||||
CrudFactory,
|
CrudFactory,
|
||||||
InvalidFacetFilterError,
|
InvalidFacetFilterError,
|
||||||
|
InvalidSearchColumnError,
|
||||||
SearchConfig,
|
SearchConfig,
|
||||||
UnsupportedFacetTypeError,
|
UnsupportedFacetTypeError,
|
||||||
get_searchable_fields,
|
get_searchable_fields,
|
||||||
@@ -1199,6 +1200,208 @@ class TestFilterParamsSchema:
|
|||||||
assert result.pagination.total_count == 2
|
assert result.pagination.total_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestSearchParamsSchema:
|
||||||
|
"""Tests for AsyncCrud.search_params()."""
|
||||||
|
|
||||||
|
def test_generates_search_and_search_column_params(self):
|
||||||
|
"""Returned dependency has search and search_column query params."""
|
||||||
|
UserSearchCrud = CrudFactory(
|
||||||
|
User, searchable_fields=[User.username, User.email]
|
||||||
|
)
|
||||||
|
dep = UserSearchCrud.search_params()
|
||||||
|
|
||||||
|
param_names = set(inspect.signature(dep).parameters)
|
||||||
|
assert param_names == {"search", "search_column"}
|
||||||
|
|
||||||
|
def test_dependency_name_includes_model_name(self):
|
||||||
|
"""Dependency function is named {Model}SearchParams."""
|
||||||
|
UserSearchCrud = CrudFactory(
|
||||||
|
User, searchable_fields=[User.username, User.email]
|
||||||
|
)
|
||||||
|
dep = UserSearchCrud.search_params()
|
||||||
|
assert dep.__name__ == "UserSearchParams" # type: ignore[union-attr] # ty:ignore[unresolved-attribute]
|
||||||
|
|
||||||
|
def test_raises_when_no_searchable_fields(self):
|
||||||
|
"""ValueError raised when overriding with empty search_fields."""
|
||||||
|
UserSearchCrud = CrudFactory(User, searchable_fields=[User.username])
|
||||||
|
with pytest.raises(ValueError, match="no searchable_fields"):
|
||||||
|
UserSearchCrud.search_params(search_fields=[])
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_awaiting_dep_with_search_only(self):
|
||||||
|
"""Awaiting the dependency with only search returns a dict with search key."""
|
||||||
|
UserSearchCrud = CrudFactory(
|
||||||
|
User, searchable_fields=[User.username, User.email]
|
||||||
|
)
|
||||||
|
dep = UserSearchCrud.search_params()
|
||||||
|
|
||||||
|
result = await dep(search="alice")
|
||||||
|
assert result == {"search": "alice"}
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_awaiting_dep_with_search_and_column(self):
|
||||||
|
"""Awaiting the dependency with both params returns both keys."""
|
||||||
|
UserSearchCrud = CrudFactory(
|
||||||
|
User, searchable_fields=[User.username, User.email]
|
||||||
|
)
|
||||||
|
dep = UserSearchCrud.search_params()
|
||||||
|
|
||||||
|
result = await dep(search="alice", search_column="username")
|
||||||
|
assert result == {"search": "alice", "search_column": "username"}
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_awaiting_dep_with_no_values(self):
|
||||||
|
"""Awaiting the dependency with no values returns an empty dict."""
|
||||||
|
UserSearchCrud = CrudFactory(
|
||||||
|
User, searchable_fields=[User.username, User.email]
|
||||||
|
)
|
||||||
|
dep = UserSearchCrud.search_params()
|
||||||
|
|
||||||
|
result = await dep()
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
def test_relationship_search_field_key(self):
|
||||||
|
"""Relationship tuple search fields use __ joined keys."""
|
||||||
|
UserRelSearchCrud = CrudFactory(
|
||||||
|
User, searchable_fields=[User.username, (User.role, Role.name)]
|
||||||
|
)
|
||||||
|
dep = UserRelSearchCrud.search_params()
|
||||||
|
|
||||||
|
params = inspect.signature(dep).parameters
|
||||||
|
search_column_param = params["search_column"]
|
||||||
|
assert search_column_param.default.json_schema_extra.get("enum") == [
|
||||||
|
"id",
|
||||||
|
"username",
|
||||||
|
"role__name",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TestSearchColumns:
|
||||||
|
"""Tests for search_columns in paginated responses."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_columns_returned_in_offset_paginate(
|
||||||
|
self, db_session: AsyncSession
|
||||||
|
):
|
||||||
|
"""offset_paginate response includes search_columns."""
|
||||||
|
UserSearchCrud = CrudFactory(
|
||||||
|
User, searchable_fields=[User.username, User.email]
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="alice", email="a@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserSearchCrud.offset_paginate(db_session, schema=UserRead)
|
||||||
|
|
||||||
|
assert result.search_columns is not None
|
||||||
|
assert "username" in result.search_columns
|
||||||
|
assert "email" in result.search_columns
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_columns_returned_in_cursor_paginate(
|
||||||
|
self, db_session: AsyncSession
|
||||||
|
):
|
||||||
|
"""cursor_paginate response includes search_columns."""
|
||||||
|
UserSearchCursorCrud = CrudFactory(
|
||||||
|
User,
|
||||||
|
cursor_column=User.id,
|
||||||
|
searchable_fields=[User.username, User.email],
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="alice", email="a@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserSearchCursorCrud.cursor_paginate(db_session, schema=UserRead)
|
||||||
|
|
||||||
|
assert result.search_columns is not None
|
||||||
|
assert "username" in result.search_columns
|
||||||
|
assert "email" in result.search_columns
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_column_narrows_search(self, db_session: AsyncSession):
|
||||||
|
"""search_column restricts search to a single field."""
|
||||||
|
UserSearchCrud = CrudFactory(
|
||||||
|
User, searchable_fields=[User.username, User.email]
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="alice", email="bob@test.com")
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="bob", email="alice@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search "alice" in username only — should return only alice
|
||||||
|
result = await UserSearchCrud.offset_paginate(
|
||||||
|
db_session, search="alice", search_column="username", schema=UserRead
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 1
|
||||||
|
assert result.data[0].username == "alice"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_column_invalid_raises(self, db_session: AsyncSession):
|
||||||
|
"""search_column with an invalid key raises InvalidSearchColumnError."""
|
||||||
|
UserSearchCrud = CrudFactory(
|
||||||
|
User, searchable_fields=[User.username, User.email]
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(InvalidSearchColumnError) as exc_info:
|
||||||
|
await UserSearchCrud.offset_paginate(
|
||||||
|
db_session,
|
||||||
|
search="alice",
|
||||||
|
search_column="nonexistent",
|
||||||
|
schema=UserRead,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert exc_info.value.column == "nonexistent"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_without_search_column_searches_all(
|
||||||
|
self, db_session: AsyncSession
|
||||||
|
):
|
||||||
|
"""search without search_column searches across all configured fields."""
|
||||||
|
UserSearchCrud = CrudFactory(
|
||||||
|
User, searchable_fields=[User.username, User.email]
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="alice", email="bob@test.com")
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="bob", email="alice@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search "alice" across all fields — should return both
|
||||||
|
result = await UserSearchCrud.offset_paginate(
|
||||||
|
db_session, search="alice", schema=UserRead
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
|
assert result.pagination.total_count == 2
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_column_with_cursor_paginate(self, db_session: AsyncSession):
|
||||||
|
"""search_column works with cursor_paginate."""
|
||||||
|
UserSearchCursorCrud = CrudFactory(
|
||||||
|
User,
|
||||||
|
cursor_column=User.id,
|
||||||
|
searchable_fields=[User.username, User.email],
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="alice", email="bob@test.com")
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="bob", email="alice@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserSearchCursorCrud.cursor_paginate(
|
||||||
|
db_session, search="alice", search_column="email", schema=UserRead
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result.data) == 1
|
||||||
|
assert result.data[0].username == "bob"
|
||||||
|
|
||||||
|
|
||||||
class TestOrderParamsSchema:
|
class TestOrderParamsSchema:
|
||||||
"""Tests for AsyncCrud.order_params()."""
|
"""Tests for AsyncCrud.order_params()."""
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user