diff --git a/src/fastapi_toolsets/crud/__init__.py b/src/fastapi_toolsets/crud/__init__.py index 093763b..3aaa5bf 100644 --- a/src/fastapi_toolsets/crud/__init__.py +++ b/src/fastapi_toolsets/crud/__init__.py @@ -1,7 +1,7 @@ """Generic async CRUD operations for SQLAlchemy models.""" from ..exceptions import NoSearchableFieldsError -from .factory import CrudFactory +from .factory import CrudFactory, JoinType, M2MFieldType from .search import ( SearchConfig, get_searchable_fields, @@ -10,6 +10,8 @@ from .search import ( __all__ = [ "CrudFactory", "get_searchable_fields", + "JoinType", + "M2MFieldType", "NoSearchableFieldsError", "SearchConfig", ] diff --git a/src/fastapi_toolsets/crud/factory.py b/src/fastapi_toolsets/crud/factory.py index 96a8c73..6ae7e8b 100644 --- a/src/fastapi_toolsets/crud/factory.py +++ b/src/fastapi_toolsets/crud/factory.py @@ -2,7 +2,7 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload from pydantic import BaseModel @@ -11,7 +11,7 @@ from sqlalchemy import delete as sql_delete from sqlalchemy.dialects.postgresql import insert from sqlalchemy.exc import NoResultFound from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import DeclarativeBase, QueryableAttribute, selectinload from sqlalchemy.sql.roles import WhereHavingRole from ..db import get_transaction @@ -21,6 +21,7 @@ from .search import SearchConfig, SearchFieldType, build_search_filters ModelType = TypeVar("ModelType", bound=DeclarativeBase) JoinType = list[tuple[type[DeclarativeBase], Any]] +M2MFieldType = Mapping[str, QueryableAttribute[Any]] class AsyncCrud(Generic[ModelType]): @@ -31,6 +32,7 @@ class AsyncCrud(Generic[ModelType]): model: ClassVar[type[DeclarativeBase]] searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None + m2m_fields: ClassVar[M2MFieldType | None] = None @overload @classmethod @@ -52,6 +54,62 @@ class AsyncCrud(Generic[ModelType]): as_response: Literal[False] = ..., ) -> ModelType: ... + @classmethod + async def _resolve_m2m( + cls: type[Self], + session: AsyncSession, + obj: BaseModel, + *, + only_set: bool = False, + ) -> dict[str, list[Any]]: + """Resolve M2M fields from a Pydantic schema into related model instances. + + Args: + session: DB async session + obj: Pydantic model containing M2M ID fields + only_set: If True, only process fields explicitly set on the schema + + Returns: + Dict mapping relationship attr names to lists of related instances + """ + result: dict[str, list[Any]] = {} + if not cls.m2m_fields: + return result + + for schema_field, rel in cls.m2m_fields.items(): + rel_attr = rel.property.key + related_model = rel.property.mapper.class_ + if only_set and schema_field not in obj.model_fields_set: + continue + ids = getattr(obj, schema_field, None) + if ids is not None: + related = ( + ( + await session.execute( + select(related_model).where(related_model.id.in_(ids)) + ) + ) + .scalars() + .all() + ) + if len(related) != len(ids): + found_ids = {r.id for r in related} + missing = set(ids) - found_ids + raise NotFoundError( + f"Related {related_model.__name__} not found for IDs: {missing}" + ) + result[rel_attr] = list(related) + else: + result[rel_attr] = [] + return result + + @classmethod + def _m2m_schema_fields(cls: type[Self]) -> set[str]: + """Return the set of schema field names that are M2M fields.""" + if not cls.m2m_fields: + return set() + return set(cls.m2m_fields.keys()) + @classmethod async def create( cls: type[Self], @@ -71,7 +129,17 @@ class AsyncCrud(Generic[ModelType]): Created model instance or Response wrapping it """ async with get_transaction(session): - db_model = cls.model(**obj.model_dump()) + m2m_exclude = cls._m2m_schema_fields() + data = ( + obj.model_dump(exclude=m2m_exclude) if m2m_exclude else obj.model_dump() + ) + db_model = cls.model(**data) + + if m2m_exclude: + m2m_resolved = await cls._resolve_m2m(session, obj) + for rel_attr, related_instances in m2m_resolved.items(): + setattr(db_model, rel_attr, related_instances) + session.add(db_model) await session.refresh(db_model) result = cast(ModelType, db_model) @@ -299,12 +367,33 @@ class AsyncCrud(Generic[ModelType]): NotFoundError: If no record found """ async with get_transaction(session): - db_model = await cls.get(session=session, filters=filters) + m2m_exclude = cls._m2m_schema_fields() + + # Eagerly load M2M relationships that will be updated so that + # setattr does not trigger a lazy load (which fails in async). + m2m_load_options: list[Any] = [] + if m2m_exclude and cls.m2m_fields: + for schema_field, rel in cls.m2m_fields.items(): + if schema_field in obj.model_fields_set: + m2m_load_options.append(selectinload(rel)) + + db_model = await cls.get( + session=session, + filters=filters, + load_options=m2m_load_options or None, + ) values = obj.model_dump( - exclude_unset=exclude_unset, exclude_none=exclude_none + exclude_unset=exclude_unset, + exclude_none=exclude_none, + exclude=m2m_exclude, ) for key, value in values.items(): setattr(db_model, key, value) + + if m2m_exclude: + m2m_resolved = await cls._resolve_m2m(session, obj, only_set=True) + for rel_attr, related_instances in m2m_resolved.items(): + setattr(db_model, rel_attr, related_instances) await session.refresh(db_model) if as_response: return Response(data=db_model) @@ -578,12 +667,16 @@ def CrudFactory( model: type[ModelType], *, searchable_fields: Sequence[SearchFieldType] | None = None, + m2m_fields: M2MFieldType | None = None, ) -> type[AsyncCrud[ModelType]]: """Create a CRUD class for a specific model. Args: model: SQLAlchemy model class searchable_fields: Optional list of searchable fields + m2m_fields: Optional mapping for many-to-many relationships. + Maps schema field names (containing lists of IDs) to + SQLAlchemy relationship attributes. Returns: AsyncCrud subclass bound to the model @@ -601,10 +694,20 @@ def CrudFactory( searchable_fields=[User.username, User.email, (User.role, Role.name)] ) + # With many-to-many fields: + # Schema has `tag_ids: list[UUID]`, model has `tags` relationship to Tag + PostCrud = CrudFactory( + Post, + m2m_fields={"tag_ids": Post.tags}, + ) + # Usage user = await UserCrud.get(session, [User.id == 1]) posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id]) + # Create with M2M - tag_ids are automatically resolved + post = await PostCrud.create(session, PostCreate(title="Hello", tag_ids=[id1, id2])) + # With search result = await UserCrud.paginate(session, search="john") @@ -628,6 +731,7 @@ def CrudFactory( { "model": model, "searchable_fields": searchable_fields, + "m2m_fields": m2m_fields, }, ) return cast(type[AsyncCrud[ModelType]], cls) diff --git a/tests/conftest.py b/tests/conftest.py index e65837a..bd4d09d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ import uuid import pytest from pydantic import BaseModel -from sqlalchemy import ForeignKey, String, Uuid +from sqlalchemy import Column, ForeignKey, String, Table, Uuid from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship @@ -56,6 +56,25 @@ class User(Base): role: Mapped[Role | None] = relationship(back_populates="users") +class Tag(Base): + """Test tag model.""" + + __tablename__ = "tags" + + id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4) + name: Mapped[str] = mapped_column(String(50), unique=True) + + +post_tags = Table( + "post_tags", + Base.metadata, + Column( + "post_id", Uuid, ForeignKey("posts.id", ondelete="CASCADE"), primary_key=True + ), + Column("tag_id", Uuid, ForeignKey("tags.id", ondelete="CASCADE"), primary_key=True), +) + + class Post(Base): """Test post model.""" @@ -67,6 +86,8 @@ class Post(Base): is_published: Mapped[bool] = mapped_column(default=False) author_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id")) + tags: Mapped[list[Tag]] = relationship(secondary=post_tags) + # ============================================================================= # Test Schemas @@ -105,6 +126,13 @@ class UserUpdate(BaseModel): role_id: uuid.UUID | None = None +class TagCreate(BaseModel): + """Schema for creating a tag.""" + + id: uuid.UUID | None = None + name: str + + class PostCreate(BaseModel): """Schema for creating a post.""" @@ -123,6 +151,26 @@ class PostUpdate(BaseModel): is_published: bool | None = None +class PostM2MCreate(BaseModel): + """Schema for creating a post with M2M tag IDs.""" + + id: uuid.UUID | None = None + title: str + content: str = "" + is_published: bool = False + author_id: uuid.UUID + tag_ids: list[uuid.UUID] = [] + + +class PostM2MUpdate(BaseModel): + """Schema for updating a post with M2M tag IDs.""" + + title: str | None = None + content: str | None = None + is_published: bool | None = None + tag_ids: list[uuid.UUID] | None = None + + # ============================================================================= # CRUD Classes # ============================================================================= @@ -130,6 +178,8 @@ class PostUpdate(BaseModel): RoleCrud = CrudFactory(Role) UserCrud = CrudFactory(User) PostCrud = CrudFactory(Post) +TagCrud = CrudFactory(Tag) +PostM2MCrud = CrudFactory(Post, m2m_fields={"tag_ids": Post.tags}) # ============================================================================= diff --git a/tests/test_crud.py b/tests/test_crud.py index 18fad02..d91e0e0 100644 --- a/tests/test_crud.py +++ b/tests/test_crud.py @@ -4,6 +4,7 @@ import uuid import pytest from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload from fastapi_toolsets.crud import CrudFactory from fastapi_toolsets.crud.factory import AsyncCrud @@ -13,10 +14,15 @@ from .conftest import ( Post, PostCreate, PostCrud, + PostM2MCreate, + PostM2MCrud, + PostM2MUpdate, Role, RoleCreate, RoleCrud, RoleUpdate, + TagCreate, + TagCrud, User, UserCreate, UserCrud, @@ -812,3 +818,383 @@ class TestAsResponse: assert isinstance(result, Response) assert result.data is None + + +class TestCrudFactoryM2M: + """Tests for CrudFactory with m2m_fields parameter.""" + + def test_creates_crud_with_m2m_fields(self): + """CrudFactory configures m2m_fields on the class.""" + crud = CrudFactory(Post, m2m_fields={"tag_ids": Post.tags}) + assert crud.m2m_fields is not None + assert "tag_ids" in crud.m2m_fields + + def test_creates_crud_without_m2m_fields(self): + """CrudFactory without m2m_fields has None.""" + crud = CrudFactory(Post) + assert crud.m2m_fields is None + + def test_m2m_schema_fields(self): + """_m2m_schema_fields returns correct field names.""" + crud = CrudFactory(Post, m2m_fields={"tag_ids": Post.tags}) + assert crud._m2m_schema_fields() == {"tag_ids"} + + def test_m2m_schema_fields_empty_when_none(self): + """_m2m_schema_fields returns empty set when no m2m_fields.""" + crud = CrudFactory(Post) + assert crud._m2m_schema_fields() == set() + + @pytest.mark.anyio + async def test_resolve_m2m_returns_empty_without_m2m_fields( + self, db_session: AsyncSession + ): + """_resolve_m2m returns empty dict when m2m_fields is not configured.""" + from pydantic import BaseModel + + class DummySchema(BaseModel): + name: str + + result = await PostCrud._resolve_m2m(db_session, DummySchema(name="test")) + assert result == {} + + +class TestM2MResolveNone: + """Tests for _resolve_m2m when IDs field is None.""" + + @pytest.mark.anyio + async def test_resolve_m2m_with_none_ids(self, db_session: AsyncSession): + """_resolve_m2m sets empty list when ids value is None.""" + from pydantic import BaseModel + + class SchemaWithNullableTags(BaseModel): + tag_ids: list[uuid.UUID] | None = None + + result = await PostM2MCrud._resolve_m2m( + db_session, SchemaWithNullableTags(tag_ids=None) + ) + assert result == {"tags": []} + + +class TestM2MCreate: + """Tests for create with M2M relationships.""" + + @pytest.mark.anyio + async def test_create_with_m2m_tags(self, db_session: AsyncSession): + """Create a post with M2M tags resolves tag IDs.""" + user = await UserCrud.create( + db_session, UserCreate(username="author", email="author@test.com") + ) + tag1 = await TagCrud.create(db_session, TagCreate(name="python")) + tag2 = await TagCrud.create(db_session, TagCreate(name="fastapi")) + + post = await PostM2MCrud.create( + db_session, + PostM2MCreate( + title="M2M Post", + author_id=user.id, + tag_ids=[tag1.id, tag2.id], + ), + ) + + assert post.id is not None + assert post.title == "M2M Post" + + # Reload with tags eagerly loaded + loaded = await PostM2MCrud.get( + db_session, + [Post.id == post.id], + load_options=[selectinload(Post.tags)], + ) + tag_names = sorted(t.name for t in loaded.tags) + assert tag_names == ["fastapi", "python"] + + @pytest.mark.anyio + async def test_create_with_empty_m2m(self, db_session: AsyncSession): + """Create a post with empty tag_ids list works.""" + user = await UserCrud.create( + db_session, UserCreate(username="author", email="author@test.com") + ) + + post = await PostM2MCrud.create( + db_session, + PostM2MCreate( + title="No Tags Post", + author_id=user.id, + tag_ids=[], + ), + ) + + assert post.id is not None + + loaded = await PostM2MCrud.get( + db_session, + [Post.id == post.id], + load_options=[selectinload(Post.tags)], + ) + assert loaded.tags == [] + + @pytest.mark.anyio + async def test_create_with_default_m2m(self, db_session: AsyncSession): + """Create a post using default tag_ids (empty list) works.""" + user = await UserCrud.create( + db_session, UserCreate(username="author", email="author@test.com") + ) + + post = await PostM2MCrud.create( + db_session, + PostM2MCreate(title="Default Tags", author_id=user.id), + ) + + loaded = await PostM2MCrud.get( + db_session, + [Post.id == post.id], + load_options=[selectinload(Post.tags)], + ) + assert loaded.tags == [] + + @pytest.mark.anyio + async def test_create_with_nonexistent_tag_id_raises( + self, db_session: AsyncSession + ): + """Create with a nonexistent tag ID raises NotFoundError.""" + user = await UserCrud.create( + db_session, UserCreate(username="author", email="author@test.com") + ) + tag = await TagCrud.create(db_session, TagCreate(name="valid")) + fake_id = uuid.uuid4() + + with pytest.raises(NotFoundError): + await PostM2MCrud.create( + db_session, + PostM2MCreate( + title="Bad Tags", + author_id=user.id, + tag_ids=[tag.id, fake_id], + ), + ) + + @pytest.mark.anyio + async def test_create_with_single_tag(self, db_session: AsyncSession): + """Create with a single tag works correctly.""" + user = await UserCrud.create( + db_session, UserCreate(username="author", email="author@test.com") + ) + tag = await TagCrud.create(db_session, TagCreate(name="solo")) + + post = await PostM2MCrud.create( + db_session, + PostM2MCreate( + title="Single Tag", + author_id=user.id, + tag_ids=[tag.id], + ), + ) + + loaded = await PostM2MCrud.get( + db_session, + [Post.id == post.id], + load_options=[selectinload(Post.tags)], + ) + assert len(loaded.tags) == 1 + assert loaded.tags[0].name == "solo" + + +class TestM2MUpdate: + """Tests for update with M2M relationships.""" + + @pytest.mark.anyio + async def test_update_m2m_tags(self, db_session: AsyncSession): + """Update replaces M2M tags when tag_ids is set.""" + user = await UserCrud.create( + db_session, UserCreate(username="author", email="author@test.com") + ) + tag1 = await TagCrud.create(db_session, TagCreate(name="old_tag")) + tag2 = await TagCrud.create(db_session, TagCreate(name="new_tag")) + + # Create with tag1 + post = await PostM2MCrud.create( + db_session, + PostM2MCreate( + title="Update Test", + author_id=user.id, + tag_ids=[tag1.id], + ), + ) + + # Update to tag2 + updated = await PostM2MCrud.update( + db_session, + PostM2MUpdate(tag_ids=[tag2.id]), + [Post.id == post.id], + ) + + loaded = await PostM2MCrud.get( + db_session, + [Post.id == updated.id], + load_options=[selectinload(Post.tags)], + ) + assert len(loaded.tags) == 1 + assert loaded.tags[0].name == "new_tag" + + @pytest.mark.anyio + async def test_update_without_m2m_preserves_tags(self, db_session: AsyncSession): + """Update without setting tag_ids preserves existing tags.""" + user = await UserCrud.create( + db_session, UserCreate(username="author", email="author@test.com") + ) + tag = await TagCrud.create(db_session, TagCreate(name="keep_me")) + + post = await PostM2MCrud.create( + db_session, + PostM2MCreate( + title="Keep Tags", + author_id=user.id, + tag_ids=[tag.id], + ), + ) + + # Update only title, tag_ids not set + await PostM2MCrud.update( + db_session, + PostM2MUpdate(title="Updated Title"), + [Post.id == post.id], + ) + + loaded = await PostM2MCrud.get( + db_session, + [Post.id == post.id], + load_options=[selectinload(Post.tags)], + ) + assert loaded.title == "Updated Title" + assert len(loaded.tags) == 1 + assert loaded.tags[0].name == "keep_me" + + @pytest.mark.anyio + async def test_update_clear_m2m_tags(self, db_session: AsyncSession): + """Update with empty tag_ids clears all tags.""" + user = await UserCrud.create( + db_session, UserCreate(username="author", email="author@test.com") + ) + tag = await TagCrud.create(db_session, TagCreate(name="remove_me")) + + post = await PostM2MCrud.create( + db_session, + PostM2MCreate( + title="Clear Tags", + author_id=user.id, + tag_ids=[tag.id], + ), + ) + + # Explicitly set tag_ids to empty list + await PostM2MCrud.update( + db_session, + PostM2MUpdate(tag_ids=[]), + [Post.id == post.id], + ) + + loaded = await PostM2MCrud.get( + db_session, + [Post.id == post.id], + load_options=[selectinload(Post.tags)], + ) + assert loaded.tags == [] + + @pytest.mark.anyio + async def test_update_m2m_with_nonexistent_id_raises( + self, db_session: AsyncSession + ): + """Update with nonexistent tag ID raises NotFoundError.""" + user = await UserCrud.create( + db_session, UserCreate(username="author", email="author@test.com") + ) + tag = await TagCrud.create(db_session, TagCreate(name="existing")) + + post = await PostM2MCrud.create( + db_session, + PostM2MCreate( + title="Bad Update", + author_id=user.id, + tag_ids=[tag.id], + ), + ) + + fake_id = uuid.uuid4() + with pytest.raises(NotFoundError): + await PostM2MCrud.update( + db_session, + PostM2MUpdate(tag_ids=[fake_id]), + [Post.id == post.id], + ) + + @pytest.mark.anyio + async def test_update_m2m_and_scalar_fields(self, db_session: AsyncSession): + """Update both scalar fields and M2M tags together.""" + user = await UserCrud.create( + db_session, UserCreate(username="author", email="author@test.com") + ) + tag1 = await TagCrud.create(db_session, TagCreate(name="tag1")) + tag2 = await TagCrud.create(db_session, TagCreate(name="tag2")) + + post = await PostM2MCrud.create( + db_session, + PostM2MCreate( + title="Original", + author_id=user.id, + tag_ids=[tag1.id], + ), + ) + + # Update title and tags simultaneously + await PostM2MCrud.update( + db_session, + PostM2MUpdate(title="Updated", tag_ids=[tag1.id, tag2.id]), + [Post.id == post.id], + ) + + loaded = await PostM2MCrud.get( + db_session, + [Post.id == post.id], + load_options=[selectinload(Post.tags)], + ) + assert loaded.title == "Updated" + tag_names = sorted(t.name for t in loaded.tags) + assert tag_names == ["tag1", "tag2"] + + +class TestM2MWithNonM2MCrud: + """Tests that non-M2M CRUD classes are unaffected.""" + + @pytest.mark.anyio + async def test_create_without_m2m_unchanged(self, db_session: AsyncSession): + """Regular PostCrud.create still works without M2M logic.""" + from .conftest import PostCreate + + user = await UserCrud.create( + db_session, UserCreate(username="author", email="author@test.com") + ) + post = await PostCrud.create( + db_session, + PostCreate(title="Plain Post", author_id=user.id), + ) + assert post.id is not None + assert post.title == "Plain Post" + + @pytest.mark.anyio + async def test_update_without_m2m_unchanged(self, db_session: AsyncSession): + """Regular PostCrud.update still works without M2M logic.""" + from .conftest import PostCreate, PostUpdate + + user = await UserCrud.create( + db_session, UserCreate(username="author", email="author@test.com") + ) + post = await PostCrud.create( + db_session, + PostCreate(title="Plain Post", author_id=user.id), + ) + updated = await PostCrud.update( + db_session, + PostUpdate(title="Updated Plain"), + [Post.id == post.id], + ) + assert updated.title == "Updated Plain"