mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-15 22:26:25 +02:00
feat: auto eager-load relationships in register_fixtures (#243)
This commit is contained in:
@@ -1,11 +1,13 @@
|
||||
"""Pytest plugin for using FixtureRegistry fixtures in tests."""
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from sqlalchemy.orm import DeclarativeBase, selectinload
|
||||
from sqlalchemy.orm.interfaces import ExecutableOption, ORMOption
|
||||
|
||||
from ..db import get_transaction
|
||||
from ..fixtures import FixtureRegistry, LoadStrategy
|
||||
@@ -112,7 +114,7 @@ def _create_fixture_function(
|
||||
elif strategy == LoadStrategy.MERGE:
|
||||
merged = await session.merge(instance)
|
||||
loaded.append(merged)
|
||||
elif strategy == LoadStrategy.SKIP_EXISTING:
|
||||
elif strategy == LoadStrategy.SKIP_EXISTING: # pragma: no branch
|
||||
pk = _get_primary_key(instance)
|
||||
if pk is not None:
|
||||
existing = await session.get(type(instance), pk)
|
||||
@@ -125,6 +127,11 @@ def _create_fixture_function(
|
||||
session.add(instance)
|
||||
loaded.append(instance)
|
||||
|
||||
if loaded: # pragma: no branch
|
||||
load_options = _relationship_load_options(type(loaded[0]))
|
||||
if load_options:
|
||||
return await _reload_with_relationships(session, loaded, load_options)
|
||||
|
||||
return loaded
|
||||
|
||||
# Update function signature to include dependencies
|
||||
@@ -141,6 +148,54 @@ def _create_fixture_function(
|
||||
return created_func
|
||||
|
||||
|
||||
def _relationship_load_options(model: type[DeclarativeBase]) -> list[ExecutableOption]:
|
||||
"""Build selectinload options for all direct relationships on a model."""
|
||||
return [
|
||||
selectinload(getattr(model, rel.key)) for rel in model.__mapper__.relationships
|
||||
]
|
||||
|
||||
|
||||
async def _reload_with_relationships(
|
||||
session: AsyncSession,
|
||||
instances: list[DeclarativeBase],
|
||||
load_options: list[ExecutableOption],
|
||||
) -> list[DeclarativeBase]:
|
||||
"""Reload instances in a single bulk query with relationship eager-loading.
|
||||
|
||||
Uses one SELECT … WHERE pk IN (…) so selectinload can batch all relationship
|
||||
queries — 1 + N_relationships round-trips regardless of how many instances
|
||||
there are, instead of one session.get() per instance.
|
||||
|
||||
Preserves the original insertion order.
|
||||
"""
|
||||
model = type(instances[0])
|
||||
mapper = model.__mapper__
|
||||
pk_cols = mapper.primary_key
|
||||
|
||||
if len(pk_cols) == 1:
|
||||
pk_attr = getattr(model, pk_cols[0].key)
|
||||
pks = [getattr(inst, pk_cols[0].key) for inst in instances]
|
||||
result = await session.execute(
|
||||
select(model).where(pk_attr.in_(pks)).options(*load_options)
|
||||
)
|
||||
by_pk = {getattr(row, pk_cols[0].key): row for row in result.unique().scalars()}
|
||||
return [by_pk[pk] for pk in pks]
|
||||
|
||||
# Composite PK: fall back to per-instance reload
|
||||
reloaded: list[DeclarativeBase] = []
|
||||
for instance in instances:
|
||||
pk = _get_primary_key(instance)
|
||||
refreshed = await session.get(
|
||||
model,
|
||||
pk,
|
||||
options=cast(list[ORMOption], load_options),
|
||||
populate_existing=True,
|
||||
)
|
||||
if refreshed is not None: # pragma: no branch
|
||||
reloaded.append(refreshed)
|
||||
return reloaded
|
||||
|
||||
|
||||
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
|
||||
"""Get the primary key value of a model instance."""
|
||||
mapper = instance.__class__.__mapper__
|
||||
|
||||
Reference in New Issue
Block a user