mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-03-01 17:00:48 +01:00
feat: add many to many support in CrudFactory (#65)
This commit is contained in:
@@ -5,7 +5,7 @@ import uuid
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import ForeignKey, String, Uuid
|
||||
from sqlalchemy import Column, ForeignKey, String, Table, Uuid
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||
|
||||
@@ -56,6 +56,25 @@ class User(Base):
|
||||
role: Mapped[Role | None] = relationship(back_populates="users")
|
||||
|
||||
|
||||
class Tag(Base):
|
||||
"""Test tag model."""
|
||||
|
||||
__tablename__ = "tags"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||
name: Mapped[str] = mapped_column(String(50), unique=True)
|
||||
|
||||
|
||||
post_tags = Table(
|
||||
"post_tags",
|
||||
Base.metadata,
|
||||
Column(
|
||||
"post_id", Uuid, ForeignKey("posts.id", ondelete="CASCADE"), primary_key=True
|
||||
),
|
||||
Column("tag_id", Uuid, ForeignKey("tags.id", ondelete="CASCADE"), primary_key=True),
|
||||
)
|
||||
|
||||
|
||||
class Post(Base):
|
||||
"""Test post model."""
|
||||
|
||||
@@ -67,6 +86,8 @@ class Post(Base):
|
||||
is_published: Mapped[bool] = mapped_column(default=False)
|
||||
author_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"))
|
||||
|
||||
tags: Mapped[list[Tag]] = relationship(secondary=post_tags)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Schemas
|
||||
@@ -105,6 +126,13 @@ class UserUpdate(BaseModel):
|
||||
role_id: uuid.UUID | None = None
|
||||
|
||||
|
||||
class TagCreate(BaseModel):
|
||||
"""Schema for creating a tag."""
|
||||
|
||||
id: uuid.UUID | None = None
|
||||
name: str
|
||||
|
||||
|
||||
class PostCreate(BaseModel):
|
||||
"""Schema for creating a post."""
|
||||
|
||||
@@ -123,6 +151,26 @@ class PostUpdate(BaseModel):
|
||||
is_published: bool | None = None
|
||||
|
||||
|
||||
class PostM2MCreate(BaseModel):
|
||||
"""Schema for creating a post with M2M tag IDs."""
|
||||
|
||||
id: uuid.UUID | None = None
|
||||
title: str
|
||||
content: str = ""
|
||||
is_published: bool = False
|
||||
author_id: uuid.UUID
|
||||
tag_ids: list[uuid.UUID] = []
|
||||
|
||||
|
||||
class PostM2MUpdate(BaseModel):
|
||||
"""Schema for updating a post with M2M tag IDs."""
|
||||
|
||||
title: str | None = None
|
||||
content: str | None = None
|
||||
is_published: bool | None = None
|
||||
tag_ids: list[uuid.UUID] | None = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# CRUD Classes
|
||||
# =============================================================================
|
||||
@@ -130,6 +178,8 @@ class PostUpdate(BaseModel):
|
||||
RoleCrud = CrudFactory(Role)
|
||||
UserCrud = CrudFactory(User)
|
||||
PostCrud = CrudFactory(Post)
|
||||
TagCrud = CrudFactory(Tag)
|
||||
PostM2MCrud = CrudFactory(Post, m2m_fields={"tag_ids": Post.tags})
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
||||
@@ -4,6 +4,7 @@ import uuid
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from fastapi_toolsets.crud import CrudFactory
|
||||
from fastapi_toolsets.crud.factory import AsyncCrud
|
||||
@@ -13,10 +14,15 @@ from .conftest import (
|
||||
Post,
|
||||
PostCreate,
|
||||
PostCrud,
|
||||
PostM2MCreate,
|
||||
PostM2MCrud,
|
||||
PostM2MUpdate,
|
||||
Role,
|
||||
RoleCreate,
|
||||
RoleCrud,
|
||||
RoleUpdate,
|
||||
TagCreate,
|
||||
TagCrud,
|
||||
User,
|
||||
UserCreate,
|
||||
UserCrud,
|
||||
@@ -812,3 +818,383 @@ class TestAsResponse:
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert result.data is None
|
||||
|
||||
|
||||
class TestCrudFactoryM2M:
|
||||
"""Tests for CrudFactory with m2m_fields parameter."""
|
||||
|
||||
def test_creates_crud_with_m2m_fields(self):
|
||||
"""CrudFactory configures m2m_fields on the class."""
|
||||
crud = CrudFactory(Post, m2m_fields={"tag_ids": Post.tags})
|
||||
assert crud.m2m_fields is not None
|
||||
assert "tag_ids" in crud.m2m_fields
|
||||
|
||||
def test_creates_crud_without_m2m_fields(self):
|
||||
"""CrudFactory without m2m_fields has None."""
|
||||
crud = CrudFactory(Post)
|
||||
assert crud.m2m_fields is None
|
||||
|
||||
def test_m2m_schema_fields(self):
|
||||
"""_m2m_schema_fields returns correct field names."""
|
||||
crud = CrudFactory(Post, m2m_fields={"tag_ids": Post.tags})
|
||||
assert crud._m2m_schema_fields() == {"tag_ids"}
|
||||
|
||||
def test_m2m_schema_fields_empty_when_none(self):
|
||||
"""_m2m_schema_fields returns empty set when no m2m_fields."""
|
||||
crud = CrudFactory(Post)
|
||||
assert crud._m2m_schema_fields() == set()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_resolve_m2m_returns_empty_without_m2m_fields(
|
||||
self, db_session: AsyncSession
|
||||
):
|
||||
"""_resolve_m2m returns empty dict when m2m_fields is not configured."""
|
||||
from pydantic import BaseModel
|
||||
|
||||
class DummySchema(BaseModel):
|
||||
name: str
|
||||
|
||||
result = await PostCrud._resolve_m2m(db_session, DummySchema(name="test"))
|
||||
assert result == {}
|
||||
|
||||
|
||||
class TestM2MResolveNone:
|
||||
"""Tests for _resolve_m2m when IDs field is None."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_resolve_m2m_with_none_ids(self, db_session: AsyncSession):
|
||||
"""_resolve_m2m sets empty list when ids value is None."""
|
||||
from pydantic import BaseModel
|
||||
|
||||
class SchemaWithNullableTags(BaseModel):
|
||||
tag_ids: list[uuid.UUID] | None = None
|
||||
|
||||
result = await PostM2MCrud._resolve_m2m(
|
||||
db_session, SchemaWithNullableTags(tag_ids=None)
|
||||
)
|
||||
assert result == {"tags": []}
|
||||
|
||||
|
||||
class TestM2MCreate:
|
||||
"""Tests for create with M2M relationships."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_with_m2m_tags(self, db_session: AsyncSession):
|
||||
"""Create a post with M2M tags resolves tag IDs."""
|
||||
user = await UserCrud.create(
|
||||
db_session, UserCreate(username="author", email="author@test.com")
|
||||
)
|
||||
tag1 = await TagCrud.create(db_session, TagCreate(name="python"))
|
||||
tag2 = await TagCrud.create(db_session, TagCreate(name="fastapi"))
|
||||
|
||||
post = await PostM2MCrud.create(
|
||||
db_session,
|
||||
PostM2MCreate(
|
||||
title="M2M Post",
|
||||
author_id=user.id,
|
||||
tag_ids=[tag1.id, tag2.id],
|
||||
),
|
||||
)
|
||||
|
||||
assert post.id is not None
|
||||
assert post.title == "M2M Post"
|
||||
|
||||
# Reload with tags eagerly loaded
|
||||
loaded = await PostM2MCrud.get(
|
||||
db_session,
|
||||
[Post.id == post.id],
|
||||
load_options=[selectinload(Post.tags)],
|
||||
)
|
||||
tag_names = sorted(t.name for t in loaded.tags)
|
||||
assert tag_names == ["fastapi", "python"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_with_empty_m2m(self, db_session: AsyncSession):
|
||||
"""Create a post with empty tag_ids list works."""
|
||||
user = await UserCrud.create(
|
||||
db_session, UserCreate(username="author", email="author@test.com")
|
||||
)
|
||||
|
||||
post = await PostM2MCrud.create(
|
||||
db_session,
|
||||
PostM2MCreate(
|
||||
title="No Tags Post",
|
||||
author_id=user.id,
|
||||
tag_ids=[],
|
||||
),
|
||||
)
|
||||
|
||||
assert post.id is not None
|
||||
|
||||
loaded = await PostM2MCrud.get(
|
||||
db_session,
|
||||
[Post.id == post.id],
|
||||
load_options=[selectinload(Post.tags)],
|
||||
)
|
||||
assert loaded.tags == []
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_with_default_m2m(self, db_session: AsyncSession):
|
||||
"""Create a post using default tag_ids (empty list) works."""
|
||||
user = await UserCrud.create(
|
||||
db_session, UserCreate(username="author", email="author@test.com")
|
||||
)
|
||||
|
||||
post = await PostM2MCrud.create(
|
||||
db_session,
|
||||
PostM2MCreate(title="Default Tags", author_id=user.id),
|
||||
)
|
||||
|
||||
loaded = await PostM2MCrud.get(
|
||||
db_session,
|
||||
[Post.id == post.id],
|
||||
load_options=[selectinload(Post.tags)],
|
||||
)
|
||||
assert loaded.tags == []
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_with_nonexistent_tag_id_raises(
|
||||
self, db_session: AsyncSession
|
||||
):
|
||||
"""Create with a nonexistent tag ID raises NotFoundError."""
|
||||
user = await UserCrud.create(
|
||||
db_session, UserCreate(username="author", email="author@test.com")
|
||||
)
|
||||
tag = await TagCrud.create(db_session, TagCreate(name="valid"))
|
||||
fake_id = uuid.uuid4()
|
||||
|
||||
with pytest.raises(NotFoundError):
|
||||
await PostM2MCrud.create(
|
||||
db_session,
|
||||
PostM2MCreate(
|
||||
title="Bad Tags",
|
||||
author_id=user.id,
|
||||
tag_ids=[tag.id, fake_id],
|
||||
),
|
||||
)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_with_single_tag(self, db_session: AsyncSession):
|
||||
"""Create with a single tag works correctly."""
|
||||
user = await UserCrud.create(
|
||||
db_session, UserCreate(username="author", email="author@test.com")
|
||||
)
|
||||
tag = await TagCrud.create(db_session, TagCreate(name="solo"))
|
||||
|
||||
post = await PostM2MCrud.create(
|
||||
db_session,
|
||||
PostM2MCreate(
|
||||
title="Single Tag",
|
||||
author_id=user.id,
|
||||
tag_ids=[tag.id],
|
||||
),
|
||||
)
|
||||
|
||||
loaded = await PostM2MCrud.get(
|
||||
db_session,
|
||||
[Post.id == post.id],
|
||||
load_options=[selectinload(Post.tags)],
|
||||
)
|
||||
assert len(loaded.tags) == 1
|
||||
assert loaded.tags[0].name == "solo"
|
||||
|
||||
|
||||
class TestM2MUpdate:
|
||||
"""Tests for update with M2M relationships."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_m2m_tags(self, db_session: AsyncSession):
|
||||
"""Update replaces M2M tags when tag_ids is set."""
|
||||
user = await UserCrud.create(
|
||||
db_session, UserCreate(username="author", email="author@test.com")
|
||||
)
|
||||
tag1 = await TagCrud.create(db_session, TagCreate(name="old_tag"))
|
||||
tag2 = await TagCrud.create(db_session, TagCreate(name="new_tag"))
|
||||
|
||||
# Create with tag1
|
||||
post = await PostM2MCrud.create(
|
||||
db_session,
|
||||
PostM2MCreate(
|
||||
title="Update Test",
|
||||
author_id=user.id,
|
||||
tag_ids=[tag1.id],
|
||||
),
|
||||
)
|
||||
|
||||
# Update to tag2
|
||||
updated = await PostM2MCrud.update(
|
||||
db_session,
|
||||
PostM2MUpdate(tag_ids=[tag2.id]),
|
||||
[Post.id == post.id],
|
||||
)
|
||||
|
||||
loaded = await PostM2MCrud.get(
|
||||
db_session,
|
||||
[Post.id == updated.id],
|
||||
load_options=[selectinload(Post.tags)],
|
||||
)
|
||||
assert len(loaded.tags) == 1
|
||||
assert loaded.tags[0].name == "new_tag"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_without_m2m_preserves_tags(self, db_session: AsyncSession):
|
||||
"""Update without setting tag_ids preserves existing tags."""
|
||||
user = await UserCrud.create(
|
||||
db_session, UserCreate(username="author", email="author@test.com")
|
||||
)
|
||||
tag = await TagCrud.create(db_session, TagCreate(name="keep_me"))
|
||||
|
||||
post = await PostM2MCrud.create(
|
||||
db_session,
|
||||
PostM2MCreate(
|
||||
title="Keep Tags",
|
||||
author_id=user.id,
|
||||
tag_ids=[tag.id],
|
||||
),
|
||||
)
|
||||
|
||||
# Update only title, tag_ids not set
|
||||
await PostM2MCrud.update(
|
||||
db_session,
|
||||
PostM2MUpdate(title="Updated Title"),
|
||||
[Post.id == post.id],
|
||||
)
|
||||
|
||||
loaded = await PostM2MCrud.get(
|
||||
db_session,
|
||||
[Post.id == post.id],
|
||||
load_options=[selectinload(Post.tags)],
|
||||
)
|
||||
assert loaded.title == "Updated Title"
|
||||
assert len(loaded.tags) == 1
|
||||
assert loaded.tags[0].name == "keep_me"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_clear_m2m_tags(self, db_session: AsyncSession):
|
||||
"""Update with empty tag_ids clears all tags."""
|
||||
user = await UserCrud.create(
|
||||
db_session, UserCreate(username="author", email="author@test.com")
|
||||
)
|
||||
tag = await TagCrud.create(db_session, TagCreate(name="remove_me"))
|
||||
|
||||
post = await PostM2MCrud.create(
|
||||
db_session,
|
||||
PostM2MCreate(
|
||||
title="Clear Tags",
|
||||
author_id=user.id,
|
||||
tag_ids=[tag.id],
|
||||
),
|
||||
)
|
||||
|
||||
# Explicitly set tag_ids to empty list
|
||||
await PostM2MCrud.update(
|
||||
db_session,
|
||||
PostM2MUpdate(tag_ids=[]),
|
||||
[Post.id == post.id],
|
||||
)
|
||||
|
||||
loaded = await PostM2MCrud.get(
|
||||
db_session,
|
||||
[Post.id == post.id],
|
||||
load_options=[selectinload(Post.tags)],
|
||||
)
|
||||
assert loaded.tags == []
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_m2m_with_nonexistent_id_raises(
|
||||
self, db_session: AsyncSession
|
||||
):
|
||||
"""Update with nonexistent tag ID raises NotFoundError."""
|
||||
user = await UserCrud.create(
|
||||
db_session, UserCreate(username="author", email="author@test.com")
|
||||
)
|
||||
tag = await TagCrud.create(db_session, TagCreate(name="existing"))
|
||||
|
||||
post = await PostM2MCrud.create(
|
||||
db_session,
|
||||
PostM2MCreate(
|
||||
title="Bad Update",
|
||||
author_id=user.id,
|
||||
tag_ids=[tag.id],
|
||||
),
|
||||
)
|
||||
|
||||
fake_id = uuid.uuid4()
|
||||
with pytest.raises(NotFoundError):
|
||||
await PostM2MCrud.update(
|
||||
db_session,
|
||||
PostM2MUpdate(tag_ids=[fake_id]),
|
||||
[Post.id == post.id],
|
||||
)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_m2m_and_scalar_fields(self, db_session: AsyncSession):
|
||||
"""Update both scalar fields and M2M tags together."""
|
||||
user = await UserCrud.create(
|
||||
db_session, UserCreate(username="author", email="author@test.com")
|
||||
)
|
||||
tag1 = await TagCrud.create(db_session, TagCreate(name="tag1"))
|
||||
tag2 = await TagCrud.create(db_session, TagCreate(name="tag2"))
|
||||
|
||||
post = await PostM2MCrud.create(
|
||||
db_session,
|
||||
PostM2MCreate(
|
||||
title="Original",
|
||||
author_id=user.id,
|
||||
tag_ids=[tag1.id],
|
||||
),
|
||||
)
|
||||
|
||||
# Update title and tags simultaneously
|
||||
await PostM2MCrud.update(
|
||||
db_session,
|
||||
PostM2MUpdate(title="Updated", tag_ids=[tag1.id, tag2.id]),
|
||||
[Post.id == post.id],
|
||||
)
|
||||
|
||||
loaded = await PostM2MCrud.get(
|
||||
db_session,
|
||||
[Post.id == post.id],
|
||||
load_options=[selectinload(Post.tags)],
|
||||
)
|
||||
assert loaded.title == "Updated"
|
||||
tag_names = sorted(t.name for t in loaded.tags)
|
||||
assert tag_names == ["tag1", "tag2"]
|
||||
|
||||
|
||||
class TestM2MWithNonM2MCrud:
|
||||
"""Tests that non-M2M CRUD classes are unaffected."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_without_m2m_unchanged(self, db_session: AsyncSession):
|
||||
"""Regular PostCrud.create still works without M2M logic."""
|
||||
from .conftest import PostCreate
|
||||
|
||||
user = await UserCrud.create(
|
||||
db_session, UserCreate(username="author", email="author@test.com")
|
||||
)
|
||||
post = await PostCrud.create(
|
||||
db_session,
|
||||
PostCreate(title="Plain Post", author_id=user.id),
|
||||
)
|
||||
assert post.id is not None
|
||||
assert post.title == "Plain Post"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_without_m2m_unchanged(self, db_session: AsyncSession):
|
||||
"""Regular PostCrud.update still works without M2M logic."""
|
||||
from .conftest import PostCreate, PostUpdate
|
||||
|
||||
user = await UserCrud.create(
|
||||
db_session, UserCreate(username="author", email="author@test.com")
|
||||
)
|
||||
post = await PostCrud.create(
|
||||
db_session,
|
||||
PostCreate(title="Plain Post", author_id=user.id),
|
||||
)
|
||||
updated = await PostCrud.update(
|
||||
db_session,
|
||||
PostUpdate(title="Updated Plain"),
|
||||
[Post.id == post.id],
|
||||
)
|
||||
assert updated.title == "Updated Plain"
|
||||
|
||||
Reference in New Issue
Block a user