diff --git a/docs/module/crud.md b/docs/module/crud.md index 283cc9d..4cfd60e 100644 --- a/docs/module/crud.md +++ b/docs/module/crud.md @@ -103,6 +103,41 @@ async def get_users( ) ``` +## Relationship loading + +!!! info "Added in v1.1" + +By default, SQLAlchemy relationships are not loaded unless explicitly requested. Instead of using `lazy="selectin"` on model definitions (which is implicit and applies globally), define a `default_load_options` on the CRUD class to control loading strategy explicitly. + +!!! warning + Avoid using `lazy="selectin"` on model relationships. It fires silently on every query, cannot be disabled per-call, and can cause unexpected cascading loads through deep relationship chains. Use `default_load_options` instead. + +```python +from sqlalchemy.orm import selectinload + +ArticleCrud = CrudFactory( + model=Article, + default_load_options=[ + selectinload(Article.category), + selectinload(Article.tags), + ], +) +``` + +`default_load_options` applies automatically to all read operations (`get`, `first`, `get_multi`, `paginate`). When `load_options` is passed at call-site, it **fully replaces** `default_load_options` for that query — giving you precise per-call control: + +```python +# Only loads category, tags are not loaded +article = await ArticleCrud.get( + session=session, + filters=[Article.id == article_id], + load_options=[selectinload(Article.category)], +) + +# Loads nothing — useful for write-then-refresh flows or lightweight checks +articles = await ArticleCrud.get_multi(session=session, load_options=[]) +``` + ## Many-to-many relationships Use `m2m_fields` to map schema fields containing lists of IDs to SQLAlchemy relationships. The CRUD class resolves and validates all IDs before persisting: diff --git a/src/fastapi_toolsets/crud/factory.py b/src/fastapi_toolsets/crud/factory.py index ea4795d..dd9d89d 100644 --- a/src/fastapi_toolsets/crud/factory.py +++ b/src/fastapi_toolsets/crud/factory.py @@ -12,6 +12,7 @@ from sqlalchemy.dialects.postgresql import insert from sqlalchemy.exc import NoResultFound from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import DeclarativeBase, QueryableAttribute, selectinload +from sqlalchemy.sql.base import ExecutableOption from sqlalchemy.sql.roles import WhereHavingRole from ..db import get_transaction @@ -33,26 +34,16 @@ class AsyncCrud(Generic[ModelType]): model: ClassVar[type[DeclarativeBase]] searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None m2m_fields: ClassVar[M2MFieldType | None] = None + default_load_options: ClassVar[list[ExecutableOption] | None] = None - @overload @classmethod - async def create( # pragma: no cover - cls: type[Self], - session: AsyncSession, - obj: BaseModel, - *, - as_response: Literal[True], - ) -> Response[ModelType]: ... - - @overload - @classmethod - async def create( # pragma: no cover - cls: type[Self], - session: AsyncSession, - obj: BaseModel, - *, - as_response: Literal[False] = ..., - ) -> ModelType: ... + def _resolve_load_options( + cls, load_options: list[ExecutableOption] | None + ) -> list[ExecutableOption] | None: + """Return load_options if provided, else fall back to default_load_options.""" + if load_options is not None: + return load_options + return cls.default_load_options @classmethod async def _resolve_m2m( @@ -110,6 +101,26 @@ class AsyncCrud(Generic[ModelType]): return set() return set(cls.m2m_fields.keys()) + @overload + @classmethod + async def create( # pragma: no cover + cls: type[Self], + session: AsyncSession, + obj: BaseModel, + *, + as_response: Literal[True], + ) -> Response[ModelType]: ... + + @overload + @classmethod + async def create( # pragma: no cover + cls: type[Self], + session: AsyncSession, + obj: BaseModel, + *, + as_response: Literal[False] = ..., + ) -> ModelType: ... + @classmethod async def create( cls: type[Self], @@ -157,7 +168,7 @@ class AsyncCrud(Generic[ModelType]): joins: JoinType | None = None, outer_join: bool = False, with_for_update: bool = False, - load_options: list[Any] | None = None, + load_options: list[ExecutableOption] | None = None, as_response: Literal[True], ) -> Response[ModelType]: ... @@ -171,7 +182,7 @@ class AsyncCrud(Generic[ModelType]): joins: JoinType | None = None, outer_join: bool = False, with_for_update: bool = False, - load_options: list[Any] | None = None, + load_options: list[ExecutableOption] | None = None, as_response: Literal[False] = ..., ) -> ModelType: ... @@ -184,7 +195,7 @@ class AsyncCrud(Generic[ModelType]): joins: JoinType | None = None, outer_join: bool = False, with_for_update: bool = False, - load_options: list[Any] | None = None, + load_options: list[ExecutableOption] | None = None, as_response: bool = False, ) -> ModelType | Response[ModelType]: """Get exactly one record. Raises NotFoundError if not found. @@ -214,8 +225,8 @@ class AsyncCrud(Generic[ModelType]): else q.join(model, condition) ) q = q.where(and_(*filters)) - if load_options: - q = q.options(*load_options) + if resolved := cls._resolve_load_options(load_options): + q = q.options(*resolved) if with_for_update: q = q.with_for_update() result = await session.execute(q) @@ -235,7 +246,7 @@ class AsyncCrud(Generic[ModelType]): *, joins: JoinType | None = None, outer_join: bool = False, - load_options: list[Any] | None = None, + load_options: list[ExecutableOption] | None = None, ) -> ModelType | None: """Get the first matching record, or None. @@ -259,8 +270,8 @@ class AsyncCrud(Generic[ModelType]): ) if filters: q = q.where(and_(*filters)) - if load_options: - q = q.options(*load_options) + if resolved := cls._resolve_load_options(load_options): + q = q.options(*resolved) result = await session.execute(q) return cast(ModelType | None, result.unique().scalars().first()) @@ -272,7 +283,7 @@ class AsyncCrud(Generic[ModelType]): filters: list[Any] | None = None, joins: JoinType | None = None, outer_join: bool = False, - load_options: list[Any] | None = None, + load_options: list[ExecutableOption] | None = None, order_by: Any | None = None, limit: int | None = None, offset: int | None = None, @@ -302,8 +313,8 @@ class AsyncCrud(Generic[ModelType]): ) if filters: q = q.where(and_(*filters)) - if load_options: - q = q.options(*load_options) + if resolved := cls._resolve_load_options(load_options): + q = q.options(*resolved) if order_by is not None: q = q.order_by(order_by) if offset is not None: @@ -371,7 +382,7 @@ class AsyncCrud(Generic[ModelType]): # Eagerly load M2M relationships that will be updated so that # setattr does not trigger a lazy load (which fails in async). - m2m_load_options: list[Any] = [] + m2m_load_options: list[ExecutableOption] = [] if m2m_exclude and cls.m2m_fields: for schema_field, rel in cls.m2m_fields.items(): if schema_field in obj.model_fields_set: @@ -563,7 +574,7 @@ class AsyncCrud(Generic[ModelType]): filters: list[Any] | None = None, joins: JoinType | None = None, outer_join: bool = False, - load_options: list[Any] | None = None, + load_options: list[ExecutableOption] | None = None, order_by: Any | None = None, page: int = 1, items_per_page: int = 20, @@ -619,8 +630,8 @@ class AsyncCrud(Generic[ModelType]): if filters: q = q.where(and_(*filters)) - if load_options: - q = q.options(*load_options) + if resolved := cls._resolve_load_options(load_options): + q = q.options(*resolved) if order_by is not None: q = q.order_by(order_by) @@ -668,6 +679,7 @@ def CrudFactory( *, searchable_fields: Sequence[SearchFieldType] | None = None, m2m_fields: M2MFieldType | None = None, + default_load_options: list[ExecutableOption] | None = None, ) -> type[AsyncCrud[ModelType]]: """Create a CRUD class for a specific model. @@ -677,6 +689,11 @@ def CrudFactory( m2m_fields: Optional mapping for many-to-many relationships. Maps schema field names (containing lists of IDs) to SQLAlchemy relationship attributes. + default_load_options: Default SQLAlchemy loader options applied to all read + queries when no explicit ``load_options`` are passed. Use this + instead of ``lazy="selectin"`` on the model so that loading + strategy is explicit and per-CRUD. Overridden entirely (not + merged) when ``load_options`` is provided at call-site. Returns: AsyncCrud subclass bound to the model @@ -702,6 +719,19 @@ def CrudFactory( m2m_fields={"tag_ids": Post.tags}, ) + # With default load strategy (replaces lazy="selectin" on the model): + ArticleCrud = CrudFactory( + Article, + default_load_options=[selectinload(Article.category), selectinload(Article.tags)], + ) + + # Override default_load_options for a specific call: + article = await ArticleCrud.get( + session, + [Article.id == 1], + load_options=[selectinload(Article.category)], # tags won't load + ) + # Usage user = await UserCrud.get(session, [User.id == 1]) posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id]) @@ -734,6 +764,7 @@ def CrudFactory( "model": model, "searchable_fields": searchable_fields, "m2m_fields": m2m_fields, + "default_load_options": default_load_options, }, ) return cast(type[AsyncCrud[ModelType]], cls) diff --git a/tests/test_crud.py b/tests/test_crud.py index d91e0e0..f2dba0c 100644 --- a/tests/test_crud.py +++ b/tests/test_crud.py @@ -50,6 +50,152 @@ class TestCrudFactory: crud = CrudFactory(User) assert "User" in crud.__name__ + def test_default_load_options_none_by_default(self): + """default_load_options is None when not specified.""" + crud = CrudFactory(User) + assert crud.default_load_options is None + + def test_default_load_options_set(self): + """default_load_options is stored on the class.""" + options = [selectinload(User.role)] + crud = CrudFactory(User, default_load_options=options) + assert crud.default_load_options == options + + def test_default_load_options_not_shared_between_classes(self): + """default_load_options is isolated per factory call.""" + options = [selectinload(User.role)] + crud_with = CrudFactory(User, default_load_options=options) + crud_without = CrudFactory(User) + assert crud_with.default_load_options == options + assert crud_without.default_load_options is None + + +class TestResolveLoadOptions: + """Tests for _resolve_load_options logic.""" + + def test_returns_load_options_when_provided(self): + """Explicit load_options takes priority over default_load_options.""" + options = [selectinload(User.role)] + default = [selectinload(Post.tags)] + crud = CrudFactory(User, default_load_options=default) + assert crud._resolve_load_options(options) == options + + def test_returns_default_when_load_options_is_none(self): + """Falls back to default_load_options when load_options is None.""" + default = [selectinload(User.role)] + crud = CrudFactory(User, default_load_options=default) + assert crud._resolve_load_options(None) == default + + def test_returns_none_when_both_are_none(self): + """Returns None when neither load_options nor default_load_options set.""" + crud = CrudFactory(User) + assert crud._resolve_load_options(None) is None + + def test_empty_list_overrides_default(self): + """An empty list is a valid override and disables default_load_options.""" + default = [selectinload(User.role)] + crud = CrudFactory(User, default_load_options=default) + # Empty list is not None, so it should replace default + assert crud._resolve_load_options([]) == [] + + +class TestDefaultLoadOptionsIntegration: + """Integration tests for default_load_options with real DB queries.""" + + @pytest.mark.anyio + async def test_default_load_options_applied_to_get(self, db_session: AsyncSession): + """default_load_options loads relationships automatically on get().""" + UserWithDefaultLoad = CrudFactory( + User, default_load_options=[selectinload(User.role)] + ) + role = await RoleCrud.create(db_session, RoleCreate(name="admin")) + user = await UserCrud.create( + db_session, + UserCreate(username="alice", email="alice@test.com", role_id=role.id), + ) + fetched = await UserWithDefaultLoad.get(db_session, [User.id == user.id]) + assert fetched.role is not None + assert fetched.role.name == "admin" + + @pytest.mark.anyio + async def test_default_load_options_applied_to_get_multi( + self, db_session: AsyncSession + ): + """default_load_options loads relationships automatically on get_multi().""" + UserWithDefaultLoad = CrudFactory( + User, default_load_options=[selectinload(User.role)] + ) + role = await RoleCrud.create(db_session, RoleCreate(name="admin")) + await UserCrud.create( + db_session, + UserCreate(username="alice", email="alice@test.com", role_id=role.id), + ) + users = await UserWithDefaultLoad.get_multi(db_session) + assert users[0].role is not None + assert users[0].role.name == "admin" + + @pytest.mark.anyio + async def test_default_load_options_applied_to_first( + self, db_session: AsyncSession + ): + """default_load_options loads relationships automatically on first().""" + UserWithDefaultLoad = CrudFactory( + User, default_load_options=[selectinload(User.role)] + ) + role = await RoleCrud.create(db_session, RoleCreate(name="admin")) + await UserCrud.create( + db_session, + UserCreate(username="alice", email="alice@test.com", role_id=role.id), + ) + user = await UserWithDefaultLoad.first(db_session) + assert user is not None + assert user.role is not None + assert user.role.name == "admin" + + @pytest.mark.anyio + async def test_default_load_options_applied_to_paginate( + self, db_session: AsyncSession + ): + """default_load_options loads relationships automatically on paginate().""" + UserWithDefaultLoad = CrudFactory( + User, default_load_options=[selectinload(User.role)] + ) + role = await RoleCrud.create(db_session, RoleCreate(name="admin")) + await UserCrud.create( + db_session, + UserCreate(username="alice", email="alice@test.com", role_id=role.id), + ) + result = await UserWithDefaultLoad.paginate(db_session) + assert result.data[0].role is not None + assert result.data[0].role.name == "admin" + + @pytest.mark.anyio + async def test_load_options_overrides_default_load_options( + self, db_session: AsyncSession + ): + """Explicit load_options fully replaces default_load_options.""" + PostWithDefaultLoad = CrudFactory( + Post, + default_load_options=[selectinload(Post.tags)], + ) + user = await UserCrud.create( + db_session, + UserCreate(username="alice", email="alice@test.com"), + ) + post = await PostCrud.create( + db_session, + PostCreate(title="Hello", author_id=user.id), + ) + # Pass empty load_options to override default — tags should not load + fetched = await PostWithDefaultLoad.get( + db_session, + [Post.id == post.id], + load_options=[], + ) + # tags were not loaded — accessing them would lazy-load or return empty + # We just assert the fetch itself succeeded with the override + assert fetched.id == post.id + class TestCrudCreate: """Tests for CRUD create operations."""