mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-16 06:36:26 +02:00
feat: add lateral_load() for efficient Many:One eager loading via JOIN LATERAL
This commit is contained in:
@@ -6,9 +6,15 @@ import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from fastapi_toolsets.crud import CrudFactory, PaginationType
|
||||
from fastapi_toolsets.crud.factory import AsyncCrud, _CursorDirection
|
||||
from fastapi_toolsets.crud import CrudFactory, PaginationType, lateral_load
|
||||
from fastapi_toolsets.crud.factory import (
|
||||
AsyncCrud,
|
||||
_CursorDirection,
|
||||
_LateralLoad,
|
||||
_ResolvedLateral,
|
||||
)
|
||||
from fastapi_toolsets.exceptions import NotFoundError
|
||||
from fastapi_toolsets.schemas import PydanticBase
|
||||
|
||||
from .conftest import (
|
||||
EventCreate,
|
||||
@@ -51,6 +57,12 @@ from .conftest import (
|
||||
)
|
||||
|
||||
|
||||
class UserWithRoleRead(PydanticBase):
|
||||
id: uuid.UUID
|
||||
username: str
|
||||
role: RoleRead | None = None
|
||||
|
||||
|
||||
class TestCrudFactory:
|
||||
"""Tests for CrudFactory."""
|
||||
|
||||
@@ -208,11 +220,11 @@ class TestResolveLoadOptions:
|
||||
assert crud._resolve_load_options(None) is None
|
||||
|
||||
def test_empty_list_overrides_default(self):
|
||||
"""An empty list is a valid override and disables default_load_options."""
|
||||
"""An explicit empty list disables default_load_options (no options applied)."""
|
||||
default = [selectinload(User.role)]
|
||||
crud = CrudFactory(User, default_load_options=default)
|
||||
# Empty list is not None, so it should replace default
|
||||
assert crud._resolve_load_options([]) == []
|
||||
# Empty list replaces default; None and [] are both falsy → no options applied
|
||||
assert not crud._resolve_load_options([])
|
||||
|
||||
|
||||
class TestResolveSearchColumns:
|
||||
@@ -359,13 +371,6 @@ class TestDefaultLoadOptionsIntegration:
|
||||
self, db_session: AsyncSession
|
||||
):
|
||||
"""default_load_options loads relationships automatically on offset_paginate()."""
|
||||
from fastapi_toolsets.schemas import PydanticBase
|
||||
|
||||
class UserWithRoleRead(PydanticBase):
|
||||
id: uuid.UUID
|
||||
username: str
|
||||
role: RoleRead | None = None
|
||||
|
||||
UserWithDefaultLoad = CrudFactory(
|
||||
User, default_load_options=[selectinload(User.role)]
|
||||
)
|
||||
@@ -2462,12 +2467,7 @@ class TestCursorPaginateExtraOptions:
|
||||
@pytest.mark.anyio
|
||||
async def test_with_load_options(self, db_session: AsyncSession):
|
||||
"""cursor_paginate passes load_options to the query."""
|
||||
from fastapi_toolsets.schemas import CursorPagination, PydanticBase
|
||||
|
||||
class UserWithRoleRead(PydanticBase):
|
||||
id: uuid.UUID
|
||||
username: str
|
||||
role: RoleRead | None = None
|
||||
from fastapi_toolsets.schemas import CursorPagination
|
||||
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="manager"))
|
||||
for i in range(3):
|
||||
@@ -2833,3 +2833,368 @@ class TestPaginate:
|
||||
|
||||
assert isinstance(result.pagination, OffsetPagination)
|
||||
assert result.pagination.total_count is None
|
||||
|
||||
|
||||
class TestLateralLoadValidation:
|
||||
"""lateral_load() raises immediately for bad relationship types."""
|
||||
|
||||
def test_valid_many_to_one_returns_marker(self):
|
||||
"""lateral_load() on a Many:One rel returns a _LateralLoad with rel_attr set."""
|
||||
marker = lateral_load(User.role)
|
||||
assert isinstance(marker, _LateralLoad)
|
||||
assert marker.rel_attr is User.role
|
||||
|
||||
def test_raises_type_error_for_plain_column(self):
|
||||
"""lateral_load() raises TypeError when passed a plain column."""
|
||||
with pytest.raises(TypeError, match="relationship attribute"):
|
||||
lateral_load(User.username)
|
||||
|
||||
def test_raises_value_error_for_many_to_many(self):
|
||||
"""lateral_load() raises ValueError for Many:Many (secondary table)."""
|
||||
with pytest.raises(ValueError, match="Many:Many"):
|
||||
lateral_load(Post.tags)
|
||||
|
||||
def test_raises_value_error_for_one_to_many(self):
|
||||
"""lateral_load() raises ValueError for One:Many (uselist=True)."""
|
||||
with pytest.raises(ValueError, match="One:Many"):
|
||||
lateral_load(Role.users)
|
||||
|
||||
|
||||
class TestLateralLoadInSubclass:
|
||||
"""lateral_load() markers in default_load_options are processed at class definition."""
|
||||
|
||||
def test_marker_extracted_from_default_load_options(self):
|
||||
"""_LateralLoad is removed from default_load_options and stored in _resolved_lateral."""
|
||||
|
||||
class UserLateralCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
default_load_options = [lateral_load(User.role)]
|
||||
|
||||
assert UserLateralCrud.default_load_options is None
|
||||
assert UserLateralCrud._resolved_lateral is not None
|
||||
|
||||
def test_resolved_lateral_has_one_join_and_eager(self):
|
||||
"""_resolved_lateral contains exactly one join and one eager option."""
|
||||
|
||||
class UserLateralCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
default_load_options = [lateral_load(User.role)]
|
||||
|
||||
resolved = UserLateralCrud._resolved_lateral
|
||||
assert isinstance(resolved, _ResolvedLateral)
|
||||
assert len(resolved.joins) == 1
|
||||
assert len(resolved.eager) == 1
|
||||
|
||||
def test_regular_options_preserved_alongside_lateral(self):
|
||||
"""Non-lateral opts stay in default_load_options; lateral marker is extracted."""
|
||||
regular = selectinload(User.role)
|
||||
|
||||
class UserMixedCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
default_load_options = [lateral_load(User.role), regular]
|
||||
|
||||
assert UserMixedCrud._resolved_lateral is not None
|
||||
assert UserMixedCrud.default_load_options == [regular]
|
||||
|
||||
def test_no_lateral_leaves_default_load_options_untouched(self):
|
||||
"""When no lateral marker is present, default_load_options is unchanged."""
|
||||
opts = [selectinload(User.role)]
|
||||
|
||||
class UserNormalCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
default_load_options = opts
|
||||
|
||||
assert UserNormalCrud.default_load_options is opts
|
||||
assert UserNormalCrud._resolved_lateral is None
|
||||
|
||||
def test_no_default_load_options_leaves_resolved_lateral_none(self):
|
||||
"""_resolved_lateral stays None when default_load_options is not set."""
|
||||
|
||||
class UserPlainCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
|
||||
assert UserPlainCrud._resolved_lateral is None
|
||||
|
||||
|
||||
class TestResolveLoadOptionsWithLateral:
|
||||
"""_resolve_load_options always appends lateral eager options."""
|
||||
|
||||
def test_lateral_eager_included_when_no_call_site_opts(self):
|
||||
"""contains_eager from lateral_load is returned when load_options=None."""
|
||||
|
||||
class UserLateralCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
default_load_options = [lateral_load(User.role)]
|
||||
|
||||
resolved = UserLateralCrud._resolve_load_options(None)
|
||||
assert resolved is not None
|
||||
assert len(resolved) == 1 # the contains_eager
|
||||
|
||||
def test_lateral_eager_appended_to_call_site_opts(self):
|
||||
"""call-site load_options + lateral eager are both returned."""
|
||||
extra = selectinload(User.role)
|
||||
|
||||
class UserLateralCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
default_load_options = [lateral_load(User.role)]
|
||||
|
||||
resolved = UserLateralCrud._resolve_load_options([extra])
|
||||
assert resolved is not None
|
||||
assert len(resolved) == 2
|
||||
|
||||
def test_lateral_eager_appended_to_default_load_options(self):
|
||||
"""default_load_options (regular) + lateral eager are both returned."""
|
||||
regular = selectinload(User.role)
|
||||
|
||||
class UserMixedCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
default_load_options = [lateral_load(User.role), regular]
|
||||
|
||||
resolved = UserMixedCrud._resolve_load_options(None)
|
||||
assert resolved is not None
|
||||
assert len(resolved) == 2
|
||||
|
||||
|
||||
class TestGetLateralJoins:
|
||||
"""_get_lateral_joins merges auto-resolved and manual lateral_joins."""
|
||||
|
||||
def test_returns_none_when_no_lateral_configured(self):
|
||||
"""Returns None when neither lateral_joins nor lateral_load is set."""
|
||||
|
||||
class UserPlainCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
|
||||
assert UserPlainCrud._get_lateral_joins() is None
|
||||
|
||||
def test_returns_resolved_lateral_joins(self):
|
||||
"""Returns the join tuple built from lateral_load()."""
|
||||
|
||||
class UserLateralCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
default_load_options = [lateral_load(User.role)]
|
||||
|
||||
joins = UserLateralCrud._get_lateral_joins()
|
||||
assert joins is not None
|
||||
assert len(joins) == 1
|
||||
|
||||
def test_manual_lateral_joins_included(self):
|
||||
"""Manual lateral_joins class var is included in _get_lateral_joins."""
|
||||
from sqlalchemy import select, true
|
||||
|
||||
manual_sub = select(Role).where(Role.id == User.role_id).lateral("_manual_role")
|
||||
|
||||
class UserManualCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
lateral_joins = [(manual_sub, true())]
|
||||
|
||||
joins = UserManualCrud._get_lateral_joins()
|
||||
assert joins is not None
|
||||
assert len(joins) == 1
|
||||
|
||||
def test_manual_and_auto_lateral_joins_merged(self):
|
||||
"""Both manual lateral_joins and auto-resolved from lateral_load are combined."""
|
||||
from sqlalchemy import select, true
|
||||
|
||||
manual_sub = select(Role).where(Role.id == User.role_id).lateral("_manual_role")
|
||||
|
||||
class UserBothCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
lateral_joins = [(manual_sub, true())]
|
||||
default_load_options = [lateral_load(User.role)]
|
||||
|
||||
joins = UserBothCrud._get_lateral_joins()
|
||||
assert joins is not None
|
||||
assert len(joins) == 2
|
||||
|
||||
|
||||
class TestLateralLoadIntegration:
|
||||
"""lateral_load() in real DB queries: relationship loaded, pagination correct."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_loads_relationship(self, db_session: AsyncSession):
|
||||
"""get() populates the relationship via lateral join."""
|
||||
|
||||
class UserLateralCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
default_load_options = [lateral_load(User.role)]
|
||||
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||
user = await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="alice", email="alice@test.com", role_id=role.id),
|
||||
)
|
||||
|
||||
fetched = await UserLateralCrud.get(db_session, [User.id == user.id])
|
||||
assert fetched.role is not None
|
||||
assert fetched.role.name == "admin"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_null_fk_preserved(self, db_session: AsyncSession):
|
||||
"""User with null role_id still returned (LEFT JOIN behaviour)."""
|
||||
|
||||
class UserLateralCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
default_load_options = [lateral_load(User.role)]
|
||||
|
||||
user = await UserCrud.create(
|
||||
db_session, UserCreate(username="bob", email="bob@test.com")
|
||||
)
|
||||
|
||||
fetched = await UserLateralCrud.get(db_session, [User.id == user.id])
|
||||
assert fetched is not None
|
||||
assert fetched.role is None
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_first_loads_relationship(self, db_session: AsyncSession):
|
||||
"""first() populates the relationship via lateral join."""
|
||||
|
||||
class UserLateralCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
default_load_options = [lateral_load(User.role)]
|
||||
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="editor"))
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="carol", email="carol@test.com", role_id=role.id),
|
||||
)
|
||||
|
||||
user = await UserLateralCrud.first(db_session)
|
||||
assert user is not None
|
||||
assert user.role is not None
|
||||
assert user.role.name == "editor"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_multi_loads_relationship(self, db_session: AsyncSession):
|
||||
"""get_multi() populates the relationship via lateral join for all rows."""
|
||||
|
||||
class UserLateralCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
default_load_options = [lateral_load(User.role)]
|
||||
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="member"))
|
||||
for i in range(3):
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(
|
||||
username=f"user{i}", email=f"u{i}@test.com", role_id=role.id
|
||||
),
|
||||
)
|
||||
|
||||
users = await UserLateralCrud.get_multi(db_session)
|
||||
assert len(users) == 3
|
||||
assert all(u.role is not None and u.role.name == "member" for u in users)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_offset_paginate_correct_count(self, db_session: AsyncSession):
|
||||
"""offset_paginate total_count is not inflated by the lateral join."""
|
||||
|
||||
class UserLateralCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
default_load_options = [lateral_load(User.role)]
|
||||
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||
for i in range(5):
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(
|
||||
username=f"user{i}", email=f"u{i}@test.com", role_id=role.id
|
||||
),
|
||||
)
|
||||
|
||||
result = await UserLateralCrud.offset_paginate(
|
||||
db_session, schema=UserWithRoleRead, items_per_page=10
|
||||
)
|
||||
assert result.pagination.total_count == 5
|
||||
assert len(result.data) == 5
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_offset_paginate_loads_relationship(self, db_session: AsyncSession):
|
||||
"""offset_paginate serializes relationship data loaded via lateral."""
|
||||
|
||||
class UserLateralCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
default_load_options = [lateral_load(User.role)]
|
||||
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="alice", email="alice@test.com", role_id=role.id),
|
||||
)
|
||||
|
||||
result = await UserLateralCrud.offset_paginate(
|
||||
db_session, schema=UserWithRoleRead, items_per_page=10
|
||||
)
|
||||
assert result.data[0].role is not None
|
||||
assert result.data[0].role.name == "admin"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_offset_paginate_mixed_null_fk(self, db_session: AsyncSession):
|
||||
"""offset_paginate returns all users including those with null role_id."""
|
||||
|
||||
class UserLateralCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
default_load_options = [lateral_load(User.role)]
|
||||
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="with_role", email="a@test.com", role_id=role.id),
|
||||
)
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="no_role", email="b@test.com")
|
||||
)
|
||||
|
||||
result = await UserLateralCrud.offset_paginate(
|
||||
db_session, schema=UserWithRoleRead, items_per_page=10
|
||||
)
|
||||
assert result.pagination.total_count == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_cursor_paginate_loads_relationship(self, db_session: AsyncSession):
|
||||
"""cursor_paginate populates the relationship via lateral join."""
|
||||
|
||||
class UserLateralCursorCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
cursor_column = User.id
|
||||
default_load_options = [lateral_load(User.role)]
|
||||
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||
for i in range(3):
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(
|
||||
username=f"user{i}", email=f"u{i}@test.com", role_id=role.id
|
||||
),
|
||||
)
|
||||
|
||||
result = await UserLateralCursorCrud.cursor_paginate(
|
||||
db_session, schema=UserWithRoleRead, items_per_page=10
|
||||
)
|
||||
assert len(result.data) == 3
|
||||
assert all(item.role is not None for item in result.data)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_offset_paginate_with_search_and_lateral(
|
||||
self, db_session: AsyncSession
|
||||
):
|
||||
"""search filter works alongside lateral join."""
|
||||
|
||||
class UserLateralCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
default_load_options = [lateral_load(User.role)]
|
||||
searchable_fields = [User.username]
|
||||
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="alice", email="a@test.com", role_id=role.id),
|
||||
)
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="bob", email="b@test.com", role_id=role.id)
|
||||
)
|
||||
|
||||
result = await UserLateralCrud.offset_paginate(
|
||||
db_session, schema=UserWithRoleRead, search="alice", items_per_page=10
|
||||
)
|
||||
assert result.pagination.total_count == 1
|
||||
assert result.data[0].username == "alice"
|
||||
|
||||
Reference in New Issue
Block a user