mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-16 06:36:26 +02:00
feat: auto eager-load relationships in register_fixtures (#243)
This commit is contained in:
@@ -1,11 +1,13 @@
|
|||||||
"""Pytest plugin for using FixtureRegistry fixtures in tests."""
|
"""Pytest plugin for using FixtureRegistry fixtures in tests."""
|
||||||
|
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from typing import Any
|
from typing import Any, cast
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase, selectinload
|
||||||
|
from sqlalchemy.orm.interfaces import ExecutableOption, ORMOption
|
||||||
|
|
||||||
from ..db import get_transaction
|
from ..db import get_transaction
|
||||||
from ..fixtures import FixtureRegistry, LoadStrategy
|
from ..fixtures import FixtureRegistry, LoadStrategy
|
||||||
@@ -112,7 +114,7 @@ def _create_fixture_function(
|
|||||||
elif strategy == LoadStrategy.MERGE:
|
elif strategy == LoadStrategy.MERGE:
|
||||||
merged = await session.merge(instance)
|
merged = await session.merge(instance)
|
||||||
loaded.append(merged)
|
loaded.append(merged)
|
||||||
elif strategy == LoadStrategy.SKIP_EXISTING:
|
elif strategy == LoadStrategy.SKIP_EXISTING: # pragma: no branch
|
||||||
pk = _get_primary_key(instance)
|
pk = _get_primary_key(instance)
|
||||||
if pk is not None:
|
if pk is not None:
|
||||||
existing = await session.get(type(instance), pk)
|
existing = await session.get(type(instance), pk)
|
||||||
@@ -125,6 +127,11 @@ def _create_fixture_function(
|
|||||||
session.add(instance)
|
session.add(instance)
|
||||||
loaded.append(instance)
|
loaded.append(instance)
|
||||||
|
|
||||||
|
if loaded: # pragma: no branch
|
||||||
|
load_options = _relationship_load_options(type(loaded[0]))
|
||||||
|
if load_options:
|
||||||
|
return await _reload_with_relationships(session, loaded, load_options)
|
||||||
|
|
||||||
return loaded
|
return loaded
|
||||||
|
|
||||||
# Update function signature to include dependencies
|
# Update function signature to include dependencies
|
||||||
@@ -141,6 +148,54 @@ def _create_fixture_function(
|
|||||||
return created_func
|
return created_func
|
||||||
|
|
||||||
|
|
||||||
|
def _relationship_load_options(model: type[DeclarativeBase]) -> list[ExecutableOption]:
|
||||||
|
"""Build selectinload options for all direct relationships on a model."""
|
||||||
|
return [
|
||||||
|
selectinload(getattr(model, rel.key)) for rel in model.__mapper__.relationships
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def _reload_with_relationships(
|
||||||
|
session: AsyncSession,
|
||||||
|
instances: list[DeclarativeBase],
|
||||||
|
load_options: list[ExecutableOption],
|
||||||
|
) -> list[DeclarativeBase]:
|
||||||
|
"""Reload instances in a single bulk query with relationship eager-loading.
|
||||||
|
|
||||||
|
Uses one SELECT … WHERE pk IN (…) so selectinload can batch all relationship
|
||||||
|
queries — 1 + N_relationships round-trips regardless of how many instances
|
||||||
|
there are, instead of one session.get() per instance.
|
||||||
|
|
||||||
|
Preserves the original insertion order.
|
||||||
|
"""
|
||||||
|
model = type(instances[0])
|
||||||
|
mapper = model.__mapper__
|
||||||
|
pk_cols = mapper.primary_key
|
||||||
|
|
||||||
|
if len(pk_cols) == 1:
|
||||||
|
pk_attr = getattr(model, pk_cols[0].key)
|
||||||
|
pks = [getattr(inst, pk_cols[0].key) for inst in instances]
|
||||||
|
result = await session.execute(
|
||||||
|
select(model).where(pk_attr.in_(pks)).options(*load_options)
|
||||||
|
)
|
||||||
|
by_pk = {getattr(row, pk_cols[0].key): row for row in result.unique().scalars()}
|
||||||
|
return [by_pk[pk] for pk in pks]
|
||||||
|
|
||||||
|
# Composite PK: fall back to per-instance reload
|
||||||
|
reloaded: list[DeclarativeBase] = []
|
||||||
|
for instance in instances:
|
||||||
|
pk = _get_primary_key(instance)
|
||||||
|
refreshed = await session.get(
|
||||||
|
model,
|
||||||
|
pk,
|
||||||
|
options=cast(list[ORMOption], load_options),
|
||||||
|
populate_existing=True,
|
||||||
|
)
|
||||||
|
if refreshed is not None: # pragma: no branch
|
||||||
|
reloaded.append(refreshed)
|
||||||
|
return reloaded
|
||||||
|
|
||||||
|
|
||||||
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
|
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
|
||||||
"""Get the primary key value of a model instance."""
|
"""Get the primary key value of a model instance."""
|
||||||
mapper = instance.__class__.__mapper__
|
mapper = instance.__class__.__mapper__
|
||||||
|
|||||||
@@ -1,17 +1,18 @@
|
|||||||
"""Tests for fastapi_toolsets.pytest module."""
|
"""Tests for fastapi_toolsets.pytest module."""
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi import Depends, FastAPI
|
from fastapi import Depends, FastAPI
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
from sqlalchemy import select, text
|
from sqlalchemy import ForeignKey, String, select, text
|
||||||
from sqlalchemy.engine import make_url
|
from sqlalchemy.engine import make_url
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from fastapi_toolsets.db import get_transaction
|
from fastapi_toolsets.db import get_transaction
|
||||||
from fastapi_toolsets.fixtures import Context, FixtureRegistry
|
from fastapi_toolsets.fixtures import Context, FixtureRegistry, LoadStrategy
|
||||||
from fastapi_toolsets.pytest import (
|
from fastapi_toolsets.pytest import (
|
||||||
create_async_client,
|
create_async_client,
|
||||||
create_db_session,
|
create_db_session,
|
||||||
@@ -19,9 +20,23 @@ from fastapi_toolsets.pytest import (
|
|||||||
register_fixtures,
|
register_fixtures,
|
||||||
worker_database_url,
|
worker_database_url,
|
||||||
)
|
)
|
||||||
|
from fastapi_toolsets.pytest.plugin import (
|
||||||
|
_get_primary_key,
|
||||||
|
_relationship_load_options,
|
||||||
|
_reload_with_relationships,
|
||||||
|
)
|
||||||
from fastapi_toolsets.pytest.utils import _get_xdist_worker
|
from fastapi_toolsets.pytest.utils import _get_xdist_worker
|
||||||
|
|
||||||
from .conftest import DATABASE_URL, Base, Role, RoleCrud, User, UserCrud
|
from .conftest import (
|
||||||
|
DATABASE_URL,
|
||||||
|
Base,
|
||||||
|
IntRole,
|
||||||
|
Permission,
|
||||||
|
Role,
|
||||||
|
RoleCrud,
|
||||||
|
User,
|
||||||
|
UserCrud,
|
||||||
|
)
|
||||||
|
|
||||||
test_registry = FixtureRegistry()
|
test_registry = FixtureRegistry()
|
||||||
|
|
||||||
@@ -136,14 +151,8 @@ class TestGeneratedFixtures:
|
|||||||
async def test_fixture_relationships_work(
|
async def test_fixture_relationships_work(
|
||||||
self, db_session: AsyncSession, fixture_users: list[User]
|
self, db_session: AsyncSession, fixture_users: list[User]
|
||||||
):
|
):
|
||||||
"""Loaded fixtures have working relationships."""
|
"""Loaded fixtures have working relationships directly accessible."""
|
||||||
# Load user with role relationship
|
user = next(u for u in fixture_users if u.id == USER_ADMIN_ID)
|
||||||
user = await UserCrud.get(
|
|
||||||
db_session,
|
|
||||||
[User.id == USER_ADMIN_ID],
|
|
||||||
load_options=[selectinload(User.role)],
|
|
||||||
)
|
|
||||||
|
|
||||||
assert user.role is not None
|
assert user.role is not None
|
||||||
assert user.role.name == "plugin_admin"
|
assert user.role.name == "plugin_admin"
|
||||||
|
|
||||||
@@ -177,6 +186,15 @@ class TestGeneratedFixtures:
|
|||||||
assert users[0].username == "plugin_admin"
|
assert users[0].username == "plugin_admin"
|
||||||
assert users[1].username == "plugin_user"
|
assert users[1].username == "plugin_user"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_fixture_auto_loads_relationships(
|
||||||
|
self, db_session: AsyncSession, fixture_users: list[User]
|
||||||
|
):
|
||||||
|
"""Fixtures automatically eager-load all direct relationships."""
|
||||||
|
user = next(u for u in fixture_users if u.username == "plugin_admin")
|
||||||
|
assert user.role is not None
|
||||||
|
assert user.role.name == "plugin_admin"
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_multiple_fixtures_in_same_test(
|
async def test_multiple_fixtures_in_same_test(
|
||||||
self,
|
self,
|
||||||
@@ -516,3 +534,197 @@ class TestCreateWorkerDatabase:
|
|||||||
)
|
)
|
||||||
assert result.scalar() is None
|
assert result.scalar() is None
|
||||||
await engine.dispose()
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Local models for composite-PK coverage (own Base → own tables, isolated)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class _LocalBase(DeclarativeBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class _Group(_LocalBase):
|
||||||
|
__tablename__ = "_test_groups"
|
||||||
|
id: Mapped[uuid.UUID] = mapped_column(primary_key=True, default=uuid.uuid4)
|
||||||
|
name: Mapped[str] = mapped_column(String(50))
|
||||||
|
|
||||||
|
|
||||||
|
class _CompositeItem(_LocalBase):
|
||||||
|
"""Model with composite PK and a relationship — exercises the fallback path."""
|
||||||
|
|
||||||
|
__tablename__ = "_test_composite_items"
|
||||||
|
group_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
ForeignKey("_test_groups.id"), primary_key=True
|
||||||
|
)
|
||||||
|
item_code: Mapped[str] = mapped_column(String(50), primary_key=True)
|
||||||
|
group: Mapped["_Group"] = relationship()
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetPrimaryKey:
|
||||||
|
"""Unit tests for _get_primary_key — no DB needed."""
|
||||||
|
|
||||||
|
def test_single_pk_returns_value(self):
|
||||||
|
rid = uuid.UUID("00000000-0000-0000-0000-000000000001")
|
||||||
|
role = Role(id=rid, name="x")
|
||||||
|
assert _get_primary_key(role) == rid
|
||||||
|
|
||||||
|
def test_composite_pk_all_set_returns_tuple(self):
|
||||||
|
perm = Permission(subject="posts", action="read")
|
||||||
|
assert _get_primary_key(perm) == ("posts", "read")
|
||||||
|
|
||||||
|
def test_composite_pk_partial_none_returns_none(self):
|
||||||
|
perm = Permission(subject=None, action="read")
|
||||||
|
assert _get_primary_key(perm) is None
|
||||||
|
|
||||||
|
def test_composite_pk_all_none_returns_none(self):
|
||||||
|
perm = Permission(subject=None, action=None)
|
||||||
|
assert _get_primary_key(perm) is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestRelationshipLoadOptions:
|
||||||
|
"""Unit tests for _relationship_load_options — no DB needed."""
|
||||||
|
|
||||||
|
def test_empty_for_model_with_no_relationships(self):
|
||||||
|
assert _relationship_load_options(IntRole) == []
|
||||||
|
|
||||||
|
def test_returns_options_for_model_with_relationships(self):
|
||||||
|
opts = _relationship_load_options(User)
|
||||||
|
assert len(opts) >= 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestFixtureStrategies:
|
||||||
|
"""Integration tests covering INSERT, SKIP_EXISTING, empty fixture, no-rels model."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_empty_fixture_returns_empty_list(self, db_session: AsyncSession):
|
||||||
|
"""Fixture function returning [] produces an empty list."""
|
||||||
|
registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@registry.register()
|
||||||
|
def empty() -> list[Role]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
local_ns: dict = {}
|
||||||
|
register_fixtures(registry, local_ns, session_fixture="db_session")
|
||||||
|
inner = local_ns["fixture_empty"].__wrapped__ # type: ignore[attr-defined]
|
||||||
|
result = await inner(db_session=db_session)
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_insert_strategy_no_relationships(self, db_session: AsyncSession):
|
||||||
|
"""INSERT strategy adds instances; model with no rels skips reload (line 135)."""
|
||||||
|
registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@registry.register()
|
||||||
|
def int_roles() -> list[IntRole]:
|
||||||
|
return [IntRole(name="insert_role")]
|
||||||
|
|
||||||
|
local_ns: dict = {}
|
||||||
|
register_fixtures(
|
||||||
|
registry,
|
||||||
|
local_ns,
|
||||||
|
session_fixture="db_session",
|
||||||
|
strategy=LoadStrategy.INSERT,
|
||||||
|
)
|
||||||
|
inner = local_ns["fixture_int_roles"].__wrapped__ # type: ignore[attr-defined]
|
||||||
|
result = await inner(db_session=db_session)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0].name == "insert_role"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_skip_existing_inserts_new_record(self, db_session: AsyncSession):
|
||||||
|
"""SKIP_EXISTING inserts when the record does not yet exist."""
|
||||||
|
registry = FixtureRegistry()
|
||||||
|
role_id = uuid.uuid4()
|
||||||
|
|
||||||
|
@registry.register()
|
||||||
|
def new_roles() -> list[Role]:
|
||||||
|
return [Role(id=role_id, name="skip_new")]
|
||||||
|
|
||||||
|
local_ns: dict = {}
|
||||||
|
register_fixtures(
|
||||||
|
registry,
|
||||||
|
local_ns,
|
||||||
|
session_fixture="db_session",
|
||||||
|
strategy=LoadStrategy.SKIP_EXISTING,
|
||||||
|
)
|
||||||
|
inner = local_ns["fixture_new_roles"].__wrapped__ # type: ignore[attr-defined]
|
||||||
|
result = await inner(db_session=db_session)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0].id == role_id
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_skip_existing_returns_existing_record(
|
||||||
|
self, db_session: AsyncSession
|
||||||
|
):
|
||||||
|
"""SKIP_EXISTING returns the existing DB record when PK already present."""
|
||||||
|
role_id = uuid.uuid4()
|
||||||
|
existing = Role(id=role_id, name="already_there")
|
||||||
|
db_session.add(existing)
|
||||||
|
await db_session.flush()
|
||||||
|
|
||||||
|
registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@registry.register()
|
||||||
|
def dup_roles() -> list[Role]:
|
||||||
|
return [Role(id=role_id, name="should_not_overwrite")]
|
||||||
|
|
||||||
|
local_ns: dict = {}
|
||||||
|
register_fixtures(
|
||||||
|
registry,
|
||||||
|
local_ns,
|
||||||
|
session_fixture="db_session",
|
||||||
|
strategy=LoadStrategy.SKIP_EXISTING,
|
||||||
|
)
|
||||||
|
inner = local_ns["fixture_dup_roles"].__wrapped__ # type: ignore[attr-defined]
|
||||||
|
result = await inner(db_session=db_session)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0].name == "already_there"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_skip_existing_null_pk_inserts(self, db_session: AsyncSession):
|
||||||
|
"""SKIP_EXISTING with null PK (auto-increment) falls through to session.add()."""
|
||||||
|
registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@registry.register()
|
||||||
|
def auto_roles() -> list[IntRole]:
|
||||||
|
return [IntRole(name="auto_int")]
|
||||||
|
|
||||||
|
local_ns: dict = {}
|
||||||
|
register_fixtures(
|
||||||
|
registry,
|
||||||
|
local_ns,
|
||||||
|
session_fixture="db_session",
|
||||||
|
strategy=LoadStrategy.SKIP_EXISTING,
|
||||||
|
)
|
||||||
|
inner = local_ns["fixture_auto_roles"].__wrapped__ # type: ignore[attr-defined]
|
||||||
|
result = await inner(db_session=db_session)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0].name == "auto_int"
|
||||||
|
|
||||||
|
|
||||||
|
class TestReloadWithRelationshipsCompositePK:
|
||||||
|
"""Integration test for _reload_with_relationships composite-PK fallback."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_composite_pk_fallback_loads_relationships(self):
|
||||||
|
"""Models with composite PKs are reloaded per-instance via session.get()."""
|
||||||
|
async with create_db_session(DATABASE_URL, _LocalBase) as session:
|
||||||
|
group = _Group(id=uuid.uuid4(), name="g1")
|
||||||
|
session.add(group)
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
item = _CompositeItem(group_id=group.id, item_code="A")
|
||||||
|
session.add(item)
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
load_opts = _relationship_load_options(_CompositeItem)
|
||||||
|
assert load_opts # _CompositeItem has 'group' relationship
|
||||||
|
|
||||||
|
reloaded = await _reload_with_relationships(session, [item], load_opts)
|
||||||
|
assert len(reloaded) == 1
|
||||||
|
reloaded_item = cast(_CompositeItem, reloaded[0])
|
||||||
|
assert reloaded_item.group is not None
|
||||||
|
assert reloaded_item.group.name == "g1"
|
||||||
|
|||||||
Reference in New Issue
Block a user