feat: add lateral_load() for efficient Many:One eager loading via JOIN LATERAL

This commit is contained in:
2026-04-07 13:07:34 -04:00
parent 0ed93d62c8
commit c859d6d48b
4 changed files with 534 additions and 27 deletions

View File

@@ -10,15 +10,32 @@ from collections.abc import Awaitable, Callable, Sequence
from datetime import date, datetime
from decimal import Decimal
from enum import Enum
from typing import Any, ClassVar, Generic, Literal, Self, cast, overload
from typing import Any, ClassVar, Generic, Literal, NamedTuple, Self, cast, overload
from fastapi import Query
from pydantic import BaseModel
from sqlalchemy import Date, DateTime, Float, Integer, Numeric, Uuid, and_, func, select
from sqlalchemy import (
Date,
DateTime,
Float,
Integer,
Numeric,
Uuid,
and_,
func,
select,
true,
)
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.exc import NoResultFound
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase, QueryableAttribute, selectinload
from sqlalchemy.orm import (
DeclarativeBase,
QueryableAttribute,
RelationshipProperty,
contains_eager,
selectinload,
)
from sqlalchemy.sql.base import ExecutableOption
from sqlalchemy.sql.roles import WhereHavingRole
@@ -35,6 +52,7 @@ from ..schemas import (
from ..types import (
FacetFieldType,
JoinType,
LateralJoinType,
M2MFieldType,
ModelType,
OrderByClause,
@@ -115,6 +133,78 @@ def _apply_joins(q: Any, joins: JoinType | None, outer_join: bool) -> Any:
return q
class _ResolvedLateral(NamedTuple):
joins: LateralJoinType
eager: list[ExecutableOption]
class _LateralLoad:
"""Marker used inside ``default_load_options`` for lateral join loading.
Supports only Many:One and One:One relationships (single row per parent).
"""
__slots__ = ("rel_attr",)
def __init__(self, rel_attr: QueryableAttribute) -> None:
prop = rel_attr.property
if not isinstance(prop, RelationshipProperty):
raise TypeError(
f"lateral_load() requires a relationship attribute, got {type(prop).__name__}. "
"Example: lateral_load(User.team)"
)
if prop.secondary is not None:
raise ValueError(
f"lateral_load({rel_attr}) does not support Many:Many relationships. "
"Use selectinload() instead."
)
if prop.uselist:
raise ValueError(
f"lateral_load({rel_attr}) does not support One:Many relationships. "
"Use selectinload() instead."
)
self.rel_attr = rel_attr
def lateral_load(rel_attr: QueryableAttribute) -> _LateralLoad:
"""Mark a Many:One or One:One relationship for lateral join loading.
Raises ``ValueError`` for One:Many or Many:Many relationships.
"""
return _LateralLoad(rel_attr)
def _build_lateral_from_relationship(
rel_attr: QueryableAttribute,
) -> tuple[Any, Any, ExecutableOption]:
"""Introspect a Many:One relationship and build (lateral_subquery, true(), contains_eager)."""
prop = rel_attr.property
target_class = prop.mapper.class_
parent_class = prop.parent.class_
conditions = [
getattr(target_class, remote_col.key) == getattr(parent_class, local_col.key)
for local_col, remote_col in prop.local_remote_pairs
]
lateral_sub = (
select(target_class)
.where(and_(*conditions))
.correlate(parent_class)
.lateral(f"_lateral_{prop.key}")
)
return lateral_sub, true(), contains_eager(rel_attr, alias=lateral_sub)
def _apply_lateral_joins(q: Any, lateral_joins: LateralJoinType | None) -> Any:
"""Apply lateral subqueries as LEFT JOIN LATERAL to preserve all parent rows."""
if not lateral_joins:
return q
for subquery, condition in lateral_joins:
q = q.outerjoin(subquery, condition)
return q
def _apply_search_joins(q: Any, search_joins: list[Any]) -> Any:
"""Apply relationship-based outer joins (from search/filter_by) to a query."""
seen: set[str] = set()
@@ -132,12 +222,17 @@ class AsyncCrud(Generic[ModelType]):
Subclass this and set the `model` class variable, or use `CrudFactory`.
"""
_resolved_lateral: ClassVar[_ResolvedLateral | None] = None
model: ClassVar[type[DeclarativeBase]]
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
facet_fields: ClassVar[Sequence[FacetFieldType] | None] = None
order_fields: ClassVar[Sequence[OrderFieldType] | None] = None
m2m_fields: ClassVar[M2MFieldType | None] = None
default_load_options: ClassVar[Sequence[ExecutableOption] | None] = None
default_load_options: ClassVar[Sequence[ExecutableOption | _LateralLoad] | None] = (
None
)
lateral_joins: ClassVar[LateralJoinType | None] = None
cursor_column: ClassVar[Any | None] = None
@classmethod
@@ -161,14 +256,50 @@ class AsyncCrud(Generic[ModelType]):
):
cls.searchable_fields = [pk_col, *raw_fields]
raw_default_opts = cls.__dict__.get("default_load_options", None)
if raw_default_opts:
joins: LateralJoinType = []
eager: list[ExecutableOption] = []
clean: list[ExecutableOption] = []
for opt in raw_default_opts:
if isinstance(opt, _LateralLoad):
lat_sub, condition, eager_opt = _build_lateral_from_relationship(
opt.rel_attr
)
joins.append((lat_sub, condition))
eager.append(eager_opt)
else:
clean.append(opt)
if joins:
cls._resolved_lateral = _ResolvedLateral(joins=joins, eager=eager)
cls.default_load_options = clean or None
@classmethod
def _get_lateral_joins(cls) -> LateralJoinType | None:
"""Merge manual lateral_joins with ones resolved from default_load_options."""
resolved = cls._resolved_lateral
all_lateral = [
*(cls.lateral_joins or []),
*(resolved.joins if resolved else []),
]
return all_lateral or None
@classmethod
def _resolve_load_options(
cls, load_options: Sequence[ExecutableOption] | None
) -> Sequence[ExecutableOption] | None:
"""Return load_options if provided, else fall back to default_load_options."""
if load_options is not None:
return load_options
return cls.default_load_options
"""Return merged load options: call-site or default, always with lateral eager opts."""
resolved = cls._resolved_lateral
# default_load_options is cleaned of _LateralLoad markers in __init_subclass__,
# but its declared type still includes them — cast to reflect the runtime invariant.
base: Sequence[ExecutableOption] = (
load_options
if load_options is not None
else cast(list[ExecutableOption], cls.default_load_options or [])
)
lateral = resolved.eager if resolved else []
merged = [*base, *lateral]
return merged or None
@classmethod
async def _reload_with_options(
@@ -861,6 +992,7 @@ class AsyncCrud(Generic[ModelType]):
"""
q = select(cls.model)
q = _apply_joins(q, joins, outer_join)
q = _apply_lateral_joins(q, cls._get_lateral_joins())
q = q.where(and_(*filters))
if resolved := cls._resolve_load_options(load_options):
q = q.options(*resolved)
@@ -933,6 +1065,7 @@ class AsyncCrud(Generic[ModelType]):
"""
q = select(cls.model)
q = _apply_joins(q, joins, outer_join)
q = _apply_lateral_joins(q, cls._get_lateral_joins())
if filters:
q = q.where(and_(*filters))
if resolved := cls._resolve_load_options(load_options):
@@ -978,6 +1111,7 @@ class AsyncCrud(Generic[ModelType]):
"""
q = select(cls.model)
q = _apply_joins(q, joins, outer_join)
q = _apply_lateral_joins(q, cls._get_lateral_joins())
if filters:
q = q.where(and_(*filters))
if resolved := cls._resolve_load_options(load_options):
@@ -1300,6 +1434,9 @@ class AsyncCrud(Generic[ModelType]):
# Apply explicit joins
q = _apply_joins(q, joins, outer_join)
# Apply lateral joins (Many:One relationship loading, excluded from count query)
q = _apply_lateral_joins(q, cls._get_lateral_joins())
# Apply search joins (always outer joins for search)
q = _apply_search_joins(q, search_joins)
@@ -1455,6 +1592,9 @@ class AsyncCrud(Generic[ModelType]):
# Apply explicit joins
q = _apply_joins(q, joins, outer_join)
# Apply lateral joins (Many:One relationship loading)
q = _apply_lateral_joins(q, cls._get_lateral_joins())
# Apply search joins (always outer joins)
q = _apply_search_joins(q, search_joins)