mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-03-01 17:00:48 +01:00
feat: add many to many support in CrudFactory (#65)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
"""Generic async CRUD operations for SQLAlchemy models."""
|
||||
|
||||
from ..exceptions import NoSearchableFieldsError
|
||||
from .factory import CrudFactory
|
||||
from .factory import CrudFactory, JoinType, M2MFieldType
|
||||
from .search import (
|
||||
SearchConfig,
|
||||
get_searchable_fields,
|
||||
@@ -10,6 +10,8 @@ from .search import (
|
||||
__all__ = [
|
||||
"CrudFactory",
|
||||
"get_searchable_fields",
|
||||
"JoinType",
|
||||
"M2MFieldType",
|
||||
"NoSearchableFieldsError",
|
||||
"SearchConfig",
|
||||
]
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -11,7 +11,7 @@ from sqlalchemy import delete as sql_delete
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from sqlalchemy.orm import DeclarativeBase, QueryableAttribute, selectinload
|
||||
from sqlalchemy.sql.roles import WhereHavingRole
|
||||
|
||||
from ..db import get_transaction
|
||||
@@ -21,6 +21,7 @@ from .search import SearchConfig, SearchFieldType, build_search_filters
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
||||
JoinType = list[tuple[type[DeclarativeBase], Any]]
|
||||
M2MFieldType = Mapping[str, QueryableAttribute[Any]]
|
||||
|
||||
|
||||
class AsyncCrud(Generic[ModelType]):
|
||||
@@ -31,6 +32,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
|
||||
model: ClassVar[type[DeclarativeBase]]
|
||||
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
|
||||
m2m_fields: ClassVar[M2MFieldType | None] = None
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
@@ -52,6 +54,62 @@ class AsyncCrud(Generic[ModelType]):
|
||||
as_response: Literal[False] = ...,
|
||||
) -> ModelType: ...
|
||||
|
||||
@classmethod
|
||||
async def _resolve_m2m(
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
obj: BaseModel,
|
||||
*,
|
||||
only_set: bool = False,
|
||||
) -> dict[str, list[Any]]:
|
||||
"""Resolve M2M fields from a Pydantic schema into related model instances.
|
||||
|
||||
Args:
|
||||
session: DB async session
|
||||
obj: Pydantic model containing M2M ID fields
|
||||
only_set: If True, only process fields explicitly set on the schema
|
||||
|
||||
Returns:
|
||||
Dict mapping relationship attr names to lists of related instances
|
||||
"""
|
||||
result: dict[str, list[Any]] = {}
|
||||
if not cls.m2m_fields:
|
||||
return result
|
||||
|
||||
for schema_field, rel in cls.m2m_fields.items():
|
||||
rel_attr = rel.property.key
|
||||
related_model = rel.property.mapper.class_
|
||||
if only_set and schema_field not in obj.model_fields_set:
|
||||
continue
|
||||
ids = getattr(obj, schema_field, None)
|
||||
if ids is not None:
|
||||
related = (
|
||||
(
|
||||
await session.execute(
|
||||
select(related_model).where(related_model.id.in_(ids))
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
if len(related) != len(ids):
|
||||
found_ids = {r.id for r in related}
|
||||
missing = set(ids) - found_ids
|
||||
raise NotFoundError(
|
||||
f"Related {related_model.__name__} not found for IDs: {missing}"
|
||||
)
|
||||
result[rel_attr] = list(related)
|
||||
else:
|
||||
result[rel_attr] = []
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _m2m_schema_fields(cls: type[Self]) -> set[str]:
|
||||
"""Return the set of schema field names that are M2M fields."""
|
||||
if not cls.m2m_fields:
|
||||
return set()
|
||||
return set(cls.m2m_fields.keys())
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls: type[Self],
|
||||
@@ -71,7 +129,17 @@ class AsyncCrud(Generic[ModelType]):
|
||||
Created model instance or Response wrapping it
|
||||
"""
|
||||
async with get_transaction(session):
|
||||
db_model = cls.model(**obj.model_dump())
|
||||
m2m_exclude = cls._m2m_schema_fields()
|
||||
data = (
|
||||
obj.model_dump(exclude=m2m_exclude) if m2m_exclude else obj.model_dump()
|
||||
)
|
||||
db_model = cls.model(**data)
|
||||
|
||||
if m2m_exclude:
|
||||
m2m_resolved = await cls._resolve_m2m(session, obj)
|
||||
for rel_attr, related_instances in m2m_resolved.items():
|
||||
setattr(db_model, rel_attr, related_instances)
|
||||
|
||||
session.add(db_model)
|
||||
await session.refresh(db_model)
|
||||
result = cast(ModelType, db_model)
|
||||
@@ -299,12 +367,33 @@ class AsyncCrud(Generic[ModelType]):
|
||||
NotFoundError: If no record found
|
||||
"""
|
||||
async with get_transaction(session):
|
||||
db_model = await cls.get(session=session, filters=filters)
|
||||
m2m_exclude = cls._m2m_schema_fields()
|
||||
|
||||
# Eagerly load M2M relationships that will be updated so that
|
||||
# setattr does not trigger a lazy load (which fails in async).
|
||||
m2m_load_options: list[Any] = []
|
||||
if m2m_exclude and cls.m2m_fields:
|
||||
for schema_field, rel in cls.m2m_fields.items():
|
||||
if schema_field in obj.model_fields_set:
|
||||
m2m_load_options.append(selectinload(rel))
|
||||
|
||||
db_model = await cls.get(
|
||||
session=session,
|
||||
filters=filters,
|
||||
load_options=m2m_load_options or None,
|
||||
)
|
||||
values = obj.model_dump(
|
||||
exclude_unset=exclude_unset, exclude_none=exclude_none
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_none=exclude_none,
|
||||
exclude=m2m_exclude,
|
||||
)
|
||||
for key, value in values.items():
|
||||
setattr(db_model, key, value)
|
||||
|
||||
if m2m_exclude:
|
||||
m2m_resolved = await cls._resolve_m2m(session, obj, only_set=True)
|
||||
for rel_attr, related_instances in m2m_resolved.items():
|
||||
setattr(db_model, rel_attr, related_instances)
|
||||
await session.refresh(db_model)
|
||||
if as_response:
|
||||
return Response(data=db_model)
|
||||
@@ -578,12 +667,16 @@ def CrudFactory(
|
||||
model: type[ModelType],
|
||||
*,
|
||||
searchable_fields: Sequence[SearchFieldType] | None = None,
|
||||
m2m_fields: M2MFieldType | None = None,
|
||||
) -> type[AsyncCrud[ModelType]]:
|
||||
"""Create a CRUD class for a specific model.
|
||||
|
||||
Args:
|
||||
model: SQLAlchemy model class
|
||||
searchable_fields: Optional list of searchable fields
|
||||
m2m_fields: Optional mapping for many-to-many relationships.
|
||||
Maps schema field names (containing lists of IDs) to
|
||||
SQLAlchemy relationship attributes.
|
||||
|
||||
Returns:
|
||||
AsyncCrud subclass bound to the model
|
||||
@@ -601,10 +694,20 @@ def CrudFactory(
|
||||
searchable_fields=[User.username, User.email, (User.role, Role.name)]
|
||||
)
|
||||
|
||||
# With many-to-many fields:
|
||||
# Schema has `tag_ids: list[UUID]`, model has `tags` relationship to Tag
|
||||
PostCrud = CrudFactory(
|
||||
Post,
|
||||
m2m_fields={"tag_ids": Post.tags},
|
||||
)
|
||||
|
||||
# Usage
|
||||
user = await UserCrud.get(session, [User.id == 1])
|
||||
posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id])
|
||||
|
||||
# Create with M2M - tag_ids are automatically resolved
|
||||
post = await PostCrud.create(session, PostCreate(title="Hello", tag_ids=[id1, id2]))
|
||||
|
||||
# With search
|
||||
result = await UserCrud.paginate(session, search="john")
|
||||
|
||||
@@ -628,6 +731,7 @@ def CrudFactory(
|
||||
{
|
||||
"model": model,
|
||||
"searchable_fields": searchable_fields,
|
||||
"m2m_fields": m2m_fields,
|
||||
},
|
||||
)
|
||||
return cast(type[AsyncCrud[ModelType]], cls)
|
||||
|
||||
Reference in New Issue
Block a user