feat: add lateral_load() for efficient Many:One eager loading via JOIN LATERAL

This commit is contained in:
2026-04-07 13:07:34 -04:00
parent 0ed93d62c8
commit e1f96ad7fe
4 changed files with 616 additions and 27 deletions

View File

@@ -15,12 +15,13 @@ from ..types import (
OrderFieldType,
SearchFieldType,
)
from .factory import AsyncCrud, CrudFactory
from .factory import AsyncCrud, CrudFactory, lateral_load
from .search import SearchConfig, get_searchable_fields
__all__ = [
"AsyncCrud",
"CrudFactory",
"lateral_load",
"FacetFieldType",
"get_searchable_fields",
"InvalidFacetFilterError",

View File

@@ -10,15 +10,32 @@ from collections.abc import Awaitable, Callable, Sequence
from datetime import date, datetime
from decimal import Decimal
from enum import Enum
from typing import Any, ClassVar, Generic, Literal, Self, cast, overload
from typing import Any, ClassVar, Generic, Literal, NamedTuple, Self, cast, overload
from fastapi import Query
from pydantic import BaseModel
from sqlalchemy import Date, DateTime, Float, Integer, Numeric, Uuid, and_, func, select
from sqlalchemy import (
Date,
DateTime,
Float,
Integer,
Numeric,
Uuid,
and_,
func,
select,
true,
)
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.exc import NoResultFound
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase, QueryableAttribute, selectinload
from sqlalchemy.orm import (
DeclarativeBase,
QueryableAttribute,
RelationshipProperty,
contains_eager,
selectinload,
)
from sqlalchemy.sql.base import ExecutableOption
from sqlalchemy.sql.roles import WhereHavingRole
@@ -35,6 +52,7 @@ from ..schemas import (
from ..types import (
FacetFieldType,
JoinType,
LateralJoinType,
M2MFieldType,
ModelType,
OrderByClause,
@@ -115,6 +133,78 @@ def _apply_joins(q: Any, joins: JoinType | None, outer_join: bool) -> Any:
return q
class _ResolvedLateral(NamedTuple):
joins: LateralJoinType
eager: list[ExecutableOption]
class _LateralLoad:
"""Marker used inside ``default_load_options`` for lateral join loading.
Supports only Many:One and One:One relationships (single row per parent).
"""
__slots__ = ("rel_attr",)
def __init__(self, rel_attr: QueryableAttribute) -> None:
prop = rel_attr.property
if not isinstance(prop, RelationshipProperty):
raise TypeError(
f"lateral_load() requires a relationship attribute, got {type(prop).__name__}. "
"Example: lateral_load(User.team)"
)
if prop.secondary is not None:
raise ValueError(
f"lateral_load({rel_attr}) does not support Many:Many relationships. "
"Use selectinload() instead."
)
if prop.uselist:
raise ValueError(
f"lateral_load({rel_attr}) does not support One:Many relationships. "
"Use selectinload() instead."
)
self.rel_attr = rel_attr
def lateral_load(rel_attr: QueryableAttribute) -> _LateralLoad:
"""Mark a Many:One or One:One relationship for lateral join loading.
Raises ``ValueError`` for One:Many or Many:Many relationships.
"""
return _LateralLoad(rel_attr)
def _build_lateral_from_relationship(
rel_attr: QueryableAttribute,
) -> tuple[Any, Any, ExecutableOption]:
"""Introspect a Many:One relationship and build (lateral_subquery, true(), contains_eager)."""
prop = rel_attr.property
target_class = prop.mapper.class_
parent_class = prop.parent.class_
conditions = [
getattr(target_class, remote_col.key) == getattr(parent_class, local_col.key)
for local_col, remote_col in prop.local_remote_pairs
]
lateral_sub = (
select(target_class)
.where(and_(*conditions))
.correlate(parent_class)
.lateral(f"_lateral_{prop.key}")
)
return lateral_sub, true(), contains_eager(rel_attr, alias=lateral_sub)
def _apply_lateral_joins(q: Any, lateral_joins: LateralJoinType | None) -> Any:
"""Apply lateral subqueries as LEFT JOIN LATERAL to preserve all parent rows."""
if not lateral_joins:
return q
for subquery, condition in lateral_joins:
q = q.outerjoin(subquery, condition)
return q
def _apply_search_joins(q: Any, search_joins: list[Any]) -> Any:
"""Apply relationship-based outer joins (from search/filter_by) to a query."""
seen: set[str] = set()
@@ -132,12 +222,17 @@ class AsyncCrud(Generic[ModelType]):
Subclass this and set the `model` class variable, or use `CrudFactory`.
"""
_resolved_lateral: ClassVar[_ResolvedLateral | None] = None
model: ClassVar[type[DeclarativeBase]]
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
facet_fields: ClassVar[Sequence[FacetFieldType] | None] = None
order_fields: ClassVar[Sequence[OrderFieldType] | None] = None
m2m_fields: ClassVar[M2MFieldType | None] = None
default_load_options: ClassVar[Sequence[ExecutableOption] | None] = None
default_load_options: ClassVar[Sequence[ExecutableOption | _LateralLoad] | None] = (
None
)
lateral_joins: ClassVar[LateralJoinType | None] = None
cursor_column: ClassVar[Any | None] = None
@classmethod
@@ -161,14 +256,48 @@ class AsyncCrud(Generic[ModelType]):
):
cls.searchable_fields = [pk_col, *raw_fields]
raw_default_opts = cls.__dict__.get("default_load_options", None)
if raw_default_opts:
joins: LateralJoinType = []
eager: list[ExecutableOption] = []
clean: list[ExecutableOption] = []
for opt in raw_default_opts:
if isinstance(opt, _LateralLoad):
lat_sub, condition, eager_opt = _build_lateral_from_relationship(
opt.rel_attr
)
joins.append((lat_sub, condition))
eager.append(eager_opt)
else:
clean.append(opt)
if joins:
cls._resolved_lateral = _ResolvedLateral(joins=joins, eager=eager)
cls.default_load_options = clean or None
@classmethod
def _get_lateral_joins(cls) -> LateralJoinType | None:
"""Merge manual lateral_joins with ones resolved from default_load_options."""
resolved = cls._resolved_lateral
all_lateral = [
*(cls.lateral_joins or []),
*(resolved.joins if resolved else []),
]
return all_lateral or None
@classmethod
def _resolve_load_options(
cls, load_options: Sequence[ExecutableOption] | None
) -> Sequence[ExecutableOption] | None:
"""Return load_options if provided, else fall back to default_load_options."""
"""Return merged load options."""
if load_options is not None:
return load_options
return cls.default_load_options
return list(load_options) or None
resolved = cls._resolved_lateral
# default_load_options is cleaned of _LateralLoad markers in __init_subclass__,
# but its declared type still includes them — cast to reflect the runtime invariant.
base = cast(list[ExecutableOption], cls.default_load_options or [])
lateral = resolved.eager if resolved else []
merged = [*base, *lateral]
return merged or None
@classmethod
async def _reload_with_options(
@@ -861,6 +990,8 @@ class AsyncCrud(Generic[ModelType]):
"""
q = select(cls.model)
q = _apply_joins(q, joins, outer_join)
if load_options is None:
q = _apply_lateral_joins(q, cls._get_lateral_joins())
q = q.where(and_(*filters))
if resolved := cls._resolve_load_options(load_options):
q = q.options(*resolved)
@@ -933,6 +1064,8 @@ class AsyncCrud(Generic[ModelType]):
"""
q = select(cls.model)
q = _apply_joins(q, joins, outer_join)
if load_options is None:
q = _apply_lateral_joins(q, cls._get_lateral_joins())
if filters:
q = q.where(and_(*filters))
if resolved := cls._resolve_load_options(load_options):
@@ -978,6 +1111,8 @@ class AsyncCrud(Generic[ModelType]):
"""
q = select(cls.model)
q = _apply_joins(q, joins, outer_join)
if load_options is None:
q = _apply_lateral_joins(q, cls._get_lateral_joins())
if filters:
q = q.where(and_(*filters))
if resolved := cls._resolve_load_options(load_options):
@@ -1300,6 +1435,10 @@ class AsyncCrud(Generic[ModelType]):
# Apply explicit joins
q = _apply_joins(q, joins, outer_join)
# Apply lateral joins (Many:One relationship loading, excluded from count query)
if load_options is None:
q = _apply_lateral_joins(q, cls._get_lateral_joins())
# Apply search joins (always outer joins for search)
q = _apply_search_joins(q, search_joins)
@@ -1398,7 +1537,9 @@ class AsyncCrud(Generic[ModelType]):
tables.
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN.
load_options: SQLAlchemy loader options. Falls back to
``default_load_options`` when not provided.
``default_load_options`` (including any lateral joins) when not
provided. When explicitly supplied, the caller takes full control
and lateral joins are skipped.
order_by: Additional ordering applied after the cursor column.
items_per_page: Number of items per page (default 20).
search: Search query string or SearchConfig object.
@@ -1455,6 +1596,10 @@ class AsyncCrud(Generic[ModelType]):
# Apply explicit joins
q = _apply_joins(q, joins, outer_join)
# Apply lateral joins (Many:One relationship loading)
if load_options is None:
q = _apply_lateral_joins(q, cls._get_lateral_joins())
# Apply search joins (always outer joins)
q = _apply_search_joins(q, search_joins)

View File

@@ -16,6 +16,7 @@ SchemaType = TypeVar("SchemaType", bound=BaseModel)
# CRUD type aliases
JoinType = list[tuple[type[DeclarativeBase] | Any, Any]]
LateralJoinType = list[tuple[Any, Any]]
M2MFieldType = Mapping[str, QueryableAttribute[Any]]
OrderByClause = ColumnElement[Any] | QueryableAttribute[Any]

View File

@@ -6,9 +6,15 @@ import pytest
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from fastapi_toolsets.crud import CrudFactory, PaginationType
from fastapi_toolsets.crud.factory import AsyncCrud, _CursorDirection
from fastapi_toolsets.crud import CrudFactory, PaginationType, lateral_load
from fastapi_toolsets.crud.factory import (
AsyncCrud,
_CursorDirection,
_LateralLoad,
_ResolvedLateral,
)
from fastapi_toolsets.exceptions import NotFoundError
from fastapi_toolsets.schemas import PydanticBase
from .conftest import (
EventCreate,
@@ -51,6 +57,12 @@ from .conftest import (
)
class UserWithRoleRead(PydanticBase):
id: uuid.UUID
username: str
role: RoleRead | None = None
class TestCrudFactory:
"""Tests for CrudFactory."""
@@ -208,11 +220,11 @@ class TestResolveLoadOptions:
assert crud._resolve_load_options(None) is None
def test_empty_list_overrides_default(self):
"""An empty list is a valid override and disables default_load_options."""
"""An explicit empty list disables default_load_options (no options applied)."""
default = [selectinload(User.role)]
crud = CrudFactory(User, default_load_options=default)
# Empty list is not None, so it should replace default
assert crud._resolve_load_options([]) == []
# Empty list replaces default; None and [] are both falsy → no options applied
assert not crud._resolve_load_options([])
class TestResolveSearchColumns:
@@ -359,13 +371,6 @@ class TestDefaultLoadOptionsIntegration:
self, db_session: AsyncSession
):
"""default_load_options loads relationships automatically on offset_paginate()."""
from fastapi_toolsets.schemas import PydanticBase
class UserWithRoleRead(PydanticBase):
id: uuid.UUID
username: str
role: RoleRead | None = None
UserWithDefaultLoad = CrudFactory(
User, default_load_options=[selectinload(User.role)]
)
@@ -2462,12 +2467,7 @@ class TestCursorPaginateExtraOptions:
@pytest.mark.anyio
async def test_with_load_options(self, db_session: AsyncSession):
"""cursor_paginate passes load_options to the query."""
from fastapi_toolsets.schemas import CursorPagination, PydanticBase
class UserWithRoleRead(PydanticBase):
id: uuid.UUID
username: str
role: RoleRead | None = None
from fastapi_toolsets.schemas import CursorPagination
role = await RoleCrud.create(db_session, RoleCreate(name="manager"))
for i in range(3):
@@ -2833,3 +2833,445 @@ class TestPaginate:
assert isinstance(result.pagination, OffsetPagination)
assert result.pagination.total_count is None
class TestLateralLoadValidation:
"""lateral_load() raises immediately for bad relationship types."""
def test_valid_many_to_one_returns_marker(self):
"""lateral_load() on a Many:One rel returns a _LateralLoad with rel_attr set."""
marker = lateral_load(User.role)
assert isinstance(marker, _LateralLoad)
assert marker.rel_attr is User.role
def test_raises_type_error_for_plain_column(self):
"""lateral_load() raises TypeError when passed a plain column."""
with pytest.raises(TypeError, match="relationship attribute"):
lateral_load(User.username)
def test_raises_value_error_for_many_to_many(self):
"""lateral_load() raises ValueError for Many:Many (secondary table)."""
with pytest.raises(ValueError, match="Many:Many"):
lateral_load(Post.tags)
def test_raises_value_error_for_one_to_many(self):
"""lateral_load() raises ValueError for One:Many (uselist=True)."""
with pytest.raises(ValueError, match="One:Many"):
lateral_load(Role.users)
class TestLateralLoadInSubclass:
"""lateral_load() markers in default_load_options are processed at class definition."""
def test_marker_extracted_from_default_load_options(self):
"""_LateralLoad is removed from default_load_options and stored in _resolved_lateral."""
class UserLateralCrud(AsyncCrud[User]):
model = User
default_load_options = [lateral_load(User.role)]
assert UserLateralCrud.default_load_options is None
assert UserLateralCrud._resolved_lateral is not None
def test_resolved_lateral_has_one_join_and_eager(self):
"""_resolved_lateral contains exactly one join and one eager option."""
class UserLateralCrud(AsyncCrud[User]):
model = User
default_load_options = [lateral_load(User.role)]
resolved = UserLateralCrud._resolved_lateral
assert isinstance(resolved, _ResolvedLateral)
assert len(resolved.joins) == 1
assert len(resolved.eager) == 1
def test_regular_options_preserved_alongside_lateral(self):
"""Non-lateral opts stay in default_load_options; lateral marker is extracted."""
regular = selectinload(User.role)
class UserMixedCrud(AsyncCrud[User]):
model = User
default_load_options = [lateral_load(User.role), regular]
assert UserMixedCrud._resolved_lateral is not None
assert UserMixedCrud.default_load_options == [regular]
def test_no_lateral_leaves_default_load_options_untouched(self):
"""When no lateral marker is present, default_load_options is unchanged."""
opts = [selectinload(User.role)]
class UserNormalCrud(AsyncCrud[User]):
model = User
default_load_options = opts
assert UserNormalCrud.default_load_options is opts
assert UserNormalCrud._resolved_lateral is None
def test_no_default_load_options_leaves_resolved_lateral_none(self):
"""_resolved_lateral stays None when default_load_options is not set."""
class UserPlainCrud(AsyncCrud[User]):
model = User
assert UserPlainCrud._resolved_lateral is None
class TestResolveLoadOptionsWithLateral:
"""_resolve_load_options always appends lateral eager options."""
def test_lateral_eager_included_when_no_call_site_opts(self):
"""contains_eager from lateral_load is returned when load_options=None."""
class UserLateralCrud(AsyncCrud[User]):
model = User
default_load_options = [lateral_load(User.role)]
resolved = UserLateralCrud._resolve_load_options(None)
assert resolved is not None
assert len(resolved) == 1 # the contains_eager
def test_call_site_opts_bypass_lateral_eager(self):
"""When call-site load_options are provided, lateral eager is NOT appended."""
extra = selectinload(User.role)
class UserLateralCrud(AsyncCrud[User]):
model = User
default_load_options = [lateral_load(User.role)]
resolved = UserLateralCrud._resolve_load_options([extra])
assert resolved is not None
assert len(resolved) == 1 # only the call-site option; lateral eager skipped
def test_lateral_eager_appended_to_default_load_options(self):
"""default_load_options (regular) + lateral eager are both returned."""
regular = selectinload(User.role)
class UserMixedCrud(AsyncCrud[User]):
model = User
default_load_options = [lateral_load(User.role), regular]
resolved = UserMixedCrud._resolve_load_options(None)
assert resolved is not None
assert len(resolved) == 2
class TestGetLateralJoins:
"""_get_lateral_joins merges auto-resolved and manual lateral_joins."""
def test_returns_none_when_no_lateral_configured(self):
"""Returns None when neither lateral_joins nor lateral_load is set."""
class UserPlainCrud(AsyncCrud[User]):
model = User
assert UserPlainCrud._get_lateral_joins() is None
def test_returns_resolved_lateral_joins(self):
"""Returns the join tuple built from lateral_load()."""
class UserLateralCrud(AsyncCrud[User]):
model = User
default_load_options = [lateral_load(User.role)]
joins = UserLateralCrud._get_lateral_joins()
assert joins is not None
assert len(joins) == 1
def test_manual_lateral_joins_included(self):
"""Manual lateral_joins class var is included in _get_lateral_joins."""
from sqlalchemy import select, true
manual_sub = select(Role).where(Role.id == User.role_id).lateral("_manual_role")
class UserManualCrud(AsyncCrud[User]):
model = User
lateral_joins = [(manual_sub, true())]
joins = UserManualCrud._get_lateral_joins()
assert joins is not None
assert len(joins) == 1
def test_manual_and_auto_lateral_joins_merged(self):
"""Both manual lateral_joins and auto-resolved from lateral_load are combined."""
from sqlalchemy import select, true
manual_sub = select(Role).where(Role.id == User.role_id).lateral("_manual_role")
class UserBothCrud(AsyncCrud[User]):
model = User
lateral_joins = [(manual_sub, true())]
default_load_options = [lateral_load(User.role)]
joins = UserBothCrud._get_lateral_joins()
assert joins is not None
assert len(joins) == 2
class TestLateralLoadIntegration:
"""lateral_load() in real DB queries: relationship loaded, pagination correct."""
@pytest.mark.anyio
async def test_get_loads_relationship(self, db_session: AsyncSession):
"""get() populates the relationship via lateral join."""
class UserLateralCrud(AsyncCrud[User]):
model = User
default_load_options = [lateral_load(User.role)]
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
user = await UserCrud.create(
db_session,
UserCreate(username="alice", email="alice@test.com", role_id=role.id),
)
fetched = await UserLateralCrud.get(db_session, [User.id == user.id])
assert fetched.role is not None
assert fetched.role.name == "admin"
@pytest.mark.anyio
async def test_get_null_fk_preserved(self, db_session: AsyncSession):
"""User with null role_id still returned (LEFT JOIN behaviour)."""
class UserLateralCrud(AsyncCrud[User]):
model = User
default_load_options = [lateral_load(User.role)]
user = await UserCrud.create(
db_session, UserCreate(username="bob", email="bob@test.com")
)
fetched = await UserLateralCrud.get(db_session, [User.id == user.id])
assert fetched is not None
assert fetched.role is None
@pytest.mark.anyio
async def test_first_loads_relationship(self, db_session: AsyncSession):
"""first() populates the relationship via lateral join."""
class UserLateralCrud(AsyncCrud[User]):
model = User
default_load_options = [lateral_load(User.role)]
role = await RoleCrud.create(db_session, RoleCreate(name="editor"))
await UserCrud.create(
db_session,
UserCreate(username="carol", email="carol@test.com", role_id=role.id),
)
user = await UserLateralCrud.first(db_session)
assert user is not None
assert user.role is not None
assert user.role.name == "editor"
@pytest.mark.anyio
async def test_get_multi_loads_relationship(self, db_session: AsyncSession):
"""get_multi() populates the relationship via lateral join for all rows."""
class UserLateralCrud(AsyncCrud[User]):
model = User
default_load_options = [lateral_load(User.role)]
role = await RoleCrud.create(db_session, RoleCreate(name="member"))
for i in range(3):
await UserCrud.create(
db_session,
UserCreate(
username=f"user{i}", email=f"u{i}@test.com", role_id=role.id
),
)
users = await UserLateralCrud.get_multi(db_session)
assert len(users) == 3
assert all(u.role is not None and u.role.name == "member" for u in users)
@pytest.mark.anyio
async def test_offset_paginate_correct_count(self, db_session: AsyncSession):
"""offset_paginate total_count is not inflated by the lateral join."""
class UserLateralCrud(AsyncCrud[User]):
model = User
default_load_options = [lateral_load(User.role)]
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
for i in range(5):
await UserCrud.create(
db_session,
UserCreate(
username=f"user{i}", email=f"u{i}@test.com", role_id=role.id
),
)
result = await UserLateralCrud.offset_paginate(
db_session, schema=UserWithRoleRead, items_per_page=10
)
assert result.pagination.total_count == 5
assert len(result.data) == 5
@pytest.mark.anyio
async def test_offset_paginate_loads_relationship(self, db_session: AsyncSession):
"""offset_paginate serializes relationship data loaded via lateral."""
class UserLateralCrud(AsyncCrud[User]):
model = User
default_load_options = [lateral_load(User.role)]
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
await UserCrud.create(
db_session,
UserCreate(username="alice", email="alice@test.com", role_id=role.id),
)
result = await UserLateralCrud.offset_paginate(
db_session, schema=UserWithRoleRead, items_per_page=10
)
assert result.data[0].role is not None
assert result.data[0].role.name == "admin"
@pytest.mark.anyio
async def test_offset_paginate_mixed_null_fk(self, db_session: AsyncSession):
"""offset_paginate returns all users including those with null role_id."""
class UserLateralCrud(AsyncCrud[User]):
model = User
default_load_options = [lateral_load(User.role)]
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
await UserCrud.create(
db_session,
UserCreate(username="with_role", email="a@test.com", role_id=role.id),
)
await UserCrud.create(
db_session, UserCreate(username="no_role", email="b@test.com")
)
result = await UserLateralCrud.offset_paginate(
db_session, schema=UserWithRoleRead, items_per_page=10
)
assert result.pagination.total_count == 2
@pytest.mark.anyio
async def test_cursor_paginate_loads_relationship(self, db_session: AsyncSession):
"""cursor_paginate populates the relationship via lateral join."""
class UserLateralCursorCrud(AsyncCrud[User]):
model = User
cursor_column = User.id
default_load_options = [lateral_load(User.role)]
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
for i in range(3):
await UserCrud.create(
db_session,
UserCreate(
username=f"user{i}", email=f"u{i}@test.com", role_id=role.id
),
)
result = await UserLateralCursorCrud.cursor_paginate(
db_session, schema=UserWithRoleRead, items_per_page=10
)
assert len(result.data) == 3
assert all(item.role is not None for item in result.data)
@pytest.mark.anyio
async def test_offset_paginate_with_search_and_lateral(
self, db_session: AsyncSession
):
"""search filter works alongside lateral join."""
class UserLateralCrud(AsyncCrud[User]):
model = User
default_load_options = [lateral_load(User.role)]
searchable_fields = [User.username]
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
await UserCrud.create(
db_session,
UserCreate(username="alice", email="a@test.com", role_id=role.id),
)
await UserCrud.create(
db_session, UserCreate(username="bob", email="b@test.com", role_id=role.id)
)
result = await UserLateralCrud.offset_paginate(
db_session, schema=UserWithRoleRead, search="alice", items_per_page=10
)
assert result.pagination.total_count == 1
assert result.data[0].username == "alice"
@pytest.mark.anyio
async def test_first_call_site_load_options_bypasses_lateral(
self, db_session: AsyncSession
):
"""When load_options is provided, lateral join is skipped (no conflict)."""
class UserLateralCrud(AsyncCrud[User]):
model = User
default_load_options = [lateral_load(User.role)]
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
user = await UserCrud.create(
db_session,
UserCreate(username="alice", email="alice@test.com", role_id=role.id),
)
# Passing explicit load_options bypasses the lateral join — role loaded via selectinload
fetched = await UserLateralCrud.first(
db_session,
filters=[User.id == user.id],
load_options=[selectinload(User.role)],
)
assert fetched is not None
assert fetched.role is not None
assert fetched.role.name == "admin"
@pytest.mark.anyio
async def test_get_multi_call_site_load_options_bypasses_lateral(
self, db_session: AsyncSession
):
"""When load_options is provided, lateral join is skipped (no conflict)."""
class UserLateralCrud(AsyncCrud[User]):
model = User
default_load_options = [lateral_load(User.role)]
role = await RoleCrud.create(db_session, RoleCreate(name="viewer"))
for i in range(2):
await UserCrud.create(
db_session,
UserCreate(username=f"u{i}", email=f"u{i}@test.com", role_id=role.id),
)
# Passing explicit load_options bypasses the lateral join — role loaded via selectinload
users = await UserLateralCrud.get_multi(
db_session, load_options=[selectinload(User.role)]
)
assert len(users) == 2
assert all(u.role is not None and u.role.name == "viewer" for u in users)
@pytest.mark.anyio
async def test_offset_paginate_call_site_load_options_bypasses_lateral(
self, db_session: AsyncSession
):
"""When load_options is provided, lateral join is skipped (no conflict)."""
class UserLateralCrud(AsyncCrud[User]):
model = User
default_load_options = [lateral_load(User.role)]
role = await RoleCrud.create(db_session, RoleCreate(name="editor"))
for i in range(3):
await UserCrud.create(
db_session,
UserCreate(username=f"e{i}", email=f"e{i}@test.com", role_id=role.id),
)
# Passing explicit load_options bypasses the lateral join — role loaded via selectinload
result = await UserLateralCrud.offset_paginate(
db_session,
schema=UserWithRoleRead,
items_per_page=10,
load_options=[selectinload(User.role)],
)
assert result.pagination.total_count == 3
assert all(item.role is not None for item in result.data)