From 7482bc5dad89ce0707da283691ccf81b24e451d9 Mon Sep 17 00:00:00 2001 From: d3vyce <44915747+d3vyce@users.noreply.github.com> Date: Mon, 23 Feb 2026 10:02:52 +0100 Subject: [PATCH] feat: add schema parameter to CRUD methods for typed response serialization (#84) --- docs/module/crud.md | 28 +++- src/fastapi_toolsets/crud/factory.py | 176 ++++++++++++++++++++--- tests/conftest.py | 16 +++ tests/test_crud.py | 204 ++++++++++++++++++++++++--- 4 files changed, 382 insertions(+), 42 deletions(-) diff --git a/docs/module/crud.md b/docs/module/crud.md index 4cfd60e..f7cae44 100644 --- a/docs/module/crud.md +++ b/docs/module/crud.md @@ -164,24 +164,42 @@ await UserCrud.upsert( ) ``` -## `as_response` +## `schema` — typed response serialization -Pass `as_response=True` to any write operation to get a [`Response[ModelType]`](../reference/schemas.md#fastapi_toolsets.schemas.Response) back directly for API usage: +!!! info "Added in `v1.1`" + +Pass a Pydantic schema class to `create`, `get`, `update`, or `paginate` to serialize the result directly into that schema and wrap it in a [`Response[schema]`](../reference/schemas.md#fastapi_toolsets.schemas.Response) or [`PaginatedResponse[schema]`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse): ```python +class UserRead(PydanticBase): + id: UUID + username: str + @router.get( "/{uuid}", - response_model=Response[User], responses=generate_error_responses(NotFoundError), ) -async def get_user(session: SessionDep, uuid: UUID): +async def get_user(session: SessionDep, uuid: UUID) -> Response[UserRead]: return await crud.UserCrud.get( session=session, filters=[User.id == uuid], - as_response=True, + schema=UserRead, + ) + +@router.get("") +async def list_users(session: SessionDep, page: int = 1) -> PaginatedResponse[UserRead]: + return await crud.UserCrud.paginate( + session=session, + page=page, + schema=UserRead, ) ``` +The schema must have `from_attributes=True` (or inherit from [`PydanticBase`](../reference/schemas.md#fastapi_toolsets.schemas.PydanticBase)) so it can be built from SQLAlchemy model instances. + +!!! warning "Deprecated: `as_response`" + The `as_response=True` parameter is **deprecated** and will be removed in **v2.0**. Replace it with `schema=YourSchema`. + --- [:material-api: API Reference](../reference/crud.md) diff --git a/src/fastapi_toolsets/crud/factory.py b/src/fastapi_toolsets/crud/factory.py index dd9d89d..4927e36 100644 --- a/src/fastapi_toolsets/crud/factory.py +++ b/src/fastapi_toolsets/crud/factory.py @@ -2,6 +2,7 @@ from __future__ import annotations +import warnings from collections.abc import Mapping, Sequence from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload @@ -21,6 +22,7 @@ from ..schemas import PaginatedResponse, Pagination, Response from .search import SearchConfig, SearchFieldType, build_search_filters ModelType = TypeVar("ModelType", bound=DeclarativeBase) +SchemaType = TypeVar("SchemaType", bound=BaseModel) JoinType = list[tuple[type[DeclarativeBase], Any]] M2MFieldType = Mapping[str, QueryableAttribute[Any]] @@ -101,6 +103,18 @@ class AsyncCrud(Generic[ModelType]): return set() return set(cls.m2m_fields.keys()) + @overload + @classmethod + async def create( # pragma: no cover + cls: type[Self], + session: AsyncSession, + obj: BaseModel, + *, + schema: type[SchemaType], + as_response: bool = ..., + ) -> Response[SchemaType]: ... + + # Backward-compatible - will be removed in v2.0 @overload @classmethod async def create( # pragma: no cover @@ -109,6 +123,7 @@ class AsyncCrud(Generic[ModelType]): obj: BaseModel, *, as_response: Literal[True], + schema: None = ..., ) -> Response[ModelType]: ... @overload @@ -119,6 +134,7 @@ class AsyncCrud(Generic[ModelType]): obj: BaseModel, *, as_response: Literal[False] = ..., + schema: None = ..., ) -> ModelType: ... @classmethod @@ -128,17 +144,28 @@ class AsyncCrud(Generic[ModelType]): obj: BaseModel, *, as_response: bool = False, - ) -> ModelType | Response[ModelType]: + schema: type[BaseModel] | None = None, + ) -> ModelType | Response[ModelType] | Response[Any]: """Create a new record in the database. Args: session: DB async session obj: Pydantic model with data to create - as_response: If True, wrap result in Response object + as_response: Deprecated. Use ``schema`` instead. Will be removed in v2.0. + schema: Pydantic schema to serialize the result into. When provided, + the result is automatically wrapped in a ``Response[schema]``. Returns: - Created model instance or Response wrapping it + Created model instance, or ``Response[schema]`` when ``schema`` is given, + or ``Response[ModelType]`` when ``as_response=True`` (deprecated). """ + if as_response and schema is None: + warnings.warn( + "as_response is deprecated and will be removed in v2.0. " + "Use schema=YourSchema instead.", + DeprecationWarning, + stacklevel=2, + ) async with get_transaction(session): m2m_exclude = cls._m2m_schema_fields() data = ( @@ -154,10 +181,27 @@ class AsyncCrud(Generic[ModelType]): session.add(db_model) await session.refresh(db_model) result = cast(ModelType, db_model) - if as_response: - return Response(data=result) + if as_response or schema: + data_out = schema.model_validate(result) if schema else result + return Response(data=data_out) return result + @overload + @classmethod + async def get( # pragma: no cover + cls: type[Self], + session: AsyncSession, + filters: list[Any], + *, + joins: JoinType | None = None, + outer_join: bool = False, + with_for_update: bool = False, + load_options: list[ExecutableOption] | None = None, + schema: type[SchemaType], + as_response: bool = ..., + ) -> Response[SchemaType]: ... + + # Backward-compatible - will be removed in v2.0 @overload @classmethod async def get( # pragma: no cover @@ -170,6 +214,7 @@ class AsyncCrud(Generic[ModelType]): with_for_update: bool = False, load_options: list[ExecutableOption] | None = None, as_response: Literal[True], + schema: None = ..., ) -> Response[ModelType]: ... @overload @@ -184,6 +229,7 @@ class AsyncCrud(Generic[ModelType]): with_for_update: bool = False, load_options: list[ExecutableOption] | None = None, as_response: Literal[False] = ..., + schema: None = ..., ) -> ModelType: ... @classmethod @@ -197,7 +243,8 @@ class AsyncCrud(Generic[ModelType]): with_for_update: bool = False, load_options: list[ExecutableOption] | None = None, as_response: bool = False, - ) -> ModelType | Response[ModelType]: + schema: type[BaseModel] | None = None, + ) -> ModelType | Response[ModelType] | Response[Any]: """Get exactly one record. Raises NotFoundError if not found. Args: @@ -207,15 +254,25 @@ class AsyncCrud(Generic[ModelType]): outer_join: Use LEFT OUTER JOIN instead of INNER JOIN with_for_update: Lock the row for update load_options: SQLAlchemy loader options (e.g., selectinload) - as_response: If True, wrap result in Response object + as_response: Deprecated. Use ``schema`` instead. Will be removed in v2.0. + schema: Pydantic schema to serialize the result into. When provided, + the result is automatically wrapped in a ``Response[schema]``. Returns: - Model instance or Response wrapping it + Model instance, or ``Response[schema]`` when ``schema`` is given, + or ``Response[ModelType]`` when ``as_response=True`` (deprecated). Raises: NotFoundError: If no record found MultipleResultsFound: If more than one record found """ + if as_response and schema is None: + warnings.warn( + "as_response is deprecated and will be removed in v2.0. " + "Use schema=YourSchema instead.", + DeprecationWarning, + stacklevel=2, + ) q = select(cls.model) if joins: for model, condition in joins: @@ -234,8 +291,9 @@ class AsyncCrud(Generic[ModelType]): if not item: raise NotFoundError() result = cast(ModelType, item) - if as_response: - return Response(data=result) + if as_response or schema: + data_out = schema.model_validate(result) if schema else result + return Response(data=data_out) return result @classmethod @@ -324,6 +382,21 @@ class AsyncCrud(Generic[ModelType]): result = await session.execute(q) return cast(Sequence[ModelType], result.unique().scalars().all()) + @overload + @classmethod + async def update( # pragma: no cover + cls: type[Self], + session: AsyncSession, + obj: BaseModel, + filters: list[Any], + *, + exclude_unset: bool = True, + exclude_none: bool = False, + schema: type[SchemaType], + as_response: bool = ..., + ) -> Response[SchemaType]: ... + + # Backward-compatible - will be removed in v2.0 @overload @classmethod async def update( # pragma: no cover @@ -335,6 +408,7 @@ class AsyncCrud(Generic[ModelType]): exclude_unset: bool = True, exclude_none: bool = False, as_response: Literal[True], + schema: None = ..., ) -> Response[ModelType]: ... @overload @@ -348,6 +422,7 @@ class AsyncCrud(Generic[ModelType]): exclude_unset: bool = True, exclude_none: bool = False, as_response: Literal[False] = ..., + schema: None = ..., ) -> ModelType: ... @classmethod @@ -360,7 +435,8 @@ class AsyncCrud(Generic[ModelType]): exclude_unset: bool = True, exclude_none: bool = False, as_response: bool = False, - ) -> ModelType | Response[ModelType]: + schema: type[BaseModel] | None = None, + ) -> ModelType | Response[ModelType] | Response[Any]: """Update a record in the database. Args: @@ -369,14 +445,24 @@ class AsyncCrud(Generic[ModelType]): filters: List of SQLAlchemy filter conditions exclude_unset: Exclude fields not explicitly set in the schema exclude_none: Exclude fields with None value - as_response: If True, wrap result in Response object + as_response: Deprecated. Use ``schema`` instead. Will be removed in v2.0. + schema: Pydantic schema to serialize the result into. When provided, + the result is automatically wrapped in a ``Response[schema]``. Returns: - Updated model instance or Response wrapping it + Updated model instance, or ``Response[schema]`` when ``schema`` is given, + or ``Response[ModelType]`` when ``as_response=True`` (deprecated). Raises: NotFoundError: If no record found """ + if as_response and schema is None: + warnings.warn( + "as_response is deprecated and will be removed in v2.0. " + "Use schema=YourSchema instead.", + DeprecationWarning, + stacklevel=2, + ) async with get_transaction(session): m2m_exclude = cls._m2m_schema_fields() @@ -406,8 +492,9 @@ class AsyncCrud(Generic[ModelType]): for rel_attr, related_instances in m2m_resolved.items(): setattr(db_model, rel_attr, related_instances) await session.refresh(db_model) - if as_response: - return Response(data=db_model) + if as_response or schema: + data_out = schema.model_validate(db_model) if schema else db_model + return Response(data=data_out) return db_model @classmethod @@ -489,11 +576,20 @@ class AsyncCrud(Generic[ModelType]): Args: session: DB async session filters: List of SQLAlchemy filter conditions - as_response: If True, wrap result in Response object + as_response: Deprecated. Will be removed in v2.0. When ``True``, + returns ``Response[None]`` instead of ``bool``. Returns: - True if deletion was executed, or Response wrapping it + ``True`` if deletion was executed, or ``Response[None]`` when + ``as_response=True`` (deprecated). """ + if as_response: + warnings.warn( + "as_response is deprecated and will be removed in v2.0. " + "Use schema=YourSchema instead.", + DeprecationWarning, + stacklevel=2, + ) async with get_transaction(session): q = sql_delete(cls.model).where(and_(*filters)) await session.execute(q) @@ -566,6 +662,43 @@ class AsyncCrud(Generic[ModelType]): result = await session.execute(q) return bool(result.scalar()) + @overload + @classmethod + async def paginate( # pragma: no cover + cls: type[Self], + session: AsyncSession, + *, + filters: list[Any] | None = None, + joins: JoinType | None = None, + outer_join: bool = False, + load_options: list[ExecutableOption] | None = None, + order_by: Any | None = None, + page: int = 1, + items_per_page: int = 20, + search: str | SearchConfig | None = None, + search_fields: Sequence[SearchFieldType] | None = None, + schema: type[SchemaType], + ) -> PaginatedResponse[SchemaType]: ... + + # Backward-compatible - will be removed in v2.0 + @overload + @classmethod + async def paginate( # pragma: no cover + cls: type[Self], + session: AsyncSession, + *, + filters: list[Any] | None = None, + joins: JoinType | None = None, + outer_join: bool = False, + load_options: list[ExecutableOption] | None = None, + order_by: Any | None = None, + page: int = 1, + items_per_page: int = 20, + search: str | SearchConfig | None = None, + search_fields: Sequence[SearchFieldType] | None = None, + schema: None = ..., + ) -> PaginatedResponse[ModelType]: ... + @classmethod async def paginate( cls: type[Self], @@ -580,7 +713,8 @@ class AsyncCrud(Generic[ModelType]): items_per_page: int = 20, search: str | SearchConfig | None = None, search_fields: Sequence[SearchFieldType] | None = None, - ) -> PaginatedResponse[ModelType]: + schema: type[BaseModel] | None = None, + ) -> PaginatedResponse[ModelType] | PaginatedResponse[Any]: """Get paginated results with metadata. Args: @@ -594,6 +728,7 @@ class AsyncCrud(Generic[ModelType]): items_per_page: Number of items per page search: Search query string or SearchConfig object search_fields: Fields to search in (overrides class default) + schema: Optional Pydantic schema to serialize each item into. Returns: Dict with 'data' and 'pagination' keys @@ -637,7 +772,10 @@ class AsyncCrud(Generic[ModelType]): q = q.offset(offset).limit(items_per_page) result = await session.execute(q) - items = cast(list[ModelType], result.unique().scalars().all()) + raw_items = cast(list[ModelType], result.unique().scalars().all()) + items: list[Any] = ( + [schema.model_validate(item) for item in raw_items] if schema else raw_items + ) # Count query (with same joins and filters) pk_col = cls.model.__mapper__.primary_key[0] diff --git a/tests/conftest.py b/tests/conftest.py index c0a8db4..aafe043 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,8 @@ import uuid import pytest from pydantic import BaseModel from sqlalchemy import Column, ForeignKey, String, Table, Uuid + +from fastapi_toolsets.schemas import PydanticBase from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship @@ -90,6 +92,13 @@ class RoleCreate(BaseModel): name: str +class RoleRead(PydanticBase): + """Schema for reading a role.""" + + id: uuid.UUID + name: str + + class RoleUpdate(BaseModel): """Schema for updating a role.""" @@ -106,6 +115,13 @@ class UserCreate(BaseModel): role_id: uuid.UUID | None = None +class UserRead(PydanticBase): + """Schema for reading a user (subset of fields).""" + + id: uuid.UUID + username: str + + class UserUpdate(BaseModel): """Schema for updating a user.""" diff --git a/tests/test_crud.py b/tests/test_crud.py index f2dba0c..0aa2df8 100644 --- a/tests/test_crud.py +++ b/tests/test_crud.py @@ -20,12 +20,14 @@ from .conftest import ( Role, RoleCreate, RoleCrud, + RoleRead, RoleUpdate, TagCreate, TagCrud, User, UserCreate, UserCrud, + UserRead, UserUpdate, ) @@ -907,15 +909,16 @@ class TestCrudJoins: class TestAsResponse: - """Tests for as_response parameter.""" + """Tests for as_response parameter (deprecated, kept for backward compat).""" @pytest.mark.anyio async def test_create_as_response(self, db_session: AsyncSession): - """Create with as_response=True returns Response.""" + """Create with as_response=True returns Response and emits DeprecationWarning.""" from fastapi_toolsets.schemas import Response data = RoleCreate(name="response_role") - result = await RoleCrud.create(db_session, data, as_response=True) + with pytest.warns(DeprecationWarning, match="as_response is deprecated"): + result = await RoleCrud.create(db_session, data, as_response=True) assert isinstance(result, Response) assert result.data is not None @@ -923,13 +926,14 @@ class TestAsResponse: @pytest.mark.anyio async def test_get_as_response(self, db_session: AsyncSession): - """Get with as_response=True returns Response.""" + """Get with as_response=True returns Response and emits DeprecationWarning.""" from fastapi_toolsets.schemas import Response created = await RoleCrud.create(db_session, RoleCreate(name="get_response")) - result = await RoleCrud.get( - db_session, [Role.id == created.id], as_response=True - ) + with pytest.warns(DeprecationWarning, match="as_response is deprecated"): + result = await RoleCrud.get( + db_session, [Role.id == created.id], as_response=True + ) assert isinstance(result, Response) assert result.data is not None @@ -937,16 +941,17 @@ class TestAsResponse: @pytest.mark.anyio async def test_update_as_response(self, db_session: AsyncSession): - """Update with as_response=True returns Response.""" + """Update with as_response=True returns Response and emits DeprecationWarning.""" from fastapi_toolsets.schemas import Response created = await RoleCrud.create(db_session, RoleCreate(name="old_name")) - result = await RoleCrud.update( - db_session, - RoleUpdate(name="new_name"), - [Role.id == created.id], - as_response=True, - ) + with pytest.warns(DeprecationWarning, match="as_response is deprecated"): + result = await RoleCrud.update( + db_session, + RoleUpdate(name="new_name"), + [Role.id == created.id], + as_response=True, + ) assert isinstance(result, Response) assert result.data is not None @@ -954,13 +959,14 @@ class TestAsResponse: @pytest.mark.anyio async def test_delete_as_response(self, db_session: AsyncSession): - """Delete with as_response=True returns Response.""" + """Delete with as_response=True returns Response and emits DeprecationWarning.""" from fastapi_toolsets.schemas import Response created = await RoleCrud.create(db_session, RoleCreate(name="to_delete")) - result = await RoleCrud.delete( - db_session, [Role.id == created.id], as_response=True - ) + with pytest.warns(DeprecationWarning, match="as_response is deprecated"): + result = await RoleCrud.delete( + db_session, [Role.id == created.id], as_response=True + ) assert isinstance(result, Response) assert result.data is None @@ -1344,3 +1350,165 @@ class TestM2MWithNonM2MCrud: [Post.id == post.id], ) assert updated.title == "Updated Plain" + + +class TestSchemaResponse: + """Tests for the schema parameter on as_response methods.""" + + @pytest.mark.anyio + async def test_create_with_schema(self, db_session: AsyncSession): + """create with schema returns Response[SchemaType].""" + from fastapi_toolsets.schemas import Response + + result = await RoleCrud.create( + db_session, RoleCreate(name="schema_role"), schema=RoleRead + ) + + assert isinstance(result, Response) + assert isinstance(result.data, RoleRead) + assert result.data.name == "schema_role" + + @pytest.mark.anyio + async def test_create_schema_implies_as_response(self, db_session: AsyncSession): + """create with schema alone wraps in Response without as_response=True.""" + from fastapi_toolsets.schemas import Response + + result = await RoleCrud.create( + db_session, RoleCreate(name="implicit"), schema=RoleRead + ) + + assert isinstance(result, Response) + + @pytest.mark.anyio + async def test_create_schema_filters_fields(self, db_session: AsyncSession): + """create with schema only exposes schema fields, not all model fields.""" + result = await UserCrud.create( + db_session, + UserCreate(username="filtered", email="filtered@test.com"), + schema=UserRead, + ) + + assert isinstance(result.data, UserRead) + assert result.data.username == "filtered" + assert not hasattr(result.data, "email") + + @pytest.mark.anyio + async def test_get_with_schema(self, db_session: AsyncSession): + """get with schema returns Response[SchemaType].""" + from fastapi_toolsets.schemas import Response + + created = await RoleCrud.create(db_session, RoleCreate(name="get_schema")) + result = await RoleCrud.get( + db_session, [Role.id == created.id], schema=RoleRead + ) + + assert isinstance(result, Response) + assert isinstance(result.data, RoleRead) + assert result.data.id == created.id + assert result.data.name == "get_schema" + + @pytest.mark.anyio + async def test_get_schema_implies_as_response(self, db_session: AsyncSession): + """get with schema alone wraps in Response without as_response=True.""" + from fastapi_toolsets.schemas import Response + + created = await RoleCrud.create(db_session, RoleCreate(name="implicit_get")) + result = await RoleCrud.get( + db_session, [Role.id == created.id], schema=RoleRead + ) + + assert isinstance(result, Response) + + @pytest.mark.anyio + async def test_update_with_schema(self, db_session: AsyncSession): + """update with schema returns Response[SchemaType].""" + from fastapi_toolsets.schemas import Response + + created = await RoleCrud.create(db_session, RoleCreate(name="before")) + result = await RoleCrud.update( + db_session, + RoleUpdate(name="after"), + [Role.id == created.id], + schema=RoleRead, + ) + + assert isinstance(result, Response) + assert isinstance(result.data, RoleRead) + assert result.data.name == "after" + + @pytest.mark.anyio + async def test_update_schema_implies_as_response(self, db_session: AsyncSession): + """update with schema alone wraps in Response without as_response=True.""" + from fastapi_toolsets.schemas import Response + + created = await RoleCrud.create(db_session, RoleCreate(name="before2")) + result = await RoleCrud.update( + db_session, + RoleUpdate(name="after2"), + [Role.id == created.id], + schema=RoleRead, + ) + + assert isinstance(result, Response) + + @pytest.mark.anyio + async def test_paginate_with_schema(self, db_session: AsyncSession): + """paginate with schema returns PaginatedResponse[SchemaType].""" + from fastapi_toolsets.schemas import PaginatedResponse + + await RoleCrud.create(db_session, RoleCreate(name="p_role1")) + await RoleCrud.create(db_session, RoleCreate(name="p_role2")) + + result = await RoleCrud.paginate(db_session, schema=RoleRead) + + assert isinstance(result, PaginatedResponse) + assert len(result.data) == 2 + assert all(isinstance(item, RoleRead) for item in result.data) + + @pytest.mark.anyio + async def test_paginate_schema_filters_fields(self, db_session: AsyncSession): + """paginate with schema only exposes schema fields per item.""" + await UserCrud.create( + db_session, + UserCreate(username="pg_user", email="pg@test.com"), + ) + + result = await UserCrud.paginate(db_session, schema=UserRead) + + assert isinstance(result.data[0], UserRead) + assert result.data[0].username == "pg_user" + assert not hasattr(result.data[0], "email") + + @pytest.mark.anyio + async def test_as_response_true_without_schema_unchanged( + self, db_session: AsyncSession + ): + """as_response=True without schema still returns Response[ModelType] with a warning.""" + from fastapi_toolsets.schemas import Response + + created = await RoleCrud.create(db_session, RoleCreate(name="compat")) + with pytest.warns(DeprecationWarning, match="as_response is deprecated"): + result = await RoleCrud.get( + db_session, [Role.id == created.id], as_response=True + ) + + assert isinstance(result, Response) + assert isinstance(result.data, Role) + + @pytest.mark.anyio + async def test_schema_with_explicit_as_response_true( + self, db_session: AsyncSession + ): + """schema combined with explicit as_response=True works correctly.""" + from fastapi_toolsets.schemas import Response + + created = await RoleCrud.create(db_session, RoleCreate(name="combined")) + result = await RoleCrud.get( + db_session, + [Role.id == created.id], + as_response=True, + schema=RoleRead, + ) + + assert isinstance(result, Response) + assert isinstance(result.data, RoleRead)