mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-03-01 17:00:48 +01:00
feat: use PaginatedResponse and Response into crud (#36)
* feat: return PaginatedResponse for paginate crud function * feat: add as_response argument for get, create, update and delete crud functions
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
"""Generic async CRUD operations for SQLAlchemy models."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, ClassVar, Generic, Self, TypeVar, cast
|
||||
from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import and_, func, select
|
||||
@@ -14,6 +14,7 @@ from sqlalchemy.sql.roles import WhereHavingRole
|
||||
|
||||
from ..db import get_transaction
|
||||
from ..exceptions import NotFoundError
|
||||
from ..schemas import PaginatedResponse, Pagination, Response
|
||||
from .search import SearchConfig, SearchFieldType, build_search_filters
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
||||
@@ -29,26 +30,80 @@ class AsyncCrud(Generic[ModelType]):
|
||||
model: ClassVar[type[DeclarativeBase]]
|
||||
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def create( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
obj: BaseModel,
|
||||
*,
|
||||
as_response: Literal[True],
|
||||
) -> Response[ModelType]: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def create( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
obj: BaseModel,
|
||||
*,
|
||||
as_response: Literal[False] = ...,
|
||||
) -> ModelType: ...
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
obj: BaseModel,
|
||||
) -> ModelType:
|
||||
*,
|
||||
as_response: bool = False,
|
||||
) -> ModelType | Response[ModelType]:
|
||||
"""Create a new record in the database.
|
||||
|
||||
Args:
|
||||
session: DB async session
|
||||
obj: Pydantic model with data to create
|
||||
as_response: If True, wrap result in Response object
|
||||
|
||||
Returns:
|
||||
Created model instance
|
||||
Created model instance or Response wrapping it
|
||||
"""
|
||||
async with get_transaction(session):
|
||||
db_model = cls.model(**obj.model_dump())
|
||||
session.add(db_model)
|
||||
await session.refresh(db_model)
|
||||
return cast(ModelType, db_model)
|
||||
result = cast(ModelType, db_model)
|
||||
if as_response:
|
||||
return Response(data=result)
|
||||
return result
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def get( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
filters: list[Any],
|
||||
*,
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
with_for_update: bool = False,
|
||||
load_options: list[Any] | None = None,
|
||||
as_response: Literal[True],
|
||||
) -> Response[ModelType]: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def get( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
filters: list[Any],
|
||||
*,
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
with_for_update: bool = False,
|
||||
load_options: list[Any] | None = None,
|
||||
as_response: Literal[False] = ...,
|
||||
) -> ModelType: ...
|
||||
|
||||
@classmethod
|
||||
async def get(
|
||||
@@ -60,7 +115,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
outer_join: bool = False,
|
||||
with_for_update: bool = False,
|
||||
load_options: list[Any] | None = None,
|
||||
) -> ModelType:
|
||||
as_response: bool = False,
|
||||
) -> ModelType | Response[ModelType]:
|
||||
"""Get exactly one record. Raises NotFoundError if not found.
|
||||
|
||||
Args:
|
||||
@@ -70,9 +126,10 @@ class AsyncCrud(Generic[ModelType]):
|
||||
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
|
||||
with_for_update: Lock the row for update
|
||||
load_options: SQLAlchemy loader options (e.g., selectinload)
|
||||
as_response: If True, wrap result in Response object
|
||||
|
||||
Returns:
|
||||
Model instance
|
||||
Model instance or Response wrapping it
|
||||
|
||||
Raises:
|
||||
NotFoundError: If no record found
|
||||
@@ -95,7 +152,10 @@ class AsyncCrud(Generic[ModelType]):
|
||||
item = result.unique().scalar_one_or_none()
|
||||
if not item:
|
||||
raise NotFoundError()
|
||||
return cast(ModelType, item)
|
||||
result = cast(ModelType, item)
|
||||
if as_response:
|
||||
return Response(data=result)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
async def first(
|
||||
@@ -183,6 +243,32 @@ class AsyncCrud(Generic[ModelType]):
|
||||
result = await session.execute(q)
|
||||
return cast(Sequence[ModelType], result.unique().scalars().all())
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def update( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
obj: BaseModel,
|
||||
filters: list[Any],
|
||||
*,
|
||||
exclude_unset: bool = True,
|
||||
exclude_none: bool = False,
|
||||
as_response: Literal[True],
|
||||
) -> Response[ModelType]: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def update( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
obj: BaseModel,
|
||||
filters: list[Any],
|
||||
*,
|
||||
exclude_unset: bool = True,
|
||||
exclude_none: bool = False,
|
||||
as_response: Literal[False] = ...,
|
||||
) -> ModelType: ...
|
||||
|
||||
@classmethod
|
||||
async def update(
|
||||
cls: type[Self],
|
||||
@@ -192,7 +278,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
*,
|
||||
exclude_unset: bool = True,
|
||||
exclude_none: bool = False,
|
||||
) -> ModelType:
|
||||
as_response: bool = False,
|
||||
) -> ModelType | Response[ModelType]:
|
||||
"""Update a record in the database.
|
||||
|
||||
Args:
|
||||
@@ -201,9 +288,10 @@ class AsyncCrud(Generic[ModelType]):
|
||||
filters: List of SQLAlchemy filter conditions
|
||||
exclude_unset: Exclude fields not explicitly set in the schema
|
||||
exclude_none: Exclude fields with None value
|
||||
as_response: If True, wrap result in Response object
|
||||
|
||||
Returns:
|
||||
Updated model instance
|
||||
Updated model instance or Response wrapping it
|
||||
|
||||
Raises:
|
||||
NotFoundError: If no record found
|
||||
@@ -216,6 +304,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
for key, value in values.items():
|
||||
setattr(db_model, key, value)
|
||||
await session.refresh(db_model)
|
||||
if as_response:
|
||||
return Response(data=db_model)
|
||||
return db_model
|
||||
|
||||
@classmethod
|
||||
@@ -264,24 +354,49 @@ class AsyncCrud(Generic[ModelType]):
|
||||
)
|
||||
return cast(ModelType | None, db_model)
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def delete( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
filters: list[Any],
|
||||
*,
|
||||
as_response: Literal[True],
|
||||
) -> Response[None]: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def delete( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
filters: list[Any],
|
||||
*,
|
||||
as_response: Literal[False] = ...,
|
||||
) -> bool: ...
|
||||
|
||||
@classmethod
|
||||
async def delete(
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
filters: list[Any],
|
||||
) -> bool:
|
||||
*,
|
||||
as_response: bool = False,
|
||||
) -> bool | Response[None]:
|
||||
"""Delete records from the database.
|
||||
|
||||
Args:
|
||||
session: DB async session
|
||||
filters: List of SQLAlchemy filter conditions
|
||||
as_response: If True, wrap result in Response object
|
||||
|
||||
Returns:
|
||||
True if deletion was executed
|
||||
True if deletion was executed, or Response wrapping it
|
||||
"""
|
||||
async with get_transaction(session):
|
||||
q = sql_delete(cls.model).where(and_(*filters))
|
||||
await session.execute(q)
|
||||
if as_response:
|
||||
return Response(data=None)
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
@@ -363,7 +478,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
items_per_page: int = 20,
|
||||
search: str | SearchConfig | None = None,
|
||||
search_fields: Sequence[SearchFieldType] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
) -> PaginatedResponse[ModelType]:
|
||||
"""Get paginated results with metadata.
|
||||
|
||||
Args:
|
||||
@@ -420,7 +535,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
|
||||
q = q.offset(offset).limit(items_per_page)
|
||||
result = await session.execute(q)
|
||||
items = result.unique().scalars().all()
|
||||
items = cast(list[ModelType], result.unique().scalars().all())
|
||||
|
||||
# Count query (with same joins and filters)
|
||||
pk_col = cls.model.__mapper__.primary_key[0]
|
||||
@@ -446,15 +561,15 @@ class AsyncCrud(Generic[ModelType]):
|
||||
count_result = await session.execute(count_q)
|
||||
total_count = count_result.scalar_one()
|
||||
|
||||
return {
|
||||
"data": items,
|
||||
"pagination": {
|
||||
"total_count": total_count,
|
||||
"items_per_page": items_per_page,
|
||||
"page": page,
|
||||
"has_more": page * items_per_page < total_count,
|
||||
},
|
||||
}
|
||||
return PaginatedResponse(
|
||||
data=items,
|
||||
pagination=Pagination(
|
||||
total_count=total_count,
|
||||
items_per_page=items_per_page,
|
||||
page=page,
|
||||
has_more=page * items_per_page < total_count,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def CrudFactory(
|
||||
|
||||
Reference in New Issue
Block a user