mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-03-01 17:00:48 +01:00
feat: add many to many support in CrudFactory (#65)
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
"""Generic async CRUD operations for SQLAlchemy models."""
|
"""Generic async CRUD operations for SQLAlchemy models."""
|
||||||
|
|
||||||
from ..exceptions import NoSearchableFieldsError
|
from ..exceptions import NoSearchableFieldsError
|
||||||
from .factory import CrudFactory
|
from .factory import CrudFactory, JoinType, M2MFieldType
|
||||||
from .search import (
|
from .search import (
|
||||||
SearchConfig,
|
SearchConfig,
|
||||||
get_searchable_fields,
|
get_searchable_fields,
|
||||||
@@ -10,6 +10,8 @@ from .search import (
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"CrudFactory",
|
"CrudFactory",
|
||||||
"get_searchable_fields",
|
"get_searchable_fields",
|
||||||
|
"JoinType",
|
||||||
|
"M2MFieldType",
|
||||||
"NoSearchableFieldsError",
|
"NoSearchableFieldsError",
|
||||||
"SearchConfig",
|
"SearchConfig",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload
|
from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -11,7 +11,7 @@ from sqlalchemy import delete as sql_delete
|
|||||||
from sqlalchemy.dialects.postgresql import insert
|
from sqlalchemy.dialects.postgresql import insert
|
||||||
from sqlalchemy.exc import NoResultFound
|
from sqlalchemy.exc import NoResultFound
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase, QueryableAttribute, selectinload
|
||||||
from sqlalchemy.sql.roles import WhereHavingRole
|
from sqlalchemy.sql.roles import WhereHavingRole
|
||||||
|
|
||||||
from ..db import get_transaction
|
from ..db import get_transaction
|
||||||
@@ -21,6 +21,7 @@ from .search import SearchConfig, SearchFieldType, build_search_filters
|
|||||||
|
|
||||||
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
||||||
JoinType = list[tuple[type[DeclarativeBase], Any]]
|
JoinType = list[tuple[type[DeclarativeBase], Any]]
|
||||||
|
M2MFieldType = Mapping[str, QueryableAttribute[Any]]
|
||||||
|
|
||||||
|
|
||||||
class AsyncCrud(Generic[ModelType]):
|
class AsyncCrud(Generic[ModelType]):
|
||||||
@@ -31,6 +32,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
|
|
||||||
model: ClassVar[type[DeclarativeBase]]
|
model: ClassVar[type[DeclarativeBase]]
|
||||||
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
|
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
|
||||||
|
m2m_fields: ClassVar[M2MFieldType | None] = None
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -52,6 +54,62 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
as_response: Literal[False] = ...,
|
as_response: Literal[False] = ...,
|
||||||
) -> ModelType: ...
|
) -> ModelType: ...
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _resolve_m2m(
|
||||||
|
cls: type[Self],
|
||||||
|
session: AsyncSession,
|
||||||
|
obj: BaseModel,
|
||||||
|
*,
|
||||||
|
only_set: bool = False,
|
||||||
|
) -> dict[str, list[Any]]:
|
||||||
|
"""Resolve M2M fields from a Pydantic schema into related model instances.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: DB async session
|
||||||
|
obj: Pydantic model containing M2M ID fields
|
||||||
|
only_set: If True, only process fields explicitly set on the schema
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping relationship attr names to lists of related instances
|
||||||
|
"""
|
||||||
|
result: dict[str, list[Any]] = {}
|
||||||
|
if not cls.m2m_fields:
|
||||||
|
return result
|
||||||
|
|
||||||
|
for schema_field, rel in cls.m2m_fields.items():
|
||||||
|
rel_attr = rel.property.key
|
||||||
|
related_model = rel.property.mapper.class_
|
||||||
|
if only_set and schema_field not in obj.model_fields_set:
|
||||||
|
continue
|
||||||
|
ids = getattr(obj, schema_field, None)
|
||||||
|
if ids is not None:
|
||||||
|
related = (
|
||||||
|
(
|
||||||
|
await session.execute(
|
||||||
|
select(related_model).where(related_model.id.in_(ids))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.scalars()
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
if len(related) != len(ids):
|
||||||
|
found_ids = {r.id for r in related}
|
||||||
|
missing = set(ids) - found_ids
|
||||||
|
raise NotFoundError(
|
||||||
|
f"Related {related_model.__name__} not found for IDs: {missing}"
|
||||||
|
)
|
||||||
|
result[rel_attr] = list(related)
|
||||||
|
else:
|
||||||
|
result[rel_attr] = []
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _m2m_schema_fields(cls: type[Self]) -> set[str]:
|
||||||
|
"""Return the set of schema field names that are M2M fields."""
|
||||||
|
if not cls.m2m_fields:
|
||||||
|
return set()
|
||||||
|
return set(cls.m2m_fields.keys())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create(
|
async def create(
|
||||||
cls: type[Self],
|
cls: type[Self],
|
||||||
@@ -71,7 +129,17 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
Created model instance or Response wrapping it
|
Created model instance or Response wrapping it
|
||||||
"""
|
"""
|
||||||
async with get_transaction(session):
|
async with get_transaction(session):
|
||||||
db_model = cls.model(**obj.model_dump())
|
m2m_exclude = cls._m2m_schema_fields()
|
||||||
|
data = (
|
||||||
|
obj.model_dump(exclude=m2m_exclude) if m2m_exclude else obj.model_dump()
|
||||||
|
)
|
||||||
|
db_model = cls.model(**data)
|
||||||
|
|
||||||
|
if m2m_exclude:
|
||||||
|
m2m_resolved = await cls._resolve_m2m(session, obj)
|
||||||
|
for rel_attr, related_instances in m2m_resolved.items():
|
||||||
|
setattr(db_model, rel_attr, related_instances)
|
||||||
|
|
||||||
session.add(db_model)
|
session.add(db_model)
|
||||||
await session.refresh(db_model)
|
await session.refresh(db_model)
|
||||||
result = cast(ModelType, db_model)
|
result = cast(ModelType, db_model)
|
||||||
@@ -299,12 +367,33 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
NotFoundError: If no record found
|
NotFoundError: If no record found
|
||||||
"""
|
"""
|
||||||
async with get_transaction(session):
|
async with get_transaction(session):
|
||||||
db_model = await cls.get(session=session, filters=filters)
|
m2m_exclude = cls._m2m_schema_fields()
|
||||||
|
|
||||||
|
# Eagerly load M2M relationships that will be updated so that
|
||||||
|
# setattr does not trigger a lazy load (which fails in async).
|
||||||
|
m2m_load_options: list[Any] = []
|
||||||
|
if m2m_exclude and cls.m2m_fields:
|
||||||
|
for schema_field, rel in cls.m2m_fields.items():
|
||||||
|
if schema_field in obj.model_fields_set:
|
||||||
|
m2m_load_options.append(selectinload(rel))
|
||||||
|
|
||||||
|
db_model = await cls.get(
|
||||||
|
session=session,
|
||||||
|
filters=filters,
|
||||||
|
load_options=m2m_load_options or None,
|
||||||
|
)
|
||||||
values = obj.model_dump(
|
values = obj.model_dump(
|
||||||
exclude_unset=exclude_unset, exclude_none=exclude_none
|
exclude_unset=exclude_unset,
|
||||||
|
exclude_none=exclude_none,
|
||||||
|
exclude=m2m_exclude,
|
||||||
)
|
)
|
||||||
for key, value in values.items():
|
for key, value in values.items():
|
||||||
setattr(db_model, key, value)
|
setattr(db_model, key, value)
|
||||||
|
|
||||||
|
if m2m_exclude:
|
||||||
|
m2m_resolved = await cls._resolve_m2m(session, obj, only_set=True)
|
||||||
|
for rel_attr, related_instances in m2m_resolved.items():
|
||||||
|
setattr(db_model, rel_attr, related_instances)
|
||||||
await session.refresh(db_model)
|
await session.refresh(db_model)
|
||||||
if as_response:
|
if as_response:
|
||||||
return Response(data=db_model)
|
return Response(data=db_model)
|
||||||
@@ -578,12 +667,16 @@ def CrudFactory(
|
|||||||
model: type[ModelType],
|
model: type[ModelType],
|
||||||
*,
|
*,
|
||||||
searchable_fields: Sequence[SearchFieldType] | None = None,
|
searchable_fields: Sequence[SearchFieldType] | None = None,
|
||||||
|
m2m_fields: M2MFieldType | None = None,
|
||||||
) -> type[AsyncCrud[ModelType]]:
|
) -> type[AsyncCrud[ModelType]]:
|
||||||
"""Create a CRUD class for a specific model.
|
"""Create a CRUD class for a specific model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: SQLAlchemy model class
|
model: SQLAlchemy model class
|
||||||
searchable_fields: Optional list of searchable fields
|
searchable_fields: Optional list of searchable fields
|
||||||
|
m2m_fields: Optional mapping for many-to-many relationships.
|
||||||
|
Maps schema field names (containing lists of IDs) to
|
||||||
|
SQLAlchemy relationship attributes.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
AsyncCrud subclass bound to the model
|
AsyncCrud subclass bound to the model
|
||||||
@@ -601,10 +694,20 @@ def CrudFactory(
|
|||||||
searchable_fields=[User.username, User.email, (User.role, Role.name)]
|
searchable_fields=[User.username, User.email, (User.role, Role.name)]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# With many-to-many fields:
|
||||||
|
# Schema has `tag_ids: list[UUID]`, model has `tags` relationship to Tag
|
||||||
|
PostCrud = CrudFactory(
|
||||||
|
Post,
|
||||||
|
m2m_fields={"tag_ids": Post.tags},
|
||||||
|
)
|
||||||
|
|
||||||
# Usage
|
# Usage
|
||||||
user = await UserCrud.get(session, [User.id == 1])
|
user = await UserCrud.get(session, [User.id == 1])
|
||||||
posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id])
|
posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id])
|
||||||
|
|
||||||
|
# Create with M2M - tag_ids are automatically resolved
|
||||||
|
post = await PostCrud.create(session, PostCreate(title="Hello", tag_ids=[id1, id2]))
|
||||||
|
|
||||||
# With search
|
# With search
|
||||||
result = await UserCrud.paginate(session, search="john")
|
result = await UserCrud.paginate(session, search="john")
|
||||||
|
|
||||||
@@ -628,6 +731,7 @@ def CrudFactory(
|
|||||||
{
|
{
|
||||||
"model": model,
|
"model": model,
|
||||||
"searchable_fields": searchable_fields,
|
"searchable_fields": searchable_fields,
|
||||||
|
"m2m_fields": m2m_fields,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return cast(type[AsyncCrud[ModelType]], cls)
|
return cast(type[AsyncCrud[ModelType]], cls)
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import uuid
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import ForeignKey, String, Uuid
|
from sqlalchemy import Column, ForeignKey, String, Table, Uuid
|
||||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||||
|
|
||||||
@@ -56,6 +56,25 @@ class User(Base):
|
|||||||
role: Mapped[Role | None] = relationship(back_populates="users")
|
role: Mapped[Role | None] = relationship(back_populates="users")
|
||||||
|
|
||||||
|
|
||||||
|
class Tag(Base):
|
||||||
|
"""Test tag model."""
|
||||||
|
|
||||||
|
__tablename__ = "tags"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||||
|
name: Mapped[str] = mapped_column(String(50), unique=True)
|
||||||
|
|
||||||
|
|
||||||
|
post_tags = Table(
|
||||||
|
"post_tags",
|
||||||
|
Base.metadata,
|
||||||
|
Column(
|
||||||
|
"post_id", Uuid, ForeignKey("posts.id", ondelete="CASCADE"), primary_key=True
|
||||||
|
),
|
||||||
|
Column("tag_id", Uuid, ForeignKey("tags.id", ondelete="CASCADE"), primary_key=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Post(Base):
|
class Post(Base):
|
||||||
"""Test post model."""
|
"""Test post model."""
|
||||||
|
|
||||||
@@ -67,6 +86,8 @@ class Post(Base):
|
|||||||
is_published: Mapped[bool] = mapped_column(default=False)
|
is_published: Mapped[bool] = mapped_column(default=False)
|
||||||
author_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"))
|
author_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"))
|
||||||
|
|
||||||
|
tags: Mapped[list[Tag]] = relationship(secondary=post_tags)
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Test Schemas
|
# Test Schemas
|
||||||
@@ -105,6 +126,13 @@ class UserUpdate(BaseModel):
|
|||||||
role_id: uuid.UUID | None = None
|
role_id: uuid.UUID | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TagCreate(BaseModel):
|
||||||
|
"""Schema for creating a tag."""
|
||||||
|
|
||||||
|
id: uuid.UUID | None = None
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
class PostCreate(BaseModel):
|
class PostCreate(BaseModel):
|
||||||
"""Schema for creating a post."""
|
"""Schema for creating a post."""
|
||||||
|
|
||||||
@@ -123,6 +151,26 @@ class PostUpdate(BaseModel):
|
|||||||
is_published: bool | None = None
|
is_published: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class PostM2MCreate(BaseModel):
|
||||||
|
"""Schema for creating a post with M2M tag IDs."""
|
||||||
|
|
||||||
|
id: uuid.UUID | None = None
|
||||||
|
title: str
|
||||||
|
content: str = ""
|
||||||
|
is_published: bool = False
|
||||||
|
author_id: uuid.UUID
|
||||||
|
tag_ids: list[uuid.UUID] = []
|
||||||
|
|
||||||
|
|
||||||
|
class PostM2MUpdate(BaseModel):
|
||||||
|
"""Schema for updating a post with M2M tag IDs."""
|
||||||
|
|
||||||
|
title: str | None = None
|
||||||
|
content: str | None = None
|
||||||
|
is_published: bool | None = None
|
||||||
|
tag_ids: list[uuid.UUID] | None = None
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# CRUD Classes
|
# CRUD Classes
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -130,6 +178,8 @@ class PostUpdate(BaseModel):
|
|||||||
RoleCrud = CrudFactory(Role)
|
RoleCrud = CrudFactory(Role)
|
||||||
UserCrud = CrudFactory(User)
|
UserCrud = CrudFactory(User)
|
||||||
PostCrud = CrudFactory(Post)
|
PostCrud = CrudFactory(Post)
|
||||||
|
TagCrud = CrudFactory(Tag)
|
||||||
|
PostM2MCrud = CrudFactory(Post, m2m_fields={"tag_ids": Post.tags})
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import uuid
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
from fastapi_toolsets.crud import CrudFactory
|
from fastapi_toolsets.crud import CrudFactory
|
||||||
from fastapi_toolsets.crud.factory import AsyncCrud
|
from fastapi_toolsets.crud.factory import AsyncCrud
|
||||||
@@ -13,10 +14,15 @@ from .conftest import (
|
|||||||
Post,
|
Post,
|
||||||
PostCreate,
|
PostCreate,
|
||||||
PostCrud,
|
PostCrud,
|
||||||
|
PostM2MCreate,
|
||||||
|
PostM2MCrud,
|
||||||
|
PostM2MUpdate,
|
||||||
Role,
|
Role,
|
||||||
RoleCreate,
|
RoleCreate,
|
||||||
RoleCrud,
|
RoleCrud,
|
||||||
RoleUpdate,
|
RoleUpdate,
|
||||||
|
TagCreate,
|
||||||
|
TagCrud,
|
||||||
User,
|
User,
|
||||||
UserCreate,
|
UserCreate,
|
||||||
UserCrud,
|
UserCrud,
|
||||||
@@ -812,3 +818,383 @@ class TestAsResponse:
|
|||||||
|
|
||||||
assert isinstance(result, Response)
|
assert isinstance(result, Response)
|
||||||
assert result.data is None
|
assert result.data is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestCrudFactoryM2M:
|
||||||
|
"""Tests for CrudFactory with m2m_fields parameter."""
|
||||||
|
|
||||||
|
def test_creates_crud_with_m2m_fields(self):
|
||||||
|
"""CrudFactory configures m2m_fields on the class."""
|
||||||
|
crud = CrudFactory(Post, m2m_fields={"tag_ids": Post.tags})
|
||||||
|
assert crud.m2m_fields is not None
|
||||||
|
assert "tag_ids" in crud.m2m_fields
|
||||||
|
|
||||||
|
def test_creates_crud_without_m2m_fields(self):
|
||||||
|
"""CrudFactory without m2m_fields has None."""
|
||||||
|
crud = CrudFactory(Post)
|
||||||
|
assert crud.m2m_fields is None
|
||||||
|
|
||||||
|
def test_m2m_schema_fields(self):
|
||||||
|
"""_m2m_schema_fields returns correct field names."""
|
||||||
|
crud = CrudFactory(Post, m2m_fields={"tag_ids": Post.tags})
|
||||||
|
assert crud._m2m_schema_fields() == {"tag_ids"}
|
||||||
|
|
||||||
|
def test_m2m_schema_fields_empty_when_none(self):
|
||||||
|
"""_m2m_schema_fields returns empty set when no m2m_fields."""
|
||||||
|
crud = CrudFactory(Post)
|
||||||
|
assert crud._m2m_schema_fields() == set()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_resolve_m2m_returns_empty_without_m2m_fields(
|
||||||
|
self, db_session: AsyncSession
|
||||||
|
):
|
||||||
|
"""_resolve_m2m returns empty dict when m2m_fields is not configured."""
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
class DummySchema(BaseModel):
|
||||||
|
name: str
|
||||||
|
|
||||||
|
result = await PostCrud._resolve_m2m(db_session, DummySchema(name="test"))
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
class TestM2MResolveNone:
|
||||||
|
"""Tests for _resolve_m2m when IDs field is None."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_resolve_m2m_with_none_ids(self, db_session: AsyncSession):
|
||||||
|
"""_resolve_m2m sets empty list when ids value is None."""
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
class SchemaWithNullableTags(BaseModel):
|
||||||
|
tag_ids: list[uuid.UUID] | None = None
|
||||||
|
|
||||||
|
result = await PostM2MCrud._resolve_m2m(
|
||||||
|
db_session, SchemaWithNullableTags(tag_ids=None)
|
||||||
|
)
|
||||||
|
assert result == {"tags": []}
|
||||||
|
|
||||||
|
|
||||||
|
class TestM2MCreate:
|
||||||
|
"""Tests for create with M2M relationships."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_create_with_m2m_tags(self, db_session: AsyncSession):
|
||||||
|
"""Create a post with M2M tags resolves tag IDs."""
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="author", email="author@test.com")
|
||||||
|
)
|
||||||
|
tag1 = await TagCrud.create(db_session, TagCreate(name="python"))
|
||||||
|
tag2 = await TagCrud.create(db_session, TagCreate(name="fastapi"))
|
||||||
|
|
||||||
|
post = await PostM2MCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostM2MCreate(
|
||||||
|
title="M2M Post",
|
||||||
|
author_id=user.id,
|
||||||
|
tag_ids=[tag1.id, tag2.id],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert post.id is not None
|
||||||
|
assert post.title == "M2M Post"
|
||||||
|
|
||||||
|
# Reload with tags eagerly loaded
|
||||||
|
loaded = await PostM2MCrud.get(
|
||||||
|
db_session,
|
||||||
|
[Post.id == post.id],
|
||||||
|
load_options=[selectinload(Post.tags)],
|
||||||
|
)
|
||||||
|
tag_names = sorted(t.name for t in loaded.tags)
|
||||||
|
assert tag_names == ["fastapi", "python"]
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_create_with_empty_m2m(self, db_session: AsyncSession):
|
||||||
|
"""Create a post with empty tag_ids list works."""
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="author", email="author@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
post = await PostM2MCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostM2MCreate(
|
||||||
|
title="No Tags Post",
|
||||||
|
author_id=user.id,
|
||||||
|
tag_ids=[],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert post.id is not None
|
||||||
|
|
||||||
|
loaded = await PostM2MCrud.get(
|
||||||
|
db_session,
|
||||||
|
[Post.id == post.id],
|
||||||
|
load_options=[selectinload(Post.tags)],
|
||||||
|
)
|
||||||
|
assert loaded.tags == []
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_create_with_default_m2m(self, db_session: AsyncSession):
|
||||||
|
"""Create a post using default tag_ids (empty list) works."""
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="author", email="author@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
post = await PostM2MCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostM2MCreate(title="Default Tags", author_id=user.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
loaded = await PostM2MCrud.get(
|
||||||
|
db_session,
|
||||||
|
[Post.id == post.id],
|
||||||
|
load_options=[selectinload(Post.tags)],
|
||||||
|
)
|
||||||
|
assert loaded.tags == []
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_create_with_nonexistent_tag_id_raises(
|
||||||
|
self, db_session: AsyncSession
|
||||||
|
):
|
||||||
|
"""Create with a nonexistent tag ID raises NotFoundError."""
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="author", email="author@test.com")
|
||||||
|
)
|
||||||
|
tag = await TagCrud.create(db_session, TagCreate(name="valid"))
|
||||||
|
fake_id = uuid.uuid4()
|
||||||
|
|
||||||
|
with pytest.raises(NotFoundError):
|
||||||
|
await PostM2MCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostM2MCreate(
|
||||||
|
title="Bad Tags",
|
||||||
|
author_id=user.id,
|
||||||
|
tag_ids=[tag.id, fake_id],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_create_with_single_tag(self, db_session: AsyncSession):
|
||||||
|
"""Create with a single tag works correctly."""
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="author", email="author@test.com")
|
||||||
|
)
|
||||||
|
tag = await TagCrud.create(db_session, TagCreate(name="solo"))
|
||||||
|
|
||||||
|
post = await PostM2MCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostM2MCreate(
|
||||||
|
title="Single Tag",
|
||||||
|
author_id=user.id,
|
||||||
|
tag_ids=[tag.id],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
loaded = await PostM2MCrud.get(
|
||||||
|
db_session,
|
||||||
|
[Post.id == post.id],
|
||||||
|
load_options=[selectinload(Post.tags)],
|
||||||
|
)
|
||||||
|
assert len(loaded.tags) == 1
|
||||||
|
assert loaded.tags[0].name == "solo"
|
||||||
|
|
||||||
|
|
||||||
|
class TestM2MUpdate:
|
||||||
|
"""Tests for update with M2M relationships."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_update_m2m_tags(self, db_session: AsyncSession):
|
||||||
|
"""Update replaces M2M tags when tag_ids is set."""
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="author", email="author@test.com")
|
||||||
|
)
|
||||||
|
tag1 = await TagCrud.create(db_session, TagCreate(name="old_tag"))
|
||||||
|
tag2 = await TagCrud.create(db_session, TagCreate(name="new_tag"))
|
||||||
|
|
||||||
|
# Create with tag1
|
||||||
|
post = await PostM2MCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostM2MCreate(
|
||||||
|
title="Update Test",
|
||||||
|
author_id=user.id,
|
||||||
|
tag_ids=[tag1.id],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update to tag2
|
||||||
|
updated = await PostM2MCrud.update(
|
||||||
|
db_session,
|
||||||
|
PostM2MUpdate(tag_ids=[tag2.id]),
|
||||||
|
[Post.id == post.id],
|
||||||
|
)
|
||||||
|
|
||||||
|
loaded = await PostM2MCrud.get(
|
||||||
|
db_session,
|
||||||
|
[Post.id == updated.id],
|
||||||
|
load_options=[selectinload(Post.tags)],
|
||||||
|
)
|
||||||
|
assert len(loaded.tags) == 1
|
||||||
|
assert loaded.tags[0].name == "new_tag"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_update_without_m2m_preserves_tags(self, db_session: AsyncSession):
|
||||||
|
"""Update without setting tag_ids preserves existing tags."""
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="author", email="author@test.com")
|
||||||
|
)
|
||||||
|
tag = await TagCrud.create(db_session, TagCreate(name="keep_me"))
|
||||||
|
|
||||||
|
post = await PostM2MCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostM2MCreate(
|
||||||
|
title="Keep Tags",
|
||||||
|
author_id=user.id,
|
||||||
|
tag_ids=[tag.id],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update only title, tag_ids not set
|
||||||
|
await PostM2MCrud.update(
|
||||||
|
db_session,
|
||||||
|
PostM2MUpdate(title="Updated Title"),
|
||||||
|
[Post.id == post.id],
|
||||||
|
)
|
||||||
|
|
||||||
|
loaded = await PostM2MCrud.get(
|
||||||
|
db_session,
|
||||||
|
[Post.id == post.id],
|
||||||
|
load_options=[selectinload(Post.tags)],
|
||||||
|
)
|
||||||
|
assert loaded.title == "Updated Title"
|
||||||
|
assert len(loaded.tags) == 1
|
||||||
|
assert loaded.tags[0].name == "keep_me"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_update_clear_m2m_tags(self, db_session: AsyncSession):
|
||||||
|
"""Update with empty tag_ids clears all tags."""
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="author", email="author@test.com")
|
||||||
|
)
|
||||||
|
tag = await TagCrud.create(db_session, TagCreate(name="remove_me"))
|
||||||
|
|
||||||
|
post = await PostM2MCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostM2MCreate(
|
||||||
|
title="Clear Tags",
|
||||||
|
author_id=user.id,
|
||||||
|
tag_ids=[tag.id],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Explicitly set tag_ids to empty list
|
||||||
|
await PostM2MCrud.update(
|
||||||
|
db_session,
|
||||||
|
PostM2MUpdate(tag_ids=[]),
|
||||||
|
[Post.id == post.id],
|
||||||
|
)
|
||||||
|
|
||||||
|
loaded = await PostM2MCrud.get(
|
||||||
|
db_session,
|
||||||
|
[Post.id == post.id],
|
||||||
|
load_options=[selectinload(Post.tags)],
|
||||||
|
)
|
||||||
|
assert loaded.tags == []
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_update_m2m_with_nonexistent_id_raises(
|
||||||
|
self, db_session: AsyncSession
|
||||||
|
):
|
||||||
|
"""Update with nonexistent tag ID raises NotFoundError."""
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="author", email="author@test.com")
|
||||||
|
)
|
||||||
|
tag = await TagCrud.create(db_session, TagCreate(name="existing"))
|
||||||
|
|
||||||
|
post = await PostM2MCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostM2MCreate(
|
||||||
|
title="Bad Update",
|
||||||
|
author_id=user.id,
|
||||||
|
tag_ids=[tag.id],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
fake_id = uuid.uuid4()
|
||||||
|
with pytest.raises(NotFoundError):
|
||||||
|
await PostM2MCrud.update(
|
||||||
|
db_session,
|
||||||
|
PostM2MUpdate(tag_ids=[fake_id]),
|
||||||
|
[Post.id == post.id],
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_update_m2m_and_scalar_fields(self, db_session: AsyncSession):
|
||||||
|
"""Update both scalar fields and M2M tags together."""
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="author", email="author@test.com")
|
||||||
|
)
|
||||||
|
tag1 = await TagCrud.create(db_session, TagCreate(name="tag1"))
|
||||||
|
tag2 = await TagCrud.create(db_session, TagCreate(name="tag2"))
|
||||||
|
|
||||||
|
post = await PostM2MCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostM2MCreate(
|
||||||
|
title="Original",
|
||||||
|
author_id=user.id,
|
||||||
|
tag_ids=[tag1.id],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update title and tags simultaneously
|
||||||
|
await PostM2MCrud.update(
|
||||||
|
db_session,
|
||||||
|
PostM2MUpdate(title="Updated", tag_ids=[tag1.id, tag2.id]),
|
||||||
|
[Post.id == post.id],
|
||||||
|
)
|
||||||
|
|
||||||
|
loaded = await PostM2MCrud.get(
|
||||||
|
db_session,
|
||||||
|
[Post.id == post.id],
|
||||||
|
load_options=[selectinload(Post.tags)],
|
||||||
|
)
|
||||||
|
assert loaded.title == "Updated"
|
||||||
|
tag_names = sorted(t.name for t in loaded.tags)
|
||||||
|
assert tag_names == ["tag1", "tag2"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestM2MWithNonM2MCrud:
|
||||||
|
"""Tests that non-M2M CRUD classes are unaffected."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_create_without_m2m_unchanged(self, db_session: AsyncSession):
|
||||||
|
"""Regular PostCrud.create still works without M2M logic."""
|
||||||
|
from .conftest import PostCreate
|
||||||
|
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="author", email="author@test.com")
|
||||||
|
)
|
||||||
|
post = await PostCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostCreate(title="Plain Post", author_id=user.id),
|
||||||
|
)
|
||||||
|
assert post.id is not None
|
||||||
|
assert post.title == "Plain Post"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_update_without_m2m_unchanged(self, db_session: AsyncSession):
|
||||||
|
"""Regular PostCrud.update still works without M2M logic."""
|
||||||
|
from .conftest import PostCreate, PostUpdate
|
||||||
|
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="author", email="author@test.com")
|
||||||
|
)
|
||||||
|
post = await PostCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostCreate(title="Plain Post", author_id=user.id),
|
||||||
|
)
|
||||||
|
updated = await PostCrud.update(
|
||||||
|
db_session,
|
||||||
|
PostUpdate(title="Updated Plain"),
|
||||||
|
[Post.id == post.id],
|
||||||
|
)
|
||||||
|
assert updated.title == "Updated Plain"
|
||||||
|
|||||||
Reference in New Issue
Block a user