feat: auto eager-load relationships in register_fixtures (#243)

This commit is contained in:
d3vyce
2026-04-12 17:04:44 +02:00
committed by GitHub
parent aca3f62a6b
commit c7397faea4
2 changed files with 282 additions and 15 deletions

View File

@@ -1,11 +1,13 @@
"""Pytest plugin for using FixtureRegistry fixtures in tests."""
from collections.abc import Callable, Sequence
from typing import Any
from typing import Any, cast
import pytest
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import DeclarativeBase, selectinload
from sqlalchemy.orm.interfaces import ExecutableOption, ORMOption
from ..db import get_transaction
from ..fixtures import FixtureRegistry, LoadStrategy
@@ -112,7 +114,7 @@ def _create_fixture_function(
elif strategy == LoadStrategy.MERGE:
merged = await session.merge(instance)
loaded.append(merged)
elif strategy == LoadStrategy.SKIP_EXISTING:
elif strategy == LoadStrategy.SKIP_EXISTING: # pragma: no branch
pk = _get_primary_key(instance)
if pk is not None:
existing = await session.get(type(instance), pk)
@@ -125,6 +127,11 @@ def _create_fixture_function(
session.add(instance)
loaded.append(instance)
if loaded: # pragma: no branch
load_options = _relationship_load_options(type(loaded[0]))
if load_options:
return await _reload_with_relationships(session, loaded, load_options)
return loaded
# Update function signature to include dependencies
@@ -141,6 +148,54 @@ def _create_fixture_function(
return created_func
def _relationship_load_options(model: type[DeclarativeBase]) -> list[ExecutableOption]:
"""Build selectinload options for all direct relationships on a model."""
return [
selectinload(getattr(model, rel.key)) for rel in model.__mapper__.relationships
]
async def _reload_with_relationships(
session: AsyncSession,
instances: list[DeclarativeBase],
load_options: list[ExecutableOption],
) -> list[DeclarativeBase]:
"""Reload instances in a single bulk query with relationship eager-loading.
Uses one SELECT … WHERE pk IN (…) so selectinload can batch all relationship
queries — 1 + N_relationships round-trips regardless of how many instances
there are, instead of one session.get() per instance.
Preserves the original insertion order.
"""
model = type(instances[0])
mapper = model.__mapper__
pk_cols = mapper.primary_key
if len(pk_cols) == 1:
pk_attr = getattr(model, pk_cols[0].key)
pks = [getattr(inst, pk_cols[0].key) for inst in instances]
result = await session.execute(
select(model).where(pk_attr.in_(pks)).options(*load_options)
)
by_pk = {getattr(row, pk_cols[0].key): row for row in result.unique().scalars()}
return [by_pk[pk] for pk in pks]
# Composite PK: fall back to per-instance reload
reloaded: list[DeclarativeBase] = []
for instance in instances:
pk = _get_primary_key(instance)
refreshed = await session.get(
model,
pk,
options=cast(list[ORMOption], load_options),
populate_existing=True,
)
if refreshed is not None: # pragma: no branch
reloaded.append(refreshed)
return reloaded
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
"""Get the primary key value of a model instance."""
mapper = instance.__class__.__mapper__

View File

@@ -1,17 +1,18 @@
"""Tests for fastapi_toolsets.pytest module."""
import uuid
from typing import cast
import pytest
from fastapi import Depends, FastAPI
from httpx import AsyncClient
from sqlalchemy import select, text
from sqlalchemy import ForeignKey, String, select, text
from sqlalchemy.engine import make_url
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from fastapi_toolsets.db import get_transaction
from fastapi_toolsets.fixtures import Context, FixtureRegistry
from fastapi_toolsets.fixtures import Context, FixtureRegistry, LoadStrategy
from fastapi_toolsets.pytest import (
create_async_client,
create_db_session,
@@ -19,9 +20,23 @@ from fastapi_toolsets.pytest import (
register_fixtures,
worker_database_url,
)
from fastapi_toolsets.pytest.plugin import (
_get_primary_key,
_relationship_load_options,
_reload_with_relationships,
)
from fastapi_toolsets.pytest.utils import _get_xdist_worker
from .conftest import DATABASE_URL, Base, Role, RoleCrud, User, UserCrud
from .conftest import (
DATABASE_URL,
Base,
IntRole,
Permission,
Role,
RoleCrud,
User,
UserCrud,
)
test_registry = FixtureRegistry()
@@ -136,14 +151,8 @@ class TestGeneratedFixtures:
async def test_fixture_relationships_work(
self, db_session: AsyncSession, fixture_users: list[User]
):
"""Loaded fixtures have working relationships."""
# Load user with role relationship
user = await UserCrud.get(
db_session,
[User.id == USER_ADMIN_ID],
load_options=[selectinload(User.role)],
)
"""Loaded fixtures have working relationships directly accessible."""
user = next(u for u in fixture_users if u.id == USER_ADMIN_ID)
assert user.role is not None
assert user.role.name == "plugin_admin"
@@ -177,6 +186,15 @@ class TestGeneratedFixtures:
assert users[0].username == "plugin_admin"
assert users[1].username == "plugin_user"
@pytest.mark.anyio
async def test_fixture_auto_loads_relationships(
self, db_session: AsyncSession, fixture_users: list[User]
):
"""Fixtures automatically eager-load all direct relationships."""
user = next(u for u in fixture_users if u.username == "plugin_admin")
assert user.role is not None
assert user.role.name == "plugin_admin"
@pytest.mark.anyio
async def test_multiple_fixtures_in_same_test(
self,
@@ -516,3 +534,197 @@ class TestCreateWorkerDatabase:
)
assert result.scalar() is None
await engine.dispose()
# ---------------------------------------------------------------------------
# Local models for composite-PK coverage (own Base → own tables, isolated)
# ---------------------------------------------------------------------------
class _LocalBase(DeclarativeBase):
pass
class _Group(_LocalBase):
__tablename__ = "_test_groups"
id: Mapped[uuid.UUID] = mapped_column(primary_key=True, default=uuid.uuid4)
name: Mapped[str] = mapped_column(String(50))
class _CompositeItem(_LocalBase):
"""Model with composite PK and a relationship — exercises the fallback path."""
__tablename__ = "_test_composite_items"
group_id: Mapped[uuid.UUID] = mapped_column(
ForeignKey("_test_groups.id"), primary_key=True
)
item_code: Mapped[str] = mapped_column(String(50), primary_key=True)
group: Mapped["_Group"] = relationship()
class TestGetPrimaryKey:
"""Unit tests for _get_primary_key — no DB needed."""
def test_single_pk_returns_value(self):
rid = uuid.UUID("00000000-0000-0000-0000-000000000001")
role = Role(id=rid, name="x")
assert _get_primary_key(role) == rid
def test_composite_pk_all_set_returns_tuple(self):
perm = Permission(subject="posts", action="read")
assert _get_primary_key(perm) == ("posts", "read")
def test_composite_pk_partial_none_returns_none(self):
perm = Permission(subject=None, action="read")
assert _get_primary_key(perm) is None
def test_composite_pk_all_none_returns_none(self):
perm = Permission(subject=None, action=None)
assert _get_primary_key(perm) is None
class TestRelationshipLoadOptions:
"""Unit tests for _relationship_load_options — no DB needed."""
def test_empty_for_model_with_no_relationships(self):
assert _relationship_load_options(IntRole) == []
def test_returns_options_for_model_with_relationships(self):
opts = _relationship_load_options(User)
assert len(opts) >= 1
class TestFixtureStrategies:
"""Integration tests covering INSERT, SKIP_EXISTING, empty fixture, no-rels model."""
@pytest.mark.anyio
async def test_empty_fixture_returns_empty_list(self, db_session: AsyncSession):
"""Fixture function returning [] produces an empty list."""
registry = FixtureRegistry()
@registry.register()
def empty() -> list[Role]:
return []
local_ns: dict = {}
register_fixtures(registry, local_ns, session_fixture="db_session")
inner = local_ns["fixture_empty"].__wrapped__ # type: ignore[attr-defined]
result = await inner(db_session=db_session)
assert result == []
@pytest.mark.anyio
async def test_insert_strategy_no_relationships(self, db_session: AsyncSession):
"""INSERT strategy adds instances; model with no rels skips reload (line 135)."""
registry = FixtureRegistry()
@registry.register()
def int_roles() -> list[IntRole]:
return [IntRole(name="insert_role")]
local_ns: dict = {}
register_fixtures(
registry,
local_ns,
session_fixture="db_session",
strategy=LoadStrategy.INSERT,
)
inner = local_ns["fixture_int_roles"].__wrapped__ # type: ignore[attr-defined]
result = await inner(db_session=db_session)
assert len(result) == 1
assert result[0].name == "insert_role"
@pytest.mark.anyio
async def test_skip_existing_inserts_new_record(self, db_session: AsyncSession):
"""SKIP_EXISTING inserts when the record does not yet exist."""
registry = FixtureRegistry()
role_id = uuid.uuid4()
@registry.register()
def new_roles() -> list[Role]:
return [Role(id=role_id, name="skip_new")]
local_ns: dict = {}
register_fixtures(
registry,
local_ns,
session_fixture="db_session",
strategy=LoadStrategy.SKIP_EXISTING,
)
inner = local_ns["fixture_new_roles"].__wrapped__ # type: ignore[attr-defined]
result = await inner(db_session=db_session)
assert len(result) == 1
assert result[0].id == role_id
@pytest.mark.anyio
async def test_skip_existing_returns_existing_record(
self, db_session: AsyncSession
):
"""SKIP_EXISTING returns the existing DB record when PK already present."""
role_id = uuid.uuid4()
existing = Role(id=role_id, name="already_there")
db_session.add(existing)
await db_session.flush()
registry = FixtureRegistry()
@registry.register()
def dup_roles() -> list[Role]:
return [Role(id=role_id, name="should_not_overwrite")]
local_ns: dict = {}
register_fixtures(
registry,
local_ns,
session_fixture="db_session",
strategy=LoadStrategy.SKIP_EXISTING,
)
inner = local_ns["fixture_dup_roles"].__wrapped__ # type: ignore[attr-defined]
result = await inner(db_session=db_session)
assert len(result) == 1
assert result[0].name == "already_there"
@pytest.mark.anyio
async def test_skip_existing_null_pk_inserts(self, db_session: AsyncSession):
"""SKIP_EXISTING with null PK (auto-increment) falls through to session.add()."""
registry = FixtureRegistry()
@registry.register()
def auto_roles() -> list[IntRole]:
return [IntRole(name="auto_int")]
local_ns: dict = {}
register_fixtures(
registry,
local_ns,
session_fixture="db_session",
strategy=LoadStrategy.SKIP_EXISTING,
)
inner = local_ns["fixture_auto_roles"].__wrapped__ # type: ignore[attr-defined]
result = await inner(db_session=db_session)
assert len(result) == 1
assert result[0].name == "auto_int"
class TestReloadWithRelationshipsCompositePK:
"""Integration test for _reload_with_relationships composite-PK fallback."""
@pytest.mark.anyio
async def test_composite_pk_fallback_loads_relationships(self):
"""Models with composite PKs are reloaded per-instance via session.get()."""
async with create_db_session(DATABASE_URL, _LocalBase) as session:
group = _Group(id=uuid.uuid4(), name="g1")
session.add(group)
await session.flush()
item = _CompositeItem(group_id=group.id, item_code="A")
session.add(item)
await session.flush()
load_opts = _relationship_load_options(_CompositeItem)
assert load_opts # _CompositeItem has 'group' relationship
reloaded = await _reload_with_relationships(session, [item], load_opts)
assert len(reloaded) == 1
reloaded_item = cast(_CompositeItem, reloaded[0])
assert reloaded_item.group is not None
assert reloaded_item.group.name == "g1"