feat: add opt-in default_load_options parameter in CrudFactory (#82)

* feat: add opt-in default_load_options parameter in CrudFactory

* docs: add Relationship loading in CRUD
This commit is contained in:
d3vyce
2026-02-21 12:35:15 +01:00
committed by GitHub
parent 31678935aa
commit 9d07dfea85
3 changed files with 245 additions and 33 deletions

View File

@@ -103,6 +103,41 @@ async def get_users(
) )
``` ```
## Relationship loading
!!! info "Added in v1.1"
By default, SQLAlchemy relationships are not loaded unless explicitly requested. Instead of using `lazy="selectin"` on model definitions (which is implicit and applies globally), define a `default_load_options` on the CRUD class to control loading strategy explicitly.
!!! warning
Avoid using `lazy="selectin"` on model relationships. It fires silently on every query, cannot be disabled per-call, and can cause unexpected cascading loads through deep relationship chains. Use `default_load_options` instead.
```python
from sqlalchemy.orm import selectinload
ArticleCrud = CrudFactory(
model=Article,
default_load_options=[
selectinload(Article.category),
selectinload(Article.tags),
],
)
```
`default_load_options` applies automatically to all read operations (`get`, `first`, `get_multi`, `paginate`). When `load_options` is passed at call-site, it **fully replaces** `default_load_options` for that query — giving you precise per-call control:
```python
# Only loads category, tags are not loaded
article = await ArticleCrud.get(
session=session,
filters=[Article.id == article_id],
load_options=[selectinload(Article.category)],
)
# Loads nothing — useful for write-then-refresh flows or lightweight checks
articles = await ArticleCrud.get_multi(session=session, load_options=[])
```
## Many-to-many relationships ## Many-to-many relationships
Use `m2m_fields` to map schema fields containing lists of IDs to SQLAlchemy relationships. The CRUD class resolves and validates all IDs before persisting: Use `m2m_fields` to map schema fields containing lists of IDs to SQLAlchemy relationships. The CRUD class resolves and validates all IDs before persisting:

View File

@@ -12,6 +12,7 @@ from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.exc import NoResultFound from sqlalchemy.exc import NoResultFound
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase, QueryableAttribute, selectinload from sqlalchemy.orm import DeclarativeBase, QueryableAttribute, selectinload
from sqlalchemy.sql.base import ExecutableOption
from sqlalchemy.sql.roles import WhereHavingRole from sqlalchemy.sql.roles import WhereHavingRole
from ..db import get_transaction from ..db import get_transaction
@@ -33,26 +34,16 @@ class AsyncCrud(Generic[ModelType]):
model: ClassVar[type[DeclarativeBase]] model: ClassVar[type[DeclarativeBase]]
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
m2m_fields: ClassVar[M2MFieldType | None] = None m2m_fields: ClassVar[M2MFieldType | None] = None
default_load_options: ClassVar[list[ExecutableOption] | None] = None
@overload
@classmethod @classmethod
async def create( # pragma: no cover def _resolve_load_options(
cls: type[Self], cls, load_options: list[ExecutableOption] | None
session: AsyncSession, ) -> list[ExecutableOption] | None:
obj: BaseModel, """Return load_options if provided, else fall back to default_load_options."""
*, if load_options is not None:
as_response: Literal[True], return load_options
) -> Response[ModelType]: ... return cls.default_load_options
@overload
@classmethod
async def create( # pragma: no cover
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
*,
as_response: Literal[False] = ...,
) -> ModelType: ...
@classmethod @classmethod
async def _resolve_m2m( async def _resolve_m2m(
@@ -110,6 +101,26 @@ class AsyncCrud(Generic[ModelType]):
return set() return set()
return set(cls.m2m_fields.keys()) return set(cls.m2m_fields.keys())
@overload
@classmethod
async def create( # pragma: no cover
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
*,
as_response: Literal[True],
) -> Response[ModelType]: ...
@overload
@classmethod
async def create( # pragma: no cover
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
*,
as_response: Literal[False] = ...,
) -> ModelType: ...
@classmethod @classmethod
async def create( async def create(
cls: type[Self], cls: type[Self],
@@ -157,7 +168,7 @@ class AsyncCrud(Generic[ModelType]):
joins: JoinType | None = None, joins: JoinType | None = None,
outer_join: bool = False, outer_join: bool = False,
with_for_update: bool = False, with_for_update: bool = False,
load_options: list[Any] | None = None, load_options: list[ExecutableOption] | None = None,
as_response: Literal[True], as_response: Literal[True],
) -> Response[ModelType]: ... ) -> Response[ModelType]: ...
@@ -171,7 +182,7 @@ class AsyncCrud(Generic[ModelType]):
joins: JoinType | None = None, joins: JoinType | None = None,
outer_join: bool = False, outer_join: bool = False,
with_for_update: bool = False, with_for_update: bool = False,
load_options: list[Any] | None = None, load_options: list[ExecutableOption] | None = None,
as_response: Literal[False] = ..., as_response: Literal[False] = ...,
) -> ModelType: ... ) -> ModelType: ...
@@ -184,7 +195,7 @@ class AsyncCrud(Generic[ModelType]):
joins: JoinType | None = None, joins: JoinType | None = None,
outer_join: bool = False, outer_join: bool = False,
with_for_update: bool = False, with_for_update: bool = False,
load_options: list[Any] | None = None, load_options: list[ExecutableOption] | None = None,
as_response: bool = False, as_response: bool = False,
) -> ModelType | Response[ModelType]: ) -> ModelType | Response[ModelType]:
"""Get exactly one record. Raises NotFoundError if not found. """Get exactly one record. Raises NotFoundError if not found.
@@ -214,8 +225,8 @@ class AsyncCrud(Generic[ModelType]):
else q.join(model, condition) else q.join(model, condition)
) )
q = q.where(and_(*filters)) q = q.where(and_(*filters))
if load_options: if resolved := cls._resolve_load_options(load_options):
q = q.options(*load_options) q = q.options(*resolved)
if with_for_update: if with_for_update:
q = q.with_for_update() q = q.with_for_update()
result = await session.execute(q) result = await session.execute(q)
@@ -235,7 +246,7 @@ class AsyncCrud(Generic[ModelType]):
*, *,
joins: JoinType | None = None, joins: JoinType | None = None,
outer_join: bool = False, outer_join: bool = False,
load_options: list[Any] | None = None, load_options: list[ExecutableOption] | None = None,
) -> ModelType | None: ) -> ModelType | None:
"""Get the first matching record, or None. """Get the first matching record, or None.
@@ -259,8 +270,8 @@ class AsyncCrud(Generic[ModelType]):
) )
if filters: if filters:
q = q.where(and_(*filters)) q = q.where(and_(*filters))
if load_options: if resolved := cls._resolve_load_options(load_options):
q = q.options(*load_options) q = q.options(*resolved)
result = await session.execute(q) result = await session.execute(q)
return cast(ModelType | None, result.unique().scalars().first()) return cast(ModelType | None, result.unique().scalars().first())
@@ -272,7 +283,7 @@ class AsyncCrud(Generic[ModelType]):
filters: list[Any] | None = None, filters: list[Any] | None = None,
joins: JoinType | None = None, joins: JoinType | None = None,
outer_join: bool = False, outer_join: bool = False,
load_options: list[Any] | None = None, load_options: list[ExecutableOption] | None = None,
order_by: Any | None = None, order_by: Any | None = None,
limit: int | None = None, limit: int | None = None,
offset: int | None = None, offset: int | None = None,
@@ -302,8 +313,8 @@ class AsyncCrud(Generic[ModelType]):
) )
if filters: if filters:
q = q.where(and_(*filters)) q = q.where(and_(*filters))
if load_options: if resolved := cls._resolve_load_options(load_options):
q = q.options(*load_options) q = q.options(*resolved)
if order_by is not None: if order_by is not None:
q = q.order_by(order_by) q = q.order_by(order_by)
if offset is not None: if offset is not None:
@@ -371,7 +382,7 @@ class AsyncCrud(Generic[ModelType]):
# Eagerly load M2M relationships that will be updated so that # Eagerly load M2M relationships that will be updated so that
# setattr does not trigger a lazy load (which fails in async). # setattr does not trigger a lazy load (which fails in async).
m2m_load_options: list[Any] = [] m2m_load_options: list[ExecutableOption] = []
if m2m_exclude and cls.m2m_fields: if m2m_exclude and cls.m2m_fields:
for schema_field, rel in cls.m2m_fields.items(): for schema_field, rel in cls.m2m_fields.items():
if schema_field in obj.model_fields_set: if schema_field in obj.model_fields_set:
@@ -563,7 +574,7 @@ class AsyncCrud(Generic[ModelType]):
filters: list[Any] | None = None, filters: list[Any] | None = None,
joins: JoinType | None = None, joins: JoinType | None = None,
outer_join: bool = False, outer_join: bool = False,
load_options: list[Any] | None = None, load_options: list[ExecutableOption] | None = None,
order_by: Any | None = None, order_by: Any | None = None,
page: int = 1, page: int = 1,
items_per_page: int = 20, items_per_page: int = 20,
@@ -619,8 +630,8 @@ class AsyncCrud(Generic[ModelType]):
if filters: if filters:
q = q.where(and_(*filters)) q = q.where(and_(*filters))
if load_options: if resolved := cls._resolve_load_options(load_options):
q = q.options(*load_options) q = q.options(*resolved)
if order_by is not None: if order_by is not None:
q = q.order_by(order_by) q = q.order_by(order_by)
@@ -668,6 +679,7 @@ def CrudFactory(
*, *,
searchable_fields: Sequence[SearchFieldType] | None = None, searchable_fields: Sequence[SearchFieldType] | None = None,
m2m_fields: M2MFieldType | None = None, m2m_fields: M2MFieldType | None = None,
default_load_options: list[ExecutableOption] | None = None,
) -> type[AsyncCrud[ModelType]]: ) -> type[AsyncCrud[ModelType]]:
"""Create a CRUD class for a specific model. """Create a CRUD class for a specific model.
@@ -677,6 +689,11 @@ def CrudFactory(
m2m_fields: Optional mapping for many-to-many relationships. m2m_fields: Optional mapping for many-to-many relationships.
Maps schema field names (containing lists of IDs) to Maps schema field names (containing lists of IDs) to
SQLAlchemy relationship attributes. SQLAlchemy relationship attributes.
default_load_options: Default SQLAlchemy loader options applied to all read
queries when no explicit ``load_options`` are passed. Use this
instead of ``lazy="selectin"`` on the model so that loading
strategy is explicit and per-CRUD. Overridden entirely (not
merged) when ``load_options`` is provided at call-site.
Returns: Returns:
AsyncCrud subclass bound to the model AsyncCrud subclass bound to the model
@@ -702,6 +719,19 @@ def CrudFactory(
m2m_fields={"tag_ids": Post.tags}, m2m_fields={"tag_ids": Post.tags},
) )
# With default load strategy (replaces lazy="selectin" on the model):
ArticleCrud = CrudFactory(
Article,
default_load_options=[selectinload(Article.category), selectinload(Article.tags)],
)
# Override default_load_options for a specific call:
article = await ArticleCrud.get(
session,
[Article.id == 1],
load_options=[selectinload(Article.category)], # tags won't load
)
# Usage # Usage
user = await UserCrud.get(session, [User.id == 1]) user = await UserCrud.get(session, [User.id == 1])
posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id]) posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id])
@@ -734,6 +764,7 @@ def CrudFactory(
"model": model, "model": model,
"searchable_fields": searchable_fields, "searchable_fields": searchable_fields,
"m2m_fields": m2m_fields, "m2m_fields": m2m_fields,
"default_load_options": default_load_options,
}, },
) )
return cast(type[AsyncCrud[ModelType]], cls) return cast(type[AsyncCrud[ModelType]], cls)

View File

@@ -50,6 +50,152 @@ class TestCrudFactory:
crud = CrudFactory(User) crud = CrudFactory(User)
assert "User" in crud.__name__ assert "User" in crud.__name__
def test_default_load_options_none_by_default(self):
"""default_load_options is None when not specified."""
crud = CrudFactory(User)
assert crud.default_load_options is None
def test_default_load_options_set(self):
"""default_load_options is stored on the class."""
options = [selectinload(User.role)]
crud = CrudFactory(User, default_load_options=options)
assert crud.default_load_options == options
def test_default_load_options_not_shared_between_classes(self):
"""default_load_options is isolated per factory call."""
options = [selectinload(User.role)]
crud_with = CrudFactory(User, default_load_options=options)
crud_without = CrudFactory(User)
assert crud_with.default_load_options == options
assert crud_without.default_load_options is None
class TestResolveLoadOptions:
"""Tests for _resolve_load_options logic."""
def test_returns_load_options_when_provided(self):
"""Explicit load_options takes priority over default_load_options."""
options = [selectinload(User.role)]
default = [selectinload(Post.tags)]
crud = CrudFactory(User, default_load_options=default)
assert crud._resolve_load_options(options) == options
def test_returns_default_when_load_options_is_none(self):
"""Falls back to default_load_options when load_options is None."""
default = [selectinload(User.role)]
crud = CrudFactory(User, default_load_options=default)
assert crud._resolve_load_options(None) == default
def test_returns_none_when_both_are_none(self):
"""Returns None when neither load_options nor default_load_options set."""
crud = CrudFactory(User)
assert crud._resolve_load_options(None) is None
def test_empty_list_overrides_default(self):
"""An empty list is a valid override and disables default_load_options."""
default = [selectinload(User.role)]
crud = CrudFactory(User, default_load_options=default)
# Empty list is not None, so it should replace default
assert crud._resolve_load_options([]) == []
class TestDefaultLoadOptionsIntegration:
"""Integration tests for default_load_options with real DB queries."""
@pytest.mark.anyio
async def test_default_load_options_applied_to_get(self, db_session: AsyncSession):
"""default_load_options loads relationships automatically on get()."""
UserWithDefaultLoad = CrudFactory(
User, default_load_options=[selectinload(User.role)]
)
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
user = await UserCrud.create(
db_session,
UserCreate(username="alice", email="alice@test.com", role_id=role.id),
)
fetched = await UserWithDefaultLoad.get(db_session, [User.id == user.id])
assert fetched.role is not None
assert fetched.role.name == "admin"
@pytest.mark.anyio
async def test_default_load_options_applied_to_get_multi(
self, db_session: AsyncSession
):
"""default_load_options loads relationships automatically on get_multi()."""
UserWithDefaultLoad = CrudFactory(
User, default_load_options=[selectinload(User.role)]
)
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
await UserCrud.create(
db_session,
UserCreate(username="alice", email="alice@test.com", role_id=role.id),
)
users = await UserWithDefaultLoad.get_multi(db_session)
assert users[0].role is not None
assert users[0].role.name == "admin"
@pytest.mark.anyio
async def test_default_load_options_applied_to_first(
self, db_session: AsyncSession
):
"""default_load_options loads relationships automatically on first()."""
UserWithDefaultLoad = CrudFactory(
User, default_load_options=[selectinload(User.role)]
)
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
await UserCrud.create(
db_session,
UserCreate(username="alice", email="alice@test.com", role_id=role.id),
)
user = await UserWithDefaultLoad.first(db_session)
assert user is not None
assert user.role is not None
assert user.role.name == "admin"
@pytest.mark.anyio
async def test_default_load_options_applied_to_paginate(
self, db_session: AsyncSession
):
"""default_load_options loads relationships automatically on paginate()."""
UserWithDefaultLoad = CrudFactory(
User, default_load_options=[selectinload(User.role)]
)
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
await UserCrud.create(
db_session,
UserCreate(username="alice", email="alice@test.com", role_id=role.id),
)
result = await UserWithDefaultLoad.paginate(db_session)
assert result.data[0].role is not None
assert result.data[0].role.name == "admin"
@pytest.mark.anyio
async def test_load_options_overrides_default_load_options(
self, db_session: AsyncSession
):
"""Explicit load_options fully replaces default_load_options."""
PostWithDefaultLoad = CrudFactory(
Post,
default_load_options=[selectinload(Post.tags)],
)
user = await UserCrud.create(
db_session,
UserCreate(username="alice", email="alice@test.com"),
)
post = await PostCrud.create(
db_session,
PostCreate(title="Hello", author_id=user.id),
)
# Pass empty load_options to override default — tags should not load
fetched = await PostWithDefaultLoad.get(
db_session,
[Post.id == post.id],
load_options=[],
)
# tags were not loaded — accessing them would lazy-load or return empty
# We just assert the fetch itself succeeded with the override
assert fetched.id == post.id
class TestCrudCreate: class TestCrudCreate:
"""Tests for CRUD create operations.""" """Tests for CRUD create operations."""