feat: use PaginatedResponse and Response into crud (#36)

* feat: return PaginatedResponse for paginate crud function

* feat: add as_response argument for get, create, update and delete crud functions
This commit is contained in:
d3vyce
2026-02-05 22:54:07 +01:00
committed by GitHub
parent 3a69c3c788
commit f68793fbdb
3 changed files with 235 additions and 60 deletions

View File

@@ -1,7 +1,7 @@
"""Generic async CRUD operations for SQLAlchemy models."""
from collections.abc import Sequence
from typing import Any, ClassVar, Generic, Self, TypeVar, cast
from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload
from pydantic import BaseModel
from sqlalchemy import and_, func, select
@@ -14,6 +14,7 @@ from sqlalchemy.sql.roles import WhereHavingRole
from ..db import get_transaction
from ..exceptions import NotFoundError
from ..schemas import PaginatedResponse, Pagination, Response
from .search import SearchConfig, SearchFieldType, build_search_filters
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
@@ -29,26 +30,80 @@ class AsyncCrud(Generic[ModelType]):
model: ClassVar[type[DeclarativeBase]]
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
@overload
@classmethod
async def create( # pragma: no cover
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
*,
as_response: Literal[True],
) -> Response[ModelType]: ...
@overload
@classmethod
async def create( # pragma: no cover
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
*,
as_response: Literal[False] = ...,
) -> ModelType: ...
@classmethod
async def create(
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
) -> ModelType:
*,
as_response: bool = False,
) -> ModelType | Response[ModelType]:
"""Create a new record in the database.
Args:
session: DB async session
obj: Pydantic model with data to create
as_response: If True, wrap result in Response object
Returns:
Created model instance
Created model instance or Response wrapping it
"""
async with get_transaction(session):
db_model = cls.model(**obj.model_dump())
session.add(db_model)
await session.refresh(db_model)
return cast(ModelType, db_model)
result = cast(ModelType, db_model)
if as_response:
return Response(data=result)
return result
@overload
@classmethod
async def get( # pragma: no cover
cls: type[Self],
session: AsyncSession,
filters: list[Any],
*,
joins: JoinType | None = None,
outer_join: bool = False,
with_for_update: bool = False,
load_options: list[Any] | None = None,
as_response: Literal[True],
) -> Response[ModelType]: ...
@overload
@classmethod
async def get( # pragma: no cover
cls: type[Self],
session: AsyncSession,
filters: list[Any],
*,
joins: JoinType | None = None,
outer_join: bool = False,
with_for_update: bool = False,
load_options: list[Any] | None = None,
as_response: Literal[False] = ...,
) -> ModelType: ...
@classmethod
async def get(
@@ -60,7 +115,8 @@ class AsyncCrud(Generic[ModelType]):
outer_join: bool = False,
with_for_update: bool = False,
load_options: list[Any] | None = None,
) -> ModelType:
as_response: bool = False,
) -> ModelType | Response[ModelType]:
"""Get exactly one record. Raises NotFoundError if not found.
Args:
@@ -70,9 +126,10 @@ class AsyncCrud(Generic[ModelType]):
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
with_for_update: Lock the row for update
load_options: SQLAlchemy loader options (e.g., selectinload)
as_response: If True, wrap result in Response object
Returns:
Model instance
Model instance or Response wrapping it
Raises:
NotFoundError: If no record found
@@ -95,7 +152,10 @@ class AsyncCrud(Generic[ModelType]):
item = result.unique().scalar_one_or_none()
if not item:
raise NotFoundError()
return cast(ModelType, item)
result = cast(ModelType, item)
if as_response:
return Response(data=result)
return result
@classmethod
async def first(
@@ -183,6 +243,32 @@ class AsyncCrud(Generic[ModelType]):
result = await session.execute(q)
return cast(Sequence[ModelType], result.unique().scalars().all())
@overload
@classmethod
async def update( # pragma: no cover
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
filters: list[Any],
*,
exclude_unset: bool = True,
exclude_none: bool = False,
as_response: Literal[True],
) -> Response[ModelType]: ...
@overload
@classmethod
async def update( # pragma: no cover
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
filters: list[Any],
*,
exclude_unset: bool = True,
exclude_none: bool = False,
as_response: Literal[False] = ...,
) -> ModelType: ...
@classmethod
async def update(
cls: type[Self],
@@ -192,7 +278,8 @@ class AsyncCrud(Generic[ModelType]):
*,
exclude_unset: bool = True,
exclude_none: bool = False,
) -> ModelType:
as_response: bool = False,
) -> ModelType | Response[ModelType]:
"""Update a record in the database.
Args:
@@ -201,9 +288,10 @@ class AsyncCrud(Generic[ModelType]):
filters: List of SQLAlchemy filter conditions
exclude_unset: Exclude fields not explicitly set in the schema
exclude_none: Exclude fields with None value
as_response: If True, wrap result in Response object
Returns:
Updated model instance
Updated model instance or Response wrapping it
Raises:
NotFoundError: If no record found
@@ -216,6 +304,8 @@ class AsyncCrud(Generic[ModelType]):
for key, value in values.items():
setattr(db_model, key, value)
await session.refresh(db_model)
if as_response:
return Response(data=db_model)
return db_model
@classmethod
@@ -264,24 +354,49 @@ class AsyncCrud(Generic[ModelType]):
)
return cast(ModelType | None, db_model)
@overload
@classmethod
async def delete( # pragma: no cover
cls: type[Self],
session: AsyncSession,
filters: list[Any],
*,
as_response: Literal[True],
) -> Response[None]: ...
@overload
@classmethod
async def delete( # pragma: no cover
cls: type[Self],
session: AsyncSession,
filters: list[Any],
*,
as_response: Literal[False] = ...,
) -> bool: ...
@classmethod
async def delete(
cls: type[Self],
session: AsyncSession,
filters: list[Any],
) -> bool:
*,
as_response: bool = False,
) -> bool | Response[None]:
"""Delete records from the database.
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
as_response: If True, wrap result in Response object
Returns:
True if deletion was executed
True if deletion was executed, or Response wrapping it
"""
async with get_transaction(session):
q = sql_delete(cls.model).where(and_(*filters))
await session.execute(q)
if as_response:
return Response(data=None)
return True
@classmethod
@@ -363,7 +478,7 @@ class AsyncCrud(Generic[ModelType]):
items_per_page: int = 20,
search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None,
) -> dict[str, Any]:
) -> PaginatedResponse[ModelType]:
"""Get paginated results with metadata.
Args:
@@ -420,7 +535,7 @@ class AsyncCrud(Generic[ModelType]):
q = q.offset(offset).limit(items_per_page)
result = await session.execute(q)
items = result.unique().scalars().all()
items = cast(list[ModelType], result.unique().scalars().all())
# Count query (with same joins and filters)
pk_col = cls.model.__mapper__.primary_key[0]
@@ -446,15 +561,15 @@ class AsyncCrud(Generic[ModelType]):
count_result = await session.execute(count_q)
total_count = count_result.scalar_one()
return {
"data": items,
"pagination": {
"total_count": total_count,
"items_per_page": items_per_page,
"page": page,
"has_more": page * items_per_page < total_count,
},
}
return PaginatedResponse(
data=items,
pagination=Pagination(
total_count=total_count,
items_per_page=items_per_page,
page=page,
has_more=page * items_per_page < total_count,
),
)
def CrudFactory(

View File

@@ -429,11 +429,11 @@ class TestCrudPaginate:
result = await RoleCrud.paginate(db_session, page=1, items_per_page=10)
assert len(result["data"]) == 10
assert result["pagination"]["total_count"] == 25
assert result["pagination"]["page"] == 1
assert result["pagination"]["items_per_page"] == 10
assert result["pagination"]["has_more"] is True
assert len(result.data) == 10
assert result.pagination.total_count == 25
assert result.pagination.page == 1
assert result.pagination.items_per_page == 10
assert result.pagination.has_more is True
@pytest.mark.anyio
async def test_paginate_last_page(self, db_session: AsyncSession):
@@ -443,8 +443,8 @@ class TestCrudPaginate:
result = await RoleCrud.paginate(db_session, page=3, items_per_page=10)
assert len(result["data"]) == 5
assert result["pagination"]["has_more"] is False
assert len(result.data) == 5
assert result.pagination.has_more is False
@pytest.mark.anyio
async def test_paginate_with_filters(self, db_session: AsyncSession):
@@ -466,7 +466,7 @@ class TestCrudPaginate:
items_per_page=10,
)
assert result["pagination"]["total_count"] == 5
assert result.pagination.total_count == 5
@pytest.mark.anyio
async def test_paginate_with_ordering(self, db_session: AsyncSession):
@@ -482,7 +482,7 @@ class TestCrudPaginate:
items_per_page=10,
)
names = [r.name for r in result["data"]]
names = [r.name for r in result.data]
assert names == ["alpha", "bravo", "charlie"]
@@ -690,8 +690,8 @@ class TestCrudJoins:
items_per_page=10,
)
assert result["pagination"]["total_count"] == 3
assert len(result["data"]) == 3
assert result.pagination.total_count == 3
assert len(result.data) == 3
@pytest.mark.anyio
async def test_paginate_with_outer_join(self, db_session: AsyncSession):
@@ -721,8 +721,8 @@ class TestCrudJoins:
items_per_page=10,
)
assert result["pagination"]["total_count"] == 2
assert len(result["data"]) == 2
assert result.pagination.total_count == 2
assert len(result.data) == 2
@pytest.mark.anyio
async def test_multiple_joins(self, db_session: AsyncSession):
@@ -752,3 +752,63 @@ class TestCrudJoins:
)
assert len(users) == 1
assert users[0].username == "multi_join"
class TestAsResponse:
"""Tests for as_response parameter."""
@pytest.mark.anyio
async def test_create_as_response(self, db_session: AsyncSession):
"""Create with as_response=True returns Response."""
from fastapi_toolsets.schemas import Response
data = RoleCreate(name="response_role")
result = await RoleCrud.create(db_session, data, as_response=True)
assert isinstance(result, Response)
assert result.data is not None
assert result.data.name == "response_role"
@pytest.mark.anyio
async def test_get_as_response(self, db_session: AsyncSession):
"""Get with as_response=True returns Response."""
from fastapi_toolsets.schemas import Response
created = await RoleCrud.create(db_session, RoleCreate(name="get_response"))
result = await RoleCrud.get(
db_session, [Role.id == created.id], as_response=True
)
assert isinstance(result, Response)
assert result.data is not None
assert result.data.id == created.id
@pytest.mark.anyio
async def test_update_as_response(self, db_session: AsyncSession):
"""Update with as_response=True returns Response."""
from fastapi_toolsets.schemas import Response
created = await RoleCrud.create(db_session, RoleCreate(name="old_name"))
result = await RoleCrud.update(
db_session,
RoleUpdate(name="new_name"),
[Role.id == created.id],
as_response=True,
)
assert isinstance(result, Response)
assert result.data is not None
assert result.data.name == "new_name"
@pytest.mark.anyio
async def test_delete_as_response(self, db_session: AsyncSession):
"""Delete with as_response=True returns Response."""
from fastapi_toolsets.schemas import Response
created = await RoleCrud.create(db_session, RoleCreate(name="to_delete"))
result = await RoleCrud.delete(
db_session, [Role.id == created.id], as_response=True
)
assert isinstance(result, Response)
assert result.data is None

View File

@@ -39,7 +39,7 @@ class TestPaginateSearch:
search_fields=[User.username],
)
assert result["pagination"]["total_count"] == 2
assert result.pagination.total_count == 2
@pytest.mark.anyio
async def test_search_multiple_columns(self, db_session: AsyncSession):
@@ -57,7 +57,7 @@ class TestPaginateSearch:
search_fields=[User.username, User.email],
)
assert result["pagination"]["total_count"] == 2
assert result.pagination.total_count == 2
@pytest.mark.anyio
async def test_search_relationship_depth1(self, db_session: AsyncSession):
@@ -84,7 +84,7 @@ class TestPaginateSearch:
search_fields=[(User.role, Role.name)],
)
assert result["pagination"]["total_count"] == 2
assert result.pagination.total_count == 2
@pytest.mark.anyio
async def test_search_mixed_direct_and_relation(self, db_session: AsyncSession):
@@ -102,7 +102,7 @@ class TestPaginateSearch:
search_fields=[User.username, (User.role, Role.name)],
)
assert result["pagination"]["total_count"] == 1
assert result.pagination.total_count == 1
@pytest.mark.anyio
async def test_search_case_insensitive(self, db_session: AsyncSession):
@@ -117,7 +117,7 @@ class TestPaginateSearch:
search_fields=[User.username],
)
assert result["pagination"]["total_count"] == 1
assert result.pagination.total_count == 1
@pytest.mark.anyio
async def test_search_case_sensitive(self, db_session: AsyncSession):
@@ -132,7 +132,7 @@ class TestPaginateSearch:
search=SearchConfig(query="johndoe", case_sensitive=True),
search_fields=[User.username],
)
assert result["pagination"]["total_count"] == 0
assert result.pagination.total_count == 0
# Should find (case match)
result = await UserCrud.paginate(
@@ -140,7 +140,7 @@ class TestPaginateSearch:
search=SearchConfig(query="JohnDoe", case_sensitive=True),
search_fields=[User.username],
)
assert result["pagination"]["total_count"] == 1
assert result.pagination.total_count == 1
@pytest.mark.anyio
async def test_search_empty_query(self, db_session: AsyncSession):
@@ -153,10 +153,10 @@ class TestPaginateSearch:
)
result = await UserCrud.paginate(db_session, search="")
assert result["pagination"]["total_count"] == 2
assert result.pagination.total_count == 2
result = await UserCrud.paginate(db_session, search=None)
assert result["pagination"]["total_count"] == 2
assert result.pagination.total_count == 2
@pytest.mark.anyio
async def test_search_with_existing_filters(self, db_session: AsyncSession):
@@ -177,8 +177,8 @@ class TestPaginateSearch:
search_fields=[User.username],
)
assert result["pagination"]["total_count"] == 1
assert result["data"][0].username == "active_john"
assert result.pagination.total_count == 1
assert result.data[0].username == "active_john"
@pytest.mark.anyio
async def test_search_auto_detect_fields(self, db_session: AsyncSession):
@@ -189,7 +189,7 @@ class TestPaginateSearch:
result = await UserCrud.paginate(db_session, search="findme")
assert result["pagination"]["total_count"] == 1
assert result.pagination.total_count == 1
@pytest.mark.anyio
async def test_search_no_results(self, db_session: AsyncSession):
@@ -204,8 +204,8 @@ class TestPaginateSearch:
search_fields=[User.username],
)
assert result["pagination"]["total_count"] == 0
assert result["data"] == []
assert result.pagination.total_count == 0
assert result.data == []
@pytest.mark.anyio
async def test_search_with_pagination(self, db_session: AsyncSession):
@@ -224,9 +224,9 @@ class TestPaginateSearch:
items_per_page=5,
)
assert result["pagination"]["total_count"] == 15
assert len(result["data"]) == 5
assert result["pagination"]["has_more"] is True
assert result.pagination.total_count == 15
assert len(result.data) == 5
assert result.pagination.has_more is True
@pytest.mark.anyio
async def test_search_null_relationship(self, db_session: AsyncSession):
@@ -248,7 +248,7 @@ class TestPaginateSearch:
search_fields=[User.username],
)
assert result["pagination"]["total_count"] == 2
assert result.pagination.total_count == 2
@pytest.mark.anyio
async def test_search_with_order_by(self, db_session: AsyncSession):
@@ -270,8 +270,8 @@ class TestPaginateSearch:
order_by=User.username,
)
assert result["pagination"]["total_count"] == 3
usernames = [u.username for u in result["data"]]
assert result.pagination.total_count == 3
usernames = [u.username for u in result.data]
assert usernames == ["alice", "bob", "charlie"]
@pytest.mark.anyio
@@ -292,8 +292,8 @@ class TestPaginateSearch:
search_fields=[User.id, User.username],
)
assert result["pagination"]["total_count"] == 1
assert result["data"][0].id == user_id
assert result.pagination.total_count == 1
assert result.data[0].id == user_id
class TestSearchConfig:
@@ -318,8 +318,8 @@ class TestSearchConfig:
search_fields=[User.username, User.email],
)
assert result["pagination"]["total_count"] == 1
assert result["data"][0].username == "john_test"
assert result.pagination.total_count == 1
assert result.data[0].username == "john_test"
@pytest.mark.anyio
async def test_search_config_with_fields(self, db_session: AsyncSession):
@@ -333,7 +333,7 @@ class TestSearchConfig:
search=SearchConfig(query="findme", fields=[User.email]),
)
assert result["pagination"]["total_count"] == 1
assert result.pagination.total_count == 1
class TestNoSearchableFieldsError: