mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-15 22:26:25 +02:00
feat: auto eager-load relationships in register_fixtures (#243)
This commit is contained in:
@@ -1,11 +1,13 @@
|
||||
"""Pytest plugin for using FixtureRegistry fixtures in tests."""
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from sqlalchemy.orm import DeclarativeBase, selectinload
|
||||
from sqlalchemy.orm.interfaces import ExecutableOption, ORMOption
|
||||
|
||||
from ..db import get_transaction
|
||||
from ..fixtures import FixtureRegistry, LoadStrategy
|
||||
@@ -112,7 +114,7 @@ def _create_fixture_function(
|
||||
elif strategy == LoadStrategy.MERGE:
|
||||
merged = await session.merge(instance)
|
||||
loaded.append(merged)
|
||||
elif strategy == LoadStrategy.SKIP_EXISTING:
|
||||
elif strategy == LoadStrategy.SKIP_EXISTING: # pragma: no branch
|
||||
pk = _get_primary_key(instance)
|
||||
if pk is not None:
|
||||
existing = await session.get(type(instance), pk)
|
||||
@@ -125,6 +127,11 @@ def _create_fixture_function(
|
||||
session.add(instance)
|
||||
loaded.append(instance)
|
||||
|
||||
if loaded: # pragma: no branch
|
||||
load_options = _relationship_load_options(type(loaded[0]))
|
||||
if load_options:
|
||||
return await _reload_with_relationships(session, loaded, load_options)
|
||||
|
||||
return loaded
|
||||
|
||||
# Update function signature to include dependencies
|
||||
@@ -141,6 +148,54 @@ def _create_fixture_function(
|
||||
return created_func
|
||||
|
||||
|
||||
def _relationship_load_options(model: type[DeclarativeBase]) -> list[ExecutableOption]:
|
||||
"""Build selectinload options for all direct relationships on a model."""
|
||||
return [
|
||||
selectinload(getattr(model, rel.key)) for rel in model.__mapper__.relationships
|
||||
]
|
||||
|
||||
|
||||
async def _reload_with_relationships(
|
||||
session: AsyncSession,
|
||||
instances: list[DeclarativeBase],
|
||||
load_options: list[ExecutableOption],
|
||||
) -> list[DeclarativeBase]:
|
||||
"""Reload instances in a single bulk query with relationship eager-loading.
|
||||
|
||||
Uses one SELECT … WHERE pk IN (…) so selectinload can batch all relationship
|
||||
queries — 1 + N_relationships round-trips regardless of how many instances
|
||||
there are, instead of one session.get() per instance.
|
||||
|
||||
Preserves the original insertion order.
|
||||
"""
|
||||
model = type(instances[0])
|
||||
mapper = model.__mapper__
|
||||
pk_cols = mapper.primary_key
|
||||
|
||||
if len(pk_cols) == 1:
|
||||
pk_attr = getattr(model, pk_cols[0].key)
|
||||
pks = [getattr(inst, pk_cols[0].key) for inst in instances]
|
||||
result = await session.execute(
|
||||
select(model).where(pk_attr.in_(pks)).options(*load_options)
|
||||
)
|
||||
by_pk = {getattr(row, pk_cols[0].key): row for row in result.unique().scalars()}
|
||||
return [by_pk[pk] for pk in pks]
|
||||
|
||||
# Composite PK: fall back to per-instance reload
|
||||
reloaded: list[DeclarativeBase] = []
|
||||
for instance in instances:
|
||||
pk = _get_primary_key(instance)
|
||||
refreshed = await session.get(
|
||||
model,
|
||||
pk,
|
||||
options=cast(list[ORMOption], load_options),
|
||||
populate_existing=True,
|
||||
)
|
||||
if refreshed is not None: # pragma: no branch
|
||||
reloaded.append(refreshed)
|
||||
return reloaded
|
||||
|
||||
|
||||
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
|
||||
"""Get the primary key value of a model instance."""
|
||||
mapper = instance.__class__.__mapper__
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
"""Tests for fastapi_toolsets.pytest module."""
|
||||
|
||||
import uuid
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from fastapi import Depends, FastAPI
|
||||
from httpx import AsyncClient
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy import ForeignKey, String, select, text
|
||||
from sqlalchemy.engine import make_url
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||
|
||||
from fastapi_toolsets.db import get_transaction
|
||||
from fastapi_toolsets.fixtures import Context, FixtureRegistry
|
||||
from fastapi_toolsets.fixtures import Context, FixtureRegistry, LoadStrategy
|
||||
from fastapi_toolsets.pytest import (
|
||||
create_async_client,
|
||||
create_db_session,
|
||||
@@ -19,9 +20,23 @@ from fastapi_toolsets.pytest import (
|
||||
register_fixtures,
|
||||
worker_database_url,
|
||||
)
|
||||
from fastapi_toolsets.pytest.plugin import (
|
||||
_get_primary_key,
|
||||
_relationship_load_options,
|
||||
_reload_with_relationships,
|
||||
)
|
||||
from fastapi_toolsets.pytest.utils import _get_xdist_worker
|
||||
|
||||
from .conftest import DATABASE_URL, Base, Role, RoleCrud, User, UserCrud
|
||||
from .conftest import (
|
||||
DATABASE_URL,
|
||||
Base,
|
||||
IntRole,
|
||||
Permission,
|
||||
Role,
|
||||
RoleCrud,
|
||||
User,
|
||||
UserCrud,
|
||||
)
|
||||
|
||||
test_registry = FixtureRegistry()
|
||||
|
||||
@@ -136,14 +151,8 @@ class TestGeneratedFixtures:
|
||||
async def test_fixture_relationships_work(
|
||||
self, db_session: AsyncSession, fixture_users: list[User]
|
||||
):
|
||||
"""Loaded fixtures have working relationships."""
|
||||
# Load user with role relationship
|
||||
user = await UserCrud.get(
|
||||
db_session,
|
||||
[User.id == USER_ADMIN_ID],
|
||||
load_options=[selectinload(User.role)],
|
||||
)
|
||||
|
||||
"""Loaded fixtures have working relationships directly accessible."""
|
||||
user = next(u for u in fixture_users if u.id == USER_ADMIN_ID)
|
||||
assert user.role is not None
|
||||
assert user.role.name == "plugin_admin"
|
||||
|
||||
@@ -177,6 +186,15 @@ class TestGeneratedFixtures:
|
||||
assert users[0].username == "plugin_admin"
|
||||
assert users[1].username == "plugin_user"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_fixture_auto_loads_relationships(
|
||||
self, db_session: AsyncSession, fixture_users: list[User]
|
||||
):
|
||||
"""Fixtures automatically eager-load all direct relationships."""
|
||||
user = next(u for u in fixture_users if u.username == "plugin_admin")
|
||||
assert user.role is not None
|
||||
assert user.role.name == "plugin_admin"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_multiple_fixtures_in_same_test(
|
||||
self,
|
||||
@@ -516,3 +534,197 @@ class TestCreateWorkerDatabase:
|
||||
)
|
||||
assert result.scalar() is None
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Local models for composite-PK coverage (own Base → own tables, isolated)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _LocalBase(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class _Group(_LocalBase):
|
||||
__tablename__ = "_test_groups"
|
||||
id: Mapped[uuid.UUID] = mapped_column(primary_key=True, default=uuid.uuid4)
|
||||
name: Mapped[str] = mapped_column(String(50))
|
||||
|
||||
|
||||
class _CompositeItem(_LocalBase):
|
||||
"""Model with composite PK and a relationship — exercises the fallback path."""
|
||||
|
||||
__tablename__ = "_test_composite_items"
|
||||
group_id: Mapped[uuid.UUID] = mapped_column(
|
||||
ForeignKey("_test_groups.id"), primary_key=True
|
||||
)
|
||||
item_code: Mapped[str] = mapped_column(String(50), primary_key=True)
|
||||
group: Mapped["_Group"] = relationship()
|
||||
|
||||
|
||||
class TestGetPrimaryKey:
|
||||
"""Unit tests for _get_primary_key — no DB needed."""
|
||||
|
||||
def test_single_pk_returns_value(self):
|
||||
rid = uuid.UUID("00000000-0000-0000-0000-000000000001")
|
||||
role = Role(id=rid, name="x")
|
||||
assert _get_primary_key(role) == rid
|
||||
|
||||
def test_composite_pk_all_set_returns_tuple(self):
|
||||
perm = Permission(subject="posts", action="read")
|
||||
assert _get_primary_key(perm) == ("posts", "read")
|
||||
|
||||
def test_composite_pk_partial_none_returns_none(self):
|
||||
perm = Permission(subject=None, action="read")
|
||||
assert _get_primary_key(perm) is None
|
||||
|
||||
def test_composite_pk_all_none_returns_none(self):
|
||||
perm = Permission(subject=None, action=None)
|
||||
assert _get_primary_key(perm) is None
|
||||
|
||||
|
||||
class TestRelationshipLoadOptions:
|
||||
"""Unit tests for _relationship_load_options — no DB needed."""
|
||||
|
||||
def test_empty_for_model_with_no_relationships(self):
|
||||
assert _relationship_load_options(IntRole) == []
|
||||
|
||||
def test_returns_options_for_model_with_relationships(self):
|
||||
opts = _relationship_load_options(User)
|
||||
assert len(opts) >= 1
|
||||
|
||||
|
||||
class TestFixtureStrategies:
|
||||
"""Integration tests covering INSERT, SKIP_EXISTING, empty fixture, no-rels model."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_empty_fixture_returns_empty_list(self, db_session: AsyncSession):
|
||||
"""Fixture function returning [] produces an empty list."""
|
||||
registry = FixtureRegistry()
|
||||
|
||||
@registry.register()
|
||||
def empty() -> list[Role]:
|
||||
return []
|
||||
|
||||
local_ns: dict = {}
|
||||
register_fixtures(registry, local_ns, session_fixture="db_session")
|
||||
inner = local_ns["fixture_empty"].__wrapped__ # type: ignore[attr-defined]
|
||||
result = await inner(db_session=db_session)
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_insert_strategy_no_relationships(self, db_session: AsyncSession):
|
||||
"""INSERT strategy adds instances; model with no rels skips reload (line 135)."""
|
||||
registry = FixtureRegistry()
|
||||
|
||||
@registry.register()
|
||||
def int_roles() -> list[IntRole]:
|
||||
return [IntRole(name="insert_role")]
|
||||
|
||||
local_ns: dict = {}
|
||||
register_fixtures(
|
||||
registry,
|
||||
local_ns,
|
||||
session_fixture="db_session",
|
||||
strategy=LoadStrategy.INSERT,
|
||||
)
|
||||
inner = local_ns["fixture_int_roles"].__wrapped__ # type: ignore[attr-defined]
|
||||
result = await inner(db_session=db_session)
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "insert_role"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_skip_existing_inserts_new_record(self, db_session: AsyncSession):
|
||||
"""SKIP_EXISTING inserts when the record does not yet exist."""
|
||||
registry = FixtureRegistry()
|
||||
role_id = uuid.uuid4()
|
||||
|
||||
@registry.register()
|
||||
def new_roles() -> list[Role]:
|
||||
return [Role(id=role_id, name="skip_new")]
|
||||
|
||||
local_ns: dict = {}
|
||||
register_fixtures(
|
||||
registry,
|
||||
local_ns,
|
||||
session_fixture="db_session",
|
||||
strategy=LoadStrategy.SKIP_EXISTING,
|
||||
)
|
||||
inner = local_ns["fixture_new_roles"].__wrapped__ # type: ignore[attr-defined]
|
||||
result = await inner(db_session=db_session)
|
||||
assert len(result) == 1
|
||||
assert result[0].id == role_id
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_skip_existing_returns_existing_record(
|
||||
self, db_session: AsyncSession
|
||||
):
|
||||
"""SKIP_EXISTING returns the existing DB record when PK already present."""
|
||||
role_id = uuid.uuid4()
|
||||
existing = Role(id=role_id, name="already_there")
|
||||
db_session.add(existing)
|
||||
await db_session.flush()
|
||||
|
||||
registry = FixtureRegistry()
|
||||
|
||||
@registry.register()
|
||||
def dup_roles() -> list[Role]:
|
||||
return [Role(id=role_id, name="should_not_overwrite")]
|
||||
|
||||
local_ns: dict = {}
|
||||
register_fixtures(
|
||||
registry,
|
||||
local_ns,
|
||||
session_fixture="db_session",
|
||||
strategy=LoadStrategy.SKIP_EXISTING,
|
||||
)
|
||||
inner = local_ns["fixture_dup_roles"].__wrapped__ # type: ignore[attr-defined]
|
||||
result = await inner(db_session=db_session)
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "already_there"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_skip_existing_null_pk_inserts(self, db_session: AsyncSession):
|
||||
"""SKIP_EXISTING with null PK (auto-increment) falls through to session.add()."""
|
||||
registry = FixtureRegistry()
|
||||
|
||||
@registry.register()
|
||||
def auto_roles() -> list[IntRole]:
|
||||
return [IntRole(name="auto_int")]
|
||||
|
||||
local_ns: dict = {}
|
||||
register_fixtures(
|
||||
registry,
|
||||
local_ns,
|
||||
session_fixture="db_session",
|
||||
strategy=LoadStrategy.SKIP_EXISTING,
|
||||
)
|
||||
inner = local_ns["fixture_auto_roles"].__wrapped__ # type: ignore[attr-defined]
|
||||
result = await inner(db_session=db_session)
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "auto_int"
|
||||
|
||||
|
||||
class TestReloadWithRelationshipsCompositePK:
|
||||
"""Integration test for _reload_with_relationships composite-PK fallback."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_composite_pk_fallback_loads_relationships(self):
|
||||
"""Models with composite PKs are reloaded per-instance via session.get()."""
|
||||
async with create_db_session(DATABASE_URL, _LocalBase) as session:
|
||||
group = _Group(id=uuid.uuid4(), name="g1")
|
||||
session.add(group)
|
||||
await session.flush()
|
||||
|
||||
item = _CompositeItem(group_id=group.id, item_code="A")
|
||||
session.add(item)
|
||||
await session.flush()
|
||||
|
||||
load_opts = _relationship_load_options(_CompositeItem)
|
||||
assert load_opts # _CompositeItem has 'group' relationship
|
||||
|
||||
reloaded = await _reload_with_relationships(session, [item], load_opts)
|
||||
assert len(reloaded) == 1
|
||||
reloaded_item = cast(_CompositeItem, reloaded[0])
|
||||
assert reloaded_item.group is not None
|
||||
assert reloaded_item.group.name == "g1"
|
||||
|
||||
Reference in New Issue
Block a user