mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-03-01 17:00:48 +01:00
Compare commits
6 Commits
117675d02f
...
chore/vers
| Author | SHA1 | Date | |
|---|---|---|---|
|
0fc86d3c34
|
|||
|
82ef96082e
|
|||
|
e0828c7e71
|
|||
|
59d028d00e
|
|||
|
56d365d14b
|
|||
|
|
a257d85d45 |
@@ -1,6 +1,6 @@
|
||||
# Pagination & search
|
||||
|
||||
This example builds an articles listing endpoint that supports **offset pagination**, **cursor pagination**, **full-text search**, and **faceted filtering** — all from a single `CrudFactory` definition.
|
||||
This example builds an articles listing endpoint that supports **offset pagination**, **cursor pagination**, **full-text search**, **faceted filtering**, and **sorting** — all from a single `CrudFactory` definition.
|
||||
|
||||
## Models
|
||||
|
||||
@@ -16,7 +16,7 @@ This example builds an articles listing endpoint that supports **offset paginati
|
||||
|
||||
## Crud
|
||||
|
||||
Declare `facet_fields` and `searchable_fields` once on [`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory). All endpoints built from this class share the same defaults and can override them per call.
|
||||
Declare `searchable_fields`, `facet_fields`, and `order_fields` once on [`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory). All endpoints built from this class share the same defaults and can override them per call.
|
||||
|
||||
```python title="crud.py"
|
||||
--8<-- "docs_src/examples/pagination_search/crud.py"
|
||||
@@ -46,14 +46,14 @@ Declare `facet_fields` and `searchable_fields` once on [`CrudFactory`](../refere
|
||||
|
||||
Best for admin panels or any UI that needs a total item count and numbered pages.
|
||||
|
||||
```python title="routes.py:1:27"
|
||||
--8<-- "docs_src/examples/pagination_search/routes.py:1:27"
|
||||
```python title="routes.py:1:36"
|
||||
--8<-- "docs_src/examples/pagination_search/routes.py:1:36"
|
||||
```
|
||||
|
||||
**Example request**
|
||||
|
||||
```
|
||||
GET /articles/offset?page=2&items_per_page=10&search=fastapi&status=published
|
||||
GET /articles/offset?page=2&items_per_page=10&search=fastapi&status=published&order_by=title&order=asc
|
||||
```
|
||||
|
||||
**Example response**
|
||||
@@ -83,14 +83,14 @@ GET /articles/offset?page=2&items_per_page=10&search=fastapi&status=published
|
||||
|
||||
Best for feeds, infinite scroll, or any high-throughput API where offset performance degrades.
|
||||
|
||||
```python title="routes.py:30:45"
|
||||
--8<-- "docs_src/examples/pagination_search/routes.py:30:45"
|
||||
```python title="routes.py:39:59"
|
||||
--8<-- "docs_src/examples/pagination_search/routes.py:39:59"
|
||||
```
|
||||
|
||||
**Example request**
|
||||
|
||||
```
|
||||
GET /articles/cursor?items_per_page=10&status=published
|
||||
GET /articles/cursor?items_per_page=10&status=published&order_by=created_at&order=desc
|
||||
```
|
||||
|
||||
**Example response**
|
||||
|
||||
@@ -95,9 +95,6 @@ The [`offset_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.Async
|
||||
}
|
||||
```
|
||||
|
||||
!!! warning "Deprecated: `paginate`"
|
||||
The `paginate` function is a backward-compatible alias for `offset_paginate`. This function is **deprecated** and will be removed in **v2.0**.
|
||||
|
||||
### Cursor pagination
|
||||
|
||||
```python
|
||||
@@ -295,6 +292,8 @@ Use `filter_by` to pass the client's chosen filter values directly — no need t
|
||||
Use [`filter_params()`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.filter_params) to generate a dict with the facet filter values from the query parameters:
|
||||
|
||||
```python
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends
|
||||
|
||||
UserCrud = CrudFactory(
|
||||
@@ -306,7 +305,7 @@ UserCrud = CrudFactory(
|
||||
async def list_users(
|
||||
session: SessionDep,
|
||||
page: int = 1,
|
||||
filter_by: dict[str, list[str]] = Depends(UserCrud.filter_params()),
|
||||
filter_by: Annotated[dict[str, list[str]], Depends(UserCrud.filter_params())],
|
||||
) -> PaginatedResponse[UserRead]:
|
||||
return await UserCrud.offset_paginate(
|
||||
session=session,
|
||||
@@ -323,6 +322,58 @@ GET /users?status=active&country=FR → filter_by={"status": ["active"], "coun
|
||||
GET /users?role=admin&role=editor → filter_by={"role": ["admin", "editor"]} (IN clause)
|
||||
```
|
||||
|
||||
## Sorting
|
||||
|
||||
!!! info "Added in `v1.3`"
|
||||
|
||||
Declare `order_fields` on the CRUD class to expose client-driven column ordering via `order_by` and `order` query parameters.
|
||||
|
||||
```python
|
||||
UserCrud = CrudFactory(
|
||||
model=User,
|
||||
order_fields=[
|
||||
User.name,
|
||||
User.created_at,
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
Call [`order_params()`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.order_params) to generate a FastAPI dependency that maps the query parameters to an [`OrderByClause`](../reference/crud.md#fastapi_toolsets.crud.factory.OrderByClause) expression:
|
||||
|
||||
```python
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi_toolsets.crud import OrderByClause
|
||||
|
||||
@router.get("")
|
||||
async def list_users(
|
||||
session: SessionDep,
|
||||
order_by: Annotated[OrderByClause | None, Depends(UserCrud.order_params())],
|
||||
) -> PaginatedResponse[UserRead]:
|
||||
return await UserCrud.offset_paginate(session=session, order_by=order_by)
|
||||
```
|
||||
|
||||
The dependency adds two query parameters to the endpoint:
|
||||
|
||||
| Parameter | Type |
|
||||
| ---------- | --------------- |
|
||||
| `order_by` | `str | null` |
|
||||
| `order` | `asc` or `desc` |
|
||||
|
||||
```
|
||||
GET /users?order_by=name&order=asc → ORDER BY users.name ASC
|
||||
GET /users?order_by=name&order=desc → ORDER BY users.name DESC
|
||||
```
|
||||
|
||||
An unknown `order_by` value raises [`InvalidOrderFieldError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.InvalidOrderFieldError) (HTTP 422).
|
||||
|
||||
You can also pass `order_fields` directly to `order_params()` to override the class-level defaults without modifying them:
|
||||
|
||||
```python
|
||||
UserOrderParams = UserCrud.order_params(order_fields=[User.name])
|
||||
```
|
||||
|
||||
## Relationship loading
|
||||
|
||||
!!! info "Added in `v1.1`"
|
||||
@@ -417,9 +468,6 @@ async def list_users(session: SessionDep, page: int = 1) -> PaginatedResponse[Us
|
||||
|
||||
The schema must have `from_attributes=True` (or inherit from [`PydanticBase`](../reference/schemas.md#fastapi_toolsets.schemas.PydanticBase)) so it can be built from SQLAlchemy model instances.
|
||||
|
||||
!!! warning "Deprecated: `as_response`"
|
||||
The `as_response=True` parameter is **deprecated** and will be removed in **v2.0**. Replace it with `schema=YourSchema`.
|
||||
|
||||
---
|
||||
|
||||
[:material-api: API Reference](../reference/crud.md)
|
||||
|
||||
@@ -13,6 +13,7 @@ from fastapi_toolsets.exceptions import (
|
||||
ConflictError,
|
||||
NoSearchableFieldsError,
|
||||
InvalidFacetFilterError,
|
||||
InvalidOrderFieldError,
|
||||
generate_error_responses,
|
||||
init_exceptions_handlers,
|
||||
)
|
||||
@@ -32,6 +33,8 @@ from fastapi_toolsets.exceptions import (
|
||||
|
||||
## ::: fastapi_toolsets.exceptions.exceptions.InvalidFacetFilterError
|
||||
|
||||
## ::: fastapi_toolsets.exceptions.exceptions.InvalidOrderFieldError
|
||||
|
||||
## ::: fastapi_toolsets.exceptions.exceptions.generate_error_responses
|
||||
|
||||
## ::: fastapi_toolsets.exceptions.handler.init_exceptions_handlers
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
from fastapi import FastAPI
|
||||
|
||||
from fastapi_toolsets.exceptions import init_exceptions_handlers
|
||||
|
||||
from .routes import router
|
||||
|
||||
app = FastAPI()
|
||||
init_exceptions_handlers(app=app)
|
||||
app.include_router(router=router)
|
||||
|
||||
@@ -14,6 +14,8 @@ ArticleCrud = CrudFactory(
|
||||
Article.status,
|
||||
(Article.category, Category.name),
|
||||
],
|
||||
order_fields=[ # fields exposed for client-driven ordering
|
||||
Article.title,
|
||||
Article.created_at,
|
||||
],
|
||||
)
|
||||
|
||||
ArticleFilters = ArticleCrud.filter_params()
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from fastapi_toolsets.crud import OrderByClause
|
||||
from fastapi_toolsets.schemas import PaginatedResponse
|
||||
|
||||
from .crud import ArticleCrud
|
||||
from .db import SessionDep
|
||||
from .models import Article
|
||||
from .schemas import ArticleRead
|
||||
|
||||
router = APIRouter(prefix="/articles")
|
||||
@@ -12,10 +16,14 @@ router = APIRouter(prefix="/articles")
|
||||
@router.get("/offset")
|
||||
async def list_articles_offset(
|
||||
session: SessionDep,
|
||||
filter_by: Annotated[dict[str, list[str]], Depends(ArticleCrud.filter_params())],
|
||||
order_by: Annotated[
|
||||
OrderByClause | None,
|
||||
Depends(ArticleCrud.order_params(default_field=Article.created_at)),
|
||||
],
|
||||
page: int = Query(1, ge=1),
|
||||
items_per_page: int = Query(20, ge=1, le=100),
|
||||
search: str | None = None,
|
||||
filter_by: dict[str, list[str]] = Depends(ArticleCrud.filter_params()),
|
||||
) -> PaginatedResponse[ArticleRead]:
|
||||
return await ArticleCrud.offset_paginate(
|
||||
session=session,
|
||||
@@ -23,6 +31,7 @@ async def list_articles_offset(
|
||||
items_per_page=items_per_page,
|
||||
search=search,
|
||||
filter_by=filter_by or None,
|
||||
order_by=order_by,
|
||||
schema=ArticleRead,
|
||||
)
|
||||
|
||||
@@ -30,10 +39,14 @@ async def list_articles_offset(
|
||||
@router.get("/cursor")
|
||||
async def list_articles_cursor(
|
||||
session: SessionDep,
|
||||
filter_by: Annotated[dict[str, list[str]], Depends(ArticleCrud.filter_params())],
|
||||
order_by: Annotated[
|
||||
OrderByClause | None,
|
||||
Depends(ArticleCrud.order_params(default_field=Article.created_at)),
|
||||
],
|
||||
cursor: str | None = None,
|
||||
items_per_page: int = Query(20, ge=1, le=100),
|
||||
search: str | None = None,
|
||||
filter_by: dict[str, list[str]] = Depends(ArticleCrud.filter_params()),
|
||||
) -> PaginatedResponse[ArticleRead]:
|
||||
return await ArticleCrud.cursor_paginate(
|
||||
session=session,
|
||||
@@ -41,5 +54,6 @@ async def list_articles_cursor(
|
||||
items_per_page=items_per_page,
|
||||
search=search,
|
||||
filter_by=filter_by or None,
|
||||
order_by=order_by,
|
||||
schema=ArticleRead,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "fastapi-toolsets"
|
||||
version = "1.2.1"
|
||||
version = "1.3.0"
|
||||
description = "Production-ready utilities for FastAPI applications"
|
||||
readme = "README.md"
|
||||
license = "MIT"
|
||||
|
||||
@@ -21,4 +21,4 @@ Example usage:
|
||||
return Response(data={"user": user.username}, message="Success")
|
||||
"""
|
||||
|
||||
__version__ = "1.2.1"
|
||||
__version__ = "1.3.0"
|
||||
|
||||
@@ -72,7 +72,7 @@ async def load(
|
||||
registry = get_fixtures_registry()
|
||||
db_context = get_db_context()
|
||||
|
||||
context_list = [c.value for c in contexts] if contexts else [Context.BASE]
|
||||
context_list = list(contexts) if contexts else [Context.BASE]
|
||||
|
||||
ordered = registry.resolve_context_dependencies(*context_list)
|
||||
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
"""Generic async CRUD operations for SQLAlchemy models."""
|
||||
|
||||
from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError
|
||||
from .factory import CrudFactory, JoinType, M2MFieldType
|
||||
from .search import (
|
||||
FacetFieldType,
|
||||
SearchConfig,
|
||||
get_searchable_fields,
|
||||
)
|
||||
from ..types import FacetFieldType, JoinType, M2MFieldType, OrderByClause
|
||||
from .factory import CrudFactory
|
||||
from .search import SearchConfig, get_searchable_fields
|
||||
|
||||
__all__ = [
|
||||
"CrudFactory",
|
||||
@@ -16,5 +13,6 @@ __all__ = [
|
||||
"JoinType",
|
||||
"M2MFieldType",
|
||||
"NoSearchableFieldsError",
|
||||
"OrderByClause",
|
||||
"SearchConfig",
|
||||
]
|
||||
|
||||
@@ -6,11 +6,10 @@ import base64
|
||||
import inspect
|
||||
import json
|
||||
import uuid as uuid_module
|
||||
import warnings
|
||||
from collections.abc import Awaitable, Callable, Mapping, Sequence
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload
|
||||
from typing import Any, ClassVar, Generic, Literal, Self, cast, overload
|
||||
|
||||
from fastapi import Query
|
||||
from pydantic import BaseModel
|
||||
@@ -24,23 +23,25 @@ from sqlalchemy.sql.base import ExecutableOption
|
||||
from sqlalchemy.sql.roles import WhereHavingRole
|
||||
|
||||
from ..db import get_transaction
|
||||
from ..exceptions import NotFoundError
|
||||
from ..exceptions import InvalidOrderFieldError, NotFoundError
|
||||
from ..schemas import CursorPagination, OffsetPagination, PaginatedResponse, Response
|
||||
from .search import (
|
||||
from ..types import (
|
||||
FacetFieldType,
|
||||
SearchConfig,
|
||||
JoinType,
|
||||
M2MFieldType,
|
||||
ModelType,
|
||||
OrderByClause,
|
||||
SchemaType,
|
||||
SearchFieldType,
|
||||
)
|
||||
from .search import (
|
||||
SearchConfig,
|
||||
build_facets,
|
||||
build_filter_by,
|
||||
build_search_filters,
|
||||
facet_keys,
|
||||
)
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
||||
SchemaType = TypeVar("SchemaType", bound=BaseModel)
|
||||
JoinType = list[tuple[type[DeclarativeBase], Any]]
|
||||
M2MFieldType = Mapping[str, QueryableAttribute[Any]]
|
||||
|
||||
|
||||
def _encode_cursor(value: Any) -> str:
|
||||
"""Encode cursor column value as an base64 string."""
|
||||
@@ -52,6 +53,22 @@ def _decode_cursor(cursor: str) -> str:
|
||||
return json.loads(base64.b64decode(cursor.encode()).decode())
|
||||
|
||||
|
||||
def _apply_joins(q: Any, joins: JoinType | None, outer_join: bool) -> Any:
|
||||
"""Apply a list of (model, condition) joins to a SQLAlchemy select query."""
|
||||
if not joins:
|
||||
return q
|
||||
for model, condition in joins:
|
||||
q = q.outerjoin(model, condition) if outer_join else q.join(model, condition)
|
||||
return q
|
||||
|
||||
|
||||
def _apply_search_joins(q: Any, search_joins: list[Any]) -> Any:
|
||||
"""Apply relationship-based outer joins (from search/filter_by) to a query."""
|
||||
for join_rel in search_joins:
|
||||
q = q.outerjoin(join_rel)
|
||||
return q
|
||||
|
||||
|
||||
class AsyncCrud(Generic[ModelType]):
|
||||
"""Generic async CRUD operations for SQLAlchemy models.
|
||||
|
||||
@@ -61,6 +78,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
model: ClassVar[type[DeclarativeBase]]
|
||||
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
|
||||
facet_fields: ClassVar[Sequence[FacetFieldType] | None] = None
|
||||
order_fields: ClassVar[Sequence[QueryableAttribute[Any]] | None] = None
|
||||
m2m_fields: ClassVar[M2MFieldType | None] = None
|
||||
default_load_options: ClassVar[list[ExecutableOption] | None] = None
|
||||
cursor_column: ClassVar[Any | None] = None
|
||||
@@ -130,6 +148,48 @@ class AsyncCrud(Generic[ModelType]):
|
||||
return set()
|
||||
return set(cls.m2m_fields.keys())
|
||||
|
||||
@classmethod
|
||||
def _resolve_facet_fields(
|
||||
cls: type[Self],
|
||||
facet_fields: Sequence[FacetFieldType] | None,
|
||||
) -> Sequence[FacetFieldType] | None:
|
||||
"""Return facet_fields if given, otherwise fall back to the class-level default."""
|
||||
return facet_fields if facet_fields is not None else cls.facet_fields
|
||||
|
||||
@classmethod
|
||||
def _prepare_filter_by(
|
||||
cls: type[Self],
|
||||
filter_by: dict[str, Any] | BaseModel | None,
|
||||
facet_fields: Sequence[FacetFieldType] | None,
|
||||
) -> tuple[list[Any], list[Any]]:
|
||||
"""Normalize filter_by and return (filters, joins) to apply to the query."""
|
||||
if isinstance(filter_by, BaseModel):
|
||||
filter_by = filter_by.model_dump(exclude_none=True)
|
||||
if not filter_by:
|
||||
return [], []
|
||||
resolved = cls._resolve_facet_fields(facet_fields)
|
||||
return build_filter_by(filter_by, resolved or [])
|
||||
|
||||
@classmethod
|
||||
async def _build_filter_attributes(
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
facet_fields: Sequence[FacetFieldType] | None,
|
||||
filters: list[Any],
|
||||
search_joins: list[Any],
|
||||
) -> dict[str, list[Any]] | None:
|
||||
"""Build facet filter_attributes, or return None if no facet fields configured."""
|
||||
resolved = cls._resolve_facet_fields(facet_fields)
|
||||
if not resolved:
|
||||
return None
|
||||
return await build_facets(
|
||||
session,
|
||||
cls.model,
|
||||
resolved,
|
||||
base_filters=filters,
|
||||
base_joins=search_joins,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def filter_params(
|
||||
cls: type[Self],
|
||||
@@ -150,7 +210,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
ValueError: If no facet fields are configured on this CRUD class and none are
|
||||
provided via ``facet_fields``.
|
||||
"""
|
||||
fields = facet_fields if facet_fields is not None else cls.facet_fields
|
||||
fields = cls._resolve_facet_fields(facet_fields)
|
||||
if not fields:
|
||||
raise ValueError(
|
||||
f"{cls.__name__} has no facet_fields configured. "
|
||||
@@ -176,6 +236,63 @@ class AsyncCrud(Generic[ModelType]):
|
||||
|
||||
return dependency
|
||||
|
||||
@classmethod
|
||||
def order_params(
|
||||
cls: type[Self],
|
||||
*,
|
||||
order_fields: Sequence[QueryableAttribute[Any]] | None = None,
|
||||
default_field: QueryableAttribute[Any] | None = None,
|
||||
default_order: Literal["asc", "desc"] = "asc",
|
||||
) -> Callable[..., Awaitable[OrderByClause | None]]:
|
||||
"""Return a FastAPI dependency that resolves order query params into an order_by clause.
|
||||
|
||||
Args:
|
||||
order_fields: Override the allowed order fields. Falls back to the class-level
|
||||
``order_fields`` if not provided.
|
||||
default_field: Field to order by when ``order_by`` query param is absent.
|
||||
If ``None`` and no ``order_by`` is provided, no ordering is applied.
|
||||
default_order: Default order direction when ``order`` is absent
|
||||
(``"asc"`` or ``"desc"``).
|
||||
|
||||
Returns:
|
||||
An async dependency function named ``{Model}OrderParams`` that resolves to an
|
||||
``OrderByClause`` (or ``None``). Pass it to ``Depends()`` in your route.
|
||||
|
||||
Raises:
|
||||
ValueError: If no order fields are configured on this CRUD class and none are
|
||||
provided via ``order_fields``.
|
||||
InvalidOrderFieldError: When the request provides an unknown ``order_by`` value.
|
||||
"""
|
||||
fields = order_fields if order_fields is not None else cls.order_fields
|
||||
if not fields:
|
||||
raise ValueError(
|
||||
f"{cls.__name__} has no order_fields configured. "
|
||||
"Pass order_fields= or set them on CrudFactory."
|
||||
)
|
||||
field_map: dict[str, QueryableAttribute[Any]] = {f.key: f for f in fields}
|
||||
valid_keys = sorted(field_map.keys())
|
||||
|
||||
async def dependency(
|
||||
order_by: str | None = Query(
|
||||
None, description=f"Field to order by. Valid values: {valid_keys}"
|
||||
),
|
||||
order: Literal["asc", "desc"] = Query(
|
||||
default_order, description="Sort direction"
|
||||
),
|
||||
) -> OrderByClause | None:
|
||||
if order_by is None:
|
||||
if default_field is None:
|
||||
return None
|
||||
field = default_field
|
||||
elif order_by not in field_map:
|
||||
raise InvalidOrderFieldError(order_by, valid_keys)
|
||||
else:
|
||||
field = field_map[order_by]
|
||||
return field.asc() if order == "asc" else field.desc()
|
||||
|
||||
dependency.__name__ = f"{cls.model.__name__}OrderParams"
|
||||
return dependency
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def create( # pragma: no cover
|
||||
@@ -184,10 +301,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
obj: BaseModel,
|
||||
*,
|
||||
schema: type[SchemaType],
|
||||
as_response: bool = ...,
|
||||
) -> Response[SchemaType]: ...
|
||||
|
||||
# Backward-compatible - will be removed in v2.0
|
||||
@overload
|
||||
@classmethod
|
||||
async def create( # pragma: no cover
|
||||
@@ -195,18 +310,6 @@ class AsyncCrud(Generic[ModelType]):
|
||||
session: AsyncSession,
|
||||
obj: BaseModel,
|
||||
*,
|
||||
as_response: Literal[True],
|
||||
schema: None = ...,
|
||||
) -> Response[ModelType]: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def create( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
obj: BaseModel,
|
||||
*,
|
||||
as_response: Literal[False] = ...,
|
||||
schema: None = ...,
|
||||
) -> ModelType: ...
|
||||
|
||||
@@ -216,29 +319,19 @@ class AsyncCrud(Generic[ModelType]):
|
||||
session: AsyncSession,
|
||||
obj: BaseModel,
|
||||
*,
|
||||
as_response: bool = False,
|
||||
schema: type[BaseModel] | None = None,
|
||||
) -> ModelType | Response[ModelType] | Response[Any]:
|
||||
) -> ModelType | Response[Any]:
|
||||
"""Create a new record in the database.
|
||||
|
||||
Args:
|
||||
session: DB async session
|
||||
obj: Pydantic model with data to create
|
||||
as_response: Deprecated. Use ``schema`` instead. Will be removed in v2.0.
|
||||
schema: Pydantic schema to serialize the result into. When provided,
|
||||
the result is automatically wrapped in a ``Response[schema]``.
|
||||
|
||||
Returns:
|
||||
Created model instance, or ``Response[schema]`` when ``schema`` is given,
|
||||
or ``Response[ModelType]`` when ``as_response=True`` (deprecated).
|
||||
Created model instance, or ``Response[schema]`` when ``schema`` is given.
|
||||
"""
|
||||
if as_response and schema is None:
|
||||
warnings.warn(
|
||||
"as_response is deprecated and will be removed in v2.0. "
|
||||
"Use schema=YourSchema instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
async with get_transaction(session):
|
||||
m2m_exclude = cls._m2m_schema_fields()
|
||||
data = (
|
||||
@@ -254,9 +347,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
session.add(db_model)
|
||||
await session.refresh(db_model)
|
||||
result = cast(ModelType, db_model)
|
||||
if as_response or schema:
|
||||
data_out = schema.model_validate(result) if schema else result
|
||||
return Response(data=data_out)
|
||||
if schema:
|
||||
return Response(data=schema.model_validate(result))
|
||||
return result
|
||||
|
||||
@overload
|
||||
@@ -271,10 +363,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
with_for_update: bool = False,
|
||||
load_options: list[ExecutableOption] | None = None,
|
||||
schema: type[SchemaType],
|
||||
as_response: bool = ...,
|
||||
) -> Response[SchemaType]: ...
|
||||
|
||||
# Backward-compatible - will be removed in v2.0
|
||||
@overload
|
||||
@classmethod
|
||||
async def get( # pragma: no cover
|
||||
@@ -286,22 +376,6 @@ class AsyncCrud(Generic[ModelType]):
|
||||
outer_join: bool = False,
|
||||
with_for_update: bool = False,
|
||||
load_options: list[ExecutableOption] | None = None,
|
||||
as_response: Literal[True],
|
||||
schema: None = ...,
|
||||
) -> Response[ModelType]: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def get( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
filters: list[Any],
|
||||
*,
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
with_for_update: bool = False,
|
||||
load_options: list[ExecutableOption] | None = None,
|
||||
as_response: Literal[False] = ...,
|
||||
schema: None = ...,
|
||||
) -> ModelType: ...
|
||||
|
||||
@@ -315,9 +389,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
outer_join: bool = False,
|
||||
with_for_update: bool = False,
|
||||
load_options: list[ExecutableOption] | None = None,
|
||||
as_response: bool = False,
|
||||
schema: type[BaseModel] | None = None,
|
||||
) -> ModelType | Response[ModelType] | Response[Any]:
|
||||
) -> ModelType | Response[Any]:
|
||||
"""Get exactly one record. Raises NotFoundError if not found.
|
||||
|
||||
Args:
|
||||
@@ -327,33 +400,18 @@ class AsyncCrud(Generic[ModelType]):
|
||||
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
|
||||
with_for_update: Lock the row for update
|
||||
load_options: SQLAlchemy loader options (e.g., selectinload)
|
||||
as_response: Deprecated. Use ``schema`` instead. Will be removed in v2.0.
|
||||
schema: Pydantic schema to serialize the result into. When provided,
|
||||
the result is automatically wrapped in a ``Response[schema]``.
|
||||
|
||||
Returns:
|
||||
Model instance, or ``Response[schema]`` when ``schema`` is given,
|
||||
or ``Response[ModelType]`` when ``as_response=True`` (deprecated).
|
||||
Model instance, or ``Response[schema]`` when ``schema`` is given.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If no record found
|
||||
MultipleResultsFound: If more than one record found
|
||||
"""
|
||||
if as_response and schema is None:
|
||||
warnings.warn(
|
||||
"as_response is deprecated and will be removed in v2.0. "
|
||||
"Use schema=YourSchema instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
q = select(cls.model)
|
||||
if joins:
|
||||
for model, condition in joins:
|
||||
q = (
|
||||
q.outerjoin(model, condition)
|
||||
if outer_join
|
||||
else q.join(model, condition)
|
||||
)
|
||||
q = _apply_joins(q, joins, outer_join)
|
||||
q = q.where(and_(*filters))
|
||||
if resolved := cls._resolve_load_options(load_options):
|
||||
q = q.options(*resolved)
|
||||
@@ -364,9 +422,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
if not item:
|
||||
raise NotFoundError()
|
||||
result = cast(ModelType, item)
|
||||
if as_response or schema:
|
||||
data_out = schema.model_validate(result) if schema else result
|
||||
return Response(data=data_out)
|
||||
if schema:
|
||||
return Response(data=schema.model_validate(result))
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
@@ -392,13 +449,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
Model instance or None
|
||||
"""
|
||||
q = select(cls.model)
|
||||
if joins:
|
||||
for model, condition in joins:
|
||||
q = (
|
||||
q.outerjoin(model, condition)
|
||||
if outer_join
|
||||
else q.join(model, condition)
|
||||
)
|
||||
q = _apply_joins(q, joins, outer_join)
|
||||
if filters:
|
||||
q = q.where(and_(*filters))
|
||||
if resolved := cls._resolve_load_options(load_options):
|
||||
@@ -415,7 +466,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
load_options: list[ExecutableOption] | None = None,
|
||||
order_by: Any | None = None,
|
||||
order_by: OrderByClause | None = None,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
) -> Sequence[ModelType]:
|
||||
@@ -435,13 +486,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
List of model instances
|
||||
"""
|
||||
q = select(cls.model)
|
||||
if joins:
|
||||
for model, condition in joins:
|
||||
q = (
|
||||
q.outerjoin(model, condition)
|
||||
if outer_join
|
||||
else q.join(model, condition)
|
||||
)
|
||||
q = _apply_joins(q, joins, outer_join)
|
||||
if filters:
|
||||
q = q.where(and_(*filters))
|
||||
if resolved := cls._resolve_load_options(load_options):
|
||||
@@ -466,10 +511,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
exclude_unset: bool = True,
|
||||
exclude_none: bool = False,
|
||||
schema: type[SchemaType],
|
||||
as_response: bool = ...,
|
||||
) -> Response[SchemaType]: ...
|
||||
|
||||
# Backward-compatible - will be removed in v2.0
|
||||
@overload
|
||||
@classmethod
|
||||
async def update( # pragma: no cover
|
||||
@@ -480,21 +523,6 @@ class AsyncCrud(Generic[ModelType]):
|
||||
*,
|
||||
exclude_unset: bool = True,
|
||||
exclude_none: bool = False,
|
||||
as_response: Literal[True],
|
||||
schema: None = ...,
|
||||
) -> Response[ModelType]: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def update( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
obj: BaseModel,
|
||||
filters: list[Any],
|
||||
*,
|
||||
exclude_unset: bool = True,
|
||||
exclude_none: bool = False,
|
||||
as_response: Literal[False] = ...,
|
||||
schema: None = ...,
|
||||
) -> ModelType: ...
|
||||
|
||||
@@ -507,9 +535,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
*,
|
||||
exclude_unset: bool = True,
|
||||
exclude_none: bool = False,
|
||||
as_response: bool = False,
|
||||
schema: type[BaseModel] | None = None,
|
||||
) -> ModelType | Response[ModelType] | Response[Any]:
|
||||
) -> ModelType | Response[Any]:
|
||||
"""Update a record in the database.
|
||||
|
||||
Args:
|
||||
@@ -518,24 +545,15 @@ class AsyncCrud(Generic[ModelType]):
|
||||
filters: List of SQLAlchemy filter conditions
|
||||
exclude_unset: Exclude fields not explicitly set in the schema
|
||||
exclude_none: Exclude fields with None value
|
||||
as_response: Deprecated. Use ``schema`` instead. Will be removed in v2.0.
|
||||
schema: Pydantic schema to serialize the result into. When provided,
|
||||
the result is automatically wrapped in a ``Response[schema]``.
|
||||
|
||||
Returns:
|
||||
Updated model instance, or ``Response[schema]`` when ``schema`` is given,
|
||||
or ``Response[ModelType]`` when ``as_response=True`` (deprecated).
|
||||
Updated model instance, or ``Response[schema]`` when ``schema`` is given.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If no record found
|
||||
"""
|
||||
if as_response and schema is None:
|
||||
warnings.warn(
|
||||
"as_response is deprecated and will be removed in v2.0. "
|
||||
"Use schema=YourSchema instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
async with get_transaction(session):
|
||||
m2m_exclude = cls._m2m_schema_fields()
|
||||
|
||||
@@ -565,9 +583,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
for rel_attr, related_instances in m2m_resolved.items():
|
||||
setattr(db_model, rel_attr, related_instances)
|
||||
await session.refresh(db_model)
|
||||
if as_response or schema:
|
||||
data_out = schema.model_validate(db_model) if schema else db_model
|
||||
return Response(data=data_out)
|
||||
if schema:
|
||||
return Response(data=schema.model_validate(db_model))
|
||||
return db_model
|
||||
|
||||
@classmethod
|
||||
@@ -623,7 +640,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
session: AsyncSession,
|
||||
filters: list[Any],
|
||||
*,
|
||||
as_response: Literal[True],
|
||||
return_response: Literal[True],
|
||||
) -> Response[None]: ...
|
||||
|
||||
@overload
|
||||
@@ -633,8 +650,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
session: AsyncSession,
|
||||
filters: list[Any],
|
||||
*,
|
||||
as_response: Literal[False] = ...,
|
||||
) -> bool: ...
|
||||
return_response: Literal[False] = ...,
|
||||
) -> None: ...
|
||||
|
||||
@classmethod
|
||||
async def delete(
|
||||
@@ -642,33 +659,26 @@ class AsyncCrud(Generic[ModelType]):
|
||||
session: AsyncSession,
|
||||
filters: list[Any],
|
||||
*,
|
||||
as_response: bool = False,
|
||||
) -> bool | Response[None]:
|
||||
return_response: bool = False,
|
||||
) -> None | Response[None]:
|
||||
"""Delete records from the database.
|
||||
|
||||
Args:
|
||||
session: DB async session
|
||||
filters: List of SQLAlchemy filter conditions
|
||||
as_response: Deprecated. Will be removed in v2.0. When ``True``,
|
||||
returns ``Response[None]`` instead of ``bool``.
|
||||
return_response: When ``True``, returns ``Response[None]`` instead
|
||||
of ``None``. Useful for API endpoints that expect a consistent
|
||||
response envelope.
|
||||
|
||||
Returns:
|
||||
``True`` if deletion was executed, or ``Response[None]`` when
|
||||
``as_response=True`` (deprecated).
|
||||
``None``, or ``Response[None]`` when ``return_response=True``.
|
||||
"""
|
||||
if as_response:
|
||||
warnings.warn(
|
||||
"as_response is deprecated and will be removed in v2.0. "
|
||||
"Use schema=YourSchema instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
async with get_transaction(session):
|
||||
q = sql_delete(cls.model).where(and_(*filters))
|
||||
await session.execute(q)
|
||||
if as_response:
|
||||
if return_response:
|
||||
return Response(data=None)
|
||||
return True
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def count(
|
||||
@@ -691,13 +701,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
Number of matching records
|
||||
"""
|
||||
q = select(func.count()).select_from(cls.model)
|
||||
if joins:
|
||||
for model, condition in joins:
|
||||
q = (
|
||||
q.outerjoin(model, condition)
|
||||
if outer_join
|
||||
else q.join(model, condition)
|
||||
)
|
||||
q = _apply_joins(q, joins, outer_join)
|
||||
if filters:
|
||||
q = q.where(and_(*filters))
|
||||
result = await session.execute(q)
|
||||
@@ -724,58 +728,11 @@ class AsyncCrud(Generic[ModelType]):
|
||||
True if at least one record matches
|
||||
"""
|
||||
q = select(cls.model)
|
||||
if joins:
|
||||
for model, condition in joins:
|
||||
q = (
|
||||
q.outerjoin(model, condition)
|
||||
if outer_join
|
||||
else q.join(model, condition)
|
||||
)
|
||||
q = _apply_joins(q, joins, outer_join)
|
||||
q = q.where(and_(*filters)).exists().select()
|
||||
result = await session.execute(q)
|
||||
return bool(result.scalar())
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def offset_paginate( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
*,
|
||||
filters: list[Any] | None = None,
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
load_options: list[ExecutableOption] | None = None,
|
||||
order_by: Any | None = None,
|
||||
page: int = 1,
|
||||
items_per_page: int = 20,
|
||||
search: str | SearchConfig | None = None,
|
||||
search_fields: Sequence[SearchFieldType] | None = None,
|
||||
facet_fields: Sequence[FacetFieldType] | None = None,
|
||||
filter_by: dict[str, Any] | BaseModel | None = None,
|
||||
schema: type[SchemaType],
|
||||
) -> PaginatedResponse[SchemaType]: ...
|
||||
|
||||
# Backward-compatible - will be removed in v2.0
|
||||
@overload
|
||||
@classmethod
|
||||
async def offset_paginate( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
*,
|
||||
filters: list[Any] | None = None,
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
load_options: list[ExecutableOption] | None = None,
|
||||
order_by: Any | None = None,
|
||||
page: int = 1,
|
||||
items_per_page: int = 20,
|
||||
search: str | SearchConfig | None = None,
|
||||
search_fields: Sequence[SearchFieldType] | None = None,
|
||||
facet_fields: Sequence[FacetFieldType] | None = None,
|
||||
filter_by: dict[str, Any] | BaseModel | None = None,
|
||||
schema: None = ...,
|
||||
) -> PaginatedResponse[ModelType]: ...
|
||||
|
||||
@classmethod
|
||||
async def offset_paginate(
|
||||
cls: type[Self],
|
||||
@@ -785,15 +742,15 @@ class AsyncCrud(Generic[ModelType]):
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
load_options: list[ExecutableOption] | None = None,
|
||||
order_by: Any | None = None,
|
||||
order_by: OrderByClause | None = None,
|
||||
page: int = 1,
|
||||
items_per_page: int = 20,
|
||||
search: str | SearchConfig | None = None,
|
||||
search_fields: Sequence[SearchFieldType] | None = None,
|
||||
facet_fields: Sequence[FacetFieldType] | None = None,
|
||||
filter_by: dict[str, Any] | BaseModel | None = None,
|
||||
schema: type[BaseModel] | None = None,
|
||||
) -> PaginatedResponse[ModelType] | PaginatedResponse[Any]:
|
||||
schema: type[BaseModel],
|
||||
) -> PaginatedResponse[Any]:
|
||||
"""Get paginated results using offset-based pagination.
|
||||
|
||||
Args:
|
||||
@@ -811,54 +768,36 @@ class AsyncCrud(Generic[ModelType]):
|
||||
filter_by: Dict of {column_key: value} to filter by declared facet fields.
|
||||
Keys must match the column.key of a facet field. Scalar → equality,
|
||||
list → IN clause. Raises InvalidFacetFilterError for unknown keys.
|
||||
schema: Optional Pydantic schema to serialize each item into.
|
||||
schema: Pydantic schema to serialize each item into.
|
||||
|
||||
Returns:
|
||||
PaginatedResponse with OffsetPagination metadata
|
||||
"""
|
||||
filters = list(filters) if filters else []
|
||||
offset = (page - 1) * items_per_page
|
||||
search_joins: list[Any] = []
|
||||
|
||||
if isinstance(filter_by, BaseModel):
|
||||
filter_by = filter_by.model_dump(exclude_none=True) or None
|
||||
|
||||
# Build filter_by conditions from declared facet fields
|
||||
if filter_by:
|
||||
resolved_facets_for_filter = (
|
||||
facet_fields if facet_fields is not None else cls.facet_fields
|
||||
)
|
||||
fb_filters, fb_joins = build_filter_by(
|
||||
filter_by, resolved_facets_for_filter or []
|
||||
)
|
||||
filters.extend(fb_filters)
|
||||
search_joins.extend(fb_joins)
|
||||
fb_filters, search_joins = cls._prepare_filter_by(filter_by, facet_fields)
|
||||
filters.extend(fb_filters)
|
||||
|
||||
# Build search filters
|
||||
if search:
|
||||
search_filters, search_joins = build_search_filters(
|
||||
search_filters, new_search_joins = build_search_filters(
|
||||
cls.model,
|
||||
search,
|
||||
search_fields=search_fields,
|
||||
default_fields=cls.searchable_fields,
|
||||
)
|
||||
filters.extend(search_filters)
|
||||
search_joins.extend(new_search_joins)
|
||||
|
||||
# Build query with joins
|
||||
q = select(cls.model)
|
||||
|
||||
# Apply explicit joins
|
||||
if joins:
|
||||
for model, condition in joins:
|
||||
q = (
|
||||
q.outerjoin(model, condition)
|
||||
if outer_join
|
||||
else q.join(model, condition)
|
||||
)
|
||||
q = _apply_joins(q, joins, outer_join)
|
||||
|
||||
# Apply search joins (always outer joins for search)
|
||||
for join_rel in search_joins:
|
||||
q = q.outerjoin(join_rel)
|
||||
q = _apply_search_joins(q, search_joins)
|
||||
|
||||
if filters:
|
||||
q = q.where(and_(*filters))
|
||||
@@ -870,9 +809,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
q = q.offset(offset).limit(items_per_page)
|
||||
result = await session.execute(q)
|
||||
raw_items = cast(list[ModelType], result.unique().scalars().all())
|
||||
items: list[Any] = (
|
||||
[schema.model_validate(item) for item in raw_items] if schema else raw_items
|
||||
)
|
||||
items: list[Any] = [schema.model_validate(item) for item in raw_items]
|
||||
|
||||
# Count query (with same joins and filters)
|
||||
pk_col = cls.model.__mapper__.primary_key[0]
|
||||
@@ -880,17 +817,10 @@ class AsyncCrud(Generic[ModelType]):
|
||||
count_q = count_q.select_from(cls.model)
|
||||
|
||||
# Apply explicit joins to count query
|
||||
if joins:
|
||||
for model, condition in joins:
|
||||
count_q = (
|
||||
count_q.outerjoin(model, condition)
|
||||
if outer_join
|
||||
else count_q.join(model, condition)
|
||||
)
|
||||
count_q = _apply_joins(count_q, joins, outer_join)
|
||||
|
||||
# Apply search joins to count query
|
||||
for join_rel in search_joins:
|
||||
count_q = count_q.outerjoin(join_rel)
|
||||
count_q = _apply_search_joins(count_q, search_joins)
|
||||
|
||||
if filters:
|
||||
count_q = count_q.where(and_(*filters))
|
||||
@@ -898,19 +828,9 @@ class AsyncCrud(Generic[ModelType]):
|
||||
count_result = await session.execute(count_q)
|
||||
total_count = count_result.scalar_one()
|
||||
|
||||
# Build facets
|
||||
resolved_facet_fields = (
|
||||
facet_fields if facet_fields is not None else cls.facet_fields
|
||||
filter_attributes = await cls._build_filter_attributes(
|
||||
session, facet_fields, filters, search_joins
|
||||
)
|
||||
filter_attributes: dict[str, list[Any]] | None = None
|
||||
if resolved_facet_fields:
|
||||
filter_attributes = await build_facets(
|
||||
session,
|
||||
cls.model,
|
||||
resolved_facet_fields,
|
||||
base_filters=filters or None,
|
||||
base_joins=search_joins or None,
|
||||
)
|
||||
|
||||
return PaginatedResponse(
|
||||
data=items,
|
||||
@@ -923,50 +843,6 @@ class AsyncCrud(Generic[ModelType]):
|
||||
filter_attributes=filter_attributes,
|
||||
)
|
||||
|
||||
# Backward-compatible - will be removed in v2.0
|
||||
paginate = offset_paginate
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def cursor_paginate( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
*,
|
||||
cursor: str | None = None,
|
||||
filters: list[Any] | None = None,
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
load_options: list[ExecutableOption] | None = None,
|
||||
order_by: Any | None = None,
|
||||
items_per_page: int = 20,
|
||||
search: str | SearchConfig | None = None,
|
||||
search_fields: Sequence[SearchFieldType] | None = None,
|
||||
facet_fields: Sequence[FacetFieldType] | None = None,
|
||||
filter_by: dict[str, Any] | BaseModel | None = None,
|
||||
schema: type[SchemaType],
|
||||
) -> PaginatedResponse[SchemaType]: ...
|
||||
|
||||
# Backward-compatible - will be removed in v2.0
|
||||
@overload
|
||||
@classmethod
|
||||
async def cursor_paginate( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
*,
|
||||
cursor: str | None = None,
|
||||
filters: list[Any] | None = None,
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
load_options: list[ExecutableOption] | None = None,
|
||||
order_by: Any | None = None,
|
||||
items_per_page: int = 20,
|
||||
search: str | SearchConfig | None = None,
|
||||
search_fields: Sequence[SearchFieldType] | None = None,
|
||||
facet_fields: Sequence[FacetFieldType] | None = None,
|
||||
filter_by: dict[str, Any] | BaseModel | None = None,
|
||||
schema: None = ...,
|
||||
) -> PaginatedResponse[ModelType]: ...
|
||||
|
||||
@classmethod
|
||||
async def cursor_paginate(
|
||||
cls: type[Self],
|
||||
@@ -977,14 +853,14 @@ class AsyncCrud(Generic[ModelType]):
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
load_options: list[ExecutableOption] | None = None,
|
||||
order_by: Any | None = None,
|
||||
order_by: OrderByClause | None = None,
|
||||
items_per_page: int = 20,
|
||||
search: str | SearchConfig | None = None,
|
||||
search_fields: Sequence[SearchFieldType] | None = None,
|
||||
facet_fields: Sequence[FacetFieldType] | None = None,
|
||||
filter_by: dict[str, Any] | BaseModel | None = None,
|
||||
schema: type[BaseModel] | None = None,
|
||||
) -> PaginatedResponse[ModelType] | PaginatedResponse[Any]:
|
||||
schema: type[BaseModel],
|
||||
) -> PaginatedResponse[Any]:
|
||||
"""Get paginated results using cursor-based pagination.
|
||||
|
||||
Args:
|
||||
@@ -1011,21 +887,9 @@ class AsyncCrud(Generic[ModelType]):
|
||||
PaginatedResponse with CursorPagination metadata
|
||||
"""
|
||||
filters = list(filters) if filters else []
|
||||
search_joins: list[Any] = []
|
||||
|
||||
if isinstance(filter_by, BaseModel):
|
||||
filter_by = filter_by.model_dump(exclude_none=True) or None
|
||||
|
||||
# Build filter_by conditions from declared facet fields
|
||||
if filter_by:
|
||||
resolved_facets_for_filter = (
|
||||
facet_fields if facet_fields is not None else cls.facet_fields
|
||||
)
|
||||
fb_filters, fb_joins = build_filter_by(
|
||||
filter_by, resolved_facets_for_filter or []
|
||||
)
|
||||
filters.extend(fb_filters)
|
||||
search_joins.extend(fb_joins)
|
||||
fb_filters, search_joins = cls._prepare_filter_by(filter_by, facet_fields)
|
||||
filters.extend(fb_filters)
|
||||
|
||||
if cls.cursor_column is None:
|
||||
raise ValueError(
|
||||
@@ -1058,29 +922,23 @@ class AsyncCrud(Generic[ModelType]):
|
||||
|
||||
# Build search filters
|
||||
if search:
|
||||
search_filters, search_joins = build_search_filters(
|
||||
search_filters, new_search_joins = build_search_filters(
|
||||
cls.model,
|
||||
search,
|
||||
search_fields=search_fields,
|
||||
default_fields=cls.searchable_fields,
|
||||
)
|
||||
filters.extend(search_filters)
|
||||
search_joins.extend(new_search_joins)
|
||||
|
||||
# Build query
|
||||
q = select(cls.model)
|
||||
|
||||
# Apply explicit joins
|
||||
if joins:
|
||||
for model, condition in joins:
|
||||
q = (
|
||||
q.outerjoin(model, condition)
|
||||
if outer_join
|
||||
else q.join(model, condition)
|
||||
)
|
||||
q = _apply_joins(q, joins, outer_join)
|
||||
|
||||
# Apply search joins (always outer joins)
|
||||
for join_rel in search_joins:
|
||||
q = q.outerjoin(join_rel)
|
||||
q = _apply_search_joins(q, search_joins)
|
||||
|
||||
if filters:
|
||||
q = q.where(and_(*filters))
|
||||
@@ -1110,25 +968,11 @@ class AsyncCrud(Generic[ModelType]):
|
||||
if cursor is not None and items_page:
|
||||
prev_cursor = _encode_cursor(getattr(items_page[0], cursor_col_name))
|
||||
|
||||
items: list[Any] = (
|
||||
[schema.model_validate(item) for item in items_page]
|
||||
if schema
|
||||
else items_page
|
||||
)
|
||||
items: list[Any] = [schema.model_validate(item) for item in items_page]
|
||||
|
||||
# Build facets
|
||||
resolved_facet_fields = (
|
||||
facet_fields if facet_fields is not None else cls.facet_fields
|
||||
filter_attributes = await cls._build_filter_attributes(
|
||||
session, facet_fields, filters, search_joins
|
||||
)
|
||||
filter_attributes: dict[str, list[Any]] | None = None
|
||||
if resolved_facet_fields:
|
||||
filter_attributes = await build_facets(
|
||||
session,
|
||||
cls.model,
|
||||
resolved_facet_fields,
|
||||
base_filters=filters or None,
|
||||
base_joins=search_joins or None,
|
||||
)
|
||||
|
||||
return PaginatedResponse(
|
||||
data=items,
|
||||
@@ -1147,6 +991,7 @@ def CrudFactory(
|
||||
*,
|
||||
searchable_fields: Sequence[SearchFieldType] | None = None,
|
||||
facet_fields: Sequence[FacetFieldType] | None = None,
|
||||
order_fields: Sequence[QueryableAttribute[Any]] | None = None,
|
||||
m2m_fields: M2MFieldType | None = None,
|
||||
default_load_options: list[ExecutableOption] | None = None,
|
||||
cursor_column: Any | None = None,
|
||||
@@ -1159,6 +1004,8 @@ def CrudFactory(
|
||||
facet_fields: Optional list of columns to compute distinct values for in paginated
|
||||
responses. Supports direct columns (``User.status``) and relationship tuples
|
||||
(``(User.role, Role.name)``). Can be overridden per call.
|
||||
order_fields: Optional list of model attributes that callers are allowed to order by
|
||||
via ``order_params()``. Can be overridden per call.
|
||||
m2m_fields: Optional mapping for many-to-many relationships.
|
||||
Maps schema field names (containing lists of IDs) to
|
||||
SQLAlchemy relationship attributes.
|
||||
@@ -1252,6 +1099,7 @@ def CrudFactory(
|
||||
"model": model,
|
||||
"searchable_fields": searchable_fields,
|
||||
"facet_fields": facet_fields,
|
||||
"order_fields": order_fields,
|
||||
"m2m_fields": m2m_fields,
|
||||
"default_load_options": default_load_options,
|
||||
"cursor_column": cursor_column,
|
||||
|
||||
@@ -1,24 +1,23 @@
|
||||
"""Search utilities for AsyncCrud."""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
from collections import Counter
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from sqlalchemy import String, or_, select
|
||||
from sqlalchemy import String, and_, or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||
|
||||
from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError
|
||||
from ..types import FacetFieldType, SearchFieldType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
|
||||
SearchFieldType = InstrumentedAttribute[Any] | tuple[InstrumentedAttribute[Any], ...]
|
||||
FacetFieldType = SearchFieldType
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchConfig:
|
||||
@@ -37,6 +36,7 @@ class SearchConfig:
|
||||
match_mode: Literal["any", "all"] = "any"
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def get_searchable_fields(
|
||||
model: type[DeclarativeBase],
|
||||
*,
|
||||
@@ -101,14 +101,11 @@ def build_search_filters(
|
||||
if isinstance(search, str):
|
||||
config = SearchConfig(query=search, fields=search_fields)
|
||||
else:
|
||||
config = search
|
||||
if search_fields is not None:
|
||||
config = SearchConfig(
|
||||
query=config.query,
|
||||
fields=search_fields,
|
||||
case_sensitive=config.case_sensitive,
|
||||
match_mode=config.match_mode,
|
||||
)
|
||||
config = (
|
||||
replace(search, fields=search_fields)
|
||||
if search_fields is not None
|
||||
else search
|
||||
)
|
||||
|
||||
if not config.query or not config.query.strip():
|
||||
return [], []
|
||||
@@ -227,8 +224,6 @@ async def build_facets(
|
||||
q = q.outerjoin(rel)
|
||||
|
||||
if base_filters:
|
||||
from sqlalchemy import and_
|
||||
|
||||
q = q.where(and_(*base_filters))
|
||||
|
||||
q = q.order_by(column)
|
||||
|
||||
@@ -1,20 +1,17 @@
|
||||
"""Dependency factories for FastAPI routes."""
|
||||
|
||||
import inspect
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from typing import Any, TypeVar, cast
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from .crud import CrudFactory
|
||||
from .types import ModelType, SessionDependency
|
||||
|
||||
__all__ = ["BodyDependency", "PathDependency"]
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
||||
SessionDependency = Callable[[], AsyncGenerator[AsyncSession, None]]
|
||||
|
||||
|
||||
def PathDependency(
|
||||
model: type[ModelType],
|
||||
|
||||
@@ -6,6 +6,7 @@ from .exceptions import (
|
||||
ConflictError,
|
||||
ForbiddenError,
|
||||
InvalidFacetFilterError,
|
||||
InvalidOrderFieldError,
|
||||
NoSearchableFieldsError,
|
||||
NotFoundError,
|
||||
UnauthorizedError,
|
||||
@@ -21,6 +22,7 @@ __all__ = [
|
||||
"generate_error_responses",
|
||||
"init_exceptions_handlers",
|
||||
"InvalidFacetFilterError",
|
||||
"InvalidOrderFieldError",
|
||||
"NoSearchableFieldsError",
|
||||
"NotFoundError",
|
||||
"UnauthorizedError",
|
||||
|
||||
@@ -128,6 +128,31 @@ class InvalidFacetFilterError(ApiException):
|
||||
super().__init__(detail)
|
||||
|
||||
|
||||
class InvalidOrderFieldError(ApiException):
|
||||
"""Raised when order_by contains a field not in the allowed order fields."""
|
||||
|
||||
api_error = ApiError(
|
||||
code=422,
|
||||
msg="Invalid Order Field",
|
||||
desc="The requested order field is not allowed for this resource.",
|
||||
err_code="SORT-422",
|
||||
)
|
||||
|
||||
def __init__(self, field: str, valid_fields: list[str]) -> None:
|
||||
"""Initialize the exception.
|
||||
|
||||
Args:
|
||||
field: The unknown order field provided by the caller
|
||||
valid_fields: List of valid field names
|
||||
"""
|
||||
self.field = field
|
||||
self.valid_fields = valid_fields
|
||||
detail = (
|
||||
f"'{field}' is not an allowed order field. Valid fields: {valid_fields}."
|
||||
)
|
||||
super().__init__(detail)
|
||||
|
||||
|
||||
def generate_error_responses(
|
||||
*errors: type[ApiException],
|
||||
) -> dict[int | str, dict[str, Any]]:
|
||||
|
||||
@@ -10,6 +10,10 @@ from fastapi.responses import JSONResponse
|
||||
from ..schemas import ErrorResponse, ResponseStatus
|
||||
from .exceptions import ApiException
|
||||
|
||||
_VALIDATION_LOCATION_PARAMS: frozenset[str] = frozenset(
|
||||
{"body", "query", "path", "header", "cookie"}
|
||||
)
|
||||
|
||||
|
||||
def init_exceptions_handlers(app: FastAPI) -> FastAPI:
|
||||
"""Register exception handlers and custom OpenAPI schema on a FastAPI app.
|
||||
@@ -106,9 +110,7 @@ def _format_validation_error(
|
||||
|
||||
for error in errors:
|
||||
field_path = ".".join(
|
||||
str(loc)
|
||||
for loc in error["loc"]
|
||||
if loc not in ("body", "query", "path", "header", "cookie")
|
||||
str(loc) for loc in error["loc"] if loc not in _VALIDATION_LOCATION_PARAMS
|
||||
)
|
||||
formatted_errors.append(
|
||||
{
|
||||
|
||||
@@ -1,24 +1,84 @@
|
||||
"""Fixture loading utilities for database seeding."""
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, TypeVar
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from ..db import get_transaction
|
||||
from ..logger import get_logger
|
||||
from ..types import ModelType
|
||||
from .enum import LoadStrategy
|
||||
from .registry import Context, FixtureRegistry
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
T = TypeVar("T", bound=DeclarativeBase)
|
||||
|
||||
async def _load_ordered(
|
||||
session: AsyncSession,
|
||||
registry: FixtureRegistry,
|
||||
ordered_names: list[str],
|
||||
strategy: LoadStrategy,
|
||||
) -> dict[str, list[DeclarativeBase]]:
|
||||
"""Load fixtures in order."""
|
||||
results: dict[str, list[DeclarativeBase]] = {}
|
||||
|
||||
for name in ordered_names:
|
||||
fixture = registry.get(name)
|
||||
instances = list(fixture.func())
|
||||
|
||||
if not instances:
|
||||
results[name] = []
|
||||
continue
|
||||
|
||||
model_name = type(instances[0]).__name__
|
||||
loaded: list[DeclarativeBase] = []
|
||||
|
||||
async with get_transaction(session):
|
||||
for instance in instances:
|
||||
if strategy == LoadStrategy.INSERT:
|
||||
session.add(instance)
|
||||
loaded.append(instance)
|
||||
|
||||
elif strategy == LoadStrategy.MERGE:
|
||||
merged = await session.merge(instance)
|
||||
loaded.append(merged)
|
||||
|
||||
else: # LoadStrategy.SKIP_EXISTING
|
||||
pk = _get_primary_key(instance)
|
||||
if pk is not None:
|
||||
existing = await session.get(type(instance), pk)
|
||||
if existing is None:
|
||||
session.add(instance)
|
||||
loaded.append(instance)
|
||||
else:
|
||||
session.add(instance)
|
||||
loaded.append(instance)
|
||||
|
||||
results[name] = loaded
|
||||
logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
|
||||
"""Get the primary key value of a model instance."""
|
||||
mapper = instance.__class__.__mapper__
|
||||
pk_cols = mapper.primary_key
|
||||
|
||||
if len(pk_cols) == 1:
|
||||
return getattr(instance, pk_cols[0].name, None)
|
||||
|
||||
pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols)
|
||||
if all(v is not None for v in pk_values):
|
||||
return pk_values
|
||||
return None
|
||||
|
||||
|
||||
def get_obj_by_attr(
|
||||
fixtures: Callable[[], Sequence[T]], attr_name: str, value: Any
|
||||
) -> T:
|
||||
fixtures: Callable[[], Sequence[ModelType]], attr_name: str, value: Any
|
||||
) -> ModelType:
|
||||
"""Get a SQLAlchemy model instance by matching an attribute value.
|
||||
|
||||
Args:
|
||||
@@ -57,13 +117,6 @@ async def load_fixtures(
|
||||
|
||||
Returns:
|
||||
Dict mapping fixture names to loaded instances
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Loads 'roles' first (dependency), then 'users'
|
||||
result = await load_fixtures(session, fixtures, "users")
|
||||
print(result["users"]) # [User(...), ...]
|
||||
```
|
||||
"""
|
||||
ordered = registry.resolve_dependencies(*names)
|
||||
return await _load_ordered(session, registry, ordered, strategy)
|
||||
@@ -85,76 +138,6 @@ async def load_fixtures_by_context(
|
||||
|
||||
Returns:
|
||||
Dict mapping fixture names to loaded instances
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Load base + testing fixtures
|
||||
await load_fixtures_by_context(
|
||||
session, fixtures,
|
||||
Context.BASE, Context.TESTING
|
||||
)
|
||||
```
|
||||
"""
|
||||
ordered = registry.resolve_context_dependencies(*contexts)
|
||||
return await _load_ordered(session, registry, ordered, strategy)
|
||||
|
||||
|
||||
async def _load_ordered(
|
||||
session: AsyncSession,
|
||||
registry: FixtureRegistry,
|
||||
ordered_names: list[str],
|
||||
strategy: LoadStrategy,
|
||||
) -> dict[str, list[DeclarativeBase]]:
|
||||
"""Load fixtures in order."""
|
||||
results: dict[str, list[DeclarativeBase]] = {}
|
||||
|
||||
for name in ordered_names:
|
||||
fixture = registry.get(name)
|
||||
instances = list(fixture.func())
|
||||
|
||||
if not instances:
|
||||
results[name] = []
|
||||
continue
|
||||
|
||||
model_name = type(instances[0]).__name__
|
||||
loaded: list[DeclarativeBase] = []
|
||||
|
||||
async with get_transaction(session):
|
||||
for instance in instances:
|
||||
if strategy == LoadStrategy.INSERT:
|
||||
session.add(instance)
|
||||
loaded.append(instance)
|
||||
|
||||
elif strategy == LoadStrategy.MERGE:
|
||||
merged = await session.merge(instance)
|
||||
loaded.append(merged)
|
||||
|
||||
elif strategy == LoadStrategy.SKIP_EXISTING:
|
||||
pk = _get_primary_key(instance)
|
||||
if pk is not None:
|
||||
existing = await session.get(type(instance), pk)
|
||||
if existing is None:
|
||||
session.add(instance)
|
||||
loaded.append(instance)
|
||||
else:
|
||||
session.add(instance)
|
||||
loaded.append(instance)
|
||||
|
||||
results[name] = loaded
|
||||
logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
|
||||
"""Get the primary key value of a model instance."""
|
||||
mapper = instance.__class__.__mapper__
|
||||
pk_cols = mapper.primary_key
|
||||
|
||||
if len(pk_cols) == 1:
|
||||
return getattr(instance, pk_cols[0].name, None)
|
||||
|
||||
pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols)
|
||||
if all(v is not None for v in pk_values):
|
||||
return pk_values
|
||||
return None
|
||||
|
||||
@@ -53,17 +53,23 @@ def init_metrics(
|
||||
logger.debug("Initialising metric provider '%s'", provider.name)
|
||||
provider.func()
|
||||
|
||||
collectors = registry.get_collectors()
|
||||
# Partition collectors and cache env check at startup — both are stable for the app lifetime.
|
||||
async_collectors = [
|
||||
c for c in registry.get_collectors() if asyncio.iscoroutinefunction(c.func)
|
||||
]
|
||||
sync_collectors = [
|
||||
c for c in registry.get_collectors() if not asyncio.iscoroutinefunction(c.func)
|
||||
]
|
||||
multiprocess_mode = _is_multiprocess()
|
||||
|
||||
@app.get(path, include_in_schema=False)
|
||||
async def metrics_endpoint() -> Response:
|
||||
for collector in collectors:
|
||||
if asyncio.iscoroutinefunction(collector.func):
|
||||
await collector.func()
|
||||
else:
|
||||
collector.func()
|
||||
for collector in sync_collectors:
|
||||
collector.func()
|
||||
for collector in async_collectors:
|
||||
await collector.func()
|
||||
|
||||
if _is_multiprocess():
|
||||
if multiprocess_mode:
|
||||
prom_registry = CollectorRegistry()
|
||||
multiprocess.MultiProcessCollector(prom_registry)
|
||||
output = generate_latest(prom_registry)
|
||||
|
||||
@@ -1,24 +1,23 @@
|
||||
"""Base Pydantic schemas for API responses."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, ClassVar, Generic, TypeVar
|
||||
from typing import Any, ClassVar, Generic
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from .types import DataT
|
||||
|
||||
__all__ = [
|
||||
"ApiError",
|
||||
"CursorPagination",
|
||||
"ErrorResponse",
|
||||
"OffsetPagination",
|
||||
"Pagination",
|
||||
"PaginatedResponse",
|
||||
"PydanticBase",
|
||||
"Response",
|
||||
"ResponseStatus",
|
||||
]
|
||||
|
||||
DataT = TypeVar("DataT")
|
||||
|
||||
|
||||
class PydanticBase(BaseModel):
|
||||
"""Base class for all Pydantic models with common configuration."""
|
||||
@@ -108,10 +107,6 @@ class OffsetPagination(PydanticBase):
|
||||
has_more: bool
|
||||
|
||||
|
||||
# Backward-compatible - will be removed in v2.0
|
||||
Pagination = OffsetPagination
|
||||
|
||||
|
||||
class CursorPagination(PydanticBase):
|
||||
"""Pagination metadata for cursor-based list responses.
|
||||
|
||||
|
||||
27
src/fastapi_toolsets/types.py
Normal file
27
src/fastapi_toolsets/types.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""Shared type aliases for the fastapi-toolsets package."""
|
||||
|
||||
from collections.abc import AsyncGenerator, Callable, Mapping
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import DeclarativeBase, QueryableAttribute
|
||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
|
||||
# Generic TypeVars
|
||||
DataT = TypeVar("DataT")
|
||||
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
||||
SchemaType = TypeVar("SchemaType", bound=BaseModel)
|
||||
|
||||
# CRUD type aliases
|
||||
JoinType = list[tuple[type[DeclarativeBase], Any]]
|
||||
M2MFieldType = Mapping[str, QueryableAttribute[Any]]
|
||||
OrderByClause = ColumnElement[Any] | QueryableAttribute[Any]
|
||||
|
||||
# Search / facet type aliases
|
||||
SearchFieldType = InstrumentedAttribute[Any] | tuple[InstrumentedAttribute[Any], ...]
|
||||
FacetFieldType = SearchFieldType
|
||||
|
||||
# Dependency type aliases
|
||||
SessionDependency = Callable[[], AsyncGenerator[AsyncSession, None]]
|
||||
@@ -92,6 +92,15 @@ class IntRole(Base):
|
||||
name: Mapped[str] = mapped_column(String(50), unique=True)
|
||||
|
||||
|
||||
class Permission(Base):
|
||||
"""Test model with composite primary key."""
|
||||
|
||||
__tablename__ = "permissions"
|
||||
|
||||
subject: Mapped[str] = mapped_column(String(50), primary_key=True)
|
||||
action: Mapped[str] = mapped_column(String(50), primary_key=True)
|
||||
|
||||
|
||||
class Event(Base):
|
||||
"""Test model with DateTime and Date cursor columns."""
|
||||
|
||||
@@ -162,6 +171,7 @@ class UserRead(PydanticBase):
|
||||
|
||||
id: uuid.UUID
|
||||
username: str
|
||||
is_active: bool = True
|
||||
|
||||
|
||||
class UserUpdate(BaseModel):
|
||||
@@ -218,12 +228,26 @@ class PostM2MUpdate(BaseModel):
|
||||
tag_ids: list[uuid.UUID] | None = None
|
||||
|
||||
|
||||
class IntRoleRead(PydanticBase):
|
||||
"""Schema for reading an IntRole."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
|
||||
|
||||
class IntRoleCreate(BaseModel):
|
||||
"""Schema for creating an IntRole."""
|
||||
|
||||
name: str
|
||||
|
||||
|
||||
class EventRead(PydanticBase):
|
||||
"""Schema for reading an Event."""
|
||||
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
|
||||
|
||||
class EventCreate(BaseModel):
|
||||
"""Schema for creating an Event."""
|
||||
|
||||
@@ -232,6 +256,13 @@ class EventCreate(BaseModel):
|
||||
scheduled_date: datetime.date
|
||||
|
||||
|
||||
class ProductRead(PydanticBase):
|
||||
"""Schema for reading a Product."""
|
||||
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
|
||||
|
||||
class ProductCreate(BaseModel):
|
||||
"""Schema for creating a Product."""
|
||||
|
||||
|
||||
@@ -15,8 +15,10 @@ from .conftest import (
|
||||
EventCrud,
|
||||
EventDateCursorCrud,
|
||||
EventDateTimeCursorCrud,
|
||||
EventRead,
|
||||
IntRoleCreate,
|
||||
IntRoleCursorCrud,
|
||||
IntRoleRead,
|
||||
Post,
|
||||
PostCreate,
|
||||
PostCrud,
|
||||
@@ -26,6 +28,7 @@ from .conftest import (
|
||||
ProductCreate,
|
||||
ProductCrud,
|
||||
ProductNumericCursorCrud,
|
||||
ProductRead,
|
||||
Role,
|
||||
RoleCreate,
|
||||
RoleCrud,
|
||||
@@ -169,7 +172,14 @@ class TestDefaultLoadOptionsIntegration:
|
||||
async def test_default_load_options_applied_to_paginate(
|
||||
self, db_session: AsyncSession
|
||||
):
|
||||
"""default_load_options loads relationships automatically on paginate()."""
|
||||
"""default_load_options loads relationships automatically on offset_paginate()."""
|
||||
from fastapi_toolsets.schemas import PydanticBase
|
||||
|
||||
class UserWithRoleRead(PydanticBase):
|
||||
id: uuid.UUID
|
||||
username: str
|
||||
role: RoleRead | None = None
|
||||
|
||||
UserWithDefaultLoad = CrudFactory(
|
||||
User, default_load_options=[selectinload(User.role)]
|
||||
)
|
||||
@@ -178,7 +188,9 @@ class TestDefaultLoadOptionsIntegration:
|
||||
db_session,
|
||||
UserCreate(username="alice", email="alice@test.com", role_id=role.id),
|
||||
)
|
||||
result = await UserWithDefaultLoad.paginate(db_session)
|
||||
result = await UserWithDefaultLoad.offset_paginate(
|
||||
db_session, schema=UserWithRoleRead
|
||||
)
|
||||
assert result.data[0].role is not None
|
||||
assert result.data[0].role.name == "admin"
|
||||
|
||||
@@ -430,7 +442,7 @@ class TestCrudDelete:
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="to_delete"))
|
||||
result = await RoleCrud.delete(db_session, [Role.id == role.id])
|
||||
|
||||
assert result is True
|
||||
assert result is None
|
||||
assert await RoleCrud.first(db_session, [Role.id == role.id]) is None
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -454,6 +466,20 @@ class TestCrudDelete:
|
||||
assert len(remaining) == 1
|
||||
assert remaining[0].username == "u3"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_return_response(self, db_session: AsyncSession):
|
||||
"""Delete with return_response=True returns Response[None]."""
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="to_delete_resp"))
|
||||
result = await RoleCrud.delete(
|
||||
db_session, [Role.id == role.id], return_response=True
|
||||
)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert result.data is None
|
||||
assert await RoleCrud.first(db_session, [Role.id == role.id]) is None
|
||||
|
||||
|
||||
class TestCrudExists:
|
||||
"""Tests for CRUD exists operations."""
|
||||
@@ -594,7 +620,9 @@ class TestCrudPaginate:
|
||||
|
||||
from fastapi_toolsets.schemas import OffsetPagination
|
||||
|
||||
result = await RoleCrud.paginate(db_session, page=1, items_per_page=10)
|
||||
result = await RoleCrud.offset_paginate(
|
||||
db_session, page=1, items_per_page=10, schema=RoleRead
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert len(result.data) == 10
|
||||
@@ -609,7 +637,9 @@ class TestCrudPaginate:
|
||||
for i in range(25):
|
||||
await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
|
||||
|
||||
result = await RoleCrud.paginate(db_session, page=3, items_per_page=10)
|
||||
result = await RoleCrud.offset_paginate(
|
||||
db_session, page=3, items_per_page=10, schema=RoleRead
|
||||
)
|
||||
|
||||
assert len(result.data) == 5
|
||||
assert result.pagination.has_more is False
|
||||
@@ -629,11 +659,12 @@ class TestCrudPaginate:
|
||||
|
||||
from fastapi_toolsets.schemas import OffsetPagination
|
||||
|
||||
result = await UserCrud.paginate(
|
||||
result = await UserCrud.offset_paginate(
|
||||
db_session,
|
||||
filters=[User.is_active == True], # noqa: E712
|
||||
page=1,
|
||||
items_per_page=10,
|
||||
schema=UserRead,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
@@ -646,11 +677,12 @@ class TestCrudPaginate:
|
||||
await RoleCrud.create(db_session, RoleCreate(name="alpha"))
|
||||
await RoleCrud.create(db_session, RoleCreate(name="bravo"))
|
||||
|
||||
result = await RoleCrud.paginate(
|
||||
result = await RoleCrud.offset_paginate(
|
||||
db_session,
|
||||
order_by=Role.name,
|
||||
page=1,
|
||||
items_per_page=10,
|
||||
schema=RoleRead,
|
||||
)
|
||||
|
||||
names = [r.name for r in result.data]
|
||||
@@ -855,12 +887,13 @@ class TestCrudJoins:
|
||||
from fastapi_toolsets.schemas import OffsetPagination
|
||||
|
||||
# Paginate users with published posts
|
||||
result = await UserCrud.paginate(
|
||||
result = await UserCrud.offset_paginate(
|
||||
db_session,
|
||||
joins=[(Post, Post.author_id == User.id)],
|
||||
filters=[Post.is_published == True], # noqa: E712
|
||||
page=1,
|
||||
items_per_page=10,
|
||||
schema=UserRead,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
@@ -889,12 +922,13 @@ class TestCrudJoins:
|
||||
from fastapi_toolsets.schemas import OffsetPagination
|
||||
|
||||
# Paginate with outer join
|
||||
result = await UserCrud.paginate(
|
||||
result = await UserCrud.offset_paginate(
|
||||
db_session,
|
||||
joins=[(Post, Post.author_id == User.id)],
|
||||
outer_join=True,
|
||||
page=1,
|
||||
items_per_page=10,
|
||||
schema=UserRead,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
@@ -931,70 +965,6 @@ class TestCrudJoins:
|
||||
assert users[0].username == "multi_join"
|
||||
|
||||
|
||||
class TestAsResponse:
|
||||
"""Tests for as_response parameter (deprecated, kept for backward compat)."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_as_response(self, db_session: AsyncSession):
|
||||
"""Create with as_response=True returns Response and emits DeprecationWarning."""
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
data = RoleCreate(name="response_role")
|
||||
with pytest.warns(DeprecationWarning, match="as_response is deprecated"):
|
||||
result = await RoleCrud.create(db_session, data, as_response=True)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert result.data is not None
|
||||
assert result.data.name == "response_role"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_as_response(self, db_session: AsyncSession):
|
||||
"""Get with as_response=True returns Response and emits DeprecationWarning."""
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
created = await RoleCrud.create(db_session, RoleCreate(name="get_response"))
|
||||
with pytest.warns(DeprecationWarning, match="as_response is deprecated"):
|
||||
result = await RoleCrud.get(
|
||||
db_session, [Role.id == created.id], as_response=True
|
||||
)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert result.data is not None
|
||||
assert result.data.id == created.id
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_as_response(self, db_session: AsyncSession):
|
||||
"""Update with as_response=True returns Response and emits DeprecationWarning."""
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
created = await RoleCrud.create(db_session, RoleCreate(name="old_name"))
|
||||
with pytest.warns(DeprecationWarning, match="as_response is deprecated"):
|
||||
result = await RoleCrud.update(
|
||||
db_session,
|
||||
RoleUpdate(name="new_name"),
|
||||
[Role.id == created.id],
|
||||
as_response=True,
|
||||
)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert result.data is not None
|
||||
assert result.data.name == "new_name"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_as_response(self, db_session: AsyncSession):
|
||||
"""Delete with as_response=True returns Response and emits DeprecationWarning."""
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
created = await RoleCrud.create(db_session, RoleCreate(name="to_delete"))
|
||||
with pytest.warns(DeprecationWarning, match="as_response is deprecated"):
|
||||
result = await RoleCrud.delete(
|
||||
db_session, [Role.id == created.id], as_response=True
|
||||
)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert result.data is None
|
||||
|
||||
|
||||
class TestCrudFactoryM2M:
|
||||
"""Tests for CrudFactory with m2m_fields parameter."""
|
||||
|
||||
@@ -1475,92 +1445,35 @@ class TestSchemaResponse:
|
||||
assert isinstance(result, Response)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_paginate_with_schema(self, db_session: AsyncSession):
|
||||
"""paginate with schema returns PaginatedResponse[SchemaType]."""
|
||||
async def test_offset_paginate_with_schema(self, db_session: AsyncSession):
|
||||
"""offset_paginate with schema returns PaginatedResponse[SchemaType]."""
|
||||
from fastapi_toolsets.schemas import PaginatedResponse
|
||||
|
||||
await RoleCrud.create(db_session, RoleCreate(name="p_role1"))
|
||||
await RoleCrud.create(db_session, RoleCreate(name="p_role2"))
|
||||
|
||||
result = await RoleCrud.paginate(db_session, schema=RoleRead)
|
||||
result = await RoleCrud.offset_paginate(db_session, schema=RoleRead)
|
||||
|
||||
assert isinstance(result, PaginatedResponse)
|
||||
assert len(result.data) == 2
|
||||
assert all(isinstance(item, RoleRead) for item in result.data)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_paginate_schema_filters_fields(self, db_session: AsyncSession):
|
||||
"""paginate with schema only exposes schema fields per item."""
|
||||
async def test_offset_paginate_schema_filters_fields(
|
||||
self, db_session: AsyncSession
|
||||
):
|
||||
"""offset_paginate with schema only exposes schema fields per item."""
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="pg_user", email="pg@test.com"),
|
||||
)
|
||||
|
||||
result = await UserCrud.paginate(db_session, schema=UserRead)
|
||||
result = await UserCrud.offset_paginate(db_session, schema=UserRead)
|
||||
|
||||
assert isinstance(result.data[0], UserRead)
|
||||
assert result.data[0].username == "pg_user"
|
||||
assert not hasattr(result.data[0], "email")
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_as_response_true_without_schema_unchanged(
|
||||
self, db_session: AsyncSession
|
||||
):
|
||||
"""as_response=True without schema still returns Response[ModelType] with a warning."""
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
created = await RoleCrud.create(db_session, RoleCreate(name="compat"))
|
||||
with pytest.warns(DeprecationWarning, match="as_response is deprecated"):
|
||||
result = await RoleCrud.get(
|
||||
db_session, [Role.id == created.id], as_response=True
|
||||
)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert isinstance(result.data, Role)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_schema_with_explicit_as_response_true(
|
||||
self, db_session: AsyncSession
|
||||
):
|
||||
"""schema combined with explicit as_response=True works correctly."""
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
created = await RoleCrud.create(db_session, RoleCreate(name="combined"))
|
||||
result = await RoleCrud.get(
|
||||
db_session,
|
||||
[Role.id == created.id],
|
||||
as_response=True,
|
||||
schema=RoleRead,
|
||||
)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert isinstance(result.data, RoleRead)
|
||||
|
||||
|
||||
class TestPaginateAlias:
|
||||
"""Tests that paginate is a backward-compatible alias for offset_paginate."""
|
||||
|
||||
def test_paginate_is_alias_of_offset_paginate(self):
|
||||
"""paginate and offset_paginate are the same underlying function."""
|
||||
assert RoleCrud.paginate.__func__ is RoleCrud.offset_paginate.__func__
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_paginate_alias_returns_offset_pagination(
|
||||
self, db_session: AsyncSession
|
||||
):
|
||||
"""paginate() still works and returns PaginatedResponse with OffsetPagination."""
|
||||
from fastapi_toolsets.schemas import OffsetPagination, PaginatedResponse
|
||||
|
||||
for i in range(3):
|
||||
await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
|
||||
|
||||
result = await RoleCrud.paginate(db_session, page=1, items_per_page=10)
|
||||
|
||||
assert isinstance(result, PaginatedResponse)
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 3
|
||||
assert result.pagination.page == 1
|
||||
|
||||
|
||||
class TestCursorPaginate:
|
||||
"""Tests for cursor-based pagination via cursor_paginate()."""
|
||||
@@ -1573,7 +1486,9 @@ class TestCursorPaginate:
|
||||
for i in range(25):
|
||||
await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
|
||||
|
||||
result = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=10)
|
||||
result = await RoleCursorCrud.cursor_paginate(
|
||||
db_session, items_per_page=10, schema=RoleRead
|
||||
)
|
||||
|
||||
assert isinstance(result, PaginatedResponse)
|
||||
assert isinstance(result.pagination, CursorPagination)
|
||||
@@ -1591,7 +1506,9 @@ class TestCursorPaginate:
|
||||
for i in range(5):
|
||||
await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
|
||||
|
||||
result = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=10)
|
||||
result = await RoleCursorCrud.cursor_paginate(
|
||||
db_session, items_per_page=10, schema=RoleRead
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, CursorPagination)
|
||||
assert len(result.data) == 5
|
||||
@@ -1606,14 +1523,16 @@ class TestCursorPaginate:
|
||||
|
||||
from fastapi_toolsets.schemas import CursorPagination
|
||||
|
||||
page1 = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=10)
|
||||
page1 = await RoleCursorCrud.cursor_paginate(
|
||||
db_session, items_per_page=10, schema=RoleRead
|
||||
)
|
||||
assert isinstance(page1.pagination, CursorPagination)
|
||||
assert len(page1.data) == 10
|
||||
assert page1.pagination.has_more is True
|
||||
|
||||
cursor = page1.pagination.next_cursor
|
||||
page2 = await RoleCursorCrud.cursor_paginate(
|
||||
db_session, cursor=cursor, items_per_page=10
|
||||
db_session, cursor=cursor, items_per_page=10, schema=RoleRead
|
||||
)
|
||||
assert isinstance(page2.pagination, CursorPagination)
|
||||
assert len(page2.data) == 5
|
||||
@@ -1628,12 +1547,15 @@ class TestCursorPaginate:
|
||||
|
||||
from fastapi_toolsets.schemas import CursorPagination
|
||||
|
||||
page1 = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=4)
|
||||
page1 = await RoleCursorCrud.cursor_paginate(
|
||||
db_session, items_per_page=4, schema=RoleRead
|
||||
)
|
||||
assert isinstance(page1.pagination, CursorPagination)
|
||||
page2 = await RoleCursorCrud.cursor_paginate(
|
||||
db_session,
|
||||
cursor=page1.pagination.next_cursor,
|
||||
items_per_page=4,
|
||||
schema=RoleRead,
|
||||
)
|
||||
|
||||
ids_page1 = {r.id for r in page1.data}
|
||||
@@ -1646,7 +1568,9 @@ class TestCursorPaginate:
|
||||
"""cursor_paginate on an empty table returns empty data with no cursor."""
|
||||
from fastapi_toolsets.schemas import CursorPagination
|
||||
|
||||
result = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=10)
|
||||
result = await RoleCursorCrud.cursor_paginate(
|
||||
db_session, items_per_page=10, schema=RoleRead
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, CursorPagination)
|
||||
assert result.data == []
|
||||
@@ -1671,6 +1595,7 @@ class TestCursorPaginate:
|
||||
db_session,
|
||||
filters=[User.is_active == True], # noqa: E712
|
||||
items_per_page=20,
|
||||
schema=UserRead,
|
||||
)
|
||||
|
||||
assert len(result.data) == 5
|
||||
@@ -1703,7 +1628,9 @@ class TestCursorPaginate:
|
||||
for i in range(5):
|
||||
await RoleNameCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
|
||||
|
||||
result = await RoleNameCrud.cursor_paginate(db_session, items_per_page=3)
|
||||
result = await RoleNameCrud.cursor_paginate(
|
||||
db_session, items_per_page=3, schema=RoleRead
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, CursorPagination)
|
||||
assert len(result.data) == 3
|
||||
@@ -1714,7 +1641,7 @@ class TestCursorPaginate:
|
||||
async def test_raises_without_cursor_column(self, db_session: AsyncSession):
|
||||
"""cursor_paginate raises ValueError when cursor_column is not configured."""
|
||||
with pytest.raises(ValueError, match="cursor_column is not set"):
|
||||
await RoleCrud.cursor_paginate(db_session)
|
||||
await RoleCrud.cursor_paginate(db_session, schema=RoleRead)
|
||||
|
||||
|
||||
class TestCursorPaginatePrevCursor:
|
||||
@@ -1728,7 +1655,9 @@ class TestCursorPaginatePrevCursor:
|
||||
|
||||
from fastapi_toolsets.schemas import CursorPagination
|
||||
|
||||
result = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=3)
|
||||
result = await RoleCursorCrud.cursor_paginate(
|
||||
db_session, items_per_page=3, schema=RoleRead
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, CursorPagination)
|
||||
assert result.pagination.prev_cursor is None
|
||||
@@ -1741,12 +1670,15 @@ class TestCursorPaginatePrevCursor:
|
||||
|
||||
from fastapi_toolsets.schemas import CursorPagination
|
||||
|
||||
page1 = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=5)
|
||||
page1 = await RoleCursorCrud.cursor_paginate(
|
||||
db_session, items_per_page=5, schema=RoleRead
|
||||
)
|
||||
assert isinstance(page1.pagination, CursorPagination)
|
||||
page2 = await RoleCursorCrud.cursor_paginate(
|
||||
db_session,
|
||||
cursor=page1.pagination.next_cursor,
|
||||
items_per_page=5,
|
||||
schema=RoleRead,
|
||||
)
|
||||
assert isinstance(page2.pagination, CursorPagination)
|
||||
assert page2.pagination.prev_cursor is not None
|
||||
@@ -1762,12 +1694,15 @@ class TestCursorPaginatePrevCursor:
|
||||
|
||||
from fastapi_toolsets.schemas import CursorPagination
|
||||
|
||||
page1 = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=5)
|
||||
page1 = await RoleCursorCrud.cursor_paginate(
|
||||
db_session, items_per_page=5, schema=RoleRead
|
||||
)
|
||||
assert isinstance(page1.pagination, CursorPagination)
|
||||
page2 = await RoleCursorCrud.cursor_paginate(
|
||||
db_session,
|
||||
cursor=page1.pagination.next_cursor,
|
||||
items_per_page=5,
|
||||
schema=RoleRead,
|
||||
)
|
||||
assert isinstance(page2.pagination, CursorPagination)
|
||||
assert page2.pagination.prev_cursor is not None
|
||||
@@ -1802,6 +1737,7 @@ class TestCursorPaginateWithSearch:
|
||||
db_session,
|
||||
search="admin",
|
||||
items_per_page=20,
|
||||
schema=RoleRead,
|
||||
)
|
||||
|
||||
assert len(result.data) == 5
|
||||
@@ -1836,6 +1772,7 @@ class TestCursorPaginateExtraOptions:
|
||||
db_session,
|
||||
joins=[(Role, User.role_id == Role.id)],
|
||||
items_per_page=20,
|
||||
schema=UserRead,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, CursorPagination)
|
||||
@@ -1867,6 +1804,7 @@ class TestCursorPaginateExtraOptions:
|
||||
joins=[(Role, User.role_id == Role.id)],
|
||||
outer_join=True,
|
||||
items_per_page=20,
|
||||
schema=UserRead,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, CursorPagination)
|
||||
@@ -1876,7 +1814,12 @@ class TestCursorPaginateExtraOptions:
|
||||
@pytest.mark.anyio
|
||||
async def test_with_load_options(self, db_session: AsyncSession):
|
||||
"""cursor_paginate passes load_options to the query."""
|
||||
from fastapi_toolsets.schemas import CursorPagination
|
||||
from fastapi_toolsets.schemas import CursorPagination, PydanticBase
|
||||
|
||||
class UserWithRoleRead(PydanticBase):
|
||||
id: uuid.UUID
|
||||
username: str
|
||||
role: RoleRead | None = None
|
||||
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="manager"))
|
||||
for i in range(3):
|
||||
@@ -1893,6 +1836,7 @@ class TestCursorPaginateExtraOptions:
|
||||
db_session,
|
||||
load_options=[selectinload(User.role)],
|
||||
items_per_page=20,
|
||||
schema=UserWithRoleRead,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, CursorPagination)
|
||||
@@ -1912,6 +1856,7 @@ class TestCursorPaginateExtraOptions:
|
||||
db_session,
|
||||
order_by=Role.name.desc(),
|
||||
items_per_page=3,
|
||||
schema=RoleRead,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, CursorPagination)
|
||||
@@ -1925,7 +1870,9 @@ class TestCursorPaginateExtraOptions:
|
||||
for i in range(5):
|
||||
await IntRoleCursorCrud.create(db_session, IntRoleCreate(name=f"role{i}"))
|
||||
|
||||
page1 = await IntRoleCursorCrud.cursor_paginate(db_session, items_per_page=3)
|
||||
page1 = await IntRoleCursorCrud.cursor_paginate(
|
||||
db_session, items_per_page=3, schema=IntRoleRead
|
||||
)
|
||||
|
||||
assert isinstance(page1.pagination, CursorPagination)
|
||||
assert len(page1.data) == 3
|
||||
@@ -1935,6 +1882,7 @@ class TestCursorPaginateExtraOptions:
|
||||
db_session,
|
||||
cursor=page1.pagination.next_cursor,
|
||||
items_per_page=3,
|
||||
schema=IntRoleRead,
|
||||
)
|
||||
|
||||
assert isinstance(page2.pagination, CursorPagination)
|
||||
@@ -1955,7 +1903,9 @@ class TestCursorPaginateExtraOptions:
|
||||
await RoleCrud.create(db_session, RoleCreate(name="role01"))
|
||||
|
||||
# First page succeeds (no cursor to decode)
|
||||
page1 = await RoleNameCursorCrud.cursor_paginate(db_session, items_per_page=1)
|
||||
page1 = await RoleNameCursorCrud.cursor_paginate(
|
||||
db_session, items_per_page=1, schema=RoleRead
|
||||
)
|
||||
assert page1.pagination.has_more is True
|
||||
assert isinstance(page1.pagination, CursorPagination)
|
||||
|
||||
@@ -1965,6 +1915,7 @@ class TestCursorPaginateExtraOptions:
|
||||
db_session,
|
||||
cursor=page1.pagination.next_cursor,
|
||||
items_per_page=1,
|
||||
schema=RoleRead,
|
||||
)
|
||||
|
||||
|
||||
@@ -2003,6 +1954,7 @@ class TestCursorPaginateSearchJoins:
|
||||
search="administrator",
|
||||
search_fields=[(User.role, Role.name)],
|
||||
items_per_page=20,
|
||||
schema=UserRead,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, CursorPagination)
|
||||
@@ -2049,7 +2001,7 @@ class TestCursorPaginateColumnTypes:
|
||||
)
|
||||
|
||||
page1 = await EventDateTimeCursorCrud.cursor_paginate(
|
||||
db_session, items_per_page=3
|
||||
db_session, items_per_page=3, schema=EventRead
|
||||
)
|
||||
|
||||
assert isinstance(page1.pagination, CursorPagination)
|
||||
@@ -2060,6 +2012,7 @@ class TestCursorPaginateColumnTypes:
|
||||
db_session,
|
||||
cursor=page1.pagination.next_cursor,
|
||||
items_per_page=3,
|
||||
schema=EventRead,
|
||||
)
|
||||
|
||||
assert isinstance(page2.pagination, CursorPagination)
|
||||
@@ -2087,7 +2040,9 @@ class TestCursorPaginateColumnTypes:
|
||||
),
|
||||
)
|
||||
|
||||
page1 = await EventDateCursorCrud.cursor_paginate(db_session, items_per_page=3)
|
||||
page1 = await EventDateCursorCrud.cursor_paginate(
|
||||
db_session, items_per_page=3, schema=EventRead
|
||||
)
|
||||
|
||||
assert isinstance(page1.pagination, CursorPagination)
|
||||
assert len(page1.data) == 3
|
||||
@@ -2097,6 +2052,7 @@ class TestCursorPaginateColumnTypes:
|
||||
db_session,
|
||||
cursor=page1.pagination.next_cursor,
|
||||
items_per_page=3,
|
||||
schema=EventRead,
|
||||
)
|
||||
|
||||
assert isinstance(page2.pagination, CursorPagination)
|
||||
@@ -2123,7 +2079,7 @@ class TestCursorPaginateColumnTypes:
|
||||
)
|
||||
|
||||
page1 = await ProductNumericCursorCrud.cursor_paginate(
|
||||
db_session, items_per_page=3
|
||||
db_session, items_per_page=3, schema=ProductRead
|
||||
)
|
||||
|
||||
assert isinstance(page1.pagination, CursorPagination)
|
||||
@@ -2134,6 +2090,7 @@ class TestCursorPaginateColumnTypes:
|
||||
db_session,
|
||||
cursor=page1.pagination.next_cursor,
|
||||
items_per_page=3,
|
||||
schema=ProductRead,
|
||||
)
|
||||
|
||||
assert isinstance(page2.pagination, CursorPagination)
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
"""Tests for CRUD search functionality."""
|
||||
|
||||
import inspect
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.sql.elements import ColumnElement, UnaryExpression
|
||||
|
||||
from fastapi_toolsets.crud import (
|
||||
CrudFactory,
|
||||
@@ -11,6 +13,7 @@ from fastapi_toolsets.crud import (
|
||||
SearchConfig,
|
||||
get_searchable_fields,
|
||||
)
|
||||
from fastapi_toolsets.exceptions import InvalidOrderFieldError
|
||||
from fastapi_toolsets.schemas import OffsetPagination
|
||||
|
||||
from .conftest import (
|
||||
@@ -20,6 +23,7 @@ from .conftest import (
|
||||
User,
|
||||
UserCreate,
|
||||
UserCrud,
|
||||
UserRead,
|
||||
)
|
||||
|
||||
|
||||
@@ -39,10 +43,11 @@ class TestPaginateSearch:
|
||||
db_session, UserCreate(username="bob_smith", email="bob@test.com")
|
||||
)
|
||||
|
||||
result = await UserCrud.paginate(
|
||||
result = await UserCrud.offset_paginate(
|
||||
db_session,
|
||||
search="doe",
|
||||
search_fields=[User.username],
|
||||
schema=UserRead,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
@@ -58,10 +63,11 @@ class TestPaginateSearch:
|
||||
db_session, UserCreate(username="company_bob", email="bob@other.com")
|
||||
)
|
||||
|
||||
result = await UserCrud.paginate(
|
||||
result = await UserCrud.offset_paginate(
|
||||
db_session,
|
||||
search="company",
|
||||
search_fields=[User.username, User.email],
|
||||
schema=UserRead,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
@@ -86,10 +92,11 @@ class TestPaginateSearch:
|
||||
UserCreate(username="user1", email="u1@test.com", role_id=user_role.id),
|
||||
)
|
||||
|
||||
result = await UserCrud.paginate(
|
||||
result = await UserCrud.offset_paginate(
|
||||
db_session,
|
||||
search="admin",
|
||||
search_fields=[(User.role, Role.name)],
|
||||
schema=UserRead,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
@@ -105,10 +112,11 @@ class TestPaginateSearch:
|
||||
)
|
||||
|
||||
# Search "admin" in username OR role.name
|
||||
result = await UserCrud.paginate(
|
||||
result = await UserCrud.offset_paginate(
|
||||
db_session,
|
||||
search="admin",
|
||||
search_fields=[User.username, (User.role, Role.name)],
|
||||
schema=UserRead,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
@@ -121,10 +129,11 @@ class TestPaginateSearch:
|
||||
db_session, UserCreate(username="JohnDoe", email="j@test.com")
|
||||
)
|
||||
|
||||
result = await UserCrud.paginate(
|
||||
result = await UserCrud.offset_paginate(
|
||||
db_session,
|
||||
search="johndoe",
|
||||
search_fields=[User.username],
|
||||
schema=UserRead,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
@@ -138,19 +147,21 @@ class TestPaginateSearch:
|
||||
)
|
||||
|
||||
# Should not find (case mismatch)
|
||||
result = await UserCrud.paginate(
|
||||
result = await UserCrud.offset_paginate(
|
||||
db_session,
|
||||
search=SearchConfig(query="johndoe", case_sensitive=True),
|
||||
search_fields=[User.username],
|
||||
schema=UserRead,
|
||||
)
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 0
|
||||
|
||||
# Should find (case match)
|
||||
result = await UserCrud.paginate(
|
||||
result = await UserCrud.offset_paginate(
|
||||
db_session,
|
||||
search=SearchConfig(query="JohnDoe", case_sensitive=True),
|
||||
search_fields=[User.username],
|
||||
schema=UserRead,
|
||||
)
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 1
|
||||
@@ -165,11 +176,13 @@ class TestPaginateSearch:
|
||||
db_session, UserCreate(username="user2", email="u2@test.com")
|
||||
)
|
||||
|
||||
result = await UserCrud.paginate(db_session, search="")
|
||||
result = await UserCrud.offset_paginate(db_session, search="", schema=UserRead)
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 2
|
||||
|
||||
result = await UserCrud.paginate(db_session, search=None)
|
||||
result = await UserCrud.offset_paginate(
|
||||
db_session, search=None, schema=UserRead
|
||||
)
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 2
|
||||
|
||||
@@ -185,11 +198,12 @@ class TestPaginateSearch:
|
||||
UserCreate(username="inactive_john", email="ij@test.com", is_active=False),
|
||||
)
|
||||
|
||||
result = await UserCrud.paginate(
|
||||
result = await UserCrud.offset_paginate(
|
||||
db_session,
|
||||
filters=[User.is_active == True], # noqa: E712
|
||||
search="john",
|
||||
search_fields=[User.username],
|
||||
schema=UserRead,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
@@ -203,7 +217,9 @@ class TestPaginateSearch:
|
||||
db_session, UserCreate(username="findme", email="other@test.com")
|
||||
)
|
||||
|
||||
result = await UserCrud.paginate(db_session, search="findme")
|
||||
result = await UserCrud.offset_paginate(
|
||||
db_session, search="findme", schema=UserRead
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 1
|
||||
@@ -215,10 +231,11 @@ class TestPaginateSearch:
|
||||
db_session, UserCreate(username="john", email="j@test.com")
|
||||
)
|
||||
|
||||
result = await UserCrud.paginate(
|
||||
result = await UserCrud.offset_paginate(
|
||||
db_session,
|
||||
search="nonexistent",
|
||||
search_fields=[User.username],
|
||||
schema=UserRead,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
@@ -234,12 +251,13 @@ class TestPaginateSearch:
|
||||
UserCreate(username=f"user_{i}", email=f"user{i}@test.com"),
|
||||
)
|
||||
|
||||
result = await UserCrud.paginate(
|
||||
result = await UserCrud.offset_paginate(
|
||||
db_session,
|
||||
search="user_",
|
||||
search_fields=[User.username],
|
||||
page=1,
|
||||
items_per_page=5,
|
||||
schema=UserRead,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
@@ -261,10 +279,11 @@ class TestPaginateSearch:
|
||||
)
|
||||
|
||||
# Search in username, not in role
|
||||
result = await UserCrud.paginate(
|
||||
result = await UserCrud.offset_paginate(
|
||||
db_session,
|
||||
search="role",
|
||||
search_fields=[User.username],
|
||||
schema=UserRead,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
@@ -283,11 +302,12 @@ class TestPaginateSearch:
|
||||
db_session, UserCreate(username="bob", email="b@test.com")
|
||||
)
|
||||
|
||||
result = await UserCrud.paginate(
|
||||
result = await UserCrud.offset_paginate(
|
||||
db_session,
|
||||
search="@test.com",
|
||||
search_fields=[User.email],
|
||||
order_by=User.username,
|
||||
schema=UserRead,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
@@ -307,10 +327,11 @@ class TestPaginateSearch:
|
||||
)
|
||||
|
||||
# Search by UUID (partial match)
|
||||
result = await UserCrud.paginate(
|
||||
result = await UserCrud.offset_paginate(
|
||||
db_session,
|
||||
search="12345678",
|
||||
search_fields=[User.id, User.username],
|
||||
schema=UserRead,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
@@ -360,10 +381,11 @@ class TestSearchConfig:
|
||||
)
|
||||
|
||||
# 'john' must be in username AND email
|
||||
result = await UserCrud.paginate(
|
||||
result = await UserCrud.offset_paginate(
|
||||
db_session,
|
||||
search=SearchConfig(query="john", match_mode="all"),
|
||||
search_fields=[User.username, User.email],
|
||||
schema=UserRead,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
@@ -377,9 +399,10 @@ class TestSearchConfig:
|
||||
db_session, UserCreate(username="test", email="findme@test.com")
|
||||
)
|
||||
|
||||
result = await UserCrud.paginate(
|
||||
result = await UserCrud.offset_paginate(
|
||||
db_session,
|
||||
search=SearchConfig(query="findme", fields=[User.email]),
|
||||
schema=UserRead,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
@@ -475,7 +498,7 @@ class TestFacetsNotSet:
|
||||
db_session, UserCreate(username="alice", email="a@test.com")
|
||||
)
|
||||
|
||||
result = await UserCrud.offset_paginate(db_session)
|
||||
result = await UserCrud.offset_paginate(db_session, schema=UserRead)
|
||||
|
||||
assert result.filter_attributes is None
|
||||
|
||||
@@ -487,7 +510,7 @@ class TestFacetsNotSet:
|
||||
db_session, UserCreate(username="alice", email="a@test.com")
|
||||
)
|
||||
|
||||
result = await UserCursorCrud.cursor_paginate(db_session)
|
||||
result = await UserCursorCrud.cursor_paginate(db_session, schema=UserRead)
|
||||
|
||||
assert result.filter_attributes is None
|
||||
|
||||
@@ -506,7 +529,7 @@ class TestFacetsDirectColumn:
|
||||
db_session, UserCreate(username="bob", email="b@test.com")
|
||||
)
|
||||
|
||||
result = await UserFacetCrud.offset_paginate(db_session)
|
||||
result = await UserFacetCrud.offset_paginate(db_session, schema=UserRead)
|
||||
|
||||
assert result.filter_attributes is not None
|
||||
# Distinct usernames, sorted
|
||||
@@ -525,7 +548,7 @@ class TestFacetsDirectColumn:
|
||||
db_session, UserCreate(username="bob", email="b@test.com")
|
||||
)
|
||||
|
||||
result = await UserFacetCursorCrud.cursor_paginate(db_session)
|
||||
result = await UserFacetCursorCrud.cursor_paginate(db_session, schema=UserRead)
|
||||
|
||||
assert result.filter_attributes is not None
|
||||
assert set(result.filter_attributes["email"]) == {"a@test.com", "b@test.com"}
|
||||
@@ -541,7 +564,7 @@ class TestFacetsDirectColumn:
|
||||
db_session, UserCreate(username="bob", email="b@test.com")
|
||||
)
|
||||
|
||||
result = await UserFacetCrud.offset_paginate(db_session)
|
||||
result = await UserFacetCrud.offset_paginate(db_session, schema=UserRead)
|
||||
|
||||
assert result.filter_attributes is not None
|
||||
assert "username" in result.filter_attributes
|
||||
@@ -558,7 +581,7 @@ class TestFacetsDirectColumn:
|
||||
|
||||
# Override: ask for email instead of username
|
||||
result = await UserFacetCrud.offset_paginate(
|
||||
db_session, facet_fields=[User.email]
|
||||
db_session, facet_fields=[User.email], schema=UserRead
|
||||
)
|
||||
|
||||
assert result.filter_attributes is not None
|
||||
@@ -584,6 +607,7 @@ class TestFacetsRespectFilters:
|
||||
result = await UserFacetCrud.offset_paginate(
|
||||
db_session,
|
||||
filters=[User.is_active == True], # noqa: E712
|
||||
schema=UserRead,
|
||||
)
|
||||
|
||||
assert result.filter_attributes is not None
|
||||
@@ -614,7 +638,7 @@ class TestFacetsRelationship:
|
||||
db_session, UserCreate(username="charlie", email="c@test.com")
|
||||
)
|
||||
|
||||
result = await UserRelFacetCrud.offset_paginate(db_session)
|
||||
result = await UserRelFacetCrud.offset_paginate(db_session, schema=UserRead)
|
||||
|
||||
assert result.filter_attributes is not None
|
||||
assert set(result.filter_attributes["name"]) == {"admin", "editor"}
|
||||
@@ -629,7 +653,7 @@ class TestFacetsRelationship:
|
||||
db_session, UserCreate(username="norole", email="n@test.com")
|
||||
)
|
||||
|
||||
result = await UserRelFacetCrud.offset_paginate(db_session)
|
||||
result = await UserRelFacetCrud.offset_paginate(db_session, schema=UserRead)
|
||||
|
||||
assert result.filter_attributes is not None
|
||||
assert result.filter_attributes["name"] == []
|
||||
@@ -653,7 +677,10 @@ class TestFacetsRelationship:
|
||||
)
|
||||
|
||||
result = await UserSearchFacetCrud.offset_paginate(
|
||||
db_session, search="admin", search_fields=[(User.role, Role.name)]
|
||||
db_session,
|
||||
search="admin",
|
||||
search_fields=[(User.role, Role.name)],
|
||||
schema=UserRead,
|
||||
)
|
||||
|
||||
assert result.filter_attributes is not None
|
||||
@@ -675,7 +702,7 @@ class TestFilterBy:
|
||||
)
|
||||
|
||||
result = await UserFacetCrud.offset_paginate(
|
||||
db_session, filter_by={"username": "alice"}
|
||||
db_session, filter_by={"username": "alice"}, schema=UserRead
|
||||
)
|
||||
|
||||
assert len(result.data) == 1
|
||||
@@ -698,7 +725,7 @@ class TestFilterBy:
|
||||
)
|
||||
|
||||
result = await UserFacetCrud.offset_paginate(
|
||||
db_session, filter_by={"username": ["alice", "bob"]}
|
||||
db_session, filter_by={"username": ["alice", "bob"]}, schema=UserRead
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
@@ -723,7 +750,7 @@ class TestFilterBy:
|
||||
)
|
||||
|
||||
result = await UserRelFacetCrud.offset_paginate(
|
||||
db_session, filter_by={"name": "admin"}
|
||||
db_session, filter_by={"name": "admin"}, schema=UserRead
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
@@ -746,6 +773,7 @@ class TestFilterBy:
|
||||
db_session,
|
||||
filters=[User.is_active == True], # noqa: E712
|
||||
filter_by={"username": ["alice", "alice2"]},
|
||||
schema=UserRead,
|
||||
)
|
||||
|
||||
# Only alice passes both: is_active=True AND username IN [alice, alice2]
|
||||
@@ -760,7 +788,7 @@ class TestFilterBy:
|
||||
|
||||
with pytest.raises(InvalidFacetFilterError) as exc_info:
|
||||
await UserFacetCrud.offset_paginate(
|
||||
db_session, filter_by={"nonexistent": "value"}
|
||||
db_session, filter_by={"nonexistent": "value"}, schema=UserRead
|
||||
)
|
||||
|
||||
assert exc_info.value.key == "nonexistent"
|
||||
@@ -792,6 +820,7 @@ class TestFilterBy:
|
||||
result = await UserRoleFacetCrud.offset_paginate(
|
||||
db_session,
|
||||
filter_by={"name": "admin", "id": str(admin.id)},
|
||||
schema=UserRead,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
@@ -812,7 +841,7 @@ class TestFilterBy:
|
||||
)
|
||||
|
||||
result = await UserFacetCursorCrud.cursor_paginate(
|
||||
db_session, filter_by={"username": "alice"}
|
||||
db_session, filter_by={"username": "alice"}, schema=UserRead
|
||||
)
|
||||
|
||||
assert len(result.data) == 1
|
||||
@@ -836,7 +865,7 @@ class TestFilterBy:
|
||||
)
|
||||
|
||||
result = await UserFacetCrud.offset_paginate(
|
||||
db_session, filter_by=UserFilter(username="alice")
|
||||
db_session, filter_by=UserFilter(username="alice"), schema=UserRead
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
@@ -862,7 +891,7 @@ class TestFilterBy:
|
||||
)
|
||||
|
||||
result = await UserFacetCursorCrud.cursor_paginate(
|
||||
db_session, filter_by=UserFilter(username="alice")
|
||||
db_session, filter_by=UserFilter(username="alice"), schema=UserRead
|
||||
)
|
||||
|
||||
assert len(result.data) == 1
|
||||
@@ -971,7 +1000,9 @@ class TestFilterParamsSchema:
|
||||
|
||||
dep = UserFacetCrud.filter_params()
|
||||
f = await dep(username=["alice"])
|
||||
result = await UserFacetCrud.offset_paginate(db_session, filter_by=f)
|
||||
result = await UserFacetCrud.offset_paginate(
|
||||
db_session, filter_by=f, schema=UserRead
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 1
|
||||
@@ -992,7 +1023,9 @@ class TestFilterParamsSchema:
|
||||
|
||||
dep = UserFacetCursorCrud.filter_params()
|
||||
f = await dep(username=["alice"])
|
||||
result = await UserFacetCursorCrud.cursor_paginate(db_session, filter_by=f)
|
||||
result = await UserFacetCursorCrud.cursor_paginate(
|
||||
db_session, filter_by=f, schema=UserRead
|
||||
)
|
||||
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0].username == "alice"
|
||||
@@ -1010,7 +1043,150 @@ class TestFilterParamsSchema:
|
||||
|
||||
dep = UserFacetCrud.filter_params()
|
||||
f = await dep() # all fields None
|
||||
result = await UserFacetCrud.offset_paginate(db_session, filter_by=f)
|
||||
result = await UserFacetCrud.offset_paginate(
|
||||
db_session, filter_by=f, schema=UserRead
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 2
|
||||
|
||||
|
||||
class TestOrderParamsSchema:
|
||||
"""Tests for AsyncCrud.order_params()."""
|
||||
|
||||
def test_generates_order_by_and_order_params(self):
|
||||
"""Returned dependency has order_by and order query params."""
|
||||
UserOrderCrud = CrudFactory(User, order_fields=[User.username, User.email])
|
||||
dep = UserOrderCrud.order_params()
|
||||
|
||||
param_names = set(inspect.signature(dep).parameters)
|
||||
assert param_names == {"order_by", "order"}
|
||||
|
||||
def test_dependency_name_includes_model_name(self):
|
||||
"""Dependency function is named after the model."""
|
||||
UserOrderCrud = CrudFactory(User, order_fields=[User.username])
|
||||
dep = UserOrderCrud.order_params()
|
||||
assert getattr(dep, "__name__") == "UserOrderParams"
|
||||
|
||||
def test_raises_when_no_order_fields(self):
|
||||
"""ValueError raised when no order_fields are configured or provided."""
|
||||
with pytest.raises(ValueError, match="no order_fields"):
|
||||
UserCrud.order_params()
|
||||
|
||||
def test_order_fields_override(self):
|
||||
"""order_fields= parameter overrides the class-level default."""
|
||||
UserOrderCrud = CrudFactory(User, order_fields=[User.username, User.email])
|
||||
dep = UserOrderCrud.order_params(order_fields=[User.email])
|
||||
|
||||
param_names = set(inspect.signature(dep).parameters)
|
||||
assert "order_by" in param_names
|
||||
# description should only mention email, not username
|
||||
sig = inspect.signature(dep)
|
||||
description = sig.parameters["order_by"].default.description
|
||||
assert "email" in description
|
||||
assert "username" not in description
|
||||
|
||||
def test_order_by_description_lists_valid_fields(self):
|
||||
"""order_by query param description mentions each allowed field."""
|
||||
UserOrderCrud = CrudFactory(User, order_fields=[User.username, User.email])
|
||||
dep = UserOrderCrud.order_params()
|
||||
|
||||
sig = inspect.signature(dep)
|
||||
description = sig.parameters["order_by"].default.description
|
||||
assert "username" in description
|
||||
assert "email" in description
|
||||
|
||||
def test_default_order_reflected_in_order_default(self):
|
||||
"""default_order is used as the default value for order."""
|
||||
UserOrderCrud = CrudFactory(User, order_fields=[User.username])
|
||||
dep_asc = UserOrderCrud.order_params(default_order="asc")
|
||||
dep_desc = UserOrderCrud.order_params(default_order="desc")
|
||||
|
||||
sig_asc = inspect.signature(dep_asc)
|
||||
sig_desc = inspect.signature(dep_desc)
|
||||
assert sig_asc.parameters["order"].default.default == "asc"
|
||||
assert sig_desc.parameters["order"].default.default == "desc"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_no_order_by_no_default_returns_none(self):
|
||||
"""Returns None when order_by is absent and no default_field is set."""
|
||||
UserOrderCrud = CrudFactory(User, order_fields=[User.username])
|
||||
dep = UserOrderCrud.order_params()
|
||||
result = await dep(order_by=None, order="asc")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_no_order_by_with_default_field_returns_asc_expression(self):
|
||||
"""Returns default_field.asc() when order_by absent and order=asc."""
|
||||
UserOrderCrud = CrudFactory(User, order_fields=[User.username])
|
||||
dep = UserOrderCrud.order_params(default_field=User.username)
|
||||
result = await dep(order_by=None, order="asc")
|
||||
assert isinstance(result, UnaryExpression)
|
||||
assert "ASC" in str(result)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_no_order_by_with_default_field_returns_desc_expression(self):
|
||||
"""Returns default_field.desc() when order_by absent and order=desc."""
|
||||
UserOrderCrud = CrudFactory(User, order_fields=[User.username])
|
||||
dep = UserOrderCrud.order_params(default_field=User.username)
|
||||
result = await dep(order_by=None, order="desc")
|
||||
assert isinstance(result, UnaryExpression)
|
||||
assert "DESC" in str(result)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_valid_order_by_asc(self):
|
||||
"""Returns field.asc() for a valid order_by with order=asc."""
|
||||
UserOrderCrud = CrudFactory(User, order_fields=[User.username])
|
||||
dep = UserOrderCrud.order_params()
|
||||
result = await dep(order_by="username", order="asc")
|
||||
assert isinstance(result, UnaryExpression)
|
||||
assert "ASC" in str(result)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_valid_order_by_desc(self):
|
||||
"""Returns field.desc() for a valid order_by with order=desc."""
|
||||
UserOrderCrud = CrudFactory(User, order_fields=[User.username])
|
||||
dep = UserOrderCrud.order_params()
|
||||
result = await dep(order_by="username", order="desc")
|
||||
assert isinstance(result, UnaryExpression)
|
||||
assert "DESC" in str(result)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_invalid_order_by_raises_invalid_order_field_error(self):
|
||||
"""Raises InvalidOrderFieldError for an unknown order_by value."""
|
||||
UserOrderCrud = CrudFactory(User, order_fields=[User.username])
|
||||
dep = UserOrderCrud.order_params()
|
||||
with pytest.raises(InvalidOrderFieldError) as exc_info:
|
||||
await dep(order_by="nonexistent", order="asc")
|
||||
assert exc_info.value.field == "nonexistent"
|
||||
assert "username" in exc_info.value.valid_fields
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_multiple_fields_all_resolve(self):
|
||||
"""All configured fields resolve correctly via order_by."""
|
||||
UserOrderCrud = CrudFactory(User, order_fields=[User.username, User.email])
|
||||
dep = UserOrderCrud.order_params()
|
||||
result_username = await dep(order_by="username", order="asc")
|
||||
result_email = await dep(order_by="email", order="desc")
|
||||
assert isinstance(result_username, ColumnElement)
|
||||
assert isinstance(result_email, ColumnElement)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_order_params_integrates_with_get_multi(
|
||||
self, db_session: AsyncSession
|
||||
):
|
||||
"""order_params output is accepted by get_multi(order_by=...)."""
|
||||
UserOrderCrud = CrudFactory(User, order_fields=[User.username])
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="charlie", email="c@test.com")
|
||||
)
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="alice", email="a@test.com")
|
||||
)
|
||||
|
||||
dep = UserOrderCrud.order_params()
|
||||
order_by = await dep(order_by="username", order="asc")
|
||||
results = await UserOrderCrud.get_multi(db_session, order_by=order_by)
|
||||
|
||||
assert results[0].username == "alice"
|
||||
assert results[1].username == "charlie"
|
||||
|
||||
@@ -15,12 +15,14 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_asyn
|
||||
from docs_src.examples.pagination_search.db import get_db
|
||||
from docs_src.examples.pagination_search.models import Article, Base, Category
|
||||
from docs_src.examples.pagination_search.routes import router
|
||||
from fastapi_toolsets.exceptions import init_exceptions_handlers
|
||||
|
||||
from .conftest import DATABASE_URL
|
||||
|
||||
|
||||
def build_app(session: AsyncSession) -> FastAPI:
|
||||
app = FastAPI()
|
||||
init_exceptions_handlers(app)
|
||||
|
||||
async def override_get_db():
|
||||
yield session
|
||||
@@ -269,3 +271,125 @@ class TestCursorPagination:
|
||||
body = resp.json()
|
||||
assert len(body["data"]) == 1
|
||||
assert body["data"][0]["title"] == "SQLAlchemy async"
|
||||
|
||||
|
||||
class TestOffsetSorting:
|
||||
"""Tests for order_by / order query parameters on the offset endpoint."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_default_order_uses_created_at_asc(
|
||||
self, client: AsyncClient, ex_db_session
|
||||
):
|
||||
"""No order_by → default field (created_at) ASC."""
|
||||
await seed(ex_db_session)
|
||||
|
||||
resp = await client.get("/articles/offset")
|
||||
|
||||
assert resp.status_code == 200
|
||||
titles = [a["title"] for a in resp.json()["data"]]
|
||||
assert titles == ["FastAPI tips", "SQLAlchemy async", "Draft notes"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_order_by_title_asc(self, client: AsyncClient, ex_db_session):
|
||||
"""order_by=title&order=asc returns alphabetical order."""
|
||||
await seed(ex_db_session)
|
||||
|
||||
resp = await client.get("/articles/offset?order_by=title&order=asc")
|
||||
|
||||
assert resp.status_code == 200
|
||||
titles = [a["title"] for a in resp.json()["data"]]
|
||||
assert titles == ["Draft notes", "FastAPI tips", "SQLAlchemy async"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_order_by_title_desc(self, client: AsyncClient, ex_db_session):
|
||||
"""order_by=title&order=desc returns reverse alphabetical order."""
|
||||
await seed(ex_db_session)
|
||||
|
||||
resp = await client.get("/articles/offset?order_by=title&order=desc")
|
||||
|
||||
assert resp.status_code == 200
|
||||
titles = [a["title"] for a in resp.json()["data"]]
|
||||
assert titles == ["SQLAlchemy async", "FastAPI tips", "Draft notes"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_order_by_created_at_desc(self, client: AsyncClient, ex_db_session):
|
||||
"""order_by=created_at&order=desc returns newest-first."""
|
||||
await seed(ex_db_session)
|
||||
|
||||
resp = await client.get("/articles/offset?order_by=created_at&order=desc")
|
||||
|
||||
assert resp.status_code == 200
|
||||
titles = [a["title"] for a in resp.json()["data"]]
|
||||
assert titles == ["Draft notes", "SQLAlchemy async", "FastAPI tips"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_invalid_order_by_returns_422(
|
||||
self, client: AsyncClient, ex_db_session
|
||||
):
|
||||
"""Unknown order_by field returns 422 with SORT-422 error code."""
|
||||
resp = await client.get("/articles/offset?order_by=nonexistent_field")
|
||||
|
||||
assert resp.status_code == 422
|
||||
body = resp.json()
|
||||
assert body["error_code"] == "SORT-422"
|
||||
assert body["status"] == "FAIL"
|
||||
|
||||
|
||||
class TestCursorSorting:
|
||||
"""Tests for order_by / order query parameters on the cursor endpoint.
|
||||
|
||||
In cursor_paginate the cursor_column is always the primary sort; order_by
|
||||
acts as a secondary tiebreaker. With the seeded articles (all having unique
|
||||
created_at values) the overall ordering is always created_at ASC regardless
|
||||
of the order_by value — only the valid/invalid field check and the response
|
||||
shape are meaningful here.
|
||||
"""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_default_order_uses_created_at_asc(
|
||||
self, client: AsyncClient, ex_db_session
|
||||
):
|
||||
"""No order_by → default field (created_at) ASC."""
|
||||
await seed(ex_db_session)
|
||||
|
||||
resp = await client.get("/articles/cursor")
|
||||
|
||||
assert resp.status_code == 200
|
||||
titles = [a["title"] for a in resp.json()["data"]]
|
||||
assert titles == ["FastAPI tips", "SQLAlchemy async", "Draft notes"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_order_by_title_asc_accepted(
|
||||
self, client: AsyncClient, ex_db_session
|
||||
):
|
||||
"""order_by=title is a valid field — request succeeds and returns all articles."""
|
||||
await seed(ex_db_session)
|
||||
|
||||
resp = await client.get("/articles/cursor?order_by=title&order=asc")
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()["data"]) == 3
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_order_by_title_desc_accepted(
|
||||
self, client: AsyncClient, ex_db_session
|
||||
):
|
||||
"""order_by=title&order=desc is valid — request succeeds and returns all articles."""
|
||||
await seed(ex_db_session)
|
||||
|
||||
resp = await client.get("/articles/cursor?order_by=title&order=desc")
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()["data"]) == 3
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_invalid_order_by_returns_422(
|
||||
self, client: AsyncClient, ex_db_session
|
||||
):
|
||||
"""Unknown order_by field returns 422 with SORT-422 error code."""
|
||||
resp = await client.get("/articles/cursor?order_by=nonexistent_field")
|
||||
|
||||
assert resp.status_code == 422
|
||||
body = resp.json()
|
||||
assert body["error_code"] == "SORT-422"
|
||||
assert body["status"] == "FAIL"
|
||||
|
||||
@@ -8,6 +8,7 @@ from fastapi_toolsets.exceptions import (
|
||||
ApiException,
|
||||
ConflictError,
|
||||
ForbiddenError,
|
||||
InvalidOrderFieldError,
|
||||
NotFoundError,
|
||||
UnauthorizedError,
|
||||
generate_error_responses,
|
||||
@@ -334,3 +335,43 @@ class TestExceptionIntegration:
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"id": 1}
|
||||
|
||||
|
||||
class TestInvalidOrderFieldError:
|
||||
"""Tests for InvalidOrderFieldError exception."""
|
||||
|
||||
def test_api_error_attributes(self):
|
||||
"""InvalidOrderFieldError has correct api_error metadata."""
|
||||
assert InvalidOrderFieldError.api_error.code == 422
|
||||
assert InvalidOrderFieldError.api_error.err_code == "SORT-422"
|
||||
assert InvalidOrderFieldError.api_error.msg == "Invalid Order Field"
|
||||
|
||||
def test_stores_field_and_valid_fields(self):
|
||||
"""InvalidOrderFieldError stores field and valid_fields on the instance."""
|
||||
error = InvalidOrderFieldError("unknown", ["name", "created_at"])
|
||||
assert error.field == "unknown"
|
||||
assert error.valid_fields == ["name", "created_at"]
|
||||
|
||||
def test_message_contains_field_and_valid_fields(self):
|
||||
"""Exception message mentions the bad field and valid options."""
|
||||
error = InvalidOrderFieldError("bad_field", ["name", "email"])
|
||||
assert "bad_field" in str(error)
|
||||
assert "name" in str(error)
|
||||
assert "email" in str(error)
|
||||
|
||||
def test_handled_as_422_by_exception_handler(self):
|
||||
"""init_exceptions_handlers turns InvalidOrderFieldError into a 422 response."""
|
||||
app = FastAPI()
|
||||
init_exceptions_handlers(app)
|
||||
|
||||
@app.get("/items")
|
||||
async def list_items():
|
||||
raise InvalidOrderFieldError("bad", ["name"])
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/items")
|
||||
|
||||
assert response.status_code == 422
|
||||
data = response.json()
|
||||
assert data["error_code"] == "SORT-422"
|
||||
assert data["status"] == "FAIL"
|
||||
|
||||
@@ -14,7 +14,9 @@ from fastapi_toolsets.fixtures import (
|
||||
load_fixtures_by_context,
|
||||
)
|
||||
|
||||
from .conftest import Role, User
|
||||
from fastapi_toolsets.fixtures.utils import _get_primary_key
|
||||
|
||||
from .conftest import IntRole, Permission, Role, User
|
||||
|
||||
|
||||
class TestContext:
|
||||
@@ -597,6 +599,46 @@ class TestLoadFixtures:
|
||||
count = await RoleCrud.count(db_session)
|
||||
assert count == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_skip_existing_skips_if_record_exists(self, db_session: AsyncSession):
|
||||
"""SKIP_EXISTING returns empty loaded list when the record already exists."""
|
||||
registry = FixtureRegistry()
|
||||
role_id = uuid.uuid4()
|
||||
|
||||
@registry.register
|
||||
def roles():
|
||||
return [Role(id=role_id, name="admin")]
|
||||
|
||||
# First load — inserts the record.
|
||||
result1 = await load_fixtures(
|
||||
db_session, registry, "roles", strategy=LoadStrategy.SKIP_EXISTING
|
||||
)
|
||||
assert len(result1["roles"]) == 1
|
||||
|
||||
# Remove from identity map so session.get() queries the DB in the second load.
|
||||
db_session.expunge_all()
|
||||
|
||||
# Second load — record exists in DB, nothing should be added.
|
||||
result2 = await load_fixtures(
|
||||
db_session, registry, "roles", strategy=LoadStrategy.SKIP_EXISTING
|
||||
)
|
||||
assert result2["roles"] == []
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_skip_existing_null_pk_inserts(self, db_session: AsyncSession):
|
||||
"""SKIP_EXISTING inserts when the instance has no PK set (auto-increment)."""
|
||||
registry = FixtureRegistry()
|
||||
|
||||
@registry.register
|
||||
def int_roles():
|
||||
# No id provided — PK is None before INSERT (autoincrement).
|
||||
return [IntRole(name="member")]
|
||||
|
||||
result = await load_fixtures(
|
||||
db_session, registry, "int_roles", strategy=LoadStrategy.SKIP_EXISTING
|
||||
)
|
||||
assert len(result["int_roles"]) == 1
|
||||
|
||||
|
||||
class TestLoadFixturesByContext:
|
||||
"""Tests for load_fixtures_by_context function."""
|
||||
@@ -755,3 +797,19 @@ class TestGetObjByAttr:
|
||||
"""Raises StopIteration when value type doesn't match."""
|
||||
with pytest.raises(StopIteration):
|
||||
get_obj_by_attr(self.roles, "id", "not-a-uuid")
|
||||
|
||||
|
||||
class TestGetPrimaryKey:
|
||||
"""Unit tests for the _get_primary_key helper (composite PK paths)."""
|
||||
|
||||
def test_composite_pk_all_set(self):
|
||||
"""Returns a tuple when all composite PK values are set."""
|
||||
instance = Permission(subject="post", action="read")
|
||||
pk = _get_primary_key(instance)
|
||||
assert pk == ("post", "read")
|
||||
|
||||
def test_composite_pk_partial_none(self):
|
||||
"""Returns None when any composite PK value is None."""
|
||||
instance = Permission(subject="post") # action is None
|
||||
pk = _get_primary_key(instance)
|
||||
assert pk is None
|
||||
|
||||
@@ -9,7 +9,6 @@ from fastapi_toolsets.schemas import (
|
||||
ErrorResponse,
|
||||
OffsetPagination,
|
||||
PaginatedResponse,
|
||||
Pagination,
|
||||
Response,
|
||||
ResponseStatus,
|
||||
)
|
||||
@@ -199,20 +198,6 @@ class TestOffsetPagination:
|
||||
assert data["page"] == 2
|
||||
assert data["has_more"] is True
|
||||
|
||||
def test_pagination_alias_is_offset_pagination(self):
|
||||
"""Pagination is a backward-compatible alias for OffsetPagination."""
|
||||
assert Pagination is OffsetPagination
|
||||
|
||||
def test_pagination_alias_constructs_offset_pagination(self):
|
||||
"""Code using Pagination(...) still works unchanged."""
|
||||
pagination = Pagination(
|
||||
total_count=10,
|
||||
items_per_page=5,
|
||||
page=2,
|
||||
has_more=False,
|
||||
)
|
||||
assert isinstance(pagination, OffsetPagination)
|
||||
|
||||
|
||||
class TestCursorPagination:
|
||||
"""Tests for CursorPagination schema."""
|
||||
@@ -276,7 +261,7 @@ class TestPaginatedResponse:
|
||||
|
||||
def test_create_paginated_response(self):
|
||||
"""Create PaginatedResponse with data and pagination."""
|
||||
pagination = Pagination(
|
||||
pagination = OffsetPagination(
|
||||
total_count=30,
|
||||
items_per_page=10,
|
||||
page=1,
|
||||
@@ -294,7 +279,7 @@ class TestPaginatedResponse:
|
||||
|
||||
def test_with_custom_message(self):
|
||||
"""PaginatedResponse with custom message."""
|
||||
pagination = Pagination(
|
||||
pagination = OffsetPagination(
|
||||
total_count=5,
|
||||
items_per_page=10,
|
||||
page=1,
|
||||
@@ -310,7 +295,7 @@ class TestPaginatedResponse:
|
||||
|
||||
def test_empty_data(self):
|
||||
"""PaginatedResponse with empty data."""
|
||||
pagination = Pagination(
|
||||
pagination = OffsetPagination(
|
||||
total_count=0,
|
||||
items_per_page=10,
|
||||
page=1,
|
||||
@@ -332,7 +317,7 @@ class TestPaginatedResponse:
|
||||
id: int
|
||||
name: str
|
||||
|
||||
pagination = Pagination(
|
||||
pagination = OffsetPagination(
|
||||
total_count=1,
|
||||
items_per_page=10,
|
||||
page=1,
|
||||
@@ -347,7 +332,7 @@ class TestPaginatedResponse:
|
||||
|
||||
def test_serialization(self):
|
||||
"""PaginatedResponse serializes correctly."""
|
||||
pagination = Pagination(
|
||||
pagination = OffsetPagination(
|
||||
total_count=100,
|
||||
items_per_page=10,
|
||||
page=5,
|
||||
@@ -385,16 +370,6 @@ class TestPaginatedResponse:
|
||||
)
|
||||
assert isinstance(response.pagination, CursorPagination)
|
||||
|
||||
def test_pagination_alias_accepted(self):
|
||||
"""Constructing PaginatedResponse with Pagination (alias) still works."""
|
||||
response = PaginatedResponse(
|
||||
data=[],
|
||||
pagination=Pagination(
|
||||
total_count=0, items_per_page=10, page=1, has_more=False
|
||||
),
|
||||
)
|
||||
assert isinstance(response.pagination, OffsetPagination)
|
||||
|
||||
|
||||
class TestFromAttributes:
|
||||
"""Tests for from_attributes config (ORM mode)."""
|
||||
|
||||
Reference in New Issue
Block a user