mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-03-01 17:00:48 +01:00
feat: add sort_params helper in CrudFactory
This commit is contained in:
@@ -13,6 +13,7 @@ from fastapi_toolsets.exceptions import (
|
|||||||
ConflictError,
|
ConflictError,
|
||||||
NoSearchableFieldsError,
|
NoSearchableFieldsError,
|
||||||
InvalidFacetFilterError,
|
InvalidFacetFilterError,
|
||||||
|
InvalidSortFieldError,
|
||||||
generate_error_responses,
|
generate_error_responses,
|
||||||
init_exceptions_handlers,
|
init_exceptions_handlers,
|
||||||
)
|
)
|
||||||
@@ -32,6 +33,8 @@ from fastapi_toolsets.exceptions import (
|
|||||||
|
|
||||||
## ::: fastapi_toolsets.exceptions.exceptions.InvalidFacetFilterError
|
## ::: fastapi_toolsets.exceptions.exceptions.InvalidFacetFilterError
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.exceptions.exceptions.InvalidSortFieldError
|
||||||
|
|
||||||
## ::: fastapi_toolsets.exceptions.exceptions.generate_error_responses
|
## ::: fastapi_toolsets.exceptions.exceptions.generate_error_responses
|
||||||
|
|
||||||
## ::: fastapi_toolsets.exceptions.handler.init_exceptions_handlers
|
## ::: fastapi_toolsets.exceptions.handler.init_exceptions_handlers
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Generic async CRUD operations for SQLAlchemy models."""
|
"""Generic async CRUD operations for SQLAlchemy models."""
|
||||||
|
|
||||||
from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError
|
from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError
|
||||||
from .factory import CrudFactory, JoinType, M2MFieldType
|
from .factory import CrudFactory, JoinType, M2MFieldType, OrderByClause
|
||||||
from .search import (
|
from .search import (
|
||||||
FacetFieldType,
|
FacetFieldType,
|
||||||
SearchConfig,
|
SearchConfig,
|
||||||
@@ -16,5 +16,6 @@ __all__ = [
|
|||||||
"JoinType",
|
"JoinType",
|
||||||
"M2MFieldType",
|
"M2MFieldType",
|
||||||
"NoSearchableFieldsError",
|
"NoSearchableFieldsError",
|
||||||
|
"OrderByClause",
|
||||||
"SearchConfig",
|
"SearchConfig",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -21,10 +21,11 @@ from sqlalchemy.exc import NoResultFound
|
|||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.orm import DeclarativeBase, QueryableAttribute, selectinload
|
from sqlalchemy.orm import DeclarativeBase, QueryableAttribute, selectinload
|
||||||
from sqlalchemy.sql.base import ExecutableOption
|
from sqlalchemy.sql.base import ExecutableOption
|
||||||
|
from sqlalchemy.sql.elements import ColumnElement
|
||||||
from sqlalchemy.sql.roles import WhereHavingRole
|
from sqlalchemy.sql.roles import WhereHavingRole
|
||||||
|
|
||||||
from ..db import get_transaction
|
from ..db import get_transaction
|
||||||
from ..exceptions import NotFoundError
|
from ..exceptions import InvalidSortFieldError, NotFoundError
|
||||||
from ..schemas import CursorPagination, OffsetPagination, PaginatedResponse, Response
|
from ..schemas import CursorPagination, OffsetPagination, PaginatedResponse, Response
|
||||||
from .search import (
|
from .search import (
|
||||||
FacetFieldType,
|
FacetFieldType,
|
||||||
@@ -40,6 +41,7 @@ ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
|||||||
SchemaType = TypeVar("SchemaType", bound=BaseModel)
|
SchemaType = TypeVar("SchemaType", bound=BaseModel)
|
||||||
JoinType = list[tuple[type[DeclarativeBase], Any]]
|
JoinType = list[tuple[type[DeclarativeBase], Any]]
|
||||||
M2MFieldType = Mapping[str, QueryableAttribute[Any]]
|
M2MFieldType = Mapping[str, QueryableAttribute[Any]]
|
||||||
|
OrderByClause = ColumnElement[Any] | QueryableAttribute[Any]
|
||||||
|
|
||||||
|
|
||||||
def _encode_cursor(value: Any) -> str:
|
def _encode_cursor(value: Any) -> str:
|
||||||
@@ -61,6 +63,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
model: ClassVar[type[DeclarativeBase]]
|
model: ClassVar[type[DeclarativeBase]]
|
||||||
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
|
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
|
||||||
facet_fields: ClassVar[Sequence[FacetFieldType] | None] = None
|
facet_fields: ClassVar[Sequence[FacetFieldType] | None] = None
|
||||||
|
sort_fields: ClassVar[Sequence[QueryableAttribute[Any]] | None] = None
|
||||||
m2m_fields: ClassVar[M2MFieldType | None] = None
|
m2m_fields: ClassVar[M2MFieldType | None] = None
|
||||||
default_load_options: ClassVar[list[ExecutableOption] | None] = None
|
default_load_options: ClassVar[list[ExecutableOption] | None] = None
|
||||||
cursor_column: ClassVar[Any | None] = None
|
cursor_column: ClassVar[Any | None] = None
|
||||||
@@ -176,6 +179,63 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
|
|
||||||
return dependency
|
return dependency
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sort_params(
|
||||||
|
cls: type[Self],
|
||||||
|
*,
|
||||||
|
sort_fields: Sequence[QueryableAttribute[Any]] | None = None,
|
||||||
|
default_field: QueryableAttribute[Any] | None = None,
|
||||||
|
default_order: Literal["asc", "desc"] = "asc",
|
||||||
|
) -> Callable[..., Awaitable[OrderByClause | None]]:
|
||||||
|
"""Return a FastAPI dependency that resolves sort query params into an order_by clause.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sort_fields: Override the allowed sort fields. Falls back to the class-level
|
||||||
|
``sort_fields`` if not provided.
|
||||||
|
default_field: Field to sort by when ``sort_by`` query param is absent.
|
||||||
|
If ``None`` and no ``sort_by`` is provided, no ordering is applied.
|
||||||
|
default_order: Default sort direction when ``sort_order`` is absent
|
||||||
|
(``"asc"`` or ``"desc"``).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An async dependency function named ``{Model}SortParams`` that resolves to an
|
||||||
|
``OrderByClause`` (or ``None``). Pass it to ``Depends()`` in your route.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no sort fields are configured on this CRUD class and none are
|
||||||
|
provided via ``sort_fields``.
|
||||||
|
InvalidSortFieldError: When the request provides an unknown ``sort_by`` value.
|
||||||
|
"""
|
||||||
|
fields = sort_fields if sort_fields is not None else cls.sort_fields
|
||||||
|
if not fields:
|
||||||
|
raise ValueError(
|
||||||
|
f"{cls.__name__} has no sort_fields configured. "
|
||||||
|
"Pass sort_fields= or set them on CrudFactory."
|
||||||
|
)
|
||||||
|
field_map: dict[str, QueryableAttribute[Any]] = {f.key: f for f in fields}
|
||||||
|
valid_keys = sorted(field_map.keys())
|
||||||
|
|
||||||
|
async def dependency(
|
||||||
|
sort_by: str | None = Query(
|
||||||
|
None, description=f"Field to sort by. Valid values: {valid_keys}"
|
||||||
|
),
|
||||||
|
sort_order: Literal["asc", "desc"] = Query(
|
||||||
|
default_order, description="Sort direction"
|
||||||
|
),
|
||||||
|
) -> OrderByClause | None:
|
||||||
|
if sort_by is None:
|
||||||
|
if default_field is None:
|
||||||
|
return None
|
||||||
|
field = default_field
|
||||||
|
elif sort_by not in field_map:
|
||||||
|
raise InvalidSortFieldError(sort_by, valid_keys)
|
||||||
|
else:
|
||||||
|
field = field_map[sort_by]
|
||||||
|
return field.asc() if sort_order == "asc" else field.desc()
|
||||||
|
|
||||||
|
dependency.__name__ = f"{cls.model.__name__}SortParams"
|
||||||
|
return dependency
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create( # pragma: no cover
|
async def create( # pragma: no cover
|
||||||
@@ -415,7 +475,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
joins: JoinType | None = None,
|
joins: JoinType | None = None,
|
||||||
outer_join: bool = False,
|
outer_join: bool = False,
|
||||||
load_options: list[ExecutableOption] | None = None,
|
load_options: list[ExecutableOption] | None = None,
|
||||||
order_by: Any | None = None,
|
order_by: OrderByClause | None = None,
|
||||||
limit: int | None = None,
|
limit: int | None = None,
|
||||||
offset: int | None = None,
|
offset: int | None = None,
|
||||||
) -> Sequence[ModelType]:
|
) -> Sequence[ModelType]:
|
||||||
@@ -745,7 +805,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
joins: JoinType | None = None,
|
joins: JoinType | None = None,
|
||||||
outer_join: bool = False,
|
outer_join: bool = False,
|
||||||
load_options: list[ExecutableOption] | None = None,
|
load_options: list[ExecutableOption] | None = None,
|
||||||
order_by: Any | None = None,
|
order_by: OrderByClause | None = None,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
items_per_page: int = 20,
|
items_per_page: int = 20,
|
||||||
search: str | SearchConfig | None = None,
|
search: str | SearchConfig | None = None,
|
||||||
@@ -766,7 +826,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
joins: JoinType | None = None,
|
joins: JoinType | None = None,
|
||||||
outer_join: bool = False,
|
outer_join: bool = False,
|
||||||
load_options: list[ExecutableOption] | None = None,
|
load_options: list[ExecutableOption] | None = None,
|
||||||
order_by: Any | None = None,
|
order_by: OrderByClause | None = None,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
items_per_page: int = 20,
|
items_per_page: int = 20,
|
||||||
search: str | SearchConfig | None = None,
|
search: str | SearchConfig | None = None,
|
||||||
@@ -785,7 +845,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
joins: JoinType | None = None,
|
joins: JoinType | None = None,
|
||||||
outer_join: bool = False,
|
outer_join: bool = False,
|
||||||
load_options: list[ExecutableOption] | None = None,
|
load_options: list[ExecutableOption] | None = None,
|
||||||
order_by: Any | None = None,
|
order_by: OrderByClause | None = None,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
items_per_page: int = 20,
|
items_per_page: int = 20,
|
||||||
search: str | SearchConfig | None = None,
|
search: str | SearchConfig | None = None,
|
||||||
@@ -937,7 +997,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
joins: JoinType | None = None,
|
joins: JoinType | None = None,
|
||||||
outer_join: bool = False,
|
outer_join: bool = False,
|
||||||
load_options: list[ExecutableOption] | None = None,
|
load_options: list[ExecutableOption] | None = None,
|
||||||
order_by: Any | None = None,
|
order_by: OrderByClause | None = None,
|
||||||
items_per_page: int = 20,
|
items_per_page: int = 20,
|
||||||
search: str | SearchConfig | None = None,
|
search: str | SearchConfig | None = None,
|
||||||
search_fields: Sequence[SearchFieldType] | None = None,
|
search_fields: Sequence[SearchFieldType] | None = None,
|
||||||
@@ -958,7 +1018,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
joins: JoinType | None = None,
|
joins: JoinType | None = None,
|
||||||
outer_join: bool = False,
|
outer_join: bool = False,
|
||||||
load_options: list[ExecutableOption] | None = None,
|
load_options: list[ExecutableOption] | None = None,
|
||||||
order_by: Any | None = None,
|
order_by: OrderByClause | None = None,
|
||||||
items_per_page: int = 20,
|
items_per_page: int = 20,
|
||||||
search: str | SearchConfig | None = None,
|
search: str | SearchConfig | None = None,
|
||||||
search_fields: Sequence[SearchFieldType] | None = None,
|
search_fields: Sequence[SearchFieldType] | None = None,
|
||||||
@@ -977,7 +1037,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
joins: JoinType | None = None,
|
joins: JoinType | None = None,
|
||||||
outer_join: bool = False,
|
outer_join: bool = False,
|
||||||
load_options: list[ExecutableOption] | None = None,
|
load_options: list[ExecutableOption] | None = None,
|
||||||
order_by: Any | None = None,
|
order_by: OrderByClause | None = None,
|
||||||
items_per_page: int = 20,
|
items_per_page: int = 20,
|
||||||
search: str | SearchConfig | None = None,
|
search: str | SearchConfig | None = None,
|
||||||
search_fields: Sequence[SearchFieldType] | None = None,
|
search_fields: Sequence[SearchFieldType] | None = None,
|
||||||
@@ -1147,6 +1207,7 @@ def CrudFactory(
|
|||||||
*,
|
*,
|
||||||
searchable_fields: Sequence[SearchFieldType] | None = None,
|
searchable_fields: Sequence[SearchFieldType] | None = None,
|
||||||
facet_fields: Sequence[FacetFieldType] | None = None,
|
facet_fields: Sequence[FacetFieldType] | None = None,
|
||||||
|
sort_fields: Sequence[QueryableAttribute[Any]] | None = None,
|
||||||
m2m_fields: M2MFieldType | None = None,
|
m2m_fields: M2MFieldType | None = None,
|
||||||
default_load_options: list[ExecutableOption] | None = None,
|
default_load_options: list[ExecutableOption] | None = None,
|
||||||
cursor_column: Any | None = None,
|
cursor_column: Any | None = None,
|
||||||
@@ -1159,6 +1220,8 @@ def CrudFactory(
|
|||||||
facet_fields: Optional list of columns to compute distinct values for in paginated
|
facet_fields: Optional list of columns to compute distinct values for in paginated
|
||||||
responses. Supports direct columns (``User.status``) and relationship tuples
|
responses. Supports direct columns (``User.status``) and relationship tuples
|
||||||
(``(User.role, Role.name)``). Can be overridden per call.
|
(``(User.role, Role.name)``). Can be overridden per call.
|
||||||
|
sort_fields: Optional list of model attributes that callers are allowed to sort by
|
||||||
|
via ``sort_params()``. Can be overridden per call.
|
||||||
m2m_fields: Optional mapping for many-to-many relationships.
|
m2m_fields: Optional mapping for many-to-many relationships.
|
||||||
Maps schema field names (containing lists of IDs) to
|
Maps schema field names (containing lists of IDs) to
|
||||||
SQLAlchemy relationship attributes.
|
SQLAlchemy relationship attributes.
|
||||||
@@ -1252,6 +1315,7 @@ def CrudFactory(
|
|||||||
"model": model,
|
"model": model,
|
||||||
"searchable_fields": searchable_fields,
|
"searchable_fields": searchable_fields,
|
||||||
"facet_fields": facet_fields,
|
"facet_fields": facet_fields,
|
||||||
|
"sort_fields": sort_fields,
|
||||||
"m2m_fields": m2m_fields,
|
"m2m_fields": m2m_fields,
|
||||||
"default_load_options": default_load_options,
|
"default_load_options": default_load_options,
|
||||||
"cursor_column": cursor_column,
|
"cursor_column": cursor_column,
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from .exceptions import (
|
|||||||
ConflictError,
|
ConflictError,
|
||||||
ForbiddenError,
|
ForbiddenError,
|
||||||
InvalidFacetFilterError,
|
InvalidFacetFilterError,
|
||||||
|
InvalidSortFieldError,
|
||||||
NoSearchableFieldsError,
|
NoSearchableFieldsError,
|
||||||
NotFoundError,
|
NotFoundError,
|
||||||
UnauthorizedError,
|
UnauthorizedError,
|
||||||
@@ -21,6 +22,7 @@ __all__ = [
|
|||||||
"generate_error_responses",
|
"generate_error_responses",
|
||||||
"init_exceptions_handlers",
|
"init_exceptions_handlers",
|
||||||
"InvalidFacetFilterError",
|
"InvalidFacetFilterError",
|
||||||
|
"InvalidSortFieldError",
|
||||||
"NoSearchableFieldsError",
|
"NoSearchableFieldsError",
|
||||||
"NotFoundError",
|
"NotFoundError",
|
||||||
"UnauthorizedError",
|
"UnauthorizedError",
|
||||||
|
|||||||
@@ -128,6 +128,31 @@ class InvalidFacetFilterError(ApiException):
|
|||||||
super().__init__(detail)
|
super().__init__(detail)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidSortFieldError(ApiException):
|
||||||
|
"""Raised when sort_by contains a field not in the allowed sort fields."""
|
||||||
|
|
||||||
|
api_error = ApiError(
|
||||||
|
code=422,
|
||||||
|
msg="Invalid Sort Field",
|
||||||
|
desc="The requested sort field is not allowed for this resource.",
|
||||||
|
err_code="SORT-422",
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, field: str, valid_fields: list[str]) -> None:
|
||||||
|
"""Initialize the exception.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
field: The unknown sort field provided by the caller
|
||||||
|
valid_fields: List of valid field names
|
||||||
|
"""
|
||||||
|
self.field = field
|
||||||
|
self.valid_fields = valid_fields
|
||||||
|
detail = (
|
||||||
|
f"'{field}' is not an allowed sort field. Valid fields: {valid_fields}."
|
||||||
|
)
|
||||||
|
super().__init__(detail)
|
||||||
|
|
||||||
|
|
||||||
def generate_error_responses(
|
def generate_error_responses(
|
||||||
*errors: type[ApiException],
|
*errors: type[ApiException],
|
||||||
) -> dict[int | str, dict[str, Any]]:
|
) -> dict[int | str, dict[str, Any]]:
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
"""Tests for CRUD search functionality."""
|
"""Tests for CRUD search functionality."""
|
||||||
|
|
||||||
|
import inspect
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.sql.elements import ColumnElement, UnaryExpression
|
||||||
|
|
||||||
from fastapi_toolsets.crud import (
|
from fastapi_toolsets.crud import (
|
||||||
CrudFactory,
|
CrudFactory,
|
||||||
@@ -11,6 +13,7 @@ from fastapi_toolsets.crud import (
|
|||||||
SearchConfig,
|
SearchConfig,
|
||||||
get_searchable_fields,
|
get_searchable_fields,
|
||||||
)
|
)
|
||||||
|
from fastapi_toolsets.exceptions import InvalidSortFieldError
|
||||||
from fastapi_toolsets.schemas import OffsetPagination
|
from fastapi_toolsets.schemas import OffsetPagination
|
||||||
|
|
||||||
from .conftest import (
|
from .conftest import (
|
||||||
@@ -1014,3 +1017,144 @@ class TestFilterParamsSchema:
|
|||||||
|
|
||||||
assert isinstance(result.pagination, OffsetPagination)
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
assert result.pagination.total_count == 2
|
assert result.pagination.total_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestSortParamsSchema:
|
||||||
|
"""Tests for AsyncCrud.sort_params()."""
|
||||||
|
|
||||||
|
def test_generates_sort_by_and_sort_order_params(self):
|
||||||
|
"""Returned dependency has sort_by and sort_order query params."""
|
||||||
|
UserSortCrud = CrudFactory(User, sort_fields=[User.username, User.email])
|
||||||
|
dep = UserSortCrud.sort_params()
|
||||||
|
|
||||||
|
param_names = set(inspect.signature(dep).parameters)
|
||||||
|
assert param_names == {"sort_by", "sort_order"}
|
||||||
|
|
||||||
|
def test_dependency_name_includes_model_name(self):
|
||||||
|
"""Dependency function is named after the model."""
|
||||||
|
UserSortCrud = CrudFactory(User, sort_fields=[User.username])
|
||||||
|
dep = UserSortCrud.sort_params()
|
||||||
|
assert getattr(dep, "__name__") == "UserSortParams"
|
||||||
|
|
||||||
|
def test_raises_when_no_sort_fields(self):
|
||||||
|
"""ValueError raised when no sort_fields are configured or provided."""
|
||||||
|
with pytest.raises(ValueError, match="no sort_fields"):
|
||||||
|
UserCrud.sort_params()
|
||||||
|
|
||||||
|
def test_sort_fields_override(self):
|
||||||
|
"""sort_fields= parameter overrides the class-level default."""
|
||||||
|
UserSortCrud = CrudFactory(User, sort_fields=[User.username, User.email])
|
||||||
|
dep = UserSortCrud.sort_params(sort_fields=[User.email])
|
||||||
|
|
||||||
|
param_names = set(inspect.signature(dep).parameters)
|
||||||
|
assert "sort_by" in param_names
|
||||||
|
# description should only mention email, not username
|
||||||
|
sig = inspect.signature(dep)
|
||||||
|
description = sig.parameters["sort_by"].default.description
|
||||||
|
assert "email" in description
|
||||||
|
assert "username" not in description
|
||||||
|
|
||||||
|
def test_sort_by_description_lists_valid_fields(self):
|
||||||
|
"""sort_by query param description mentions each allowed field."""
|
||||||
|
UserSortCrud = CrudFactory(User, sort_fields=[User.username, User.email])
|
||||||
|
dep = UserSortCrud.sort_params()
|
||||||
|
|
||||||
|
sig = inspect.signature(dep)
|
||||||
|
description = sig.parameters["sort_by"].default.description
|
||||||
|
assert "username" in description
|
||||||
|
assert "email" in description
|
||||||
|
|
||||||
|
def test_default_order_reflected_in_sort_order_default(self):
|
||||||
|
"""default_order is used as the default value for sort_order."""
|
||||||
|
UserSortCrud = CrudFactory(User, sort_fields=[User.username])
|
||||||
|
dep_asc = UserSortCrud.sort_params(default_order="asc")
|
||||||
|
dep_desc = UserSortCrud.sort_params(default_order="desc")
|
||||||
|
|
||||||
|
sig_asc = inspect.signature(dep_asc)
|
||||||
|
sig_desc = inspect.signature(dep_desc)
|
||||||
|
assert sig_asc.parameters["sort_order"].default.default == "asc"
|
||||||
|
assert sig_desc.parameters["sort_order"].default.default == "desc"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_no_sort_by_no_default_returns_none(self):
|
||||||
|
"""Returns None when sort_by is absent and no default_field is set."""
|
||||||
|
UserSortCrud = CrudFactory(User, sort_fields=[User.username])
|
||||||
|
dep = UserSortCrud.sort_params()
|
||||||
|
result = await dep(sort_by=None, sort_order="asc")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_no_sort_by_with_default_field_returns_asc_expression(self):
|
||||||
|
"""Returns default_field.asc() when sort_by absent and sort_order=asc."""
|
||||||
|
UserSortCrud = CrudFactory(User, sort_fields=[User.username])
|
||||||
|
dep = UserSortCrud.sort_params(default_field=User.username)
|
||||||
|
result = await dep(sort_by=None, sort_order="asc")
|
||||||
|
assert isinstance(result, UnaryExpression)
|
||||||
|
assert "ASC" in str(result)
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_no_sort_by_with_default_field_returns_desc_expression(self):
|
||||||
|
"""Returns default_field.desc() when sort_by absent and sort_order=desc."""
|
||||||
|
UserSortCrud = CrudFactory(User, sort_fields=[User.username])
|
||||||
|
dep = UserSortCrud.sort_params(default_field=User.username)
|
||||||
|
result = await dep(sort_by=None, sort_order="desc")
|
||||||
|
assert isinstance(result, UnaryExpression)
|
||||||
|
assert "DESC" in str(result)
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_valid_sort_by_asc(self):
|
||||||
|
"""Returns field.asc() for a valid sort_by with sort_order=asc."""
|
||||||
|
UserSortCrud = CrudFactory(User, sort_fields=[User.username])
|
||||||
|
dep = UserSortCrud.sort_params()
|
||||||
|
result = await dep(sort_by="username", sort_order="asc")
|
||||||
|
assert isinstance(result, UnaryExpression)
|
||||||
|
assert "ASC" in str(result)
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_valid_sort_by_desc(self):
|
||||||
|
"""Returns field.desc() for a valid sort_by with sort_order=desc."""
|
||||||
|
UserSortCrud = CrudFactory(User, sort_fields=[User.username])
|
||||||
|
dep = UserSortCrud.sort_params()
|
||||||
|
result = await dep(sort_by="username", sort_order="desc")
|
||||||
|
assert isinstance(result, UnaryExpression)
|
||||||
|
assert "DESC" in str(result)
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_invalid_sort_by_raises_invalid_sort_field_error(self):
|
||||||
|
"""Raises InvalidSortFieldError for an unknown sort_by value."""
|
||||||
|
UserSortCrud = CrudFactory(User, sort_fields=[User.username])
|
||||||
|
dep = UserSortCrud.sort_params()
|
||||||
|
with pytest.raises(InvalidSortFieldError) as exc_info:
|
||||||
|
await dep(sort_by="nonexistent", sort_order="asc")
|
||||||
|
assert exc_info.value.field == "nonexistent"
|
||||||
|
assert "username" in exc_info.value.valid_fields
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_multiple_fields_all_resolve(self):
|
||||||
|
"""All configured fields resolve correctly via sort_by."""
|
||||||
|
UserSortCrud = CrudFactory(User, sort_fields=[User.username, User.email])
|
||||||
|
dep = UserSortCrud.sort_params()
|
||||||
|
result_username = await dep(sort_by="username", sort_order="asc")
|
||||||
|
result_email = await dep(sort_by="email", sort_order="desc")
|
||||||
|
assert isinstance(result_username, ColumnElement)
|
||||||
|
assert isinstance(result_email, ColumnElement)
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_sort_params_integrates_with_get_multi(
|
||||||
|
self, db_session: AsyncSession
|
||||||
|
):
|
||||||
|
"""sort_params output is accepted by get_multi(order_by=...)."""
|
||||||
|
UserSortCrud = CrudFactory(User, sort_fields=[User.username])
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="charlie", email="c@test.com")
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="alice", email="a@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
dep = UserSortCrud.sort_params()
|
||||||
|
order_by = await dep(sort_by="username", sort_order="asc")
|
||||||
|
results = await UserSortCrud.get_multi(db_session, order_by=order_by)
|
||||||
|
|
||||||
|
assert results[0].username == "alice"
|
||||||
|
assert results[1].username == "charlie"
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from fastapi_toolsets.exceptions import (
|
|||||||
ApiException,
|
ApiException,
|
||||||
ConflictError,
|
ConflictError,
|
||||||
ForbiddenError,
|
ForbiddenError,
|
||||||
|
InvalidSortFieldError,
|
||||||
NotFoundError,
|
NotFoundError,
|
||||||
UnauthorizedError,
|
UnauthorizedError,
|
||||||
generate_error_responses,
|
generate_error_responses,
|
||||||
@@ -334,3 +335,43 @@ class TestExceptionIntegration:
|
|||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == {"id": 1}
|
assert response.json() == {"id": 1}
|
||||||
|
|
||||||
|
|
||||||
|
class TestInvalidSortFieldError:
|
||||||
|
"""Tests for InvalidSortFieldError exception."""
|
||||||
|
|
||||||
|
def test_api_error_attributes(self):
|
||||||
|
"""InvalidSortFieldError has correct api_error metadata."""
|
||||||
|
assert InvalidSortFieldError.api_error.code == 422
|
||||||
|
assert InvalidSortFieldError.api_error.err_code == "SORT-422"
|
||||||
|
assert InvalidSortFieldError.api_error.msg == "Invalid Sort Field"
|
||||||
|
|
||||||
|
def test_stores_field_and_valid_fields(self):
|
||||||
|
"""InvalidSortFieldError stores field and valid_fields on the instance."""
|
||||||
|
error = InvalidSortFieldError("unknown", ["name", "created_at"])
|
||||||
|
assert error.field == "unknown"
|
||||||
|
assert error.valid_fields == ["name", "created_at"]
|
||||||
|
|
||||||
|
def test_message_contains_field_and_valid_fields(self):
|
||||||
|
"""Exception message mentions the bad field and valid options."""
|
||||||
|
error = InvalidSortFieldError("bad_field", ["name", "email"])
|
||||||
|
assert "bad_field" in str(error)
|
||||||
|
assert "name" in str(error)
|
||||||
|
assert "email" in str(error)
|
||||||
|
|
||||||
|
def test_handled_as_422_by_exception_handler(self):
|
||||||
|
"""init_exceptions_handlers turns InvalidSortFieldError into a 422 response."""
|
||||||
|
app = FastAPI()
|
||||||
|
init_exceptions_handlers(app)
|
||||||
|
|
||||||
|
@app.get("/items")
|
||||||
|
async def list_items():
|
||||||
|
raise InvalidSortFieldError("bad", ["name"])
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/items")
|
||||||
|
|
||||||
|
assert response.status_code == 422
|
||||||
|
data = response.json()
|
||||||
|
assert data["error_code"] == "SORT-422"
|
||||||
|
assert data["status"] == "FAIL"
|
||||||
|
|||||||
Reference in New Issue
Block a user