Compare commits

..

5 Commits

2 changed files with 97 additions and 15 deletions

View File

@@ -288,15 +288,13 @@ class AsyncCrud(Generic[ModelType]):
def _resolve_load_options( def _resolve_load_options(
cls, load_options: Sequence[ExecutableOption] | None cls, load_options: Sequence[ExecutableOption] | None
) -> Sequence[ExecutableOption] | None: ) -> Sequence[ExecutableOption] | None:
"""Return merged load options: call-site or default, always with lateral eager opts.""" """Return merged load options."""
if load_options is not None:
return list(load_options) or None
resolved = cls._resolved_lateral resolved = cls._resolved_lateral
# default_load_options is cleaned of _LateralLoad markers in __init_subclass__, # default_load_options is cleaned of _LateralLoad markers in __init_subclass__,
# but its declared type still includes them — cast to reflect the runtime invariant. # but its declared type still includes them — cast to reflect the runtime invariant.
base: Sequence[ExecutableOption] = ( base = cast(list[ExecutableOption], cls.default_load_options or [])
load_options
if load_options is not None
else cast(list[ExecutableOption], cls.default_load_options or [])
)
lateral = resolved.eager if resolved else [] lateral = resolved.eager if resolved else []
merged = [*base, *lateral] merged = [*base, *lateral]
return merged or None return merged or None
@@ -992,6 +990,7 @@ class AsyncCrud(Generic[ModelType]):
""" """
q = select(cls.model) q = select(cls.model)
q = _apply_joins(q, joins, outer_join) q = _apply_joins(q, joins, outer_join)
if load_options is None:
q = _apply_lateral_joins(q, cls._get_lateral_joins()) q = _apply_lateral_joins(q, cls._get_lateral_joins())
q = q.where(and_(*filters)) q = q.where(and_(*filters))
if resolved := cls._resolve_load_options(load_options): if resolved := cls._resolve_load_options(load_options):
@@ -1065,6 +1064,7 @@ class AsyncCrud(Generic[ModelType]):
""" """
q = select(cls.model) q = select(cls.model)
q = _apply_joins(q, joins, outer_join) q = _apply_joins(q, joins, outer_join)
if load_options is None:
q = _apply_lateral_joins(q, cls._get_lateral_joins()) q = _apply_lateral_joins(q, cls._get_lateral_joins())
if filters: if filters:
q = q.where(and_(*filters)) q = q.where(and_(*filters))
@@ -1111,6 +1111,7 @@ class AsyncCrud(Generic[ModelType]):
""" """
q = select(cls.model) q = select(cls.model)
q = _apply_joins(q, joins, outer_join) q = _apply_joins(q, joins, outer_join)
if load_options is None:
q = _apply_lateral_joins(q, cls._get_lateral_joins()) q = _apply_lateral_joins(q, cls._get_lateral_joins())
if filters: if filters:
q = q.where(and_(*filters)) q = q.where(and_(*filters))
@@ -1435,6 +1436,7 @@ class AsyncCrud(Generic[ModelType]):
q = _apply_joins(q, joins, outer_join) q = _apply_joins(q, joins, outer_join)
# Apply lateral joins (Many:One relationship loading, excluded from count query) # Apply lateral joins (Many:One relationship loading, excluded from count query)
if load_options is None:
q = _apply_lateral_joins(q, cls._get_lateral_joins()) q = _apply_lateral_joins(q, cls._get_lateral_joins())
# Apply search joins (always outer joins for search) # Apply search joins (always outer joins for search)
@@ -1535,7 +1537,9 @@ class AsyncCrud(Generic[ModelType]):
tables. tables.
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN. outer_join: Use LEFT OUTER JOIN instead of INNER JOIN.
load_options: SQLAlchemy loader options. Falls back to load_options: SQLAlchemy loader options. Falls back to
``default_load_options`` when not provided. ``default_load_options`` (including any lateral joins) when not
provided. When explicitly supplied, the caller takes full control
and lateral joins are skipped.
order_by: Additional ordering applied after the cursor column. order_by: Additional ordering applied after the cursor column.
items_per_page: Number of items per page (default 20). items_per_page: Number of items per page (default 20).
search: Search query string or SearchConfig object. search: Search query string or SearchConfig object.
@@ -1593,6 +1597,7 @@ class AsyncCrud(Generic[ModelType]):
q = _apply_joins(q, joins, outer_join) q = _apply_joins(q, joins, outer_join)
# Apply lateral joins (Many:One relationship loading) # Apply lateral joins (Many:One relationship loading)
if load_options is None:
q = _apply_lateral_joins(q, cls._get_lateral_joins()) q = _apply_lateral_joins(q, cls._get_lateral_joins())
# Apply search joins (always outer joins) # Apply search joins (always outer joins)

View File

@@ -2930,8 +2930,8 @@ class TestResolveLoadOptionsWithLateral:
assert resolved is not None assert resolved is not None
assert len(resolved) == 1 # the contains_eager assert len(resolved) == 1 # the contains_eager
def test_lateral_eager_appended_to_call_site_opts(self): def test_call_site_opts_bypass_lateral_eager(self):
"""call-site load_options + lateral eager are both returned.""" """When call-site load_options are provided, lateral eager is NOT appended."""
extra = selectinload(User.role) extra = selectinload(User.role)
class UserLateralCrud(AsyncCrud[User]): class UserLateralCrud(AsyncCrud[User]):
@@ -2940,7 +2940,7 @@ class TestResolveLoadOptionsWithLateral:
resolved = UserLateralCrud._resolve_load_options([extra]) resolved = UserLateralCrud._resolve_load_options([extra])
assert resolved is not None assert resolved is not None
assert len(resolved) == 2 assert len(resolved) == 1 # only the call-site option; lateral eager skipped
def test_lateral_eager_appended_to_default_load_options(self): def test_lateral_eager_appended_to_default_load_options(self):
"""default_load_options (regular) + lateral eager are both returned.""" """default_load_options (regular) + lateral eager are both returned."""
@@ -3198,3 +3198,80 @@ class TestLateralLoadIntegration:
) )
assert result.pagination.total_count == 1 assert result.pagination.total_count == 1
assert result.data[0].username == "alice" assert result.data[0].username == "alice"
@pytest.mark.anyio
async def test_first_call_site_load_options_bypasses_lateral(
self, db_session: AsyncSession
):
"""When load_options is provided, lateral join is skipped (no conflict)."""
class UserLateralCrud(AsyncCrud[User]):
model = User
default_load_options = [lateral_load(User.role)]
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
user = await UserCrud.create(
db_session,
UserCreate(username="alice", email="alice@test.com", role_id=role.id),
)
# Passing explicit load_options bypasses the lateral join — role loaded via selectinload
fetched = await UserLateralCrud.first(
db_session,
filters=[User.id == user.id],
load_options=[selectinload(User.role)],
)
assert fetched is not None
assert fetched.role is not None
assert fetched.role.name == "admin"
@pytest.mark.anyio
async def test_get_multi_call_site_load_options_bypasses_lateral(
self, db_session: AsyncSession
):
"""When load_options is provided, lateral join is skipped (no conflict)."""
class UserLateralCrud(AsyncCrud[User]):
model = User
default_load_options = [lateral_load(User.role)]
role = await RoleCrud.create(db_session, RoleCreate(name="viewer"))
for i in range(2):
await UserCrud.create(
db_session,
UserCreate(username=f"u{i}", email=f"u{i}@test.com", role_id=role.id),
)
# Passing explicit load_options bypasses the lateral join — role loaded via selectinload
users = await UserLateralCrud.get_multi(
db_session, load_options=[selectinload(User.role)]
)
assert len(users) == 2
assert all(u.role is not None and u.role.name == "viewer" for u in users)
@pytest.mark.anyio
async def test_offset_paginate_call_site_load_options_bypasses_lateral(
self, db_session: AsyncSession
):
"""When load_options is provided, lateral join is skipped (no conflict)."""
class UserLateralCrud(AsyncCrud[User]):
model = User
default_load_options = [lateral_load(User.role)]
role = await RoleCrud.create(db_session, RoleCreate(name="editor"))
for i in range(3):
await UserCrud.create(
db_session,
UserCreate(username=f"e{i}", email=f"e{i}@test.com", role_id=role.id),
)
# Passing explicit load_options bypasses the lateral join — role loaded via selectinload
result = await UserLateralCrud.offset_paginate(
db_session,
schema=UserWithRoleRead,
items_per_page=10,
load_options=[selectinload(User.role)],
)
assert result.pagination.total_count == 3
assert all(item.role is not None for item in result.data)