1 Commits

Author SHA1 Message Date
d3vyce
c32f2e18be feat: add many to many support in CrudFactory (#65) 2026-02-15 15:57:15 +01:00
4 changed files with 549 additions and 7 deletions

View File

@@ -1,7 +1,7 @@
"""Generic async CRUD operations for SQLAlchemy models."""
from ..exceptions import NoSearchableFieldsError
from .factory import CrudFactory
from .factory import CrudFactory, JoinType, M2MFieldType
from .search import (
SearchConfig,
get_searchable_fields,
@@ -10,6 +10,8 @@ from .search import (
__all__ = [
"CrudFactory",
"get_searchable_fields",
"JoinType",
"M2MFieldType",
"NoSearchableFieldsError",
"SearchConfig",
]

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload
from pydantic import BaseModel
@@ -11,7 +11,7 @@ from sqlalchemy import delete as sql_delete
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.exc import NoResultFound
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import DeclarativeBase, QueryableAttribute, selectinload
from sqlalchemy.sql.roles import WhereHavingRole
from ..db import get_transaction
@@ -21,6 +21,7 @@ from .search import SearchConfig, SearchFieldType, build_search_filters
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
JoinType = list[tuple[type[DeclarativeBase], Any]]
M2MFieldType = Mapping[str, QueryableAttribute[Any]]
class AsyncCrud(Generic[ModelType]):
@@ -31,6 +32,7 @@ class AsyncCrud(Generic[ModelType]):
model: ClassVar[type[DeclarativeBase]]
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
m2m_fields: ClassVar[M2MFieldType | None] = None
@overload
@classmethod
@@ -52,6 +54,62 @@ class AsyncCrud(Generic[ModelType]):
as_response: Literal[False] = ...,
) -> ModelType: ...
@classmethod
async def _resolve_m2m(
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
*,
only_set: bool = False,
) -> dict[str, list[Any]]:
"""Resolve M2M fields from a Pydantic schema into related model instances.
Args:
session: DB async session
obj: Pydantic model containing M2M ID fields
only_set: If True, only process fields explicitly set on the schema
Returns:
Dict mapping relationship attr names to lists of related instances
"""
result: dict[str, list[Any]] = {}
if not cls.m2m_fields:
return result
for schema_field, rel in cls.m2m_fields.items():
rel_attr = rel.property.key
related_model = rel.property.mapper.class_
if only_set and schema_field not in obj.model_fields_set:
continue
ids = getattr(obj, schema_field, None)
if ids is not None:
related = (
(
await session.execute(
select(related_model).where(related_model.id.in_(ids))
)
)
.scalars()
.all()
)
if len(related) != len(ids):
found_ids = {r.id for r in related}
missing = set(ids) - found_ids
raise NotFoundError(
f"Related {related_model.__name__} not found for IDs: {missing}"
)
result[rel_attr] = list(related)
else:
result[rel_attr] = []
return result
@classmethod
def _m2m_schema_fields(cls: type[Self]) -> set[str]:
"""Return the set of schema field names that are M2M fields."""
if not cls.m2m_fields:
return set()
return set(cls.m2m_fields.keys())
@classmethod
async def create(
cls: type[Self],
@@ -71,7 +129,17 @@ class AsyncCrud(Generic[ModelType]):
Created model instance or Response wrapping it
"""
async with get_transaction(session):
db_model = cls.model(**obj.model_dump())
m2m_exclude = cls._m2m_schema_fields()
data = (
obj.model_dump(exclude=m2m_exclude) if m2m_exclude else obj.model_dump()
)
db_model = cls.model(**data)
if m2m_exclude:
m2m_resolved = await cls._resolve_m2m(session, obj)
for rel_attr, related_instances in m2m_resolved.items():
setattr(db_model, rel_attr, related_instances)
session.add(db_model)
await session.refresh(db_model)
result = cast(ModelType, db_model)
@@ -299,12 +367,33 @@ class AsyncCrud(Generic[ModelType]):
NotFoundError: If no record found
"""
async with get_transaction(session):
db_model = await cls.get(session=session, filters=filters)
m2m_exclude = cls._m2m_schema_fields()
# Eagerly load M2M relationships that will be updated so that
# setattr does not trigger a lazy load (which fails in async).
m2m_load_options: list[Any] = []
if m2m_exclude and cls.m2m_fields:
for schema_field, rel in cls.m2m_fields.items():
if schema_field in obj.model_fields_set:
m2m_load_options.append(selectinload(rel))
db_model = await cls.get(
session=session,
filters=filters,
load_options=m2m_load_options or None,
)
values = obj.model_dump(
exclude_unset=exclude_unset, exclude_none=exclude_none
exclude_unset=exclude_unset,
exclude_none=exclude_none,
exclude=m2m_exclude,
)
for key, value in values.items():
setattr(db_model, key, value)
if m2m_exclude:
m2m_resolved = await cls._resolve_m2m(session, obj, only_set=True)
for rel_attr, related_instances in m2m_resolved.items():
setattr(db_model, rel_attr, related_instances)
await session.refresh(db_model)
if as_response:
return Response(data=db_model)
@@ -578,12 +667,16 @@ def CrudFactory(
model: type[ModelType],
*,
searchable_fields: Sequence[SearchFieldType] | None = None,
m2m_fields: M2MFieldType | None = None,
) -> type[AsyncCrud[ModelType]]:
"""Create a CRUD class for a specific model.
Args:
model: SQLAlchemy model class
searchable_fields: Optional list of searchable fields
m2m_fields: Optional mapping for many-to-many relationships.
Maps schema field names (containing lists of IDs) to
SQLAlchemy relationship attributes.
Returns:
AsyncCrud subclass bound to the model
@@ -601,10 +694,20 @@ def CrudFactory(
searchable_fields=[User.username, User.email, (User.role, Role.name)]
)
# With many-to-many fields:
# Schema has `tag_ids: list[UUID]`, model has `tags` relationship to Tag
PostCrud = CrudFactory(
Post,
m2m_fields={"tag_ids": Post.tags},
)
# Usage
user = await UserCrud.get(session, [User.id == 1])
posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id])
# Create with M2M - tag_ids are automatically resolved
post = await PostCrud.create(session, PostCreate(title="Hello", tag_ids=[id1, id2]))
# With search
result = await UserCrud.paginate(session, search="john")
@@ -628,6 +731,7 @@ def CrudFactory(
{
"model": model,
"searchable_fields": searchable_fields,
"m2m_fields": m2m_fields,
},
)
return cast(type[AsyncCrud[ModelType]], cls)

View File

@@ -5,7 +5,7 @@ import uuid
import pytest
from pydantic import BaseModel
from sqlalchemy import ForeignKey, String, Uuid
from sqlalchemy import Column, ForeignKey, String, Table, Uuid
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
@@ -56,6 +56,25 @@ class User(Base):
role: Mapped[Role | None] = relationship(back_populates="users")
class Tag(Base):
"""Test tag model."""
__tablename__ = "tags"
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
name: Mapped[str] = mapped_column(String(50), unique=True)
post_tags = Table(
"post_tags",
Base.metadata,
Column(
"post_id", Uuid, ForeignKey("posts.id", ondelete="CASCADE"), primary_key=True
),
Column("tag_id", Uuid, ForeignKey("tags.id", ondelete="CASCADE"), primary_key=True),
)
class Post(Base):
"""Test post model."""
@@ -67,6 +86,8 @@ class Post(Base):
is_published: Mapped[bool] = mapped_column(default=False)
author_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"))
tags: Mapped[list[Tag]] = relationship(secondary=post_tags)
# =============================================================================
# Test Schemas
@@ -105,6 +126,13 @@ class UserUpdate(BaseModel):
role_id: uuid.UUID | None = None
class TagCreate(BaseModel):
"""Schema for creating a tag."""
id: uuid.UUID | None = None
name: str
class PostCreate(BaseModel):
"""Schema for creating a post."""
@@ -123,6 +151,26 @@ class PostUpdate(BaseModel):
is_published: bool | None = None
class PostM2MCreate(BaseModel):
"""Schema for creating a post with M2M tag IDs."""
id: uuid.UUID | None = None
title: str
content: str = ""
is_published: bool = False
author_id: uuid.UUID
tag_ids: list[uuid.UUID] = []
class PostM2MUpdate(BaseModel):
"""Schema for updating a post with M2M tag IDs."""
title: str | None = None
content: str | None = None
is_published: bool | None = None
tag_ids: list[uuid.UUID] | None = None
# =============================================================================
# CRUD Classes
# =============================================================================
@@ -130,6 +178,8 @@ class PostUpdate(BaseModel):
RoleCrud = CrudFactory(Role)
UserCrud = CrudFactory(User)
PostCrud = CrudFactory(Post)
TagCrud = CrudFactory(Tag)
PostM2MCrud = CrudFactory(Post, m2m_fields={"tag_ids": Post.tags})
# =============================================================================

View File

@@ -4,6 +4,7 @@ import uuid
import pytest
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from fastapi_toolsets.crud import CrudFactory
from fastapi_toolsets.crud.factory import AsyncCrud
@@ -13,10 +14,15 @@ from .conftest import (
Post,
PostCreate,
PostCrud,
PostM2MCreate,
PostM2MCrud,
PostM2MUpdate,
Role,
RoleCreate,
RoleCrud,
RoleUpdate,
TagCreate,
TagCrud,
User,
UserCreate,
UserCrud,
@@ -812,3 +818,383 @@ class TestAsResponse:
assert isinstance(result, Response)
assert result.data is None
class TestCrudFactoryM2M:
"""Tests for CrudFactory with m2m_fields parameter."""
def test_creates_crud_with_m2m_fields(self):
"""CrudFactory configures m2m_fields on the class."""
crud = CrudFactory(Post, m2m_fields={"tag_ids": Post.tags})
assert crud.m2m_fields is not None
assert "tag_ids" in crud.m2m_fields
def test_creates_crud_without_m2m_fields(self):
"""CrudFactory without m2m_fields has None."""
crud = CrudFactory(Post)
assert crud.m2m_fields is None
def test_m2m_schema_fields(self):
"""_m2m_schema_fields returns correct field names."""
crud = CrudFactory(Post, m2m_fields={"tag_ids": Post.tags})
assert crud._m2m_schema_fields() == {"tag_ids"}
def test_m2m_schema_fields_empty_when_none(self):
"""_m2m_schema_fields returns empty set when no m2m_fields."""
crud = CrudFactory(Post)
assert crud._m2m_schema_fields() == set()
@pytest.mark.anyio
async def test_resolve_m2m_returns_empty_without_m2m_fields(
self, db_session: AsyncSession
):
"""_resolve_m2m returns empty dict when m2m_fields is not configured."""
from pydantic import BaseModel
class DummySchema(BaseModel):
name: str
result = await PostCrud._resolve_m2m(db_session, DummySchema(name="test"))
assert result == {}
class TestM2MResolveNone:
"""Tests for _resolve_m2m when IDs field is None."""
@pytest.mark.anyio
async def test_resolve_m2m_with_none_ids(self, db_session: AsyncSession):
"""_resolve_m2m sets empty list when ids value is None."""
from pydantic import BaseModel
class SchemaWithNullableTags(BaseModel):
tag_ids: list[uuid.UUID] | None = None
result = await PostM2MCrud._resolve_m2m(
db_session, SchemaWithNullableTags(tag_ids=None)
)
assert result == {"tags": []}
class TestM2MCreate:
"""Tests for create with M2M relationships."""
@pytest.mark.anyio
async def test_create_with_m2m_tags(self, db_session: AsyncSession):
"""Create a post with M2M tags resolves tag IDs."""
user = await UserCrud.create(
db_session, UserCreate(username="author", email="author@test.com")
)
tag1 = await TagCrud.create(db_session, TagCreate(name="python"))
tag2 = await TagCrud.create(db_session, TagCreate(name="fastapi"))
post = await PostM2MCrud.create(
db_session,
PostM2MCreate(
title="M2M Post",
author_id=user.id,
tag_ids=[tag1.id, tag2.id],
),
)
assert post.id is not None
assert post.title == "M2M Post"
# Reload with tags eagerly loaded
loaded = await PostM2MCrud.get(
db_session,
[Post.id == post.id],
load_options=[selectinload(Post.tags)],
)
tag_names = sorted(t.name for t in loaded.tags)
assert tag_names == ["fastapi", "python"]
@pytest.mark.anyio
async def test_create_with_empty_m2m(self, db_session: AsyncSession):
"""Create a post with empty tag_ids list works."""
user = await UserCrud.create(
db_session, UserCreate(username="author", email="author@test.com")
)
post = await PostM2MCrud.create(
db_session,
PostM2MCreate(
title="No Tags Post",
author_id=user.id,
tag_ids=[],
),
)
assert post.id is not None
loaded = await PostM2MCrud.get(
db_session,
[Post.id == post.id],
load_options=[selectinload(Post.tags)],
)
assert loaded.tags == []
@pytest.mark.anyio
async def test_create_with_default_m2m(self, db_session: AsyncSession):
"""Create a post using default tag_ids (empty list) works."""
user = await UserCrud.create(
db_session, UserCreate(username="author", email="author@test.com")
)
post = await PostM2MCrud.create(
db_session,
PostM2MCreate(title="Default Tags", author_id=user.id),
)
loaded = await PostM2MCrud.get(
db_session,
[Post.id == post.id],
load_options=[selectinload(Post.tags)],
)
assert loaded.tags == []
@pytest.mark.anyio
async def test_create_with_nonexistent_tag_id_raises(
self, db_session: AsyncSession
):
"""Create with a nonexistent tag ID raises NotFoundError."""
user = await UserCrud.create(
db_session, UserCreate(username="author", email="author@test.com")
)
tag = await TagCrud.create(db_session, TagCreate(name="valid"))
fake_id = uuid.uuid4()
with pytest.raises(NotFoundError):
await PostM2MCrud.create(
db_session,
PostM2MCreate(
title="Bad Tags",
author_id=user.id,
tag_ids=[tag.id, fake_id],
),
)
@pytest.mark.anyio
async def test_create_with_single_tag(self, db_session: AsyncSession):
"""Create with a single tag works correctly."""
user = await UserCrud.create(
db_session, UserCreate(username="author", email="author@test.com")
)
tag = await TagCrud.create(db_session, TagCreate(name="solo"))
post = await PostM2MCrud.create(
db_session,
PostM2MCreate(
title="Single Tag",
author_id=user.id,
tag_ids=[tag.id],
),
)
loaded = await PostM2MCrud.get(
db_session,
[Post.id == post.id],
load_options=[selectinload(Post.tags)],
)
assert len(loaded.tags) == 1
assert loaded.tags[0].name == "solo"
class TestM2MUpdate:
"""Tests for update with M2M relationships."""
@pytest.mark.anyio
async def test_update_m2m_tags(self, db_session: AsyncSession):
"""Update replaces M2M tags when tag_ids is set."""
user = await UserCrud.create(
db_session, UserCreate(username="author", email="author@test.com")
)
tag1 = await TagCrud.create(db_session, TagCreate(name="old_tag"))
tag2 = await TagCrud.create(db_session, TagCreate(name="new_tag"))
# Create with tag1
post = await PostM2MCrud.create(
db_session,
PostM2MCreate(
title="Update Test",
author_id=user.id,
tag_ids=[tag1.id],
),
)
# Update to tag2
updated = await PostM2MCrud.update(
db_session,
PostM2MUpdate(tag_ids=[tag2.id]),
[Post.id == post.id],
)
loaded = await PostM2MCrud.get(
db_session,
[Post.id == updated.id],
load_options=[selectinload(Post.tags)],
)
assert len(loaded.tags) == 1
assert loaded.tags[0].name == "new_tag"
@pytest.mark.anyio
async def test_update_without_m2m_preserves_tags(self, db_session: AsyncSession):
"""Update without setting tag_ids preserves existing tags."""
user = await UserCrud.create(
db_session, UserCreate(username="author", email="author@test.com")
)
tag = await TagCrud.create(db_session, TagCreate(name="keep_me"))
post = await PostM2MCrud.create(
db_session,
PostM2MCreate(
title="Keep Tags",
author_id=user.id,
tag_ids=[tag.id],
),
)
# Update only title, tag_ids not set
await PostM2MCrud.update(
db_session,
PostM2MUpdate(title="Updated Title"),
[Post.id == post.id],
)
loaded = await PostM2MCrud.get(
db_session,
[Post.id == post.id],
load_options=[selectinload(Post.tags)],
)
assert loaded.title == "Updated Title"
assert len(loaded.tags) == 1
assert loaded.tags[0].name == "keep_me"
@pytest.mark.anyio
async def test_update_clear_m2m_tags(self, db_session: AsyncSession):
"""Update with empty tag_ids clears all tags."""
user = await UserCrud.create(
db_session, UserCreate(username="author", email="author@test.com")
)
tag = await TagCrud.create(db_session, TagCreate(name="remove_me"))
post = await PostM2MCrud.create(
db_session,
PostM2MCreate(
title="Clear Tags",
author_id=user.id,
tag_ids=[tag.id],
),
)
# Explicitly set tag_ids to empty list
await PostM2MCrud.update(
db_session,
PostM2MUpdate(tag_ids=[]),
[Post.id == post.id],
)
loaded = await PostM2MCrud.get(
db_session,
[Post.id == post.id],
load_options=[selectinload(Post.tags)],
)
assert loaded.tags == []
@pytest.mark.anyio
async def test_update_m2m_with_nonexistent_id_raises(
self, db_session: AsyncSession
):
"""Update with nonexistent tag ID raises NotFoundError."""
user = await UserCrud.create(
db_session, UserCreate(username="author", email="author@test.com")
)
tag = await TagCrud.create(db_session, TagCreate(name="existing"))
post = await PostM2MCrud.create(
db_session,
PostM2MCreate(
title="Bad Update",
author_id=user.id,
tag_ids=[tag.id],
),
)
fake_id = uuid.uuid4()
with pytest.raises(NotFoundError):
await PostM2MCrud.update(
db_session,
PostM2MUpdate(tag_ids=[fake_id]),
[Post.id == post.id],
)
@pytest.mark.anyio
async def test_update_m2m_and_scalar_fields(self, db_session: AsyncSession):
"""Update both scalar fields and M2M tags together."""
user = await UserCrud.create(
db_session, UserCreate(username="author", email="author@test.com")
)
tag1 = await TagCrud.create(db_session, TagCreate(name="tag1"))
tag2 = await TagCrud.create(db_session, TagCreate(name="tag2"))
post = await PostM2MCrud.create(
db_session,
PostM2MCreate(
title="Original",
author_id=user.id,
tag_ids=[tag1.id],
),
)
# Update title and tags simultaneously
await PostM2MCrud.update(
db_session,
PostM2MUpdate(title="Updated", tag_ids=[tag1.id, tag2.id]),
[Post.id == post.id],
)
loaded = await PostM2MCrud.get(
db_session,
[Post.id == post.id],
load_options=[selectinload(Post.tags)],
)
assert loaded.title == "Updated"
tag_names = sorted(t.name for t in loaded.tags)
assert tag_names == ["tag1", "tag2"]
class TestM2MWithNonM2MCrud:
"""Tests that non-M2M CRUD classes are unaffected."""
@pytest.mark.anyio
async def test_create_without_m2m_unchanged(self, db_session: AsyncSession):
"""Regular PostCrud.create still works without M2M logic."""
from .conftest import PostCreate
user = await UserCrud.create(
db_session, UserCreate(username="author", email="author@test.com")
)
post = await PostCrud.create(
db_session,
PostCreate(title="Plain Post", author_id=user.id),
)
assert post.id is not None
assert post.title == "Plain Post"
@pytest.mark.anyio
async def test_update_without_m2m_unchanged(self, db_session: AsyncSession):
"""Regular PostCrud.update still works without M2M logic."""
from .conftest import PostCreate, PostUpdate
user = await UserCrud.create(
db_session, UserCreate(username="author", email="author@test.com")
)
post = await PostCrud.create(
db_session,
PostCreate(title="Plain Post", author_id=user.id),
)
updated = await PostCrud.update(
db_session,
PostUpdate(title="Updated Plain"),
[Post.id == post.id],
)
assert updated.title == "Updated Plain"