From c7397faea4742f75492bd4d36b4bf3373aa5c571 Mon Sep 17 00:00:00 2001 From: d3vyce <44915747+d3vyce@users.noreply.github.com> Date: Sun, 12 Apr 2026 17:04:44 +0200 Subject: [PATCH] feat: auto eager-load relationships in register_fixtures (#243) --- src/fastapi_toolsets/pytest/plugin.py | 61 ++++++- tests/test_pytest.py | 236 ++++++++++++++++++++++++-- 2 files changed, 282 insertions(+), 15 deletions(-) diff --git a/src/fastapi_toolsets/pytest/plugin.py b/src/fastapi_toolsets/pytest/plugin.py index 1cec985..2aba2f9 100644 --- a/src/fastapi_toolsets/pytest/plugin.py +++ b/src/fastapi_toolsets/pytest/plugin.py @@ -1,11 +1,13 @@ """Pytest plugin for using FixtureRegistry fixtures in tests.""" from collections.abc import Callable, Sequence -from typing import Any +from typing import Any, cast import pytest +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import DeclarativeBase, selectinload +from sqlalchemy.orm.interfaces import ExecutableOption, ORMOption from ..db import get_transaction from ..fixtures import FixtureRegistry, LoadStrategy @@ -112,7 +114,7 @@ def _create_fixture_function( elif strategy == LoadStrategy.MERGE: merged = await session.merge(instance) loaded.append(merged) - elif strategy == LoadStrategy.SKIP_EXISTING: + elif strategy == LoadStrategy.SKIP_EXISTING: # pragma: no branch pk = _get_primary_key(instance) if pk is not None: existing = await session.get(type(instance), pk) @@ -125,6 +127,11 @@ def _create_fixture_function( session.add(instance) loaded.append(instance) + if loaded: # pragma: no branch + load_options = _relationship_load_options(type(loaded[0])) + if load_options: + return await _reload_with_relationships(session, loaded, load_options) + return loaded # Update function signature to include dependencies @@ -141,6 +148,54 @@ def _create_fixture_function( return created_func +def _relationship_load_options(model: type[DeclarativeBase]) -> list[ExecutableOption]: + """Build selectinload options for all direct relationships on a model.""" + return [ + selectinload(getattr(model, rel.key)) for rel in model.__mapper__.relationships + ] + + +async def _reload_with_relationships( + session: AsyncSession, + instances: list[DeclarativeBase], + load_options: list[ExecutableOption], +) -> list[DeclarativeBase]: + """Reload instances in a single bulk query with relationship eager-loading. + + Uses one SELECT … WHERE pk IN (…) so selectinload can batch all relationship + queries — 1 + N_relationships round-trips regardless of how many instances + there are, instead of one session.get() per instance. + + Preserves the original insertion order. + """ + model = type(instances[0]) + mapper = model.__mapper__ + pk_cols = mapper.primary_key + + if len(pk_cols) == 1: + pk_attr = getattr(model, pk_cols[0].key) + pks = [getattr(inst, pk_cols[0].key) for inst in instances] + result = await session.execute( + select(model).where(pk_attr.in_(pks)).options(*load_options) + ) + by_pk = {getattr(row, pk_cols[0].key): row for row in result.unique().scalars()} + return [by_pk[pk] for pk in pks] + + # Composite PK: fall back to per-instance reload + reloaded: list[DeclarativeBase] = [] + for instance in instances: + pk = _get_primary_key(instance) + refreshed = await session.get( + model, + pk, + options=cast(list[ORMOption], load_options), + populate_existing=True, + ) + if refreshed is not None: # pragma: no branch + reloaded.append(refreshed) + return reloaded + + def _get_primary_key(instance: DeclarativeBase) -> Any | None: """Get the primary key value of a model instance.""" mapper = instance.__class__.__mapper__ diff --git a/tests/test_pytest.py b/tests/test_pytest.py index b277251..623e043 100644 --- a/tests/test_pytest.py +++ b/tests/test_pytest.py @@ -1,17 +1,18 @@ """Tests for fastapi_toolsets.pytest module.""" import uuid +from typing import cast import pytest from fastapi import Depends, FastAPI from httpx import AsyncClient -from sqlalchemy import select, text +from sqlalchemy import ForeignKey, String, select, text from sqlalchemy.engine import make_url from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine -from sqlalchemy.orm import selectinload +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship from fastapi_toolsets.db import get_transaction -from fastapi_toolsets.fixtures import Context, FixtureRegistry +from fastapi_toolsets.fixtures import Context, FixtureRegistry, LoadStrategy from fastapi_toolsets.pytest import ( create_async_client, create_db_session, @@ -19,9 +20,23 @@ from fastapi_toolsets.pytest import ( register_fixtures, worker_database_url, ) +from fastapi_toolsets.pytest.plugin import ( + _get_primary_key, + _relationship_load_options, + _reload_with_relationships, +) from fastapi_toolsets.pytest.utils import _get_xdist_worker -from .conftest import DATABASE_URL, Base, Role, RoleCrud, User, UserCrud +from .conftest import ( + DATABASE_URL, + Base, + IntRole, + Permission, + Role, + RoleCrud, + User, + UserCrud, +) test_registry = FixtureRegistry() @@ -136,14 +151,8 @@ class TestGeneratedFixtures: async def test_fixture_relationships_work( self, db_session: AsyncSession, fixture_users: list[User] ): - """Loaded fixtures have working relationships.""" - # Load user with role relationship - user = await UserCrud.get( - db_session, - [User.id == USER_ADMIN_ID], - load_options=[selectinload(User.role)], - ) - + """Loaded fixtures have working relationships directly accessible.""" + user = next(u for u in fixture_users if u.id == USER_ADMIN_ID) assert user.role is not None assert user.role.name == "plugin_admin" @@ -177,6 +186,15 @@ class TestGeneratedFixtures: assert users[0].username == "plugin_admin" assert users[1].username == "plugin_user" + @pytest.mark.anyio + async def test_fixture_auto_loads_relationships( + self, db_session: AsyncSession, fixture_users: list[User] + ): + """Fixtures automatically eager-load all direct relationships.""" + user = next(u for u in fixture_users if u.username == "plugin_admin") + assert user.role is not None + assert user.role.name == "plugin_admin" + @pytest.mark.anyio async def test_multiple_fixtures_in_same_test( self, @@ -516,3 +534,197 @@ class TestCreateWorkerDatabase: ) assert result.scalar() is None await engine.dispose() + + +# --------------------------------------------------------------------------- +# Local models for composite-PK coverage (own Base → own tables, isolated) +# --------------------------------------------------------------------------- + + +class _LocalBase(DeclarativeBase): + pass + + +class _Group(_LocalBase): + __tablename__ = "_test_groups" + id: Mapped[uuid.UUID] = mapped_column(primary_key=True, default=uuid.uuid4) + name: Mapped[str] = mapped_column(String(50)) + + +class _CompositeItem(_LocalBase): + """Model with composite PK and a relationship — exercises the fallback path.""" + + __tablename__ = "_test_composite_items" + group_id: Mapped[uuid.UUID] = mapped_column( + ForeignKey("_test_groups.id"), primary_key=True + ) + item_code: Mapped[str] = mapped_column(String(50), primary_key=True) + group: Mapped["_Group"] = relationship() + + +class TestGetPrimaryKey: + """Unit tests for _get_primary_key — no DB needed.""" + + def test_single_pk_returns_value(self): + rid = uuid.UUID("00000000-0000-0000-0000-000000000001") + role = Role(id=rid, name="x") + assert _get_primary_key(role) == rid + + def test_composite_pk_all_set_returns_tuple(self): + perm = Permission(subject="posts", action="read") + assert _get_primary_key(perm) == ("posts", "read") + + def test_composite_pk_partial_none_returns_none(self): + perm = Permission(subject=None, action="read") + assert _get_primary_key(perm) is None + + def test_composite_pk_all_none_returns_none(self): + perm = Permission(subject=None, action=None) + assert _get_primary_key(perm) is None + + +class TestRelationshipLoadOptions: + """Unit tests for _relationship_load_options — no DB needed.""" + + def test_empty_for_model_with_no_relationships(self): + assert _relationship_load_options(IntRole) == [] + + def test_returns_options_for_model_with_relationships(self): + opts = _relationship_load_options(User) + assert len(opts) >= 1 + + +class TestFixtureStrategies: + """Integration tests covering INSERT, SKIP_EXISTING, empty fixture, no-rels model.""" + + @pytest.mark.anyio + async def test_empty_fixture_returns_empty_list(self, db_session: AsyncSession): + """Fixture function returning [] produces an empty list.""" + registry = FixtureRegistry() + + @registry.register() + def empty() -> list[Role]: + return [] + + local_ns: dict = {} + register_fixtures(registry, local_ns, session_fixture="db_session") + inner = local_ns["fixture_empty"].__wrapped__ # type: ignore[attr-defined] + result = await inner(db_session=db_session) + assert result == [] + + @pytest.mark.anyio + async def test_insert_strategy_no_relationships(self, db_session: AsyncSession): + """INSERT strategy adds instances; model with no rels skips reload (line 135).""" + registry = FixtureRegistry() + + @registry.register() + def int_roles() -> list[IntRole]: + return [IntRole(name="insert_role")] + + local_ns: dict = {} + register_fixtures( + registry, + local_ns, + session_fixture="db_session", + strategy=LoadStrategy.INSERT, + ) + inner = local_ns["fixture_int_roles"].__wrapped__ # type: ignore[attr-defined] + result = await inner(db_session=db_session) + assert len(result) == 1 + assert result[0].name == "insert_role" + + @pytest.mark.anyio + async def test_skip_existing_inserts_new_record(self, db_session: AsyncSession): + """SKIP_EXISTING inserts when the record does not yet exist.""" + registry = FixtureRegistry() + role_id = uuid.uuid4() + + @registry.register() + def new_roles() -> list[Role]: + return [Role(id=role_id, name="skip_new")] + + local_ns: dict = {} + register_fixtures( + registry, + local_ns, + session_fixture="db_session", + strategy=LoadStrategy.SKIP_EXISTING, + ) + inner = local_ns["fixture_new_roles"].__wrapped__ # type: ignore[attr-defined] + result = await inner(db_session=db_session) + assert len(result) == 1 + assert result[0].id == role_id + + @pytest.mark.anyio + async def test_skip_existing_returns_existing_record( + self, db_session: AsyncSession + ): + """SKIP_EXISTING returns the existing DB record when PK already present.""" + role_id = uuid.uuid4() + existing = Role(id=role_id, name="already_there") + db_session.add(existing) + await db_session.flush() + + registry = FixtureRegistry() + + @registry.register() + def dup_roles() -> list[Role]: + return [Role(id=role_id, name="should_not_overwrite")] + + local_ns: dict = {} + register_fixtures( + registry, + local_ns, + session_fixture="db_session", + strategy=LoadStrategy.SKIP_EXISTING, + ) + inner = local_ns["fixture_dup_roles"].__wrapped__ # type: ignore[attr-defined] + result = await inner(db_session=db_session) + assert len(result) == 1 + assert result[0].name == "already_there" + + @pytest.mark.anyio + async def test_skip_existing_null_pk_inserts(self, db_session: AsyncSession): + """SKIP_EXISTING with null PK (auto-increment) falls through to session.add().""" + registry = FixtureRegistry() + + @registry.register() + def auto_roles() -> list[IntRole]: + return [IntRole(name="auto_int")] + + local_ns: dict = {} + register_fixtures( + registry, + local_ns, + session_fixture="db_session", + strategy=LoadStrategy.SKIP_EXISTING, + ) + inner = local_ns["fixture_auto_roles"].__wrapped__ # type: ignore[attr-defined] + result = await inner(db_session=db_session) + assert len(result) == 1 + assert result[0].name == "auto_int" + + +class TestReloadWithRelationshipsCompositePK: + """Integration test for _reload_with_relationships composite-PK fallback.""" + + @pytest.mark.anyio + async def test_composite_pk_fallback_loads_relationships(self): + """Models with composite PKs are reloaded per-instance via session.get().""" + async with create_db_session(DATABASE_URL, _LocalBase) as session: + group = _Group(id=uuid.uuid4(), name="g1") + session.add(group) + await session.flush() + + item = _CompositeItem(group_id=group.id, item_code="A") + session.add(item) + await session.flush() + + load_opts = _relationship_load_options(_CompositeItem) + assert load_opts # _CompositeItem has 'group' relationship + + reloaded = await _reload_with_relationships(session, [item], load_opts) + assert len(reloaded) == 1 + reloaded_item = cast(_CompositeItem, reloaded[0]) + assert reloaded_item.group is not None + assert reloaded_item.group.name == "g1"