mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-03-01 17:00:48 +01:00
feat: add opt-in default_load_options parameter in CrudFactory (#82)
* feat: add opt-in default_load_options parameter in CrudFactory * docs: add Relationship loading in CRUD
This commit is contained in:
@@ -12,6 +12,7 @@ from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import DeclarativeBase, QueryableAttribute, selectinload
|
||||
from sqlalchemy.sql.base import ExecutableOption
|
||||
from sqlalchemy.sql.roles import WhereHavingRole
|
||||
|
||||
from ..db import get_transaction
|
||||
@@ -33,26 +34,16 @@ class AsyncCrud(Generic[ModelType]):
|
||||
model: ClassVar[type[DeclarativeBase]]
|
||||
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
|
||||
m2m_fields: ClassVar[M2MFieldType | None] = None
|
||||
default_load_options: ClassVar[list[ExecutableOption] | None] = None
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def create( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
obj: BaseModel,
|
||||
*,
|
||||
as_response: Literal[True],
|
||||
) -> Response[ModelType]: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def create( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
obj: BaseModel,
|
||||
*,
|
||||
as_response: Literal[False] = ...,
|
||||
) -> ModelType: ...
|
||||
def _resolve_load_options(
|
||||
cls, load_options: list[ExecutableOption] | None
|
||||
) -> list[ExecutableOption] | None:
|
||||
"""Return load_options if provided, else fall back to default_load_options."""
|
||||
if load_options is not None:
|
||||
return load_options
|
||||
return cls.default_load_options
|
||||
|
||||
@classmethod
|
||||
async def _resolve_m2m(
|
||||
@@ -110,6 +101,26 @@ class AsyncCrud(Generic[ModelType]):
|
||||
return set()
|
||||
return set(cls.m2m_fields.keys())
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def create( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
obj: BaseModel,
|
||||
*,
|
||||
as_response: Literal[True],
|
||||
) -> Response[ModelType]: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def create( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
obj: BaseModel,
|
||||
*,
|
||||
as_response: Literal[False] = ...,
|
||||
) -> ModelType: ...
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls: type[Self],
|
||||
@@ -157,7 +168,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
with_for_update: bool = False,
|
||||
load_options: list[Any] | None = None,
|
||||
load_options: list[ExecutableOption] | None = None,
|
||||
as_response: Literal[True],
|
||||
) -> Response[ModelType]: ...
|
||||
|
||||
@@ -171,7 +182,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
with_for_update: bool = False,
|
||||
load_options: list[Any] | None = None,
|
||||
load_options: list[ExecutableOption] | None = None,
|
||||
as_response: Literal[False] = ...,
|
||||
) -> ModelType: ...
|
||||
|
||||
@@ -184,7 +195,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
with_for_update: bool = False,
|
||||
load_options: list[Any] | None = None,
|
||||
load_options: list[ExecutableOption] | None = None,
|
||||
as_response: bool = False,
|
||||
) -> ModelType | Response[ModelType]:
|
||||
"""Get exactly one record. Raises NotFoundError if not found.
|
||||
@@ -214,8 +225,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
else q.join(model, condition)
|
||||
)
|
||||
q = q.where(and_(*filters))
|
||||
if load_options:
|
||||
q = q.options(*load_options)
|
||||
if resolved := cls._resolve_load_options(load_options):
|
||||
q = q.options(*resolved)
|
||||
if with_for_update:
|
||||
q = q.with_for_update()
|
||||
result = await session.execute(q)
|
||||
@@ -235,7 +246,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
*,
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
load_options: list[Any] | None = None,
|
||||
load_options: list[ExecutableOption] | None = None,
|
||||
) -> ModelType | None:
|
||||
"""Get the first matching record, or None.
|
||||
|
||||
@@ -259,8 +270,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
)
|
||||
if filters:
|
||||
q = q.where(and_(*filters))
|
||||
if load_options:
|
||||
q = q.options(*load_options)
|
||||
if resolved := cls._resolve_load_options(load_options):
|
||||
q = q.options(*resolved)
|
||||
result = await session.execute(q)
|
||||
return cast(ModelType | None, result.unique().scalars().first())
|
||||
|
||||
@@ -272,7 +283,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
filters: list[Any] | None = None,
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
load_options: list[Any] | None = None,
|
||||
load_options: list[ExecutableOption] | None = None,
|
||||
order_by: Any | None = None,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
@@ -302,8 +313,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
)
|
||||
if filters:
|
||||
q = q.where(and_(*filters))
|
||||
if load_options:
|
||||
q = q.options(*load_options)
|
||||
if resolved := cls._resolve_load_options(load_options):
|
||||
q = q.options(*resolved)
|
||||
if order_by is not None:
|
||||
q = q.order_by(order_by)
|
||||
if offset is not None:
|
||||
@@ -371,7 +382,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
|
||||
# Eagerly load M2M relationships that will be updated so that
|
||||
# setattr does not trigger a lazy load (which fails in async).
|
||||
m2m_load_options: list[Any] = []
|
||||
m2m_load_options: list[ExecutableOption] = []
|
||||
if m2m_exclude and cls.m2m_fields:
|
||||
for schema_field, rel in cls.m2m_fields.items():
|
||||
if schema_field in obj.model_fields_set:
|
||||
@@ -563,7 +574,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
filters: list[Any] | None = None,
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
load_options: list[Any] | None = None,
|
||||
load_options: list[ExecutableOption] | None = None,
|
||||
order_by: Any | None = None,
|
||||
page: int = 1,
|
||||
items_per_page: int = 20,
|
||||
@@ -619,8 +630,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
|
||||
if filters:
|
||||
q = q.where(and_(*filters))
|
||||
if load_options:
|
||||
q = q.options(*load_options)
|
||||
if resolved := cls._resolve_load_options(load_options):
|
||||
q = q.options(*resolved)
|
||||
if order_by is not None:
|
||||
q = q.order_by(order_by)
|
||||
|
||||
@@ -668,6 +679,7 @@ def CrudFactory(
|
||||
*,
|
||||
searchable_fields: Sequence[SearchFieldType] | None = None,
|
||||
m2m_fields: M2MFieldType | None = None,
|
||||
default_load_options: list[ExecutableOption] | None = None,
|
||||
) -> type[AsyncCrud[ModelType]]:
|
||||
"""Create a CRUD class for a specific model.
|
||||
|
||||
@@ -677,6 +689,11 @@ def CrudFactory(
|
||||
m2m_fields: Optional mapping for many-to-many relationships.
|
||||
Maps schema field names (containing lists of IDs) to
|
||||
SQLAlchemy relationship attributes.
|
||||
default_load_options: Default SQLAlchemy loader options applied to all read
|
||||
queries when no explicit ``load_options`` are passed. Use this
|
||||
instead of ``lazy="selectin"`` on the model so that loading
|
||||
strategy is explicit and per-CRUD. Overridden entirely (not
|
||||
merged) when ``load_options`` is provided at call-site.
|
||||
|
||||
Returns:
|
||||
AsyncCrud subclass bound to the model
|
||||
@@ -702,6 +719,19 @@ def CrudFactory(
|
||||
m2m_fields={"tag_ids": Post.tags},
|
||||
)
|
||||
|
||||
# With default load strategy (replaces lazy="selectin" on the model):
|
||||
ArticleCrud = CrudFactory(
|
||||
Article,
|
||||
default_load_options=[selectinload(Article.category), selectinload(Article.tags)],
|
||||
)
|
||||
|
||||
# Override default_load_options for a specific call:
|
||||
article = await ArticleCrud.get(
|
||||
session,
|
||||
[Article.id == 1],
|
||||
load_options=[selectinload(Article.category)], # tags won't load
|
||||
)
|
||||
|
||||
# Usage
|
||||
user = await UserCrud.get(session, [User.id == 1])
|
||||
posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id])
|
||||
@@ -734,6 +764,7 @@ def CrudFactory(
|
||||
"model": model,
|
||||
"searchable_fields": searchable_fields,
|
||||
"m2m_fields": m2m_fields,
|
||||
"default_load_options": default_load_options,
|
||||
},
|
||||
)
|
||||
return cast(type[AsyncCrud[ModelType]], cls)
|
||||
|
||||
Reference in New Issue
Block a user