feat: auto eager-load relationships in register_fixtures (#243)

This commit is contained in:
d3vyce
2026-04-12 17:04:44 +02:00
committed by GitHub
parent aca3f62a6b
commit c7397faea4
2 changed files with 282 additions and 15 deletions

View File

@@ -1,11 +1,13 @@
"""Pytest plugin for using FixtureRegistry fixtures in tests."""
from collections.abc import Callable, Sequence
from typing import Any
from typing import Any, cast
import pytest
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import DeclarativeBase, selectinload
from sqlalchemy.orm.interfaces import ExecutableOption, ORMOption
from ..db import get_transaction
from ..fixtures import FixtureRegistry, LoadStrategy
@@ -112,7 +114,7 @@ def _create_fixture_function(
elif strategy == LoadStrategy.MERGE:
merged = await session.merge(instance)
loaded.append(merged)
elif strategy == LoadStrategy.SKIP_EXISTING:
elif strategy == LoadStrategy.SKIP_EXISTING: # pragma: no branch
pk = _get_primary_key(instance)
if pk is not None:
existing = await session.get(type(instance), pk)
@@ -125,6 +127,11 @@ def _create_fixture_function(
session.add(instance)
loaded.append(instance)
if loaded: # pragma: no branch
load_options = _relationship_load_options(type(loaded[0]))
if load_options:
return await _reload_with_relationships(session, loaded, load_options)
return loaded
# Update function signature to include dependencies
@@ -141,6 +148,54 @@ def _create_fixture_function(
return created_func
def _relationship_load_options(model: type[DeclarativeBase]) -> list[ExecutableOption]:
"""Build selectinload options for all direct relationships on a model."""
return [
selectinload(getattr(model, rel.key)) for rel in model.__mapper__.relationships
]
async def _reload_with_relationships(
session: AsyncSession,
instances: list[DeclarativeBase],
load_options: list[ExecutableOption],
) -> list[DeclarativeBase]:
"""Reload instances in a single bulk query with relationship eager-loading.
Uses one SELECT … WHERE pk IN (…) so selectinload can batch all relationship
queries — 1 + N_relationships round-trips regardless of how many instances
there are, instead of one session.get() per instance.
Preserves the original insertion order.
"""
model = type(instances[0])
mapper = model.__mapper__
pk_cols = mapper.primary_key
if len(pk_cols) == 1:
pk_attr = getattr(model, pk_cols[0].key)
pks = [getattr(inst, pk_cols[0].key) for inst in instances]
result = await session.execute(
select(model).where(pk_attr.in_(pks)).options(*load_options)
)
by_pk = {getattr(row, pk_cols[0].key): row for row in result.unique().scalars()}
return [by_pk[pk] for pk in pks]
# Composite PK: fall back to per-instance reload
reloaded: list[DeclarativeBase] = []
for instance in instances:
pk = _get_primary_key(instance)
refreshed = await session.get(
model,
pk,
options=cast(list[ORMOption], load_options),
populate_existing=True,
)
if refreshed is not None: # pragma: no branch
reloaded.append(refreshed)
return reloaded
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
"""Get the primary key value of a model instance."""
mapper = instance.__class__.__mapper__