From d4498e20630083980f57a97456229e2b2448300a Mon Sep 17 00:00:00 2001 From: d3vyce <44915747+d3vyce@users.noreply.github.com> Date: Fri, 13 Feb 2026 18:03:28 +0100 Subject: [PATCH] feat: add cleanup parameter to create_db_session (#60) --- src/fastapi_toolsets/pytest/utils.py | 17 +++++++++++++---- tests/test_pytest.py | 16 ++++++++++++++++ 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/src/fastapi_toolsets/pytest/utils.py b/src/fastapi_toolsets/pytest/utils.py index a1d9f6c..9e59c9a 100644 --- a/src/fastapi_toolsets/pytest/utils.py +++ b/src/fastapi_toolsets/pytest/utils.py @@ -60,6 +60,7 @@ async def create_db_session( echo: bool = False, expire_on_commit: bool = False, drop_tables: bool = True, + cleanup: bool = False, ) -> AsyncGenerator[AsyncSession, None]: """Create a database session for testing. @@ -72,6 +73,8 @@ async def create_db_session( echo: Enable SQLAlchemy query logging. Defaults to False. expire_on_commit: Expire objects after commit. Defaults to False. drop_tables: Drop tables after test. Defaults to True. + cleanup: Truncate all tables after test using + :func:`cleanup_tables`. Defaults to False. Yields: An AsyncSession ready for database operations. @@ -84,7 +87,9 @@ async def create_db_session( @pytest.fixture async def db_session(): - async with create_db_session(DATABASE_URL, Base) as session: + async with create_db_session( + DATABASE_URL, Base, cleanup=True + ) as session: yield session async def test_create_user(db_session: AsyncSession): @@ -106,6 +111,9 @@ async def create_db_session( async with get_session() as session: yield session + if cleanup: + await cleanup_tables(session, base) + if drop_tables: async with engine.begin() as conn: await conn.run_sync(base.metadata.drop_all) @@ -193,7 +201,7 @@ async def create_worker_database( Example: from fastapi_toolsets.pytest import ( - create_worker_database, create_db_session, cleanup_tables + create_worker_database, create_db_session, ) DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost/test_db" @@ -205,9 +213,10 @@ async def create_worker_database( @pytest.fixture async def db_session(worker_db_url): - async with create_db_session(worker_db_url, Base) as session: + async with create_db_session( + worker_db_url, Base, cleanup=True + ) as session: yield session - await cleanup_tables(session, Base) """ worker_url = worker_database_url( database_url=database_url, default_test_db=default_test_db diff --git a/tests/test_pytest.py b/tests/test_pytest.py index 40cb083..410a82b 100644 --- a/tests/test_pytest.py +++ b/tests/test_pytest.py @@ -297,6 +297,22 @@ class TestCreateDbSession: async with create_db_session(DATABASE_URL, Base, drop_tables=True) as _: pass + @pytest.mark.anyio + async def test_cleanup_truncates_tables(self): + """Tables are truncated after session closes when cleanup=True.""" + role_id = uuid.uuid4() + async with create_db_session( + DATABASE_URL, Base, cleanup=True, drop_tables=False + ) as session: + role = Role(id=role_id, name="will_be_cleaned") + session.add(role) + await session.commit() + + # Data should have been truncated, but tables still exist + async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session: + result = await session.execute(select(Role)) + assert result.all() == [] + class TestGetXdistWorker: """Tests for _get_xdist_worker helper."""