diff --git a/src/fastapi_toolsets/crud/__init__.py b/src/fastapi_toolsets/crud/__init__.py index 8d431b6..23f7d97 100644 --- a/src/fastapi_toolsets/crud/__init__.py +++ b/src/fastapi_toolsets/crud/__init__.py @@ -2,6 +2,7 @@ from ..exceptions import ( InvalidFacetFilterError, + InvalidSearchColumnError, NoSearchableFieldsError, UnsupportedFacetTypeError, ) @@ -22,6 +23,7 @@ __all__ = [ "FacetFieldType", "get_searchable_fields", "InvalidFacetFilterError", + "InvalidSearchColumnError", "JoinType", "M2MFieldType", "NoSearchableFieldsError", diff --git a/src/fastapi_toolsets/crud/factory.py b/src/fastapi_toolsets/crud/factory.py index 0a47d79..5fde829 100644 --- a/src/fastapi_toolsets/crud/factory.py +++ b/src/fastapi_toolsets/crud/factory.py @@ -47,6 +47,7 @@ from .search import ( build_filter_by, build_search_filters, facet_keys, + search_field_keys, ) @@ -309,6 +310,69 @@ class AsyncCrud(Generic[ModelType]): return dependency + @classmethod + def search_params( + cls: type[Self], + *, + search_fields: Sequence[SearchFieldType] | None = None, + ) -> Callable[..., Awaitable[dict[str, Any]]]: + """Return a FastAPI dependency that collects search params from query parameters. + + Args: + search_fields: Override search fields for this dependency. + Falls back to the class-level ``searchable_fields``. + + Returns: + An async dependency function named ``{Model}SearchParams`` that + resolves to a ``dict`` with ``search`` and ``search_column`` keys + (absent keys are excluded). + """ + fields = search_fields if search_fields is not None else cls.searchable_fields + if not fields: + raise ValueError( + f"{cls.__name__} has no searchable_fields configured. " + "Pass search_fields= or set them on CrudFactory." + ) + keys = search_field_keys(fields) + + async def dependency(**kwargs: Any) -> dict[str, Any]: + return {k: v for k, v in kwargs.items() if v is not None} + + dependency.__name__ = f"{cls.model.__name__}SearchParams" + dependency.__signature__ = inspect.Signature( # type: ignore[attr-defined] # ty:ignore[unresolved-attribute] + parameters=[ + inspect.Parameter( + "search", + inspect.Parameter.KEYWORD_ONLY, + annotation=str | None, + default=Query(default=None, description="Search query string"), + ), + inspect.Parameter( + "search_column", + inspect.Parameter.KEYWORD_ONLY, + annotation=str | None, + default=Query( + default=None, + description="Restrict search to a single column", + enum=keys, + ), + ), + ] + ) + + return dependency + + @classmethod + def _resolve_search_columns( + cls: type[Self], + search_fields: Sequence[SearchFieldType] | None, + ) -> list[str] | None: + """Return search column keys, or None if no searchable fields configured.""" + fields = search_fields if search_fields is not None else cls.searchable_fields + if not fields: + return None + return search_field_keys(fields) + @classmethod def offset_params( cls: type[Self], @@ -1056,6 +1120,7 @@ class AsyncCrud(Generic[ModelType]): include_total: bool = True, search: str | SearchConfig | None = None, search_fields: Sequence[SearchFieldType] | None = None, + search_column: str | None = None, facet_fields: Sequence[FacetFieldType] | None = None, filter_by: dict[str, Any] | BaseModel | None = None, schema: type[BaseModel], @@ -1075,6 +1140,7 @@ class AsyncCrud(Generic[ModelType]): ``pagination.total_count`` will be ``None``. search: Search query string or SearchConfig object search_fields: Fields to search in (overrides class default) + search_column: Restrict search to a single column key. facet_fields: Columns to compute distinct values for (overrides class default) filter_by: Dict of {column_key: value} to filter by declared facet fields. Keys must match the column.key of a facet field. Scalar → equality, @@ -1097,6 +1163,7 @@ class AsyncCrud(Generic[ModelType]): search, search_fields=search_fields, default_fields=cls.searchable_fields, + search_column=search_column, ) filters.extend(search_filters) search_joins.extend(new_search_joins) @@ -1153,6 +1220,7 @@ class AsyncCrud(Generic[ModelType]): filter_attributes = await cls._build_filter_attributes( session, facet_fields, filters, search_joins ) + search_columns = cls._resolve_search_columns(search_fields) return OffsetPaginatedResponse( data=items, @@ -1163,6 +1231,7 @@ class AsyncCrud(Generic[ModelType]): has_more=has_more, ), filter_attributes=filter_attributes, + search_columns=search_columns, ) @classmethod @@ -1179,6 +1248,7 @@ class AsyncCrud(Generic[ModelType]): items_per_page: int = 20, search: str | SearchConfig | None = None, search_fields: Sequence[SearchFieldType] | None = None, + search_column: str | None = None, facet_fields: Sequence[FacetFieldType] | None = None, filter_by: dict[str, Any] | BaseModel | None = None, schema: type[BaseModel], @@ -1199,6 +1269,7 @@ class AsyncCrud(Generic[ModelType]): items_per_page: Number of items per page (default 20). search: Search query string or SearchConfig object. search_fields: Fields to search in (overrides class default). + search_column: Restrict search to a single column key. facet_fields: Columns to compute distinct values for (overrides class default). filter_by: Dict of {column_key: value} to filter by declared facet fields. Keys must match the column.key of a facet field. Scalar → equality, @@ -1238,6 +1309,7 @@ class AsyncCrud(Generic[ModelType]): search, search_fields=search_fields, default_fields=cls.searchable_fields, + search_column=search_column, ) filters.extend(search_filters) search_joins.extend(new_search_joins) @@ -1308,6 +1380,7 @@ class AsyncCrud(Generic[ModelType]): filter_attributes = await cls._build_filter_attributes( session, facet_fields, filters, search_joins ) + search_columns = cls._resolve_search_columns(search_fields) return CursorPaginatedResponse( data=items, @@ -1318,6 +1391,7 @@ class AsyncCrud(Generic[ModelType]): has_more=has_more, ), filter_attributes=filter_attributes, + search_columns=search_columns, ) @overload @@ -1338,6 +1412,7 @@ class AsyncCrud(Generic[ModelType]): include_total: bool = ..., search: str | SearchConfig | None = ..., search_fields: Sequence[SearchFieldType] | None = ..., + search_column: str | None = ..., facet_fields: Sequence[FacetFieldType] | None = ..., filter_by: dict[str, Any] | BaseModel | None = ..., schema: type[BaseModel], @@ -1361,6 +1436,7 @@ class AsyncCrud(Generic[ModelType]): include_total: bool = ..., search: str | SearchConfig | None = ..., search_fields: Sequence[SearchFieldType] | None = ..., + search_column: str | None = ..., facet_fields: Sequence[FacetFieldType] | None = ..., filter_by: dict[str, Any] | BaseModel | None = ..., schema: type[BaseModel], @@ -1383,6 +1459,7 @@ class AsyncCrud(Generic[ModelType]): include_total: bool = True, search: str | SearchConfig | None = None, search_fields: Sequence[SearchFieldType] | None = None, + search_column: str | None = None, facet_fields: Sequence[FacetFieldType] | None = None, filter_by: dict[str, Any] | BaseModel | None = None, schema: type[BaseModel], @@ -1410,6 +1487,7 @@ class AsyncCrud(Generic[ModelType]): only applies when ``pagination_type`` is ``OFFSET``. search: Search query string or :class:`.SearchConfig` object. search_fields: Fields to search in (overrides class default). + search_column: Restrict search to a single column key. facet_fields: Columns to compute distinct values for (overrides class default). filter_by: Dict of ``{column_key: value}`` to filter by declared @@ -1438,6 +1516,7 @@ class AsyncCrud(Generic[ModelType]): items_per_page=items_per_page, search=search, search_fields=search_fields, + search_column=search_column, facet_fields=facet_fields, filter_by=filter_by, schema=schema, @@ -1457,6 +1536,7 @@ class AsyncCrud(Generic[ModelType]): include_total=include_total, search=search, search_fields=search_fields, + search_column=search_column, facet_fields=facet_fields, filter_by=filter_by, schema=schema, diff --git a/src/fastapi_toolsets/crud/search.py b/src/fastapi_toolsets/crud/search.py index 07abbaa..5a95b12 100644 --- a/src/fastapi_toolsets/crud/search.py +++ b/src/fastapi_toolsets/crud/search.py @@ -24,6 +24,7 @@ from sqlalchemy.types import ( from ..exceptions import ( InvalidFacetFilterError, + InvalidSearchColumnError, NoSearchableFieldsError, UnsupportedFacetTypeError, ) @@ -96,6 +97,7 @@ def build_search_filters( search: str | SearchConfig, search_fields: Sequence[SearchFieldType] | None = None, default_fields: Sequence[SearchFieldType] | None = None, + search_column: str | None = None, ) -> tuple[list["ColumnElement[bool]"], list[InstrumentedAttribute[Any]]]: """Build SQLAlchemy filter conditions for search. @@ -104,6 +106,8 @@ def build_search_filters( search: Search string or SearchConfig search_fields: Fields specified per-call (takes priority) default_fields: Default fields (from ClassVar) + search_column: Optional key to narrow search to a single field. + Must match one of the resolved search field keys. Returns: Tuple of (filter_conditions, joins_needed) @@ -130,6 +134,14 @@ def build_search_filters( if not fields: raise NoSearchableFieldsError(model) + # Narrow to a single column when search_column is specified + if search_column is not None: + keys = search_field_keys(fields) + index = {k: f for k, f in zip(keys, fields)} + if search_column not in index: + raise InvalidSearchColumnError(search_column, sorted(index)) + fields = [index[search_column]] + query = config.query.strip() filters: list[ColumnElement[bool]] = [] joins: list[InstrumentedAttribute[Any]] = [] @@ -164,6 +176,11 @@ def build_search_filters( return filters, joins +def search_field_keys(fields: Sequence[SearchFieldType]) -> list[str]: + """Return a human-readable key for each search field.""" + return facet_keys(fields) + + def facet_keys(facet_fields: Sequence[FacetFieldType]) -> list[str]: """Return a key for each facet field. diff --git a/src/fastapi_toolsets/exceptions/__init__.py b/src/fastapi_toolsets/exceptions/__init__.py index fc5da21..2ee2d26 100644 --- a/src/fastapi_toolsets/exceptions/__init__.py +++ b/src/fastapi_toolsets/exceptions/__init__.py @@ -7,6 +7,7 @@ from .exceptions import ( ForbiddenError, InvalidFacetFilterError, InvalidOrderFieldError, + InvalidSearchColumnError, NoSearchableFieldsError, NotFoundError, UnauthorizedError, @@ -24,6 +25,7 @@ __all__ = [ "init_exceptions_handlers", "InvalidFacetFilterError", "InvalidOrderFieldError", + "InvalidSearchColumnError", "NoSearchableFieldsError", "NotFoundError", "UnauthorizedError", diff --git a/src/fastapi_toolsets/exceptions/exceptions.py b/src/fastapi_toolsets/exceptions/exceptions.py index 5822bb5..772aec3 100644 --- a/src/fastapi_toolsets/exceptions/exceptions.py +++ b/src/fastapi_toolsets/exceptions/exceptions.py @@ -172,6 +172,33 @@ class UnsupportedFacetTypeError(ApiException): ) +class InvalidSearchColumnError(ApiException): + """Raised when search_column is not one of the configured searchable fields.""" + + api_error = ApiError( + code=400, + msg="Invalid Search Column", + desc="The requested search column is not a configured searchable field.", + err_code="SEARCH-COL-400", + ) + + def __init__(self, column: str, valid_columns: list[str]) -> None: + """Initialize the exception. + + Args: + column: The unknown search column provided by the caller. + valid_columns: List of valid search column keys. + """ + self.column = column + self.valid_columns = valid_columns + super().__init__( + desc=( + f"'{column}' is not a searchable column. " + f"Valid columns: {valid_columns}." + ) + ) + + class InvalidOrderFieldError(ApiException): """Raised when order_by contains a field not in the allowed order fields.""" diff --git a/src/fastapi_toolsets/schemas.py b/src/fastapi_toolsets/schemas.py index 214e460..608ebbc 100644 --- a/src/fastapi_toolsets/schemas.py +++ b/src/fastapi_toolsets/schemas.py @@ -162,6 +162,7 @@ class PaginatedResponse(BaseResponse, Generic[DataT]): pagination: OffsetPagination | CursorPagination pagination_type: PaginationType | None = None filter_attributes: dict[str, list[Any]] | None = None + search_columns: list[str] | None = None _discriminated_union_cache: ClassVar[dict[Any, Any]] = {} diff --git a/tests/test_crud.py b/tests/test_crud.py index 2e0b539..aa5708a 100644 --- a/tests/test_crud.py +++ b/tests/test_crud.py @@ -211,6 +211,38 @@ class TestResolveLoadOptions: assert crud._resolve_load_options([]) == [] +class TestResolveSearchColumns: + """Tests for _resolve_search_columns logic.""" + + def test_returns_none_when_no_searchable_fields(self): + """Returns None when cls.searchable_fields is None and no search_fields passed.""" + + class AbstractCrud(AsyncCrud[User]): + pass + + assert AbstractCrud._resolve_search_columns(None) is None + + def test_returns_none_when_empty_search_fields_passed(self): + """Returns None when an empty list is passed explicitly.""" + crud = CrudFactory(User) + assert crud._resolve_search_columns([]) is None + + def test_returns_keys_from_class_searchable_fields(self): + """Returns column keys from cls.searchable_fields when no override passed.""" + crud = CrudFactory(User, searchable_fields=[User.username]) + result = crud._resolve_search_columns(None) + assert result is not None + assert "username" in result + + def test_search_fields_override_takes_priority(self): + """Explicit search_fields override cls.searchable_fields.""" + crud = CrudFactory(User, searchable_fields=[User.username]) + result = crud._resolve_search_columns([User.email]) + assert result is not None + assert "email" in result + assert "username" not in result + + class TestDefaultLoadOptionsIntegration: """Integration tests for default_load_options with real DB queries.""" diff --git a/tests/test_crud_search.py b/tests/test_crud_search.py index 3c71be8..9e88171 100644 --- a/tests/test_crud_search.py +++ b/tests/test_crud_search.py @@ -10,6 +10,7 @@ from sqlalchemy.sql.elements import ColumnElement, UnaryExpression from fastapi_toolsets.crud import ( CrudFactory, InvalidFacetFilterError, + InvalidSearchColumnError, SearchConfig, UnsupportedFacetTypeError, get_searchable_fields, @@ -1199,6 +1200,208 @@ class TestFilterParamsSchema: assert result.pagination.total_count == 2 +class TestSearchParamsSchema: + """Tests for AsyncCrud.search_params().""" + + def test_generates_search_and_search_column_params(self): + """Returned dependency has search and search_column query params.""" + UserSearchCrud = CrudFactory( + User, searchable_fields=[User.username, User.email] + ) + dep = UserSearchCrud.search_params() + + param_names = set(inspect.signature(dep).parameters) + assert param_names == {"search", "search_column"} + + def test_dependency_name_includes_model_name(self): + """Dependency function is named {Model}SearchParams.""" + UserSearchCrud = CrudFactory( + User, searchable_fields=[User.username, User.email] + ) + dep = UserSearchCrud.search_params() + assert dep.__name__ == "UserSearchParams" # type: ignore[union-attr] # ty:ignore[unresolved-attribute] + + def test_raises_when_no_searchable_fields(self): + """ValueError raised when overriding with empty search_fields.""" + UserSearchCrud = CrudFactory(User, searchable_fields=[User.username]) + with pytest.raises(ValueError, match="no searchable_fields"): + UserSearchCrud.search_params(search_fields=[]) + + @pytest.mark.anyio + async def test_awaiting_dep_with_search_only(self): + """Awaiting the dependency with only search returns a dict with search key.""" + UserSearchCrud = CrudFactory( + User, searchable_fields=[User.username, User.email] + ) + dep = UserSearchCrud.search_params() + + result = await dep(search="alice") + assert result == {"search": "alice"} + + @pytest.mark.anyio + async def test_awaiting_dep_with_search_and_column(self): + """Awaiting the dependency with both params returns both keys.""" + UserSearchCrud = CrudFactory( + User, searchable_fields=[User.username, User.email] + ) + dep = UserSearchCrud.search_params() + + result = await dep(search="alice", search_column="username") + assert result == {"search": "alice", "search_column": "username"} + + @pytest.mark.anyio + async def test_awaiting_dep_with_no_values(self): + """Awaiting the dependency with no values returns an empty dict.""" + UserSearchCrud = CrudFactory( + User, searchable_fields=[User.username, User.email] + ) + dep = UserSearchCrud.search_params() + + result = await dep() + assert result == {} + + def test_relationship_search_field_key(self): + """Relationship tuple search fields use __ joined keys.""" + UserRelSearchCrud = CrudFactory( + User, searchable_fields=[User.username, (User.role, Role.name)] + ) + dep = UserRelSearchCrud.search_params() + + params = inspect.signature(dep).parameters + search_column_param = params["search_column"] + assert search_column_param.default.json_schema_extra.get("enum") == [ + "id", + "username", + "role__name", + ] + + +class TestSearchColumns: + """Tests for search_columns in paginated responses.""" + + @pytest.mark.anyio + async def test_search_columns_returned_in_offset_paginate( + self, db_session: AsyncSession + ): + """offset_paginate response includes search_columns.""" + UserSearchCrud = CrudFactory( + User, searchable_fields=[User.username, User.email] + ) + await UserCrud.create( + db_session, UserCreate(username="alice", email="a@test.com") + ) + + result = await UserSearchCrud.offset_paginate(db_session, schema=UserRead) + + assert result.search_columns is not None + assert "username" in result.search_columns + assert "email" in result.search_columns + + @pytest.mark.anyio + async def test_search_columns_returned_in_cursor_paginate( + self, db_session: AsyncSession + ): + """cursor_paginate response includes search_columns.""" + UserSearchCursorCrud = CrudFactory( + User, + cursor_column=User.id, + searchable_fields=[User.username, User.email], + ) + await UserCrud.create( + db_session, UserCreate(username="alice", email="a@test.com") + ) + + result = await UserSearchCursorCrud.cursor_paginate(db_session, schema=UserRead) + + assert result.search_columns is not None + assert "username" in result.search_columns + assert "email" in result.search_columns + + @pytest.mark.anyio + async def test_search_column_narrows_search(self, db_session: AsyncSession): + """search_column restricts search to a single field.""" + UserSearchCrud = CrudFactory( + User, searchable_fields=[User.username, User.email] + ) + await UserCrud.create( + db_session, UserCreate(username="alice", email="bob@test.com") + ) + await UserCrud.create( + db_session, UserCreate(username="bob", email="alice@test.com") + ) + + # Search "alice" in username only — should return only alice + result = await UserSearchCrud.offset_paginate( + db_session, search="alice", search_column="username", schema=UserRead + ) + + assert isinstance(result.pagination, OffsetPagination) + assert result.pagination.total_count == 1 + assert result.data[0].username == "alice" + + @pytest.mark.anyio + async def test_search_column_invalid_raises(self, db_session: AsyncSession): + """search_column with an invalid key raises InvalidSearchColumnError.""" + UserSearchCrud = CrudFactory( + User, searchable_fields=[User.username, User.email] + ) + + with pytest.raises(InvalidSearchColumnError) as exc_info: + await UserSearchCrud.offset_paginate( + db_session, + search="alice", + search_column="nonexistent", + schema=UserRead, + ) + + assert exc_info.value.column == "nonexistent" + + @pytest.mark.anyio + async def test_search_without_search_column_searches_all( + self, db_session: AsyncSession + ): + """search without search_column searches across all configured fields.""" + UserSearchCrud = CrudFactory( + User, searchable_fields=[User.username, User.email] + ) + await UserCrud.create( + db_session, UserCreate(username="alice", email="bob@test.com") + ) + await UserCrud.create( + db_session, UserCreate(username="bob", email="alice@test.com") + ) + + # Search "alice" across all fields — should return both + result = await UserSearchCrud.offset_paginate( + db_session, search="alice", schema=UserRead + ) + + assert isinstance(result.pagination, OffsetPagination) + assert result.pagination.total_count == 2 + + @pytest.mark.anyio + async def test_search_column_with_cursor_paginate(self, db_session: AsyncSession): + """search_column works with cursor_paginate.""" + UserSearchCursorCrud = CrudFactory( + User, + cursor_column=User.id, + searchable_fields=[User.username, User.email], + ) + await UserCrud.create( + db_session, UserCreate(username="alice", email="bob@test.com") + ) + await UserCrud.create( + db_session, UserCreate(username="bob", email="alice@test.com") + ) + + result = await UserSearchCursorCrud.cursor_paginate( + db_session, search="alice", search_column="email", schema=UserRead + ) + + assert len(result.data) == 1 + assert result.data[0].username == "bob" + + class TestOrderParamsSchema: """Tests for AsyncCrud.order_params()."""