mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-03-01 17:00:48 +01:00
feat: add schema parameter to CRUD methods for typed response serialization (#84)
This commit is contained in:
@@ -164,24 +164,42 @@ await UserCrud.upsert(
|
||||
)
|
||||
```
|
||||
|
||||
## `as_response`
|
||||
## `schema` — typed response serialization
|
||||
|
||||
Pass `as_response=True` to any write operation to get a [`Response[ModelType]`](../reference/schemas.md#fastapi_toolsets.schemas.Response) back directly for API usage:
|
||||
!!! info "Added in `v1.1`"
|
||||
|
||||
Pass a Pydantic schema class to `create`, `get`, `update`, or `paginate` to serialize the result directly into that schema and wrap it in a [`Response[schema]`](../reference/schemas.md#fastapi_toolsets.schemas.Response) or [`PaginatedResponse[schema]`](../reference/schemas.md#fastapi_toolsets.schemas.PaginatedResponse):
|
||||
|
||||
```python
|
||||
class UserRead(PydanticBase):
|
||||
id: UUID
|
||||
username: str
|
||||
|
||||
@router.get(
|
||||
"/{uuid}",
|
||||
response_model=Response[User],
|
||||
responses=generate_error_responses(NotFoundError),
|
||||
)
|
||||
async def get_user(session: SessionDep, uuid: UUID):
|
||||
async def get_user(session: SessionDep, uuid: UUID) -> Response[UserRead]:
|
||||
return await crud.UserCrud.get(
|
||||
session=session,
|
||||
filters=[User.id == uuid],
|
||||
as_response=True,
|
||||
schema=UserRead,
|
||||
)
|
||||
|
||||
@router.get("")
|
||||
async def list_users(session: SessionDep, page: int = 1) -> PaginatedResponse[UserRead]:
|
||||
return await crud.UserCrud.paginate(
|
||||
session=session,
|
||||
page=page,
|
||||
schema=UserRead,
|
||||
)
|
||||
```
|
||||
|
||||
The schema must have `from_attributes=True` (or inherit from [`PydanticBase`](../reference/schemas.md#fastapi_toolsets.schemas.PydanticBase)) so it can be built from SQLAlchemy model instances.
|
||||
|
||||
!!! warning "Deprecated: `as_response`"
|
||||
The `as_response=True` parameter is **deprecated** and will be removed in **v2.0**. Replace it with `schema=YourSchema`.
|
||||
|
||||
---
|
||||
|
||||
[:material-api: API Reference](../reference/crud.md)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload
|
||||
|
||||
@@ -21,6 +22,7 @@ from ..schemas import PaginatedResponse, Pagination, Response
|
||||
from .search import SearchConfig, SearchFieldType, build_search_filters
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
||||
SchemaType = TypeVar("SchemaType", bound=BaseModel)
|
||||
JoinType = list[tuple[type[DeclarativeBase], Any]]
|
||||
M2MFieldType = Mapping[str, QueryableAttribute[Any]]
|
||||
|
||||
@@ -101,6 +103,18 @@ class AsyncCrud(Generic[ModelType]):
|
||||
return set()
|
||||
return set(cls.m2m_fields.keys())
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def create( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
obj: BaseModel,
|
||||
*,
|
||||
schema: type[SchemaType],
|
||||
as_response: bool = ...,
|
||||
) -> Response[SchemaType]: ...
|
||||
|
||||
# Backward-compatible - will be removed in v2.0
|
||||
@overload
|
||||
@classmethod
|
||||
async def create( # pragma: no cover
|
||||
@@ -109,6 +123,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
obj: BaseModel,
|
||||
*,
|
||||
as_response: Literal[True],
|
||||
schema: None = ...,
|
||||
) -> Response[ModelType]: ...
|
||||
|
||||
@overload
|
||||
@@ -119,6 +134,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
obj: BaseModel,
|
||||
*,
|
||||
as_response: Literal[False] = ...,
|
||||
schema: None = ...,
|
||||
) -> ModelType: ...
|
||||
|
||||
@classmethod
|
||||
@@ -128,17 +144,28 @@ class AsyncCrud(Generic[ModelType]):
|
||||
obj: BaseModel,
|
||||
*,
|
||||
as_response: bool = False,
|
||||
) -> ModelType | Response[ModelType]:
|
||||
schema: type[BaseModel] | None = None,
|
||||
) -> ModelType | Response[ModelType] | Response[Any]:
|
||||
"""Create a new record in the database.
|
||||
|
||||
Args:
|
||||
session: DB async session
|
||||
obj: Pydantic model with data to create
|
||||
as_response: If True, wrap result in Response object
|
||||
as_response: Deprecated. Use ``schema`` instead. Will be removed in v2.0.
|
||||
schema: Pydantic schema to serialize the result into. When provided,
|
||||
the result is automatically wrapped in a ``Response[schema]``.
|
||||
|
||||
Returns:
|
||||
Created model instance or Response wrapping it
|
||||
Created model instance, or ``Response[schema]`` when ``schema`` is given,
|
||||
or ``Response[ModelType]`` when ``as_response=True`` (deprecated).
|
||||
"""
|
||||
if as_response and schema is None:
|
||||
warnings.warn(
|
||||
"as_response is deprecated and will be removed in v2.0. "
|
||||
"Use schema=YourSchema instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
async with get_transaction(session):
|
||||
m2m_exclude = cls._m2m_schema_fields()
|
||||
data = (
|
||||
@@ -154,10 +181,27 @@ class AsyncCrud(Generic[ModelType]):
|
||||
session.add(db_model)
|
||||
await session.refresh(db_model)
|
||||
result = cast(ModelType, db_model)
|
||||
if as_response:
|
||||
return Response(data=result)
|
||||
if as_response or schema:
|
||||
data_out = schema.model_validate(result) if schema else result
|
||||
return Response(data=data_out)
|
||||
return result
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def get( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
filters: list[Any],
|
||||
*,
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
with_for_update: bool = False,
|
||||
load_options: list[ExecutableOption] | None = None,
|
||||
schema: type[SchemaType],
|
||||
as_response: bool = ...,
|
||||
) -> Response[SchemaType]: ...
|
||||
|
||||
# Backward-compatible - will be removed in v2.0
|
||||
@overload
|
||||
@classmethod
|
||||
async def get( # pragma: no cover
|
||||
@@ -170,6 +214,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
with_for_update: bool = False,
|
||||
load_options: list[ExecutableOption] | None = None,
|
||||
as_response: Literal[True],
|
||||
schema: None = ...,
|
||||
) -> Response[ModelType]: ...
|
||||
|
||||
@overload
|
||||
@@ -184,6 +229,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
with_for_update: bool = False,
|
||||
load_options: list[ExecutableOption] | None = None,
|
||||
as_response: Literal[False] = ...,
|
||||
schema: None = ...,
|
||||
) -> ModelType: ...
|
||||
|
||||
@classmethod
|
||||
@@ -197,7 +243,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
with_for_update: bool = False,
|
||||
load_options: list[ExecutableOption] | None = None,
|
||||
as_response: bool = False,
|
||||
) -> ModelType | Response[ModelType]:
|
||||
schema: type[BaseModel] | None = None,
|
||||
) -> ModelType | Response[ModelType] | Response[Any]:
|
||||
"""Get exactly one record. Raises NotFoundError if not found.
|
||||
|
||||
Args:
|
||||
@@ -207,15 +254,25 @@ class AsyncCrud(Generic[ModelType]):
|
||||
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
|
||||
with_for_update: Lock the row for update
|
||||
load_options: SQLAlchemy loader options (e.g., selectinload)
|
||||
as_response: If True, wrap result in Response object
|
||||
as_response: Deprecated. Use ``schema`` instead. Will be removed in v2.0.
|
||||
schema: Pydantic schema to serialize the result into. When provided,
|
||||
the result is automatically wrapped in a ``Response[schema]``.
|
||||
|
||||
Returns:
|
||||
Model instance or Response wrapping it
|
||||
Model instance, or ``Response[schema]`` when ``schema`` is given,
|
||||
or ``Response[ModelType]`` when ``as_response=True`` (deprecated).
|
||||
|
||||
Raises:
|
||||
NotFoundError: If no record found
|
||||
MultipleResultsFound: If more than one record found
|
||||
"""
|
||||
if as_response and schema is None:
|
||||
warnings.warn(
|
||||
"as_response is deprecated and will be removed in v2.0. "
|
||||
"Use schema=YourSchema instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
q = select(cls.model)
|
||||
if joins:
|
||||
for model, condition in joins:
|
||||
@@ -234,8 +291,9 @@ class AsyncCrud(Generic[ModelType]):
|
||||
if not item:
|
||||
raise NotFoundError()
|
||||
result = cast(ModelType, item)
|
||||
if as_response:
|
||||
return Response(data=result)
|
||||
if as_response or schema:
|
||||
data_out = schema.model_validate(result) if schema else result
|
||||
return Response(data=data_out)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
@@ -324,6 +382,21 @@ class AsyncCrud(Generic[ModelType]):
|
||||
result = await session.execute(q)
|
||||
return cast(Sequence[ModelType], result.unique().scalars().all())
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def update( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
obj: BaseModel,
|
||||
filters: list[Any],
|
||||
*,
|
||||
exclude_unset: bool = True,
|
||||
exclude_none: bool = False,
|
||||
schema: type[SchemaType],
|
||||
as_response: bool = ...,
|
||||
) -> Response[SchemaType]: ...
|
||||
|
||||
# Backward-compatible - will be removed in v2.0
|
||||
@overload
|
||||
@classmethod
|
||||
async def update( # pragma: no cover
|
||||
@@ -335,6 +408,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
exclude_unset: bool = True,
|
||||
exclude_none: bool = False,
|
||||
as_response: Literal[True],
|
||||
schema: None = ...,
|
||||
) -> Response[ModelType]: ...
|
||||
|
||||
@overload
|
||||
@@ -348,6 +422,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
exclude_unset: bool = True,
|
||||
exclude_none: bool = False,
|
||||
as_response: Literal[False] = ...,
|
||||
schema: None = ...,
|
||||
) -> ModelType: ...
|
||||
|
||||
@classmethod
|
||||
@@ -360,7 +435,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
exclude_unset: bool = True,
|
||||
exclude_none: bool = False,
|
||||
as_response: bool = False,
|
||||
) -> ModelType | Response[ModelType]:
|
||||
schema: type[BaseModel] | None = None,
|
||||
) -> ModelType | Response[ModelType] | Response[Any]:
|
||||
"""Update a record in the database.
|
||||
|
||||
Args:
|
||||
@@ -369,14 +445,24 @@ class AsyncCrud(Generic[ModelType]):
|
||||
filters: List of SQLAlchemy filter conditions
|
||||
exclude_unset: Exclude fields not explicitly set in the schema
|
||||
exclude_none: Exclude fields with None value
|
||||
as_response: If True, wrap result in Response object
|
||||
as_response: Deprecated. Use ``schema`` instead. Will be removed in v2.0.
|
||||
schema: Pydantic schema to serialize the result into. When provided,
|
||||
the result is automatically wrapped in a ``Response[schema]``.
|
||||
|
||||
Returns:
|
||||
Updated model instance or Response wrapping it
|
||||
Updated model instance, or ``Response[schema]`` when ``schema`` is given,
|
||||
or ``Response[ModelType]`` when ``as_response=True`` (deprecated).
|
||||
|
||||
Raises:
|
||||
NotFoundError: If no record found
|
||||
"""
|
||||
if as_response and schema is None:
|
||||
warnings.warn(
|
||||
"as_response is deprecated and will be removed in v2.0. "
|
||||
"Use schema=YourSchema instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
async with get_transaction(session):
|
||||
m2m_exclude = cls._m2m_schema_fields()
|
||||
|
||||
@@ -406,8 +492,9 @@ class AsyncCrud(Generic[ModelType]):
|
||||
for rel_attr, related_instances in m2m_resolved.items():
|
||||
setattr(db_model, rel_attr, related_instances)
|
||||
await session.refresh(db_model)
|
||||
if as_response:
|
||||
return Response(data=db_model)
|
||||
if as_response or schema:
|
||||
data_out = schema.model_validate(db_model) if schema else db_model
|
||||
return Response(data=data_out)
|
||||
return db_model
|
||||
|
||||
@classmethod
|
||||
@@ -489,11 +576,20 @@ class AsyncCrud(Generic[ModelType]):
|
||||
Args:
|
||||
session: DB async session
|
||||
filters: List of SQLAlchemy filter conditions
|
||||
as_response: If True, wrap result in Response object
|
||||
as_response: Deprecated. Will be removed in v2.0. When ``True``,
|
||||
returns ``Response[None]`` instead of ``bool``.
|
||||
|
||||
Returns:
|
||||
True if deletion was executed, or Response wrapping it
|
||||
``True`` if deletion was executed, or ``Response[None]`` when
|
||||
``as_response=True`` (deprecated).
|
||||
"""
|
||||
if as_response:
|
||||
warnings.warn(
|
||||
"as_response is deprecated and will be removed in v2.0. "
|
||||
"Use schema=YourSchema instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
async with get_transaction(session):
|
||||
q = sql_delete(cls.model).where(and_(*filters))
|
||||
await session.execute(q)
|
||||
@@ -566,6 +662,43 @@ class AsyncCrud(Generic[ModelType]):
|
||||
result = await session.execute(q)
|
||||
return bool(result.scalar())
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def paginate( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
*,
|
||||
filters: list[Any] | None = None,
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
load_options: list[ExecutableOption] | None = None,
|
||||
order_by: Any | None = None,
|
||||
page: int = 1,
|
||||
items_per_page: int = 20,
|
||||
search: str | SearchConfig | None = None,
|
||||
search_fields: Sequence[SearchFieldType] | None = None,
|
||||
schema: type[SchemaType],
|
||||
) -> PaginatedResponse[SchemaType]: ...
|
||||
|
||||
# Backward-compatible - will be removed in v2.0
|
||||
@overload
|
||||
@classmethod
|
||||
async def paginate( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
*,
|
||||
filters: list[Any] | None = None,
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
load_options: list[ExecutableOption] | None = None,
|
||||
order_by: Any | None = None,
|
||||
page: int = 1,
|
||||
items_per_page: int = 20,
|
||||
search: str | SearchConfig | None = None,
|
||||
search_fields: Sequence[SearchFieldType] | None = None,
|
||||
schema: None = ...,
|
||||
) -> PaginatedResponse[ModelType]: ...
|
||||
|
||||
@classmethod
|
||||
async def paginate(
|
||||
cls: type[Self],
|
||||
@@ -580,7 +713,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
items_per_page: int = 20,
|
||||
search: str | SearchConfig | None = None,
|
||||
search_fields: Sequence[SearchFieldType] | None = None,
|
||||
) -> PaginatedResponse[ModelType]:
|
||||
schema: type[BaseModel] | None = None,
|
||||
) -> PaginatedResponse[ModelType] | PaginatedResponse[Any]:
|
||||
"""Get paginated results with metadata.
|
||||
|
||||
Args:
|
||||
@@ -594,6 +728,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
items_per_page: Number of items per page
|
||||
search: Search query string or SearchConfig object
|
||||
search_fields: Fields to search in (overrides class default)
|
||||
schema: Optional Pydantic schema to serialize each item into.
|
||||
|
||||
Returns:
|
||||
Dict with 'data' and 'pagination' keys
|
||||
@@ -637,7 +772,10 @@ class AsyncCrud(Generic[ModelType]):
|
||||
|
||||
q = q.offset(offset).limit(items_per_page)
|
||||
result = await session.execute(q)
|
||||
items = cast(list[ModelType], result.unique().scalars().all())
|
||||
raw_items = cast(list[ModelType], result.unique().scalars().all())
|
||||
items: list[Any] = (
|
||||
[schema.model_validate(item) for item in raw_items] if schema else raw_items
|
||||
)
|
||||
|
||||
# Count query (with same joins and filters)
|
||||
pk_col = cls.model.__mapper__.primary_key[0]
|
||||
|
||||
@@ -6,6 +6,8 @@ import uuid
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import Column, ForeignKey, String, Table, Uuid
|
||||
|
||||
from fastapi_toolsets.schemas import PydanticBase
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||
|
||||
@@ -90,6 +92,13 @@ class RoleCreate(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class RoleRead(PydanticBase):
|
||||
"""Schema for reading a role."""
|
||||
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
|
||||
|
||||
class RoleUpdate(BaseModel):
|
||||
"""Schema for updating a role."""
|
||||
|
||||
@@ -106,6 +115,13 @@ class UserCreate(BaseModel):
|
||||
role_id: uuid.UUID | None = None
|
||||
|
||||
|
||||
class UserRead(PydanticBase):
|
||||
"""Schema for reading a user (subset of fields)."""
|
||||
|
||||
id: uuid.UUID
|
||||
username: str
|
||||
|
||||
|
||||
class UserUpdate(BaseModel):
|
||||
"""Schema for updating a user."""
|
||||
|
||||
|
||||
@@ -20,12 +20,14 @@ from .conftest import (
|
||||
Role,
|
||||
RoleCreate,
|
||||
RoleCrud,
|
||||
RoleRead,
|
||||
RoleUpdate,
|
||||
TagCreate,
|
||||
TagCrud,
|
||||
User,
|
||||
UserCreate,
|
||||
UserCrud,
|
||||
UserRead,
|
||||
UserUpdate,
|
||||
)
|
||||
|
||||
@@ -907,14 +909,15 @@ class TestCrudJoins:
|
||||
|
||||
|
||||
class TestAsResponse:
|
||||
"""Tests for as_response parameter."""
|
||||
"""Tests for as_response parameter (deprecated, kept for backward compat)."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_as_response(self, db_session: AsyncSession):
|
||||
"""Create with as_response=True returns Response."""
|
||||
"""Create with as_response=True returns Response and emits DeprecationWarning."""
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
data = RoleCreate(name="response_role")
|
||||
with pytest.warns(DeprecationWarning, match="as_response is deprecated"):
|
||||
result = await RoleCrud.create(db_session, data, as_response=True)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
@@ -923,10 +926,11 @@ class TestAsResponse:
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_as_response(self, db_session: AsyncSession):
|
||||
"""Get with as_response=True returns Response."""
|
||||
"""Get with as_response=True returns Response and emits DeprecationWarning."""
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
created = await RoleCrud.create(db_session, RoleCreate(name="get_response"))
|
||||
with pytest.warns(DeprecationWarning, match="as_response is deprecated"):
|
||||
result = await RoleCrud.get(
|
||||
db_session, [Role.id == created.id], as_response=True
|
||||
)
|
||||
@@ -937,10 +941,11 @@ class TestAsResponse:
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_as_response(self, db_session: AsyncSession):
|
||||
"""Update with as_response=True returns Response."""
|
||||
"""Update with as_response=True returns Response and emits DeprecationWarning."""
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
created = await RoleCrud.create(db_session, RoleCreate(name="old_name"))
|
||||
with pytest.warns(DeprecationWarning, match="as_response is deprecated"):
|
||||
result = await RoleCrud.update(
|
||||
db_session,
|
||||
RoleUpdate(name="new_name"),
|
||||
@@ -954,10 +959,11 @@ class TestAsResponse:
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_as_response(self, db_session: AsyncSession):
|
||||
"""Delete with as_response=True returns Response."""
|
||||
"""Delete with as_response=True returns Response and emits DeprecationWarning."""
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
created = await RoleCrud.create(db_session, RoleCreate(name="to_delete"))
|
||||
with pytest.warns(DeprecationWarning, match="as_response is deprecated"):
|
||||
result = await RoleCrud.delete(
|
||||
db_session, [Role.id == created.id], as_response=True
|
||||
)
|
||||
@@ -1344,3 +1350,165 @@ class TestM2MWithNonM2MCrud:
|
||||
[Post.id == post.id],
|
||||
)
|
||||
assert updated.title == "Updated Plain"
|
||||
|
||||
|
||||
class TestSchemaResponse:
|
||||
"""Tests for the schema parameter on as_response methods."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_with_schema(self, db_session: AsyncSession):
|
||||
"""create with schema returns Response[SchemaType]."""
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
result = await RoleCrud.create(
|
||||
db_session, RoleCreate(name="schema_role"), schema=RoleRead
|
||||
)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert isinstance(result.data, RoleRead)
|
||||
assert result.data.name == "schema_role"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_schema_implies_as_response(self, db_session: AsyncSession):
|
||||
"""create with schema alone wraps in Response without as_response=True."""
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
result = await RoleCrud.create(
|
||||
db_session, RoleCreate(name="implicit"), schema=RoleRead
|
||||
)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_schema_filters_fields(self, db_session: AsyncSession):
|
||||
"""create with schema only exposes schema fields, not all model fields."""
|
||||
result = await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="filtered", email="filtered@test.com"),
|
||||
schema=UserRead,
|
||||
)
|
||||
|
||||
assert isinstance(result.data, UserRead)
|
||||
assert result.data.username == "filtered"
|
||||
assert not hasattr(result.data, "email")
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_with_schema(self, db_session: AsyncSession):
|
||||
"""get with schema returns Response[SchemaType]."""
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
created = await RoleCrud.create(db_session, RoleCreate(name="get_schema"))
|
||||
result = await RoleCrud.get(
|
||||
db_session, [Role.id == created.id], schema=RoleRead
|
||||
)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert isinstance(result.data, RoleRead)
|
||||
assert result.data.id == created.id
|
||||
assert result.data.name == "get_schema"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_schema_implies_as_response(self, db_session: AsyncSession):
|
||||
"""get with schema alone wraps in Response without as_response=True."""
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
created = await RoleCrud.create(db_session, RoleCreate(name="implicit_get"))
|
||||
result = await RoleCrud.get(
|
||||
db_session, [Role.id == created.id], schema=RoleRead
|
||||
)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_with_schema(self, db_session: AsyncSession):
|
||||
"""update with schema returns Response[SchemaType]."""
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
created = await RoleCrud.create(db_session, RoleCreate(name="before"))
|
||||
result = await RoleCrud.update(
|
||||
db_session,
|
||||
RoleUpdate(name="after"),
|
||||
[Role.id == created.id],
|
||||
schema=RoleRead,
|
||||
)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert isinstance(result.data, RoleRead)
|
||||
assert result.data.name == "after"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_schema_implies_as_response(self, db_session: AsyncSession):
|
||||
"""update with schema alone wraps in Response without as_response=True."""
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
created = await RoleCrud.create(db_session, RoleCreate(name="before2"))
|
||||
result = await RoleCrud.update(
|
||||
db_session,
|
||||
RoleUpdate(name="after2"),
|
||||
[Role.id == created.id],
|
||||
schema=RoleRead,
|
||||
)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_paginate_with_schema(self, db_session: AsyncSession):
|
||||
"""paginate with schema returns PaginatedResponse[SchemaType]."""
|
||||
from fastapi_toolsets.schemas import PaginatedResponse
|
||||
|
||||
await RoleCrud.create(db_session, RoleCreate(name="p_role1"))
|
||||
await RoleCrud.create(db_session, RoleCreate(name="p_role2"))
|
||||
|
||||
result = await RoleCrud.paginate(db_session, schema=RoleRead)
|
||||
|
||||
assert isinstance(result, PaginatedResponse)
|
||||
assert len(result.data) == 2
|
||||
assert all(isinstance(item, RoleRead) for item in result.data)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_paginate_schema_filters_fields(self, db_session: AsyncSession):
|
||||
"""paginate with schema only exposes schema fields per item."""
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="pg_user", email="pg@test.com"),
|
||||
)
|
||||
|
||||
result = await UserCrud.paginate(db_session, schema=UserRead)
|
||||
|
||||
assert isinstance(result.data[0], UserRead)
|
||||
assert result.data[0].username == "pg_user"
|
||||
assert not hasattr(result.data[0], "email")
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_as_response_true_without_schema_unchanged(
|
||||
self, db_session: AsyncSession
|
||||
):
|
||||
"""as_response=True without schema still returns Response[ModelType] with a warning."""
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
created = await RoleCrud.create(db_session, RoleCreate(name="compat"))
|
||||
with pytest.warns(DeprecationWarning, match="as_response is deprecated"):
|
||||
result = await RoleCrud.get(
|
||||
db_session, [Role.id == created.id], as_response=True
|
||||
)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert isinstance(result.data, Role)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_schema_with_explicit_as_response_true(
|
||||
self, db_session: AsyncSession
|
||||
):
|
||||
"""schema combined with explicit as_response=True works correctly."""
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
created = await RoleCrud.create(db_session, RoleCreate(name="combined"))
|
||||
result = await RoleCrud.get(
|
||||
db_session,
|
||||
[Role.id == created.id],
|
||||
as_response=True,
|
||||
schema=RoleRead,
|
||||
)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert isinstance(result.data, RoleRead)
|
||||
|
||||
Reference in New Issue
Block a user