diff --git a/src/fastapi_toolsets/crud/factory.py b/src/fastapi_toolsets/crud/factory.py index 565178c..a1e7eb7 100644 --- a/src/fastapi_toolsets/crud/factory.py +++ b/src/fastapi_toolsets/crud/factory.py @@ -1,7 +1,7 @@ """Generic async CRUD operations for SQLAlchemy models.""" from collections.abc import Sequence -from typing import Any, ClassVar, Generic, Self, TypeVar, cast +from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload from pydantic import BaseModel from sqlalchemy import and_, func, select @@ -14,6 +14,7 @@ from sqlalchemy.sql.roles import WhereHavingRole from ..db import get_transaction from ..exceptions import NotFoundError +from ..schemas import PaginatedResponse, Pagination, Response from .search import SearchConfig, SearchFieldType, build_search_filters ModelType = TypeVar("ModelType", bound=DeclarativeBase) @@ -29,26 +30,80 @@ class AsyncCrud(Generic[ModelType]): model: ClassVar[type[DeclarativeBase]] searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None + @overload + @classmethod + async def create( # pragma: no cover + cls: type[Self], + session: AsyncSession, + obj: BaseModel, + *, + as_response: Literal[True], + ) -> Response[ModelType]: ... + + @overload + @classmethod + async def create( # pragma: no cover + cls: type[Self], + session: AsyncSession, + obj: BaseModel, + *, + as_response: Literal[False] = ..., + ) -> ModelType: ... + @classmethod async def create( cls: type[Self], session: AsyncSession, obj: BaseModel, - ) -> ModelType: + *, + as_response: bool = False, + ) -> ModelType | Response[ModelType]: """Create a new record in the database. Args: session: DB async session obj: Pydantic model with data to create + as_response: If True, wrap result in Response object Returns: - Created model instance + Created model instance or Response wrapping it """ async with get_transaction(session): db_model = cls.model(**obj.model_dump()) session.add(db_model) await session.refresh(db_model) - return cast(ModelType, db_model) + result = cast(ModelType, db_model) + if as_response: + return Response(data=result) + return result + + @overload + @classmethod + async def get( # pragma: no cover + cls: type[Self], + session: AsyncSession, + filters: list[Any], + *, + joins: JoinType | None = None, + outer_join: bool = False, + with_for_update: bool = False, + load_options: list[Any] | None = None, + as_response: Literal[True], + ) -> Response[ModelType]: ... + + @overload + @classmethod + async def get( # pragma: no cover + cls: type[Self], + session: AsyncSession, + filters: list[Any], + *, + joins: JoinType | None = None, + outer_join: bool = False, + with_for_update: bool = False, + load_options: list[Any] | None = None, + as_response: Literal[False] = ..., + ) -> ModelType: ... @classmethod async def get( @@ -60,7 +115,8 @@ class AsyncCrud(Generic[ModelType]): outer_join: bool = False, with_for_update: bool = False, load_options: list[Any] | None = None, - ) -> ModelType: + as_response: bool = False, + ) -> ModelType | Response[ModelType]: """Get exactly one record. Raises NotFoundError if not found. Args: @@ -70,9 +126,10 @@ class AsyncCrud(Generic[ModelType]): outer_join: Use LEFT OUTER JOIN instead of INNER JOIN with_for_update: Lock the row for update load_options: SQLAlchemy loader options (e.g., selectinload) + as_response: If True, wrap result in Response object Returns: - Model instance + Model instance or Response wrapping it Raises: NotFoundError: If no record found @@ -95,7 +152,10 @@ class AsyncCrud(Generic[ModelType]): item = result.unique().scalar_one_or_none() if not item: raise NotFoundError() - return cast(ModelType, item) + result = cast(ModelType, item) + if as_response: + return Response(data=result) + return result @classmethod async def first( @@ -183,6 +243,32 @@ class AsyncCrud(Generic[ModelType]): result = await session.execute(q) return cast(Sequence[ModelType], result.unique().scalars().all()) + @overload + @classmethod + async def update( # pragma: no cover + cls: type[Self], + session: AsyncSession, + obj: BaseModel, + filters: list[Any], + *, + exclude_unset: bool = True, + exclude_none: bool = False, + as_response: Literal[True], + ) -> Response[ModelType]: ... + + @overload + @classmethod + async def update( # pragma: no cover + cls: type[Self], + session: AsyncSession, + obj: BaseModel, + filters: list[Any], + *, + exclude_unset: bool = True, + exclude_none: bool = False, + as_response: Literal[False] = ..., + ) -> ModelType: ... + @classmethod async def update( cls: type[Self], @@ -192,7 +278,8 @@ class AsyncCrud(Generic[ModelType]): *, exclude_unset: bool = True, exclude_none: bool = False, - ) -> ModelType: + as_response: bool = False, + ) -> ModelType | Response[ModelType]: """Update a record in the database. Args: @@ -201,9 +288,10 @@ class AsyncCrud(Generic[ModelType]): filters: List of SQLAlchemy filter conditions exclude_unset: Exclude fields not explicitly set in the schema exclude_none: Exclude fields with None value + as_response: If True, wrap result in Response object Returns: - Updated model instance + Updated model instance or Response wrapping it Raises: NotFoundError: If no record found @@ -216,6 +304,8 @@ class AsyncCrud(Generic[ModelType]): for key, value in values.items(): setattr(db_model, key, value) await session.refresh(db_model) + if as_response: + return Response(data=db_model) return db_model @classmethod @@ -264,24 +354,49 @@ class AsyncCrud(Generic[ModelType]): ) return cast(ModelType | None, db_model) + @overload + @classmethod + async def delete( # pragma: no cover + cls: type[Self], + session: AsyncSession, + filters: list[Any], + *, + as_response: Literal[True], + ) -> Response[None]: ... + + @overload + @classmethod + async def delete( # pragma: no cover + cls: type[Self], + session: AsyncSession, + filters: list[Any], + *, + as_response: Literal[False] = ..., + ) -> bool: ... + @classmethod async def delete( cls: type[Self], session: AsyncSession, filters: list[Any], - ) -> bool: + *, + as_response: bool = False, + ) -> bool | Response[None]: """Delete records from the database. Args: session: DB async session filters: List of SQLAlchemy filter conditions + as_response: If True, wrap result in Response object Returns: - True if deletion was executed + True if deletion was executed, or Response wrapping it """ async with get_transaction(session): q = sql_delete(cls.model).where(and_(*filters)) await session.execute(q) + if as_response: + return Response(data=None) return True @classmethod @@ -363,7 +478,7 @@ class AsyncCrud(Generic[ModelType]): items_per_page: int = 20, search: str | SearchConfig | None = None, search_fields: Sequence[SearchFieldType] | None = None, - ) -> dict[str, Any]: + ) -> PaginatedResponse[ModelType]: """Get paginated results with metadata. Args: @@ -420,7 +535,7 @@ class AsyncCrud(Generic[ModelType]): q = q.offset(offset).limit(items_per_page) result = await session.execute(q) - items = result.unique().scalars().all() + items = cast(list[ModelType], result.unique().scalars().all()) # Count query (with same joins and filters) pk_col = cls.model.__mapper__.primary_key[0] @@ -446,15 +561,15 @@ class AsyncCrud(Generic[ModelType]): count_result = await session.execute(count_q) total_count = count_result.scalar_one() - return { - "data": items, - "pagination": { - "total_count": total_count, - "items_per_page": items_per_page, - "page": page, - "has_more": page * items_per_page < total_count, - }, - } + return PaginatedResponse( + data=items, + pagination=Pagination( + total_count=total_count, + items_per_page=items_per_page, + page=page, + has_more=page * items_per_page < total_count, + ), + ) def CrudFactory( diff --git a/tests/test_crud.py b/tests/test_crud.py index 043f01b..18fad02 100644 --- a/tests/test_crud.py +++ b/tests/test_crud.py @@ -429,11 +429,11 @@ class TestCrudPaginate: result = await RoleCrud.paginate(db_session, page=1, items_per_page=10) - assert len(result["data"]) == 10 - assert result["pagination"]["total_count"] == 25 - assert result["pagination"]["page"] == 1 - assert result["pagination"]["items_per_page"] == 10 - assert result["pagination"]["has_more"] is True + assert len(result.data) == 10 + assert result.pagination.total_count == 25 + assert result.pagination.page == 1 + assert result.pagination.items_per_page == 10 + assert result.pagination.has_more is True @pytest.mark.anyio async def test_paginate_last_page(self, db_session: AsyncSession): @@ -443,8 +443,8 @@ class TestCrudPaginate: result = await RoleCrud.paginate(db_session, page=3, items_per_page=10) - assert len(result["data"]) == 5 - assert result["pagination"]["has_more"] is False + assert len(result.data) == 5 + assert result.pagination.has_more is False @pytest.mark.anyio async def test_paginate_with_filters(self, db_session: AsyncSession): @@ -466,7 +466,7 @@ class TestCrudPaginate: items_per_page=10, ) - assert result["pagination"]["total_count"] == 5 + assert result.pagination.total_count == 5 @pytest.mark.anyio async def test_paginate_with_ordering(self, db_session: AsyncSession): @@ -482,7 +482,7 @@ class TestCrudPaginate: items_per_page=10, ) - names = [r.name for r in result["data"]] + names = [r.name for r in result.data] assert names == ["alpha", "bravo", "charlie"] @@ -690,8 +690,8 @@ class TestCrudJoins: items_per_page=10, ) - assert result["pagination"]["total_count"] == 3 - assert len(result["data"]) == 3 + assert result.pagination.total_count == 3 + assert len(result.data) == 3 @pytest.mark.anyio async def test_paginate_with_outer_join(self, db_session: AsyncSession): @@ -721,8 +721,8 @@ class TestCrudJoins: items_per_page=10, ) - assert result["pagination"]["total_count"] == 2 - assert len(result["data"]) == 2 + assert result.pagination.total_count == 2 + assert len(result.data) == 2 @pytest.mark.anyio async def test_multiple_joins(self, db_session: AsyncSession): @@ -752,3 +752,63 @@ class TestCrudJoins: ) assert len(users) == 1 assert users[0].username == "multi_join" + + +class TestAsResponse: + """Tests for as_response parameter.""" + + @pytest.mark.anyio + async def test_create_as_response(self, db_session: AsyncSession): + """Create with as_response=True returns Response.""" + from fastapi_toolsets.schemas import Response + + data = RoleCreate(name="response_role") + result = await RoleCrud.create(db_session, data, as_response=True) + + assert isinstance(result, Response) + assert result.data is not None + assert result.data.name == "response_role" + + @pytest.mark.anyio + async def test_get_as_response(self, db_session: AsyncSession): + """Get with as_response=True returns Response.""" + from fastapi_toolsets.schemas import Response + + created = await RoleCrud.create(db_session, RoleCreate(name="get_response")) + result = await RoleCrud.get( + db_session, [Role.id == created.id], as_response=True + ) + + assert isinstance(result, Response) + assert result.data is not None + assert result.data.id == created.id + + @pytest.mark.anyio + async def test_update_as_response(self, db_session: AsyncSession): + """Update with as_response=True returns Response.""" + from fastapi_toolsets.schemas import Response + + created = await RoleCrud.create(db_session, RoleCreate(name="old_name")) + result = await RoleCrud.update( + db_session, + RoleUpdate(name="new_name"), + [Role.id == created.id], + as_response=True, + ) + + assert isinstance(result, Response) + assert result.data is not None + assert result.data.name == "new_name" + + @pytest.mark.anyio + async def test_delete_as_response(self, db_session: AsyncSession): + """Delete with as_response=True returns Response.""" + from fastapi_toolsets.schemas import Response + + created = await RoleCrud.create(db_session, RoleCreate(name="to_delete")) + result = await RoleCrud.delete( + db_session, [Role.id == created.id], as_response=True + ) + + assert isinstance(result, Response) + assert result.data is None diff --git a/tests/test_crud_search.py b/tests/test_crud_search.py index 4a6e886..8f45826 100644 --- a/tests/test_crud_search.py +++ b/tests/test_crud_search.py @@ -39,7 +39,7 @@ class TestPaginateSearch: search_fields=[User.username], ) - assert result["pagination"]["total_count"] == 2 + assert result.pagination.total_count == 2 @pytest.mark.anyio async def test_search_multiple_columns(self, db_session: AsyncSession): @@ -57,7 +57,7 @@ class TestPaginateSearch: search_fields=[User.username, User.email], ) - assert result["pagination"]["total_count"] == 2 + assert result.pagination.total_count == 2 @pytest.mark.anyio async def test_search_relationship_depth1(self, db_session: AsyncSession): @@ -84,7 +84,7 @@ class TestPaginateSearch: search_fields=[(User.role, Role.name)], ) - assert result["pagination"]["total_count"] == 2 + assert result.pagination.total_count == 2 @pytest.mark.anyio async def test_search_mixed_direct_and_relation(self, db_session: AsyncSession): @@ -102,7 +102,7 @@ class TestPaginateSearch: search_fields=[User.username, (User.role, Role.name)], ) - assert result["pagination"]["total_count"] == 1 + assert result.pagination.total_count == 1 @pytest.mark.anyio async def test_search_case_insensitive(self, db_session: AsyncSession): @@ -117,7 +117,7 @@ class TestPaginateSearch: search_fields=[User.username], ) - assert result["pagination"]["total_count"] == 1 + assert result.pagination.total_count == 1 @pytest.mark.anyio async def test_search_case_sensitive(self, db_session: AsyncSession): @@ -132,7 +132,7 @@ class TestPaginateSearch: search=SearchConfig(query="johndoe", case_sensitive=True), search_fields=[User.username], ) - assert result["pagination"]["total_count"] == 0 + assert result.pagination.total_count == 0 # Should find (case match) result = await UserCrud.paginate( @@ -140,7 +140,7 @@ class TestPaginateSearch: search=SearchConfig(query="JohnDoe", case_sensitive=True), search_fields=[User.username], ) - assert result["pagination"]["total_count"] == 1 + assert result.pagination.total_count == 1 @pytest.mark.anyio async def test_search_empty_query(self, db_session: AsyncSession): @@ -153,10 +153,10 @@ class TestPaginateSearch: ) result = await UserCrud.paginate(db_session, search="") - assert result["pagination"]["total_count"] == 2 + assert result.pagination.total_count == 2 result = await UserCrud.paginate(db_session, search=None) - assert result["pagination"]["total_count"] == 2 + assert result.pagination.total_count == 2 @pytest.mark.anyio async def test_search_with_existing_filters(self, db_session: AsyncSession): @@ -177,8 +177,8 @@ class TestPaginateSearch: search_fields=[User.username], ) - assert result["pagination"]["total_count"] == 1 - assert result["data"][0].username == "active_john" + assert result.pagination.total_count == 1 + assert result.data[0].username == "active_john" @pytest.mark.anyio async def test_search_auto_detect_fields(self, db_session: AsyncSession): @@ -189,7 +189,7 @@ class TestPaginateSearch: result = await UserCrud.paginate(db_session, search="findme") - assert result["pagination"]["total_count"] == 1 + assert result.pagination.total_count == 1 @pytest.mark.anyio async def test_search_no_results(self, db_session: AsyncSession): @@ -204,8 +204,8 @@ class TestPaginateSearch: search_fields=[User.username], ) - assert result["pagination"]["total_count"] == 0 - assert result["data"] == [] + assert result.pagination.total_count == 0 + assert result.data == [] @pytest.mark.anyio async def test_search_with_pagination(self, db_session: AsyncSession): @@ -224,9 +224,9 @@ class TestPaginateSearch: items_per_page=5, ) - assert result["pagination"]["total_count"] == 15 - assert len(result["data"]) == 5 - assert result["pagination"]["has_more"] is True + assert result.pagination.total_count == 15 + assert len(result.data) == 5 + assert result.pagination.has_more is True @pytest.mark.anyio async def test_search_null_relationship(self, db_session: AsyncSession): @@ -248,7 +248,7 @@ class TestPaginateSearch: search_fields=[User.username], ) - assert result["pagination"]["total_count"] == 2 + assert result.pagination.total_count == 2 @pytest.mark.anyio async def test_search_with_order_by(self, db_session: AsyncSession): @@ -270,8 +270,8 @@ class TestPaginateSearch: order_by=User.username, ) - assert result["pagination"]["total_count"] == 3 - usernames = [u.username for u in result["data"]] + assert result.pagination.total_count == 3 + usernames = [u.username for u in result.data] assert usernames == ["alice", "bob", "charlie"] @pytest.mark.anyio @@ -292,8 +292,8 @@ class TestPaginateSearch: search_fields=[User.id, User.username], ) - assert result["pagination"]["total_count"] == 1 - assert result["data"][0].id == user_id + assert result.pagination.total_count == 1 + assert result.data[0].id == user_id class TestSearchConfig: @@ -318,8 +318,8 @@ class TestSearchConfig: search_fields=[User.username, User.email], ) - assert result["pagination"]["total_count"] == 1 - assert result["data"][0].username == "john_test" + assert result.pagination.total_count == 1 + assert result.data[0].username == "john_test" @pytest.mark.anyio async def test_search_config_with_fields(self, db_session: AsyncSession): @@ -333,7 +333,7 @@ class TestSearchConfig: search=SearchConfig(query="findme", fields=[User.email]), ) - assert result["pagination"]["total_count"] == 1 + assert result.pagination.total_count == 1 class TestNoSearchableFieldsError: