mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-03-01 17:00:48 +01:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
e17d385910
|
|||
|
|
6cf7df55ef | ||
|
|
7482bc5dad | ||
|
|
9d07dfea85 |
@@ -1,6 +1,6 @@
|
||||
# CRUD
|
||||
|
||||
Generic async CRUD operations for SQLAlchemy models with search, pagination, and many-to-many support. This module has features that are only compatible with Postgres.
|
||||
Generic async CRUD operations for SQLAlchemy models with search, pagination, and many-to-many support.
|
||||
|
||||
!!! info
|
||||
This module has been coded and tested to be compatible with PostgreSQL only.
|
||||
@@ -48,6 +48,21 @@ exists = await UserCrud.exists(session=session, filters=[User.email == email])
|
||||
|
||||
## Pagination
|
||||
|
||||
!!! info "Added in `v1.1` (only offset_pagination via `paginate` if `<v1.1`)"
|
||||
|
||||
Two pagination strategies are available. Both return a [`PaginatedResponse`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse) but differ in how they navigate through results.
|
||||
|
||||
| | `offset_paginate` | `cursor_paginate` |
|
||||
|---|---|---|
|
||||
| Total count | Yes | No |
|
||||
| Jump to arbitrary page | Yes | No |
|
||||
| Performance on deep pages | Degrades | Constant |
|
||||
| Stable under concurrent inserts | No | Yes |
|
||||
| Search compatible | Yes | Yes |
|
||||
| Use case | Admin panels, numbered pagination | Feeds, APIs, infinite scroll |
|
||||
|
||||
### Offset pagination
|
||||
|
||||
```python
|
||||
@router.get(
|
||||
"",
|
||||
@@ -58,14 +73,88 @@ async def get_users(
|
||||
items_per_page: int = 50,
|
||||
page: int = 1,
|
||||
):
|
||||
return await crud.UserCrud.paginate(
|
||||
return await crud.UserCrud.offset_paginate(
|
||||
session=session,
|
||||
items_per_page=items_per_page,
|
||||
page=page,
|
||||
)
|
||||
```
|
||||
|
||||
The [`paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.paginate) function will return a [`PaginatedResponse`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse).
|
||||
The [`offset_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.offset_paginate) method returns a [`PaginatedResponse`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse) whose `pagination` field is an [`OffsetPagination`](../reference/schemas.md#fastapi_toolsets.schemas.OffsetPagination) object:
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"data": ["..."],
|
||||
"pagination": {
|
||||
"total_count": 100,
|
||||
"page": 1,
|
||||
"items_per_page": 20,
|
||||
"has_more": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
!!! warning "Deprecated: `paginate`"
|
||||
The `paginate` function is a backward-compatible alias for `offset_paginate`. This function is **deprecated** and will be removed in **v2.0**.
|
||||
|
||||
### Cursor pagination
|
||||
|
||||
```python
|
||||
@router.get(
|
||||
"",
|
||||
response_model=PaginatedResponse[UserRead],
|
||||
)
|
||||
async def list_users(
|
||||
session: SessionDep,
|
||||
cursor: str | None = None,
|
||||
items_per_page: int = 20,
|
||||
):
|
||||
return await UserCrud.cursor_paginate(
|
||||
session=session,
|
||||
cursor=cursor,
|
||||
items_per_page=items_per_page,
|
||||
)
|
||||
```
|
||||
|
||||
The [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate) method returns a [`PaginatedResponse`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse) whose `pagination` field is a [`CursorPagination`](../reference/schemas.md#fastapi_toolsets.schemas.CursorPagination) object:
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"data": ["..."],
|
||||
"pagination": {
|
||||
"next_cursor": "eyJ2YWx1ZSI6ICIzZjQ3YWM2OS0uLi4ifQ==",
|
||||
"prev_cursor": null,
|
||||
"items_per_page": 20,
|
||||
"has_more": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Pass `next_cursor` as the `cursor` query parameter on the next request to advance to the next page. `prev_cursor` is set on pages 2+ and points back to the first item of the current page. Both are `null` when there is no adjacent page.
|
||||
|
||||
#### Choosing a cursor column
|
||||
|
||||
The cursor column is set once on [`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory) via the `cursor_column` parameter. It must be monotonically ordered for stable results:
|
||||
|
||||
- Auto-increment integer PKs
|
||||
- UUID v7 PKs
|
||||
- Timestamps
|
||||
|
||||
!!! warning
|
||||
Random UUID v4 PKs are **not** suitable as cursor columns because their ordering is non-deterministic.
|
||||
|
||||
!!! note
|
||||
`cursor_column` is required. Calling [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate) on a CRUD class that has no `cursor_column` configured raises a `ValueError`.
|
||||
|
||||
```python
|
||||
# Paginate by the primary key
|
||||
PostCrud = CrudFactory(model=Post, cursor_column=Post.id)
|
||||
|
||||
# Paginate by a timestamp column instead
|
||||
PostCrud = CrudFactory(model=Post, cursor_column=Post.created_at)
|
||||
```
|
||||
|
||||
## Search
|
||||
|
||||
@@ -82,7 +171,7 @@ PostCrud = CrudFactory(
|
||||
)
|
||||
```
|
||||
|
||||
This allow to do a search with the [`paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.paginate) function:
|
||||
This allows searching with both [`offset_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.offset_paginate) and [`cursor_paginate`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud.cursor_paginate):
|
||||
|
||||
```python
|
||||
@router.get(
|
||||
@@ -95,7 +184,7 @@ async def get_users(
|
||||
page: int = 1,
|
||||
search: str | None = None,
|
||||
):
|
||||
return await crud.UserCrud.paginate(
|
||||
return await crud.UserCrud.offset_paginate(
|
||||
session=session,
|
||||
items_per_page=items_per_page,
|
||||
page=page,
|
||||
@@ -103,6 +192,60 @@ async def get_users(
|
||||
)
|
||||
```
|
||||
|
||||
```python
|
||||
@router.get(
|
||||
"",
|
||||
response_model=PaginatedResponse[User],
|
||||
)
|
||||
async def get_users(
|
||||
session: SessionDep,
|
||||
cursor: str | None = None,
|
||||
items_per_page: int = 50,
|
||||
search: str | None = None,
|
||||
):
|
||||
return await crud.UserCrud.cursor_paginate(
|
||||
session=session,
|
||||
items_per_page=items_per_page,
|
||||
cursor=cursor,
|
||||
search=search,
|
||||
)
|
||||
```
|
||||
|
||||
## Relationship loading
|
||||
|
||||
!!! info "Added in `v1.1`"
|
||||
|
||||
By default, SQLAlchemy relationships are not loaded unless explicitly requested. Instead of using `lazy="selectin"` on model definitions (which is implicit and applies globally), define a `default_load_options` on the CRUD class to control loading strategy explicitly.
|
||||
|
||||
!!! warning
|
||||
Avoid using `lazy="selectin"` on model relationships. It fires silently on every query, cannot be disabled per-call, and can cause unexpected cascading loads through deep relationship chains. Use `default_load_options` instead.
|
||||
|
||||
```python
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
ArticleCrud = CrudFactory(
|
||||
model=Article,
|
||||
default_load_options=[
|
||||
selectinload(Article.category),
|
||||
selectinload(Article.tags),
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
`default_load_options` applies automatically to all read operations (`get`, `first`, `get_multi`, `offset_paginate`, `cursor_paginate`). When `load_options` is passed at call-site, it **fully replaces** `default_load_options` for that query — giving you precise per-call control:
|
||||
|
||||
```python
|
||||
# Only loads category, tags are not loaded
|
||||
article = await ArticleCrud.get(
|
||||
session=session,
|
||||
filters=[Article.id == article_id],
|
||||
load_options=[selectinload(Article.category)],
|
||||
)
|
||||
|
||||
# Loads nothing — useful for write-then-refresh flows or lightweight checks
|
||||
articles = await ArticleCrud.get_multi(session=session, load_options=[])
|
||||
```
|
||||
|
||||
## Many-to-many relationships
|
||||
|
||||
Use `m2m_fields` to map schema fields containing lists of IDs to SQLAlchemy relationships. The CRUD class resolves and validates all IDs before persisting:
|
||||
@@ -129,24 +272,42 @@ await UserCrud.upsert(
|
||||
)
|
||||
```
|
||||
|
||||
## `as_response`
|
||||
## `schema` — typed response serialization
|
||||
|
||||
Pass `as_response=True` to any write operation to get a [`Response[ModelType]`](../reference/schemas.md#fastapi_toolsets.schemas.Response) back directly for API usage:
|
||||
!!! info "Added in `v1.1`"
|
||||
|
||||
Pass a Pydantic schema class to `create`, `get`, `update`, or `offset_paginate` to serialize the result directly into that schema and wrap it in a [`Response[schema]`](../reference/schemas.md#fastapi_toolsets.schemas.Response) or [`PaginatedResponse[schema]`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse):
|
||||
|
||||
```python
|
||||
class UserRead(PydanticBase):
|
||||
id: UUID
|
||||
username: str
|
||||
|
||||
@router.get(
|
||||
"/{uuid}",
|
||||
response_model=Response[User],
|
||||
responses=generate_error_responses(NotFoundError),
|
||||
)
|
||||
async def get_user(session: SessionDep, uuid: UUID):
|
||||
async def get_user(session: SessionDep, uuid: UUID) -> Response[UserRead]:
|
||||
return await crud.UserCrud.get(
|
||||
session=session,
|
||||
filters=[User.id == uuid],
|
||||
as_response=True,
|
||||
schema=UserRead,
|
||||
)
|
||||
|
||||
@router.get("")
|
||||
async def list_users(session: SessionDep, page: int = 1) -> PaginatedResponse[UserRead]:
|
||||
return await crud.UserCrud.offset_paginate(
|
||||
session=session,
|
||||
page=page,
|
||||
schema=UserRead,
|
||||
)
|
||||
```
|
||||
|
||||
The schema must have `from_attributes=True` (or inherit from [`PydanticBase`](../reference/schemas.md#fastapi_toolsets.schemas.PydanticBase)) so it can be built from SQLAlchemy model instances.
|
||||
|
||||
!!! warning "Deprecated: `as_response`"
|
||||
The `as_response=True` parameter is **deprecated** and will be removed in **v2.0**. Replace it with `schema=YourSchema`.
|
||||
|
||||
---
|
||||
|
||||
[:material-api: API Reference](../reference/crud.md)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "fastapi-toolsets"
|
||||
version = "1.0.0"
|
||||
version = "1.1.0"
|
||||
description = "Reusable tools for FastAPI: async CRUD, fixtures, CLI, and standardized responses for SQLAlchemy + PostgreSQL"
|
||||
readme = "README.md"
|
||||
license = "MIT"
|
||||
|
||||
@@ -21,4 +21,4 @@ Example usage:
|
||||
return Response(data={"user": user.username}, message="Success")
|
||||
"""
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__version__ = "1.1.0"
|
||||
|
||||
@@ -2,28 +2,44 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import uuid as uuid_module
|
||||
import warnings
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import and_, func, select
|
||||
from sqlalchemy import Integer, Uuid, and_, func, select
|
||||
from sqlalchemy import delete as sql_delete
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import DeclarativeBase, QueryableAttribute, selectinload
|
||||
from sqlalchemy.sql.base import ExecutableOption
|
||||
from sqlalchemy.sql.roles import WhereHavingRole
|
||||
|
||||
from ..db import get_transaction
|
||||
from ..exceptions import NotFoundError
|
||||
from ..schemas import PaginatedResponse, Pagination, Response
|
||||
from ..schemas import CursorPagination, OffsetPagination, PaginatedResponse, Response
|
||||
from .search import SearchConfig, SearchFieldType, build_search_filters
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
||||
SchemaType = TypeVar("SchemaType", bound=BaseModel)
|
||||
JoinType = list[tuple[type[DeclarativeBase], Any]]
|
||||
M2MFieldType = Mapping[str, QueryableAttribute[Any]]
|
||||
|
||||
|
||||
def _encode_cursor(value: Any) -> str:
|
||||
"""Encode cursor column value as an base64 string."""
|
||||
return base64.b64encode(json.dumps(str(value)).encode()).decode()
|
||||
|
||||
|
||||
def _decode_cursor(cursor: str) -> str:
|
||||
"""Decode cursor base64 string."""
|
||||
return json.loads(base64.b64decode(cursor.encode()).decode())
|
||||
|
||||
|
||||
class AsyncCrud(Generic[ModelType]):
|
||||
"""Generic async CRUD operations for SQLAlchemy models.
|
||||
|
||||
@@ -33,26 +49,17 @@ class AsyncCrud(Generic[ModelType]):
|
||||
model: ClassVar[type[DeclarativeBase]]
|
||||
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
|
||||
m2m_fields: ClassVar[M2MFieldType | None] = None
|
||||
default_load_options: ClassVar[list[ExecutableOption] | None] = None
|
||||
cursor_column: ClassVar[Any | None] = None
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def create( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
obj: BaseModel,
|
||||
*,
|
||||
as_response: Literal[True],
|
||||
) -> Response[ModelType]: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def create( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
obj: BaseModel,
|
||||
*,
|
||||
as_response: Literal[False] = ...,
|
||||
) -> ModelType: ...
|
||||
def _resolve_load_options(
|
||||
cls, load_options: list[ExecutableOption] | None
|
||||
) -> list[ExecutableOption] | None:
|
||||
"""Return load_options if provided, else fall back to default_load_options."""
|
||||
if load_options is not None:
|
||||
return load_options
|
||||
return cls.default_load_options
|
||||
|
||||
@classmethod
|
||||
async def _resolve_m2m(
|
||||
@@ -110,6 +117,40 @@ class AsyncCrud(Generic[ModelType]):
|
||||
return set()
|
||||
return set(cls.m2m_fields.keys())
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def create( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
obj: BaseModel,
|
||||
*,
|
||||
schema: type[SchemaType],
|
||||
as_response: bool = ...,
|
||||
) -> Response[SchemaType]: ...
|
||||
|
||||
# Backward-compatible - will be removed in v2.0
|
||||
@overload
|
||||
@classmethod
|
||||
async def create( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
obj: BaseModel,
|
||||
*,
|
||||
as_response: Literal[True],
|
||||
schema: None = ...,
|
||||
) -> Response[ModelType]: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def create( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
obj: BaseModel,
|
||||
*,
|
||||
as_response: Literal[False] = ...,
|
||||
schema: None = ...,
|
||||
) -> ModelType: ...
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls: type[Self],
|
||||
@@ -117,17 +158,28 @@ class AsyncCrud(Generic[ModelType]):
|
||||
obj: BaseModel,
|
||||
*,
|
||||
as_response: bool = False,
|
||||
) -> ModelType | Response[ModelType]:
|
||||
schema: type[BaseModel] | None = None,
|
||||
) -> ModelType | Response[ModelType] | Response[Any]:
|
||||
"""Create a new record in the database.
|
||||
|
||||
Args:
|
||||
session: DB async session
|
||||
obj: Pydantic model with data to create
|
||||
as_response: If True, wrap result in Response object
|
||||
as_response: Deprecated. Use ``schema`` instead. Will be removed in v2.0.
|
||||
schema: Pydantic schema to serialize the result into. When provided,
|
||||
the result is automatically wrapped in a ``Response[schema]``.
|
||||
|
||||
Returns:
|
||||
Created model instance or Response wrapping it
|
||||
Created model instance, or ``Response[schema]`` when ``schema`` is given,
|
||||
or ``Response[ModelType]`` when ``as_response=True`` (deprecated).
|
||||
"""
|
||||
if as_response and schema is None:
|
||||
warnings.warn(
|
||||
"as_response is deprecated and will be removed in v2.0. "
|
||||
"Use schema=YourSchema instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
async with get_transaction(session):
|
||||
m2m_exclude = cls._m2m_schema_fields()
|
||||
data = (
|
||||
@@ -143,8 +195,9 @@ class AsyncCrud(Generic[ModelType]):
|
||||
session.add(db_model)
|
||||
await session.refresh(db_model)
|
||||
result = cast(ModelType, db_model)
|
||||
if as_response:
|
||||
return Response(data=result)
|
||||
if as_response or schema:
|
||||
data_out = schema.model_validate(result) if schema else result
|
||||
return Response(data=data_out)
|
||||
return result
|
||||
|
||||
@overload
|
||||
@@ -157,8 +210,25 @@ class AsyncCrud(Generic[ModelType]):
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
with_for_update: bool = False,
|
||||
load_options: list[Any] | None = None,
|
||||
load_options: list[ExecutableOption] | None = None,
|
||||
schema: type[SchemaType],
|
||||
as_response: bool = ...,
|
||||
) -> Response[SchemaType]: ...
|
||||
|
||||
# Backward-compatible - will be removed in v2.0
|
||||
@overload
|
||||
@classmethod
|
||||
async def get( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
filters: list[Any],
|
||||
*,
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
with_for_update: bool = False,
|
||||
load_options: list[ExecutableOption] | None = None,
|
||||
as_response: Literal[True],
|
||||
schema: None = ...,
|
||||
) -> Response[ModelType]: ...
|
||||
|
||||
@overload
|
||||
@@ -171,8 +241,9 @@ class AsyncCrud(Generic[ModelType]):
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
with_for_update: bool = False,
|
||||
load_options: list[Any] | None = None,
|
||||
load_options: list[ExecutableOption] | None = None,
|
||||
as_response: Literal[False] = ...,
|
||||
schema: None = ...,
|
||||
) -> ModelType: ...
|
||||
|
||||
@classmethod
|
||||
@@ -184,9 +255,10 @@ class AsyncCrud(Generic[ModelType]):
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
with_for_update: bool = False,
|
||||
load_options: list[Any] | None = None,
|
||||
load_options: list[ExecutableOption] | None = None,
|
||||
as_response: bool = False,
|
||||
) -> ModelType | Response[ModelType]:
|
||||
schema: type[BaseModel] | None = None,
|
||||
) -> ModelType | Response[ModelType] | Response[Any]:
|
||||
"""Get exactly one record. Raises NotFoundError if not found.
|
||||
|
||||
Args:
|
||||
@@ -196,15 +268,25 @@ class AsyncCrud(Generic[ModelType]):
|
||||
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
|
||||
with_for_update: Lock the row for update
|
||||
load_options: SQLAlchemy loader options (e.g., selectinload)
|
||||
as_response: If True, wrap result in Response object
|
||||
as_response: Deprecated. Use ``schema`` instead. Will be removed in v2.0.
|
||||
schema: Pydantic schema to serialize the result into. When provided,
|
||||
the result is automatically wrapped in a ``Response[schema]``.
|
||||
|
||||
Returns:
|
||||
Model instance or Response wrapping it
|
||||
Model instance, or ``Response[schema]`` when ``schema`` is given,
|
||||
or ``Response[ModelType]`` when ``as_response=True`` (deprecated).
|
||||
|
||||
Raises:
|
||||
NotFoundError: If no record found
|
||||
MultipleResultsFound: If more than one record found
|
||||
"""
|
||||
if as_response and schema is None:
|
||||
warnings.warn(
|
||||
"as_response is deprecated and will be removed in v2.0. "
|
||||
"Use schema=YourSchema instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
q = select(cls.model)
|
||||
if joins:
|
||||
for model, condition in joins:
|
||||
@@ -214,8 +296,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
else q.join(model, condition)
|
||||
)
|
||||
q = q.where(and_(*filters))
|
||||
if load_options:
|
||||
q = q.options(*load_options)
|
||||
if resolved := cls._resolve_load_options(load_options):
|
||||
q = q.options(*resolved)
|
||||
if with_for_update:
|
||||
q = q.with_for_update()
|
||||
result = await session.execute(q)
|
||||
@@ -223,8 +305,9 @@ class AsyncCrud(Generic[ModelType]):
|
||||
if not item:
|
||||
raise NotFoundError()
|
||||
result = cast(ModelType, item)
|
||||
if as_response:
|
||||
return Response(data=result)
|
||||
if as_response or schema:
|
||||
data_out = schema.model_validate(result) if schema else result
|
||||
return Response(data=data_out)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
@@ -235,7 +318,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
*,
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
load_options: list[Any] | None = None,
|
||||
load_options: list[ExecutableOption] | None = None,
|
||||
) -> ModelType | None:
|
||||
"""Get the first matching record, or None.
|
||||
|
||||
@@ -259,8 +342,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
)
|
||||
if filters:
|
||||
q = q.where(and_(*filters))
|
||||
if load_options:
|
||||
q = q.options(*load_options)
|
||||
if resolved := cls._resolve_load_options(load_options):
|
||||
q = q.options(*resolved)
|
||||
result = await session.execute(q)
|
||||
return cast(ModelType | None, result.unique().scalars().first())
|
||||
|
||||
@@ -272,7 +355,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
filters: list[Any] | None = None,
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
load_options: list[Any] | None = None,
|
||||
load_options: list[ExecutableOption] | None = None,
|
||||
order_by: Any | None = None,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
@@ -302,8 +385,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
)
|
||||
if filters:
|
||||
q = q.where(and_(*filters))
|
||||
if load_options:
|
||||
q = q.options(*load_options)
|
||||
if resolved := cls._resolve_load_options(load_options):
|
||||
q = q.options(*resolved)
|
||||
if order_by is not None:
|
||||
q = q.order_by(order_by)
|
||||
if offset is not None:
|
||||
@@ -313,6 +396,21 @@ class AsyncCrud(Generic[ModelType]):
|
||||
result = await session.execute(q)
|
||||
return cast(Sequence[ModelType], result.unique().scalars().all())
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def update( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
obj: BaseModel,
|
||||
filters: list[Any],
|
||||
*,
|
||||
exclude_unset: bool = True,
|
||||
exclude_none: bool = False,
|
||||
schema: type[SchemaType],
|
||||
as_response: bool = ...,
|
||||
) -> Response[SchemaType]: ...
|
||||
|
||||
# Backward-compatible - will be removed in v2.0
|
||||
@overload
|
||||
@classmethod
|
||||
async def update( # pragma: no cover
|
||||
@@ -324,6 +422,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
exclude_unset: bool = True,
|
||||
exclude_none: bool = False,
|
||||
as_response: Literal[True],
|
||||
schema: None = ...,
|
||||
) -> Response[ModelType]: ...
|
||||
|
||||
@overload
|
||||
@@ -337,6 +436,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
exclude_unset: bool = True,
|
||||
exclude_none: bool = False,
|
||||
as_response: Literal[False] = ...,
|
||||
schema: None = ...,
|
||||
) -> ModelType: ...
|
||||
|
||||
@classmethod
|
||||
@@ -349,7 +449,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
exclude_unset: bool = True,
|
||||
exclude_none: bool = False,
|
||||
as_response: bool = False,
|
||||
) -> ModelType | Response[ModelType]:
|
||||
schema: type[BaseModel] | None = None,
|
||||
) -> ModelType | Response[ModelType] | Response[Any]:
|
||||
"""Update a record in the database.
|
||||
|
||||
Args:
|
||||
@@ -358,20 +459,30 @@ class AsyncCrud(Generic[ModelType]):
|
||||
filters: List of SQLAlchemy filter conditions
|
||||
exclude_unset: Exclude fields not explicitly set in the schema
|
||||
exclude_none: Exclude fields with None value
|
||||
as_response: If True, wrap result in Response object
|
||||
as_response: Deprecated. Use ``schema`` instead. Will be removed in v2.0.
|
||||
schema: Pydantic schema to serialize the result into. When provided,
|
||||
the result is automatically wrapped in a ``Response[schema]``.
|
||||
|
||||
Returns:
|
||||
Updated model instance or Response wrapping it
|
||||
Updated model instance, or ``Response[schema]`` when ``schema`` is given,
|
||||
or ``Response[ModelType]`` when ``as_response=True`` (deprecated).
|
||||
|
||||
Raises:
|
||||
NotFoundError: If no record found
|
||||
"""
|
||||
if as_response and schema is None:
|
||||
warnings.warn(
|
||||
"as_response is deprecated and will be removed in v2.0. "
|
||||
"Use schema=YourSchema instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
async with get_transaction(session):
|
||||
m2m_exclude = cls._m2m_schema_fields()
|
||||
|
||||
# Eagerly load M2M relationships that will be updated so that
|
||||
# setattr does not trigger a lazy load (which fails in async).
|
||||
m2m_load_options: list[Any] = []
|
||||
m2m_load_options: list[ExecutableOption] = []
|
||||
if m2m_exclude and cls.m2m_fields:
|
||||
for schema_field, rel in cls.m2m_fields.items():
|
||||
if schema_field in obj.model_fields_set:
|
||||
@@ -395,8 +506,9 @@ class AsyncCrud(Generic[ModelType]):
|
||||
for rel_attr, related_instances in m2m_resolved.items():
|
||||
setattr(db_model, rel_attr, related_instances)
|
||||
await session.refresh(db_model)
|
||||
if as_response:
|
||||
return Response(data=db_model)
|
||||
if as_response or schema:
|
||||
data_out = schema.model_validate(db_model) if schema else db_model
|
||||
return Response(data=data_out)
|
||||
return db_model
|
||||
|
||||
@classmethod
|
||||
@@ -478,11 +590,20 @@ class AsyncCrud(Generic[ModelType]):
|
||||
Args:
|
||||
session: DB async session
|
||||
filters: List of SQLAlchemy filter conditions
|
||||
as_response: If True, wrap result in Response object
|
||||
as_response: Deprecated. Will be removed in v2.0. When ``True``,
|
||||
returns ``Response[None]`` instead of ``bool``.
|
||||
|
||||
Returns:
|
||||
True if deletion was executed, or Response wrapping it
|
||||
``True`` if deletion was executed, or ``Response[None]`` when
|
||||
``as_response=True`` (deprecated).
|
||||
"""
|
||||
if as_response:
|
||||
warnings.warn(
|
||||
"as_response is deprecated and will be removed in v2.0. "
|
||||
"Use schema=YourSchema instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
async with get_transaction(session):
|
||||
q = sql_delete(cls.model).where(and_(*filters))
|
||||
await session.execute(q)
|
||||
@@ -555,22 +676,60 @@ class AsyncCrud(Generic[ModelType]):
|
||||
result = await session.execute(q)
|
||||
return bool(result.scalar())
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def paginate(
|
||||
async def offset_paginate( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
*,
|
||||
filters: list[Any] | None = None,
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
load_options: list[Any] | None = None,
|
||||
load_options: list[ExecutableOption] | None = None,
|
||||
order_by: Any | None = None,
|
||||
page: int = 1,
|
||||
items_per_page: int = 20,
|
||||
search: str | SearchConfig | None = None,
|
||||
search_fields: Sequence[SearchFieldType] | None = None,
|
||||
) -> PaginatedResponse[ModelType]:
|
||||
"""Get paginated results with metadata.
|
||||
schema: type[SchemaType],
|
||||
) -> PaginatedResponse[SchemaType]: ...
|
||||
|
||||
# Backward-compatible - will be removed in v2.0
|
||||
@overload
|
||||
@classmethod
|
||||
async def offset_paginate( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
*,
|
||||
filters: list[Any] | None = None,
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
load_options: list[ExecutableOption] | None = None,
|
||||
order_by: Any | None = None,
|
||||
page: int = 1,
|
||||
items_per_page: int = 20,
|
||||
search: str | SearchConfig | None = None,
|
||||
search_fields: Sequence[SearchFieldType] | None = None,
|
||||
schema: None = ...,
|
||||
) -> PaginatedResponse[ModelType]: ...
|
||||
|
||||
@classmethod
|
||||
async def offset_paginate(
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
*,
|
||||
filters: list[Any] | None = None,
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
load_options: list[ExecutableOption] | None = None,
|
||||
order_by: Any | None = None,
|
||||
page: int = 1,
|
||||
items_per_page: int = 20,
|
||||
search: str | SearchConfig | None = None,
|
||||
search_fields: Sequence[SearchFieldType] | None = None,
|
||||
schema: type[BaseModel] | None = None,
|
||||
) -> PaginatedResponse[ModelType] | PaginatedResponse[Any]:
|
||||
"""Get paginated results using offset-based pagination.
|
||||
|
||||
Args:
|
||||
session: DB async session
|
||||
@@ -583,9 +742,10 @@ class AsyncCrud(Generic[ModelType]):
|
||||
items_per_page: Number of items per page
|
||||
search: Search query string or SearchConfig object
|
||||
search_fields: Fields to search in (overrides class default)
|
||||
schema: Optional Pydantic schema to serialize each item into.
|
||||
|
||||
Returns:
|
||||
Dict with 'data' and 'pagination' keys
|
||||
PaginatedResponse with OffsetPagination metadata
|
||||
"""
|
||||
filters = list(filters) if filters else []
|
||||
offset = (page - 1) * items_per_page
|
||||
@@ -619,14 +779,17 @@ class AsyncCrud(Generic[ModelType]):
|
||||
|
||||
if filters:
|
||||
q = q.where(and_(*filters))
|
||||
if load_options:
|
||||
q = q.options(*load_options)
|
||||
if resolved := cls._resolve_load_options(load_options):
|
||||
q = q.options(*resolved)
|
||||
if order_by is not None:
|
||||
q = q.order_by(order_by)
|
||||
|
||||
q = q.offset(offset).limit(items_per_page)
|
||||
result = await session.execute(q)
|
||||
items = cast(list[ModelType], result.unique().scalars().all())
|
||||
raw_items = cast(list[ModelType], result.unique().scalars().all())
|
||||
items: list[Any] = (
|
||||
[schema.model_validate(item) for item in raw_items] if schema else raw_items
|
||||
)
|
||||
|
||||
# Count query (with same joins and filters)
|
||||
pk_col = cls.model.__mapper__.primary_key[0]
|
||||
@@ -654,7 +817,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
|
||||
return PaginatedResponse(
|
||||
data=items,
|
||||
pagination=Pagination(
|
||||
pagination=OffsetPagination(
|
||||
total_count=total_count,
|
||||
items_per_page=items_per_page,
|
||||
page=page,
|
||||
@@ -662,12 +825,183 @@ class AsyncCrud(Generic[ModelType]):
|
||||
),
|
||||
)
|
||||
|
||||
# Backward-compatible - will be removed in v2.0
|
||||
paginate = offset_paginate
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def cursor_paginate( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
*,
|
||||
cursor: str | None = None,
|
||||
filters: list[Any] | None = None,
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
load_options: list[ExecutableOption] | None = None,
|
||||
order_by: Any | None = None,
|
||||
items_per_page: int = 20,
|
||||
search: str | SearchConfig | None = None,
|
||||
search_fields: Sequence[SearchFieldType] | None = None,
|
||||
schema: type[SchemaType],
|
||||
) -> PaginatedResponse[SchemaType]: ...
|
||||
|
||||
# Backward-compatible - will be removed in v2.0
|
||||
@overload
|
||||
@classmethod
|
||||
async def cursor_paginate( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
*,
|
||||
cursor: str | None = None,
|
||||
filters: list[Any] | None = None,
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
load_options: list[ExecutableOption] | None = None,
|
||||
order_by: Any | None = None,
|
||||
items_per_page: int = 20,
|
||||
search: str | SearchConfig | None = None,
|
||||
search_fields: Sequence[SearchFieldType] | None = None,
|
||||
schema: None = ...,
|
||||
) -> PaginatedResponse[ModelType]: ...
|
||||
|
||||
@classmethod
|
||||
async def cursor_paginate(
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
*,
|
||||
cursor: str | None = None,
|
||||
filters: list[Any] | None = None,
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
load_options: list[ExecutableOption] | None = None,
|
||||
order_by: Any | None = None,
|
||||
items_per_page: int = 20,
|
||||
search: str | SearchConfig | None = None,
|
||||
search_fields: Sequence[SearchFieldType] | None = None,
|
||||
schema: type[BaseModel] | None = None,
|
||||
) -> PaginatedResponse[ModelType] | PaginatedResponse[Any]:
|
||||
"""Get paginated results using cursor-based pagination.
|
||||
|
||||
Args:
|
||||
session: DB async session.
|
||||
cursor: Cursor string from a previous ``CursorPagination``.
|
||||
Omit (or pass ``None``) to start from the beginning.
|
||||
filters: List of SQLAlchemy filter conditions.
|
||||
joins: List of ``(model, condition)`` tuples for joining related
|
||||
tables.
|
||||
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN.
|
||||
load_options: SQLAlchemy loader options. Falls back to
|
||||
``default_load_options`` when not provided.
|
||||
order_by: Additional ordering applied after the cursor column.
|
||||
items_per_page: Number of items per page (default 20).
|
||||
search: Search query string or SearchConfig object.
|
||||
search_fields: Fields to search in (overrides class default).
|
||||
schema: Optional Pydantic schema to serialize each item into.
|
||||
|
||||
Returns:
|
||||
PaginatedResponse with CursorPagination metadata
|
||||
"""
|
||||
filters = list(filters) if filters else []
|
||||
search_joins: list[Any] = []
|
||||
|
||||
if cls.cursor_column is None:
|
||||
raise ValueError(
|
||||
f"{cls.__name__}.cursor_column is not set. "
|
||||
"Pass cursor_column=<column> to CrudFactory() to use cursor_paginate."
|
||||
)
|
||||
cursor_column: Any = cls.cursor_column
|
||||
cursor_col_name: str = cursor_column.key
|
||||
|
||||
if cursor is not None:
|
||||
raw_val = _decode_cursor(cursor)
|
||||
col_type = cursor_column.property.columns[0].type
|
||||
if isinstance(col_type, Integer):
|
||||
cursor_val: Any = int(raw_val)
|
||||
elif isinstance(col_type, Uuid):
|
||||
cursor_val = uuid_module.UUID(raw_val)
|
||||
else:
|
||||
cursor_val = raw_val
|
||||
filters.append(cursor_column > cursor_val)
|
||||
|
||||
# Build search filters
|
||||
if search:
|
||||
search_filters, search_joins = build_search_filters(
|
||||
cls.model,
|
||||
search,
|
||||
search_fields=search_fields,
|
||||
default_fields=cls.searchable_fields,
|
||||
)
|
||||
filters.extend(search_filters)
|
||||
|
||||
# Build query
|
||||
q = select(cls.model)
|
||||
|
||||
# Apply explicit joins
|
||||
if joins:
|
||||
for model, condition in joins:
|
||||
q = (
|
||||
q.outerjoin(model, condition)
|
||||
if outer_join
|
||||
else q.join(model, condition)
|
||||
)
|
||||
|
||||
# Apply search joins (always outer joins)
|
||||
for join_rel in search_joins:
|
||||
q = q.outerjoin(join_rel)
|
||||
|
||||
if filters:
|
||||
q = q.where(and_(*filters))
|
||||
if resolved := cls._resolve_load_options(load_options):
|
||||
q = q.options(*resolved)
|
||||
|
||||
# Cursor column is always the primary sort
|
||||
q = q.order_by(cursor_column)
|
||||
if order_by is not None:
|
||||
q = q.order_by(order_by)
|
||||
|
||||
# Fetch one extra to detect whether a next page exists
|
||||
q = q.limit(items_per_page + 1)
|
||||
result = await session.execute(q)
|
||||
raw_items = cast(list[ModelType], result.unique().scalars().all())
|
||||
|
||||
has_more = len(raw_items) > items_per_page
|
||||
items_page = raw_items[:items_per_page]
|
||||
|
||||
# next_cursor points past the last item on this page
|
||||
next_cursor: str | None = None
|
||||
if has_more and items_page:
|
||||
next_cursor = _encode_cursor(getattr(items_page[-1], cursor_col_name))
|
||||
|
||||
# prev_cursor points to the first item on this page or None when on the first page
|
||||
prev_cursor: str | None = None
|
||||
if cursor is not None and items_page:
|
||||
prev_cursor = _encode_cursor(getattr(items_page[0], cursor_col_name))
|
||||
|
||||
items: list[Any] = (
|
||||
[schema.model_validate(item) for item in items_page]
|
||||
if schema
|
||||
else items_page
|
||||
)
|
||||
|
||||
return PaginatedResponse(
|
||||
data=items,
|
||||
pagination=CursorPagination(
|
||||
next_cursor=next_cursor,
|
||||
prev_cursor=prev_cursor,
|
||||
items_per_page=items_per_page,
|
||||
has_more=has_more,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def CrudFactory(
|
||||
model: type[ModelType],
|
||||
*,
|
||||
searchable_fields: Sequence[SearchFieldType] | None = None,
|
||||
m2m_fields: M2MFieldType | None = None,
|
||||
default_load_options: list[ExecutableOption] | None = None,
|
||||
cursor_column: Any | None = None,
|
||||
) -> type[AsyncCrud[ModelType]]:
|
||||
"""Create a CRUD class for a specific model.
|
||||
|
||||
@@ -677,6 +1011,13 @@ def CrudFactory(
|
||||
m2m_fields: Optional mapping for many-to-many relationships.
|
||||
Maps schema field names (containing lists of IDs) to
|
||||
SQLAlchemy relationship attributes.
|
||||
default_load_options: Default SQLAlchemy loader options applied to all read
|
||||
queries when no explicit ``load_options`` are passed. Use this
|
||||
instead of ``lazy="selectin"`` on the model so that loading
|
||||
strategy is explicit and per-CRUD. Overridden entirely (not
|
||||
merged) when ``load_options`` is provided at call-site.
|
||||
cursor_column: Required to call ``cursor_paginate``
|
||||
Must be monotonically ordered (e.g. integer PK, UUID v7, timestamp).
|
||||
|
||||
Returns:
|
||||
AsyncCrud subclass bound to the model
|
||||
@@ -702,6 +1043,25 @@ def CrudFactory(
|
||||
m2m_fields={"tag_ids": Post.tags},
|
||||
)
|
||||
|
||||
# With a fixed cursor column for cursor_paginate:
|
||||
PostCrud = CrudFactory(
|
||||
Post,
|
||||
cursor_column=Post.created_at,
|
||||
)
|
||||
|
||||
# With default load strategy (replaces lazy="selectin" on the model):
|
||||
ArticleCrud = CrudFactory(
|
||||
Article,
|
||||
default_load_options=[selectinload(Article.category), selectinload(Article.tags)],
|
||||
)
|
||||
|
||||
# Override default_load_options for a specific call:
|
||||
article = await ArticleCrud.get(
|
||||
session,
|
||||
[Article.id == 1],
|
||||
load_options=[selectinload(Article.category)], # tags won't load
|
||||
)
|
||||
|
||||
# Usage
|
||||
user = await UserCrud.get(session, [User.id == 1])
|
||||
posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id])
|
||||
@@ -710,7 +1070,7 @@ def CrudFactory(
|
||||
post = await PostCrud.create(session, PostCreate(title="Hello", tag_ids=[id1, id2]))
|
||||
|
||||
# With search
|
||||
result = await UserCrud.paginate(session, search="john")
|
||||
result = await UserCrud.offset_paginate(session, search="john")
|
||||
|
||||
# With joins (inner join by default):
|
||||
users = await UserCrud.get_multi(
|
||||
@@ -734,6 +1094,8 @@ def CrudFactory(
|
||||
"model": model,
|
||||
"searchable_fields": searchable_fields,
|
||||
"m2m_fields": m2m_fields,
|
||||
"default_load_options": default_load_options,
|
||||
"cursor_column": cursor_column,
|
||||
},
|
||||
)
|
||||
return cast(type[AsyncCrud[ModelType]], cls)
|
||||
|
||||
@@ -7,7 +7,9 @@ from pydantic import BaseModel, ConfigDict
|
||||
|
||||
__all__ = [
|
||||
"ApiError",
|
||||
"CursorPagination",
|
||||
"ErrorResponse",
|
||||
"OffsetPagination",
|
||||
"Pagination",
|
||||
"PaginatedResponse",
|
||||
"PydanticBase",
|
||||
@@ -90,8 +92,8 @@ class ErrorResponse(BaseResponse):
|
||||
data: Any | None = None
|
||||
|
||||
|
||||
class Pagination(PydanticBase):
|
||||
"""Pagination metadata for list responses.
|
||||
class OffsetPagination(PydanticBase):
|
||||
"""Pagination metadata for offset-based list responses.
|
||||
|
||||
Attributes:
|
||||
total_count: Total number of items across all pages
|
||||
@@ -106,17 +108,28 @@ class Pagination(PydanticBase):
|
||||
has_more: bool
|
||||
|
||||
|
||||
class PaginatedResponse(BaseResponse, Generic[DataT]):
|
||||
"""Paginated API response for list endpoints.
|
||||
# Backward-compatible - will be removed in v2.0
|
||||
Pagination = OffsetPagination
|
||||
|
||||
Example:
|
||||
```python
|
||||
PaginatedResponse[UserRead](
|
||||
data=users,
|
||||
pagination=Pagination(total_count=100, items_per_page=10, page=1, has_more=True)
|
||||
)
|
||||
```
|
||||
|
||||
class CursorPagination(PydanticBase):
|
||||
"""Pagination metadata for cursor-based list responses.
|
||||
|
||||
Attributes:
|
||||
next_cursor: Encoded cursor for the next page, or None on the last page.
|
||||
prev_cursor: Encoded cursor for the previous page, or None on the first page.
|
||||
items_per_page: Number of items requested per page.
|
||||
has_more: Whether there is at least one more page after this one.
|
||||
"""
|
||||
|
||||
next_cursor: str | None
|
||||
prev_cursor: str | None = None
|
||||
items_per_page: int
|
||||
has_more: bool
|
||||
|
||||
|
||||
class PaginatedResponse(BaseResponse, Generic[DataT]):
|
||||
"""Paginated API response for list endpoints."""
|
||||
|
||||
data: list[DataT]
|
||||
pagination: Pagination
|
||||
pagination: OffsetPagination | CursorPagination
|
||||
|
||||
@@ -5,11 +5,12 @@ import uuid
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import Column, ForeignKey, String, Table, Uuid
|
||||
from sqlalchemy import Column, ForeignKey, Integer, String, Table, Uuid
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||
|
||||
from fastapi_toolsets.crud import CrudFactory
|
||||
from fastapi_toolsets.schemas import PydanticBase
|
||||
|
||||
DATABASE_URL = os.getenv(
|
||||
key="DATABASE_URL",
|
||||
@@ -69,6 +70,15 @@ post_tags = Table(
|
||||
)
|
||||
|
||||
|
||||
class IntRole(Base):
|
||||
"""Test role model with auto-increment integer PK."""
|
||||
|
||||
__tablename__ = "int_roles"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
name: Mapped[str] = mapped_column(String(50), unique=True)
|
||||
|
||||
|
||||
class Post(Base):
|
||||
"""Test post model."""
|
||||
|
||||
@@ -90,6 +100,13 @@ class RoleCreate(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class RoleRead(PydanticBase):
|
||||
"""Schema for reading a role."""
|
||||
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
|
||||
|
||||
class RoleUpdate(BaseModel):
|
||||
"""Schema for updating a role."""
|
||||
|
||||
@@ -106,6 +123,13 @@ class UserCreate(BaseModel):
|
||||
role_id: uuid.UUID | None = None
|
||||
|
||||
|
||||
class UserRead(PydanticBase):
|
||||
"""Schema for reading a user (subset of fields — no email)."""
|
||||
|
||||
id: uuid.UUID
|
||||
username: str
|
||||
|
||||
|
||||
class UserUpdate(BaseModel):
|
||||
"""Schema for updating a user."""
|
||||
|
||||
@@ -160,8 +184,17 @@ class PostM2MUpdate(BaseModel):
|
||||
tag_ids: list[uuid.UUID] | None = None
|
||||
|
||||
|
||||
class IntRoleCreate(BaseModel):
|
||||
"""Schema for creating an IntRole."""
|
||||
|
||||
name: str
|
||||
|
||||
|
||||
RoleCrud = CrudFactory(Role)
|
||||
RoleCursorCrud = CrudFactory(Role, cursor_column=Role.id)
|
||||
IntRoleCursorCrud = CrudFactory(IntRole, cursor_column=IntRole.id)
|
||||
UserCrud = CrudFactory(User)
|
||||
UserCursorCrud = CrudFactory(User, cursor_column=User.id)
|
||||
PostCrud = CrudFactory(Post)
|
||||
TagCrud = CrudFactory(Tag)
|
||||
PostM2MCrud = CrudFactory(Post, m2m_fields={"tag_ids": Post.tags})
|
||||
|
||||
@@ -11,6 +11,8 @@ from fastapi_toolsets.crud.factory import AsyncCrud
|
||||
from fastapi_toolsets.exceptions import NotFoundError
|
||||
|
||||
from .conftest import (
|
||||
IntRoleCreate,
|
||||
IntRoleCursorCrud,
|
||||
Post,
|
||||
PostCreate,
|
||||
PostCrud,
|
||||
@@ -20,12 +22,16 @@ from .conftest import (
|
||||
Role,
|
||||
RoleCreate,
|
||||
RoleCrud,
|
||||
RoleCursorCrud,
|
||||
RoleRead,
|
||||
RoleUpdate,
|
||||
TagCreate,
|
||||
TagCrud,
|
||||
User,
|
||||
UserCreate,
|
||||
UserCrud,
|
||||
UserCursorCrud,
|
||||
UserRead,
|
||||
UserUpdate,
|
||||
)
|
||||
|
||||
@@ -50,6 +56,152 @@ class TestCrudFactory:
|
||||
crud = CrudFactory(User)
|
||||
assert "User" in crud.__name__
|
||||
|
||||
def test_default_load_options_none_by_default(self):
|
||||
"""default_load_options is None when not specified."""
|
||||
crud = CrudFactory(User)
|
||||
assert crud.default_load_options is None
|
||||
|
||||
def test_default_load_options_set(self):
|
||||
"""default_load_options is stored on the class."""
|
||||
options = [selectinload(User.role)]
|
||||
crud = CrudFactory(User, default_load_options=options)
|
||||
assert crud.default_load_options == options
|
||||
|
||||
def test_default_load_options_not_shared_between_classes(self):
|
||||
"""default_load_options is isolated per factory call."""
|
||||
options = [selectinload(User.role)]
|
||||
crud_with = CrudFactory(User, default_load_options=options)
|
||||
crud_without = CrudFactory(User)
|
||||
assert crud_with.default_load_options == options
|
||||
assert crud_without.default_load_options is None
|
||||
|
||||
|
||||
class TestResolveLoadOptions:
|
||||
"""Tests for _resolve_load_options logic."""
|
||||
|
||||
def test_returns_load_options_when_provided(self):
|
||||
"""Explicit load_options takes priority over default_load_options."""
|
||||
options = [selectinload(User.role)]
|
||||
default = [selectinload(Post.tags)]
|
||||
crud = CrudFactory(User, default_load_options=default)
|
||||
assert crud._resolve_load_options(options) == options
|
||||
|
||||
def test_returns_default_when_load_options_is_none(self):
|
||||
"""Falls back to default_load_options when load_options is None."""
|
||||
default = [selectinload(User.role)]
|
||||
crud = CrudFactory(User, default_load_options=default)
|
||||
assert crud._resolve_load_options(None) == default
|
||||
|
||||
def test_returns_none_when_both_are_none(self):
|
||||
"""Returns None when neither load_options nor default_load_options set."""
|
||||
crud = CrudFactory(User)
|
||||
assert crud._resolve_load_options(None) is None
|
||||
|
||||
def test_empty_list_overrides_default(self):
|
||||
"""An empty list is a valid override and disables default_load_options."""
|
||||
default = [selectinload(User.role)]
|
||||
crud = CrudFactory(User, default_load_options=default)
|
||||
# Empty list is not None, so it should replace default
|
||||
assert crud._resolve_load_options([]) == []
|
||||
|
||||
|
||||
class TestDefaultLoadOptionsIntegration:
|
||||
"""Integration tests for default_load_options with real DB queries."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_default_load_options_applied_to_get(self, db_session: AsyncSession):
|
||||
"""default_load_options loads relationships automatically on get()."""
|
||||
UserWithDefaultLoad = CrudFactory(
|
||||
User, default_load_options=[selectinload(User.role)]
|
||||
)
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||
user = await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="alice", email="alice@test.com", role_id=role.id),
|
||||
)
|
||||
fetched = await UserWithDefaultLoad.get(db_session, [User.id == user.id])
|
||||
assert fetched.role is not None
|
||||
assert fetched.role.name == "admin"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_default_load_options_applied_to_get_multi(
|
||||
self, db_session: AsyncSession
|
||||
):
|
||||
"""default_load_options loads relationships automatically on get_multi()."""
|
||||
UserWithDefaultLoad = CrudFactory(
|
||||
User, default_load_options=[selectinload(User.role)]
|
||||
)
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="alice", email="alice@test.com", role_id=role.id),
|
||||
)
|
||||
users = await UserWithDefaultLoad.get_multi(db_session)
|
||||
assert users[0].role is not None
|
||||
assert users[0].role.name == "admin"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_default_load_options_applied_to_first(
|
||||
self, db_session: AsyncSession
|
||||
):
|
||||
"""default_load_options loads relationships automatically on first()."""
|
||||
UserWithDefaultLoad = CrudFactory(
|
||||
User, default_load_options=[selectinload(User.role)]
|
||||
)
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="alice", email="alice@test.com", role_id=role.id),
|
||||
)
|
||||
user = await UserWithDefaultLoad.first(db_session)
|
||||
assert user is not None
|
||||
assert user.role is not None
|
||||
assert user.role.name == "admin"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_default_load_options_applied_to_paginate(
|
||||
self, db_session: AsyncSession
|
||||
):
|
||||
"""default_load_options loads relationships automatically on paginate()."""
|
||||
UserWithDefaultLoad = CrudFactory(
|
||||
User, default_load_options=[selectinload(User.role)]
|
||||
)
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="alice", email="alice@test.com", role_id=role.id),
|
||||
)
|
||||
result = await UserWithDefaultLoad.paginate(db_session)
|
||||
assert result.data[0].role is not None
|
||||
assert result.data[0].role.name == "admin"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_load_options_overrides_default_load_options(
|
||||
self, db_session: AsyncSession
|
||||
):
|
||||
"""Explicit load_options fully replaces default_load_options."""
|
||||
PostWithDefaultLoad = CrudFactory(
|
||||
Post,
|
||||
default_load_options=[selectinload(Post.tags)],
|
||||
)
|
||||
user = await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="alice", email="alice@test.com"),
|
||||
)
|
||||
post = await PostCrud.create(
|
||||
db_session,
|
||||
PostCreate(title="Hello", author_id=user.id),
|
||||
)
|
||||
# Pass empty load_options to override default — tags should not load
|
||||
fetched = await PostWithDefaultLoad.get(
|
||||
db_session,
|
||||
[Post.id == post.id],
|
||||
load_options=[],
|
||||
)
|
||||
# tags were not loaded — accessing them would lazy-load or return empty
|
||||
# We just assert the fetch itself succeeded with the override
|
||||
assert fetched.id == post.id
|
||||
|
||||
|
||||
class TestCrudCreate:
|
||||
"""Tests for CRUD create operations."""
|
||||
@@ -433,8 +585,11 @@ class TestCrudPaginate:
|
||||
for i in range(25):
|
||||
await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
|
||||
|
||||
from fastapi_toolsets.schemas import OffsetPagination
|
||||
|
||||
result = await RoleCrud.paginate(db_session, page=1, items_per_page=10)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert len(result.data) == 10
|
||||
assert result.pagination.total_count == 25
|
||||
assert result.pagination.page == 1
|
||||
@@ -465,6 +620,8 @@ class TestCrudPaginate:
|
||||
),
|
||||
)
|
||||
|
||||
from fastapi_toolsets.schemas import OffsetPagination
|
||||
|
||||
result = await UserCrud.paginate(
|
||||
db_session,
|
||||
filters=[User.is_active == True], # noqa: E712
|
||||
@@ -472,6 +629,7 @@ class TestCrudPaginate:
|
||||
items_per_page=10,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 5
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -687,6 +845,8 @@ class TestCrudJoins:
|
||||
),
|
||||
)
|
||||
|
||||
from fastapi_toolsets.schemas import OffsetPagination
|
||||
|
||||
# Paginate users with published posts
|
||||
result = await UserCrud.paginate(
|
||||
db_session,
|
||||
@@ -696,6 +856,7 @@ class TestCrudJoins:
|
||||
items_per_page=10,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 3
|
||||
assert len(result.data) == 3
|
||||
|
||||
@@ -718,6 +879,8 @@ class TestCrudJoins:
|
||||
UserCreate(username="without_post", email="without@test.com"),
|
||||
)
|
||||
|
||||
from fastapi_toolsets.schemas import OffsetPagination
|
||||
|
||||
# Paginate with outer join
|
||||
result = await UserCrud.paginate(
|
||||
db_session,
|
||||
@@ -727,6 +890,7 @@ class TestCrudJoins:
|
||||
items_per_page=10,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 2
|
||||
assert len(result.data) == 2
|
||||
|
||||
@@ -761,15 +925,16 @@ class TestCrudJoins:
|
||||
|
||||
|
||||
class TestAsResponse:
|
||||
"""Tests for as_response parameter."""
|
||||
"""Tests for as_response parameter (deprecated, kept for backward compat)."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_as_response(self, db_session: AsyncSession):
|
||||
"""Create with as_response=True returns Response."""
|
||||
"""Create with as_response=True returns Response and emits DeprecationWarning."""
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
data = RoleCreate(name="response_role")
|
||||
result = await RoleCrud.create(db_session, data, as_response=True)
|
||||
with pytest.warns(DeprecationWarning, match="as_response is deprecated"):
|
||||
result = await RoleCrud.create(db_session, data, as_response=True)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert result.data is not None
|
||||
@@ -777,13 +942,14 @@ class TestAsResponse:
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_as_response(self, db_session: AsyncSession):
|
||||
"""Get with as_response=True returns Response."""
|
||||
"""Get with as_response=True returns Response and emits DeprecationWarning."""
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
created = await RoleCrud.create(db_session, RoleCreate(name="get_response"))
|
||||
result = await RoleCrud.get(
|
||||
db_session, [Role.id == created.id], as_response=True
|
||||
)
|
||||
with pytest.warns(DeprecationWarning, match="as_response is deprecated"):
|
||||
result = await RoleCrud.get(
|
||||
db_session, [Role.id == created.id], as_response=True
|
||||
)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert result.data is not None
|
||||
@@ -791,16 +957,17 @@ class TestAsResponse:
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_as_response(self, db_session: AsyncSession):
|
||||
"""Update with as_response=True returns Response."""
|
||||
"""Update with as_response=True returns Response and emits DeprecationWarning."""
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
created = await RoleCrud.create(db_session, RoleCreate(name="old_name"))
|
||||
result = await RoleCrud.update(
|
||||
db_session,
|
||||
RoleUpdate(name="new_name"),
|
||||
[Role.id == created.id],
|
||||
as_response=True,
|
||||
)
|
||||
with pytest.warns(DeprecationWarning, match="as_response is deprecated"):
|
||||
result = await RoleCrud.update(
|
||||
db_session,
|
||||
RoleUpdate(name="new_name"),
|
||||
[Role.id == created.id],
|
||||
as_response=True,
|
||||
)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert result.data is not None
|
||||
@@ -808,13 +975,14 @@ class TestAsResponse:
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_as_response(self, db_session: AsyncSession):
|
||||
"""Delete with as_response=True returns Response."""
|
||||
"""Delete with as_response=True returns Response and emits DeprecationWarning."""
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
created = await RoleCrud.create(db_session, RoleCreate(name="to_delete"))
|
||||
result = await RoleCrud.delete(
|
||||
db_session, [Role.id == created.id], as_response=True
|
||||
)
|
||||
with pytest.warns(DeprecationWarning, match="as_response is deprecated"):
|
||||
result = await RoleCrud.delete(
|
||||
db_session, [Role.id == created.id], as_response=True
|
||||
)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert result.data is None
|
||||
@@ -1198,3 +1366,656 @@ class TestM2MWithNonM2MCrud:
|
||||
[Post.id == post.id],
|
||||
)
|
||||
assert updated.title == "Updated Plain"
|
||||
|
||||
|
||||
class TestSchemaResponse:
|
||||
"""Tests for the schema parameter on as_response methods."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_with_schema(self, db_session: AsyncSession):
|
||||
"""create with schema returns Response[SchemaType]."""
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
result = await RoleCrud.create(
|
||||
db_session, RoleCreate(name="schema_role"), schema=RoleRead
|
||||
)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert isinstance(result.data, RoleRead)
|
||||
assert result.data.name == "schema_role"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_schema_implies_as_response(self, db_session: AsyncSession):
|
||||
"""create with schema alone wraps in Response without as_response=True."""
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
result = await RoleCrud.create(
|
||||
db_session, RoleCreate(name="implicit"), schema=RoleRead
|
||||
)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_schema_filters_fields(self, db_session: AsyncSession):
|
||||
"""create with schema only exposes schema fields, not all model fields."""
|
||||
result = await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="filtered", email="filtered@test.com"),
|
||||
schema=UserRead,
|
||||
)
|
||||
|
||||
assert isinstance(result.data, UserRead)
|
||||
assert result.data.username == "filtered"
|
||||
assert not hasattr(result.data, "email")
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_with_schema(self, db_session: AsyncSession):
|
||||
"""get with schema returns Response[SchemaType]."""
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
created = await RoleCrud.create(db_session, RoleCreate(name="get_schema"))
|
||||
result = await RoleCrud.get(
|
||||
db_session, [Role.id == created.id], schema=RoleRead
|
||||
)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert isinstance(result.data, RoleRead)
|
||||
assert result.data.id == created.id
|
||||
assert result.data.name == "get_schema"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_schema_implies_as_response(self, db_session: AsyncSession):
|
||||
"""get with schema alone wraps in Response without as_response=True."""
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
created = await RoleCrud.create(db_session, RoleCreate(name="implicit_get"))
|
||||
result = await RoleCrud.get(
|
||||
db_session, [Role.id == created.id], schema=RoleRead
|
||||
)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_with_schema(self, db_session: AsyncSession):
|
||||
"""update with schema returns Response[SchemaType]."""
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
created = await RoleCrud.create(db_session, RoleCreate(name="before"))
|
||||
result = await RoleCrud.update(
|
||||
db_session,
|
||||
RoleUpdate(name="after"),
|
||||
[Role.id == created.id],
|
||||
schema=RoleRead,
|
||||
)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert isinstance(result.data, RoleRead)
|
||||
assert result.data.name == "after"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_schema_implies_as_response(self, db_session: AsyncSession):
|
||||
"""update with schema alone wraps in Response without as_response=True."""
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
created = await RoleCrud.create(db_session, RoleCreate(name="before2"))
|
||||
result = await RoleCrud.update(
|
||||
db_session,
|
||||
RoleUpdate(name="after2"),
|
||||
[Role.id == created.id],
|
||||
schema=RoleRead,
|
||||
)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_paginate_with_schema(self, db_session: AsyncSession):
|
||||
"""paginate with schema returns PaginatedResponse[SchemaType]."""
|
||||
from fastapi_toolsets.schemas import PaginatedResponse
|
||||
|
||||
await RoleCrud.create(db_session, RoleCreate(name="p_role1"))
|
||||
await RoleCrud.create(db_session, RoleCreate(name="p_role2"))
|
||||
|
||||
result = await RoleCrud.paginate(db_session, schema=RoleRead)
|
||||
|
||||
assert isinstance(result, PaginatedResponse)
|
||||
assert len(result.data) == 2
|
||||
assert all(isinstance(item, RoleRead) for item in result.data)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_paginate_schema_filters_fields(self, db_session: AsyncSession):
|
||||
"""paginate with schema only exposes schema fields per item."""
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="pg_user", email="pg@test.com"),
|
||||
)
|
||||
|
||||
result = await UserCrud.paginate(db_session, schema=UserRead)
|
||||
|
||||
assert isinstance(result.data[0], UserRead)
|
||||
assert result.data[0].username == "pg_user"
|
||||
assert not hasattr(result.data[0], "email")
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_as_response_true_without_schema_unchanged(
|
||||
self, db_session: AsyncSession
|
||||
):
|
||||
"""as_response=True without schema still returns Response[ModelType] with a warning."""
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
created = await RoleCrud.create(db_session, RoleCreate(name="compat"))
|
||||
with pytest.warns(DeprecationWarning, match="as_response is deprecated"):
|
||||
result = await RoleCrud.get(
|
||||
db_session, [Role.id == created.id], as_response=True
|
||||
)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert isinstance(result.data, Role)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_schema_with_explicit_as_response_true(
|
||||
self, db_session: AsyncSession
|
||||
):
|
||||
"""schema combined with explicit as_response=True works correctly."""
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
created = await RoleCrud.create(db_session, RoleCreate(name="combined"))
|
||||
result = await RoleCrud.get(
|
||||
db_session,
|
||||
[Role.id == created.id],
|
||||
as_response=True,
|
||||
schema=RoleRead,
|
||||
)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert isinstance(result.data, RoleRead)
|
||||
|
||||
|
||||
class TestPaginateAlias:
|
||||
"""Tests that paginate is a backward-compatible alias for offset_paginate."""
|
||||
|
||||
def test_paginate_is_alias_of_offset_paginate(self):
|
||||
"""paginate and offset_paginate are the same underlying function."""
|
||||
assert RoleCrud.paginate.__func__ is RoleCrud.offset_paginate.__func__
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_paginate_alias_returns_offset_pagination(
|
||||
self, db_session: AsyncSession
|
||||
):
|
||||
"""paginate() still works and returns PaginatedResponse with OffsetPagination."""
|
||||
from fastapi_toolsets.schemas import OffsetPagination, PaginatedResponse
|
||||
|
||||
for i in range(3):
|
||||
await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
|
||||
|
||||
result = await RoleCrud.paginate(db_session, page=1, items_per_page=10)
|
||||
|
||||
assert isinstance(result, PaginatedResponse)
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 3
|
||||
assert result.pagination.page == 1
|
||||
|
||||
|
||||
class TestCursorPaginate:
|
||||
"""Tests for cursor-based pagination via cursor_paginate()."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_first_page_no_cursor(self, db_session: AsyncSession):
|
||||
"""cursor_paginate without cursor returns the first page."""
|
||||
from fastapi_toolsets.schemas import CursorPagination, PaginatedResponse
|
||||
|
||||
for i in range(25):
|
||||
await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
|
||||
|
||||
result = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=10)
|
||||
|
||||
assert isinstance(result, PaginatedResponse)
|
||||
assert isinstance(result.pagination, CursorPagination)
|
||||
assert len(result.data) == 10
|
||||
assert result.pagination.has_more is True
|
||||
assert result.pagination.next_cursor is not None
|
||||
assert result.pagination.prev_cursor is None
|
||||
assert result.pagination.items_per_page == 10
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_last_page(self, db_session: AsyncSession):
|
||||
"""cursor_paginate returns has_more=False and next_cursor=None on last page."""
|
||||
from fastapi_toolsets.schemas import CursorPagination
|
||||
|
||||
for i in range(5):
|
||||
await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
|
||||
|
||||
result = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=10)
|
||||
|
||||
assert isinstance(result.pagination, CursorPagination)
|
||||
assert len(result.data) == 5
|
||||
assert result.pagination.has_more is False
|
||||
assert result.pagination.next_cursor is None
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_advances_correctly(self, db_session: AsyncSession):
|
||||
"""Providing next_cursor from the first page returns the next page."""
|
||||
for i in range(15):
|
||||
await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
|
||||
|
||||
from fastapi_toolsets.schemas import CursorPagination
|
||||
|
||||
page1 = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=10)
|
||||
assert isinstance(page1.pagination, CursorPagination)
|
||||
assert len(page1.data) == 10
|
||||
assert page1.pagination.has_more is True
|
||||
|
||||
cursor = page1.pagination.next_cursor
|
||||
page2 = await RoleCursorCrud.cursor_paginate(
|
||||
db_session, cursor=cursor, items_per_page=10
|
||||
)
|
||||
assert isinstance(page2.pagination, CursorPagination)
|
||||
assert len(page2.data) == 5
|
||||
assert page2.pagination.has_more is False
|
||||
assert page2.pagination.next_cursor is None
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_no_duplicates_across_pages(self, db_session: AsyncSession):
|
||||
"""Items from consecutive cursor pages are non-overlapping and cover all rows."""
|
||||
for i in range(7):
|
||||
await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
|
||||
|
||||
from fastapi_toolsets.schemas import CursorPagination
|
||||
|
||||
page1 = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=4)
|
||||
assert isinstance(page1.pagination, CursorPagination)
|
||||
page2 = await RoleCursorCrud.cursor_paginate(
|
||||
db_session,
|
||||
cursor=page1.pagination.next_cursor,
|
||||
items_per_page=4,
|
||||
)
|
||||
|
||||
ids_page1 = {r.id for r in page1.data}
|
||||
ids_page2 = {r.id for r in page2.data}
|
||||
assert ids_page1.isdisjoint(ids_page2)
|
||||
assert len(ids_page1 | ids_page2) == 7
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_empty_table(self, db_session: AsyncSession):
|
||||
"""cursor_paginate on an empty table returns empty data with no cursor."""
|
||||
from fastapi_toolsets.schemas import CursorPagination
|
||||
|
||||
result = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=10)
|
||||
|
||||
assert isinstance(result.pagination, CursorPagination)
|
||||
assert result.data == []
|
||||
assert result.pagination.has_more is False
|
||||
assert result.pagination.next_cursor is None
|
||||
assert result.pagination.prev_cursor is None
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_with_filters(self, db_session: AsyncSession):
|
||||
"""cursor_paginate respects filters."""
|
||||
for i in range(10):
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(
|
||||
username=f"user{i}",
|
||||
email=f"user{i}@test.com",
|
||||
is_active=i % 2 == 0,
|
||||
),
|
||||
)
|
||||
|
||||
result = await UserCursorCrud.cursor_paginate(
|
||||
db_session,
|
||||
filters=[User.is_active == True], # noqa: E712
|
||||
items_per_page=20,
|
||||
)
|
||||
|
||||
assert len(result.data) == 5
|
||||
assert all(u.is_active for u in result.data)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_with_schema(self, db_session: AsyncSession):
|
||||
"""cursor_paginate with schema serializes items into the schema."""
|
||||
from fastapi_toolsets.schemas import PaginatedResponse
|
||||
|
||||
for i in range(3):
|
||||
await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
|
||||
|
||||
result = await RoleCursorCrud.cursor_paginate(db_session, schema=RoleRead)
|
||||
|
||||
assert isinstance(result, PaginatedResponse)
|
||||
assert all(isinstance(item, RoleRead) for item in result.data)
|
||||
assert all(
|
||||
hasattr(item, "id") and hasattr(item, "name") for item in result.data
|
||||
)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_with_cursor_column(self, db_session: AsyncSession):
|
||||
"""cursor_paginate uses cursor_column set on CrudFactory."""
|
||||
from fastapi_toolsets.crud import CrudFactory
|
||||
from fastapi_toolsets.schemas import CursorPagination
|
||||
|
||||
RoleNameCrud = CrudFactory(Role, cursor_column=Role.name)
|
||||
|
||||
for i in range(5):
|
||||
await RoleNameCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
|
||||
|
||||
result = await RoleNameCrud.cursor_paginate(db_session, items_per_page=3)
|
||||
|
||||
assert isinstance(result.pagination, CursorPagination)
|
||||
assert len(result.data) == 3
|
||||
assert result.pagination.has_more is True
|
||||
assert result.pagination.next_cursor is not None
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_raises_without_cursor_column(self, db_session: AsyncSession):
|
||||
"""cursor_paginate raises ValueError when cursor_column is not configured."""
|
||||
with pytest.raises(ValueError, match="cursor_column is not set"):
|
||||
await RoleCrud.cursor_paginate(db_session)
|
||||
|
||||
|
||||
class TestCursorPaginatePrevCursor:
|
||||
"""Tests for prev_cursor behavior in cursor_paginate()."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_prev_cursor_none_on_first_page(self, db_session: AsyncSession):
|
||||
"""prev_cursor is None when no cursor was provided (first page)."""
|
||||
for i in range(5):
|
||||
await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
|
||||
|
||||
from fastapi_toolsets.schemas import CursorPagination
|
||||
|
||||
result = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=3)
|
||||
|
||||
assert isinstance(result.pagination, CursorPagination)
|
||||
assert result.pagination.prev_cursor is None
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_prev_cursor_set_on_subsequent_pages(self, db_session: AsyncSession):
|
||||
"""prev_cursor is set when a cursor was provided (subsequent pages)."""
|
||||
for i in range(10):
|
||||
await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
|
||||
|
||||
from fastapi_toolsets.schemas import CursorPagination
|
||||
|
||||
page1 = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=5)
|
||||
assert isinstance(page1.pagination, CursorPagination)
|
||||
page2 = await RoleCursorCrud.cursor_paginate(
|
||||
db_session,
|
||||
cursor=page1.pagination.next_cursor,
|
||||
items_per_page=5,
|
||||
)
|
||||
assert isinstance(page2.pagination, CursorPagination)
|
||||
assert page2.pagination.prev_cursor is not None
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_prev_cursor_points_to_first_item(self, db_session: AsyncSession):
|
||||
"""prev_cursor encodes the value of the first item on the current page."""
|
||||
import base64
|
||||
import json
|
||||
|
||||
for i in range(10):
|
||||
await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
|
||||
|
||||
from fastapi_toolsets.schemas import CursorPagination
|
||||
|
||||
page1 = await RoleCursorCrud.cursor_paginate(db_session, items_per_page=5)
|
||||
assert isinstance(page1.pagination, CursorPagination)
|
||||
page2 = await RoleCursorCrud.cursor_paginate(
|
||||
db_session,
|
||||
cursor=page1.pagination.next_cursor,
|
||||
items_per_page=5,
|
||||
)
|
||||
assert isinstance(page2.pagination, CursorPagination)
|
||||
assert page2.pagination.prev_cursor is not None
|
||||
|
||||
# Decode prev_cursor and compare to first item's id
|
||||
decoded = json.loads(
|
||||
base64.b64decode(page2.pagination.prev_cursor.encode()).decode()
|
||||
)
|
||||
first_item_id = str(page2.data[0].id)
|
||||
assert decoded == first_item_id
|
||||
|
||||
|
||||
class TestCursorPaginateWithSearch:
|
||||
"""Tests for cursor_paginate() combined with search."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_cursor_paginate_with_search(self, db_session: AsyncSession):
|
||||
"""cursor_paginate respects search filters alongside cursor predicate."""
|
||||
from fastapi_toolsets.crud import CrudFactory
|
||||
|
||||
# Create a CRUD with searchable fields and cursor column
|
||||
SearchableRoleCrud = CrudFactory(
|
||||
Role, searchable_fields=[Role.name], cursor_column=Role.id
|
||||
)
|
||||
|
||||
for i in range(5):
|
||||
await RoleCrud.create(db_session, RoleCreate(name=f"admin{i:02d}"))
|
||||
for i in range(5):
|
||||
await RoleCrud.create(db_session, RoleCreate(name=f"user{i:02d}"))
|
||||
|
||||
result = await SearchableRoleCrud.cursor_paginate(
|
||||
db_session,
|
||||
search="admin",
|
||||
items_per_page=20,
|
||||
)
|
||||
|
||||
assert len(result.data) == 5
|
||||
assert all("admin" in r.name for r in result.data)
|
||||
|
||||
|
||||
class TestCursorPaginateExtraOptions:
|
||||
"""Tests for cursor_paginate() covering joins, load_options, and order_by."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_with_joins(self, db_session: AsyncSession):
|
||||
"""cursor_paginate applies explicit inner joins."""
|
||||
from fastapi_toolsets.schemas import CursorPagination
|
||||
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="member"))
|
||||
for i in range(5):
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(
|
||||
username=f"u{i}",
|
||||
email=f"u{i}@test.com",
|
||||
role_id=role.id,
|
||||
),
|
||||
)
|
||||
# One user without a role to confirm inner join excludes them
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="norole", email="norole@test.com"),
|
||||
)
|
||||
|
||||
result = await UserCursorCrud.cursor_paginate(
|
||||
db_session,
|
||||
joins=[(Role, User.role_id == Role.id)],
|
||||
items_per_page=20,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, CursorPagination)
|
||||
# Only users with a role are returned (inner join)
|
||||
assert len(result.data) == 5
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_with_outer_join(self, db_session: AsyncSession):
|
||||
"""cursor_paginate applies LEFT OUTER JOIN when outer_join=True."""
|
||||
from fastapi_toolsets.schemas import CursorPagination
|
||||
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="member"))
|
||||
for i in range(3):
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(
|
||||
username=f"u{i}",
|
||||
email=f"u{i}@test.com",
|
||||
role_id=role.id,
|
||||
),
|
||||
)
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="norole", email="norole@test.com"),
|
||||
)
|
||||
|
||||
result = await UserCursorCrud.cursor_paginate(
|
||||
db_session,
|
||||
joins=[(Role, User.role_id == Role.id)],
|
||||
outer_join=True,
|
||||
items_per_page=20,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, CursorPagination)
|
||||
# All users are included (outer join)
|
||||
assert len(result.data) == 4
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_with_load_options(self, db_session: AsyncSession):
|
||||
"""cursor_paginate passes load_options to the query."""
|
||||
from fastapi_toolsets.schemas import CursorPagination
|
||||
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="manager"))
|
||||
for i in range(3):
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(
|
||||
username=f"u{i}",
|
||||
email=f"u{i}@test.com",
|
||||
role_id=role.id,
|
||||
),
|
||||
)
|
||||
|
||||
result = await UserCursorCrud.cursor_paginate(
|
||||
db_session,
|
||||
load_options=[selectinload(User.role)],
|
||||
items_per_page=20,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, CursorPagination)
|
||||
assert len(result.data) == 3
|
||||
# Relationship was eagerly loaded
|
||||
assert all(u.role is not None for u in result.data)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_with_order_by(self, db_session: AsyncSession):
|
||||
"""cursor_paginate applies additional order_by after the cursor column."""
|
||||
from fastapi_toolsets.schemas import CursorPagination
|
||||
|
||||
for i in range(5):
|
||||
await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
|
||||
|
||||
result = await RoleCursorCrud.cursor_paginate(
|
||||
db_session,
|
||||
order_by=Role.name.desc(),
|
||||
items_per_page=3,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, CursorPagination)
|
||||
assert len(result.data) == 3
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_integer_cursor_column(self, db_session: AsyncSession):
|
||||
"""cursor_paginate decodes Integer cursor values correctly."""
|
||||
from fastapi_toolsets.schemas import CursorPagination
|
||||
|
||||
for i in range(5):
|
||||
await IntRoleCursorCrud.create(db_session, IntRoleCreate(name=f"role{i}"))
|
||||
|
||||
page1 = await IntRoleCursorCrud.cursor_paginate(db_session, items_per_page=3)
|
||||
|
||||
assert isinstance(page1.pagination, CursorPagination)
|
||||
assert len(page1.data) == 3
|
||||
assert page1.pagination.has_more is True
|
||||
|
||||
page2 = await IntRoleCursorCrud.cursor_paginate(
|
||||
db_session,
|
||||
cursor=page1.pagination.next_cursor,
|
||||
items_per_page=3,
|
||||
)
|
||||
|
||||
assert isinstance(page2.pagination, CursorPagination)
|
||||
assert len(page2.data) == 2
|
||||
assert page2.pagination.has_more is False
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_string_cursor_column(self, db_session: AsyncSession):
|
||||
"""cursor_paginate decodes non-UUID/non-Integer cursor values (string branch)."""
|
||||
from fastapi_toolsets.crud import CrudFactory
|
||||
from fastapi_toolsets.schemas import CursorPagination
|
||||
|
||||
RoleNameCursorCrud = CrudFactory(Role, cursor_column=Role.name)
|
||||
|
||||
for i in range(5):
|
||||
await RoleCrud.create(db_session, RoleCreate(name=f"role{i:02d}"))
|
||||
|
||||
page1 = await RoleNameCursorCrud.cursor_paginate(db_session, items_per_page=3)
|
||||
|
||||
assert isinstance(page1.pagination, CursorPagination)
|
||||
assert len(page1.data) == 3
|
||||
assert page1.pagination.has_more is True
|
||||
|
||||
page2 = await RoleNameCursorCrud.cursor_paginate(
|
||||
db_session,
|
||||
cursor=page1.pagination.next_cursor,
|
||||
items_per_page=3,
|
||||
)
|
||||
|
||||
assert isinstance(page2.pagination, CursorPagination)
|
||||
assert len(page2.data) == 2
|
||||
assert page2.pagination.has_more is False
|
||||
|
||||
|
||||
class TestCursorPaginateSearchJoins:
|
||||
"""Tests for cursor_paginate() search that traverses relationships (search_joins)."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_via_relationship(self, db_session: AsyncSession):
|
||||
"""cursor_paginate outerjoin search-join when searching through a relationship."""
|
||||
from fastapi_toolsets.schemas import CursorPagination
|
||||
|
||||
role_admin = await RoleCrud.create(db_session, RoleCreate(name="administrator"))
|
||||
role_user = await RoleCrud.create(db_session, RoleCreate(name="regularuser"))
|
||||
|
||||
for i in range(3):
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(
|
||||
username=f"admin_u{i}",
|
||||
email=f"admin_u{i}@test.com",
|
||||
role_id=role_admin.id,
|
||||
),
|
||||
)
|
||||
for i in range(2):
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(
|
||||
username=f"reg_u{i}",
|
||||
email=f"reg_u{i}@test.com",
|
||||
role_id=role_user.id,
|
||||
),
|
||||
)
|
||||
|
||||
result = await UserCursorCrud.cursor_paginate(
|
||||
db_session,
|
||||
search="administrator",
|
||||
search_fields=[(User.role, Role.name)],
|
||||
items_per_page=20,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, CursorPagination)
|
||||
assert len(result.data) == 3
|
||||
|
||||
|
||||
class TestGetWithForUpdate:
|
||||
"""Tests for get() with with_for_update=True."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_with_for_update(self, db_session: AsyncSession):
|
||||
"""get() with with_for_update=True locks the row."""
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="locked"))
|
||||
|
||||
result = await RoleCrud.get(
|
||||
db_session,
|
||||
filters=[Role.id == role.id],
|
||||
with_for_update=True,
|
||||
)
|
||||
|
||||
assert result.id == role.id
|
||||
assert result.name == "locked"
|
||||
|
||||
@@ -6,6 +6,7 @@ import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from fastapi_toolsets.crud import SearchConfig, get_searchable_fields
|
||||
from fastapi_toolsets.schemas import OffsetPagination
|
||||
|
||||
from .conftest import (
|
||||
Role,
|
||||
@@ -39,6 +40,7 @@ class TestPaginateSearch:
|
||||
search_fields=[User.username],
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -57,6 +59,7 @@ class TestPaginateSearch:
|
||||
search_fields=[User.username, User.email],
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -84,6 +87,7 @@ class TestPaginateSearch:
|
||||
search_fields=[(User.role, Role.name)],
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -102,6 +106,7 @@ class TestPaginateSearch:
|
||||
search_fields=[User.username, (User.role, Role.name)],
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -117,6 +122,7 @@ class TestPaginateSearch:
|
||||
search_fields=[User.username],
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -132,6 +138,7 @@ class TestPaginateSearch:
|
||||
search=SearchConfig(query="johndoe", case_sensitive=True),
|
||||
search_fields=[User.username],
|
||||
)
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 0
|
||||
|
||||
# Should find (case match)
|
||||
@@ -140,6 +147,7 @@ class TestPaginateSearch:
|
||||
search=SearchConfig(query="JohnDoe", case_sensitive=True),
|
||||
search_fields=[User.username],
|
||||
)
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -153,9 +161,11 @@ class TestPaginateSearch:
|
||||
)
|
||||
|
||||
result = await UserCrud.paginate(db_session, search="")
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 2
|
||||
|
||||
result = await UserCrud.paginate(db_session, search=None)
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -177,6 +187,7 @@ class TestPaginateSearch:
|
||||
search_fields=[User.username],
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 1
|
||||
assert result.data[0].username == "active_john"
|
||||
|
||||
@@ -189,6 +200,7 @@ class TestPaginateSearch:
|
||||
|
||||
result = await UserCrud.paginate(db_session, search="findme")
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -204,6 +216,7 @@ class TestPaginateSearch:
|
||||
search_fields=[User.username],
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 0
|
||||
assert result.data == []
|
||||
|
||||
@@ -224,6 +237,7 @@ class TestPaginateSearch:
|
||||
items_per_page=5,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 15
|
||||
assert len(result.data) == 5
|
||||
assert result.pagination.has_more is True
|
||||
@@ -248,6 +262,7 @@ class TestPaginateSearch:
|
||||
search_fields=[User.username],
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -270,6 +285,7 @@ class TestPaginateSearch:
|
||||
order_by=User.username,
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 3
|
||||
usernames = [u.username for u in result.data]
|
||||
assert usernames == ["alice", "bob", "charlie"]
|
||||
@@ -292,6 +308,7 @@ class TestPaginateSearch:
|
||||
search_fields=[User.id, User.username],
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 1
|
||||
assert result.data[0].id == user_id
|
||||
|
||||
@@ -318,6 +335,7 @@ class TestSearchConfig:
|
||||
search_fields=[User.username, User.email],
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 1
|
||||
assert result.data[0].username == "john_test"
|
||||
|
||||
@@ -333,6 +351,7 @@ class TestSearchConfig:
|
||||
search=SearchConfig(query="findme", fields=[User.email]),
|
||||
)
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count == 1
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,9 @@ from pydantic import ValidationError
|
||||
|
||||
from fastapi_toolsets.schemas import (
|
||||
ApiError,
|
||||
CursorPagination,
|
||||
ErrorResponse,
|
||||
OffsetPagination,
|
||||
PaginatedResponse,
|
||||
Pagination,
|
||||
Response,
|
||||
@@ -154,12 +156,12 @@ class TestErrorResponse:
|
||||
assert data["description"] == "Details"
|
||||
|
||||
|
||||
class TestPagination:
|
||||
"""Tests for Pagination schema."""
|
||||
class TestOffsetPagination:
|
||||
"""Tests for OffsetPagination schema (canonical name for offset-based pagination)."""
|
||||
|
||||
def test_create_pagination(self):
|
||||
"""Create Pagination with all fields."""
|
||||
pagination = Pagination(
|
||||
"""Create OffsetPagination with all fields."""
|
||||
pagination = OffsetPagination(
|
||||
total_count=100,
|
||||
items_per_page=10,
|
||||
page=1,
|
||||
@@ -173,7 +175,7 @@ class TestPagination:
|
||||
|
||||
def test_last_page_has_more_false(self):
|
||||
"""Last page has has_more=False."""
|
||||
pagination = Pagination(
|
||||
pagination = OffsetPagination(
|
||||
total_count=25,
|
||||
items_per_page=10,
|
||||
page=3,
|
||||
@@ -183,8 +185,8 @@ class TestPagination:
|
||||
assert pagination.has_more is False
|
||||
|
||||
def test_serialization(self):
|
||||
"""Pagination serializes correctly."""
|
||||
pagination = Pagination(
|
||||
"""OffsetPagination serializes correctly."""
|
||||
pagination = OffsetPagination(
|
||||
total_count=50,
|
||||
items_per_page=20,
|
||||
page=2,
|
||||
@@ -197,6 +199,77 @@ class TestPagination:
|
||||
assert data["page"] == 2
|
||||
assert data["has_more"] is True
|
||||
|
||||
def test_pagination_alias_is_offset_pagination(self):
|
||||
"""Pagination is a backward-compatible alias for OffsetPagination."""
|
||||
assert Pagination is OffsetPagination
|
||||
|
||||
def test_pagination_alias_constructs_offset_pagination(self):
|
||||
"""Code using Pagination(...) still works unchanged."""
|
||||
pagination = Pagination(
|
||||
total_count=10,
|
||||
items_per_page=5,
|
||||
page=2,
|
||||
has_more=False,
|
||||
)
|
||||
assert isinstance(pagination, OffsetPagination)
|
||||
|
||||
|
||||
class TestCursorPagination:
|
||||
"""Tests for CursorPagination schema."""
|
||||
|
||||
def test_create_with_next_cursor(self):
|
||||
"""CursorPagination with a next cursor indicates more pages."""
|
||||
pagination = CursorPagination(
|
||||
next_cursor="eyJ2YWx1ZSI6ICIxMjMifQ==",
|
||||
items_per_page=20,
|
||||
has_more=True,
|
||||
)
|
||||
assert pagination.next_cursor == "eyJ2YWx1ZSI6ICIxMjMifQ=="
|
||||
assert pagination.prev_cursor is None
|
||||
assert pagination.items_per_page == 20
|
||||
assert pagination.has_more is True
|
||||
|
||||
def test_create_last_page(self):
|
||||
"""CursorPagination for the last page has next_cursor=None and has_more=False."""
|
||||
pagination = CursorPagination(
|
||||
next_cursor=None,
|
||||
items_per_page=20,
|
||||
has_more=False,
|
||||
)
|
||||
assert pagination.next_cursor is None
|
||||
assert pagination.has_more is False
|
||||
|
||||
def test_prev_cursor_defaults_to_none(self):
|
||||
"""prev_cursor defaults to None."""
|
||||
pagination = CursorPagination(
|
||||
next_cursor=None, items_per_page=10, has_more=False
|
||||
)
|
||||
assert pagination.prev_cursor is None
|
||||
|
||||
def test_prev_cursor_can_be_set(self):
|
||||
"""prev_cursor can be explicitly set."""
|
||||
pagination = CursorPagination(
|
||||
next_cursor="next123",
|
||||
prev_cursor="prev456",
|
||||
items_per_page=10,
|
||||
has_more=True,
|
||||
)
|
||||
assert pagination.prev_cursor == "prev456"
|
||||
|
||||
def test_serialization(self):
|
||||
"""CursorPagination serializes correctly."""
|
||||
pagination = CursorPagination(
|
||||
next_cursor="abc123",
|
||||
prev_cursor="xyz789",
|
||||
items_per_page=20,
|
||||
has_more=True,
|
||||
)
|
||||
data = pagination.model_dump()
|
||||
assert data["next_cursor"] == "abc123"
|
||||
assert data["prev_cursor"] == "xyz789"
|
||||
assert data["items_per_page"] == 20
|
||||
assert data["has_more"] is True
|
||||
|
||||
|
||||
class TestPaginatedResponse:
|
||||
"""Tests for PaginatedResponse schema."""
|
||||
@@ -214,6 +287,7 @@ class TestPaginatedResponse:
|
||||
pagination=pagination,
|
||||
)
|
||||
|
||||
assert isinstance(response.pagination, OffsetPagination)
|
||||
assert len(response.data) == 2
|
||||
assert response.pagination.total_count == 30
|
||||
assert response.status == ResponseStatus.SUCCESS
|
||||
@@ -247,6 +321,7 @@ class TestPaginatedResponse:
|
||||
pagination=pagination,
|
||||
)
|
||||
|
||||
assert isinstance(response.pagination, OffsetPagination)
|
||||
assert response.data == []
|
||||
assert response.pagination.total_count == 0
|
||||
|
||||
@@ -290,6 +365,36 @@ class TestPaginatedResponse:
|
||||
assert data["data"] == ["item1", "item2"]
|
||||
assert data["pagination"]["page"] == 5
|
||||
|
||||
def test_pagination_field_accepts_offset_pagination(self):
|
||||
"""PaginatedResponse.pagination accepts OffsetPagination."""
|
||||
response = PaginatedResponse(
|
||||
data=[1, 2],
|
||||
pagination=OffsetPagination(
|
||||
total_count=2, items_per_page=10, page=1, has_more=False
|
||||
),
|
||||
)
|
||||
assert isinstance(response.pagination, OffsetPagination)
|
||||
|
||||
def test_pagination_field_accepts_cursor_pagination(self):
|
||||
"""PaginatedResponse.pagination accepts CursorPagination."""
|
||||
response = PaginatedResponse(
|
||||
data=[1, 2],
|
||||
pagination=CursorPagination(
|
||||
next_cursor=None, items_per_page=10, has_more=False
|
||||
),
|
||||
)
|
||||
assert isinstance(response.pagination, CursorPagination)
|
||||
|
||||
def test_pagination_alias_accepted(self):
|
||||
"""Constructing PaginatedResponse with Pagination (alias) still works."""
|
||||
response = PaginatedResponse(
|
||||
data=[],
|
||||
pagination=Pagination(
|
||||
total_count=0, items_per_page=10, page=1, has_more=False
|
||||
),
|
||||
)
|
||||
assert isinstance(response.pagination, OffsetPagination)
|
||||
|
||||
|
||||
class TestFromAttributes:
|
||||
"""Tests for from_attributes config (ORM mode)."""
|
||||
|
||||
Reference in New Issue
Block a user