Compare commits

...

11 Commits

19 changed files with 1083 additions and 217 deletions

View File

@@ -72,6 +72,7 @@ GET /articles/offset?page=2&items_per_page=10&search=fastapi&status=published&or
], ],
"pagination": { "pagination": {
"total_count": 42, "total_count": 42,
"pages": 5,
"page": 2, "page": 2,
"items_per_page": 10, "items_per_page": 10,
"has_more": true "has_more": true
@@ -85,6 +86,8 @@ GET /articles/offset?page=2&items_per_page=10&search=fastapi&status=published&or
`filter_attributes` always reflects the values visible **after** applying the active filters. Use it to populate filter dropdowns on the client. `filter_attributes` always reflects the values visible **after** applying the active filters. Use it to populate filter dropdowns on the client.
To skip the `COUNT(*)` query for better performance on large tables, pass `include_total=False`. `pagination.total_count` will be `null` in the response, while `has_more` remains accurate.
### Cursor pagination ### Cursor pagination
Best for feeds, infinite scroll, or any high-throughput API where offset performance degrades. Best for feeds, infinite scroll, or any high-throughput API where offset performance degrades.
@@ -144,7 +147,7 @@ GET /articles/?pagination_type=offset&page=1&items_per_page=10
"status": "SUCCESS", "status": "SUCCESS",
"pagination_type": "offset", "pagination_type": "offset",
"data": ["..."], "data": ["..."],
"pagination": { "total_count": 42, "page": 1, "items_per_page": 10, "has_more": true } "pagination": { "total_count": 42, "pages": 5, "page": 1, "items_per_page": 10, "has_more": true }
} }
``` ```

View File

@@ -182,6 +182,7 @@ The [`offset_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.Async
"data": ["..."], "data": ["..."],
"pagination": { "pagination": {
"total_count": 100, "total_count": 100,
"pages": 5,
"page": 1, "page": 1,
"items_per_page": 20, "items_per_page": 20,
"has_more": true "has_more": true
@@ -189,6 +190,40 @@ The [`offset_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.Async
} }
``` ```
#### Skipping the COUNT query
!!! info "Added in `v2.4.1`"
By default `offset_paginate` runs two queries: one for the page items and one `COUNT(*)` for `total_count`. On large tables the `COUNT` can be expensive. Pass `include_total=False` to skip it:
```python
result = await UserCrud.offset_paginate(
session=session,
page=page,
items_per_page=items_per_page,
include_total=False,
schema=UserRead,
)
```
#### Pagination params dependency
!!! info "Added in `v2.4.1`"
Use [`offset_params()`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.offset_params) to generate a FastAPI dependency that injects `page` and `items_per_page` from query parameters with configurable defaults and a `max_page_size` cap:
```python
from typing import Annotated
from fastapi import Depends
@router.get("")
async def list_users(
session: SessionDep,
params: Annotated[dict, Depends(UserCrud.offset_params(default_page_size=20, max_page_size=100))],
) -> OffsetPaginatedResponse[UserRead]:
return await UserCrud.offset_paginate(session=session, **params, schema=UserRead)
```
### Cursor pagination ### Cursor pagination
```python ```python
@@ -238,7 +273,7 @@ The cursor column is set once on [`CrudFactory`](../reference/crud.md#fastapi_to
!!! note !!! note
`cursor_column` is required. Calling [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate) on a CRUD class that has no `cursor_column` configured raises a `ValueError`. `cursor_column` is required. Calling [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate) on a CRUD class that has no `cursor_column` configured raises a `ValueError`.
The cursor value is base64-encoded when returned to the client and decoded back to the correct Python type on the next request. The following SQLAlchemy column types are supported: The cursor value is URL-safe base64-encoded (no padding) when returned to the client and decoded back to the correct Python type on the next request. The following SQLAlchemy column types are supported:
| SQLAlchemy type | Python type | | SQLAlchemy type | Python type |
|---|---| |---|---|
@@ -256,6 +291,24 @@ PostCrud = CrudFactory(model=Post, cursor_column=Post.id)
PostCrud = CrudFactory(model=Post, cursor_column=Post.created_at) PostCrud = CrudFactory(model=Post, cursor_column=Post.created_at)
``` ```
#### Pagination params dependency
!!! info "Added in `v2.4.1`"
Use [`cursor_params()`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_params) to inject `cursor` and `items_per_page` from query parameters with a `max_page_size` cap:
```python
from typing import Annotated
from fastapi import Depends
@router.get("")
async def list_users(
session: SessionDep,
params: Annotated[dict, Depends(UserCrud.cursor_params(default_page_size=20, max_page_size=100))],
) -> CursorPaginatedResponse[UserRead]:
return await UserCrud.cursor_paginate(session=session, **params, schema=UserRead)
```
### Unified endpoint (both strategies) ### Unified endpoint (both strategies)
!!! info "Added in `v2.3.0`" !!! info "Added in `v2.3.0`"
@@ -289,7 +342,24 @@ GET /users?pagination_type=offset&page=2&items_per_page=10
GET /users?pagination_type=cursor&cursor=eyJ2YWx1ZSI6...&items_per_page=10 GET /users?pagination_type=cursor&cursor=eyJ2YWx1ZSI6...&items_per_page=10
``` ```
Both `page` and `cursor` are always accepted by the endpoint — unused parameters are silently ignored by `paginate()`. #### Pagination params dependency
!!! info "Added in `v2.4.1`"
Use [`paginate_params()`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.paginate_params) to inject all parameters at once with configurable defaults and a `max_page_size` cap:
```python
from typing import Annotated
from fastapi import Depends
from fastapi_toolsets.schemas import PaginatedResponse
@router.get("")
async def list_users(
session: SessionDep,
params: Annotated[dict, Depends(UserCrud.paginate_params(default_page_size=20, max_page_size=100))],
) -> PaginatedResponse[UserRead]:
return await UserCrud.paginate(session, **params, schema=UserRead)
```
## Search ## Search

View File

@@ -138,6 +138,23 @@ Server-side defaults (e.g. `id`, `created_at`) are fully populated in all callba
| `@watch("status", "role")` | Only fires when `status` or `role` changes | | `@watch("status", "role")` | Only fires when `status` or `role` changes |
| *(no decorator)* | Fires when **any** mapped field changes | | *(no decorator)* | Fires when **any** mapped field changes |
`@watch` is inherited through the class hierarchy. If a subclass does not declare its own `@watch`, it uses the filter from the nearest decorated parent. Applying `@watch` on the subclass overrides the parent's filter:
```python
@watch("status")
class Order(Base, UUIDMixin, WatchedFieldsMixin):
...
class UrgentOrder(Order):
# inherits @watch("status") — on_update fires only for status changes
...
@watch("priority")
class PriorityOrder(Order):
# overrides parent — on_update fires only for priority changes
...
```
#### Option 1 — catch-all with `on_event` #### Option 1 — catch-all with `on_event`
Override `on_event` to handle all event types in one place. The specific methods delegate here by default: Override `on_event` to handle all event types in one place. The specific methods delegate here by default:
@@ -197,6 +214,25 @@ The `changes` dict maps each watched field that changed to `{"old": ..., "new":
!!! warning "Callbacks fire only for ORM-level changes. Rows updated via raw SQL (`UPDATE ... SET ...`) are not detected." !!! warning "Callbacks fire only for ORM-level changes. Rows updated via raw SQL (`UPDATE ... SET ...`) are not detected."
!!! warning "Callbacks fire after the **outermost** transaction commits."
If you create several related objects using `CrudFactory.create` and need
callbacks to see all of them (including associations), wrap the whole
operation in a single [`get_transaction`](db.md) block. Without it, each
`create` call commits independently and `on_create` fires before the
remaining objects exist.
```python
from fastapi_toolsets.db import get_transaction
async with get_transaction(session):
order = await OrderCrud.create(session, order_data)
item = await ItemCrud.create(session, item_data)
await session.refresh(order, attribute_names=["items"])
order.items.append(item)
# on_create fires here for both order and item,
# with the full association already committed.
```
## Composing mixins ## Composing mixins
All mixins can be combined in any order. The only constraint is that exactly one primary key must be defined — either via `UUIDMixin` or directly on the model. All mixins can be combined in any order. The only constraint is that exactly one primary key must be defined — either via `UUIDMixin` or directly on the model.

View File

@@ -1,8 +1,8 @@
from typing import Annotated from typing import Annotated
from fastapi import APIRouter, Depends, Query from fastapi import APIRouter, Depends
from fastapi_toolsets.crud import OrderByClause, PaginationType from fastapi_toolsets.crud import OrderByClause
from fastapi_toolsets.schemas import ( from fastapi_toolsets.schemas import (
CursorPaginatedResponse, CursorPaginatedResponse,
OffsetPaginatedResponse, OffsetPaginatedResponse,
@@ -20,19 +20,20 @@ router = APIRouter(prefix="/articles")
@router.get("/offset") @router.get("/offset")
async def list_articles_offset( async def list_articles_offset(
session: SessionDep, session: SessionDep,
params: Annotated[
dict,
Depends(ArticleCrud.offset_params(default_page_size=20, max_page_size=100)),
],
filter_by: Annotated[dict[str, list[str]], Depends(ArticleCrud.filter_params())], filter_by: Annotated[dict[str, list[str]], Depends(ArticleCrud.filter_params())],
order_by: Annotated[ order_by: Annotated[
OrderByClause | None, OrderByClause | None,
Depends(ArticleCrud.order_params(default_field=Article.created_at)), Depends(ArticleCrud.order_params(default_field=Article.created_at)),
], ],
page: int = Query(1, ge=1),
items_per_page: int = Query(20, ge=1, le=100),
search: str | None = None, search: str | None = None,
) -> OffsetPaginatedResponse[ArticleRead]: ) -> OffsetPaginatedResponse[ArticleRead]:
return await ArticleCrud.offset_paginate( return await ArticleCrud.offset_paginate(
session=session, session=session,
page=page, **params,
items_per_page=items_per_page,
search=search, search=search,
filter_by=filter_by or None, filter_by=filter_by or None,
order_by=order_by, order_by=order_by,
@@ -43,19 +44,20 @@ async def list_articles_offset(
@router.get("/cursor") @router.get("/cursor")
async def list_articles_cursor( async def list_articles_cursor(
session: SessionDep, session: SessionDep,
params: Annotated[
dict,
Depends(ArticleCrud.cursor_params(default_page_size=20, max_page_size=100)),
],
filter_by: Annotated[dict[str, list[str]], Depends(ArticleCrud.filter_params())], filter_by: Annotated[dict[str, list[str]], Depends(ArticleCrud.filter_params())],
order_by: Annotated[ order_by: Annotated[
OrderByClause | None, OrderByClause | None,
Depends(ArticleCrud.order_params(default_field=Article.created_at)), Depends(ArticleCrud.order_params(default_field=Article.created_at)),
], ],
cursor: str | None = None,
items_per_page: int = Query(20, ge=1, le=100),
search: str | None = None, search: str | None = None,
) -> CursorPaginatedResponse[ArticleRead]: ) -> CursorPaginatedResponse[ArticleRead]:
return await ArticleCrud.cursor_paginate( return await ArticleCrud.cursor_paginate(
session=session, session=session,
cursor=cursor, **params,
items_per_page=items_per_page,
search=search, search=search,
filter_by=filter_by or None, filter_by=filter_by or None,
order_by=order_by, order_by=order_by,
@@ -66,23 +68,20 @@ async def list_articles_cursor(
@router.get("/") @router.get("/")
async def list_articles( async def list_articles(
session: SessionDep, session: SessionDep,
params: Annotated[
dict,
Depends(ArticleCrud.paginate_params(default_page_size=20, max_page_size=100)),
],
filter_by: Annotated[dict[str, list[str]], Depends(ArticleCrud.filter_params())], filter_by: Annotated[dict[str, list[str]], Depends(ArticleCrud.filter_params())],
order_by: Annotated[ order_by: Annotated[
OrderByClause | None, OrderByClause | None,
Depends(ArticleCrud.order_params(default_field=Article.created_at)), Depends(ArticleCrud.order_params(default_field=Article.created_at)),
], ],
pagination_type: PaginationType = PaginationType.OFFSET,
page: int = Query(1, ge=1),
cursor: str | None = None,
items_per_page: int = Query(20, ge=1, le=100),
search: str | None = None, search: str | None = None,
) -> PaginatedResponse[ArticleRead]: ) -> PaginatedResponse[ArticleRead]:
return await ArticleCrud.paginate( return await ArticleCrud.paginate(
session, session,
pagination_type=pagination_type, **params,
page=page,
cursor=cursor,
items_per_page=items_per_page,
search=search, search=search,
filter_by=filter_by or None, filter_by=filter_by or None,
order_by=order_by, order_by=order_by,

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "fastapi-toolsets" name = "fastapi-toolsets"
version = "2.4.0" version = "2.4.2"
description = "Production-ready utilities for FastAPI applications" description = "Production-ready utilities for FastAPI applications"
readme = "README.md" readme = "README.md"
license = "MIT" license = "MIT"

View File

@@ -21,4 +21,4 @@ Example usage:
return Response(data={"user": user.username}, message="Success") return Response(data={"user": user.username}, message="Success")
""" """
__version__ = "2.4.0" __version__ = "2.4.2"

View File

@@ -58,18 +58,33 @@ class _CursorDirection(str, Enum):
def _encode_cursor( def _encode_cursor(
value: Any, *, direction: _CursorDirection = _CursorDirection.NEXT value: Any, *, direction: _CursorDirection = _CursorDirection.NEXT
) -> str: ) -> str:
"""Encode a cursor column value and navigation direction as a base64 string.""" """Encode a cursor column value and navigation direction as a URL-safe base64 string."""
return base64.b64encode( return (
base64.urlsafe_b64encode(
json.dumps({"val": str(value), "dir": direction}).encode() json.dumps({"val": str(value), "dir": direction}).encode()
).decode() )
.decode()
.rstrip("=")
)
def _decode_cursor(cursor: str) -> tuple[str, _CursorDirection]: def _decode_cursor(cursor: str) -> tuple[str, _CursorDirection]:
"""Decode a cursor base64 string into ``(raw_value, direction)``.""" """Decode a URL-safe base64 cursor string into ``(raw_value, direction)``."""
payload = json.loads(base64.b64decode(cursor.encode()).decode()) padded = cursor + "=" * (-len(cursor) % 4)
payload = json.loads(base64.urlsafe_b64decode(padded).decode())
return payload["val"], _CursorDirection(payload["dir"]) return payload["val"], _CursorDirection(payload["dir"])
def _page_size_query(default: int, max_size: int) -> int:
"""Return a FastAPI ``Query`` for the ``items_per_page`` parameter."""
return Query(
default,
ge=1,
le=max_size,
description=f"Number of items per page (max {max_size})",
)
def _parse_cursor_value(raw_val: str, col_type: Any) -> Any: def _parse_cursor_value(raw_val: str, col_type: Any) -> Any:
"""Parse a raw cursor string value back into the appropriate Python type.""" """Parse a raw cursor string value back into the appropriate Python type."""
if isinstance(col_type, Integer): if isinstance(col_type, Integer):
@@ -254,6 +269,7 @@ class AsyncCrud(Generic[ModelType]):
facet_fields: Sequence[FacetFieldType] | None = None, facet_fields: Sequence[FacetFieldType] | None = None,
) -> Callable[..., Awaitable[dict[str, list[str]]]]: ) -> Callable[..., Awaitable[dict[str, list[str]]]]:
"""Return a FastAPI dependency that collects facet filter values from query parameters. """Return a FastAPI dependency that collects facet filter values from query parameters.
Args: Args:
facet_fields: Override the facet fields for this dependency. Falls back to the facet_fields: Override the facet fields for this dependency. Falls back to the
class-level ``facet_fields`` if not provided. class-level ``facet_fields`` if not provided.
@@ -293,6 +309,121 @@ class AsyncCrud(Generic[ModelType]):
return dependency return dependency
@classmethod
def offset_params(
cls: type[Self],
*,
default_page_size: int = 20,
max_page_size: int = 100,
include_total: bool = True,
) -> Callable[..., Awaitable[dict[str, Any]]]:
"""Return a FastAPI dependency that collects offset pagination params from query params.
Args:
default_page_size: Default value for the ``items_per_page`` query parameter.
max_page_size: Maximum allowed value for ``items_per_page`` (enforced via
``le`` on the ``Query``).
include_total: Server-side flag forwarded as-is to ``include_total`` in
:meth:`offset_paginate`. Not exposed as a query parameter.
Returns:
An async dependency that resolves to a dict with ``page``,
``items_per_page``, and ``include_total`` keys, ready to be
unpacked into :meth:`offset_paginate`.
"""
async def dependency(
page: int = Query(1, ge=1, description="Page number (1-indexed)"),
items_per_page: int = _page_size_query(default_page_size, max_page_size),
) -> dict[str, Any]:
return {
"page": page,
"items_per_page": items_per_page,
"include_total": include_total,
}
dependency.__name__ = f"{cls.model.__name__}OffsetParams"
return dependency
@classmethod
def cursor_params(
cls: type[Self],
*,
default_page_size: int = 20,
max_page_size: int = 100,
) -> Callable[..., Awaitable[dict[str, Any]]]:
"""Return a FastAPI dependency that collects cursor pagination params from query params.
Args:
default_page_size: Default value for the ``items_per_page`` query parameter.
max_page_size: Maximum allowed value for ``items_per_page`` (enforced via
``le`` on the ``Query``).
Returns:
An async dependency that resolves to a dict with ``cursor`` and
``items_per_page`` keys, ready to be unpacked into
:meth:`cursor_paginate`.
"""
async def dependency(
cursor: str | None = Query(
None, description="Cursor token from a previous response"
),
items_per_page: int = _page_size_query(default_page_size, max_page_size),
) -> dict[str, Any]:
return {"cursor": cursor, "items_per_page": items_per_page}
dependency.__name__ = f"{cls.model.__name__}CursorParams"
return dependency
@classmethod
def paginate_params(
cls: type[Self],
*,
default_page_size: int = 20,
max_page_size: int = 100,
default_pagination_type: PaginationType = PaginationType.OFFSET,
include_total: bool = True,
) -> Callable[..., Awaitable[dict[str, Any]]]:
"""Return a FastAPI dependency that collects all pagination params from query params.
Args:
default_page_size: Default value for the ``items_per_page`` query parameter.
max_page_size: Maximum allowed value for ``items_per_page`` (enforced via
``le`` on the ``Query``).
default_pagination_type: Default pagination strategy.
include_total: Server-side flag forwarded as-is to ``include_total`` in
:meth:`paginate`. Not exposed as a query parameter.
Returns:
An async dependency that resolves to a dict with ``pagination_type``,
``page``, ``cursor``, ``items_per_page``, and ``include_total`` keys,
ready to be unpacked into :meth:`paginate`.
"""
async def dependency(
pagination_type: PaginationType = Query(
default_pagination_type, description="Pagination strategy"
),
page: int = Query(
1, ge=1, description="Page number (1-indexed, offset only)"
),
cursor: str | None = Query(
None, description="Cursor token from a previous response (cursor only)"
),
items_per_page: int = _page_size_query(default_page_size, max_page_size),
) -> dict[str, Any]:
return {
"pagination_type": pagination_type,
"page": page,
"cursor": cursor,
"items_per_page": items_per_page,
"include_total": include_total,
}
dependency.__name__ = f"{cls.model.__name__}PaginateParams"
return dependency
@classmethod @classmethod
def order_params( def order_params(
cls: type[Self], cls: type[Self],
@@ -922,6 +1053,7 @@ class AsyncCrud(Generic[ModelType]):
order_by: OrderByClause | None = None, order_by: OrderByClause | None = None,
page: int = 1, page: int = 1,
items_per_page: int = 20, items_per_page: int = 20,
include_total: bool = True,
search: str | SearchConfig | None = None, search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None, search_fields: Sequence[SearchFieldType] | None = None,
facet_fields: Sequence[FacetFieldType] | None = None, facet_fields: Sequence[FacetFieldType] | None = None,
@@ -939,6 +1071,8 @@ class AsyncCrud(Generic[ModelType]):
order_by: Column or list of columns to order by order_by: Column or list of columns to order by
page: Page number (1-indexed) page: Page number (1-indexed)
items_per_page: Number of items per page items_per_page: Number of items per page
include_total: When ``False``, skip the ``COUNT`` query;
``pagination.total_count`` will be ``None``.
search: Search query string or SearchConfig object search: Search query string or SearchConfig object
search_fields: Fields to search in (overrides class default) search_fields: Fields to search in (overrides class default)
facet_fields: Columns to compute distinct values for (overrides class default) facet_fields: Columns to compute distinct values for (overrides class default)
@@ -983,10 +1117,10 @@ class AsyncCrud(Generic[ModelType]):
if order_by is not None: if order_by is not None:
q = q.order_by(order_by) q = q.order_by(order_by)
if include_total:
q = q.offset(offset).limit(items_per_page) q = q.offset(offset).limit(items_per_page)
result = await session.execute(q) result = await session.execute(q)
raw_items = cast(list[ModelType], result.unique().scalars().all()) raw_items = cast(list[ModelType], result.unique().scalars().all())
items: list[Any] = [schema.model_validate(item) for item in raw_items]
# Count query (with same joins and filters) # Count query (with same joins and filters)
pk_col = cls.model.__mapper__.primary_key[0] pk_col = cls.model.__mapper__.primary_key[0]
@@ -1003,7 +1137,18 @@ class AsyncCrud(Generic[ModelType]):
count_q = count_q.where(and_(*filters)) count_q = count_q.where(and_(*filters))
count_result = await session.execute(count_q) count_result = await session.execute(count_q)
total_count = count_result.scalar_one() total_count: int | None = count_result.scalar_one()
has_more = page * items_per_page < total_count
else:
# Fetch one extra row to detect if a next page exists without COUNT
q = q.offset(offset).limit(items_per_page + 1)
result = await session.execute(q)
raw_items = cast(list[ModelType], result.unique().scalars().all())
has_more = len(raw_items) > items_per_page
raw_items = raw_items[:items_per_page]
total_count = None
items: list[Any] = [schema.model_validate(item) for item in raw_items]
filter_attributes = await cls._build_filter_attributes( filter_attributes = await cls._build_filter_attributes(
session, facet_fields, filters, search_joins session, facet_fields, filters, search_joins
@@ -1015,7 +1160,7 @@ class AsyncCrud(Generic[ModelType]):
total_count=total_count, total_count=total_count,
items_per_page=items_per_page, items_per_page=items_per_page,
page=page, page=page,
has_more=page * items_per_page < total_count, has_more=has_more,
), ),
filter_attributes=filter_attributes, filter_attributes=filter_attributes,
) )
@@ -1190,6 +1335,7 @@ class AsyncCrud(Generic[ModelType]):
page: int = ..., page: int = ...,
cursor: str | None = ..., cursor: str | None = ...,
items_per_page: int = ..., items_per_page: int = ...,
include_total: bool = ...,
search: str | SearchConfig | None = ..., search: str | SearchConfig | None = ...,
search_fields: Sequence[SearchFieldType] | None = ..., search_fields: Sequence[SearchFieldType] | None = ...,
facet_fields: Sequence[FacetFieldType] | None = ..., facet_fields: Sequence[FacetFieldType] | None = ...,
@@ -1212,6 +1358,7 @@ class AsyncCrud(Generic[ModelType]):
page: int = ..., page: int = ...,
cursor: str | None = ..., cursor: str | None = ...,
items_per_page: int = ..., items_per_page: int = ...,
include_total: bool = ...,
search: str | SearchConfig | None = ..., search: str | SearchConfig | None = ...,
search_fields: Sequence[SearchFieldType] | None = ..., search_fields: Sequence[SearchFieldType] | None = ...,
facet_fields: Sequence[FacetFieldType] | None = ..., facet_fields: Sequence[FacetFieldType] | None = ...,
@@ -1233,6 +1380,7 @@ class AsyncCrud(Generic[ModelType]):
page: int = 1, page: int = 1,
cursor: str | None = None, cursor: str | None = None,
items_per_page: int = 20, items_per_page: int = 20,
include_total: bool = True,
search: str | SearchConfig | None = None, search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None, search_fields: Sequence[SearchFieldType] | None = None,
facet_fields: Sequence[FacetFieldType] | None = None, facet_fields: Sequence[FacetFieldType] | None = None,
@@ -1258,6 +1406,8 @@ class AsyncCrud(Generic[ModelType]):
:class:`.CursorPaginatedResponse`. Only used when :class:`.CursorPaginatedResponse`. Only used when
``pagination_type`` is ``CURSOR``. ``pagination_type`` is ``CURSOR``.
items_per_page: Number of items per page (default 20). items_per_page: Number of items per page (default 20).
include_total: When ``False``, skip the ``COUNT`` query;
only applies when ``pagination_type`` is ``OFFSET``.
search: Search query string or :class:`.SearchConfig` object. search: Search query string or :class:`.SearchConfig` object.
search_fields: Fields to search in (overrides class default). search_fields: Fields to search in (overrides class default).
facet_fields: Columns to compute distinct values for (overrides facet_fields: Columns to compute distinct values for (overrides
@@ -1304,6 +1454,7 @@ class AsyncCrud(Generic[ModelType]):
order_by=order_by, order_by=order_by,
page=page, page=page,
items_per_page=items_per_page, items_per_page=items_per_page,
include_total=include_total,
search=search, search=search,
search_fields=search_fields, search_fields=search_fields,
facet_fields=facet_fields, facet_fields=facet_fields,

View File

@@ -1,6 +1,7 @@
"""Field-change monitoring via SQLAlchemy session events.""" """Field-change monitoring via SQLAlchemy session events."""
import asyncio import asyncio
import inspect
import weakref import weakref
from collections.abc import Awaitable from collections.abc import Awaitable
from enum import Enum from enum import Enum
@@ -25,6 +26,7 @@ _SESSION_PENDING_NEW = "_ft_pending_new"
_SESSION_CREATES = "_ft_creates" _SESSION_CREATES = "_ft_creates"
_SESSION_DELETES = "_ft_deletes" _SESSION_DELETES = "_ft_deletes"
_SESSION_UPDATES = "_ft_updates" _SESSION_UPDATES = "_ft_updates"
_SESSION_SAVEPOINT_DEPTH = "_ft_sp_depth"
class ModelEvent(str, Enum): class ModelEvent(str, Enum):
@@ -65,6 +67,14 @@ def _snapshot_column_attrs(obj: Any) -> dict[str, Any]:
} }
def _get_watched_fields(cls: type) -> list[str] | None:
"""Return the watched fields for *cls*, walking the MRO to inherit from parents."""
for klass in cls.__mro__:
if klass in _WATCHED_FIELDS:
return _WATCHED_FIELDS[klass]
return None
def _upsert_changes( def _upsert_changes(
pending: dict[int, tuple[Any, dict[str, dict[str, Any]]]], pending: dict[int, tuple[Any, dict[str, dict[str, Any]]]],
obj: Any, obj: Any,
@@ -83,6 +93,22 @@ def _upsert_changes(
pending[key] = (obj, changes) pending[key] = (obj, changes)
@event.listens_for(AsyncSession.sync_session_class, "after_transaction_create")
def _after_transaction_create(session: Any, transaction: Any) -> None:
if transaction.nested:
session.info[_SESSION_SAVEPOINT_DEPTH] = (
session.info.get(_SESSION_SAVEPOINT_DEPTH, 0) + 1
)
@event.listens_for(AsyncSession.sync_session_class, "after_transaction_end")
def _after_transaction_end(session: Any, transaction: Any) -> None:
if transaction.nested:
depth = session.info.get(_SESSION_SAVEPOINT_DEPTH, 0)
if depth > 0: # pragma: no branch
session.info[_SESSION_SAVEPOINT_DEPTH] = depth - 1
@event.listens_for(AsyncSession.sync_session_class, "after_flush") @event.listens_for(AsyncSession.sync_session_class, "after_flush")
def _after_flush(session: Any, flush_context: Any) -> None: def _after_flush(session: Any, flush_context: Any) -> None:
# New objects: capture references while session.new is still populated. # New objects: capture references while session.new is still populated.
@@ -102,7 +128,7 @@ def _after_flush(session: Any, flush_context: Any) -> None:
continue continue
# None = not in dict = watch all fields; list = specific fields only # None = not in dict = watch all fields; list = specific fields only
watched = _WATCHED_FIELDS.get(type(obj)) watched = _get_watched_fields(type(obj))
changes: dict[str, dict[str, Any]] = {} changes: dict[str, dict[str, Any]] = {}
attrs = ( attrs = (
@@ -169,7 +195,7 @@ def _schedule_with_snapshot(
_sa_set_committed_value(obj, key, value) _sa_set_committed_value(obj, key, value)
try: try:
result = fn(*args) result = fn(*args)
if asyncio.iscoroutine(result): if inspect.isawaitable(result):
await result await result
except Exception as exc: except Exception as exc:
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc) _logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
@@ -180,12 +206,24 @@ def _schedule_with_snapshot(
@event.listens_for(AsyncSession.sync_session_class, "after_commit") @event.listens_for(AsyncSession.sync_session_class, "after_commit")
def _after_commit(session: Any) -> None: def _after_commit(session: Any) -> None:
if session.info.get(_SESSION_SAVEPOINT_DEPTH, 0) > 0:
return
creates: list[Any] = session.info.pop(_SESSION_CREATES, []) creates: list[Any] = session.info.pop(_SESSION_CREATES, [])
deletes: list[Any] = session.info.pop(_SESSION_DELETES, []) deletes: list[Any] = session.info.pop(_SESSION_DELETES, [])
field_changes: dict[int, tuple[Any, dict[str, dict[str, Any]]]] = session.info.pop( field_changes: dict[int, tuple[Any, dict[str, dict[str, Any]]]] = session.info.pop(
_SESSION_UPDATES, {} _SESSION_UPDATES, {}
) )
if creates and deletes:
transient_ids = {id(o) for o in creates} & {id(o) for o in deletes}
if transient_ids:
creates = [o for o in creates if id(o) not in transient_ids]
deletes = [o for o in deletes if id(o) not in transient_ids]
field_changes = {
k: v for k, v in field_changes.items() if k not in transient_ids
}
if not creates and not deletes and not field_changes: if not creates and not deletes and not field_changes:
return return

View File

@@ -1,9 +1,10 @@
"""Base Pydantic schemas for API responses.""" """Base Pydantic schemas for API responses."""
import math
from enum import Enum from enum import Enum
from typing import Annotated, Any, ClassVar, Generic, Literal, TypeVar, Union from typing import Annotated, Any, ClassVar, Generic, Literal, TypeVar, Union
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field, computed_field
from .types import DataT from .types import DataT
@@ -98,17 +99,29 @@ class OffsetPagination(PydanticBase):
"""Pagination metadata for offset-based list responses. """Pagination metadata for offset-based list responses.
Attributes: Attributes:
total_count: Total number of items across all pages total_count: Total number of items across all pages.
``None`` when ``include_total=False``.
items_per_page: Number of items per page items_per_page: Number of items per page
page: Current page number (1-indexed) page: Current page number (1-indexed)
has_more: Whether there are more pages has_more: Whether there are more pages
pages: Total number of pages
""" """
total_count: int total_count: int | None
items_per_page: int items_per_page: int
page: int page: int
has_more: bool has_more: bool
@computed_field
@property
def pages(self) -> int | None:
"""Total number of pages, or ``None`` when ``total_count`` is unknown."""
if self.total_count is None:
return None
if self.items_per_page == 0:
return 0
return math.ceil(self.total_count / self.items_per_page)
class CursorPagination(PydanticBase): class CursorPagination(PydanticBase):
"""Pagination metadata for cursor-based list responses. """Pagination metadata for cursor-based list responses.

View File

@@ -321,30 +321,3 @@ async def db_session(engine):
# Drop tables after test # Drop tables after test
async with engine.begin() as conn: async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all) await conn.run_sync(Base.metadata.drop_all)
@pytest.fixture
def sample_role_data() -> RoleCreate:
"""Sample role creation data."""
return RoleCreate(name="admin")
@pytest.fixture
def sample_user_data() -> UserCreate:
"""Sample user creation data."""
return UserCreate(
username="testuser",
email="test@example.com",
is_active=True,
)
@pytest.fixture
def sample_post_data() -> PostCreate:
"""Sample post creation data."""
return PostCreate(
title="Test Post",
content="Test content",
is_published=True,
author_id=uuid.uuid4(),
)

View File

@@ -1759,6 +1759,52 @@ class TestSchemaResponse:
assert result.data[0].username == "pg_user" assert result.data[0].username == "pg_user"
assert not hasattr(result.data[0], "email") assert not hasattr(result.data[0], "email")
@pytest.mark.anyio
async def test_include_total_false_skips_count(self, db_session: AsyncSession):
"""offset_paginate with include_total=False returns total_count=None."""
from fastapi_toolsets.schemas import OffsetPagination
for i in range(5):
await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
result = await RoleCrud.offset_paginate(
db_session, items_per_page=10, include_total=False, schema=RoleRead
)
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count is None
assert len(result.data) == 5
assert result.pagination.has_more is False
@pytest.mark.anyio
async def test_include_total_false_has_more_true(self, db_session: AsyncSession):
"""offset_paginate with include_total=False sets has_more via extra-row probe."""
for i in range(15):
await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
result = await RoleCrud.offset_paginate(
db_session, items_per_page=10, include_total=False, schema=RoleRead
)
assert result.pagination.total_count is None
assert result.pagination.has_more is True
assert len(result.data) == 10
@pytest.mark.anyio
async def test_include_total_false_exact_page_boundary(
self, db_session: AsyncSession
):
"""offset_paginate with include_total=False: has_more=False when items == page size."""
for i in range(10):
await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
result = await RoleCrud.offset_paginate(
db_session, items_per_page=10, include_total=False, schema=RoleRead
)
assert result.pagination.has_more is False
assert len(result.data) == 10
class TestCursorPaginate: class TestCursorPaginate:
"""Tests for cursor-based pagination via cursor_paginate().""" """Tests for cursor-based pagination via cursor_paginate()."""
@@ -2521,3 +2567,20 @@ class TestPaginate:
pagination_type="unknown", pagination_type="unknown",
schema=RoleRead, schema=RoleRead,
) # type: ignore[no-matching-overload] ) # type: ignore[no-matching-overload]
@pytest.mark.anyio
async def test_offset_include_total_false(self, db_session: AsyncSession):
"""paginate() passes include_total=False through to offset_paginate."""
from fastapi_toolsets.schemas import OffsetPagination
await RoleCrud.create(db_session, RoleCreate(name="admin"))
result = await RoleCrud.paginate(
db_session,
pagination_type=PaginationType.OFFSET,
include_total=False,
schema=RoleRead,
)
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count is None

View File

@@ -14,12 +14,14 @@ from fastapi_toolsets.crud import (
get_searchable_fields, get_searchable_fields,
) )
from fastapi_toolsets.exceptions import InvalidOrderFieldError from fastapi_toolsets.exceptions import InvalidOrderFieldError
from fastapi_toolsets.schemas import OffsetPagination from fastapi_toolsets.schemas import OffsetPagination, PaginationType
from .conftest import ( from .conftest import (
Role, Role,
RoleCreate, RoleCreate,
RoleCrud, RoleCrud,
RoleCursorCrud,
RoleRead,
User, User,
UserCreate, UserCreate,
UserCrud, UserCrud,
@@ -1193,3 +1195,245 @@ class TestOrderParamsSchema:
assert results[0].username == "alice" assert results[0].username == "alice"
assert results[1].username == "charlie" assert results[1].username == "charlie"
class TestOffsetParamsSchema:
"""Tests for AsyncCrud.offset_params()."""
def test_returns_page_and_items_per_page_params(self):
"""Returned dependency has page and items_per_page params only."""
dep = RoleCrud.offset_params()
param_names = set(inspect.signature(dep).parameters)
assert param_names == {"page", "items_per_page"}
def test_dependency_name_includes_model_name(self):
"""Dependency function is named after the model."""
dep = RoleCrud.offset_params()
assert getattr(dep, "__name__") == "RoleOffsetParams"
def test_default_page_size_reflected_in_items_per_page_default(self):
"""default_page_size is used as the default for items_per_page."""
dep = RoleCrud.offset_params(default_page_size=42)
sig = inspect.signature(dep)
assert sig.parameters["items_per_page"].default.default == 42
def test_max_page_size_reflected_in_items_per_page_le(self):
"""max_page_size is used as le constraint on items_per_page."""
dep = RoleCrud.offset_params(max_page_size=50)
sig = inspect.signature(dep)
le = next(
m.le
for m in sig.parameters["items_per_page"].default.metadata
if hasattr(m, "le")
)
assert le == 50
def test_include_total_not_a_query_param(self):
"""include_total is not exposed as a query parameter."""
dep = RoleCrud.offset_params()
param_names = set(inspect.signature(dep).parameters)
assert "include_total" not in param_names
@pytest.mark.anyio
async def test_include_total_true_forwarded_in_result(self):
"""include_total=True factory arg appears in the resolved dict."""
result = await RoleCrud.offset_params(include_total=True)(
page=1, items_per_page=10
)
assert result["include_total"] is True
@pytest.mark.anyio
async def test_include_total_false_forwarded_in_result(self):
"""include_total=False factory arg appears in the resolved dict."""
result = await RoleCrud.offset_params(include_total=False)(
page=1, items_per_page=10
)
assert result["include_total"] is False
@pytest.mark.anyio
async def test_awaiting_dep_returns_dict(self):
"""Awaiting the dependency returns a dict with page, items_per_page, include_total."""
dep = RoleCrud.offset_params(include_total=False)
result = await dep(page=2, items_per_page=10)
assert result == {"page": 2, "items_per_page": 10, "include_total": False}
@pytest.mark.anyio
async def test_integrates_with_offset_paginate(self, db_session: AsyncSession):
"""offset_params output can be unpacked directly into offset_paginate."""
await RoleCrud.create(db_session, RoleCreate(name="admin"))
dep = RoleCrud.offset_params()
params = await dep(page=1, items_per_page=10)
result = await RoleCrud.offset_paginate(db_session, **params, schema=RoleRead)
assert result.pagination.page == 1
assert result.pagination.items_per_page == 10
class TestCursorParamsSchema:
"""Tests for AsyncCrud.cursor_params()."""
def test_returns_cursor_and_items_per_page_params(self):
"""Returned dependency has cursor and items_per_page params."""
dep = RoleCursorCrud.cursor_params()
param_names = set(inspect.signature(dep).parameters)
assert param_names == {"cursor", "items_per_page"}
def test_dependency_name_includes_model_name(self):
"""Dependency function is named after the model."""
dep = RoleCursorCrud.cursor_params()
assert getattr(dep, "__name__") == "RoleCursorParams"
def test_default_page_size_reflected_in_items_per_page_default(self):
"""default_page_size is used as the default for items_per_page."""
dep = RoleCursorCrud.cursor_params(default_page_size=15)
sig = inspect.signature(dep)
assert sig.parameters["items_per_page"].default.default == 15
def test_max_page_size_reflected_in_items_per_page_le(self):
"""max_page_size is used as le constraint on items_per_page."""
dep = RoleCursorCrud.cursor_params(max_page_size=75)
sig = inspect.signature(dep)
le = next(
m.le
for m in sig.parameters["items_per_page"].default.metadata
if hasattr(m, "le")
)
assert le == 75
def test_cursor_defaults_to_none(self):
"""cursor defaults to None."""
dep = RoleCursorCrud.cursor_params()
sig = inspect.signature(dep)
assert sig.parameters["cursor"].default.default is None
@pytest.mark.anyio
async def test_awaiting_dep_returns_dict(self):
"""Awaiting the dependency returns a dict with cursor and items_per_page."""
dep = RoleCursorCrud.cursor_params()
result = await dep(cursor=None, items_per_page=5)
assert result == {"cursor": None, "items_per_page": 5}
@pytest.mark.anyio
async def test_integrates_with_cursor_paginate(self, db_session: AsyncSession):
"""cursor_params output can be unpacked directly into cursor_paginate."""
await RoleCrud.create(db_session, RoleCreate(name="admin"))
dep = RoleCursorCrud.cursor_params()
params = await dep(cursor=None, items_per_page=10)
result = await RoleCursorCrud.cursor_paginate(
db_session, **params, schema=RoleRead
)
assert result.pagination.items_per_page == 10
class TestPaginateParamsSchema:
"""Tests for AsyncCrud.paginate_params()."""
def test_returns_all_params(self):
"""Returned dependency has pagination_type, page, cursor, items_per_page (no include_total)."""
dep = RoleCursorCrud.paginate_params()
param_names = set(inspect.signature(dep).parameters)
assert param_names == {"pagination_type", "page", "cursor", "items_per_page"}
def test_dependency_name_includes_model_name(self):
"""Dependency function is named after the model."""
dep = RoleCursorCrud.paginate_params()
assert getattr(dep, "__name__") == "RolePaginateParams"
def test_default_pagination_type(self):
"""default_pagination_type is reflected in pagination_type default."""
from fastapi_toolsets.schemas import PaginationType
dep = RoleCursorCrud.paginate_params(
default_pagination_type=PaginationType.CURSOR
)
sig = inspect.signature(dep)
assert (
sig.parameters["pagination_type"].default.default == PaginationType.CURSOR
)
def test_default_page_size(self):
"""default_page_size is reflected in items_per_page default."""
dep = RoleCursorCrud.paginate_params(default_page_size=15)
sig = inspect.signature(dep)
assert sig.parameters["items_per_page"].default.default == 15
def test_max_page_size_le_constraint(self):
"""max_page_size is used as le constraint on items_per_page."""
dep = RoleCursorCrud.paginate_params(max_page_size=60)
sig = inspect.signature(dep)
le = next(
m.le
for m in sig.parameters["items_per_page"].default.metadata
if hasattr(m, "le")
)
assert le == 60
def test_include_total_not_a_query_param(self):
"""include_total is not exposed as a query parameter."""
dep = RoleCursorCrud.paginate_params()
assert "include_total" not in set(inspect.signature(dep).parameters)
@pytest.mark.anyio
async def test_include_total_forwarded_in_result(self):
"""include_total factory arg appears in the resolved dict."""
result_true = await RoleCursorCrud.paginate_params(include_total=True)(
pagination_type=PaginationType.OFFSET,
page=1,
cursor=None,
items_per_page=10,
)
result_false = await RoleCursorCrud.paginate_params(include_total=False)(
pagination_type=PaginationType.OFFSET,
page=1,
cursor=None,
items_per_page=10,
)
assert result_true["include_total"] is True
assert result_false["include_total"] is False
@pytest.mark.anyio
async def test_awaiting_dep_returns_dict(self):
"""Awaiting the dependency returns a dict with all pagination keys."""
dep = RoleCursorCrud.paginate_params()
result = await dep(
pagination_type=PaginationType.OFFSET,
page=2,
cursor=None,
items_per_page=10,
)
assert result == {
"pagination_type": PaginationType.OFFSET,
"page": 2,
"cursor": None,
"items_per_page": 10,
"include_total": True,
}
@pytest.mark.anyio
async def test_integrates_with_paginate_offset(self, db_session: AsyncSession):
"""paginate_params output unpacks into paginate() for offset strategy."""
from fastapi_toolsets.schemas import OffsetPagination
await RoleCrud.create(db_session, RoleCreate(name="admin"))
params = await RoleCursorCrud.paginate_params()(
pagination_type=PaginationType.OFFSET,
page=1,
cursor=None,
items_per_page=10,
)
result = await RoleCursorCrud.paginate(db_session, **params, schema=RoleRead)
assert isinstance(result.pagination, OffsetPagination)
@pytest.mark.anyio
async def test_integrates_with_paginate_cursor(self, db_session: AsyncSession):
"""paginate_params output unpacks into paginate() for cursor strategy."""
from fastapi_toolsets.schemas import CursorPagination
await RoleCrud.create(db_session, RoleCreate(name="admin"))
params = await RoleCursorCrud.paginate_params()(
pagination_type=PaginationType.CURSOR,
page=1,
cursor=None,
items_per_page=10,
)
result = await RoleCursorCrud.paginate(db_session, **params, schema=RoleRead)
assert isinstance(result.pagination, CursorPagination)

View File

@@ -10,12 +10,13 @@ import datetime
import pytest import pytest
from fastapi import FastAPI from fastapi import FastAPI
from httpx import ASGITransport, AsyncClient from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import AsyncSession
from docs_src.examples.pagination_search.db import get_db from docs_src.examples.pagination_search.db import get_db
from docs_src.examples.pagination_search.models import Article, Base, Category from docs_src.examples.pagination_search.models import Article, Base, Category
from docs_src.examples.pagination_search.routes import router from docs_src.examples.pagination_search.routes import router
from fastapi_toolsets.exceptions import init_exceptions_handlers from fastapi_toolsets.exceptions import init_exceptions_handlers
from fastapi_toolsets.pytest import create_db_session
from .conftest import DATABASE_URL from .conftest import DATABASE_URL
@@ -35,20 +36,8 @@ def build_app(session: AsyncSession) -> FastAPI:
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
async def ex_db_session(): async def ex_db_session():
"""Isolated session for the example models (separate tables from conftest).""" """Isolated session for the example models (separate tables from conftest)."""
engine = create_async_engine(DATABASE_URL, echo=False) async with create_db_session(DATABASE_URL, Base) as session:
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
session_factory = async_sessionmaker(engine, expire_on_commit=False)
session = session_factory()
try:
yield session yield session
finally:
await session.close()
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await engine.dispose()
@pytest.fixture @pytest.fixture

View File

@@ -16,7 +16,7 @@ from fastapi_toolsets.fixtures import (
from fastapi_toolsets.fixtures.utils import _get_primary_key from fastapi_toolsets.fixtures.utils import _get_primary_key
from .conftest import IntRole, Permission, Role, User from .conftest import IntRole, Permission, Role, RoleCrud, User, UserCrud
class TestContext: class TestContext:
@@ -447,8 +447,6 @@ class TestLoadFixtures:
assert "roles" in result assert "roles" in result
assert len(result["roles"]) == 2 assert len(result["roles"]) == 2
from .conftest import RoleCrud
count = await RoleCrud.count(db_session) count = await RoleCrud.count(db_session)
assert count == 2 assert count == 2
@@ -479,8 +477,6 @@ class TestLoadFixtures:
assert "roles" in result assert "roles" in result
assert "users" in result assert "users" in result
from .conftest import RoleCrud, UserCrud
assert await RoleCrud.count(db_session) == 1 assert await RoleCrud.count(db_session) == 1
assert await UserCrud.count(db_session) == 1 assert await UserCrud.count(db_session) == 1
@@ -497,8 +493,6 @@ class TestLoadFixtures:
await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE) await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE)
await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE) await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE)
from .conftest import RoleCrud
count = await RoleCrud.count(db_session) count = await RoleCrud.count(db_session)
assert count == 1 assert count == 1
@@ -526,8 +520,6 @@ class TestLoadFixtures:
db_session, registry, "roles", strategy=LoadStrategy.SKIP_EXISTING db_session, registry, "roles", strategy=LoadStrategy.SKIP_EXISTING
) )
from .conftest import RoleCrud
role = await RoleCrud.first(db_session, [Role.id == role_id]) role = await RoleCrud.first(db_session, [Role.id == role_id])
assert role is not None assert role is not None
assert role.name == "original" assert role.name == "original"
@@ -553,8 +545,6 @@ class TestLoadFixtures:
assert "roles" in result assert "roles" in result
assert len(result["roles"]) == 2 assert len(result["roles"]) == 2
from .conftest import RoleCrud
count = await RoleCrud.count(db_session) count = await RoleCrud.count(db_session)
assert count == 2 assert count == 2
@@ -594,8 +584,6 @@ class TestLoadFixtures:
assert "roles" in result assert "roles" in result
assert "other_roles" in result assert "other_roles" in result
from .conftest import RoleCrud
count = await RoleCrud.count(db_session) count = await RoleCrud.count(db_session)
assert count == 2 assert count == 2
@@ -660,8 +648,6 @@ class TestLoadFixturesByContext:
await load_fixtures_by_context(db_session, registry, Context.BASE) await load_fixtures_by_context(db_session, registry, Context.BASE)
from .conftest import RoleCrud
count = await RoleCrud.count(db_session) count = await RoleCrud.count(db_session)
assert count == 1 assert count == 1
@@ -688,8 +674,6 @@ class TestLoadFixturesByContext:
db_session, registry, Context.BASE, Context.TESTING db_session, registry, Context.BASE, Context.TESTING
) )
from .conftest import RoleCrud
count = await RoleCrud.count(db_session) count = await RoleCrud.count(db_session)
assert count == 2 assert count == 2
@@ -717,8 +701,6 @@ class TestLoadFixturesByContext:
await load_fixtures_by_context(db_session, registry, Context.TESTING) await load_fixtures_by_context(db_session, registry, Context.TESTING)
from .conftest import RoleCrud, UserCrud
assert await RoleCrud.count(db_session) == 1 assert await RoleCrud.count(db_session) == 1
assert await UserCrud.count(db_session) == 1 assert await UserCrud.count(db_session) == 1

View File

@@ -171,8 +171,15 @@ class TestPytestImportGuard:
class TestCliImportGuard: class TestCliImportGuard:
"""Tests for CLI module import guard when typer is missing.""" """Tests for CLI module import guard when typer is missing."""
def test_import_raises_without_typer(self): @pytest.mark.parametrize(
"""Importing cli.app raises when typer is missing.""" "expected_match",
[
"typer",
r"pip install fastapi-toolsets\[cli\]",
],
)
def test_import_raises_without_typer(self, expected_match):
"""Importing cli.app raises when typer is missing, with an informative error message."""
saved, blocking_import = _reload_without_package( saved, blocking_import = _reload_without_package(
"fastapi_toolsets.cli.app", ["typer"] "fastapi_toolsets.cli.app", ["typer"]
) )
@@ -186,33 +193,7 @@ class TestCliImportGuard:
try: try:
with patch("builtins.__import__", side_effect=blocking_import): with patch("builtins.__import__", side_effect=blocking_import):
with pytest.raises(ImportError, match="typer"): with pytest.raises(ImportError, match=expected_match):
importlib.import_module("fastapi_toolsets.cli.app")
finally:
for key in list(sys.modules):
if key.startswith("fastapi_toolsets.cli.app") or key.startswith(
"fastapi_toolsets.cli.config"
):
sys.modules.pop(key, None)
sys.modules.update(saved)
def test_error_message_suggests_cli_extra(self):
"""Error message suggests installing the cli extra."""
saved, blocking_import = _reload_without_package(
"fastapi_toolsets.cli.app", ["typer"]
)
config_keys = [
k for k in sys.modules if k.startswith("fastapi_toolsets.cli.config")
]
for key in config_keys:
if key not in saved:
saved[key] = sys.modules.pop(key)
try:
with patch("builtins.__import__", side_effect=blocking_import):
with pytest.raises(
ImportError, match=r"pip install fastapi-toolsets\[cli\]"
):
importlib.import_module("fastapi_toolsets.cli.app") importlib.import_module("fastapi_toolsets.cli.app")
finally: finally:
for key in list(sys.modules): for key in list(sys.modules):

View File

@@ -1,6 +1,5 @@
"""Tests for fastapi_toolsets.metrics module.""" """Tests for fastapi_toolsets.metrics module."""
import os
import tempfile import tempfile
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
@@ -287,6 +286,16 @@ class TestIncludeRegistry:
class TestInitMetrics: class TestInitMetrics:
"""Tests for init_metrics function.""" """Tests for init_metrics function."""
@pytest.fixture
def metrics_client(self):
"""Create a FastAPI app with MetricsRegistry and return a TestClient."""
app = FastAPI()
registry = MetricsRegistry()
init_metrics(app, registry)
client = TestClient(app)
yield client
client.close()
def test_returns_app(self): def test_returns_app(self):
"""Returns the FastAPI app.""" """Returns the FastAPI app."""
app = FastAPI() app = FastAPI()
@@ -294,26 +303,14 @@ class TestInitMetrics:
result = init_metrics(app, registry) result = init_metrics(app, registry)
assert result is app assert result is app
def test_metrics_endpoint_responds(self): def test_metrics_endpoint_responds(self, metrics_client):
"""The /metrics endpoint returns 200.""" """The /metrics endpoint returns 200."""
app = FastAPI() response = metrics_client.get("/metrics")
registry = MetricsRegistry()
init_metrics(app, registry)
client = TestClient(app)
response = client.get("/metrics")
assert response.status_code == 200 assert response.status_code == 200
def test_metrics_endpoint_content_type(self): def test_metrics_endpoint_content_type(self, metrics_client):
"""The /metrics endpoint returns prometheus content type.""" """The /metrics endpoint returns prometheus content type."""
app = FastAPI() response = metrics_client.get("/metrics")
registry = MetricsRegistry()
init_metrics(app, registry)
client = TestClient(app)
response = client.get("/metrics")
assert "text/plain" in response.headers["content-type"] assert "text/plain" in response.headers["content-type"]
def test_custom_path(self): def test_custom_path(self):
@@ -445,11 +442,10 @@ class TestInitMetrics:
class TestMultiProcessMode: class TestMultiProcessMode:
"""Tests for multi-process Prometheus mode.""" """Tests for multi-process Prometheus mode."""
def test_multiprocess_with_env_var(self): def test_multiprocess_with_env_var(self, monkeypatch):
"""Multi-process mode works when PROMETHEUS_MULTIPROC_DIR is set.""" """Multi-process mode works when PROMETHEUS_MULTIPROC_DIR is set."""
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
os.environ["PROMETHEUS_MULTIPROC_DIR"] = tmpdir monkeypatch.setenv("PROMETHEUS_MULTIPROC_DIR", tmpdir)
try:
# Use a separate registry to avoid conflicts with default # Use a separate registry to avoid conflicts with default
prom_registry = CollectorRegistry() prom_registry = CollectorRegistry()
app = FastAPI() app = FastAPI()
@@ -469,12 +465,10 @@ class TestMultiProcessMode:
response = client.get("/metrics") response = client.get("/metrics")
assert response.status_code == 200 assert response.status_code == 200
finally:
del os.environ["PROMETHEUS_MULTIPROC_DIR"]
def test_single_process_without_env_var(self): def test_single_process_without_env_var(self, monkeypatch):
"""Single-process mode when PROMETHEUS_MULTIPROC_DIR is not set.""" """Single-process mode when PROMETHEUS_MULTIPROC_DIR is not set."""
os.environ.pop("PROMETHEUS_MULTIPROC_DIR", None) monkeypatch.delenv("PROMETHEUS_MULTIPROC_DIR", raising=False)
app = FastAPI() app = FastAPI()
registry = MetricsRegistry() registry = MetricsRegistry()

View File

@@ -6,27 +6,28 @@ from contextlib import suppress
from types import SimpleNamespace from types import SimpleNamespace
from unittest.mock import patch from unittest.mock import patch
import fastapi_toolsets.models.watched as _watched_module
import pytest import pytest
from sqlalchemy import String from sqlalchemy import String
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from fastapi_toolsets.pytest import create_db_session
import fastapi_toolsets.models.watched as _watched_module
from fastapi_toolsets.models import ( from fastapi_toolsets.models import (
CreatedAtMixin, CreatedAtMixin,
ModelEvent, ModelEvent,
TimestampMixin, TimestampMixin,
UpdatedAtMixin,
UUIDMixin, UUIDMixin,
UUIDv7Mixin, UUIDv7Mixin,
UpdatedAtMixin,
WatchedFieldsMixin, WatchedFieldsMixin,
watch, watch,
) )
from fastapi_toolsets.models.watched import ( from fastapi_toolsets.models.watched import (
_SESSION_CREATES, _SESSION_CREATES,
_SESSION_DELETES, _SESSION_DELETES,
_SESSION_UPDATES,
_SESSION_PENDING_NEW, _SESSION_PENDING_NEW,
_SESSION_UPDATES,
_after_commit, _after_commit,
_after_flush, _after_flush,
_after_flush_postexec, _after_flush_postexec,
@@ -81,8 +82,6 @@ class FullMixinModel(MixinBase, UUIDMixin, UpdatedAtMixin):
name: Mapped[str] = mapped_column(String(50)) name: Mapped[str] = mapped_column(String(50))
# --- WatchedFieldsMixin test models ---
_test_events: list[dict] = [] _test_events: list[dict] = []
@@ -145,6 +144,66 @@ class NonWatchedModel(MixinBase):
value: Mapped[str] = mapped_column(String(50)) value: Mapped[str] = mapped_column(String(50))
_poly_events: list[dict] = []
class PolyAnimal(MixinBase, UUIDMixin, WatchedFieldsMixin):
"""Base class for STI polymorphism tests."""
__tablename__ = "mixin_poly_animals"
__mapper_args__ = {"polymorphic_on": "kind", "polymorphic_identity": "animal"}
kind: Mapped[str] = mapped_column(String(50))
name: Mapped[str] = mapped_column(String(50))
async def on_create(self) -> None:
_poly_events.append(
{"event": "create", "type": type(self).__name__, "obj_id": self.id}
)
async def on_delete(self) -> None:
_poly_events.append(
{"event": "delete", "type": type(self).__name__, "obj_id": self.id}
)
class PolyDog(PolyAnimal):
"""STI subclass — shares the same table as PolyAnimal."""
__mapper_args__ = {"polymorphic_identity": "dog"}
_watch_inherit_events: list[dict] = []
@watch("status")
class WatchParent(MixinBase, UUIDMixin, WatchedFieldsMixin):
"""Base class with @watch("status") — subclasses should inherit this filter."""
__tablename__ = "mixin_watch_parent"
__mapper_args__ = {"polymorphic_on": "kind", "polymorphic_identity": "parent"}
kind: Mapped[str] = mapped_column(String(50))
status: Mapped[str] = mapped_column(String(50))
other: Mapped[str] = mapped_column(String(50))
async def on_update(self, changes: dict) -> None:
_watch_inherit_events.append({"type": type(self).__name__, "changes": changes})
class WatchChild(WatchParent):
"""STI subclass that does NOT redeclare @watch — should inherit parent's filter."""
__mapper_args__ = {"polymorphic_identity": "child"}
@watch("other")
class WatchOverride(WatchParent):
"""STI subclass that overrides @watch with a different field."""
__mapper_args__ = {"polymorphic_identity": "override"}
_attr_access_events: list[dict] = [] _attr_access_events: list[dict] = []
@@ -172,6 +231,7 @@ class AttrAccessModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
_sync_events: list[dict] = [] _sync_events: list[dict] = []
_future_events: list[str] = []
@watch("status") @watch("status")
@@ -192,41 +252,33 @@ class SyncCallbackModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
_sync_events.append({"event": "update", "changes": changes}) _sync_events.append({"event": "update", "changes": changes})
class FutureCallbackModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
"""Model whose on_create returns an asyncio.Task (awaitable, not a coroutine)."""
__tablename__ = "mixin_future_callback_models"
name: Mapped[str] = mapped_column(String(50))
def on_create(self) -> "asyncio.Task[None]":
async def _work() -> None:
_future_events.append("created")
return asyncio.ensure_future(_work())
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
async def mixin_session(): async def mixin_session():
engine = create_async_engine(DATABASE_URL, echo=False) async with create_db_session(DATABASE_URL, MixinBase) as session:
async with engine.begin() as conn:
await conn.run_sync(MixinBase.metadata.create_all)
session_factory = async_sessionmaker(engine, expire_on_commit=False)
session = session_factory()
try:
yield session yield session
finally:
await session.close()
async with engine.begin() as conn:
await conn.run_sync(MixinBase.metadata.drop_all)
await engine.dispose()
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
async def mixin_session_expire(): async def mixin_session_expire():
"""Session with expire_on_commit=True (the default) to exercise attribute access after commit.""" """Session with expire_on_commit=True (the default) to exercise attribute access after commit."""
engine = create_async_engine(DATABASE_URL, echo=False) async with create_db_session(
async with engine.begin() as conn: DATABASE_URL, MixinBase, expire_on_commit=True
await conn.run_sync(MixinBase.metadata.create_all) ) as session:
session_factory = async_sessionmaker(engine, expire_on_commit=True)
session = session_factory()
try:
yield session yield session
finally:
await session.close()
async with engine.begin() as conn:
await conn.run_sync(MixinBase.metadata.drop_all)
await engine.dispose()
class TestUUIDMixin: class TestUUIDMixin:
@@ -473,6 +525,67 @@ class TestWatchDecorator:
watch() watch()
class TestWatchInheritance:
@pytest.fixture(autouse=True)
def clear_events(self):
_watch_inherit_events.clear()
yield
_watch_inherit_events.clear()
@pytest.mark.anyio
async def test_child_inherits_parent_watch_filter(self, mixin_session):
"""Subclass without @watch inherits the parent's field filter."""
obj = WatchChild(status="initial", other="x")
mixin_session.add(obj)
await mixin_session.commit()
await asyncio.sleep(0)
obj.other = "changed" # not watched by parent's @watch("status")
await mixin_session.commit()
await asyncio.sleep(0)
assert _watch_inherit_events == []
@pytest.mark.anyio
async def test_child_triggers_on_watched_field(self, mixin_session):
"""Subclass without @watch triggers on_update for the parent's watched field."""
obj = WatchChild(status="initial", other="x")
mixin_session.add(obj)
await mixin_session.commit()
await asyncio.sleep(0)
obj.status = "updated"
await mixin_session.commit()
await asyncio.sleep(0)
assert len(_watch_inherit_events) == 1
assert _watch_inherit_events[0]["type"] == "WatchChild"
assert "status" in _watch_inherit_events[0]["changes"]
@pytest.mark.anyio
async def test_subclass_override_takes_precedence(self, mixin_session):
"""Subclass @watch overrides the parent's field filter."""
obj = WatchOverride(status="initial", other="x")
mixin_session.add(obj)
await mixin_session.commit()
await asyncio.sleep(0)
obj.status = (
"changed" # watched by parent but overridden by child's @watch("other")
)
await mixin_session.commit()
await asyncio.sleep(0)
assert _watch_inherit_events == []
obj.other = "changed"
await mixin_session.commit()
await asyncio.sleep(0)
assert len(_watch_inherit_events) == 1
assert "other" in _watch_inherit_events[0]["changes"]
class TestUpsertChanges: class TestUpsertChanges:
def test_inserts_new_entry(self): def test_inserts_new_entry(self):
"""New key is inserted with the full changes dict.""" """New key is inserted with the full changes dict."""
@@ -871,6 +984,119 @@ class TestWatchedFieldsMixin:
} }
class TestTransientObject:
"""Create + delete within the same transaction should fire no events."""
@pytest.fixture(autouse=True)
def clear_events(self):
_test_events.clear()
yield
_test_events.clear()
@pytest.mark.anyio
async def test_no_events_when_created_and_deleted_in_same_transaction(
self, mixin_session
):
"""Neither on_create nor on_delete fires when the object never survives a commit."""
obj = WatchedModel(status="active", other="x")
mixin_session.add(obj)
await mixin_session.flush()
await mixin_session.delete(obj)
await mixin_session.commit()
await asyncio.sleep(0)
assert _test_events == []
@pytest.mark.anyio
async def test_other_objects_unaffected(self, mixin_session):
"""on_create still fires for objects that are not deleted in the same transaction."""
survivor = WatchedModel(status="active", other="x")
transient = WatchedModel(status="gone", other="y")
mixin_session.add(survivor)
mixin_session.add(transient)
await mixin_session.flush()
await mixin_session.delete(transient)
await mixin_session.commit()
await asyncio.sleep(0)
creates = [e for e in _test_events if e["event"] == "create"]
deletes = [e for e in _test_events if e["event"] == "delete"]
assert len(creates) == 1
assert creates[0]["obj_id"] == survivor.id
assert deletes == []
@pytest.mark.anyio
async def test_distinct_create_and_delete_both_fire(self, mixin_session):
"""on_create and on_delete both fire when different objects are created and deleted."""
existing = WatchedModel(status="old", other="x")
mixin_session.add(existing)
await mixin_session.commit()
await asyncio.sleep(0)
_test_events.clear()
new_obj = WatchedModel(status="new", other="y")
mixin_session.add(new_obj)
await mixin_session.delete(existing)
await mixin_session.commit()
await asyncio.sleep(0)
creates = [e for e in _test_events if e["event"] == "create"]
deletes = [e for e in _test_events if e["event"] == "delete"]
assert len(creates) == 1
assert len(deletes) == 1
class TestPolymorphism:
"""WatchedFieldsMixin with STI (Single Table Inheritance)."""
@pytest.fixture(autouse=True)
def clear_events(self):
_poly_events.clear()
yield
_poly_events.clear()
@pytest.mark.anyio
async def test_on_create_fires_once_for_subclass(self, mixin_session):
"""on_create fires exactly once for a STI subclass instance."""
dog = PolyDog(name="Rex")
mixin_session.add(dog)
await mixin_session.commit()
await asyncio.sleep(0)
assert len(_poly_events) == 1
assert _poly_events[0]["event"] == "create"
assert _poly_events[0]["type"] == "PolyDog"
@pytest.mark.anyio
async def test_on_delete_fires_for_subclass(self, mixin_session):
"""on_delete fires for a STI subclass instance."""
dog = PolyDog(name="Rex")
mixin_session.add(dog)
await mixin_session.commit()
await asyncio.sleep(0)
_poly_events.clear()
await mixin_session.delete(dog)
await mixin_session.commit()
await asyncio.sleep(0)
assert len(_poly_events) == 1
assert _poly_events[0]["event"] == "delete"
assert _poly_events[0]["type"] == "PolyDog"
@pytest.mark.anyio
async def test_transient_subclass_fires_no_events(self, mixin_session):
"""Create + delete of a STI subclass in one transaction fires no events."""
dog = PolyDog(name="Rex")
mixin_session.add(dog)
await mixin_session.flush()
await mixin_session.delete(dog)
await mixin_session.commit()
await asyncio.sleep(0)
assert _poly_events == []
class TestWatchAll: class TestWatchAll:
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def clear_events(self): def clear_events(self):
@@ -968,6 +1194,28 @@ class TestSyncCallbacks:
assert updates[0]["changes"]["status"] == {"old": "initial", "new": "updated"} assert updates[0]["changes"]["status"] == {"old": "initial", "new": "updated"}
class TestFutureCallbacks:
"""Callbacks returning a non-coroutine awaitable (asyncio.Task / Future)."""
@pytest.fixture(autouse=True)
def clear_events(self):
_future_events.clear()
yield
_future_events.clear()
@pytest.mark.anyio
async def test_task_callback_is_awaited(self, mixin_session):
"""on_create returning an asyncio.Task is awaited and its work completes."""
obj = FutureCallbackModel(name="test")
mixin_session.add(obj)
await mixin_session.commit()
# Two turns: one for _run() to execute, one for the inner _work() task.
await asyncio.sleep(0)
await asyncio.sleep(0)
assert _future_events == ["created"]
class TestAttributeAccessInCallbacks: class TestAttributeAccessInCallbacks:
"""Verify that self attributes are accessible inside every callback type. """Verify that self attributes are accessible inside every callback type.

View File

@@ -201,6 +201,88 @@ class TestOffsetPagination:
assert data["page"] == 2 assert data["page"] == 2
assert data["has_more"] is True assert data["has_more"] is True
def test_total_count_can_be_none(self):
"""total_count accepts None (include_total=False mode)."""
pagination = OffsetPagination(
total_count=None,
items_per_page=20,
page=1,
has_more=True,
)
assert pagination.total_count is None
def test_serialization_with_none_total_count(self):
"""OffsetPagination serializes total_count=None correctly."""
pagination = OffsetPagination(
total_count=None,
items_per_page=20,
page=1,
has_more=False,
)
data = pagination.model_dump()
assert data["total_count"] is None
def test_pages_computed(self):
"""pages is ceil(total_count / items_per_page)."""
pagination = OffsetPagination(
total_count=42,
items_per_page=10,
page=1,
has_more=True,
)
assert pagination.pages == 5
def test_pages_exact_division(self):
"""pages is exact when total_count is evenly divisible."""
pagination = OffsetPagination(
total_count=40,
items_per_page=10,
page=1,
has_more=False,
)
assert pagination.pages == 4
def test_pages_zero_total(self):
"""pages is 0 when total_count is 0."""
pagination = OffsetPagination(
total_count=0,
items_per_page=10,
page=1,
has_more=False,
)
assert pagination.pages == 0
def test_pages_zero_items_per_page(self):
"""pages is 0 when items_per_page is 0."""
pagination = OffsetPagination(
total_count=100,
items_per_page=0,
page=1,
has_more=False,
)
assert pagination.pages == 0
def test_pages_none_when_total_count_none(self):
"""pages is None when total_count is None (include_total=False)."""
pagination = OffsetPagination(
total_count=None,
items_per_page=20,
page=1,
has_more=True,
)
assert pagination.pages is None
def test_pages_in_serialization(self):
"""pages appears in model_dump output."""
pagination = OffsetPagination(
total_count=25,
items_per_page=10,
page=1,
has_more=True,
)
data = pagination.model_dump()
assert data["pages"] == 3
class TestCursorPagination: class TestCursorPagination:
"""Tests for CursorPagination schema.""" """Tests for CursorPagination schema."""

2
uv.lock generated
View File

@@ -251,7 +251,7 @@ wheels = [
[[package]] [[package]]
name = "fastapi-toolsets" name = "fastapi-toolsets"
version = "2.4.0" version = "2.4.2"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "asyncpg" }, { name = "asyncpg" },