mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-16 06:36:26 +02:00
feat: add lateral_load() for efficient Many:One eager loading via JOIN LATERAL
This commit is contained in:
@@ -15,12 +15,13 @@ from ..types import (
|
|||||||
OrderFieldType,
|
OrderFieldType,
|
||||||
SearchFieldType,
|
SearchFieldType,
|
||||||
)
|
)
|
||||||
from .factory import AsyncCrud, CrudFactory
|
from .factory import AsyncCrud, CrudFactory, lateral_load
|
||||||
from .search import SearchConfig, get_searchable_fields
|
from .search import SearchConfig, get_searchable_fields
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AsyncCrud",
|
"AsyncCrud",
|
||||||
"CrudFactory",
|
"CrudFactory",
|
||||||
|
"lateral_load",
|
||||||
"FacetFieldType",
|
"FacetFieldType",
|
||||||
"get_searchable_fields",
|
"get_searchable_fields",
|
||||||
"InvalidFacetFilterError",
|
"InvalidFacetFilterError",
|
||||||
|
|||||||
@@ -10,15 +10,32 @@ from collections.abc import Awaitable, Callable, Sequence
|
|||||||
from datetime import date, datetime
|
from datetime import date, datetime
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, ClassVar, Generic, Literal, Self, cast, overload
|
from typing import Any, ClassVar, Generic, Literal, NamedTuple, Self, cast, overload
|
||||||
|
|
||||||
from fastapi import Query
|
from fastapi import Query
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import Date, DateTime, Float, Integer, Numeric, Uuid, and_, func, select
|
from sqlalchemy import (
|
||||||
|
Date,
|
||||||
|
DateTime,
|
||||||
|
Float,
|
||||||
|
Integer,
|
||||||
|
Numeric,
|
||||||
|
Uuid,
|
||||||
|
and_,
|
||||||
|
func,
|
||||||
|
select,
|
||||||
|
true,
|
||||||
|
)
|
||||||
from sqlalchemy.dialects.postgresql import insert
|
from sqlalchemy.dialects.postgresql import insert
|
||||||
from sqlalchemy.exc import NoResultFound
|
from sqlalchemy.exc import NoResultFound
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.orm import DeclarativeBase, QueryableAttribute, selectinload
|
from sqlalchemy.orm import (
|
||||||
|
DeclarativeBase,
|
||||||
|
QueryableAttribute,
|
||||||
|
RelationshipProperty,
|
||||||
|
contains_eager,
|
||||||
|
selectinload,
|
||||||
|
)
|
||||||
from sqlalchemy.sql.base import ExecutableOption
|
from sqlalchemy.sql.base import ExecutableOption
|
||||||
from sqlalchemy.sql.roles import WhereHavingRole
|
from sqlalchemy.sql.roles import WhereHavingRole
|
||||||
|
|
||||||
@@ -35,6 +52,7 @@ from ..schemas import (
|
|||||||
from ..types import (
|
from ..types import (
|
||||||
FacetFieldType,
|
FacetFieldType,
|
||||||
JoinType,
|
JoinType,
|
||||||
|
LateralJoinType,
|
||||||
M2MFieldType,
|
M2MFieldType,
|
||||||
ModelType,
|
ModelType,
|
||||||
OrderByClause,
|
OrderByClause,
|
||||||
@@ -115,6 +133,78 @@ def _apply_joins(q: Any, joins: JoinType | None, outer_join: bool) -> Any:
|
|||||||
return q
|
return q
|
||||||
|
|
||||||
|
|
||||||
|
class _ResolvedLateral(NamedTuple):
|
||||||
|
joins: LateralJoinType
|
||||||
|
eager: list[ExecutableOption]
|
||||||
|
|
||||||
|
|
||||||
|
class _LateralLoad:
|
||||||
|
"""Marker used inside ``default_load_options`` for lateral join loading.
|
||||||
|
|
||||||
|
Supports only Many:One and One:One relationships (single row per parent).
|
||||||
|
"""
|
||||||
|
|
||||||
|
__slots__ = ("rel_attr",)
|
||||||
|
|
||||||
|
def __init__(self, rel_attr: QueryableAttribute) -> None:
|
||||||
|
prop = rel_attr.property
|
||||||
|
if not isinstance(prop, RelationshipProperty):
|
||||||
|
raise TypeError(
|
||||||
|
f"lateral_load() requires a relationship attribute, got {type(prop).__name__}. "
|
||||||
|
"Example: lateral_load(User.team)"
|
||||||
|
)
|
||||||
|
if prop.secondary is not None:
|
||||||
|
raise ValueError(
|
||||||
|
f"lateral_load({rel_attr}) does not support Many:Many relationships. "
|
||||||
|
"Use selectinload() instead."
|
||||||
|
)
|
||||||
|
if prop.uselist:
|
||||||
|
raise ValueError(
|
||||||
|
f"lateral_load({rel_attr}) does not support One:Many relationships. "
|
||||||
|
"Use selectinload() instead."
|
||||||
|
)
|
||||||
|
self.rel_attr = rel_attr
|
||||||
|
|
||||||
|
|
||||||
|
def lateral_load(rel_attr: QueryableAttribute) -> _LateralLoad:
|
||||||
|
"""Mark a Many:One or One:One relationship for lateral join loading.
|
||||||
|
|
||||||
|
Raises ``ValueError`` for One:Many or Many:Many relationships.
|
||||||
|
"""
|
||||||
|
return _LateralLoad(rel_attr)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_lateral_from_relationship(
|
||||||
|
rel_attr: QueryableAttribute,
|
||||||
|
) -> tuple[Any, Any, ExecutableOption]:
|
||||||
|
"""Introspect a Many:One relationship and build (lateral_subquery, true(), contains_eager)."""
|
||||||
|
prop = rel_attr.property
|
||||||
|
target_class = prop.mapper.class_
|
||||||
|
parent_class = prop.parent.class_
|
||||||
|
|
||||||
|
conditions = [
|
||||||
|
getattr(target_class, remote_col.key) == getattr(parent_class, local_col.key)
|
||||||
|
for local_col, remote_col in prop.local_remote_pairs
|
||||||
|
]
|
||||||
|
|
||||||
|
lateral_sub = (
|
||||||
|
select(target_class)
|
||||||
|
.where(and_(*conditions))
|
||||||
|
.correlate(parent_class)
|
||||||
|
.lateral(f"_lateral_{prop.key}")
|
||||||
|
)
|
||||||
|
return lateral_sub, true(), contains_eager(rel_attr, alias=lateral_sub)
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_lateral_joins(q: Any, lateral_joins: LateralJoinType | None) -> Any:
|
||||||
|
"""Apply lateral subqueries as LEFT JOIN LATERAL to preserve all parent rows."""
|
||||||
|
if not lateral_joins:
|
||||||
|
return q
|
||||||
|
for subquery, condition in lateral_joins:
|
||||||
|
q = q.outerjoin(subquery, condition)
|
||||||
|
return q
|
||||||
|
|
||||||
|
|
||||||
def _apply_search_joins(q: Any, search_joins: list[Any]) -> Any:
|
def _apply_search_joins(q: Any, search_joins: list[Any]) -> Any:
|
||||||
"""Apply relationship-based outer joins (from search/filter_by) to a query."""
|
"""Apply relationship-based outer joins (from search/filter_by) to a query."""
|
||||||
seen: set[str] = set()
|
seen: set[str] = set()
|
||||||
@@ -132,12 +222,17 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
Subclass this and set the `model` class variable, or use `CrudFactory`.
|
Subclass this and set the `model` class variable, or use `CrudFactory`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_resolved_lateral: ClassVar[_ResolvedLateral | None] = None
|
||||||
|
|
||||||
model: ClassVar[type[DeclarativeBase]]
|
model: ClassVar[type[DeclarativeBase]]
|
||||||
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
|
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
|
||||||
facet_fields: ClassVar[Sequence[FacetFieldType] | None] = None
|
facet_fields: ClassVar[Sequence[FacetFieldType] | None] = None
|
||||||
order_fields: ClassVar[Sequence[OrderFieldType] | None] = None
|
order_fields: ClassVar[Sequence[OrderFieldType] | None] = None
|
||||||
m2m_fields: ClassVar[M2MFieldType | None] = None
|
m2m_fields: ClassVar[M2MFieldType | None] = None
|
||||||
default_load_options: ClassVar[Sequence[ExecutableOption] | None] = None
|
default_load_options: ClassVar[Sequence[ExecutableOption | _LateralLoad] | None] = (
|
||||||
|
None
|
||||||
|
)
|
||||||
|
lateral_joins: ClassVar[LateralJoinType | None] = None
|
||||||
cursor_column: ClassVar[Any | None] = None
|
cursor_column: ClassVar[Any | None] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -161,14 +256,50 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
):
|
):
|
||||||
cls.searchable_fields = [pk_col, *raw_fields]
|
cls.searchable_fields = [pk_col, *raw_fields]
|
||||||
|
|
||||||
|
raw_default_opts = cls.__dict__.get("default_load_options", None)
|
||||||
|
if raw_default_opts:
|
||||||
|
joins: LateralJoinType = []
|
||||||
|
eager: list[ExecutableOption] = []
|
||||||
|
clean: list[ExecutableOption] = []
|
||||||
|
for opt in raw_default_opts:
|
||||||
|
if isinstance(opt, _LateralLoad):
|
||||||
|
lat_sub, condition, eager_opt = _build_lateral_from_relationship(
|
||||||
|
opt.rel_attr
|
||||||
|
)
|
||||||
|
joins.append((lat_sub, condition))
|
||||||
|
eager.append(eager_opt)
|
||||||
|
else:
|
||||||
|
clean.append(opt)
|
||||||
|
if joins:
|
||||||
|
cls._resolved_lateral = _ResolvedLateral(joins=joins, eager=eager)
|
||||||
|
cls.default_load_options = clean or None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_lateral_joins(cls) -> LateralJoinType | None:
|
||||||
|
"""Merge manual lateral_joins with ones resolved from default_load_options."""
|
||||||
|
resolved = cls._resolved_lateral
|
||||||
|
all_lateral = [
|
||||||
|
*(cls.lateral_joins or []),
|
||||||
|
*(resolved.joins if resolved else []),
|
||||||
|
]
|
||||||
|
return all_lateral or None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _resolve_load_options(
|
def _resolve_load_options(
|
||||||
cls, load_options: Sequence[ExecutableOption] | None
|
cls, load_options: Sequence[ExecutableOption] | None
|
||||||
) -> Sequence[ExecutableOption] | None:
|
) -> Sequence[ExecutableOption] | None:
|
||||||
"""Return load_options if provided, else fall back to default_load_options."""
|
"""Return merged load options: call-site or default, always with lateral eager opts."""
|
||||||
if load_options is not None:
|
resolved = cls._resolved_lateral
|
||||||
return load_options
|
# default_load_options is cleaned of _LateralLoad markers in __init_subclass__,
|
||||||
return cls.default_load_options
|
# but its declared type still includes them — cast to reflect the runtime invariant.
|
||||||
|
base: Sequence[ExecutableOption] = (
|
||||||
|
load_options
|
||||||
|
if load_options is not None
|
||||||
|
else cast(list[ExecutableOption], cls.default_load_options or [])
|
||||||
|
)
|
||||||
|
lateral = resolved.eager if resolved else []
|
||||||
|
merged = [*base, *lateral]
|
||||||
|
return merged or None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _reload_with_options(
|
async def _reload_with_options(
|
||||||
@@ -861,6 +992,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
"""
|
"""
|
||||||
q = select(cls.model)
|
q = select(cls.model)
|
||||||
q = _apply_joins(q, joins, outer_join)
|
q = _apply_joins(q, joins, outer_join)
|
||||||
|
q = _apply_lateral_joins(q, cls._get_lateral_joins())
|
||||||
q = q.where(and_(*filters))
|
q = q.where(and_(*filters))
|
||||||
if resolved := cls._resolve_load_options(load_options):
|
if resolved := cls._resolve_load_options(load_options):
|
||||||
q = q.options(*resolved)
|
q = q.options(*resolved)
|
||||||
@@ -933,6 +1065,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
"""
|
"""
|
||||||
q = select(cls.model)
|
q = select(cls.model)
|
||||||
q = _apply_joins(q, joins, outer_join)
|
q = _apply_joins(q, joins, outer_join)
|
||||||
|
q = _apply_lateral_joins(q, cls._get_lateral_joins())
|
||||||
if filters:
|
if filters:
|
||||||
q = q.where(and_(*filters))
|
q = q.where(and_(*filters))
|
||||||
if resolved := cls._resolve_load_options(load_options):
|
if resolved := cls._resolve_load_options(load_options):
|
||||||
@@ -978,6 +1111,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
"""
|
"""
|
||||||
q = select(cls.model)
|
q = select(cls.model)
|
||||||
q = _apply_joins(q, joins, outer_join)
|
q = _apply_joins(q, joins, outer_join)
|
||||||
|
q = _apply_lateral_joins(q, cls._get_lateral_joins())
|
||||||
if filters:
|
if filters:
|
||||||
q = q.where(and_(*filters))
|
q = q.where(and_(*filters))
|
||||||
if resolved := cls._resolve_load_options(load_options):
|
if resolved := cls._resolve_load_options(load_options):
|
||||||
@@ -1300,6 +1434,9 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
# Apply explicit joins
|
# Apply explicit joins
|
||||||
q = _apply_joins(q, joins, outer_join)
|
q = _apply_joins(q, joins, outer_join)
|
||||||
|
|
||||||
|
# Apply lateral joins (Many:One relationship loading, excluded from count query)
|
||||||
|
q = _apply_lateral_joins(q, cls._get_lateral_joins())
|
||||||
|
|
||||||
# Apply search joins (always outer joins for search)
|
# Apply search joins (always outer joins for search)
|
||||||
q = _apply_search_joins(q, search_joins)
|
q = _apply_search_joins(q, search_joins)
|
||||||
|
|
||||||
@@ -1455,6 +1592,9 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
# Apply explicit joins
|
# Apply explicit joins
|
||||||
q = _apply_joins(q, joins, outer_join)
|
q = _apply_joins(q, joins, outer_join)
|
||||||
|
|
||||||
|
# Apply lateral joins (Many:One relationship loading)
|
||||||
|
q = _apply_lateral_joins(q, cls._get_lateral_joins())
|
||||||
|
|
||||||
# Apply search joins (always outer joins)
|
# Apply search joins (always outer joins)
|
||||||
q = _apply_search_joins(q, search_joins)
|
q = _apply_search_joins(q, search_joins)
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ SchemaType = TypeVar("SchemaType", bound=BaseModel)
|
|||||||
|
|
||||||
# CRUD type aliases
|
# CRUD type aliases
|
||||||
JoinType = list[tuple[type[DeclarativeBase] | Any, Any]]
|
JoinType = list[tuple[type[DeclarativeBase] | Any, Any]]
|
||||||
|
LateralJoinType = list[tuple[Any, Any]]
|
||||||
M2MFieldType = Mapping[str, QueryableAttribute[Any]]
|
M2MFieldType = Mapping[str, QueryableAttribute[Any]]
|
||||||
OrderByClause = ColumnElement[Any] | QueryableAttribute[Any]
|
OrderByClause = ColumnElement[Any] | QueryableAttribute[Any]
|
||||||
|
|
||||||
|
|||||||
@@ -6,9 +6,15 @@ import pytest
|
|||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
from fastapi_toolsets.crud import CrudFactory, PaginationType
|
from fastapi_toolsets.crud import CrudFactory, PaginationType, lateral_load
|
||||||
from fastapi_toolsets.crud.factory import AsyncCrud, _CursorDirection
|
from fastapi_toolsets.crud.factory import (
|
||||||
|
AsyncCrud,
|
||||||
|
_CursorDirection,
|
||||||
|
_LateralLoad,
|
||||||
|
_ResolvedLateral,
|
||||||
|
)
|
||||||
from fastapi_toolsets.exceptions import NotFoundError
|
from fastapi_toolsets.exceptions import NotFoundError
|
||||||
|
from fastapi_toolsets.schemas import PydanticBase
|
||||||
|
|
||||||
from .conftest import (
|
from .conftest import (
|
||||||
EventCreate,
|
EventCreate,
|
||||||
@@ -51,6 +57,12 @@ from .conftest import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UserWithRoleRead(PydanticBase):
|
||||||
|
id: uuid.UUID
|
||||||
|
username: str
|
||||||
|
role: RoleRead | None = None
|
||||||
|
|
||||||
|
|
||||||
class TestCrudFactory:
|
class TestCrudFactory:
|
||||||
"""Tests for CrudFactory."""
|
"""Tests for CrudFactory."""
|
||||||
|
|
||||||
@@ -208,11 +220,11 @@ class TestResolveLoadOptions:
|
|||||||
assert crud._resolve_load_options(None) is None
|
assert crud._resolve_load_options(None) is None
|
||||||
|
|
||||||
def test_empty_list_overrides_default(self):
|
def test_empty_list_overrides_default(self):
|
||||||
"""An empty list is a valid override and disables default_load_options."""
|
"""An explicit empty list disables default_load_options (no options applied)."""
|
||||||
default = [selectinload(User.role)]
|
default = [selectinload(User.role)]
|
||||||
crud = CrudFactory(User, default_load_options=default)
|
crud = CrudFactory(User, default_load_options=default)
|
||||||
# Empty list is not None, so it should replace default
|
# Empty list replaces default; None and [] are both falsy → no options applied
|
||||||
assert crud._resolve_load_options([]) == []
|
assert not crud._resolve_load_options([])
|
||||||
|
|
||||||
|
|
||||||
class TestResolveSearchColumns:
|
class TestResolveSearchColumns:
|
||||||
@@ -359,13 +371,6 @@ class TestDefaultLoadOptionsIntegration:
|
|||||||
self, db_session: AsyncSession
|
self, db_session: AsyncSession
|
||||||
):
|
):
|
||||||
"""default_load_options loads relationships automatically on offset_paginate()."""
|
"""default_load_options loads relationships automatically on offset_paginate()."""
|
||||||
from fastapi_toolsets.schemas import PydanticBase
|
|
||||||
|
|
||||||
class UserWithRoleRead(PydanticBase):
|
|
||||||
id: uuid.UUID
|
|
||||||
username: str
|
|
||||||
role: RoleRead | None = None
|
|
||||||
|
|
||||||
UserWithDefaultLoad = CrudFactory(
|
UserWithDefaultLoad = CrudFactory(
|
||||||
User, default_load_options=[selectinload(User.role)]
|
User, default_load_options=[selectinload(User.role)]
|
||||||
)
|
)
|
||||||
@@ -2462,12 +2467,7 @@ class TestCursorPaginateExtraOptions:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_with_load_options(self, db_session: AsyncSession):
|
async def test_with_load_options(self, db_session: AsyncSession):
|
||||||
"""cursor_paginate passes load_options to the query."""
|
"""cursor_paginate passes load_options to the query."""
|
||||||
from fastapi_toolsets.schemas import CursorPagination, PydanticBase
|
from fastapi_toolsets.schemas import CursorPagination
|
||||||
|
|
||||||
class UserWithRoleRead(PydanticBase):
|
|
||||||
id: uuid.UUID
|
|
||||||
username: str
|
|
||||||
role: RoleRead | None = None
|
|
||||||
|
|
||||||
role = await RoleCrud.create(db_session, RoleCreate(name="manager"))
|
role = await RoleCrud.create(db_session, RoleCreate(name="manager"))
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
@@ -2833,3 +2833,368 @@ class TestPaginate:
|
|||||||
|
|
||||||
assert isinstance(result.pagination, OffsetPagination)
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
assert result.pagination.total_count is None
|
assert result.pagination.total_count is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestLateralLoadValidation:
|
||||||
|
"""lateral_load() raises immediately for bad relationship types."""
|
||||||
|
|
||||||
|
def test_valid_many_to_one_returns_marker(self):
|
||||||
|
"""lateral_load() on a Many:One rel returns a _LateralLoad with rel_attr set."""
|
||||||
|
marker = lateral_load(User.role)
|
||||||
|
assert isinstance(marker, _LateralLoad)
|
||||||
|
assert marker.rel_attr is User.role
|
||||||
|
|
||||||
|
def test_raises_type_error_for_plain_column(self):
|
||||||
|
"""lateral_load() raises TypeError when passed a plain column."""
|
||||||
|
with pytest.raises(TypeError, match="relationship attribute"):
|
||||||
|
lateral_load(User.username)
|
||||||
|
|
||||||
|
def test_raises_value_error_for_many_to_many(self):
|
||||||
|
"""lateral_load() raises ValueError for Many:Many (secondary table)."""
|
||||||
|
with pytest.raises(ValueError, match="Many:Many"):
|
||||||
|
lateral_load(Post.tags)
|
||||||
|
|
||||||
|
def test_raises_value_error_for_one_to_many(self):
|
||||||
|
"""lateral_load() raises ValueError for One:Many (uselist=True)."""
|
||||||
|
with pytest.raises(ValueError, match="One:Many"):
|
||||||
|
lateral_load(Role.users)
|
||||||
|
|
||||||
|
|
||||||
|
class TestLateralLoadInSubclass:
|
||||||
|
"""lateral_load() markers in default_load_options are processed at class definition."""
|
||||||
|
|
||||||
|
def test_marker_extracted_from_default_load_options(self):
|
||||||
|
"""_LateralLoad is removed from default_load_options and stored in _resolved_lateral."""
|
||||||
|
|
||||||
|
class UserLateralCrud(AsyncCrud[User]):
|
||||||
|
model = User
|
||||||
|
default_load_options = [lateral_load(User.role)]
|
||||||
|
|
||||||
|
assert UserLateralCrud.default_load_options is None
|
||||||
|
assert UserLateralCrud._resolved_lateral is not None
|
||||||
|
|
||||||
|
def test_resolved_lateral_has_one_join_and_eager(self):
|
||||||
|
"""_resolved_lateral contains exactly one join and one eager option."""
|
||||||
|
|
||||||
|
class UserLateralCrud(AsyncCrud[User]):
|
||||||
|
model = User
|
||||||
|
default_load_options = [lateral_load(User.role)]
|
||||||
|
|
||||||
|
resolved = UserLateralCrud._resolved_lateral
|
||||||
|
assert isinstance(resolved, _ResolvedLateral)
|
||||||
|
assert len(resolved.joins) == 1
|
||||||
|
assert len(resolved.eager) == 1
|
||||||
|
|
||||||
|
def test_regular_options_preserved_alongside_lateral(self):
|
||||||
|
"""Non-lateral opts stay in default_load_options; lateral marker is extracted."""
|
||||||
|
regular = selectinload(User.role)
|
||||||
|
|
||||||
|
class UserMixedCrud(AsyncCrud[User]):
|
||||||
|
model = User
|
||||||
|
default_load_options = [lateral_load(User.role), regular]
|
||||||
|
|
||||||
|
assert UserMixedCrud._resolved_lateral is not None
|
||||||
|
assert UserMixedCrud.default_load_options == [regular]
|
||||||
|
|
||||||
|
def test_no_lateral_leaves_default_load_options_untouched(self):
|
||||||
|
"""When no lateral marker is present, default_load_options is unchanged."""
|
||||||
|
opts = [selectinload(User.role)]
|
||||||
|
|
||||||
|
class UserNormalCrud(AsyncCrud[User]):
|
||||||
|
model = User
|
||||||
|
default_load_options = opts
|
||||||
|
|
||||||
|
assert UserNormalCrud.default_load_options is opts
|
||||||
|
assert UserNormalCrud._resolved_lateral is None
|
||||||
|
|
||||||
|
def test_no_default_load_options_leaves_resolved_lateral_none(self):
|
||||||
|
"""_resolved_lateral stays None when default_load_options is not set."""
|
||||||
|
|
||||||
|
class UserPlainCrud(AsyncCrud[User]):
|
||||||
|
model = User
|
||||||
|
|
||||||
|
assert UserPlainCrud._resolved_lateral is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestResolveLoadOptionsWithLateral:
|
||||||
|
"""_resolve_load_options always appends lateral eager options."""
|
||||||
|
|
||||||
|
def test_lateral_eager_included_when_no_call_site_opts(self):
|
||||||
|
"""contains_eager from lateral_load is returned when load_options=None."""
|
||||||
|
|
||||||
|
class UserLateralCrud(AsyncCrud[User]):
|
||||||
|
model = User
|
||||||
|
default_load_options = [lateral_load(User.role)]
|
||||||
|
|
||||||
|
resolved = UserLateralCrud._resolve_load_options(None)
|
||||||
|
assert resolved is not None
|
||||||
|
assert len(resolved) == 1 # the contains_eager
|
||||||
|
|
||||||
|
def test_lateral_eager_appended_to_call_site_opts(self):
|
||||||
|
"""call-site load_options + lateral eager are both returned."""
|
||||||
|
extra = selectinload(User.role)
|
||||||
|
|
||||||
|
class UserLateralCrud(AsyncCrud[User]):
|
||||||
|
model = User
|
||||||
|
default_load_options = [lateral_load(User.role)]
|
||||||
|
|
||||||
|
resolved = UserLateralCrud._resolve_load_options([extra])
|
||||||
|
assert resolved is not None
|
||||||
|
assert len(resolved) == 2
|
||||||
|
|
||||||
|
def test_lateral_eager_appended_to_default_load_options(self):
|
||||||
|
"""default_load_options (regular) + lateral eager are both returned."""
|
||||||
|
regular = selectinload(User.role)
|
||||||
|
|
||||||
|
class UserMixedCrud(AsyncCrud[User]):
|
||||||
|
model = User
|
||||||
|
default_load_options = [lateral_load(User.role), regular]
|
||||||
|
|
||||||
|
resolved = UserMixedCrud._resolve_load_options(None)
|
||||||
|
assert resolved is not None
|
||||||
|
assert len(resolved) == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetLateralJoins:
|
||||||
|
"""_get_lateral_joins merges auto-resolved and manual lateral_joins."""
|
||||||
|
|
||||||
|
def test_returns_none_when_no_lateral_configured(self):
|
||||||
|
"""Returns None when neither lateral_joins nor lateral_load is set."""
|
||||||
|
|
||||||
|
class UserPlainCrud(AsyncCrud[User]):
|
||||||
|
model = User
|
||||||
|
|
||||||
|
assert UserPlainCrud._get_lateral_joins() is None
|
||||||
|
|
||||||
|
def test_returns_resolved_lateral_joins(self):
|
||||||
|
"""Returns the join tuple built from lateral_load()."""
|
||||||
|
|
||||||
|
class UserLateralCrud(AsyncCrud[User]):
|
||||||
|
model = User
|
||||||
|
default_load_options = [lateral_load(User.role)]
|
||||||
|
|
||||||
|
joins = UserLateralCrud._get_lateral_joins()
|
||||||
|
assert joins is not None
|
||||||
|
assert len(joins) == 1
|
||||||
|
|
||||||
|
def test_manual_lateral_joins_included(self):
|
||||||
|
"""Manual lateral_joins class var is included in _get_lateral_joins."""
|
||||||
|
from sqlalchemy import select, true
|
||||||
|
|
||||||
|
manual_sub = select(Role).where(Role.id == User.role_id).lateral("_manual_role")
|
||||||
|
|
||||||
|
class UserManualCrud(AsyncCrud[User]):
|
||||||
|
model = User
|
||||||
|
lateral_joins = [(manual_sub, true())]
|
||||||
|
|
||||||
|
joins = UserManualCrud._get_lateral_joins()
|
||||||
|
assert joins is not None
|
||||||
|
assert len(joins) == 1
|
||||||
|
|
||||||
|
def test_manual_and_auto_lateral_joins_merged(self):
|
||||||
|
"""Both manual lateral_joins and auto-resolved from lateral_load are combined."""
|
||||||
|
from sqlalchemy import select, true
|
||||||
|
|
||||||
|
manual_sub = select(Role).where(Role.id == User.role_id).lateral("_manual_role")
|
||||||
|
|
||||||
|
class UserBothCrud(AsyncCrud[User]):
|
||||||
|
model = User
|
||||||
|
lateral_joins = [(manual_sub, true())]
|
||||||
|
default_load_options = [lateral_load(User.role)]
|
||||||
|
|
||||||
|
joins = UserBothCrud._get_lateral_joins()
|
||||||
|
assert joins is not None
|
||||||
|
assert len(joins) == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestLateralLoadIntegration:
|
||||||
|
"""lateral_load() in real DB queries: relationship loaded, pagination correct."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_get_loads_relationship(self, db_session: AsyncSession):
|
||||||
|
"""get() populates the relationship via lateral join."""
|
||||||
|
|
||||||
|
class UserLateralCrud(AsyncCrud[User]):
|
||||||
|
model = User
|
||||||
|
default_load_options = [lateral_load(User.role)]
|
||||||
|
|
||||||
|
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="alice", email="alice@test.com", role_id=role.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
fetched = await UserLateralCrud.get(db_session, [User.id == user.id])
|
||||||
|
assert fetched.role is not None
|
||||||
|
assert fetched.role.name == "admin"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_get_null_fk_preserved(self, db_session: AsyncSession):
|
||||||
|
"""User with null role_id still returned (LEFT JOIN behaviour)."""
|
||||||
|
|
||||||
|
class UserLateralCrud(AsyncCrud[User]):
|
||||||
|
model = User
|
||||||
|
default_load_options = [lateral_load(User.role)]
|
||||||
|
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="bob", email="bob@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
fetched = await UserLateralCrud.get(db_session, [User.id == user.id])
|
||||||
|
assert fetched is not None
|
||||||
|
assert fetched.role is None
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_first_loads_relationship(self, db_session: AsyncSession):
|
||||||
|
"""first() populates the relationship via lateral join."""
|
||||||
|
|
||||||
|
class UserLateralCrud(AsyncCrud[User]):
|
||||||
|
model = User
|
||||||
|
default_load_options = [lateral_load(User.role)]
|
||||||
|
|
||||||
|
role = await RoleCrud.create(db_session, RoleCreate(name="editor"))
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="carol", email="carol@test.com", role_id=role.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
user = await UserLateralCrud.first(db_session)
|
||||||
|
assert user is not None
|
||||||
|
assert user.role is not None
|
||||||
|
assert user.role.name == "editor"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_get_multi_loads_relationship(self, db_session: AsyncSession):
|
||||||
|
"""get_multi() populates the relationship via lateral join for all rows."""
|
||||||
|
|
||||||
|
class UserLateralCrud(AsyncCrud[User]):
|
||||||
|
model = User
|
||||||
|
default_load_options = [lateral_load(User.role)]
|
||||||
|
|
||||||
|
role = await RoleCrud.create(db_session, RoleCreate(name="member"))
|
||||||
|
for i in range(3):
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(
|
||||||
|
username=f"user{i}", email=f"u{i}@test.com", role_id=role.id
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
users = await UserLateralCrud.get_multi(db_session)
|
||||||
|
assert len(users) == 3
|
||||||
|
assert all(u.role is not None and u.role.name == "member" for u in users)
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_offset_paginate_correct_count(self, db_session: AsyncSession):
|
||||||
|
"""offset_paginate total_count is not inflated by the lateral join."""
|
||||||
|
|
||||||
|
class UserLateralCrud(AsyncCrud[User]):
|
||||||
|
model = User
|
||||||
|
default_load_options = [lateral_load(User.role)]
|
||||||
|
|
||||||
|
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||||
|
for i in range(5):
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(
|
||||||
|
username=f"user{i}", email=f"u{i}@test.com", role_id=role.id
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserLateralCrud.offset_paginate(
|
||||||
|
db_session, schema=UserWithRoleRead, items_per_page=10
|
||||||
|
)
|
||||||
|
assert result.pagination.total_count == 5
|
||||||
|
assert len(result.data) == 5
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_offset_paginate_loads_relationship(self, db_session: AsyncSession):
|
||||||
|
"""offset_paginate serializes relationship data loaded via lateral."""
|
||||||
|
|
||||||
|
class UserLateralCrud(AsyncCrud[User]):
|
||||||
|
model = User
|
||||||
|
default_load_options = [lateral_load(User.role)]
|
||||||
|
|
||||||
|
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="alice", email="alice@test.com", role_id=role.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserLateralCrud.offset_paginate(
|
||||||
|
db_session, schema=UserWithRoleRead, items_per_page=10
|
||||||
|
)
|
||||||
|
assert result.data[0].role is not None
|
||||||
|
assert result.data[0].role.name == "admin"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_offset_paginate_mixed_null_fk(self, db_session: AsyncSession):
|
||||||
|
"""offset_paginate returns all users including those with null role_id."""
|
||||||
|
|
||||||
|
class UserLateralCrud(AsyncCrud[User]):
|
||||||
|
model = User
|
||||||
|
default_load_options = [lateral_load(User.role)]
|
||||||
|
|
||||||
|
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="with_role", email="a@test.com", role_id=role.id),
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="no_role", email="b@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserLateralCrud.offset_paginate(
|
||||||
|
db_session, schema=UserWithRoleRead, items_per_page=10
|
||||||
|
)
|
||||||
|
assert result.pagination.total_count == 2
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_cursor_paginate_loads_relationship(self, db_session: AsyncSession):
|
||||||
|
"""cursor_paginate populates the relationship via lateral join."""
|
||||||
|
|
||||||
|
class UserLateralCursorCrud(AsyncCrud[User]):
|
||||||
|
model = User
|
||||||
|
cursor_column = User.id
|
||||||
|
default_load_options = [lateral_load(User.role)]
|
||||||
|
|
||||||
|
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||||
|
for i in range(3):
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(
|
||||||
|
username=f"user{i}", email=f"u{i}@test.com", role_id=role.id
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserLateralCursorCrud.cursor_paginate(
|
||||||
|
db_session, schema=UserWithRoleRead, items_per_page=10
|
||||||
|
)
|
||||||
|
assert len(result.data) == 3
|
||||||
|
assert all(item.role is not None for item in result.data)
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_offset_paginate_with_search_and_lateral(
|
||||||
|
self, db_session: AsyncSession
|
||||||
|
):
|
||||||
|
"""search filter works alongside lateral join."""
|
||||||
|
|
||||||
|
class UserLateralCrud(AsyncCrud[User]):
|
||||||
|
model = User
|
||||||
|
default_load_options = [lateral_load(User.role)]
|
||||||
|
searchable_fields = [User.username]
|
||||||
|
|
||||||
|
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="alice", email="a@test.com", role_id=role.id),
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="bob", email="b@test.com", role_id=role.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserLateralCrud.offset_paginate(
|
||||||
|
db_session, schema=UserWithRoleRead, search="alice", items_per_page=10
|
||||||
|
)
|
||||||
|
assert result.pagination.total_count == 1
|
||||||
|
assert result.data[0].username == "alice"
|
||||||
|
|||||||
Reference in New Issue
Block a user