From d14551781c1db96517c9e43d018288323194af11 Mon Sep 17 00:00:00 2001 From: d3vyce <44915747+d3vyce@users.noreply.github.com> Date: Thu, 29 Jan 2026 00:08:02 +0100 Subject: [PATCH] feat: add search to crud paginate function (#17) * feat: add search to crud paginate function * fixes: comments + tests import --- src/fastapi_toolsets/crud/__init__.py | 17 + .../{crud.py => crud/factory.py} | 96 +++-- src/fastapi_toolsets/crud/search.py | 145 +++++++ src/fastapi_toolsets/exceptions/__init__.py | 2 + src/fastapi_toolsets/exceptions/exceptions.py | 19 + tests/test_crud.py | 3 +- tests/test_crud_search.py | 392 ++++++++++++++++++ 7 files changed, 644 insertions(+), 30 deletions(-) create mode 100644 src/fastapi_toolsets/crud/__init__.py rename src/fastapi_toolsets/{crud.py => crud/factory.py} (81%) create mode 100644 src/fastapi_toolsets/crud/search.py create mode 100644 tests/test_crud_search.py diff --git a/src/fastapi_toolsets/crud/__init__.py b/src/fastapi_toolsets/crud/__init__.py new file mode 100644 index 0000000..d95aaf2 --- /dev/null +++ b/src/fastapi_toolsets/crud/__init__.py @@ -0,0 +1,17 @@ +"""Generic async CRUD operations for SQLAlchemy models.""" + +from ..exceptions import NoSearchableFieldsError +from .factory import CrudFactory +from .search import ( + SearchConfig, + SearchFieldType, + get_searchable_fields, +) + +__all__ = [ + "CrudFactory", + "get_searchable_fields", + "NoSearchableFieldsError", + "SearchConfig", + "SearchFieldType", +] diff --git a/src/fastapi_toolsets/crud.py b/src/fastapi_toolsets/crud/factory.py similarity index 81% rename from src/fastapi_toolsets/crud.py rename to src/fastapi_toolsets/crud/factory.py index 5c54eb9..a7f5fdc 100644 --- a/src/fastapi_toolsets/crud.py +++ b/src/fastapi_toolsets/crud/factory.py @@ -12,13 +12,9 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import DeclarativeBase from sqlalchemy.sql.roles import WhereHavingRole -from .db import get_transaction -from .exceptions import NotFoundError - -__all__ = [ - "AsyncCrud", - "CrudFactory", -] +from ..db import get_transaction +from ..exceptions import NotFoundError +from .search import SearchConfig, SearchFieldType, build_search_filters ModelType = TypeVar("ModelType", bound=DeclarativeBase) @@ -27,20 +23,10 @@ class AsyncCrud(Generic[ModelType]): """Generic async CRUD operations for SQLAlchemy models. Subclass this and set the `model` class variable, or use `CrudFactory`. - - Example: - class UserCrud(AsyncCrud[User]): - model = User - - # Or use the factory: - UserCrud = CrudFactory(User) - - # Then use it: - user = await UserCrud.get(session, [User.id == 1]) - users = await UserCrud.get_multi(session, limit=10) """ model: ClassVar[type[DeclarativeBase]] + searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None @classmethod async def create( @@ -313,6 +299,8 @@ class AsyncCrud(Generic[ModelType]): order_by: Any | None = None, page: int = 1, items_per_page: int = 20, + search: str | SearchConfig | None = None, + search_fields: Sequence[SearchFieldType] | None = None, ) -> dict[str, Any]: """Get paginated results with metadata. @@ -323,23 +311,54 @@ class AsyncCrud(Generic[ModelType]): order_by: Column or list of columns to order by page: Page number (1-indexed) items_per_page: Number of items per page + search: Search query string or SearchConfig object + search_fields: Fields to search in (overrides class default) Returns: Dict with 'data' and 'pagination' keys """ - filters = filters or [] + filters = list(filters) if filters else [] offset = (page - 1) * items_per_page + joins: list[Any] = [] - items = await cls.get_multi( - session, - filters=filters, - load_options=load_options, - order_by=order_by, - limit=items_per_page, - offset=offset, - ) + # Build search filters + if search: + search_filters, search_joins = build_search_filters( + cls.model, + search, + search_fields=search_fields, + default_fields=cls.searchable_fields, + ) + filters.extend(search_filters) + joins.extend(search_joins) - total_count = await cls.count(session, filters=filters) + # Build query with joins + q = select(cls.model) + for join_rel in joins: + q = q.outerjoin(join_rel) + + if filters: + q = q.where(and_(*filters)) + if load_options: + q = q.options(*load_options) + if order_by is not None: + q = q.order_by(order_by) + + q = q.offset(offset).limit(items_per_page) + result = await session.execute(q) + items = result.unique().scalars().all() + + # Count query (with same joins and filters) + pk_col = cls.model.__mapper__.primary_key[0] + count_q = select(func.count(func.distinct(getattr(cls.model, pk_col.name)))) + count_q = count_q.select_from(cls.model) + for join_rel in joins: + count_q = count_q.outerjoin(join_rel) + if filters: + count_q = count_q.where(and_(*filters)) + + count_result = await session.execute(count_q) + total_count = count_result.scalar_one() return { "data": items, @@ -354,11 +373,14 @@ class AsyncCrud(Generic[ModelType]): def CrudFactory( model: type[ModelType], + *, + searchable_fields: Sequence[SearchFieldType] | None = None, ) -> type[AsyncCrud[ModelType]]: """Create a CRUD class for a specific model. Args: model: SQLAlchemy model class + searchable_fields: Optional list of searchable fields Returns: AsyncCrud subclass bound to the model @@ -370,9 +392,25 @@ def CrudFactory( UserCrud = CrudFactory(User) PostCrud = CrudFactory(Post) + # With searchable fields: + UserCrud = CrudFactory( + User, + searchable_fields=[User.username, User.email, (User.role, Role.name)] + ) + # Usage user = await UserCrud.get(session, [User.id == 1]) posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id]) + + # With search + result = await UserCrud.paginate(session, search="john") """ - cls = type(f"Async{model.__name__}Crud", (AsyncCrud,), {"model": model}) + cls = type( + f"Async{model.__name__}Crud", + (AsyncCrud,), + { + "model": model, + "searchable_fields": searchable_fields, + }, + ) return cast(type[AsyncCrud[ModelType]], cls) diff --git a/src/fastapi_toolsets/crud/search.py b/src/fastapi_toolsets/crud/search.py new file mode 100644 index 0000000..60216e5 --- /dev/null +++ b/src/fastapi_toolsets/crud/search.py @@ -0,0 +1,145 @@ +"""Search utilities for AsyncCrud.""" + +from collections.abc import Sequence +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Literal + +from sqlalchemy import String, or_ +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm.attributes import InstrumentedAttribute + +from ..exceptions import NoSearchableFieldsError + +if TYPE_CHECKING: + from sqlalchemy.sql.elements import ColumnElement + +SearchFieldType = InstrumentedAttribute[Any] | tuple[InstrumentedAttribute[Any], ...] + + +@dataclass +class SearchConfig: + """Advanced search configuration. + + Attributes: + query: The search string + fields: Fields to search (columns or tuples for relationships) + case_sensitive: Case-sensitive search (default: False) + match_mode: "any" (OR) or "all" (AND) to combine fields + """ + + query: str + fields: Sequence[SearchFieldType] | None = None + case_sensitive: bool = False + match_mode: Literal["any", "all"] = "any" + + +def get_searchable_fields( + model: type[DeclarativeBase], + *, + include_relationships: bool = True, + max_depth: int = 1, +) -> list[SearchFieldType]: + """Auto-detect String fields on a model and its relationships. + + Args: + model: SQLAlchemy model class + include_relationships: Include fields from many-to-one/one-to-one relationships + max_depth: Max depth for relationship traversal (default: 1) + + Returns: + List of columns and tuples (relationship, column) + """ + fields: list[SearchFieldType] = [] + mapper = model.__mapper__ + + # Direct String columns + for col in mapper.columns: + if isinstance(col.type, String): + fields.append(getattr(model, col.key)) + + # Relationships (one-to-one, many-to-one only) + if include_relationships and max_depth > 0: + for rel_name, rel_prop in mapper.relationships.items(): + if rel_prop.uselist: # Skip collections (one-to-many, many-to-many) + continue + + rel_attr = getattr(model, rel_name) + related_model = rel_prop.mapper.class_ + + for col in related_model.__mapper__.columns: + if isinstance(col.type, String): + fields.append((rel_attr, getattr(related_model, col.key))) + + return fields + + +def build_search_filters( + model: type[DeclarativeBase], + search: str | SearchConfig, + search_fields: Sequence[SearchFieldType] | None = None, + default_fields: Sequence[SearchFieldType] | None = None, +) -> tuple[list["ColumnElement[bool]"], list[InstrumentedAttribute[Any]]]: + """Build SQLAlchemy filter conditions for search. + + Args: + model: SQLAlchemy model class + search: Search string or SearchConfig + search_fields: Fields specified per-call (takes priority) + default_fields: Default fields (from ClassVar) + + Returns: + Tuple of (filter_conditions, joins_needed) + """ + # Normalize input + if isinstance(search, str): + config = SearchConfig(query=search, fields=search_fields) + else: + config = search + if search_fields is not None: + config = SearchConfig( + query=config.query, + fields=search_fields, + case_sensitive=config.case_sensitive, + match_mode=config.match_mode, + ) + + if not config.query or not config.query.strip(): + return [], [] + + # Determine which fields to search + fields = config.fields or default_fields or get_searchable_fields(model) + + if not fields: + raise NoSearchableFieldsError(model) + + query = config.query.strip() + filters: list[ColumnElement[bool]] = [] + joins: list[InstrumentedAttribute[Any]] = [] + added_joins: set[str] = set() + + for field in fields: + if isinstance(field, tuple): + # Relationship: (User.role, Role.name) or deeper + for rel in field[:-1]: + rel_key = str(rel) + if rel_key not in added_joins: + joins.append(rel) + added_joins.add(rel_key) + column = field[-1] + else: + column = field + + # Build the filter + if config.case_sensitive: + filters.append(column.like(f"%{query}%")) + else: + filters.append(column.ilike(f"%{query}%")) + + if not filters: + return [], [] + + # Combine based on match_mode + if config.match_mode == "any": + return [or_(*filters)], joins + else: + return filters, joins diff --git a/src/fastapi_toolsets/exceptions/__init__.py b/src/fastapi_toolsets/exceptions/__init__.py index 490175f..f6ea230 100644 --- a/src/fastapi_toolsets/exceptions/__init__.py +++ b/src/fastapi_toolsets/exceptions/__init__.py @@ -2,6 +2,7 @@ from .exceptions import ( ApiException, ConflictError, ForbiddenError, + NoSearchableFieldsError, NotFoundError, UnauthorizedError, generate_error_responses, @@ -14,6 +15,7 @@ __all__ = [ "ApiException", "ConflictError", "ForbiddenError", + "NoSearchableFieldsError", "NotFoundError", "UnauthorizedError", ] diff --git a/src/fastapi_toolsets/exceptions/exceptions.py b/src/fastapi_toolsets/exceptions/exceptions.py index 9a625dd..fa15153 100644 --- a/src/fastapi_toolsets/exceptions/exceptions.py +++ b/src/fastapi_toolsets/exceptions/exceptions.py @@ -119,6 +119,25 @@ class RoleNotFoundError(NotFoundError): ) +class NoSearchableFieldsError(ApiException): + """Raised when search is requested but no searchable fields are available.""" + + api_error = ApiError( + code=400, + msg="No Searchable Fields", + desc="No searchable fields configured for this resource.", + err_code="SEARCH-400", + ) + + def __init__(self, model: type) -> None: + self.model = model + detail = ( + f"No searchable fields found for model '{model.__name__}'. " + "Provide 'search_fields' parameter or set 'searchable_fields' on the CRUD class." + ) + super().__init__(detail) + + def generate_error_responses( *errors: type[ApiException], ) -> dict[int | str, dict[str, Any]]: diff --git a/tests/test_crud.py b/tests/test_crud.py index adf71f2..bb79bcc 100644 --- a/tests/test_crud.py +++ b/tests/test_crud.py @@ -3,7 +3,8 @@ import pytest from sqlalchemy.ext.asyncio import AsyncSession -from fastapi_toolsets.crud import AsyncCrud, CrudFactory +from fastapi_toolsets.crud import CrudFactory +from fastapi_toolsets.crud.factory import AsyncCrud from fastapi_toolsets.exceptions import NotFoundError from .conftest import ( diff --git a/tests/test_crud_search.py b/tests/test_crud_search.py new file mode 100644 index 0000000..46b50bc --- /dev/null +++ b/tests/test_crud_search.py @@ -0,0 +1,392 @@ +"""Tests for CRUD search functionality.""" + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from fastapi_toolsets.crud import SearchConfig, get_searchable_fields + +from .conftest import ( + Role, + RoleCreate, + RoleCrud, + User, + UserCreate, + UserCrud, +) + + +class TestPaginateSearch: + """Tests for paginate() with search.""" + + @pytest.mark.anyio + async def test_search_single_column(self, db_session: AsyncSession): + """Search on a single direct column.""" + await UserCrud.create( + db_session, UserCreate(username="john_doe", email="john@test.com") + ) + await UserCrud.create( + db_session, UserCreate(username="jane_doe", email="jane@test.com") + ) + await UserCrud.create( + db_session, UserCreate(username="bob_smith", email="bob@test.com") + ) + + result = await UserCrud.paginate( + db_session, + search="doe", + search_fields=[User.username], + ) + + assert result["pagination"]["total_count"] == 2 + + @pytest.mark.anyio + async def test_search_multiple_columns(self, db_session: AsyncSession): + """Search across multiple columns (OR logic).""" + await UserCrud.create( + db_session, UserCreate(username="alice", email="alice@company.com") + ) + await UserCrud.create( + db_session, UserCreate(username="company_bob", email="bob@other.com") + ) + + result = await UserCrud.paginate( + db_session, + search="company", + search_fields=[User.username, User.email], + ) + + assert result["pagination"]["total_count"] == 2 + + @pytest.mark.anyio + async def test_search_relationship_depth1(self, db_session: AsyncSession): + """Search through a relationship (depth 1).""" + admin_role = await RoleCrud.create(db_session, RoleCreate(name="administrator")) + user_role = await RoleCrud.create(db_session, RoleCreate(name="basic_user")) + + await UserCrud.create( + db_session, + UserCreate(username="admin1", email="a1@test.com", role_id=admin_role.id), + ) + await UserCrud.create( + db_session, + UserCreate(username="admin2", email="a2@test.com", role_id=admin_role.id), + ) + await UserCrud.create( + db_session, + UserCreate(username="user1", email="u1@test.com", role_id=user_role.id), + ) + + result = await UserCrud.paginate( + db_session, + search="admin", + search_fields=[(User.role, Role.name)], + ) + + assert result["pagination"]["total_count"] == 2 + + @pytest.mark.anyio + async def test_search_mixed_direct_and_relation(self, db_session: AsyncSession): + """Search combining direct columns and relationships.""" + role = await RoleCrud.create(db_session, RoleCreate(name="admin")) + await UserCrud.create( + db_session, + UserCreate(username="john", email="john@test.com", role_id=role.id), + ) + + # Search "admin" in username OR role.name + result = await UserCrud.paginate( + db_session, + search="admin", + search_fields=[User.username, (User.role, Role.name)], + ) + + assert result["pagination"]["total_count"] == 1 + + @pytest.mark.anyio + async def test_search_case_insensitive(self, db_session: AsyncSession): + """Search is case-insensitive by default.""" + await UserCrud.create( + db_session, UserCreate(username="JohnDoe", email="j@test.com") + ) + + result = await UserCrud.paginate( + db_session, + search="johndoe", + search_fields=[User.username], + ) + + assert result["pagination"]["total_count"] == 1 + + @pytest.mark.anyio + async def test_search_case_sensitive(self, db_session: AsyncSession): + """Case-sensitive search with SearchConfig.""" + await UserCrud.create( + db_session, UserCreate(username="JohnDoe", email="j@test.com") + ) + + # Should not find (case mismatch) + result = await UserCrud.paginate( + db_session, + search=SearchConfig(query="johndoe", case_sensitive=True), + search_fields=[User.username], + ) + assert result["pagination"]["total_count"] == 0 + + # Should find (case match) + result = await UserCrud.paginate( + db_session, + search=SearchConfig(query="JohnDoe", case_sensitive=True), + search_fields=[User.username], + ) + assert result["pagination"]["total_count"] == 1 + + @pytest.mark.anyio + async def test_search_empty_query(self, db_session: AsyncSession): + """Empty search returns all results.""" + await UserCrud.create( + db_session, UserCreate(username="user1", email="u1@test.com") + ) + await UserCrud.create( + db_session, UserCreate(username="user2", email="u2@test.com") + ) + + result = await UserCrud.paginate(db_session, search="") + assert result["pagination"]["total_count"] == 2 + + result = await UserCrud.paginate(db_session, search=None) + assert result["pagination"]["total_count"] == 2 + + @pytest.mark.anyio + async def test_search_with_existing_filters(self, db_session: AsyncSession): + """Search combines with existing filters (AND).""" + await UserCrud.create( + db_session, + UserCreate(username="active_john", email="aj@test.com", is_active=True), + ) + await UserCrud.create( + db_session, + UserCreate(username="inactive_john", email="ij@test.com", is_active=False), + ) + + result = await UserCrud.paginate( + db_session, + filters=[User.is_active == True], # noqa: E712 + search="john", + search_fields=[User.username], + ) + + assert result["pagination"]["total_count"] == 1 + assert result["data"][0].username == "active_john" + + @pytest.mark.anyio + async def test_search_auto_detect_fields(self, db_session: AsyncSession): + """Auto-detect searchable fields when not specified.""" + await UserCrud.create( + db_session, UserCreate(username="findme", email="other@test.com") + ) + + result = await UserCrud.paginate(db_session, search="findme") + + assert result["pagination"]["total_count"] == 1 + + @pytest.mark.anyio + async def test_search_no_results(self, db_session: AsyncSession): + """Search with no matching results.""" + await UserCrud.create( + db_session, UserCreate(username="john", email="j@test.com") + ) + + result = await UserCrud.paginate( + db_session, + search="nonexistent", + search_fields=[User.username], + ) + + assert result["pagination"]["total_count"] == 0 + assert result["data"] == [] + + @pytest.mark.anyio + async def test_search_with_pagination(self, db_session: AsyncSession): + """Search respects pagination parameters.""" + for i in range(15): + await UserCrud.create( + db_session, + UserCreate(username=f"user_{i}", email=f"user{i}@test.com"), + ) + + result = await UserCrud.paginate( + db_session, + search="user_", + search_fields=[User.username], + page=1, + items_per_page=5, + ) + + assert result["pagination"]["total_count"] == 15 + assert len(result["data"]) == 5 + assert result["pagination"]["has_more"] is True + + @pytest.mark.anyio + async def test_search_null_relationship(self, db_session: AsyncSession): + """Users without relationship are included (outerjoin).""" + role = await RoleCrud.create(db_session, RoleCreate(name="admin")) + await UserCrud.create( + db_session, + UserCreate(username="with_role", email="wr@test.com", role_id=role.id), + ) + await UserCrud.create( + db_session, + UserCreate(username="no_role", email="nr@test.com", role_id=None), + ) + + # Search in username, not in role + result = await UserCrud.paginate( + db_session, + search="role", + search_fields=[User.username], + ) + + assert result["pagination"]["total_count"] == 2 + + @pytest.mark.anyio + async def test_search_with_order_by(self, db_session: AsyncSession): + """Search works with order_by parameter.""" + await UserCrud.create( + db_session, UserCreate(username="charlie", email="c@test.com") + ) + await UserCrud.create( + db_session, UserCreate(username="alice", email="a@test.com") + ) + await UserCrud.create( + db_session, UserCreate(username="bob", email="b@test.com") + ) + + result = await UserCrud.paginate( + db_session, + search="@test.com", + search_fields=[User.email], + order_by=User.username, + ) + + assert result["pagination"]["total_count"] == 3 + usernames = [u.username for u in result["data"]] + assert usernames == ["alice", "bob", "charlie"] + + +class TestSearchConfig: + """Tests for SearchConfig options.""" + + @pytest.mark.anyio + async def test_match_mode_all(self, db_session: AsyncSession): + """match_mode='all' requires all fields to match (AND).""" + await UserCrud.create( + db_session, + UserCreate(username="john_test", email="john_test@company.com"), + ) + await UserCrud.create( + db_session, + UserCreate(username="john_other", email="other@example.com"), + ) + + # 'john' must be in username AND email + result = await UserCrud.paginate( + db_session, + search=SearchConfig(query="john", match_mode="all"), + search_fields=[User.username, User.email], + ) + + assert result["pagination"]["total_count"] == 1 + assert result["data"][0].username == "john_test" + + @pytest.mark.anyio + async def test_search_config_with_fields(self, db_session: AsyncSession): + """SearchConfig can specify fields directly.""" + await UserCrud.create( + db_session, UserCreate(username="test", email="findme@test.com") + ) + + result = await UserCrud.paginate( + db_session, + search=SearchConfig(query="findme", fields=[User.email]), + ) + + assert result["pagination"]["total_count"] == 1 + + +class TestNoSearchableFieldsError: + """Tests for NoSearchableFieldsError exception.""" + + def test_error_is_api_exception(self): + """NoSearchableFieldsError inherits from ApiException.""" + from fastapi_toolsets.exceptions import ApiException, NoSearchableFieldsError + + assert issubclass(NoSearchableFieldsError, ApiException) + + def test_error_has_api_error_fields(self): + """NoSearchableFieldsError has proper api_error configuration.""" + from fastapi_toolsets.exceptions import NoSearchableFieldsError + + assert NoSearchableFieldsError.api_error.code == 400 + assert NoSearchableFieldsError.api_error.err_code == "SEARCH-400" + + def test_error_message_contains_model_name(self): + """Error message includes the model name.""" + from fastapi_toolsets.exceptions import NoSearchableFieldsError + + error = NoSearchableFieldsError(User) + assert "User" in str(error) + assert error.model is User + + def test_error_raised_when_no_fields(self): + """Error is raised when search has no searchable fields.""" + from sqlalchemy import Integer + from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + + from fastapi_toolsets.crud.search import build_search_filters + from fastapi_toolsets.exceptions import NoSearchableFieldsError + + # Model with no String columns + class NoStringBase(DeclarativeBase): + pass + + class NoStringModel(NoStringBase): + __tablename__ = "no_strings" + id: Mapped[int] = mapped_column(Integer, primary_key=True) + count: Mapped[int] = mapped_column(Integer, default=0) + + with pytest.raises(NoSearchableFieldsError) as exc_info: + build_search_filters(NoStringModel, "test") + + assert exc_info.value.model is NoStringModel + assert "NoStringModel" in str(exc_info.value) + + +class TestGetSearchableFields: + """Tests for auto-detection of searchable fields.""" + + def test_detects_string_columns(self): + """Detects String columns on the model.""" + fields = get_searchable_fields(User, include_relationships=False) + + # Should include username and email (String), not id or is_active + field_names = [str(f) for f in fields] + assert any("username" in f for f in field_names) + assert any("email" in f for f in field_names) + assert not any("id" in f and "role_id" not in f for f in field_names) + assert not any("is_active" in f for f in field_names) + + def test_detects_relationship_fields(self): + """Detects String fields on related models.""" + fields = get_searchable_fields(User, include_relationships=True) + + # Should include (User.role, Role.name) + has_role_name = any(isinstance(f, tuple) and len(f) == 2 for f in fields) + assert has_role_name + + def test_skips_collection_relationships(self): + """Skips one-to-many relationships.""" + fields = get_searchable_fields(Role, include_relationships=True) + + # Role.users is a collection, should not be included + field_strs = [str(f) for f in fields] + assert not any("users" in f for f in field_strs)