feat: move db related function from pytest to db module (#119)

This commit is contained in:
d3vyce
2026-03-09 17:15:50 +01:00
committed by GitHub
parent 96d445e3f3
commit baebf022f6
8 changed files with 335 additions and 223 deletions

View File

@@ -7,17 +7,19 @@ from enum import Enum
from typing import Any, TypeVar
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase
from .exceptions import NotFoundError
__all__ = [
"LockMode",
"cleanup_tables",
"create_database",
"create_db_context",
"create_db_dependency",
"lock_tables",
"get_transaction",
"lock_tables",
"wait_for_row_change",
]
@@ -188,6 +190,71 @@ async def lock_tables(
yield session
async def create_database(
db_name: str,
*,
server_url: str,
) -> None:
"""Create a database.
Connects to *server_url* using ``AUTOCOMMIT`` isolation and issues a
``CREATE DATABASE`` statement for *db_name*.
Args:
db_name: Name of the database to create.
server_url: URL used for server-level DDL (must point to an existing
database on the same server).
Example:
```python
from fastapi_toolsets.db import create_database
SERVER_URL = "postgresql+asyncpg://postgres:postgres@localhost/postgres"
await create_database("myapp_test", server_url=SERVER_URL)
```
"""
engine = create_async_engine(server_url, isolation_level="AUTOCOMMIT")
try:
async with engine.connect() as conn:
await conn.execute(text(f"CREATE DATABASE {db_name}"))
finally:
await engine.dispose()
async def cleanup_tables(
session: AsyncSession,
base: type[DeclarativeBase],
) -> None:
"""Truncate all tables for fast between-test cleanup.
Executes a single ``TRUNCATE … RESTART IDENTITY CASCADE`` statement
across every table in *base*'s metadata, which is significantly faster
than dropping and re-creating tables between tests.
This is a no-op when the metadata contains no tables.
Args:
session: An active async database session.
base: SQLAlchemy DeclarativeBase class containing model metadata.
Example:
```python
@pytest.fixture
async def db_session(worker_db_url):
async with create_db_session(worker_db_url, Base) as session:
yield session
await cleanup_tables(session, Base)
```
"""
tables = base.metadata.sorted_tables
if not tables:
return
table_names = ", ".join(f'"{t.name}"' for t in tables)
await session.execute(text(f"TRUNCATE {table_names} RESTART IDENTITY CASCADE"))
await session.commit()
_M = TypeVar("_M", bound=DeclarativeBase)