mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-03-01 17:00:48 +01:00
feat: add cleanup parameter to create_db_session (#60)
This commit is contained in:
@@ -60,6 +60,7 @@ async def create_db_session(
|
||||
echo: bool = False,
|
||||
expire_on_commit: bool = False,
|
||||
drop_tables: bool = True,
|
||||
cleanup: bool = False,
|
||||
) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Create a database session for testing.
|
||||
|
||||
@@ -72,6 +73,8 @@ async def create_db_session(
|
||||
echo: Enable SQLAlchemy query logging. Defaults to False.
|
||||
expire_on_commit: Expire objects after commit. Defaults to False.
|
||||
drop_tables: Drop tables after test. Defaults to True.
|
||||
cleanup: Truncate all tables after test using
|
||||
:func:`cleanup_tables`. Defaults to False.
|
||||
|
||||
Yields:
|
||||
An AsyncSession ready for database operations.
|
||||
@@ -84,7 +87,9 @@ async def create_db_session(
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session():
|
||||
async with create_db_session(DATABASE_URL, Base) as session:
|
||||
async with create_db_session(
|
||||
DATABASE_URL, Base, cleanup=True
|
||||
) as session:
|
||||
yield session
|
||||
|
||||
async def test_create_user(db_session: AsyncSession):
|
||||
@@ -106,6 +111,9 @@ async def create_db_session(
|
||||
async with get_session() as session:
|
||||
yield session
|
||||
|
||||
if cleanup:
|
||||
await cleanup_tables(session, base)
|
||||
|
||||
if drop_tables:
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(base.metadata.drop_all)
|
||||
@@ -193,7 +201,7 @@ async def create_worker_database(
|
||||
|
||||
Example:
|
||||
from fastapi_toolsets.pytest import (
|
||||
create_worker_database, create_db_session, cleanup_tables
|
||||
create_worker_database, create_db_session,
|
||||
)
|
||||
|
||||
DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost/test_db"
|
||||
@@ -205,9 +213,10 @@ async def create_worker_database(
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session(worker_db_url):
|
||||
async with create_db_session(worker_db_url, Base) as session:
|
||||
async with create_db_session(
|
||||
worker_db_url, Base, cleanup=True
|
||||
) as session:
|
||||
yield session
|
||||
await cleanup_tables(session, Base)
|
||||
"""
|
||||
worker_url = worker_database_url(
|
||||
database_url=database_url, default_test_db=default_test_db
|
||||
|
||||
@@ -297,6 +297,22 @@ class TestCreateDbSession:
|
||||
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as _:
|
||||
pass
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_cleanup_truncates_tables(self):
|
||||
"""Tables are truncated after session closes when cleanup=True."""
|
||||
role_id = uuid.uuid4()
|
||||
async with create_db_session(
|
||||
DATABASE_URL, Base, cleanup=True, drop_tables=False
|
||||
) as session:
|
||||
role = Role(id=role_id, name="will_be_cleaned")
|
||||
session.add(role)
|
||||
await session.commit()
|
||||
|
||||
# Data should have been truncated, but tables still exist
|
||||
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
|
||||
result = await session.execute(select(Role))
|
||||
assert result.all() == []
|
||||
|
||||
|
||||
class TestGetXdistWorker:
|
||||
"""Tests for _get_xdist_worker helper."""
|
||||
|
||||
Reference in New Issue
Block a user