From ebaa61525f6f0e2ae5bbad3b735cb467d0b7cae5 Mon Sep 17 00:00:00 2001 From: d3vyce <44915747+d3vyce@users.noreply.github.com> Date: Tue, 31 Mar 2026 21:36:54 +0200 Subject: [PATCH] fix: handle boolean and ARRAY column types in filter_by facet filtering (#203) --- src/fastapi_toolsets/crud/__init__.py | 7 +- src/fastapi_toolsets/crud/search.py | 57 +++++++- src/fastapi_toolsets/exceptions/__init__.py | 2 + src/fastapi_toolsets/exceptions/exceptions.py | 28 ++++ tests/conftest.py | 30 +++++ tests/test_crud_search.py | 127 ++++++++++++++++++ 6 files changed, 243 insertions(+), 8 deletions(-) diff --git a/src/fastapi_toolsets/crud/__init__.py b/src/fastapi_toolsets/crud/__init__.py index bc4fb60..8d431b6 100644 --- a/src/fastapi_toolsets/crud/__init__.py +++ b/src/fastapi_toolsets/crud/__init__.py @@ -1,6 +1,10 @@ """Generic async CRUD operations for SQLAlchemy models.""" -from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError +from ..exceptions import ( + InvalidFacetFilterError, + NoSearchableFieldsError, + UnsupportedFacetTypeError, +) from ..schemas import PaginationType from ..types import ( FacetFieldType, @@ -25,4 +29,5 @@ __all__ = [ "PaginationType", "SearchConfig", "SearchFieldType", + "UnsupportedFacetTypeError", ] diff --git a/src/fastapi_toolsets/crud/search.py b/src/fastapi_toolsets/crud/search.py index 987dc5b..07abbaa 100644 --- a/src/fastapi_toolsets/crud/search.py +++ b/src/fastapi_toolsets/crud/search.py @@ -6,12 +6,27 @@ from collections.abc import Sequence from dataclasses import dataclass, replace from typing import TYPE_CHECKING, Any, Literal -from sqlalchemy import String, and_, or_, select +from sqlalchemy import String, and_, func, or_, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm.attributes import InstrumentedAttribute +from sqlalchemy.types import ( + ARRAY, + Boolean, + Date, + DateTime, + Enum, + Integer, + Numeric, + Time, + Uuid, +) -from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError +from ..exceptions import ( + InvalidFacetFilterError, + NoSearchableFieldsError, + UnsupportedFacetTypeError, +) from ..types import FacetFieldType, SearchFieldType if TYPE_CHECKING: @@ -201,7 +216,14 @@ async def build_facets( rels = () column = field - q = select(column).select_from(model).distinct() + col_type = column.property.columns[0].type + is_array = isinstance(col_type, ARRAY) + + if is_array: + unnested = func.unnest(column).label(column.key) + q = select(unnested).select_from(model).distinct() + else: + q = select(column).select_from(model).distinct() # Apply base joins (already done on main query, but needed here independently) for rel in base_joins or []: @@ -215,7 +237,10 @@ async def build_facets( if base_filters: q = q.where(and_(*base_filters)) - q = q.order_by(column) + if is_array: + q = q.order_by(unnested) + else: + q = q.order_by(column) result = await session.execute(q) values = [row[0] for row in result.all() if row[0] is not None] return key, values @@ -226,6 +251,10 @@ async def build_facets( return dict(pairs) +_EQUALITY_TYPES = (String, Integer, Numeric, Date, DateTime, Time, Enum, Uuid) +"""Column types that support equality / IN filtering in build_filter_by.""" + + def build_filter_by( filter_by: dict[str, Any], facet_fields: Sequence[FacetFieldType], @@ -271,9 +300,23 @@ def build_filter_by( joins.append(rel) added_join_keys.add(rel_key) - if isinstance(value, list): - filters.append(column.in_(value)) + col_type = column.property.columns[0].type + if isinstance(col_type, ARRAY): + if isinstance(value, list): + filters.append(column.overlap(value)) + else: + filters.append(column.any(value)) + elif isinstance(col_type, Boolean): + if isinstance(value, list): + filters.append(column.in_(value)) + else: + filters.append(column.is_(value)) + elif isinstance(col_type, _EQUALITY_TYPES): + if isinstance(value, list): + filters.append(column.in_(value)) + else: + filters.append(column == value) else: - filters.append(column == value) + raise UnsupportedFacetTypeError(key, type(col_type).__name__) return filters, joins diff --git a/src/fastapi_toolsets/exceptions/__init__.py b/src/fastapi_toolsets/exceptions/__init__.py index cc43a1c..fc5da21 100644 --- a/src/fastapi_toolsets/exceptions/__init__.py +++ b/src/fastapi_toolsets/exceptions/__init__.py @@ -10,6 +10,7 @@ from .exceptions import ( NoSearchableFieldsError, NotFoundError, UnauthorizedError, + UnsupportedFacetTypeError, generate_error_responses, ) from .handler import init_exceptions_handlers @@ -26,4 +27,5 @@ __all__ = [ "NoSearchableFieldsError", "NotFoundError", "UnauthorizedError", + "UnsupportedFacetTypeError", ] diff --git a/src/fastapi_toolsets/exceptions/exceptions.py b/src/fastapi_toolsets/exceptions/exceptions.py index be5a762..5822bb5 100644 --- a/src/fastapi_toolsets/exceptions/exceptions.py +++ b/src/fastapi_toolsets/exceptions/exceptions.py @@ -144,6 +144,34 @@ class InvalidFacetFilterError(ApiException): ) +class UnsupportedFacetTypeError(ApiException): + """Raised when a facet field has a column type not supported by filter_by.""" + + api_error = ApiError( + code=400, + msg="Unsupported Facet Type", + desc="The column type is not supported for facet filtering.", + err_code="FACET-TYPE-400", + ) + + def __init__(self, key: str, col_type: str) -> None: + """Initialize the exception. + + Args: + key: The facet field key. + col_type: The unsupported column type name. + """ + self.key = key + self.col_type = col_type + super().__init__( + desc=( + f"Facet field '{key}' has unsupported column type '{col_type}'. " + f"Supported types: String, Integer, Numeric, Boolean, " + f"Date, DateTime, Time, Enum, Uuid, ARRAY." + ) + ) + + class InvalidOrderFieldError(ApiException): """Raised when order_by contains a field not in the allowed order fields.""" diff --git a/tests/conftest.py b/tests/conftest.py index 18388b7..292a2d9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,11 +14,13 @@ from sqlalchemy import ( DateTime, ForeignKey, Integer, + JSON, Numeric, String, Table, Uuid, ) +from sqlalchemy.dialects.postgresql import ARRAY from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship @@ -137,6 +139,17 @@ class Post(Base): tags: Mapped[list[Tag]] = relationship(secondary=post_tags) +class Article(Base): + """Test article model with ARRAY and JSON columns.""" + + __tablename__ = "articles" + + id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4) + title: Mapped[str] = mapped_column(String(200)) + labels: Mapped[list[str]] = mapped_column(ARRAY(String)) + metadata_: Mapped[dict | None] = mapped_column("metadata", JSON, nullable=True) + + class RoleCreate(BaseModel): """Schema for creating a role.""" @@ -271,6 +284,23 @@ class ProductCreate(BaseModel): price: decimal.Decimal +class ArticleCreate(BaseModel): + """Schema for creating an article.""" + + id: uuid.UUID | None = None + title: str + labels: list[str] = [] + + +class ArticleRead(PydanticBase): + """Schema for reading an article.""" + + id: uuid.UUID + title: str + labels: list[str] + + +ArticleCrud = CrudFactory(Article) RoleCrud = CrudFactory(Role) RoleCursorCrud = CrudFactory(Role, cursor_column=Role.id) IntRoleCursorCrud = CrudFactory(IntRole, cursor_column=IntRole.id) diff --git a/tests/test_crud_search.py b/tests/test_crud_search.py index 0933a01..3c71be8 100644 --- a/tests/test_crud_search.py +++ b/tests/test_crud_search.py @@ -11,12 +11,17 @@ from fastapi_toolsets.crud import ( CrudFactory, InvalidFacetFilterError, SearchConfig, + UnsupportedFacetTypeError, get_searchable_fields, ) from fastapi_toolsets.exceptions import InvalidOrderFieldError from fastapi_toolsets.schemas import OffsetPagination, PaginationType from .conftest import ( + Article, + ArticleCreate, + ArticleCrud, + ArticleRead, Role, RoleCreate, RoleCrud, @@ -902,6 +907,128 @@ class TestFilterBy: assert len(result.data) == 1 assert result.data[0].username == "alice" + @pytest.mark.anyio + async def test_bool_filter_false(self, db_session: AsyncSession): + """filter_by with a boolean False value correctly filters rows.""" + UserBoolCrud = CrudFactory(User, facet_fields=[User.is_active]) + await UserCrud.create( + db_session, UserCreate(username="alice", email="a@test.com", is_active=True) + ) + await UserCrud.create( + db_session, + UserCreate(username="bob", email="b@test.com", is_active=False), + ) + + result = await UserBoolCrud.offset_paginate( + db_session, filter_by={"is_active": False}, schema=UserRead + ) + + assert isinstance(result.pagination, OffsetPagination) + assert result.pagination.total_count == 1 + assert result.data[0].username == "bob" + + @pytest.mark.anyio + async def test_bool_filter_true(self, db_session: AsyncSession): + """filter_by with a boolean True value correctly filters rows.""" + UserBoolCrud = CrudFactory(User, facet_fields=[User.is_active]) + await UserCrud.create( + db_session, UserCreate(username="alice", email="a@test.com", is_active=True) + ) + await UserCrud.create( + db_session, + UserCreate(username="bob", email="b@test.com", is_active=False), + ) + + result = await UserBoolCrud.offset_paginate( + db_session, filter_by={"is_active": True}, schema=UserRead + ) + + assert isinstance(result.pagination, OffsetPagination) + assert result.pagination.total_count == 1 + assert result.data[0].username == "alice" + + @pytest.mark.anyio + async def test_bool_filter_list(self, db_session: AsyncSession): + """filter_by with a list of booleans produces an IN clause.""" + UserBoolCrud = CrudFactory(User, facet_fields=[User.is_active]) + await UserCrud.create( + db_session, UserCreate(username="alice", email="a@test.com", is_active=True) + ) + await UserCrud.create( + db_session, + UserCreate(username="bob", email="b@test.com", is_active=False), + ) + + result = await UserBoolCrud.offset_paginate( + db_session, filter_by={"is_active": [True, False]}, schema=UserRead + ) + + assert isinstance(result.pagination, OffsetPagination) + assert result.pagination.total_count == 2 + + @pytest.mark.anyio + async def test_array_contains_single_value(self, db_session: AsyncSession): + """filter_by on an ARRAY column with a scalar checks containment.""" + ArticleFacetCrud = CrudFactory(Article, facet_fields=[Article.labels]) + await ArticleCrud.create( + db_session, ArticleCreate(title="Post 1", labels=["python", "fastapi"]) + ) + await ArticleCrud.create( + db_session, ArticleCreate(title="Post 2", labels=["rust", "axum"]) + ) + await ArticleCrud.create( + db_session, ArticleCreate(title="Post 3", labels=["python", "django"]) + ) + + result = await ArticleFacetCrud.offset_paginate( + db_session, filter_by={"labels": "python"}, schema=ArticleRead + ) + + assert isinstance(result.pagination, OffsetPagination) + assert result.pagination.total_count == 2 + titles = {a.title for a in result.data} + assert titles == {"Post 1", "Post 3"} + # facet returns individual unnested values, not whole arrays + assert result.filter_attributes == {"labels": ["django", "fastapi", "python"]} + + @pytest.mark.anyio + async def test_array_overlap_list_value(self, db_session: AsyncSession): + """filter_by on an ARRAY column with a list checks overlap.""" + ArticleFacetCrud = CrudFactory(Article, facet_fields=[Article.labels]) + await ArticleCrud.create( + db_session, ArticleCreate(title="Post 1", labels=["python", "fastapi"]) + ) + await ArticleCrud.create( + db_session, ArticleCreate(title="Post 2", labels=["rust", "axum"]) + ) + await ArticleCrud.create( + db_session, ArticleCreate(title="Post 3", labels=["python", "django"]) + ) + + result = await ArticleFacetCrud.offset_paginate( + db_session, filter_by={"labels": ["rust", "django"]}, schema=ArticleRead + ) + + assert isinstance(result.pagination, OffsetPagination) + assert result.pagination.total_count == 2 + titles = {a.title for a in result.data} + assert titles == {"Post 2", "Post 3"} + + @pytest.mark.anyio + async def test_unsupported_column_type_raises(self, db_session: AsyncSession): + """filter_by on a JSON column raises UnsupportedFacetTypeError.""" + ArticleJsonCrud = CrudFactory(Article, facet_fields=[Article.metadata_]) + + with pytest.raises(UnsupportedFacetTypeError) as exc_info: + await ArticleJsonCrud.offset_paginate( + db_session, + filter_by={"metadata_": {"key": "value"}}, + schema=ArticleRead, + ) + + assert exc_info.value.key == "metadata_" + assert "JSON" in exc_info.value.col_type + class TestFilterParamsSchema: """Tests for AsyncCrud.filter_params()."""