feat: add lateral_load() for efficient Many:One eager loading via JOIN LATERAL

This commit is contained in:
2026-04-07 13:07:34 -04:00
parent 0ed93d62c8
commit e1f96ad7fe
4 changed files with 616 additions and 27 deletions

View File

@@ -10,15 +10,32 @@ from collections.abc import Awaitable, Callable, Sequence
from datetime import date, datetime
from decimal import Decimal
from enum import Enum
from typing import Any, ClassVar, Generic, Literal, Self, cast, overload
from typing import Any, ClassVar, Generic, Literal, NamedTuple, Self, cast, overload
from fastapi import Query
from pydantic import BaseModel
from sqlalchemy import Date, DateTime, Float, Integer, Numeric, Uuid, and_, func, select
from sqlalchemy import (
Date,
DateTime,
Float,
Integer,
Numeric,
Uuid,
and_,
func,
select,
true,
)
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.exc import NoResultFound
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase, QueryableAttribute, selectinload
from sqlalchemy.orm import (
DeclarativeBase,
QueryableAttribute,
RelationshipProperty,
contains_eager,
selectinload,
)
from sqlalchemy.sql.base import ExecutableOption
from sqlalchemy.sql.roles import WhereHavingRole
@@ -35,6 +52,7 @@ from ..schemas import (
from ..types import (
FacetFieldType,
JoinType,
LateralJoinType,
M2MFieldType,
ModelType,
OrderByClause,
@@ -115,6 +133,78 @@ def _apply_joins(q: Any, joins: JoinType | None, outer_join: bool) -> Any:
return q
class _ResolvedLateral(NamedTuple):
joins: LateralJoinType
eager: list[ExecutableOption]
class _LateralLoad:
"""Marker used inside ``default_load_options`` for lateral join loading.
Supports only Many:One and One:One relationships (single row per parent).
"""
__slots__ = ("rel_attr",)
def __init__(self, rel_attr: QueryableAttribute) -> None:
prop = rel_attr.property
if not isinstance(prop, RelationshipProperty):
raise TypeError(
f"lateral_load() requires a relationship attribute, got {type(prop).__name__}. "
"Example: lateral_load(User.team)"
)
if prop.secondary is not None:
raise ValueError(
f"lateral_load({rel_attr}) does not support Many:Many relationships. "
"Use selectinload() instead."
)
if prop.uselist:
raise ValueError(
f"lateral_load({rel_attr}) does not support One:Many relationships. "
"Use selectinload() instead."
)
self.rel_attr = rel_attr
def lateral_load(rel_attr: QueryableAttribute) -> _LateralLoad:
"""Mark a Many:One or One:One relationship for lateral join loading.
Raises ``ValueError`` for One:Many or Many:Many relationships.
"""
return _LateralLoad(rel_attr)
def _build_lateral_from_relationship(
rel_attr: QueryableAttribute,
) -> tuple[Any, Any, ExecutableOption]:
"""Introspect a Many:One relationship and build (lateral_subquery, true(), contains_eager)."""
prop = rel_attr.property
target_class = prop.mapper.class_
parent_class = prop.parent.class_
conditions = [
getattr(target_class, remote_col.key) == getattr(parent_class, local_col.key)
for local_col, remote_col in prop.local_remote_pairs
]
lateral_sub = (
select(target_class)
.where(and_(*conditions))
.correlate(parent_class)
.lateral(f"_lateral_{prop.key}")
)
return lateral_sub, true(), contains_eager(rel_attr, alias=lateral_sub)
def _apply_lateral_joins(q: Any, lateral_joins: LateralJoinType | None) -> Any:
"""Apply lateral subqueries as LEFT JOIN LATERAL to preserve all parent rows."""
if not lateral_joins:
return q
for subquery, condition in lateral_joins:
q = q.outerjoin(subquery, condition)
return q
def _apply_search_joins(q: Any, search_joins: list[Any]) -> Any:
"""Apply relationship-based outer joins (from search/filter_by) to a query."""
seen: set[str] = set()
@@ -132,12 +222,17 @@ class AsyncCrud(Generic[ModelType]):
Subclass this and set the `model` class variable, or use `CrudFactory`.
"""
_resolved_lateral: ClassVar[_ResolvedLateral | None] = None
model: ClassVar[type[DeclarativeBase]]
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
facet_fields: ClassVar[Sequence[FacetFieldType] | None] = None
order_fields: ClassVar[Sequence[OrderFieldType] | None] = None
m2m_fields: ClassVar[M2MFieldType | None] = None
default_load_options: ClassVar[Sequence[ExecutableOption] | None] = None
default_load_options: ClassVar[Sequence[ExecutableOption | _LateralLoad] | None] = (
None
)
lateral_joins: ClassVar[LateralJoinType | None] = None
cursor_column: ClassVar[Any | None] = None
@classmethod
@@ -161,14 +256,48 @@ class AsyncCrud(Generic[ModelType]):
):
cls.searchable_fields = [pk_col, *raw_fields]
raw_default_opts = cls.__dict__.get("default_load_options", None)
if raw_default_opts:
joins: LateralJoinType = []
eager: list[ExecutableOption] = []
clean: list[ExecutableOption] = []
for opt in raw_default_opts:
if isinstance(opt, _LateralLoad):
lat_sub, condition, eager_opt = _build_lateral_from_relationship(
opt.rel_attr
)
joins.append((lat_sub, condition))
eager.append(eager_opt)
else:
clean.append(opt)
if joins:
cls._resolved_lateral = _ResolvedLateral(joins=joins, eager=eager)
cls.default_load_options = clean or None
@classmethod
def _get_lateral_joins(cls) -> LateralJoinType | None:
"""Merge manual lateral_joins with ones resolved from default_load_options."""
resolved = cls._resolved_lateral
all_lateral = [
*(cls.lateral_joins or []),
*(resolved.joins if resolved else []),
]
return all_lateral or None
@classmethod
def _resolve_load_options(
cls, load_options: Sequence[ExecutableOption] | None
) -> Sequence[ExecutableOption] | None:
"""Return load_options if provided, else fall back to default_load_options."""
"""Return merged load options."""
if load_options is not None:
return load_options
return cls.default_load_options
return list(load_options) or None
resolved = cls._resolved_lateral
# default_load_options is cleaned of _LateralLoad markers in __init_subclass__,
# but its declared type still includes them — cast to reflect the runtime invariant.
base = cast(list[ExecutableOption], cls.default_load_options or [])
lateral = resolved.eager if resolved else []
merged = [*base, *lateral]
return merged or None
@classmethod
async def _reload_with_options(
@@ -861,6 +990,8 @@ class AsyncCrud(Generic[ModelType]):
"""
q = select(cls.model)
q = _apply_joins(q, joins, outer_join)
if load_options is None:
q = _apply_lateral_joins(q, cls._get_lateral_joins())
q = q.where(and_(*filters))
if resolved := cls._resolve_load_options(load_options):
q = q.options(*resolved)
@@ -933,6 +1064,8 @@ class AsyncCrud(Generic[ModelType]):
"""
q = select(cls.model)
q = _apply_joins(q, joins, outer_join)
if load_options is None:
q = _apply_lateral_joins(q, cls._get_lateral_joins())
if filters:
q = q.where(and_(*filters))
if resolved := cls._resolve_load_options(load_options):
@@ -978,6 +1111,8 @@ class AsyncCrud(Generic[ModelType]):
"""
q = select(cls.model)
q = _apply_joins(q, joins, outer_join)
if load_options is None:
q = _apply_lateral_joins(q, cls._get_lateral_joins())
if filters:
q = q.where(and_(*filters))
if resolved := cls._resolve_load_options(load_options):
@@ -1300,6 +1435,10 @@ class AsyncCrud(Generic[ModelType]):
# Apply explicit joins
q = _apply_joins(q, joins, outer_join)
# Apply lateral joins (Many:One relationship loading, excluded from count query)
if load_options is None:
q = _apply_lateral_joins(q, cls._get_lateral_joins())
# Apply search joins (always outer joins for search)
q = _apply_search_joins(q, search_joins)
@@ -1398,7 +1537,9 @@ class AsyncCrud(Generic[ModelType]):
tables.
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN.
load_options: SQLAlchemy loader options. Falls back to
``default_load_options`` when not provided.
``default_load_options`` (including any lateral joins) when not
provided. When explicitly supplied, the caller takes full control
and lateral joins are skipped.
order_by: Additional ordering applied after the cursor column.
items_per_page: Number of items per page (default 20).
search: Search query string or SearchConfig object.
@@ -1455,6 +1596,10 @@ class AsyncCrud(Generic[ModelType]):
# Apply explicit joins
q = _apply_joins(q, joins, outer_join)
# Apply lateral joins (Many:One relationship loading)
if load_options is None:
q = _apply_lateral_joins(q, cls._get_lateral_joins())
# Apply search joins (always outer joins)
q = _apply_search_joins(q, search_joins)