feat: add sort_params helper in CrudFactory

This commit is contained in:
2026-02-28 12:48:06 -05:00
parent 117675d02f
commit 32ac3dc127
7 changed files with 289 additions and 9 deletions

View File

@@ -13,6 +13,7 @@ from fastapi_toolsets.exceptions import (
ConflictError, ConflictError,
NoSearchableFieldsError, NoSearchableFieldsError,
InvalidFacetFilterError, InvalidFacetFilterError,
InvalidSortFieldError,
generate_error_responses, generate_error_responses,
init_exceptions_handlers, init_exceptions_handlers,
) )
@@ -32,6 +33,8 @@ from fastapi_toolsets.exceptions import (
## ::: fastapi_toolsets.exceptions.exceptions.InvalidFacetFilterError ## ::: fastapi_toolsets.exceptions.exceptions.InvalidFacetFilterError
## ::: fastapi_toolsets.exceptions.exceptions.InvalidSortFieldError
## ::: fastapi_toolsets.exceptions.exceptions.generate_error_responses ## ::: fastapi_toolsets.exceptions.exceptions.generate_error_responses
## ::: fastapi_toolsets.exceptions.handler.init_exceptions_handlers ## ::: fastapi_toolsets.exceptions.handler.init_exceptions_handlers

View File

@@ -1,7 +1,7 @@
"""Generic async CRUD operations for SQLAlchemy models.""" """Generic async CRUD operations for SQLAlchemy models."""
from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError
from .factory import CrudFactory, JoinType, M2MFieldType from .factory import CrudFactory, JoinType, M2MFieldType, OrderByClause
from .search import ( from .search import (
FacetFieldType, FacetFieldType,
SearchConfig, SearchConfig,
@@ -16,5 +16,6 @@ __all__ = [
"JoinType", "JoinType",
"M2MFieldType", "M2MFieldType",
"NoSearchableFieldsError", "NoSearchableFieldsError",
"OrderByClause",
"SearchConfig", "SearchConfig",
] ]

View File

@@ -21,10 +21,11 @@ from sqlalchemy.exc import NoResultFound
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase, QueryableAttribute, selectinload from sqlalchemy.orm import DeclarativeBase, QueryableAttribute, selectinload
from sqlalchemy.sql.base import ExecutableOption from sqlalchemy.sql.base import ExecutableOption
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.roles import WhereHavingRole from sqlalchemy.sql.roles import WhereHavingRole
from ..db import get_transaction from ..db import get_transaction
from ..exceptions import NotFoundError from ..exceptions import InvalidSortFieldError, NotFoundError
from ..schemas import CursorPagination, OffsetPagination, PaginatedResponse, Response from ..schemas import CursorPagination, OffsetPagination, PaginatedResponse, Response
from .search import ( from .search import (
FacetFieldType, FacetFieldType,
@@ -40,6 +41,7 @@ ModelType = TypeVar("ModelType", bound=DeclarativeBase)
SchemaType = TypeVar("SchemaType", bound=BaseModel) SchemaType = TypeVar("SchemaType", bound=BaseModel)
JoinType = list[tuple[type[DeclarativeBase], Any]] JoinType = list[tuple[type[DeclarativeBase], Any]]
M2MFieldType = Mapping[str, QueryableAttribute[Any]] M2MFieldType = Mapping[str, QueryableAttribute[Any]]
OrderByClause = ColumnElement[Any] | QueryableAttribute[Any]
def _encode_cursor(value: Any) -> str: def _encode_cursor(value: Any) -> str:
@@ -61,6 +63,7 @@ class AsyncCrud(Generic[ModelType]):
model: ClassVar[type[DeclarativeBase]] model: ClassVar[type[DeclarativeBase]]
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
facet_fields: ClassVar[Sequence[FacetFieldType] | None] = None facet_fields: ClassVar[Sequence[FacetFieldType] | None] = None
sort_fields: ClassVar[Sequence[QueryableAttribute[Any]] | None] = None
m2m_fields: ClassVar[M2MFieldType | None] = None m2m_fields: ClassVar[M2MFieldType | None] = None
default_load_options: ClassVar[list[ExecutableOption] | None] = None default_load_options: ClassVar[list[ExecutableOption] | None] = None
cursor_column: ClassVar[Any | None] = None cursor_column: ClassVar[Any | None] = None
@@ -176,6 +179,63 @@ class AsyncCrud(Generic[ModelType]):
return dependency return dependency
@classmethod
def sort_params(
cls: type[Self],
*,
sort_fields: Sequence[QueryableAttribute[Any]] | None = None,
default_field: QueryableAttribute[Any] | None = None,
default_order: Literal["asc", "desc"] = "asc",
) -> Callable[..., Awaitable[OrderByClause | None]]:
"""Return a FastAPI dependency that resolves sort query params into an order_by clause.
Args:
sort_fields: Override the allowed sort fields. Falls back to the class-level
``sort_fields`` if not provided.
default_field: Field to sort by when ``sort_by`` query param is absent.
If ``None`` and no ``sort_by`` is provided, no ordering is applied.
default_order: Default sort direction when ``sort_order`` is absent
(``"asc"`` or ``"desc"``).
Returns:
An async dependency function named ``{Model}SortParams`` that resolves to an
``OrderByClause`` (or ``None``). Pass it to ``Depends()`` in your route.
Raises:
ValueError: If no sort fields are configured on this CRUD class and none are
provided via ``sort_fields``.
InvalidSortFieldError: When the request provides an unknown ``sort_by`` value.
"""
fields = sort_fields if sort_fields is not None else cls.sort_fields
if not fields:
raise ValueError(
f"{cls.__name__} has no sort_fields configured. "
"Pass sort_fields= or set them on CrudFactory."
)
field_map: dict[str, QueryableAttribute[Any]] = {f.key: f for f in fields}
valid_keys = sorted(field_map.keys())
async def dependency(
sort_by: str | None = Query(
None, description=f"Field to sort by. Valid values: {valid_keys}"
),
sort_order: Literal["asc", "desc"] = Query(
default_order, description="Sort direction"
),
) -> OrderByClause | None:
if sort_by is None:
if default_field is None:
return None
field = default_field
elif sort_by not in field_map:
raise InvalidSortFieldError(sort_by, valid_keys)
else:
field = field_map[sort_by]
return field.asc() if sort_order == "asc" else field.desc()
dependency.__name__ = f"{cls.model.__name__}SortParams"
return dependency
@overload @overload
@classmethod @classmethod
async def create( # pragma: no cover async def create( # pragma: no cover
@@ -415,7 +475,7 @@ class AsyncCrud(Generic[ModelType]):
joins: JoinType | None = None, joins: JoinType | None = None,
outer_join: bool = False, outer_join: bool = False,
load_options: list[ExecutableOption] | None = None, load_options: list[ExecutableOption] | None = None,
order_by: Any | None = None, order_by: OrderByClause | None = None,
limit: int | None = None, limit: int | None = None,
offset: int | None = None, offset: int | None = None,
) -> Sequence[ModelType]: ) -> Sequence[ModelType]:
@@ -745,7 +805,7 @@ class AsyncCrud(Generic[ModelType]):
joins: JoinType | None = None, joins: JoinType | None = None,
outer_join: bool = False, outer_join: bool = False,
load_options: list[ExecutableOption] | None = None, load_options: list[ExecutableOption] | None = None,
order_by: Any | None = None, order_by: OrderByClause | None = None,
page: int = 1, page: int = 1,
items_per_page: int = 20, items_per_page: int = 20,
search: str | SearchConfig | None = None, search: str | SearchConfig | None = None,
@@ -766,7 +826,7 @@ class AsyncCrud(Generic[ModelType]):
joins: JoinType | None = None, joins: JoinType | None = None,
outer_join: bool = False, outer_join: bool = False,
load_options: list[ExecutableOption] | None = None, load_options: list[ExecutableOption] | None = None,
order_by: Any | None = None, order_by: OrderByClause | None = None,
page: int = 1, page: int = 1,
items_per_page: int = 20, items_per_page: int = 20,
search: str | SearchConfig | None = None, search: str | SearchConfig | None = None,
@@ -785,7 +845,7 @@ class AsyncCrud(Generic[ModelType]):
joins: JoinType | None = None, joins: JoinType | None = None,
outer_join: bool = False, outer_join: bool = False,
load_options: list[ExecutableOption] | None = None, load_options: list[ExecutableOption] | None = None,
order_by: Any | None = None, order_by: OrderByClause | None = None,
page: int = 1, page: int = 1,
items_per_page: int = 20, items_per_page: int = 20,
search: str | SearchConfig | None = None, search: str | SearchConfig | None = None,
@@ -937,7 +997,7 @@ class AsyncCrud(Generic[ModelType]):
joins: JoinType | None = None, joins: JoinType | None = None,
outer_join: bool = False, outer_join: bool = False,
load_options: list[ExecutableOption] | None = None, load_options: list[ExecutableOption] | None = None,
order_by: Any | None = None, order_by: OrderByClause | None = None,
items_per_page: int = 20, items_per_page: int = 20,
search: str | SearchConfig | None = None, search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None, search_fields: Sequence[SearchFieldType] | None = None,
@@ -958,7 +1018,7 @@ class AsyncCrud(Generic[ModelType]):
joins: JoinType | None = None, joins: JoinType | None = None,
outer_join: bool = False, outer_join: bool = False,
load_options: list[ExecutableOption] | None = None, load_options: list[ExecutableOption] | None = None,
order_by: Any | None = None, order_by: OrderByClause | None = None,
items_per_page: int = 20, items_per_page: int = 20,
search: str | SearchConfig | None = None, search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None, search_fields: Sequence[SearchFieldType] | None = None,
@@ -977,7 +1037,7 @@ class AsyncCrud(Generic[ModelType]):
joins: JoinType | None = None, joins: JoinType | None = None,
outer_join: bool = False, outer_join: bool = False,
load_options: list[ExecutableOption] | None = None, load_options: list[ExecutableOption] | None = None,
order_by: Any | None = None, order_by: OrderByClause | None = None,
items_per_page: int = 20, items_per_page: int = 20,
search: str | SearchConfig | None = None, search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None, search_fields: Sequence[SearchFieldType] | None = None,
@@ -1147,6 +1207,7 @@ def CrudFactory(
*, *,
searchable_fields: Sequence[SearchFieldType] | None = None, searchable_fields: Sequence[SearchFieldType] | None = None,
facet_fields: Sequence[FacetFieldType] | None = None, facet_fields: Sequence[FacetFieldType] | None = None,
sort_fields: Sequence[QueryableAttribute[Any]] | None = None,
m2m_fields: M2MFieldType | None = None, m2m_fields: M2MFieldType | None = None,
default_load_options: list[ExecutableOption] | None = None, default_load_options: list[ExecutableOption] | None = None,
cursor_column: Any | None = None, cursor_column: Any | None = None,
@@ -1159,6 +1220,8 @@ def CrudFactory(
facet_fields: Optional list of columns to compute distinct values for in paginated facet_fields: Optional list of columns to compute distinct values for in paginated
responses. Supports direct columns (``User.status``) and relationship tuples responses. Supports direct columns (``User.status``) and relationship tuples
(``(User.role, Role.name)``). Can be overridden per call. (``(User.role, Role.name)``). Can be overridden per call.
sort_fields: Optional list of model attributes that callers are allowed to sort by
via ``sort_params()``. Can be overridden per call.
m2m_fields: Optional mapping for many-to-many relationships. m2m_fields: Optional mapping for many-to-many relationships.
Maps schema field names (containing lists of IDs) to Maps schema field names (containing lists of IDs) to
SQLAlchemy relationship attributes. SQLAlchemy relationship attributes.
@@ -1252,6 +1315,7 @@ def CrudFactory(
"model": model, "model": model,
"searchable_fields": searchable_fields, "searchable_fields": searchable_fields,
"facet_fields": facet_fields, "facet_fields": facet_fields,
"sort_fields": sort_fields,
"m2m_fields": m2m_fields, "m2m_fields": m2m_fields,
"default_load_options": default_load_options, "default_load_options": default_load_options,
"cursor_column": cursor_column, "cursor_column": cursor_column,

View File

@@ -6,6 +6,7 @@ from .exceptions import (
ConflictError, ConflictError,
ForbiddenError, ForbiddenError,
InvalidFacetFilterError, InvalidFacetFilterError,
InvalidSortFieldError,
NoSearchableFieldsError, NoSearchableFieldsError,
NotFoundError, NotFoundError,
UnauthorizedError, UnauthorizedError,
@@ -21,6 +22,7 @@ __all__ = [
"generate_error_responses", "generate_error_responses",
"init_exceptions_handlers", "init_exceptions_handlers",
"InvalidFacetFilterError", "InvalidFacetFilterError",
"InvalidSortFieldError",
"NoSearchableFieldsError", "NoSearchableFieldsError",
"NotFoundError", "NotFoundError",
"UnauthorizedError", "UnauthorizedError",

View File

@@ -128,6 +128,31 @@ class InvalidFacetFilterError(ApiException):
super().__init__(detail) super().__init__(detail)
class InvalidSortFieldError(ApiException):
"""Raised when sort_by contains a field not in the allowed sort fields."""
api_error = ApiError(
code=422,
msg="Invalid Sort Field",
desc="The requested sort field is not allowed for this resource.",
err_code="SORT-422",
)
def __init__(self, field: str, valid_fields: list[str]) -> None:
"""Initialize the exception.
Args:
field: The unknown sort field provided by the caller
valid_fields: List of valid field names
"""
self.field = field
self.valid_fields = valid_fields
detail = (
f"'{field}' is not an allowed sort field. Valid fields: {valid_fields}."
)
super().__init__(detail)
def generate_error_responses( def generate_error_responses(
*errors: type[ApiException], *errors: type[ApiException],
) -> dict[int | str, dict[str, Any]]: ) -> dict[int | str, dict[str, Any]]:

View File

@@ -1,9 +1,11 @@
"""Tests for CRUD search functionality.""" """Tests for CRUD search functionality."""
import inspect
import uuid import uuid
import pytest import pytest
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql.elements import ColumnElement, UnaryExpression
from fastapi_toolsets.crud import ( from fastapi_toolsets.crud import (
CrudFactory, CrudFactory,
@@ -11,6 +13,7 @@ from fastapi_toolsets.crud import (
SearchConfig, SearchConfig,
get_searchable_fields, get_searchable_fields,
) )
from fastapi_toolsets.exceptions import InvalidSortFieldError
from fastapi_toolsets.schemas import OffsetPagination from fastapi_toolsets.schemas import OffsetPagination
from .conftest import ( from .conftest import (
@@ -1014,3 +1017,144 @@ class TestFilterParamsSchema:
assert isinstance(result.pagination, OffsetPagination) assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count == 2 assert result.pagination.total_count == 2
class TestSortParamsSchema:
"""Tests for AsyncCrud.sort_params()."""
def test_generates_sort_by_and_sort_order_params(self):
"""Returned dependency has sort_by and sort_order query params."""
UserSortCrud = CrudFactory(User, sort_fields=[User.username, User.email])
dep = UserSortCrud.sort_params()
param_names = set(inspect.signature(dep).parameters)
assert param_names == {"sort_by", "sort_order"}
def test_dependency_name_includes_model_name(self):
"""Dependency function is named after the model."""
UserSortCrud = CrudFactory(User, sort_fields=[User.username])
dep = UserSortCrud.sort_params()
assert getattr(dep, "__name__") == "UserSortParams"
def test_raises_when_no_sort_fields(self):
"""ValueError raised when no sort_fields are configured or provided."""
with pytest.raises(ValueError, match="no sort_fields"):
UserCrud.sort_params()
def test_sort_fields_override(self):
"""sort_fields= parameter overrides the class-level default."""
UserSortCrud = CrudFactory(User, sort_fields=[User.username, User.email])
dep = UserSortCrud.sort_params(sort_fields=[User.email])
param_names = set(inspect.signature(dep).parameters)
assert "sort_by" in param_names
# description should only mention email, not username
sig = inspect.signature(dep)
description = sig.parameters["sort_by"].default.description
assert "email" in description
assert "username" not in description
def test_sort_by_description_lists_valid_fields(self):
"""sort_by query param description mentions each allowed field."""
UserSortCrud = CrudFactory(User, sort_fields=[User.username, User.email])
dep = UserSortCrud.sort_params()
sig = inspect.signature(dep)
description = sig.parameters["sort_by"].default.description
assert "username" in description
assert "email" in description
def test_default_order_reflected_in_sort_order_default(self):
"""default_order is used as the default value for sort_order."""
UserSortCrud = CrudFactory(User, sort_fields=[User.username])
dep_asc = UserSortCrud.sort_params(default_order="asc")
dep_desc = UserSortCrud.sort_params(default_order="desc")
sig_asc = inspect.signature(dep_asc)
sig_desc = inspect.signature(dep_desc)
assert sig_asc.parameters["sort_order"].default.default == "asc"
assert sig_desc.parameters["sort_order"].default.default == "desc"
@pytest.mark.anyio
async def test_no_sort_by_no_default_returns_none(self):
"""Returns None when sort_by is absent and no default_field is set."""
UserSortCrud = CrudFactory(User, sort_fields=[User.username])
dep = UserSortCrud.sort_params()
result = await dep(sort_by=None, sort_order="asc")
assert result is None
@pytest.mark.anyio
async def test_no_sort_by_with_default_field_returns_asc_expression(self):
"""Returns default_field.asc() when sort_by absent and sort_order=asc."""
UserSortCrud = CrudFactory(User, sort_fields=[User.username])
dep = UserSortCrud.sort_params(default_field=User.username)
result = await dep(sort_by=None, sort_order="asc")
assert isinstance(result, UnaryExpression)
assert "ASC" in str(result)
@pytest.mark.anyio
async def test_no_sort_by_with_default_field_returns_desc_expression(self):
"""Returns default_field.desc() when sort_by absent and sort_order=desc."""
UserSortCrud = CrudFactory(User, sort_fields=[User.username])
dep = UserSortCrud.sort_params(default_field=User.username)
result = await dep(sort_by=None, sort_order="desc")
assert isinstance(result, UnaryExpression)
assert "DESC" in str(result)
@pytest.mark.anyio
async def test_valid_sort_by_asc(self):
"""Returns field.asc() for a valid sort_by with sort_order=asc."""
UserSortCrud = CrudFactory(User, sort_fields=[User.username])
dep = UserSortCrud.sort_params()
result = await dep(sort_by="username", sort_order="asc")
assert isinstance(result, UnaryExpression)
assert "ASC" in str(result)
@pytest.mark.anyio
async def test_valid_sort_by_desc(self):
"""Returns field.desc() for a valid sort_by with sort_order=desc."""
UserSortCrud = CrudFactory(User, sort_fields=[User.username])
dep = UserSortCrud.sort_params()
result = await dep(sort_by="username", sort_order="desc")
assert isinstance(result, UnaryExpression)
assert "DESC" in str(result)
@pytest.mark.anyio
async def test_invalid_sort_by_raises_invalid_sort_field_error(self):
"""Raises InvalidSortFieldError for an unknown sort_by value."""
UserSortCrud = CrudFactory(User, sort_fields=[User.username])
dep = UserSortCrud.sort_params()
with pytest.raises(InvalidSortFieldError) as exc_info:
await dep(sort_by="nonexistent", sort_order="asc")
assert exc_info.value.field == "nonexistent"
assert "username" in exc_info.value.valid_fields
@pytest.mark.anyio
async def test_multiple_fields_all_resolve(self):
"""All configured fields resolve correctly via sort_by."""
UserSortCrud = CrudFactory(User, sort_fields=[User.username, User.email])
dep = UserSortCrud.sort_params()
result_username = await dep(sort_by="username", sort_order="asc")
result_email = await dep(sort_by="email", sort_order="desc")
assert isinstance(result_username, ColumnElement)
assert isinstance(result_email, ColumnElement)
@pytest.mark.anyio
async def test_sort_params_integrates_with_get_multi(
self, db_session: AsyncSession
):
"""sort_params output is accepted by get_multi(order_by=...)."""
UserSortCrud = CrudFactory(User, sort_fields=[User.username])
await UserCrud.create(
db_session, UserCreate(username="charlie", email="c@test.com")
)
await UserCrud.create(
db_session, UserCreate(username="alice", email="a@test.com")
)
dep = UserSortCrud.sort_params()
order_by = await dep(sort_by="username", sort_order="asc")
results = await UserSortCrud.get_multi(db_session, order_by=order_by)
assert results[0].username == "alice"
assert results[1].username == "charlie"

View File

@@ -8,6 +8,7 @@ from fastapi_toolsets.exceptions import (
ApiException, ApiException,
ConflictError, ConflictError,
ForbiddenError, ForbiddenError,
InvalidSortFieldError,
NotFoundError, NotFoundError,
UnauthorizedError, UnauthorizedError,
generate_error_responses, generate_error_responses,
@@ -334,3 +335,43 @@ class TestExceptionIntegration:
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == {"id": 1} assert response.json() == {"id": 1}
class TestInvalidSortFieldError:
"""Tests for InvalidSortFieldError exception."""
def test_api_error_attributes(self):
"""InvalidSortFieldError has correct api_error metadata."""
assert InvalidSortFieldError.api_error.code == 422
assert InvalidSortFieldError.api_error.err_code == "SORT-422"
assert InvalidSortFieldError.api_error.msg == "Invalid Sort Field"
def test_stores_field_and_valid_fields(self):
"""InvalidSortFieldError stores field and valid_fields on the instance."""
error = InvalidSortFieldError("unknown", ["name", "created_at"])
assert error.field == "unknown"
assert error.valid_fields == ["name", "created_at"]
def test_message_contains_field_and_valid_fields(self):
"""Exception message mentions the bad field and valid options."""
error = InvalidSortFieldError("bad_field", ["name", "email"])
assert "bad_field" in str(error)
assert "name" in str(error)
assert "email" in str(error)
def test_handled_as_422_by_exception_handler(self):
"""init_exceptions_handlers turns InvalidSortFieldError into a 422 response."""
app = FastAPI()
init_exceptions_handlers(app)
@app.get("/items")
async def list_items():
raise InvalidSortFieldError("bad", ["name"])
client = TestClient(app)
response = client.get("/items")
assert response.status_code == 422
data = response.json()
assert data["error_code"] == "SORT-422"
assert data["status"] == "FAIL"