From 666c621fda7a408ad7f97cfd2b8cff1d5462a7ab Mon Sep 17 00:00:00 2001 From: d3vyce <44915747+d3vyce@users.noreply.github.com> Date: Thu, 26 Mar 2026 19:57:40 +0100 Subject: [PATCH] fix: create_db_session commits via real transaction, not savepoint (#184) --- src/fastapi_toolsets/pytest/utils.py | 17 +++------ tests/test_pytest.py | 52 +++++++++++++++++++++++++++- 2 files changed, 56 insertions(+), 13 deletions(-) diff --git a/src/fastapi_toolsets/pytest/utils.py b/src/fastapi_toolsets/pytest/utils.py index 51d6502..db5f118 100644 --- a/src/fastapi_toolsets/pytest/utils.py +++ b/src/fastapi_toolsets/pytest/utils.py @@ -7,6 +7,7 @@ from contextlib import asynccontextmanager from typing import Any from httpx import ASGITransport, AsyncClient +from sqlalchemy import text from sqlalchemy.engine import make_url from sqlalchemy.ext.asyncio import ( AsyncSession, @@ -15,13 +16,8 @@ from sqlalchemy.ext.asyncio import ( ) from sqlalchemy.orm import DeclarativeBase -from sqlalchemy import text - -from ..db import ( - cleanup_tables as _cleanup_tables, - create_database, - create_db_context, -) +from ..db import cleanup_tables as _cleanup_tables +from ..db import create_database async def cleanup_tables( @@ -269,15 +265,12 @@ async def create_db_session( async with engine.begin() as conn: await conn.run_sync(base.metadata.create_all) - # Create session using existing db context utility session_maker = async_sessionmaker(engine, expire_on_commit=expire_on_commit) - get_session = create_db_context(session_maker) - - async with get_session() as session: + async with session_maker() as session: yield session if cleanup: - await cleanup_tables(session, base) + await _cleanup_tables(session=session, base=base) if drop_tables: async with engine.begin() as conn: diff --git a/tests/test_pytest.py b/tests/test_pytest.py index 7791f78..b0ffa67 100644 --- a/tests/test_pytest.py +++ b/tests/test_pytest.py @@ -7,9 +7,10 @@ from fastapi import Depends, FastAPI from httpx import AsyncClient from sqlalchemy import select, text from sqlalchemy.engine import make_url -from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.orm import selectinload +from fastapi_toolsets.db import get_transaction from fastapi_toolsets.fixtures import Context, FixtureRegistry from fastapi_toolsets.pytest import ( create_async_client, @@ -336,6 +337,55 @@ class TestCreateDbSession: result = await session.execute(select(Role)) assert result.all() == [] + @pytest.mark.anyio + async def test_get_transaction_commits_visible_to_separate_session(self): + """Data written via get_transaction() is committed and visible to other sessions.""" + role_id = uuid.uuid4() + + async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session: + # Simulate what _create_fixture_function does: insert via get_transaction + # with no explicit commit afterward. + async with get_transaction(session): + role = Role(id=role_id, name="visible_to_other_session") + session.add(role) + + # The data must have been committed (begin/commit, not a savepoint), + # so a separate engine/session can read it. + other_engine = create_async_engine(DATABASE_URL, echo=False) + try: + other_session_maker = async_sessionmaker( + other_engine, expire_on_commit=False + ) + async with other_session_maker() as other: + result = await other.execute(select(Role).where(Role.id == role_id)) + fetched = result.scalar_one_or_none() + assert fetched is not None, ( + "Fixture data inserted via get_transaction() must be committed " + "and visible to a separate session. If create_db_session uses " + "create_db_context, auto-begin forces get_transaction() into " + "savepoints instead of real commits." + ) + assert fetched.name == "visible_to_other_session" + finally: + await other_engine.dispose() + + # Cleanup + async with create_db_session(DATABASE_URL, Base, drop_tables=True) as _: + pass + + +class TestDeprecatedCleanupTables: + """Tests for the deprecated cleanup_tables re-export in fastapi_toolsets.pytest.""" + + @pytest.mark.anyio + async def test_emits_deprecation_warning(self): + """cleanup_tables imported from fastapi_toolsets.pytest emits DeprecationWarning.""" + from fastapi_toolsets.pytest.utils import cleanup_tables + + async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session: + with pytest.warns(DeprecationWarning, match="fastapi_toolsets.db"): + await cleanup_tables(session, Base) + class TestGetXdistWorker: """Tests for _get_xdist_worker helper."""