diff --git a/src/fastapi_toolsets/crud/__init__.py b/src/fastapi_toolsets/crud/__init__.py index 95bf390..f6892fd 100644 --- a/src/fastapi_toolsets/crud/__init__.py +++ b/src/fastapi_toolsets/crud/__init__.py @@ -15,12 +15,13 @@ from ..types import ( OrderFieldType, SearchFieldType, ) -from .factory import AsyncCrud, CrudFactory +from .factory import AsyncCrud, CrudFactory, lateral_load from .search import SearchConfig, get_searchable_fields __all__ = [ "AsyncCrud", "CrudFactory", + "lateral_load", "FacetFieldType", "get_searchable_fields", "InvalidFacetFilterError", diff --git a/src/fastapi_toolsets/crud/factory.py b/src/fastapi_toolsets/crud/factory.py index 4bca120..2246bc3 100644 --- a/src/fastapi_toolsets/crud/factory.py +++ b/src/fastapi_toolsets/crud/factory.py @@ -10,15 +10,32 @@ from collections.abc import Awaitable, Callable, Sequence from datetime import date, datetime from decimal import Decimal from enum import Enum -from typing import Any, ClassVar, Generic, Literal, Self, cast, overload +from typing import Any, ClassVar, Generic, Literal, NamedTuple, Self, cast, overload from fastapi import Query from pydantic import BaseModel -from sqlalchemy import Date, DateTime, Float, Integer, Numeric, Uuid, and_, func, select +from sqlalchemy import ( + Date, + DateTime, + Float, + Integer, + Numeric, + Uuid, + and_, + func, + select, + true, +) from sqlalchemy.dialects.postgresql import insert from sqlalchemy.exc import NoResultFound from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import DeclarativeBase, QueryableAttribute, selectinload +from sqlalchemy.orm import ( + DeclarativeBase, + QueryableAttribute, + RelationshipProperty, + contains_eager, + selectinload, +) from sqlalchemy.sql.base import ExecutableOption from sqlalchemy.sql.roles import WhereHavingRole @@ -35,6 +52,7 @@ from ..schemas import ( from ..types import ( FacetFieldType, JoinType, + LateralJoinType, M2MFieldType, ModelType, OrderByClause, @@ -115,6 +133,78 @@ def _apply_joins(q: Any, joins: JoinType | None, outer_join: bool) -> Any: return q +class _ResolvedLateral(NamedTuple): + joins: LateralJoinType + eager: list[ExecutableOption] + + +class _LateralLoad: + """Marker used inside ``default_load_options`` for lateral join loading. + + Supports only Many:One and One:One relationships (single row per parent). + """ + + __slots__ = ("rel_attr",) + + def __init__(self, rel_attr: QueryableAttribute) -> None: + prop = rel_attr.property + if not isinstance(prop, RelationshipProperty): + raise TypeError( + f"lateral_load() requires a relationship attribute, got {type(prop).__name__}. " + "Example: lateral_load(User.team)" + ) + if prop.secondary is not None: + raise ValueError( + f"lateral_load({rel_attr}) does not support Many:Many relationships. " + "Use selectinload() instead." + ) + if prop.uselist: + raise ValueError( + f"lateral_load({rel_attr}) does not support One:Many relationships. " + "Use selectinload() instead." + ) + self.rel_attr = rel_attr + + +def lateral_load(rel_attr: QueryableAttribute) -> _LateralLoad: + """Mark a Many:One or One:One relationship for lateral join loading. + + Raises ``ValueError`` for One:Many or Many:Many relationships. + """ + return _LateralLoad(rel_attr) + + +def _build_lateral_from_relationship( + rel_attr: QueryableAttribute, +) -> tuple[Any, Any, ExecutableOption]: + """Introspect a Many:One relationship and build (lateral_subquery, true(), contains_eager).""" + prop = rel_attr.property + target_class = prop.mapper.class_ + parent_class = prop.parent.class_ + + conditions = [ + getattr(target_class, remote_col.key) == getattr(parent_class, local_col.key) + for local_col, remote_col in prop.local_remote_pairs + ] + + lateral_sub = ( + select(target_class) + .where(and_(*conditions)) + .correlate(parent_class) + .lateral(f"_lateral_{prop.key}") + ) + return lateral_sub, true(), contains_eager(rel_attr, alias=lateral_sub) + + +def _apply_lateral_joins(q: Any, lateral_joins: LateralJoinType | None) -> Any: + """Apply lateral subqueries as LEFT JOIN LATERAL to preserve all parent rows.""" + if not lateral_joins: + return q + for subquery, condition in lateral_joins: + q = q.outerjoin(subquery, condition) + return q + + def _apply_search_joins(q: Any, search_joins: list[Any]) -> Any: """Apply relationship-based outer joins (from search/filter_by) to a query.""" seen: set[str] = set() @@ -132,12 +222,17 @@ class AsyncCrud(Generic[ModelType]): Subclass this and set the `model` class variable, or use `CrudFactory`. """ + _resolved_lateral: ClassVar[_ResolvedLateral | None] = None + model: ClassVar[type[DeclarativeBase]] searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None facet_fields: ClassVar[Sequence[FacetFieldType] | None] = None order_fields: ClassVar[Sequence[OrderFieldType] | None] = None m2m_fields: ClassVar[M2MFieldType | None] = None - default_load_options: ClassVar[Sequence[ExecutableOption] | None] = None + default_load_options: ClassVar[Sequence[ExecutableOption | _LateralLoad] | None] = ( + None + ) + lateral_joins: ClassVar[LateralJoinType | None] = None cursor_column: ClassVar[Any | None] = None @classmethod @@ -161,14 +256,48 @@ class AsyncCrud(Generic[ModelType]): ): cls.searchable_fields = [pk_col, *raw_fields] + raw_default_opts = cls.__dict__.get("default_load_options", None) + if raw_default_opts: + joins: LateralJoinType = [] + eager: list[ExecutableOption] = [] + clean: list[ExecutableOption] = [] + for opt in raw_default_opts: + if isinstance(opt, _LateralLoad): + lat_sub, condition, eager_opt = _build_lateral_from_relationship( + opt.rel_attr + ) + joins.append((lat_sub, condition)) + eager.append(eager_opt) + else: + clean.append(opt) + if joins: + cls._resolved_lateral = _ResolvedLateral(joins=joins, eager=eager) + cls.default_load_options = clean or None + + @classmethod + def _get_lateral_joins(cls) -> LateralJoinType | None: + """Merge manual lateral_joins with ones resolved from default_load_options.""" + resolved = cls._resolved_lateral + all_lateral = [ + *(cls.lateral_joins or []), + *(resolved.joins if resolved else []), + ] + return all_lateral or None + @classmethod def _resolve_load_options( cls, load_options: Sequence[ExecutableOption] | None ) -> Sequence[ExecutableOption] | None: - """Return load_options if provided, else fall back to default_load_options.""" + """Return merged load options.""" if load_options is not None: - return load_options - return cls.default_load_options + return list(load_options) or None + resolved = cls._resolved_lateral + # default_load_options is cleaned of _LateralLoad markers in __init_subclass__, + # but its declared type still includes them — cast to reflect the runtime invariant. + base = cast(list[ExecutableOption], cls.default_load_options or []) + lateral = resolved.eager if resolved else [] + merged = [*base, *lateral] + return merged or None @classmethod async def _reload_with_options( @@ -861,6 +990,8 @@ class AsyncCrud(Generic[ModelType]): """ q = select(cls.model) q = _apply_joins(q, joins, outer_join) + if load_options is None: + q = _apply_lateral_joins(q, cls._get_lateral_joins()) q = q.where(and_(*filters)) if resolved := cls._resolve_load_options(load_options): q = q.options(*resolved) @@ -933,6 +1064,8 @@ class AsyncCrud(Generic[ModelType]): """ q = select(cls.model) q = _apply_joins(q, joins, outer_join) + if load_options is None: + q = _apply_lateral_joins(q, cls._get_lateral_joins()) if filters: q = q.where(and_(*filters)) if resolved := cls._resolve_load_options(load_options): @@ -978,6 +1111,8 @@ class AsyncCrud(Generic[ModelType]): """ q = select(cls.model) q = _apply_joins(q, joins, outer_join) + if load_options is None: + q = _apply_lateral_joins(q, cls._get_lateral_joins()) if filters: q = q.where(and_(*filters)) if resolved := cls._resolve_load_options(load_options): @@ -1300,6 +1435,10 @@ class AsyncCrud(Generic[ModelType]): # Apply explicit joins q = _apply_joins(q, joins, outer_join) + # Apply lateral joins (Many:One relationship loading, excluded from count query) + if load_options is None: + q = _apply_lateral_joins(q, cls._get_lateral_joins()) + # Apply search joins (always outer joins for search) q = _apply_search_joins(q, search_joins) @@ -1398,7 +1537,9 @@ class AsyncCrud(Generic[ModelType]): tables. outer_join: Use LEFT OUTER JOIN instead of INNER JOIN. load_options: SQLAlchemy loader options. Falls back to - ``default_load_options`` when not provided. + ``default_load_options`` (including any lateral joins) when not + provided. When explicitly supplied, the caller takes full control + and lateral joins are skipped. order_by: Additional ordering applied after the cursor column. items_per_page: Number of items per page (default 20). search: Search query string or SearchConfig object. @@ -1455,6 +1596,10 @@ class AsyncCrud(Generic[ModelType]): # Apply explicit joins q = _apply_joins(q, joins, outer_join) + # Apply lateral joins (Many:One relationship loading) + if load_options is None: + q = _apply_lateral_joins(q, cls._get_lateral_joins()) + # Apply search joins (always outer joins) q = _apply_search_joins(q, search_joins) diff --git a/src/fastapi_toolsets/types.py b/src/fastapi_toolsets/types.py index 517a3e5..ebac976 100644 --- a/src/fastapi_toolsets/types.py +++ b/src/fastapi_toolsets/types.py @@ -16,6 +16,7 @@ SchemaType = TypeVar("SchemaType", bound=BaseModel) # CRUD type aliases JoinType = list[tuple[type[DeclarativeBase] | Any, Any]] +LateralJoinType = list[tuple[Any, Any]] M2MFieldType = Mapping[str, QueryableAttribute[Any]] OrderByClause = ColumnElement[Any] | QueryableAttribute[Any] diff --git a/tests/test_crud.py b/tests/test_crud.py index f44652e..901e0d6 100644 --- a/tests/test_crud.py +++ b/tests/test_crud.py @@ -6,9 +6,15 @@ import pytest from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload -from fastapi_toolsets.crud import CrudFactory, PaginationType -from fastapi_toolsets.crud.factory import AsyncCrud, _CursorDirection +from fastapi_toolsets.crud import CrudFactory, PaginationType, lateral_load +from fastapi_toolsets.crud.factory import ( + AsyncCrud, + _CursorDirection, + _LateralLoad, + _ResolvedLateral, +) from fastapi_toolsets.exceptions import NotFoundError +from fastapi_toolsets.schemas import PydanticBase from .conftest import ( EventCreate, @@ -51,6 +57,12 @@ from .conftest import ( ) +class UserWithRoleRead(PydanticBase): + id: uuid.UUID + username: str + role: RoleRead | None = None + + class TestCrudFactory: """Tests for CrudFactory.""" @@ -208,11 +220,11 @@ class TestResolveLoadOptions: assert crud._resolve_load_options(None) is None def test_empty_list_overrides_default(self): - """An empty list is a valid override and disables default_load_options.""" + """An explicit empty list disables default_load_options (no options applied).""" default = [selectinload(User.role)] crud = CrudFactory(User, default_load_options=default) - # Empty list is not None, so it should replace default - assert crud._resolve_load_options([]) == [] + # Empty list replaces default; None and [] are both falsy → no options applied + assert not crud._resolve_load_options([]) class TestResolveSearchColumns: @@ -359,13 +371,6 @@ class TestDefaultLoadOptionsIntegration: self, db_session: AsyncSession ): """default_load_options loads relationships automatically on offset_paginate().""" - from fastapi_toolsets.schemas import PydanticBase - - class UserWithRoleRead(PydanticBase): - id: uuid.UUID - username: str - role: RoleRead | None = None - UserWithDefaultLoad = CrudFactory( User, default_load_options=[selectinload(User.role)] ) @@ -2462,12 +2467,7 @@ class TestCursorPaginateExtraOptions: @pytest.mark.anyio async def test_with_load_options(self, db_session: AsyncSession): """cursor_paginate passes load_options to the query.""" - from fastapi_toolsets.schemas import CursorPagination, PydanticBase - - class UserWithRoleRead(PydanticBase): - id: uuid.UUID - username: str - role: RoleRead | None = None + from fastapi_toolsets.schemas import CursorPagination role = await RoleCrud.create(db_session, RoleCreate(name="manager")) for i in range(3): @@ -2833,3 +2833,445 @@ class TestPaginate: assert isinstance(result.pagination, OffsetPagination) assert result.pagination.total_count is None + + +class TestLateralLoadValidation: + """lateral_load() raises immediately for bad relationship types.""" + + def test_valid_many_to_one_returns_marker(self): + """lateral_load() on a Many:One rel returns a _LateralLoad with rel_attr set.""" + marker = lateral_load(User.role) + assert isinstance(marker, _LateralLoad) + assert marker.rel_attr is User.role + + def test_raises_type_error_for_plain_column(self): + """lateral_load() raises TypeError when passed a plain column.""" + with pytest.raises(TypeError, match="relationship attribute"): + lateral_load(User.username) + + def test_raises_value_error_for_many_to_many(self): + """lateral_load() raises ValueError for Many:Many (secondary table).""" + with pytest.raises(ValueError, match="Many:Many"): + lateral_load(Post.tags) + + def test_raises_value_error_for_one_to_many(self): + """lateral_load() raises ValueError for One:Many (uselist=True).""" + with pytest.raises(ValueError, match="One:Many"): + lateral_load(Role.users) + + +class TestLateralLoadInSubclass: + """lateral_load() markers in default_load_options are processed at class definition.""" + + def test_marker_extracted_from_default_load_options(self): + """_LateralLoad is removed from default_load_options and stored in _resolved_lateral.""" + + class UserLateralCrud(AsyncCrud[User]): + model = User + default_load_options = [lateral_load(User.role)] + + assert UserLateralCrud.default_load_options is None + assert UserLateralCrud._resolved_lateral is not None + + def test_resolved_lateral_has_one_join_and_eager(self): + """_resolved_lateral contains exactly one join and one eager option.""" + + class UserLateralCrud(AsyncCrud[User]): + model = User + default_load_options = [lateral_load(User.role)] + + resolved = UserLateralCrud._resolved_lateral + assert isinstance(resolved, _ResolvedLateral) + assert len(resolved.joins) == 1 + assert len(resolved.eager) == 1 + + def test_regular_options_preserved_alongside_lateral(self): + """Non-lateral opts stay in default_load_options; lateral marker is extracted.""" + regular = selectinload(User.role) + + class UserMixedCrud(AsyncCrud[User]): + model = User + default_load_options = [lateral_load(User.role), regular] + + assert UserMixedCrud._resolved_lateral is not None + assert UserMixedCrud.default_load_options == [regular] + + def test_no_lateral_leaves_default_load_options_untouched(self): + """When no lateral marker is present, default_load_options is unchanged.""" + opts = [selectinload(User.role)] + + class UserNormalCrud(AsyncCrud[User]): + model = User + default_load_options = opts + + assert UserNormalCrud.default_load_options is opts + assert UserNormalCrud._resolved_lateral is None + + def test_no_default_load_options_leaves_resolved_lateral_none(self): + """_resolved_lateral stays None when default_load_options is not set.""" + + class UserPlainCrud(AsyncCrud[User]): + model = User + + assert UserPlainCrud._resolved_lateral is None + + +class TestResolveLoadOptionsWithLateral: + """_resolve_load_options always appends lateral eager options.""" + + def test_lateral_eager_included_when_no_call_site_opts(self): + """contains_eager from lateral_load is returned when load_options=None.""" + + class UserLateralCrud(AsyncCrud[User]): + model = User + default_load_options = [lateral_load(User.role)] + + resolved = UserLateralCrud._resolve_load_options(None) + assert resolved is not None + assert len(resolved) == 1 # the contains_eager + + def test_call_site_opts_bypass_lateral_eager(self): + """When call-site load_options are provided, lateral eager is NOT appended.""" + extra = selectinload(User.role) + + class UserLateralCrud(AsyncCrud[User]): + model = User + default_load_options = [lateral_load(User.role)] + + resolved = UserLateralCrud._resolve_load_options([extra]) + assert resolved is not None + assert len(resolved) == 1 # only the call-site option; lateral eager skipped + + def test_lateral_eager_appended_to_default_load_options(self): + """default_load_options (regular) + lateral eager are both returned.""" + regular = selectinload(User.role) + + class UserMixedCrud(AsyncCrud[User]): + model = User + default_load_options = [lateral_load(User.role), regular] + + resolved = UserMixedCrud._resolve_load_options(None) + assert resolved is not None + assert len(resolved) == 2 + + +class TestGetLateralJoins: + """_get_lateral_joins merges auto-resolved and manual lateral_joins.""" + + def test_returns_none_when_no_lateral_configured(self): + """Returns None when neither lateral_joins nor lateral_load is set.""" + + class UserPlainCrud(AsyncCrud[User]): + model = User + + assert UserPlainCrud._get_lateral_joins() is None + + def test_returns_resolved_lateral_joins(self): + """Returns the join tuple built from lateral_load().""" + + class UserLateralCrud(AsyncCrud[User]): + model = User + default_load_options = [lateral_load(User.role)] + + joins = UserLateralCrud._get_lateral_joins() + assert joins is not None + assert len(joins) == 1 + + def test_manual_lateral_joins_included(self): + """Manual lateral_joins class var is included in _get_lateral_joins.""" + from sqlalchemy import select, true + + manual_sub = select(Role).where(Role.id == User.role_id).lateral("_manual_role") + + class UserManualCrud(AsyncCrud[User]): + model = User + lateral_joins = [(manual_sub, true())] + + joins = UserManualCrud._get_lateral_joins() + assert joins is not None + assert len(joins) == 1 + + def test_manual_and_auto_lateral_joins_merged(self): + """Both manual lateral_joins and auto-resolved from lateral_load are combined.""" + from sqlalchemy import select, true + + manual_sub = select(Role).where(Role.id == User.role_id).lateral("_manual_role") + + class UserBothCrud(AsyncCrud[User]): + model = User + lateral_joins = [(manual_sub, true())] + default_load_options = [lateral_load(User.role)] + + joins = UserBothCrud._get_lateral_joins() + assert joins is not None + assert len(joins) == 2 + + +class TestLateralLoadIntegration: + """lateral_load() in real DB queries: relationship loaded, pagination correct.""" + + @pytest.mark.anyio + async def test_get_loads_relationship(self, db_session: AsyncSession): + """get() populates the relationship via lateral join.""" + + class UserLateralCrud(AsyncCrud[User]): + model = User + default_load_options = [lateral_load(User.role)] + + role = await RoleCrud.create(db_session, RoleCreate(name="admin")) + user = await UserCrud.create( + db_session, + UserCreate(username="alice", email="alice@test.com", role_id=role.id), + ) + + fetched = await UserLateralCrud.get(db_session, [User.id == user.id]) + assert fetched.role is not None + assert fetched.role.name == "admin" + + @pytest.mark.anyio + async def test_get_null_fk_preserved(self, db_session: AsyncSession): + """User with null role_id still returned (LEFT JOIN behaviour).""" + + class UserLateralCrud(AsyncCrud[User]): + model = User + default_load_options = [lateral_load(User.role)] + + user = await UserCrud.create( + db_session, UserCreate(username="bob", email="bob@test.com") + ) + + fetched = await UserLateralCrud.get(db_session, [User.id == user.id]) + assert fetched is not None + assert fetched.role is None + + @pytest.mark.anyio + async def test_first_loads_relationship(self, db_session: AsyncSession): + """first() populates the relationship via lateral join.""" + + class UserLateralCrud(AsyncCrud[User]): + model = User + default_load_options = [lateral_load(User.role)] + + role = await RoleCrud.create(db_session, RoleCreate(name="editor")) + await UserCrud.create( + db_session, + UserCreate(username="carol", email="carol@test.com", role_id=role.id), + ) + + user = await UserLateralCrud.first(db_session) + assert user is not None + assert user.role is not None + assert user.role.name == "editor" + + @pytest.mark.anyio + async def test_get_multi_loads_relationship(self, db_session: AsyncSession): + """get_multi() populates the relationship via lateral join for all rows.""" + + class UserLateralCrud(AsyncCrud[User]): + model = User + default_load_options = [lateral_load(User.role)] + + role = await RoleCrud.create(db_session, RoleCreate(name="member")) + for i in range(3): + await UserCrud.create( + db_session, + UserCreate( + username=f"user{i}", email=f"u{i}@test.com", role_id=role.id + ), + ) + + users = await UserLateralCrud.get_multi(db_session) + assert len(users) == 3 + assert all(u.role is not None and u.role.name == "member" for u in users) + + @pytest.mark.anyio + async def test_offset_paginate_correct_count(self, db_session: AsyncSession): + """offset_paginate total_count is not inflated by the lateral join.""" + + class UserLateralCrud(AsyncCrud[User]): + model = User + default_load_options = [lateral_load(User.role)] + + role = await RoleCrud.create(db_session, RoleCreate(name="admin")) + for i in range(5): + await UserCrud.create( + db_session, + UserCreate( + username=f"user{i}", email=f"u{i}@test.com", role_id=role.id + ), + ) + + result = await UserLateralCrud.offset_paginate( + db_session, schema=UserWithRoleRead, items_per_page=10 + ) + assert result.pagination.total_count == 5 + assert len(result.data) == 5 + + @pytest.mark.anyio + async def test_offset_paginate_loads_relationship(self, db_session: AsyncSession): + """offset_paginate serializes relationship data loaded via lateral.""" + + class UserLateralCrud(AsyncCrud[User]): + model = User + default_load_options = [lateral_load(User.role)] + + role = await RoleCrud.create(db_session, RoleCreate(name="admin")) + await UserCrud.create( + db_session, + UserCreate(username="alice", email="alice@test.com", role_id=role.id), + ) + + result = await UserLateralCrud.offset_paginate( + db_session, schema=UserWithRoleRead, items_per_page=10 + ) + assert result.data[0].role is not None + assert result.data[0].role.name == "admin" + + @pytest.mark.anyio + async def test_offset_paginate_mixed_null_fk(self, db_session: AsyncSession): + """offset_paginate returns all users including those with null role_id.""" + + class UserLateralCrud(AsyncCrud[User]): + model = User + default_load_options = [lateral_load(User.role)] + + role = await RoleCrud.create(db_session, RoleCreate(name="admin")) + await UserCrud.create( + db_session, + UserCreate(username="with_role", email="a@test.com", role_id=role.id), + ) + await UserCrud.create( + db_session, UserCreate(username="no_role", email="b@test.com") + ) + + result = await UserLateralCrud.offset_paginate( + db_session, schema=UserWithRoleRead, items_per_page=10 + ) + assert result.pagination.total_count == 2 + + @pytest.mark.anyio + async def test_cursor_paginate_loads_relationship(self, db_session: AsyncSession): + """cursor_paginate populates the relationship via lateral join.""" + + class UserLateralCursorCrud(AsyncCrud[User]): + model = User + cursor_column = User.id + default_load_options = [lateral_load(User.role)] + + role = await RoleCrud.create(db_session, RoleCreate(name="admin")) + for i in range(3): + await UserCrud.create( + db_session, + UserCreate( + username=f"user{i}", email=f"u{i}@test.com", role_id=role.id + ), + ) + + result = await UserLateralCursorCrud.cursor_paginate( + db_session, schema=UserWithRoleRead, items_per_page=10 + ) + assert len(result.data) == 3 + assert all(item.role is not None for item in result.data) + + @pytest.mark.anyio + async def test_offset_paginate_with_search_and_lateral( + self, db_session: AsyncSession + ): + """search filter works alongside lateral join.""" + + class UserLateralCrud(AsyncCrud[User]): + model = User + default_load_options = [lateral_load(User.role)] + searchable_fields = [User.username] + + role = await RoleCrud.create(db_session, RoleCreate(name="admin")) + await UserCrud.create( + db_session, + UserCreate(username="alice", email="a@test.com", role_id=role.id), + ) + await UserCrud.create( + db_session, UserCreate(username="bob", email="b@test.com", role_id=role.id) + ) + + result = await UserLateralCrud.offset_paginate( + db_session, schema=UserWithRoleRead, search="alice", items_per_page=10 + ) + assert result.pagination.total_count == 1 + assert result.data[0].username == "alice" + + @pytest.mark.anyio + async def test_first_call_site_load_options_bypasses_lateral( + self, db_session: AsyncSession + ): + """When load_options is provided, lateral join is skipped (no conflict).""" + + class UserLateralCrud(AsyncCrud[User]): + model = User + default_load_options = [lateral_load(User.role)] + + role = await RoleCrud.create(db_session, RoleCreate(name="admin")) + user = await UserCrud.create( + db_session, + UserCreate(username="alice", email="alice@test.com", role_id=role.id), + ) + + # Passing explicit load_options bypasses the lateral join — role loaded via selectinload + fetched = await UserLateralCrud.first( + db_session, + filters=[User.id == user.id], + load_options=[selectinload(User.role)], + ) + assert fetched is not None + assert fetched.role is not None + assert fetched.role.name == "admin" + + @pytest.mark.anyio + async def test_get_multi_call_site_load_options_bypasses_lateral( + self, db_session: AsyncSession + ): + """When load_options is provided, lateral join is skipped (no conflict).""" + + class UserLateralCrud(AsyncCrud[User]): + model = User + default_load_options = [lateral_load(User.role)] + + role = await RoleCrud.create(db_session, RoleCreate(name="viewer")) + for i in range(2): + await UserCrud.create( + db_session, + UserCreate(username=f"u{i}", email=f"u{i}@test.com", role_id=role.id), + ) + + # Passing explicit load_options bypasses the lateral join — role loaded via selectinload + users = await UserLateralCrud.get_multi( + db_session, load_options=[selectinload(User.role)] + ) + assert len(users) == 2 + assert all(u.role is not None and u.role.name == "viewer" for u in users) + + @pytest.mark.anyio + async def test_offset_paginate_call_site_load_options_bypasses_lateral( + self, db_session: AsyncSession + ): + """When load_options is provided, lateral join is skipped (no conflict).""" + + class UserLateralCrud(AsyncCrud[User]): + model = User + default_load_options = [lateral_load(User.role)] + + role = await RoleCrud.create(db_session, RoleCreate(name="editor")) + for i in range(3): + await UserCrud.create( + db_session, + UserCreate(username=f"e{i}", email=f"e{i}@test.com", role_id=role.id), + ) + + # Passing explicit load_options bypasses the lateral join — role loaded via selectinload + result = await UserLateralCrud.offset_paginate( + db_session, + schema=UserWithRoleRead, + items_per_page=10, + load_options=[selectinload(User.role)], + ) + assert result.pagination.total_count == 3 + assert all(item.role is not None for item in result.data)