fix: handle boolean and ARRAY column types in filter_by facet filtering (#203)

This commit is contained in:
d3vyce
2026-03-31 21:36:54 +02:00
committed by GitHub
parent 4829cfba73
commit ebaa61525f
6 changed files with 243 additions and 8 deletions

View File

@@ -6,12 +6,27 @@ from collections.abc import Sequence
from dataclasses import dataclass, replace
from typing import TYPE_CHECKING, Any, Literal
from sqlalchemy import String, and_, or_, select
from sqlalchemy import String, and_, func, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.types import (
ARRAY,
Boolean,
Date,
DateTime,
Enum,
Integer,
Numeric,
Time,
Uuid,
)
from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError
from ..exceptions import (
InvalidFacetFilterError,
NoSearchableFieldsError,
UnsupportedFacetTypeError,
)
from ..types import FacetFieldType, SearchFieldType
if TYPE_CHECKING:
@@ -201,7 +216,14 @@ async def build_facets(
rels = ()
column = field
q = select(column).select_from(model).distinct()
col_type = column.property.columns[0].type
is_array = isinstance(col_type, ARRAY)
if is_array:
unnested = func.unnest(column).label(column.key)
q = select(unnested).select_from(model).distinct()
else:
q = select(column).select_from(model).distinct()
# Apply base joins (already done on main query, but needed here independently)
for rel in base_joins or []:
@@ -215,7 +237,10 @@ async def build_facets(
if base_filters:
q = q.where(and_(*base_filters))
q = q.order_by(column)
if is_array:
q = q.order_by(unnested)
else:
q = q.order_by(column)
result = await session.execute(q)
values = [row[0] for row in result.all() if row[0] is not None]
return key, values
@@ -226,6 +251,10 @@ async def build_facets(
return dict(pairs)
_EQUALITY_TYPES = (String, Integer, Numeric, Date, DateTime, Time, Enum, Uuid)
"""Column types that support equality / IN filtering in build_filter_by."""
def build_filter_by(
filter_by: dict[str, Any],
facet_fields: Sequence[FacetFieldType],
@@ -271,9 +300,23 @@ def build_filter_by(
joins.append(rel)
added_join_keys.add(rel_key)
if isinstance(value, list):
filters.append(column.in_(value))
col_type = column.property.columns[0].type
if isinstance(col_type, ARRAY):
if isinstance(value, list):
filters.append(column.overlap(value))
else:
filters.append(column.any(value))
elif isinstance(col_type, Boolean):
if isinstance(value, list):
filters.append(column.in_(value))
else:
filters.append(column.is_(value))
elif isinstance(col_type, _EQUALITY_TYPES):
if isinstance(value, list):
filters.append(column.in_(value))
else:
filters.append(column == value)
else:
filters.append(column == value)
raise UnsupportedFacetTypeError(key, type(col_type).__name__)
return filters, joins