diff --git a/docs/reference/exceptions.md b/docs/reference/exceptions.md index 6df730e..577a038 100644 --- a/docs/reference/exceptions.md +++ b/docs/reference/exceptions.md @@ -13,6 +13,7 @@ from fastapi_toolsets.exceptions import ( ConflictError, NoSearchableFieldsError, InvalidFacetFilterError, + InvalidSortFieldError, generate_error_responses, init_exceptions_handlers, ) @@ -32,6 +33,8 @@ from fastapi_toolsets.exceptions import ( ## ::: fastapi_toolsets.exceptions.exceptions.InvalidFacetFilterError +## ::: fastapi_toolsets.exceptions.exceptions.InvalidSortFieldError + ## ::: fastapi_toolsets.exceptions.exceptions.generate_error_responses ## ::: fastapi_toolsets.exceptions.handler.init_exceptions_handlers diff --git a/src/fastapi_toolsets/crud/__init__.py b/src/fastapi_toolsets/crud/__init__.py index 3e311d1..68c6fe5 100644 --- a/src/fastapi_toolsets/crud/__init__.py +++ b/src/fastapi_toolsets/crud/__init__.py @@ -1,7 +1,7 @@ """Generic async CRUD operations for SQLAlchemy models.""" from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError -from .factory import CrudFactory, JoinType, M2MFieldType +from .factory import CrudFactory, JoinType, M2MFieldType, OrderByClause from .search import ( FacetFieldType, SearchConfig, @@ -16,5 +16,6 @@ __all__ = [ "JoinType", "M2MFieldType", "NoSearchableFieldsError", + "OrderByClause", "SearchConfig", ] diff --git a/src/fastapi_toolsets/crud/factory.py b/src/fastapi_toolsets/crud/factory.py index a2849ed..f6890d7 100644 --- a/src/fastapi_toolsets/crud/factory.py +++ b/src/fastapi_toolsets/crud/factory.py @@ -21,10 +21,11 @@ from sqlalchemy.exc import NoResultFound from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import DeclarativeBase, QueryableAttribute, selectinload from sqlalchemy.sql.base import ExecutableOption +from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.roles import WhereHavingRole from ..db import get_transaction -from ..exceptions import NotFoundError +from ..exceptions import InvalidSortFieldError, NotFoundError from ..schemas import CursorPagination, OffsetPagination, PaginatedResponse, Response from .search import ( FacetFieldType, @@ -40,6 +41,7 @@ ModelType = TypeVar("ModelType", bound=DeclarativeBase) SchemaType = TypeVar("SchemaType", bound=BaseModel) JoinType = list[tuple[type[DeclarativeBase], Any]] M2MFieldType = Mapping[str, QueryableAttribute[Any]] +OrderByClause = ColumnElement[Any] | QueryableAttribute[Any] def _encode_cursor(value: Any) -> str: @@ -61,6 +63,7 @@ class AsyncCrud(Generic[ModelType]): model: ClassVar[type[DeclarativeBase]] searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None facet_fields: ClassVar[Sequence[FacetFieldType] | None] = None + sort_fields: ClassVar[Sequence[QueryableAttribute[Any]] | None] = None m2m_fields: ClassVar[M2MFieldType | None] = None default_load_options: ClassVar[list[ExecutableOption] | None] = None cursor_column: ClassVar[Any | None] = None @@ -176,6 +179,63 @@ class AsyncCrud(Generic[ModelType]): return dependency + @classmethod + def sort_params( + cls: type[Self], + *, + sort_fields: Sequence[QueryableAttribute[Any]] | None = None, + default_field: QueryableAttribute[Any] | None = None, + default_order: Literal["asc", "desc"] = "asc", + ) -> Callable[..., Awaitable[OrderByClause | None]]: + """Return a FastAPI dependency that resolves sort query params into an order_by clause. + + Args: + sort_fields: Override the allowed sort fields. Falls back to the class-level + ``sort_fields`` if not provided. + default_field: Field to sort by when ``sort_by`` query param is absent. + If ``None`` and no ``sort_by`` is provided, no ordering is applied. + default_order: Default sort direction when ``sort_order`` is absent + (``"asc"`` or ``"desc"``). + + Returns: + An async dependency function named ``{Model}SortParams`` that resolves to an + ``OrderByClause`` (or ``None``). Pass it to ``Depends()`` in your route. + + Raises: + ValueError: If no sort fields are configured on this CRUD class and none are + provided via ``sort_fields``. + InvalidSortFieldError: When the request provides an unknown ``sort_by`` value. + """ + fields = sort_fields if sort_fields is not None else cls.sort_fields + if not fields: + raise ValueError( + f"{cls.__name__} has no sort_fields configured. " + "Pass sort_fields= or set them on CrudFactory." + ) + field_map: dict[str, QueryableAttribute[Any]] = {f.key: f for f in fields} + valid_keys = sorted(field_map.keys()) + + async def dependency( + sort_by: str | None = Query( + None, description=f"Field to sort by. Valid values: {valid_keys}" + ), + sort_order: Literal["asc", "desc"] = Query( + default_order, description="Sort direction" + ), + ) -> OrderByClause | None: + if sort_by is None: + if default_field is None: + return None + field = default_field + elif sort_by not in field_map: + raise InvalidSortFieldError(sort_by, valid_keys) + else: + field = field_map[sort_by] + return field.asc() if sort_order == "asc" else field.desc() + + dependency.__name__ = f"{cls.model.__name__}SortParams" + return dependency + @overload @classmethod async def create( # pragma: no cover @@ -415,7 +475,7 @@ class AsyncCrud(Generic[ModelType]): joins: JoinType | None = None, outer_join: bool = False, load_options: list[ExecutableOption] | None = None, - order_by: Any | None = None, + order_by: OrderByClause | None = None, limit: int | None = None, offset: int | None = None, ) -> Sequence[ModelType]: @@ -745,7 +805,7 @@ class AsyncCrud(Generic[ModelType]): joins: JoinType | None = None, outer_join: bool = False, load_options: list[ExecutableOption] | None = None, - order_by: Any | None = None, + order_by: OrderByClause | None = None, page: int = 1, items_per_page: int = 20, search: str | SearchConfig | None = None, @@ -766,7 +826,7 @@ class AsyncCrud(Generic[ModelType]): joins: JoinType | None = None, outer_join: bool = False, load_options: list[ExecutableOption] | None = None, - order_by: Any | None = None, + order_by: OrderByClause | None = None, page: int = 1, items_per_page: int = 20, search: str | SearchConfig | None = None, @@ -785,7 +845,7 @@ class AsyncCrud(Generic[ModelType]): joins: JoinType | None = None, outer_join: bool = False, load_options: list[ExecutableOption] | None = None, - order_by: Any | None = None, + order_by: OrderByClause | None = None, page: int = 1, items_per_page: int = 20, search: str | SearchConfig | None = None, @@ -937,7 +997,7 @@ class AsyncCrud(Generic[ModelType]): joins: JoinType | None = None, outer_join: bool = False, load_options: list[ExecutableOption] | None = None, - order_by: Any | None = None, + order_by: OrderByClause | None = None, items_per_page: int = 20, search: str | SearchConfig | None = None, search_fields: Sequence[SearchFieldType] | None = None, @@ -958,7 +1018,7 @@ class AsyncCrud(Generic[ModelType]): joins: JoinType | None = None, outer_join: bool = False, load_options: list[ExecutableOption] | None = None, - order_by: Any | None = None, + order_by: OrderByClause | None = None, items_per_page: int = 20, search: str | SearchConfig | None = None, search_fields: Sequence[SearchFieldType] | None = None, @@ -977,7 +1037,7 @@ class AsyncCrud(Generic[ModelType]): joins: JoinType | None = None, outer_join: bool = False, load_options: list[ExecutableOption] | None = None, - order_by: Any | None = None, + order_by: OrderByClause | None = None, items_per_page: int = 20, search: str | SearchConfig | None = None, search_fields: Sequence[SearchFieldType] | None = None, @@ -1147,6 +1207,7 @@ def CrudFactory( *, searchable_fields: Sequence[SearchFieldType] | None = None, facet_fields: Sequence[FacetFieldType] | None = None, + sort_fields: Sequence[QueryableAttribute[Any]] | None = None, m2m_fields: M2MFieldType | None = None, default_load_options: list[ExecutableOption] | None = None, cursor_column: Any | None = None, @@ -1159,6 +1220,8 @@ def CrudFactory( facet_fields: Optional list of columns to compute distinct values for in paginated responses. Supports direct columns (``User.status``) and relationship tuples (``(User.role, Role.name)``). Can be overridden per call. + sort_fields: Optional list of model attributes that callers are allowed to sort by + via ``sort_params()``. Can be overridden per call. m2m_fields: Optional mapping for many-to-many relationships. Maps schema field names (containing lists of IDs) to SQLAlchemy relationship attributes. @@ -1252,6 +1315,7 @@ def CrudFactory( "model": model, "searchable_fields": searchable_fields, "facet_fields": facet_fields, + "sort_fields": sort_fields, "m2m_fields": m2m_fields, "default_load_options": default_load_options, "cursor_column": cursor_column, diff --git a/src/fastapi_toolsets/exceptions/__init__.py b/src/fastapi_toolsets/exceptions/__init__.py index 2bb2b65..5ae8356 100644 --- a/src/fastapi_toolsets/exceptions/__init__.py +++ b/src/fastapi_toolsets/exceptions/__init__.py @@ -6,6 +6,7 @@ from .exceptions import ( ConflictError, ForbiddenError, InvalidFacetFilterError, + InvalidSortFieldError, NoSearchableFieldsError, NotFoundError, UnauthorizedError, @@ -21,6 +22,7 @@ __all__ = [ "generate_error_responses", "init_exceptions_handlers", "InvalidFacetFilterError", + "InvalidSortFieldError", "NoSearchableFieldsError", "NotFoundError", "UnauthorizedError", diff --git a/src/fastapi_toolsets/exceptions/exceptions.py b/src/fastapi_toolsets/exceptions/exceptions.py index 87d34d9..706f26f 100644 --- a/src/fastapi_toolsets/exceptions/exceptions.py +++ b/src/fastapi_toolsets/exceptions/exceptions.py @@ -128,6 +128,31 @@ class InvalidFacetFilterError(ApiException): super().__init__(detail) +class InvalidSortFieldError(ApiException): + """Raised when sort_by contains a field not in the allowed sort fields.""" + + api_error = ApiError( + code=422, + msg="Invalid Sort Field", + desc="The requested sort field is not allowed for this resource.", + err_code="SORT-422", + ) + + def __init__(self, field: str, valid_fields: list[str]) -> None: + """Initialize the exception. + + Args: + field: The unknown sort field provided by the caller + valid_fields: List of valid field names + """ + self.field = field + self.valid_fields = valid_fields + detail = ( + f"'{field}' is not an allowed sort field. Valid fields: {valid_fields}." + ) + super().__init__(detail) + + def generate_error_responses( *errors: type[ApiException], ) -> dict[int | str, dict[str, Any]]: diff --git a/tests/test_crud_search.py b/tests/test_crud_search.py index f87e5f9..5f22363 100644 --- a/tests/test_crud_search.py +++ b/tests/test_crud_search.py @@ -1,9 +1,11 @@ """Tests for CRUD search functionality.""" +import inspect import uuid import pytest from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.sql.elements import ColumnElement, UnaryExpression from fastapi_toolsets.crud import ( CrudFactory, @@ -11,6 +13,7 @@ from fastapi_toolsets.crud import ( SearchConfig, get_searchable_fields, ) +from fastapi_toolsets.exceptions import InvalidSortFieldError from fastapi_toolsets.schemas import OffsetPagination from .conftest import ( @@ -1014,3 +1017,144 @@ class TestFilterParamsSchema: assert isinstance(result.pagination, OffsetPagination) assert result.pagination.total_count == 2 + + +class TestSortParamsSchema: + """Tests for AsyncCrud.sort_params().""" + + def test_generates_sort_by_and_sort_order_params(self): + """Returned dependency has sort_by and sort_order query params.""" + UserSortCrud = CrudFactory(User, sort_fields=[User.username, User.email]) + dep = UserSortCrud.sort_params() + + param_names = set(inspect.signature(dep).parameters) + assert param_names == {"sort_by", "sort_order"} + + def test_dependency_name_includes_model_name(self): + """Dependency function is named after the model.""" + UserSortCrud = CrudFactory(User, sort_fields=[User.username]) + dep = UserSortCrud.sort_params() + assert getattr(dep, "__name__") == "UserSortParams" + + def test_raises_when_no_sort_fields(self): + """ValueError raised when no sort_fields are configured or provided.""" + with pytest.raises(ValueError, match="no sort_fields"): + UserCrud.sort_params() + + def test_sort_fields_override(self): + """sort_fields= parameter overrides the class-level default.""" + UserSortCrud = CrudFactory(User, sort_fields=[User.username, User.email]) + dep = UserSortCrud.sort_params(sort_fields=[User.email]) + + param_names = set(inspect.signature(dep).parameters) + assert "sort_by" in param_names + # description should only mention email, not username + sig = inspect.signature(dep) + description = sig.parameters["sort_by"].default.description + assert "email" in description + assert "username" not in description + + def test_sort_by_description_lists_valid_fields(self): + """sort_by query param description mentions each allowed field.""" + UserSortCrud = CrudFactory(User, sort_fields=[User.username, User.email]) + dep = UserSortCrud.sort_params() + + sig = inspect.signature(dep) + description = sig.parameters["sort_by"].default.description + assert "username" in description + assert "email" in description + + def test_default_order_reflected_in_sort_order_default(self): + """default_order is used as the default value for sort_order.""" + UserSortCrud = CrudFactory(User, sort_fields=[User.username]) + dep_asc = UserSortCrud.sort_params(default_order="asc") + dep_desc = UserSortCrud.sort_params(default_order="desc") + + sig_asc = inspect.signature(dep_asc) + sig_desc = inspect.signature(dep_desc) + assert sig_asc.parameters["sort_order"].default.default == "asc" + assert sig_desc.parameters["sort_order"].default.default == "desc" + + @pytest.mark.anyio + async def test_no_sort_by_no_default_returns_none(self): + """Returns None when sort_by is absent and no default_field is set.""" + UserSortCrud = CrudFactory(User, sort_fields=[User.username]) + dep = UserSortCrud.sort_params() + result = await dep(sort_by=None, sort_order="asc") + assert result is None + + @pytest.mark.anyio + async def test_no_sort_by_with_default_field_returns_asc_expression(self): + """Returns default_field.asc() when sort_by absent and sort_order=asc.""" + UserSortCrud = CrudFactory(User, sort_fields=[User.username]) + dep = UserSortCrud.sort_params(default_field=User.username) + result = await dep(sort_by=None, sort_order="asc") + assert isinstance(result, UnaryExpression) + assert "ASC" in str(result) + + @pytest.mark.anyio + async def test_no_sort_by_with_default_field_returns_desc_expression(self): + """Returns default_field.desc() when sort_by absent and sort_order=desc.""" + UserSortCrud = CrudFactory(User, sort_fields=[User.username]) + dep = UserSortCrud.sort_params(default_field=User.username) + result = await dep(sort_by=None, sort_order="desc") + assert isinstance(result, UnaryExpression) + assert "DESC" in str(result) + + @pytest.mark.anyio + async def test_valid_sort_by_asc(self): + """Returns field.asc() for a valid sort_by with sort_order=asc.""" + UserSortCrud = CrudFactory(User, sort_fields=[User.username]) + dep = UserSortCrud.sort_params() + result = await dep(sort_by="username", sort_order="asc") + assert isinstance(result, UnaryExpression) + assert "ASC" in str(result) + + @pytest.mark.anyio + async def test_valid_sort_by_desc(self): + """Returns field.desc() for a valid sort_by with sort_order=desc.""" + UserSortCrud = CrudFactory(User, sort_fields=[User.username]) + dep = UserSortCrud.sort_params() + result = await dep(sort_by="username", sort_order="desc") + assert isinstance(result, UnaryExpression) + assert "DESC" in str(result) + + @pytest.mark.anyio + async def test_invalid_sort_by_raises_invalid_sort_field_error(self): + """Raises InvalidSortFieldError for an unknown sort_by value.""" + UserSortCrud = CrudFactory(User, sort_fields=[User.username]) + dep = UserSortCrud.sort_params() + with pytest.raises(InvalidSortFieldError) as exc_info: + await dep(sort_by="nonexistent", sort_order="asc") + assert exc_info.value.field == "nonexistent" + assert "username" in exc_info.value.valid_fields + + @pytest.mark.anyio + async def test_multiple_fields_all_resolve(self): + """All configured fields resolve correctly via sort_by.""" + UserSortCrud = CrudFactory(User, sort_fields=[User.username, User.email]) + dep = UserSortCrud.sort_params() + result_username = await dep(sort_by="username", sort_order="asc") + result_email = await dep(sort_by="email", sort_order="desc") + assert isinstance(result_username, ColumnElement) + assert isinstance(result_email, ColumnElement) + + @pytest.mark.anyio + async def test_sort_params_integrates_with_get_multi( + self, db_session: AsyncSession + ): + """sort_params output is accepted by get_multi(order_by=...).""" + UserSortCrud = CrudFactory(User, sort_fields=[User.username]) + await UserCrud.create( + db_session, UserCreate(username="charlie", email="c@test.com") + ) + await UserCrud.create( + db_session, UserCreate(username="alice", email="a@test.com") + ) + + dep = UserSortCrud.sort_params() + order_by = await dep(sort_by="username", sort_order="asc") + results = await UserSortCrud.get_multi(db_session, order_by=order_by) + + assert results[0].username == "alice" + assert results[1].username == "charlie" diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 6b7ae25..0dc1670 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -8,6 +8,7 @@ from fastapi_toolsets.exceptions import ( ApiException, ConflictError, ForbiddenError, + InvalidSortFieldError, NotFoundError, UnauthorizedError, generate_error_responses, @@ -334,3 +335,43 @@ class TestExceptionIntegration: assert response.status_code == 200 assert response.json() == {"id": 1} + + +class TestInvalidSortFieldError: + """Tests for InvalidSortFieldError exception.""" + + def test_api_error_attributes(self): + """InvalidSortFieldError has correct api_error metadata.""" + assert InvalidSortFieldError.api_error.code == 422 + assert InvalidSortFieldError.api_error.err_code == "SORT-422" + assert InvalidSortFieldError.api_error.msg == "Invalid Sort Field" + + def test_stores_field_and_valid_fields(self): + """InvalidSortFieldError stores field and valid_fields on the instance.""" + error = InvalidSortFieldError("unknown", ["name", "created_at"]) + assert error.field == "unknown" + assert error.valid_fields == ["name", "created_at"] + + def test_message_contains_field_and_valid_fields(self): + """Exception message mentions the bad field and valid options.""" + error = InvalidSortFieldError("bad_field", ["name", "email"]) + assert "bad_field" in str(error) + assert "name" in str(error) + assert "email" in str(error) + + def test_handled_as_422_by_exception_handler(self): + """init_exceptions_handlers turns InvalidSortFieldError into a 422 response.""" + app = FastAPI() + init_exceptions_handlers(app) + + @app.get("/items") + async def list_items(): + raise InvalidSortFieldError("bad", ["name"]) + + client = TestClient(app) + response = client.get("/items") + + assert response.status_code == 422 + data = response.json() + assert data["error_code"] == "SORT-422" + assert data["status"] == "FAIL"