mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-15 22:26:25 +02:00
feat: add lateral_load() for efficient Many:One eager loading via JOIN LATERAL
This commit is contained in:
@@ -15,12 +15,13 @@ from ..types import (
|
||||
OrderFieldType,
|
||||
SearchFieldType,
|
||||
)
|
||||
from .factory import AsyncCrud, CrudFactory
|
||||
from .factory import AsyncCrud, CrudFactory, lateral_load
|
||||
from .search import SearchConfig, get_searchable_fields
|
||||
|
||||
__all__ = [
|
||||
"AsyncCrud",
|
||||
"CrudFactory",
|
||||
"lateral_load",
|
||||
"FacetFieldType",
|
||||
"get_searchable_fields",
|
||||
"InvalidFacetFilterError",
|
||||
|
||||
@@ -10,15 +10,32 @@ from collections.abc import Awaitable, Callable, Sequence
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from typing import Any, ClassVar, Generic, Literal, Self, cast, overload
|
||||
from typing import Any, ClassVar, Generic, Literal, NamedTuple, Self, cast, overload
|
||||
|
||||
from fastapi import Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import Date, DateTime, Float, Integer, Numeric, Uuid, and_, func, select
|
||||
from sqlalchemy import (
|
||||
Date,
|
||||
DateTime,
|
||||
Float,
|
||||
Integer,
|
||||
Numeric,
|
||||
Uuid,
|
||||
and_,
|
||||
func,
|
||||
select,
|
||||
true,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import DeclarativeBase, QueryableAttribute, selectinload
|
||||
from sqlalchemy.orm import (
|
||||
DeclarativeBase,
|
||||
QueryableAttribute,
|
||||
RelationshipProperty,
|
||||
contains_eager,
|
||||
selectinload,
|
||||
)
|
||||
from sqlalchemy.sql.base import ExecutableOption
|
||||
from sqlalchemy.sql.roles import WhereHavingRole
|
||||
|
||||
@@ -35,6 +52,7 @@ from ..schemas import (
|
||||
from ..types import (
|
||||
FacetFieldType,
|
||||
JoinType,
|
||||
LateralJoinType,
|
||||
M2MFieldType,
|
||||
ModelType,
|
||||
OrderByClause,
|
||||
@@ -115,6 +133,78 @@ def _apply_joins(q: Any, joins: JoinType | None, outer_join: bool) -> Any:
|
||||
return q
|
||||
|
||||
|
||||
class _ResolvedLateral(NamedTuple):
|
||||
joins: LateralJoinType
|
||||
eager: list[ExecutableOption]
|
||||
|
||||
|
||||
class _LateralLoad:
|
||||
"""Marker used inside ``default_load_options`` for lateral join loading.
|
||||
|
||||
Supports only Many:One and One:One relationships (single row per parent).
|
||||
"""
|
||||
|
||||
__slots__ = ("rel_attr",)
|
||||
|
||||
def __init__(self, rel_attr: QueryableAttribute) -> None:
|
||||
prop = rel_attr.property
|
||||
if not isinstance(prop, RelationshipProperty):
|
||||
raise TypeError(
|
||||
f"lateral_load() requires a relationship attribute, got {type(prop).__name__}. "
|
||||
"Example: lateral_load(User.team)"
|
||||
)
|
||||
if prop.secondary is not None:
|
||||
raise ValueError(
|
||||
f"lateral_load({rel_attr}) does not support Many:Many relationships. "
|
||||
"Use selectinload() instead."
|
||||
)
|
||||
if prop.uselist:
|
||||
raise ValueError(
|
||||
f"lateral_load({rel_attr}) does not support One:Many relationships. "
|
||||
"Use selectinload() instead."
|
||||
)
|
||||
self.rel_attr = rel_attr
|
||||
|
||||
|
||||
def lateral_load(rel_attr: QueryableAttribute) -> _LateralLoad:
|
||||
"""Mark a Many:One or One:One relationship for lateral join loading.
|
||||
|
||||
Raises ``ValueError`` for One:Many or Many:Many relationships.
|
||||
"""
|
||||
return _LateralLoad(rel_attr)
|
||||
|
||||
|
||||
def _build_lateral_from_relationship(
|
||||
rel_attr: QueryableAttribute,
|
||||
) -> tuple[Any, Any, ExecutableOption]:
|
||||
"""Introspect a Many:One relationship and build (lateral_subquery, true(), contains_eager)."""
|
||||
prop = rel_attr.property
|
||||
target_class = prop.mapper.class_
|
||||
parent_class = prop.parent.class_
|
||||
|
||||
conditions = [
|
||||
getattr(target_class, remote_col.key) == getattr(parent_class, local_col.key)
|
||||
for local_col, remote_col in prop.local_remote_pairs
|
||||
]
|
||||
|
||||
lateral_sub = (
|
||||
select(target_class)
|
||||
.where(and_(*conditions))
|
||||
.correlate(parent_class)
|
||||
.lateral(f"_lateral_{prop.key}")
|
||||
)
|
||||
return lateral_sub, true(), contains_eager(rel_attr, alias=lateral_sub)
|
||||
|
||||
|
||||
def _apply_lateral_joins(q: Any, lateral_joins: LateralJoinType | None) -> Any:
|
||||
"""Apply lateral subqueries as LEFT JOIN LATERAL to preserve all parent rows."""
|
||||
if not lateral_joins:
|
||||
return q
|
||||
for subquery, condition in lateral_joins:
|
||||
q = q.outerjoin(subquery, condition)
|
||||
return q
|
||||
|
||||
|
||||
def _apply_search_joins(q: Any, search_joins: list[Any]) -> Any:
|
||||
"""Apply relationship-based outer joins (from search/filter_by) to a query."""
|
||||
seen: set[str] = set()
|
||||
@@ -132,12 +222,17 @@ class AsyncCrud(Generic[ModelType]):
|
||||
Subclass this and set the `model` class variable, or use `CrudFactory`.
|
||||
"""
|
||||
|
||||
_resolved_lateral: ClassVar[_ResolvedLateral | None] = None
|
||||
|
||||
model: ClassVar[type[DeclarativeBase]]
|
||||
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
|
||||
facet_fields: ClassVar[Sequence[FacetFieldType] | None] = None
|
||||
order_fields: ClassVar[Sequence[OrderFieldType] | None] = None
|
||||
m2m_fields: ClassVar[M2MFieldType | None] = None
|
||||
default_load_options: ClassVar[Sequence[ExecutableOption] | None] = None
|
||||
default_load_options: ClassVar[Sequence[ExecutableOption | _LateralLoad] | None] = (
|
||||
None
|
||||
)
|
||||
lateral_joins: ClassVar[LateralJoinType | None] = None
|
||||
cursor_column: ClassVar[Any | None] = None
|
||||
|
||||
@classmethod
|
||||
@@ -161,14 +256,48 @@ class AsyncCrud(Generic[ModelType]):
|
||||
):
|
||||
cls.searchable_fields = [pk_col, *raw_fields]
|
||||
|
||||
raw_default_opts = cls.__dict__.get("default_load_options", None)
|
||||
if raw_default_opts:
|
||||
joins: LateralJoinType = []
|
||||
eager: list[ExecutableOption] = []
|
||||
clean: list[ExecutableOption] = []
|
||||
for opt in raw_default_opts:
|
||||
if isinstance(opt, _LateralLoad):
|
||||
lat_sub, condition, eager_opt = _build_lateral_from_relationship(
|
||||
opt.rel_attr
|
||||
)
|
||||
joins.append((lat_sub, condition))
|
||||
eager.append(eager_opt)
|
||||
else:
|
||||
clean.append(opt)
|
||||
if joins:
|
||||
cls._resolved_lateral = _ResolvedLateral(joins=joins, eager=eager)
|
||||
cls.default_load_options = clean or None
|
||||
|
||||
@classmethod
|
||||
def _get_lateral_joins(cls) -> LateralJoinType | None:
|
||||
"""Merge manual lateral_joins with ones resolved from default_load_options."""
|
||||
resolved = cls._resolved_lateral
|
||||
all_lateral = [
|
||||
*(cls.lateral_joins or []),
|
||||
*(resolved.joins if resolved else []),
|
||||
]
|
||||
return all_lateral or None
|
||||
|
||||
@classmethod
|
||||
def _resolve_load_options(
|
||||
cls, load_options: Sequence[ExecutableOption] | None
|
||||
) -> Sequence[ExecutableOption] | None:
|
||||
"""Return load_options if provided, else fall back to default_load_options."""
|
||||
"""Return merged load options."""
|
||||
if load_options is not None:
|
||||
return load_options
|
||||
return cls.default_load_options
|
||||
return list(load_options) or None
|
||||
resolved = cls._resolved_lateral
|
||||
# default_load_options is cleaned of _LateralLoad markers in __init_subclass__,
|
||||
# but its declared type still includes them — cast to reflect the runtime invariant.
|
||||
base = cast(list[ExecutableOption], cls.default_load_options or [])
|
||||
lateral = resolved.eager if resolved else []
|
||||
merged = [*base, *lateral]
|
||||
return merged or None
|
||||
|
||||
@classmethod
|
||||
async def _reload_with_options(
|
||||
@@ -861,6 +990,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
"""
|
||||
q = select(cls.model)
|
||||
q = _apply_joins(q, joins, outer_join)
|
||||
if load_options is None:
|
||||
q = _apply_lateral_joins(q, cls._get_lateral_joins())
|
||||
q = q.where(and_(*filters))
|
||||
if resolved := cls._resolve_load_options(load_options):
|
||||
q = q.options(*resolved)
|
||||
@@ -933,6 +1064,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
"""
|
||||
q = select(cls.model)
|
||||
q = _apply_joins(q, joins, outer_join)
|
||||
if load_options is None:
|
||||
q = _apply_lateral_joins(q, cls._get_lateral_joins())
|
||||
if filters:
|
||||
q = q.where(and_(*filters))
|
||||
if resolved := cls._resolve_load_options(load_options):
|
||||
@@ -978,6 +1111,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
"""
|
||||
q = select(cls.model)
|
||||
q = _apply_joins(q, joins, outer_join)
|
||||
if load_options is None:
|
||||
q = _apply_lateral_joins(q, cls._get_lateral_joins())
|
||||
if filters:
|
||||
q = q.where(and_(*filters))
|
||||
if resolved := cls._resolve_load_options(load_options):
|
||||
@@ -1300,6 +1435,10 @@ class AsyncCrud(Generic[ModelType]):
|
||||
# Apply explicit joins
|
||||
q = _apply_joins(q, joins, outer_join)
|
||||
|
||||
# Apply lateral joins (Many:One relationship loading, excluded from count query)
|
||||
if load_options is None:
|
||||
q = _apply_lateral_joins(q, cls._get_lateral_joins())
|
||||
|
||||
# Apply search joins (always outer joins for search)
|
||||
q = _apply_search_joins(q, search_joins)
|
||||
|
||||
@@ -1398,7 +1537,9 @@ class AsyncCrud(Generic[ModelType]):
|
||||
tables.
|
||||
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN.
|
||||
load_options: SQLAlchemy loader options. Falls back to
|
||||
``default_load_options`` when not provided.
|
||||
``default_load_options`` (including any lateral joins) when not
|
||||
provided. When explicitly supplied, the caller takes full control
|
||||
and lateral joins are skipped.
|
||||
order_by: Additional ordering applied after the cursor column.
|
||||
items_per_page: Number of items per page (default 20).
|
||||
search: Search query string or SearchConfig object.
|
||||
@@ -1455,6 +1596,10 @@ class AsyncCrud(Generic[ModelType]):
|
||||
# Apply explicit joins
|
||||
q = _apply_joins(q, joins, outer_join)
|
||||
|
||||
# Apply lateral joins (Many:One relationship loading)
|
||||
if load_options is None:
|
||||
q = _apply_lateral_joins(q, cls._get_lateral_joins())
|
||||
|
||||
# Apply search joins (always outer joins)
|
||||
q = _apply_search_joins(q, search_joins)
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ SchemaType = TypeVar("SchemaType", bound=BaseModel)
|
||||
|
||||
# CRUD type aliases
|
||||
JoinType = list[tuple[type[DeclarativeBase] | Any, Any]]
|
||||
LateralJoinType = list[tuple[Any, Any]]
|
||||
M2MFieldType = Mapping[str, QueryableAttribute[Any]]
|
||||
OrderByClause = ColumnElement[Any] | QueryableAttribute[Any]
|
||||
|
||||
|
||||
@@ -6,9 +6,15 @@ import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from fastapi_toolsets.crud import CrudFactory, PaginationType
|
||||
from fastapi_toolsets.crud.factory import AsyncCrud, _CursorDirection
|
||||
from fastapi_toolsets.crud import CrudFactory, PaginationType, lateral_load
|
||||
from fastapi_toolsets.crud.factory import (
|
||||
AsyncCrud,
|
||||
_CursorDirection,
|
||||
_LateralLoad,
|
||||
_ResolvedLateral,
|
||||
)
|
||||
from fastapi_toolsets.exceptions import NotFoundError
|
||||
from fastapi_toolsets.schemas import PydanticBase
|
||||
|
||||
from .conftest import (
|
||||
EventCreate,
|
||||
@@ -51,6 +57,12 @@ from .conftest import (
|
||||
)
|
||||
|
||||
|
||||
class UserWithRoleRead(PydanticBase):
|
||||
id: uuid.UUID
|
||||
username: str
|
||||
role: RoleRead | None = None
|
||||
|
||||
|
||||
class TestCrudFactory:
|
||||
"""Tests for CrudFactory."""
|
||||
|
||||
@@ -208,11 +220,11 @@ class TestResolveLoadOptions:
|
||||
assert crud._resolve_load_options(None) is None
|
||||
|
||||
def test_empty_list_overrides_default(self):
|
||||
"""An empty list is a valid override and disables default_load_options."""
|
||||
"""An explicit empty list disables default_load_options (no options applied)."""
|
||||
default = [selectinload(User.role)]
|
||||
crud = CrudFactory(User, default_load_options=default)
|
||||
# Empty list is not None, so it should replace default
|
||||
assert crud._resolve_load_options([]) == []
|
||||
# Empty list replaces default; None and [] are both falsy → no options applied
|
||||
assert not crud._resolve_load_options([])
|
||||
|
||||
|
||||
class TestResolveSearchColumns:
|
||||
@@ -359,13 +371,6 @@ class TestDefaultLoadOptionsIntegration:
|
||||
self, db_session: AsyncSession
|
||||
):
|
||||
"""default_load_options loads relationships automatically on offset_paginate()."""
|
||||
from fastapi_toolsets.schemas import PydanticBase
|
||||
|
||||
class UserWithRoleRead(PydanticBase):
|
||||
id: uuid.UUID
|
||||
username: str
|
||||
role: RoleRead | None = None
|
||||
|
||||
UserWithDefaultLoad = CrudFactory(
|
||||
User, default_load_options=[selectinload(User.role)]
|
||||
)
|
||||
@@ -2462,12 +2467,7 @@ class TestCursorPaginateExtraOptions:
|
||||
@pytest.mark.anyio
|
||||
async def test_with_load_options(self, db_session: AsyncSession):
|
||||
"""cursor_paginate passes load_options to the query."""
|
||||
from fastapi_toolsets.schemas import CursorPagination, PydanticBase
|
||||
|
||||
class UserWithRoleRead(PydanticBase):
|
||||
id: uuid.UUID
|
||||
username: str
|
||||
role: RoleRead | None = None
|
||||
from fastapi_toolsets.schemas import CursorPagination
|
||||
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="manager"))
|
||||
for i in range(3):
|
||||
@@ -2833,3 +2833,445 @@ class TestPaginate:
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count is None
|
||||
|
||||
|
||||
class TestLateralLoadValidation:
|
||||
"""lateral_load() raises immediately for bad relationship types."""
|
||||
|
||||
def test_valid_many_to_one_returns_marker(self):
|
||||
"""lateral_load() on a Many:One rel returns a _LateralLoad with rel_attr set."""
|
||||
marker = lateral_load(User.role)
|
||||
assert isinstance(marker, _LateralLoad)
|
||||
assert marker.rel_attr is User.role
|
||||
|
||||
def test_raises_type_error_for_plain_column(self):
|
||||
"""lateral_load() raises TypeError when passed a plain column."""
|
||||
with pytest.raises(TypeError, match="relationship attribute"):
|
||||
lateral_load(User.username)
|
||||
|
||||
def test_raises_value_error_for_many_to_many(self):
|
||||
"""lateral_load() raises ValueError for Many:Many (secondary table)."""
|
||||
with pytest.raises(ValueError, match="Many:Many"):
|
||||
lateral_load(Post.tags)
|
||||
|
||||
def test_raises_value_error_for_one_to_many(self):
|
||||
"""lateral_load() raises ValueError for One:Many (uselist=True)."""
|
||||
with pytest.raises(ValueError, match="One:Many"):
|
||||
lateral_load(Role.users)
|
||||
|
||||
|
||||
class TestLateralLoadInSubclass:
|
||||
"""lateral_load() markers in default_load_options are processed at class definition."""
|
||||
|
||||
def test_marker_extracted_from_default_load_options(self):
|
||||
"""_LateralLoad is removed from default_load_options and stored in _resolved_lateral."""
|
||||
|
||||
class UserLateralCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
default_load_options = [lateral_load(User.role)]
|
||||
|
||||
assert UserLateralCrud.default_load_options is None
|
||||
assert UserLateralCrud._resolved_lateral is not None
|
||||
|
||||
def test_resolved_lateral_has_one_join_and_eager(self):
|
||||
"""_resolved_lateral contains exactly one join and one eager option."""
|
||||
|
||||
class UserLateralCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
default_load_options = [lateral_load(User.role)]
|
||||
|
||||
resolved = UserLateralCrud._resolved_lateral
|
||||
assert isinstance(resolved, _ResolvedLateral)
|
||||
assert len(resolved.joins) == 1
|
||||
assert len(resolved.eager) == 1
|
||||
|
||||
def test_regular_options_preserved_alongside_lateral(self):
|
||||
"""Non-lateral opts stay in default_load_options; lateral marker is extracted."""
|
||||
regular = selectinload(User.role)
|
||||
|
||||
class UserMixedCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
default_load_options = [lateral_load(User.role), regular]
|
||||
|
||||
assert UserMixedCrud._resolved_lateral is not None
|
||||
assert UserMixedCrud.default_load_options == [regular]
|
||||
|
||||
def test_no_lateral_leaves_default_load_options_untouched(self):
|
||||
"""When no lateral marker is present, default_load_options is unchanged."""
|
||||
opts = [selectinload(User.role)]
|
||||
|
||||
class UserNormalCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
default_load_options = opts
|
||||
|
||||
assert UserNormalCrud.default_load_options is opts
|
||||
assert UserNormalCrud._resolved_lateral is None
|
||||
|
||||
def test_no_default_load_options_leaves_resolved_lateral_none(self):
|
||||
"""_resolved_lateral stays None when default_load_options is not set."""
|
||||
|
||||
class UserPlainCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
|
||||
assert UserPlainCrud._resolved_lateral is None
|
||||
|
||||
|
||||
class TestResolveLoadOptionsWithLateral:
|
||||
"""_resolve_load_options always appends lateral eager options."""
|
||||
|
||||
def test_lateral_eager_included_when_no_call_site_opts(self):
|
||||
"""contains_eager from lateral_load is returned when load_options=None."""
|
||||
|
||||
class UserLateralCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
default_load_options = [lateral_load(User.role)]
|
||||
|
||||
resolved = UserLateralCrud._resolve_load_options(None)
|
||||
assert resolved is not None
|
||||
assert len(resolved) == 1 # the contains_eager
|
||||
|
||||
def test_call_site_opts_bypass_lateral_eager(self):
|
||||
"""When call-site load_options are provided, lateral eager is NOT appended."""
|
||||
extra = selectinload(User.role)
|
||||
|
||||
class UserLateralCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
default_load_options = [lateral_load(User.role)]
|
||||
|
||||
resolved = UserLateralCrud._resolve_load_options([extra])
|
||||
assert resolved is not None
|
||||
assert len(resolved) == 1 # only the call-site option; lateral eager skipped
|
||||
|
||||
def test_lateral_eager_appended_to_default_load_options(self):
|
||||
"""default_load_options (regular) + lateral eager are both returned."""
|
||||
regular = selectinload(User.role)
|
||||
|
||||
class UserMixedCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
default_load_options = [lateral_load(User.role), regular]
|
||||
|
||||
resolved = UserMixedCrud._resolve_load_options(None)
|
||||
assert resolved is not None
|
||||
assert len(resolved) == 2
|
||||
|
||||
|
||||
class TestGetLateralJoins:
|
||||
"""_get_lateral_joins merges auto-resolved and manual lateral_joins."""
|
||||
|
||||
def test_returns_none_when_no_lateral_configured(self):
|
||||
"""Returns None when neither lateral_joins nor lateral_load is set."""
|
||||
|
||||
class UserPlainCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
|
||||
assert UserPlainCrud._get_lateral_joins() is None
|
||||
|
||||
def test_returns_resolved_lateral_joins(self):
|
||||
"""Returns the join tuple built from lateral_load()."""
|
||||
|
||||
class UserLateralCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
default_load_options = [lateral_load(User.role)]
|
||||
|
||||
joins = UserLateralCrud._get_lateral_joins()
|
||||
assert joins is not None
|
||||
assert len(joins) == 1
|
||||
|
||||
def test_manual_lateral_joins_included(self):
|
||||
"""Manual lateral_joins class var is included in _get_lateral_joins."""
|
||||
from sqlalchemy import select, true
|
||||
|
||||
manual_sub = select(Role).where(Role.id == User.role_id).lateral("_manual_role")
|
||||
|
||||
class UserManualCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
lateral_joins = [(manual_sub, true())]
|
||||
|
||||
joins = UserManualCrud._get_lateral_joins()
|
||||
assert joins is not None
|
||||
assert len(joins) == 1
|
||||
|
||||
def test_manual_and_auto_lateral_joins_merged(self):
|
||||
"""Both manual lateral_joins and auto-resolved from lateral_load are combined."""
|
||||
from sqlalchemy import select, true
|
||||
|
||||
manual_sub = select(Role).where(Role.id == User.role_id).lateral("_manual_role")
|
||||
|
||||
class UserBothCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
lateral_joins = [(manual_sub, true())]
|
||||
default_load_options = [lateral_load(User.role)]
|
||||
|
||||
joins = UserBothCrud._get_lateral_joins()
|
||||
assert joins is not None
|
||||
assert len(joins) == 2
|
||||
|
||||
|
||||
class TestLateralLoadIntegration:
|
||||
"""lateral_load() in real DB queries: relationship loaded, pagination correct."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_loads_relationship(self, db_session: AsyncSession):
|
||||
"""get() populates the relationship via lateral join."""
|
||||
|
||||
class UserLateralCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
default_load_options = [lateral_load(User.role)]
|
||||
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||
user = await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="alice", email="alice@test.com", role_id=role.id),
|
||||
)
|
||||
|
||||
fetched = await UserLateralCrud.get(db_session, [User.id == user.id])
|
||||
assert fetched.role is not None
|
||||
assert fetched.role.name == "admin"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_null_fk_preserved(self, db_session: AsyncSession):
|
||||
"""User with null role_id still returned (LEFT JOIN behaviour)."""
|
||||
|
||||
class UserLateralCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
default_load_options = [lateral_load(User.role)]
|
||||
|
||||
user = await UserCrud.create(
|
||||
db_session, UserCreate(username="bob", email="bob@test.com")
|
||||
)
|
||||
|
||||
fetched = await UserLateralCrud.get(db_session, [User.id == user.id])
|
||||
assert fetched is not None
|
||||
assert fetched.role is None
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_first_loads_relationship(self, db_session: AsyncSession):
|
||||
"""first() populates the relationship via lateral join."""
|
||||
|
||||
class UserLateralCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
default_load_options = [lateral_load(User.role)]
|
||||
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="editor"))
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="carol", email="carol@test.com", role_id=role.id),
|
||||
)
|
||||
|
||||
user = await UserLateralCrud.first(db_session)
|
||||
assert user is not None
|
||||
assert user.role is not None
|
||||
assert user.role.name == "editor"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_multi_loads_relationship(self, db_session: AsyncSession):
|
||||
"""get_multi() populates the relationship via lateral join for all rows."""
|
||||
|
||||
class UserLateralCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
default_load_options = [lateral_load(User.role)]
|
||||
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="member"))
|
||||
for i in range(3):
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(
|
||||
username=f"user{i}", email=f"u{i}@test.com", role_id=role.id
|
||||
),
|
||||
)
|
||||
|
||||
users = await UserLateralCrud.get_multi(db_session)
|
||||
assert len(users) == 3
|
||||
assert all(u.role is not None and u.role.name == "member" for u in users)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_offset_paginate_correct_count(self, db_session: AsyncSession):
|
||||
"""offset_paginate total_count is not inflated by the lateral join."""
|
||||
|
||||
class UserLateralCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
default_load_options = [lateral_load(User.role)]
|
||||
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||
for i in range(5):
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(
|
||||
username=f"user{i}", email=f"u{i}@test.com", role_id=role.id
|
||||
),
|
||||
)
|
||||
|
||||
result = await UserLateralCrud.offset_paginate(
|
||||
db_session, schema=UserWithRoleRead, items_per_page=10
|
||||
)
|
||||
assert result.pagination.total_count == 5
|
||||
assert len(result.data) == 5
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_offset_paginate_loads_relationship(self, db_session: AsyncSession):
|
||||
"""offset_paginate serializes relationship data loaded via lateral."""
|
||||
|
||||
class UserLateralCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
default_load_options = [lateral_load(User.role)]
|
||||
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="alice", email="alice@test.com", role_id=role.id),
|
||||
)
|
||||
|
||||
result = await UserLateralCrud.offset_paginate(
|
||||
db_session, schema=UserWithRoleRead, items_per_page=10
|
||||
)
|
||||
assert result.data[0].role is not None
|
||||
assert result.data[0].role.name == "admin"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_offset_paginate_mixed_null_fk(self, db_session: AsyncSession):
|
||||
"""offset_paginate returns all users including those with null role_id."""
|
||||
|
||||
class UserLateralCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
default_load_options = [lateral_load(User.role)]
|
||||
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="with_role", email="a@test.com", role_id=role.id),
|
||||
)
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="no_role", email="b@test.com")
|
||||
)
|
||||
|
||||
result = await UserLateralCrud.offset_paginate(
|
||||
db_session, schema=UserWithRoleRead, items_per_page=10
|
||||
)
|
||||
assert result.pagination.total_count == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_cursor_paginate_loads_relationship(self, db_session: AsyncSession):
|
||||
"""cursor_paginate populates the relationship via lateral join."""
|
||||
|
||||
class UserLateralCursorCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
cursor_column = User.id
|
||||
default_load_options = [lateral_load(User.role)]
|
||||
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||
for i in range(3):
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(
|
||||
username=f"user{i}", email=f"u{i}@test.com", role_id=role.id
|
||||
),
|
||||
)
|
||||
|
||||
result = await UserLateralCursorCrud.cursor_paginate(
|
||||
db_session, schema=UserWithRoleRead, items_per_page=10
|
||||
)
|
||||
assert len(result.data) == 3
|
||||
assert all(item.role is not None for item in result.data)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_offset_paginate_with_search_and_lateral(
|
||||
self, db_session: AsyncSession
|
||||
):
|
||||
"""search filter works alongside lateral join."""
|
||||
|
||||
class UserLateralCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
default_load_options = [lateral_load(User.role)]
|
||||
searchable_fields = [User.username]
|
||||
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="alice", email="a@test.com", role_id=role.id),
|
||||
)
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="bob", email="b@test.com", role_id=role.id)
|
||||
)
|
||||
|
||||
result = await UserLateralCrud.offset_paginate(
|
||||
db_session, schema=UserWithRoleRead, search="alice", items_per_page=10
|
||||
)
|
||||
assert result.pagination.total_count == 1
|
||||
assert result.data[0].username == "alice"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_first_call_site_load_options_bypasses_lateral(
|
||||
self, db_session: AsyncSession
|
||||
):
|
||||
"""When load_options is provided, lateral join is skipped (no conflict)."""
|
||||
|
||||
class UserLateralCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
default_load_options = [lateral_load(User.role)]
|
||||
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||
user = await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="alice", email="alice@test.com", role_id=role.id),
|
||||
)
|
||||
|
||||
# Passing explicit load_options bypasses the lateral join — role loaded via selectinload
|
||||
fetched = await UserLateralCrud.first(
|
||||
db_session,
|
||||
filters=[User.id == user.id],
|
||||
load_options=[selectinload(User.role)],
|
||||
)
|
||||
assert fetched is not None
|
||||
assert fetched.role is not None
|
||||
assert fetched.role.name == "admin"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_multi_call_site_load_options_bypasses_lateral(
|
||||
self, db_session: AsyncSession
|
||||
):
|
||||
"""When load_options is provided, lateral join is skipped (no conflict)."""
|
||||
|
||||
class UserLateralCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
default_load_options = [lateral_load(User.role)]
|
||||
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="viewer"))
|
||||
for i in range(2):
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username=f"u{i}", email=f"u{i}@test.com", role_id=role.id),
|
||||
)
|
||||
|
||||
# Passing explicit load_options bypasses the lateral join — role loaded via selectinload
|
||||
users = await UserLateralCrud.get_multi(
|
||||
db_session, load_options=[selectinload(User.role)]
|
||||
)
|
||||
assert len(users) == 2
|
||||
assert all(u.role is not None and u.role.name == "viewer" for u in users)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_offset_paginate_call_site_load_options_bypasses_lateral(
|
||||
self, db_session: AsyncSession
|
||||
):
|
||||
"""When load_options is provided, lateral join is skipped (no conflict)."""
|
||||
|
||||
class UserLateralCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
default_load_options = [lateral_load(User.role)]
|
||||
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="editor"))
|
||||
for i in range(3):
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username=f"e{i}", email=f"e{i}@test.com", role_id=role.id),
|
||||
)
|
||||
|
||||
# Passing explicit load_options bypasses the lateral join — role loaded via selectinload
|
||||
result = await UserLateralCrud.offset_paginate(
|
||||
db_session,
|
||||
schema=UserWithRoleRead,
|
||||
items_per_page=10,
|
||||
load_options=[selectinload(User.role)],
|
||||
)
|
||||
assert result.pagination.total_count == 3
|
||||
assert all(item.role is not None for item in result.data)
|
||||
|
||||
Reference in New Issue
Block a user