mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-16 14:46:24 +02:00
feat: add lateral_load() for efficient Many:One eager loading via JOIN LATERAL
This commit is contained in:
@@ -15,12 +15,13 @@ from ..types import (
|
||||
OrderFieldType,
|
||||
SearchFieldType,
|
||||
)
|
||||
from .factory import AsyncCrud, CrudFactory
|
||||
from .factory import AsyncCrud, CrudFactory, lateral_load
|
||||
from .search import SearchConfig, get_searchable_fields
|
||||
|
||||
__all__ = [
|
||||
"AsyncCrud",
|
||||
"CrudFactory",
|
||||
"lateral_load",
|
||||
"FacetFieldType",
|
||||
"get_searchable_fields",
|
||||
"InvalidFacetFilterError",
|
||||
|
||||
@@ -10,15 +10,32 @@ from collections.abc import Awaitable, Callable, Sequence
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from typing import Any, ClassVar, Generic, Literal, Self, cast, overload
|
||||
from typing import Any, ClassVar, Generic, Literal, NamedTuple, Self, cast, overload
|
||||
|
||||
from fastapi import Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import Date, DateTime, Float, Integer, Numeric, Uuid, and_, func, select
|
||||
from sqlalchemy import (
|
||||
Date,
|
||||
DateTime,
|
||||
Float,
|
||||
Integer,
|
||||
Numeric,
|
||||
Uuid,
|
||||
and_,
|
||||
func,
|
||||
select,
|
||||
true,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import DeclarativeBase, QueryableAttribute, selectinload
|
||||
from sqlalchemy.orm import (
|
||||
DeclarativeBase,
|
||||
QueryableAttribute,
|
||||
RelationshipProperty,
|
||||
contains_eager,
|
||||
selectinload,
|
||||
)
|
||||
from sqlalchemy.sql.base import ExecutableOption
|
||||
from sqlalchemy.sql.roles import WhereHavingRole
|
||||
|
||||
@@ -35,6 +52,7 @@ from ..schemas import (
|
||||
from ..types import (
|
||||
FacetFieldType,
|
||||
JoinType,
|
||||
LateralJoinType,
|
||||
M2MFieldType,
|
||||
ModelType,
|
||||
OrderByClause,
|
||||
@@ -115,6 +133,78 @@ def _apply_joins(q: Any, joins: JoinType | None, outer_join: bool) -> Any:
|
||||
return q
|
||||
|
||||
|
||||
class _ResolvedLateral(NamedTuple):
|
||||
joins: LateralJoinType
|
||||
eager: list[ExecutableOption]
|
||||
|
||||
|
||||
class _LateralLoad:
|
||||
"""Marker used inside ``default_load_options`` for lateral join loading.
|
||||
|
||||
Supports only Many:One and One:One relationships (single row per parent).
|
||||
"""
|
||||
|
||||
__slots__ = ("rel_attr",)
|
||||
|
||||
def __init__(self, rel_attr: QueryableAttribute) -> None:
|
||||
prop = rel_attr.property
|
||||
if not isinstance(prop, RelationshipProperty):
|
||||
raise TypeError(
|
||||
f"lateral_load() requires a relationship attribute, got {type(prop).__name__}. "
|
||||
"Example: lateral_load(User.team)"
|
||||
)
|
||||
if prop.secondary is not None:
|
||||
raise ValueError(
|
||||
f"lateral_load({rel_attr}) does not support Many:Many relationships. "
|
||||
"Use selectinload() instead."
|
||||
)
|
||||
if prop.uselist:
|
||||
raise ValueError(
|
||||
f"lateral_load({rel_attr}) does not support One:Many relationships. "
|
||||
"Use selectinload() instead."
|
||||
)
|
||||
self.rel_attr = rel_attr
|
||||
|
||||
|
||||
def lateral_load(rel_attr: QueryableAttribute) -> _LateralLoad:
|
||||
"""Mark a Many:One or One:One relationship for lateral join loading.
|
||||
|
||||
Raises ``ValueError`` for One:Many or Many:Many relationships.
|
||||
"""
|
||||
return _LateralLoad(rel_attr)
|
||||
|
||||
|
||||
def _build_lateral_from_relationship(
|
||||
rel_attr: QueryableAttribute,
|
||||
) -> tuple[Any, Any, ExecutableOption]:
|
||||
"""Introspect a Many:One relationship and build (lateral_subquery, true(), contains_eager)."""
|
||||
prop = rel_attr.property
|
||||
target_class = prop.mapper.class_
|
||||
parent_class = prop.parent.class_
|
||||
|
||||
conditions = [
|
||||
getattr(target_class, remote_col.key) == getattr(parent_class, local_col.key)
|
||||
for local_col, remote_col in prop.local_remote_pairs
|
||||
]
|
||||
|
||||
lateral_sub = (
|
||||
select(target_class)
|
||||
.where(and_(*conditions))
|
||||
.correlate(parent_class)
|
||||
.lateral(f"_lateral_{prop.key}")
|
||||
)
|
||||
return lateral_sub, true(), contains_eager(rel_attr, alias=lateral_sub)
|
||||
|
||||
|
||||
def _apply_lateral_joins(q: Any, lateral_joins: LateralJoinType | None) -> Any:
|
||||
"""Apply lateral subqueries as LEFT JOIN LATERAL to preserve all parent rows."""
|
||||
if not lateral_joins:
|
||||
return q
|
||||
for subquery, condition in lateral_joins:
|
||||
q = q.outerjoin(subquery, condition)
|
||||
return q
|
||||
|
||||
|
||||
def _apply_search_joins(q: Any, search_joins: list[Any]) -> Any:
|
||||
"""Apply relationship-based outer joins (from search/filter_by) to a query."""
|
||||
seen: set[str] = set()
|
||||
@@ -132,12 +222,17 @@ class AsyncCrud(Generic[ModelType]):
|
||||
Subclass this and set the `model` class variable, or use `CrudFactory`.
|
||||
"""
|
||||
|
||||
_resolved_lateral: ClassVar[_ResolvedLateral | None] = None
|
||||
|
||||
model: ClassVar[type[DeclarativeBase]]
|
||||
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
|
||||
facet_fields: ClassVar[Sequence[FacetFieldType] | None] = None
|
||||
order_fields: ClassVar[Sequence[OrderFieldType] | None] = None
|
||||
m2m_fields: ClassVar[M2MFieldType | None] = None
|
||||
default_load_options: ClassVar[Sequence[ExecutableOption] | None] = None
|
||||
default_load_options: ClassVar[Sequence[ExecutableOption | _LateralLoad] | None] = (
|
||||
None
|
||||
)
|
||||
lateral_joins: ClassVar[LateralJoinType | None] = None
|
||||
cursor_column: ClassVar[Any | None] = None
|
||||
|
||||
@classmethod
|
||||
@@ -161,14 +256,48 @@ class AsyncCrud(Generic[ModelType]):
|
||||
):
|
||||
cls.searchable_fields = [pk_col, *raw_fields]
|
||||
|
||||
raw_default_opts = cls.__dict__.get("default_load_options", None)
|
||||
if raw_default_opts:
|
||||
joins: LateralJoinType = []
|
||||
eager: list[ExecutableOption] = []
|
||||
clean: list[ExecutableOption] = []
|
||||
for opt in raw_default_opts:
|
||||
if isinstance(opt, _LateralLoad):
|
||||
lat_sub, condition, eager_opt = _build_lateral_from_relationship(
|
||||
opt.rel_attr
|
||||
)
|
||||
joins.append((lat_sub, condition))
|
||||
eager.append(eager_opt)
|
||||
else:
|
||||
clean.append(opt)
|
||||
if joins:
|
||||
cls._resolved_lateral = _ResolvedLateral(joins=joins, eager=eager)
|
||||
cls.default_load_options = clean or None
|
||||
|
||||
@classmethod
|
||||
def _get_lateral_joins(cls) -> LateralJoinType | None:
|
||||
"""Merge manual lateral_joins with ones resolved from default_load_options."""
|
||||
resolved = cls._resolved_lateral
|
||||
all_lateral = [
|
||||
*(cls.lateral_joins or []),
|
||||
*(resolved.joins if resolved else []),
|
||||
]
|
||||
return all_lateral or None
|
||||
|
||||
@classmethod
|
||||
def _resolve_load_options(
|
||||
cls, load_options: Sequence[ExecutableOption] | None
|
||||
) -> Sequence[ExecutableOption] | None:
|
||||
"""Return load_options if provided, else fall back to default_load_options."""
|
||||
"""Return merged load options."""
|
||||
if load_options is not None:
|
||||
return load_options
|
||||
return cls.default_load_options
|
||||
return list(load_options) or None
|
||||
resolved = cls._resolved_lateral
|
||||
# default_load_options is cleaned of _LateralLoad markers in __init_subclass__,
|
||||
# but its declared type still includes them — cast to reflect the runtime invariant.
|
||||
base = cast(list[ExecutableOption], cls.default_load_options or [])
|
||||
lateral = resolved.eager if resolved else []
|
||||
merged = [*base, *lateral]
|
||||
return merged or None
|
||||
|
||||
@classmethod
|
||||
async def _reload_with_options(
|
||||
@@ -861,6 +990,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
"""
|
||||
q = select(cls.model)
|
||||
q = _apply_joins(q, joins, outer_join)
|
||||
if load_options is None:
|
||||
q = _apply_lateral_joins(q, cls._get_lateral_joins())
|
||||
q = q.where(and_(*filters))
|
||||
if resolved := cls._resolve_load_options(load_options):
|
||||
q = q.options(*resolved)
|
||||
@@ -933,6 +1064,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
"""
|
||||
q = select(cls.model)
|
||||
q = _apply_joins(q, joins, outer_join)
|
||||
if load_options is None:
|
||||
q = _apply_lateral_joins(q, cls._get_lateral_joins())
|
||||
if filters:
|
||||
q = q.where(and_(*filters))
|
||||
if resolved := cls._resolve_load_options(load_options):
|
||||
@@ -978,6 +1111,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
"""
|
||||
q = select(cls.model)
|
||||
q = _apply_joins(q, joins, outer_join)
|
||||
if load_options is None:
|
||||
q = _apply_lateral_joins(q, cls._get_lateral_joins())
|
||||
if filters:
|
||||
q = q.where(and_(*filters))
|
||||
if resolved := cls._resolve_load_options(load_options):
|
||||
@@ -1300,6 +1435,10 @@ class AsyncCrud(Generic[ModelType]):
|
||||
# Apply explicit joins
|
||||
q = _apply_joins(q, joins, outer_join)
|
||||
|
||||
# Apply lateral joins (Many:One relationship loading, excluded from count query)
|
||||
if load_options is None:
|
||||
q = _apply_lateral_joins(q, cls._get_lateral_joins())
|
||||
|
||||
# Apply search joins (always outer joins for search)
|
||||
q = _apply_search_joins(q, search_joins)
|
||||
|
||||
@@ -1398,7 +1537,9 @@ class AsyncCrud(Generic[ModelType]):
|
||||
tables.
|
||||
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN.
|
||||
load_options: SQLAlchemy loader options. Falls back to
|
||||
``default_load_options`` when not provided.
|
||||
``default_load_options`` (including any lateral joins) when not
|
||||
provided. When explicitly supplied, the caller takes full control
|
||||
and lateral joins are skipped.
|
||||
order_by: Additional ordering applied after the cursor column.
|
||||
items_per_page: Number of items per page (default 20).
|
||||
search: Search query string or SearchConfig object.
|
||||
@@ -1455,6 +1596,10 @@ class AsyncCrud(Generic[ModelType]):
|
||||
# Apply explicit joins
|
||||
q = _apply_joins(q, joins, outer_join)
|
||||
|
||||
# Apply lateral joins (Many:One relationship loading)
|
||||
if load_options is None:
|
||||
q = _apply_lateral_joins(q, cls._get_lateral_joins())
|
||||
|
||||
# Apply search joins (always outer joins)
|
||||
q = _apply_search_joins(q, search_joins)
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ SchemaType = TypeVar("SchemaType", bound=BaseModel)
|
||||
|
||||
# CRUD type aliases
|
||||
JoinType = list[tuple[type[DeclarativeBase] | Any, Any]]
|
||||
LateralJoinType = list[tuple[Any, Any]]
|
||||
M2MFieldType = Mapping[str, QueryableAttribute[Any]]
|
||||
OrderByClause = ColumnElement[Any] | QueryableAttribute[Any]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user