feat: use PaginatedResponse and Response into crud (#36)

* feat: return PaginatedResponse for paginate crud function

* feat: add as_response argument for get, create, update and delete crud functions
This commit is contained in:
d3vyce
2026-02-05 22:54:07 +01:00
committed by GitHub
parent 3a69c3c788
commit f68793fbdb
3 changed files with 235 additions and 60 deletions

View File

@@ -1,7 +1,7 @@
"""Generic async CRUD operations for SQLAlchemy models.""" """Generic async CRUD operations for SQLAlchemy models."""
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any, ClassVar, Generic, Self, TypeVar, cast from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import and_, func, select from sqlalchemy import and_, func, select
@@ -14,6 +14,7 @@ from sqlalchemy.sql.roles import WhereHavingRole
from ..db import get_transaction from ..db import get_transaction
from ..exceptions import NotFoundError from ..exceptions import NotFoundError
from ..schemas import PaginatedResponse, Pagination, Response
from .search import SearchConfig, SearchFieldType, build_search_filters from .search import SearchConfig, SearchFieldType, build_search_filters
ModelType = TypeVar("ModelType", bound=DeclarativeBase) ModelType = TypeVar("ModelType", bound=DeclarativeBase)
@@ -29,26 +30,80 @@ class AsyncCrud(Generic[ModelType]):
model: ClassVar[type[DeclarativeBase]] model: ClassVar[type[DeclarativeBase]]
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
@overload
@classmethod
async def create( # pragma: no cover
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
*,
as_response: Literal[True],
) -> Response[ModelType]: ...
@overload
@classmethod
async def create( # pragma: no cover
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
*,
as_response: Literal[False] = ...,
) -> ModelType: ...
@classmethod @classmethod
async def create( async def create(
cls: type[Self], cls: type[Self],
session: AsyncSession, session: AsyncSession,
obj: BaseModel, obj: BaseModel,
) -> ModelType: *,
as_response: bool = False,
) -> ModelType | Response[ModelType]:
"""Create a new record in the database. """Create a new record in the database.
Args: Args:
session: DB async session session: DB async session
obj: Pydantic model with data to create obj: Pydantic model with data to create
as_response: If True, wrap result in Response object
Returns: Returns:
Created model instance Created model instance or Response wrapping it
""" """
async with get_transaction(session): async with get_transaction(session):
db_model = cls.model(**obj.model_dump()) db_model = cls.model(**obj.model_dump())
session.add(db_model) session.add(db_model)
await session.refresh(db_model) await session.refresh(db_model)
return cast(ModelType, db_model) result = cast(ModelType, db_model)
if as_response:
return Response(data=result)
return result
@overload
@classmethod
async def get( # pragma: no cover
cls: type[Self],
session: AsyncSession,
filters: list[Any],
*,
joins: JoinType | None = None,
outer_join: bool = False,
with_for_update: bool = False,
load_options: list[Any] | None = None,
as_response: Literal[True],
) -> Response[ModelType]: ...
@overload
@classmethod
async def get( # pragma: no cover
cls: type[Self],
session: AsyncSession,
filters: list[Any],
*,
joins: JoinType | None = None,
outer_join: bool = False,
with_for_update: bool = False,
load_options: list[Any] | None = None,
as_response: Literal[False] = ...,
) -> ModelType: ...
@classmethod @classmethod
async def get( async def get(
@@ -60,7 +115,8 @@ class AsyncCrud(Generic[ModelType]):
outer_join: bool = False, outer_join: bool = False,
with_for_update: bool = False, with_for_update: bool = False,
load_options: list[Any] | None = None, load_options: list[Any] | None = None,
) -> ModelType: as_response: bool = False,
) -> ModelType | Response[ModelType]:
"""Get exactly one record. Raises NotFoundError if not found. """Get exactly one record. Raises NotFoundError if not found.
Args: Args:
@@ -70,9 +126,10 @@ class AsyncCrud(Generic[ModelType]):
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
with_for_update: Lock the row for update with_for_update: Lock the row for update
load_options: SQLAlchemy loader options (e.g., selectinload) load_options: SQLAlchemy loader options (e.g., selectinload)
as_response: If True, wrap result in Response object
Returns: Returns:
Model instance Model instance or Response wrapping it
Raises: Raises:
NotFoundError: If no record found NotFoundError: If no record found
@@ -95,7 +152,10 @@ class AsyncCrud(Generic[ModelType]):
item = result.unique().scalar_one_or_none() item = result.unique().scalar_one_or_none()
if not item: if not item:
raise NotFoundError() raise NotFoundError()
return cast(ModelType, item) result = cast(ModelType, item)
if as_response:
return Response(data=result)
return result
@classmethod @classmethod
async def first( async def first(
@@ -183,6 +243,32 @@ class AsyncCrud(Generic[ModelType]):
result = await session.execute(q) result = await session.execute(q)
return cast(Sequence[ModelType], result.unique().scalars().all()) return cast(Sequence[ModelType], result.unique().scalars().all())
@overload
@classmethod
async def update( # pragma: no cover
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
filters: list[Any],
*,
exclude_unset: bool = True,
exclude_none: bool = False,
as_response: Literal[True],
) -> Response[ModelType]: ...
@overload
@classmethod
async def update( # pragma: no cover
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
filters: list[Any],
*,
exclude_unset: bool = True,
exclude_none: bool = False,
as_response: Literal[False] = ...,
) -> ModelType: ...
@classmethod @classmethod
async def update( async def update(
cls: type[Self], cls: type[Self],
@@ -192,7 +278,8 @@ class AsyncCrud(Generic[ModelType]):
*, *,
exclude_unset: bool = True, exclude_unset: bool = True,
exclude_none: bool = False, exclude_none: bool = False,
) -> ModelType: as_response: bool = False,
) -> ModelType | Response[ModelType]:
"""Update a record in the database. """Update a record in the database.
Args: Args:
@@ -201,9 +288,10 @@ class AsyncCrud(Generic[ModelType]):
filters: List of SQLAlchemy filter conditions filters: List of SQLAlchemy filter conditions
exclude_unset: Exclude fields not explicitly set in the schema exclude_unset: Exclude fields not explicitly set in the schema
exclude_none: Exclude fields with None value exclude_none: Exclude fields with None value
as_response: If True, wrap result in Response object
Returns: Returns:
Updated model instance Updated model instance or Response wrapping it
Raises: Raises:
NotFoundError: If no record found NotFoundError: If no record found
@@ -216,6 +304,8 @@ class AsyncCrud(Generic[ModelType]):
for key, value in values.items(): for key, value in values.items():
setattr(db_model, key, value) setattr(db_model, key, value)
await session.refresh(db_model) await session.refresh(db_model)
if as_response:
return Response(data=db_model)
return db_model return db_model
@classmethod @classmethod
@@ -264,24 +354,49 @@ class AsyncCrud(Generic[ModelType]):
) )
return cast(ModelType | None, db_model) return cast(ModelType | None, db_model)
@overload
@classmethod
async def delete( # pragma: no cover
cls: type[Self],
session: AsyncSession,
filters: list[Any],
*,
as_response: Literal[True],
) -> Response[None]: ...
@overload
@classmethod
async def delete( # pragma: no cover
cls: type[Self],
session: AsyncSession,
filters: list[Any],
*,
as_response: Literal[False] = ...,
) -> bool: ...
@classmethod @classmethod
async def delete( async def delete(
cls: type[Self], cls: type[Self],
session: AsyncSession, session: AsyncSession,
filters: list[Any], filters: list[Any],
) -> bool: *,
as_response: bool = False,
) -> bool | Response[None]:
"""Delete records from the database. """Delete records from the database.
Args: Args:
session: DB async session session: DB async session
filters: List of SQLAlchemy filter conditions filters: List of SQLAlchemy filter conditions
as_response: If True, wrap result in Response object
Returns: Returns:
True if deletion was executed True if deletion was executed, or Response wrapping it
""" """
async with get_transaction(session): async with get_transaction(session):
q = sql_delete(cls.model).where(and_(*filters)) q = sql_delete(cls.model).where(and_(*filters))
await session.execute(q) await session.execute(q)
if as_response:
return Response(data=None)
return True return True
@classmethod @classmethod
@@ -363,7 +478,7 @@ class AsyncCrud(Generic[ModelType]):
items_per_page: int = 20, items_per_page: int = 20,
search: str | SearchConfig | None = None, search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None, search_fields: Sequence[SearchFieldType] | None = None,
) -> dict[str, Any]: ) -> PaginatedResponse[ModelType]:
"""Get paginated results with metadata. """Get paginated results with metadata.
Args: Args:
@@ -420,7 +535,7 @@ class AsyncCrud(Generic[ModelType]):
q = q.offset(offset).limit(items_per_page) q = q.offset(offset).limit(items_per_page)
result = await session.execute(q) result = await session.execute(q)
items = result.unique().scalars().all() items = cast(list[ModelType], result.unique().scalars().all())
# Count query (with same joins and filters) # Count query (with same joins and filters)
pk_col = cls.model.__mapper__.primary_key[0] pk_col = cls.model.__mapper__.primary_key[0]
@@ -446,15 +561,15 @@ class AsyncCrud(Generic[ModelType]):
count_result = await session.execute(count_q) count_result = await session.execute(count_q)
total_count = count_result.scalar_one() total_count = count_result.scalar_one()
return { return PaginatedResponse(
"data": items, data=items,
"pagination": { pagination=Pagination(
"total_count": total_count, total_count=total_count,
"items_per_page": items_per_page, items_per_page=items_per_page,
"page": page, page=page,
"has_more": page * items_per_page < total_count, has_more=page * items_per_page < total_count,
}, ),
} )
def CrudFactory( def CrudFactory(

View File

@@ -429,11 +429,11 @@ class TestCrudPaginate:
result = await RoleCrud.paginate(db_session, page=1, items_per_page=10) result = await RoleCrud.paginate(db_session, page=1, items_per_page=10)
assert len(result["data"]) == 10 assert len(result.data) == 10
assert result["pagination"]["total_count"] == 25 assert result.pagination.total_count == 25
assert result["pagination"]["page"] == 1 assert result.pagination.page == 1
assert result["pagination"]["items_per_page"] == 10 assert result.pagination.items_per_page == 10
assert result["pagination"]["has_more"] is True assert result.pagination.has_more is True
@pytest.mark.anyio @pytest.mark.anyio
async def test_paginate_last_page(self, db_session: AsyncSession): async def test_paginate_last_page(self, db_session: AsyncSession):
@@ -443,8 +443,8 @@ class TestCrudPaginate:
result = await RoleCrud.paginate(db_session, page=3, items_per_page=10) result = await RoleCrud.paginate(db_session, page=3, items_per_page=10)
assert len(result["data"]) == 5 assert len(result.data) == 5
assert result["pagination"]["has_more"] is False assert result.pagination.has_more is False
@pytest.mark.anyio @pytest.mark.anyio
async def test_paginate_with_filters(self, db_session: AsyncSession): async def test_paginate_with_filters(self, db_session: AsyncSession):
@@ -466,7 +466,7 @@ class TestCrudPaginate:
items_per_page=10, items_per_page=10,
) )
assert result["pagination"]["total_count"] == 5 assert result.pagination.total_count == 5
@pytest.mark.anyio @pytest.mark.anyio
async def test_paginate_with_ordering(self, db_session: AsyncSession): async def test_paginate_with_ordering(self, db_session: AsyncSession):
@@ -482,7 +482,7 @@ class TestCrudPaginate:
items_per_page=10, items_per_page=10,
) )
names = [r.name for r in result["data"]] names = [r.name for r in result.data]
assert names == ["alpha", "bravo", "charlie"] assert names == ["alpha", "bravo", "charlie"]
@@ -690,8 +690,8 @@ class TestCrudJoins:
items_per_page=10, items_per_page=10,
) )
assert result["pagination"]["total_count"] == 3 assert result.pagination.total_count == 3
assert len(result["data"]) == 3 assert len(result.data) == 3
@pytest.mark.anyio @pytest.mark.anyio
async def test_paginate_with_outer_join(self, db_session: AsyncSession): async def test_paginate_with_outer_join(self, db_session: AsyncSession):
@@ -721,8 +721,8 @@ class TestCrudJoins:
items_per_page=10, items_per_page=10,
) )
assert result["pagination"]["total_count"] == 2 assert result.pagination.total_count == 2
assert len(result["data"]) == 2 assert len(result.data) == 2
@pytest.mark.anyio @pytest.mark.anyio
async def test_multiple_joins(self, db_session: AsyncSession): async def test_multiple_joins(self, db_session: AsyncSession):
@@ -752,3 +752,63 @@ class TestCrudJoins:
) )
assert len(users) == 1 assert len(users) == 1
assert users[0].username == "multi_join" assert users[0].username == "multi_join"
class TestAsResponse:
"""Tests for as_response parameter."""
@pytest.mark.anyio
async def test_create_as_response(self, db_session: AsyncSession):
"""Create with as_response=True returns Response."""
from fastapi_toolsets.schemas import Response
data = RoleCreate(name="response_role")
result = await RoleCrud.create(db_session, data, as_response=True)
assert isinstance(result, Response)
assert result.data is not None
assert result.data.name == "response_role"
@pytest.mark.anyio
async def test_get_as_response(self, db_session: AsyncSession):
"""Get with as_response=True returns Response."""
from fastapi_toolsets.schemas import Response
created = await RoleCrud.create(db_session, RoleCreate(name="get_response"))
result = await RoleCrud.get(
db_session, [Role.id == created.id], as_response=True
)
assert isinstance(result, Response)
assert result.data is not None
assert result.data.id == created.id
@pytest.mark.anyio
async def test_update_as_response(self, db_session: AsyncSession):
"""Update with as_response=True returns Response."""
from fastapi_toolsets.schemas import Response
created = await RoleCrud.create(db_session, RoleCreate(name="old_name"))
result = await RoleCrud.update(
db_session,
RoleUpdate(name="new_name"),
[Role.id == created.id],
as_response=True,
)
assert isinstance(result, Response)
assert result.data is not None
assert result.data.name == "new_name"
@pytest.mark.anyio
async def test_delete_as_response(self, db_session: AsyncSession):
"""Delete with as_response=True returns Response."""
from fastapi_toolsets.schemas import Response
created = await RoleCrud.create(db_session, RoleCreate(name="to_delete"))
result = await RoleCrud.delete(
db_session, [Role.id == created.id], as_response=True
)
assert isinstance(result, Response)
assert result.data is None

View File

@@ -39,7 +39,7 @@ class TestPaginateSearch:
search_fields=[User.username], search_fields=[User.username],
) )
assert result["pagination"]["total_count"] == 2 assert result.pagination.total_count == 2
@pytest.mark.anyio @pytest.mark.anyio
async def test_search_multiple_columns(self, db_session: AsyncSession): async def test_search_multiple_columns(self, db_session: AsyncSession):
@@ -57,7 +57,7 @@ class TestPaginateSearch:
search_fields=[User.username, User.email], search_fields=[User.username, User.email],
) )
assert result["pagination"]["total_count"] == 2 assert result.pagination.total_count == 2
@pytest.mark.anyio @pytest.mark.anyio
async def test_search_relationship_depth1(self, db_session: AsyncSession): async def test_search_relationship_depth1(self, db_session: AsyncSession):
@@ -84,7 +84,7 @@ class TestPaginateSearch:
search_fields=[(User.role, Role.name)], search_fields=[(User.role, Role.name)],
) )
assert result["pagination"]["total_count"] == 2 assert result.pagination.total_count == 2
@pytest.mark.anyio @pytest.mark.anyio
async def test_search_mixed_direct_and_relation(self, db_session: AsyncSession): async def test_search_mixed_direct_and_relation(self, db_session: AsyncSession):
@@ -102,7 +102,7 @@ class TestPaginateSearch:
search_fields=[User.username, (User.role, Role.name)], search_fields=[User.username, (User.role, Role.name)],
) )
assert result["pagination"]["total_count"] == 1 assert result.pagination.total_count == 1
@pytest.mark.anyio @pytest.mark.anyio
async def test_search_case_insensitive(self, db_session: AsyncSession): async def test_search_case_insensitive(self, db_session: AsyncSession):
@@ -117,7 +117,7 @@ class TestPaginateSearch:
search_fields=[User.username], search_fields=[User.username],
) )
assert result["pagination"]["total_count"] == 1 assert result.pagination.total_count == 1
@pytest.mark.anyio @pytest.mark.anyio
async def test_search_case_sensitive(self, db_session: AsyncSession): async def test_search_case_sensitive(self, db_session: AsyncSession):
@@ -132,7 +132,7 @@ class TestPaginateSearch:
search=SearchConfig(query="johndoe", case_sensitive=True), search=SearchConfig(query="johndoe", case_sensitive=True),
search_fields=[User.username], search_fields=[User.username],
) )
assert result["pagination"]["total_count"] == 0 assert result.pagination.total_count == 0
# Should find (case match) # Should find (case match)
result = await UserCrud.paginate( result = await UserCrud.paginate(
@@ -140,7 +140,7 @@ class TestPaginateSearch:
search=SearchConfig(query="JohnDoe", case_sensitive=True), search=SearchConfig(query="JohnDoe", case_sensitive=True),
search_fields=[User.username], search_fields=[User.username],
) )
assert result["pagination"]["total_count"] == 1 assert result.pagination.total_count == 1
@pytest.mark.anyio @pytest.mark.anyio
async def test_search_empty_query(self, db_session: AsyncSession): async def test_search_empty_query(self, db_session: AsyncSession):
@@ -153,10 +153,10 @@ class TestPaginateSearch:
) )
result = await UserCrud.paginate(db_session, search="") result = await UserCrud.paginate(db_session, search="")
assert result["pagination"]["total_count"] == 2 assert result.pagination.total_count == 2
result = await UserCrud.paginate(db_session, search=None) result = await UserCrud.paginate(db_session, search=None)
assert result["pagination"]["total_count"] == 2 assert result.pagination.total_count == 2
@pytest.mark.anyio @pytest.mark.anyio
async def test_search_with_existing_filters(self, db_session: AsyncSession): async def test_search_with_existing_filters(self, db_session: AsyncSession):
@@ -177,8 +177,8 @@ class TestPaginateSearch:
search_fields=[User.username], search_fields=[User.username],
) )
assert result["pagination"]["total_count"] == 1 assert result.pagination.total_count == 1
assert result["data"][0].username == "active_john" assert result.data[0].username == "active_john"
@pytest.mark.anyio @pytest.mark.anyio
async def test_search_auto_detect_fields(self, db_session: AsyncSession): async def test_search_auto_detect_fields(self, db_session: AsyncSession):
@@ -189,7 +189,7 @@ class TestPaginateSearch:
result = await UserCrud.paginate(db_session, search="findme") result = await UserCrud.paginate(db_session, search="findme")
assert result["pagination"]["total_count"] == 1 assert result.pagination.total_count == 1
@pytest.mark.anyio @pytest.mark.anyio
async def test_search_no_results(self, db_session: AsyncSession): async def test_search_no_results(self, db_session: AsyncSession):
@@ -204,8 +204,8 @@ class TestPaginateSearch:
search_fields=[User.username], search_fields=[User.username],
) )
assert result["pagination"]["total_count"] == 0 assert result.pagination.total_count == 0
assert result["data"] == [] assert result.data == []
@pytest.mark.anyio @pytest.mark.anyio
async def test_search_with_pagination(self, db_session: AsyncSession): async def test_search_with_pagination(self, db_session: AsyncSession):
@@ -224,9 +224,9 @@ class TestPaginateSearch:
items_per_page=5, items_per_page=5,
) )
assert result["pagination"]["total_count"] == 15 assert result.pagination.total_count == 15
assert len(result["data"]) == 5 assert len(result.data) == 5
assert result["pagination"]["has_more"] is True assert result.pagination.has_more is True
@pytest.mark.anyio @pytest.mark.anyio
async def test_search_null_relationship(self, db_session: AsyncSession): async def test_search_null_relationship(self, db_session: AsyncSession):
@@ -248,7 +248,7 @@ class TestPaginateSearch:
search_fields=[User.username], search_fields=[User.username],
) )
assert result["pagination"]["total_count"] == 2 assert result.pagination.total_count == 2
@pytest.mark.anyio @pytest.mark.anyio
async def test_search_with_order_by(self, db_session: AsyncSession): async def test_search_with_order_by(self, db_session: AsyncSession):
@@ -270,8 +270,8 @@ class TestPaginateSearch:
order_by=User.username, order_by=User.username,
) )
assert result["pagination"]["total_count"] == 3 assert result.pagination.total_count == 3
usernames = [u.username for u in result["data"]] usernames = [u.username for u in result.data]
assert usernames == ["alice", "bob", "charlie"] assert usernames == ["alice", "bob", "charlie"]
@pytest.mark.anyio @pytest.mark.anyio
@@ -292,8 +292,8 @@ class TestPaginateSearch:
search_fields=[User.id, User.username], search_fields=[User.id, User.username],
) )
assert result["pagination"]["total_count"] == 1 assert result.pagination.total_count == 1
assert result["data"][0].id == user_id assert result.data[0].id == user_id
class TestSearchConfig: class TestSearchConfig:
@@ -318,8 +318,8 @@ class TestSearchConfig:
search_fields=[User.username, User.email], search_fields=[User.username, User.email],
) )
assert result["pagination"]["total_count"] == 1 assert result.pagination.total_count == 1
assert result["data"][0].username == "john_test" assert result.data[0].username == "john_test"
@pytest.mark.anyio @pytest.mark.anyio
async def test_search_config_with_fields(self, db_session: AsyncSession): async def test_search_config_with_fields(self, db_session: AsyncSession):
@@ -333,7 +333,7 @@ class TestSearchConfig:
search=SearchConfig(query="findme", fields=[User.email]), search=SearchConfig(query="findme", fields=[User.email]),
) )
assert result["pagination"]["total_count"] == 1 assert result.pagination.total_count == 1
class TestNoSearchableFieldsError: class TestNoSearchableFieldsError: