feat: add many to many support in CrudFactory (#65)

This commit is contained in:
d3vyce
2026-02-15 15:57:15 +01:00
committed by GitHub
parent d971261f98
commit c32f2e18be
4 changed files with 549 additions and 7 deletions

View File

@@ -1,7 +1,7 @@
"""Generic async CRUD operations for SQLAlchemy models."""
from ..exceptions import NoSearchableFieldsError
from .factory import CrudFactory
from .factory import CrudFactory, JoinType, M2MFieldType
from .search import (
SearchConfig,
get_searchable_fields,
@@ -10,6 +10,8 @@ from .search import (
__all__ = [
"CrudFactory",
"get_searchable_fields",
"JoinType",
"M2MFieldType",
"NoSearchableFieldsError",
"SearchConfig",
]

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload
from pydantic import BaseModel
@@ -11,7 +11,7 @@ from sqlalchemy import delete as sql_delete
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.exc import NoResultFound
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import DeclarativeBase, QueryableAttribute, selectinload
from sqlalchemy.sql.roles import WhereHavingRole
from ..db import get_transaction
@@ -21,6 +21,7 @@ from .search import SearchConfig, SearchFieldType, build_search_filters
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
JoinType = list[tuple[type[DeclarativeBase], Any]]
M2MFieldType = Mapping[str, QueryableAttribute[Any]]
class AsyncCrud(Generic[ModelType]):
@@ -31,6 +32,7 @@ class AsyncCrud(Generic[ModelType]):
model: ClassVar[type[DeclarativeBase]]
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
m2m_fields: ClassVar[M2MFieldType | None] = None
@overload
@classmethod
@@ -52,6 +54,62 @@ class AsyncCrud(Generic[ModelType]):
as_response: Literal[False] = ...,
) -> ModelType: ...
@classmethod
async def _resolve_m2m(
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
*,
only_set: bool = False,
) -> dict[str, list[Any]]:
"""Resolve M2M fields from a Pydantic schema into related model instances.
Args:
session: DB async session
obj: Pydantic model containing M2M ID fields
only_set: If True, only process fields explicitly set on the schema
Returns:
Dict mapping relationship attr names to lists of related instances
"""
result: dict[str, list[Any]] = {}
if not cls.m2m_fields:
return result
for schema_field, rel in cls.m2m_fields.items():
rel_attr = rel.property.key
related_model = rel.property.mapper.class_
if only_set and schema_field not in obj.model_fields_set:
continue
ids = getattr(obj, schema_field, None)
if ids is not None:
related = (
(
await session.execute(
select(related_model).where(related_model.id.in_(ids))
)
)
.scalars()
.all()
)
if len(related) != len(ids):
found_ids = {r.id for r in related}
missing = set(ids) - found_ids
raise NotFoundError(
f"Related {related_model.__name__} not found for IDs: {missing}"
)
result[rel_attr] = list(related)
else:
result[rel_attr] = []
return result
@classmethod
def _m2m_schema_fields(cls: type[Self]) -> set[str]:
"""Return the set of schema field names that are M2M fields."""
if not cls.m2m_fields:
return set()
return set(cls.m2m_fields.keys())
@classmethod
async def create(
cls: type[Self],
@@ -71,7 +129,17 @@ class AsyncCrud(Generic[ModelType]):
Created model instance or Response wrapping it
"""
async with get_transaction(session):
db_model = cls.model(**obj.model_dump())
m2m_exclude = cls._m2m_schema_fields()
data = (
obj.model_dump(exclude=m2m_exclude) if m2m_exclude else obj.model_dump()
)
db_model = cls.model(**data)
if m2m_exclude:
m2m_resolved = await cls._resolve_m2m(session, obj)
for rel_attr, related_instances in m2m_resolved.items():
setattr(db_model, rel_attr, related_instances)
session.add(db_model)
await session.refresh(db_model)
result = cast(ModelType, db_model)
@@ -299,12 +367,33 @@ class AsyncCrud(Generic[ModelType]):
NotFoundError: If no record found
"""
async with get_transaction(session):
db_model = await cls.get(session=session, filters=filters)
m2m_exclude = cls._m2m_schema_fields()
# Eagerly load M2M relationships that will be updated so that
# setattr does not trigger a lazy load (which fails in async).
m2m_load_options: list[Any] = []
if m2m_exclude and cls.m2m_fields:
for schema_field, rel in cls.m2m_fields.items():
if schema_field in obj.model_fields_set:
m2m_load_options.append(selectinload(rel))
db_model = await cls.get(
session=session,
filters=filters,
load_options=m2m_load_options or None,
)
values = obj.model_dump(
exclude_unset=exclude_unset, exclude_none=exclude_none
exclude_unset=exclude_unset,
exclude_none=exclude_none,
exclude=m2m_exclude,
)
for key, value in values.items():
setattr(db_model, key, value)
if m2m_exclude:
m2m_resolved = await cls._resolve_m2m(session, obj, only_set=True)
for rel_attr, related_instances in m2m_resolved.items():
setattr(db_model, rel_attr, related_instances)
await session.refresh(db_model)
if as_response:
return Response(data=db_model)
@@ -578,12 +667,16 @@ def CrudFactory(
model: type[ModelType],
*,
searchable_fields: Sequence[SearchFieldType] | None = None,
m2m_fields: M2MFieldType | None = None,
) -> type[AsyncCrud[ModelType]]:
"""Create a CRUD class for a specific model.
Args:
model: SQLAlchemy model class
searchable_fields: Optional list of searchable fields
m2m_fields: Optional mapping for many-to-many relationships.
Maps schema field names (containing lists of IDs) to
SQLAlchemy relationship attributes.
Returns:
AsyncCrud subclass bound to the model
@@ -601,10 +694,20 @@ def CrudFactory(
searchable_fields=[User.username, User.email, (User.role, Role.name)]
)
# With many-to-many fields:
# Schema has `tag_ids: list[UUID]`, model has `tags` relationship to Tag
PostCrud = CrudFactory(
Post,
m2m_fields={"tag_ids": Post.tags},
)
# Usage
user = await UserCrud.get(session, [User.id == 1])
posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id])
# Create with M2M - tag_ids are automatically resolved
post = await PostCrud.create(session, PostCreate(title="Hello", tag_ids=[id1, id2]))
# With search
result = await UserCrud.paginate(session, search="john")
@@ -628,6 +731,7 @@ def CrudFactory(
{
"model": model,
"searchable_fields": searchable_fields,
"m2m_fields": m2m_fields,
},
)
return cast(type[AsyncCrud[ModelType]], cls)