fix: handle boolean and ARRAY column types in filter_by facet filtering (#203)

This commit is contained in:
d3vyce
2026-03-31 21:36:54 +02:00
committed by GitHub
parent 4829cfba73
commit ebaa61525f
6 changed files with 243 additions and 8 deletions

View File

@@ -1,6 +1,10 @@
"""Generic async CRUD operations for SQLAlchemy models."""
from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError
from ..exceptions import (
InvalidFacetFilterError,
NoSearchableFieldsError,
UnsupportedFacetTypeError,
)
from ..schemas import PaginationType
from ..types import (
FacetFieldType,
@@ -25,4 +29,5 @@ __all__ = [
"PaginationType",
"SearchConfig",
"SearchFieldType",
"UnsupportedFacetTypeError",
]

View File

@@ -6,12 +6,27 @@ from collections.abc import Sequence
from dataclasses import dataclass, replace
from typing import TYPE_CHECKING, Any, Literal
from sqlalchemy import String, and_, or_, select
from sqlalchemy import String, and_, func, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.types import (
ARRAY,
Boolean,
Date,
DateTime,
Enum,
Integer,
Numeric,
Time,
Uuid,
)
from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError
from ..exceptions import (
InvalidFacetFilterError,
NoSearchableFieldsError,
UnsupportedFacetTypeError,
)
from ..types import FacetFieldType, SearchFieldType
if TYPE_CHECKING:
@@ -201,7 +216,14 @@ async def build_facets(
rels = ()
column = field
q = select(column).select_from(model).distinct()
col_type = column.property.columns[0].type
is_array = isinstance(col_type, ARRAY)
if is_array:
unnested = func.unnest(column).label(column.key)
q = select(unnested).select_from(model).distinct()
else:
q = select(column).select_from(model).distinct()
# Apply base joins (already done on main query, but needed here independently)
for rel in base_joins or []:
@@ -215,7 +237,10 @@ async def build_facets(
if base_filters:
q = q.where(and_(*base_filters))
q = q.order_by(column)
if is_array:
q = q.order_by(unnested)
else:
q = q.order_by(column)
result = await session.execute(q)
values = [row[0] for row in result.all() if row[0] is not None]
return key, values
@@ -226,6 +251,10 @@ async def build_facets(
return dict(pairs)
_EQUALITY_TYPES = (String, Integer, Numeric, Date, DateTime, Time, Enum, Uuid)
"""Column types that support equality / IN filtering in build_filter_by."""
def build_filter_by(
filter_by: dict[str, Any],
facet_fields: Sequence[FacetFieldType],
@@ -271,9 +300,23 @@ def build_filter_by(
joins.append(rel)
added_join_keys.add(rel_key)
if isinstance(value, list):
filters.append(column.in_(value))
col_type = column.property.columns[0].type
if isinstance(col_type, ARRAY):
if isinstance(value, list):
filters.append(column.overlap(value))
else:
filters.append(column.any(value))
elif isinstance(col_type, Boolean):
if isinstance(value, list):
filters.append(column.in_(value))
else:
filters.append(column.is_(value))
elif isinstance(col_type, _EQUALITY_TYPES):
if isinstance(value, list):
filters.append(column.in_(value))
else:
filters.append(column == value)
else:
filters.append(column == value)
raise UnsupportedFacetTypeError(key, type(col_type).__name__)
return filters, joins

View File

@@ -10,6 +10,7 @@ from .exceptions import (
NoSearchableFieldsError,
NotFoundError,
UnauthorizedError,
UnsupportedFacetTypeError,
generate_error_responses,
)
from .handler import init_exceptions_handlers
@@ -26,4 +27,5 @@ __all__ = [
"NoSearchableFieldsError",
"NotFoundError",
"UnauthorizedError",
"UnsupportedFacetTypeError",
]

View File

@@ -144,6 +144,34 @@ class InvalidFacetFilterError(ApiException):
)
class UnsupportedFacetTypeError(ApiException):
"""Raised when a facet field has a column type not supported by filter_by."""
api_error = ApiError(
code=400,
msg="Unsupported Facet Type",
desc="The column type is not supported for facet filtering.",
err_code="FACET-TYPE-400",
)
def __init__(self, key: str, col_type: str) -> None:
"""Initialize the exception.
Args:
key: The facet field key.
col_type: The unsupported column type name.
"""
self.key = key
self.col_type = col_type
super().__init__(
desc=(
f"Facet field '{key}' has unsupported column type '{col_type}'. "
f"Supported types: String, Integer, Numeric, Boolean, "
f"Date, DateTime, Time, Enum, Uuid, ARRAY."
)
)
class InvalidOrderFieldError(ApiException):
"""Raised when order_by contains a field not in the allowed order fields."""

View File

@@ -14,11 +14,13 @@ from sqlalchemy import (
DateTime,
ForeignKey,
Integer,
JSON,
Numeric,
String,
Table,
Uuid,
)
from sqlalchemy.dialects.postgresql import ARRAY
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
@@ -137,6 +139,17 @@ class Post(Base):
tags: Mapped[list[Tag]] = relationship(secondary=post_tags)
class Article(Base):
"""Test article model with ARRAY and JSON columns."""
__tablename__ = "articles"
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
title: Mapped[str] = mapped_column(String(200))
labels: Mapped[list[str]] = mapped_column(ARRAY(String))
metadata_: Mapped[dict | None] = mapped_column("metadata", JSON, nullable=True)
class RoleCreate(BaseModel):
"""Schema for creating a role."""
@@ -271,6 +284,23 @@ class ProductCreate(BaseModel):
price: decimal.Decimal
class ArticleCreate(BaseModel):
"""Schema for creating an article."""
id: uuid.UUID | None = None
title: str
labels: list[str] = []
class ArticleRead(PydanticBase):
"""Schema for reading an article."""
id: uuid.UUID
title: str
labels: list[str]
ArticleCrud = CrudFactory(Article)
RoleCrud = CrudFactory(Role)
RoleCursorCrud = CrudFactory(Role, cursor_column=Role.id)
IntRoleCursorCrud = CrudFactory(IntRole, cursor_column=IntRole.id)

View File

@@ -11,12 +11,17 @@ from fastapi_toolsets.crud import (
CrudFactory,
InvalidFacetFilterError,
SearchConfig,
UnsupportedFacetTypeError,
get_searchable_fields,
)
from fastapi_toolsets.exceptions import InvalidOrderFieldError
from fastapi_toolsets.schemas import OffsetPagination, PaginationType
from .conftest import (
Article,
ArticleCreate,
ArticleCrud,
ArticleRead,
Role,
RoleCreate,
RoleCrud,
@@ -902,6 +907,128 @@ class TestFilterBy:
assert len(result.data) == 1
assert result.data[0].username == "alice"
@pytest.mark.anyio
async def test_bool_filter_false(self, db_session: AsyncSession):
"""filter_by with a boolean False value correctly filters rows."""
UserBoolCrud = CrudFactory(User, facet_fields=[User.is_active])
await UserCrud.create(
db_session, UserCreate(username="alice", email="a@test.com", is_active=True)
)
await UserCrud.create(
db_session,
UserCreate(username="bob", email="b@test.com", is_active=False),
)
result = await UserBoolCrud.offset_paginate(
db_session, filter_by={"is_active": False}, schema=UserRead
)
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 1
assert result.data[0].username == "bob"
@pytest.mark.anyio
async def test_bool_filter_true(self, db_session: AsyncSession):
"""filter_by with a boolean True value correctly filters rows."""
UserBoolCrud = CrudFactory(User, facet_fields=[User.is_active])
await UserCrud.create(
db_session, UserCreate(username="alice", email="a@test.com", is_active=True)
)
await UserCrud.create(
db_session,
UserCreate(username="bob", email="b@test.com", is_active=False),
)
result = await UserBoolCrud.offset_paginate(
db_session, filter_by={"is_active": True}, schema=UserRead
)
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 1
assert result.data[0].username == "alice"
@pytest.mark.anyio
async def test_bool_filter_list(self, db_session: AsyncSession):
"""filter_by with a list of booleans produces an IN clause."""
UserBoolCrud = CrudFactory(User, facet_fields=[User.is_active])
await UserCrud.create(
db_session, UserCreate(username="alice", email="a@test.com", is_active=True)
)
await UserCrud.create(
db_session,
UserCreate(username="bob", email="b@test.com", is_active=False),
)
result = await UserBoolCrud.offset_paginate(
db_session, filter_by={"is_active": [True, False]}, schema=UserRead
)
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 2
@pytest.mark.anyio
async def test_array_contains_single_value(self, db_session: AsyncSession):
"""filter_by on an ARRAY column with a scalar checks containment."""
ArticleFacetCrud = CrudFactory(Article, facet_fields=[Article.labels])
await ArticleCrud.create(
db_session, ArticleCreate(title="Post 1", labels=["python", "fastapi"])
)
await ArticleCrud.create(
db_session, ArticleCreate(title="Post 2", labels=["rust", "axum"])
)
await ArticleCrud.create(
db_session, ArticleCreate(title="Post 3", labels=["python", "django"])
)
result = await ArticleFacetCrud.offset_paginate(
db_session, filter_by={"labels": "python"}, schema=ArticleRead
)
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 2
titles = {a.title for a in result.data}
assert titles == {"Post 1", "Post 3"}
# facet returns individual unnested values, not whole arrays
assert result.filter_attributes == {"labels": ["django", "fastapi", "python"]}
@pytest.mark.anyio
async def test_array_overlap_list_value(self, db_session: AsyncSession):
"""filter_by on an ARRAY column with a list checks overlap."""
ArticleFacetCrud = CrudFactory(Article, facet_fields=[Article.labels])
await ArticleCrud.create(
db_session, ArticleCreate(title="Post 1", labels=["python", "fastapi"])
)
await ArticleCrud.create(
db_session, ArticleCreate(title="Post 2", labels=["rust", "axum"])
)
await ArticleCrud.create(
db_session, ArticleCreate(title="Post 3", labels=["python", "django"])
)
result = await ArticleFacetCrud.offset_paginate(
db_session, filter_by={"labels": ["rust", "django"]}, schema=ArticleRead
)
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 2
titles = {a.title for a in result.data}
assert titles == {"Post 2", "Post 3"}
@pytest.mark.anyio
async def test_unsupported_column_type_raises(self, db_session: AsyncSession):
"""filter_by on a JSON column raises UnsupportedFacetTypeError."""
ArticleJsonCrud = CrudFactory(Article, facet_fields=[Article.metadata_])
with pytest.raises(UnsupportedFacetTypeError) as exc_info:
await ArticleJsonCrud.offset_paginate(
db_session,
filter_by={"metadata_": {"key": "value"}},
schema=ArticleRead,
)
assert exc_info.value.key == "metadata_"
assert "JSON" in exc_info.value.col_type
class TestFilterParamsSchema:
"""Tests for AsyncCrud.filter_params()."""