mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-15 22:26:25 +02:00
484 lines
16 KiB
Python
484 lines
16 KiB
Python
"""Database utilities: sessions, transactions, and locks."""
|
|
|
|
import asyncio
|
|
from collections.abc import AsyncGenerator, Callable
|
|
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
|
from enum import Enum
|
|
from typing import Any, TypeVar, cast
|
|
|
|
from sqlalchemy import Table, delete, text, tuple_
|
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
from sqlalchemy.orm import DeclarativeBase, QueryableAttribute
|
|
from sqlalchemy.orm.relationships import RelationshipProperty
|
|
|
|
from .exceptions import NotFoundError
|
|
|
|
__all__ = [
|
|
"LockMode",
|
|
"cleanup_tables",
|
|
"create_database",
|
|
"create_db_context",
|
|
"create_db_dependency",
|
|
"get_transaction",
|
|
"lock_tables",
|
|
"m2m_add",
|
|
"m2m_remove",
|
|
"m2m_set",
|
|
"wait_for_row_change",
|
|
]
|
|
|
|
|
|
_SessionT = TypeVar("_SessionT", bound=AsyncSession)
|
|
|
|
|
|
def create_db_dependency(
|
|
session_maker: async_sessionmaker[_SessionT],
|
|
) -> Callable[[], AsyncGenerator[_SessionT, None]]:
|
|
"""Create a FastAPI dependency for database sessions.
|
|
|
|
Creates a dependency function that yields a session and auto-commits
|
|
if a transaction is active when the request completes.
|
|
|
|
Args:
|
|
session_maker: Async session factory from create_session_factory()
|
|
|
|
Returns:
|
|
An async generator function usable with FastAPI's Depends()
|
|
|
|
Example:
|
|
```python
|
|
from fastapi import Depends
|
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
|
from fastapi_toolsets.db import create_db_dependency
|
|
|
|
engine = create_async_engine("postgresql+asyncpg://...")
|
|
SessionLocal = async_sessionmaker(engine, expire_on_commit=False)
|
|
get_db = create_db_dependency(SessionLocal)
|
|
|
|
@app.get("/users")
|
|
async def list_users(session: AsyncSession = Depends(get_db)):
|
|
...
|
|
```
|
|
"""
|
|
|
|
async def get_db() -> AsyncGenerator[_SessionT, None]:
|
|
async with session_maker() as session:
|
|
await session.connection()
|
|
yield session
|
|
if session.in_transaction():
|
|
await session.commit()
|
|
|
|
return get_db
|
|
|
|
|
|
def create_db_context(
|
|
session_maker: async_sessionmaker[_SessionT],
|
|
) -> Callable[[], AbstractAsyncContextManager[_SessionT]]:
|
|
"""Create a context manager for database sessions.
|
|
|
|
Creates a context manager for use outside of FastAPI request handlers,
|
|
such as in background tasks, CLI commands, or tests.
|
|
|
|
Args:
|
|
session_maker: Async session factory from create_session_factory()
|
|
|
|
Returns:
|
|
An async context manager function
|
|
|
|
Example:
|
|
```python
|
|
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
|
from fastapi_toolsets.db import create_db_context
|
|
|
|
engine = create_async_engine("postgresql+asyncpg://...")
|
|
SessionLocal = async_sessionmaker(engine, expire_on_commit=False)
|
|
get_db_context = create_db_context(SessionLocal)
|
|
|
|
async def background_task():
|
|
async with get_db_context() as session:
|
|
user = await UserCrud.get(session, [User.id == 1])
|
|
...
|
|
```
|
|
"""
|
|
get_db = create_db_dependency(session_maker)
|
|
return asynccontextmanager(get_db)
|
|
|
|
|
|
@asynccontextmanager
|
|
async def get_transaction(
|
|
session: AsyncSession,
|
|
) -> AsyncGenerator[AsyncSession, None]:
|
|
"""Get a transaction context, handling nested transactions.
|
|
|
|
If already in a transaction, creates a savepoint (nested transaction).
|
|
Otherwise, starts a new transaction.
|
|
|
|
Args:
|
|
session: AsyncSession instance
|
|
|
|
Yields:
|
|
The session within the transaction context
|
|
|
|
Example:
|
|
```python
|
|
async with get_transaction(session):
|
|
session.add(model)
|
|
# Auto-commits on exit, rolls back on exception
|
|
```
|
|
"""
|
|
if session.in_transaction():
|
|
async with session.begin_nested():
|
|
yield session
|
|
else:
|
|
async with session.begin():
|
|
yield session
|
|
|
|
|
|
class LockMode(str, Enum):
|
|
"""PostgreSQL table lock modes.
|
|
|
|
See: https://www.postgresql.org/docs/current/explicit-locking.html
|
|
"""
|
|
|
|
ACCESS_SHARE = "ACCESS SHARE"
|
|
ROW_SHARE = "ROW SHARE"
|
|
ROW_EXCLUSIVE = "ROW EXCLUSIVE"
|
|
SHARE_UPDATE_EXCLUSIVE = "SHARE UPDATE EXCLUSIVE"
|
|
SHARE = "SHARE"
|
|
SHARE_ROW_EXCLUSIVE = "SHARE ROW EXCLUSIVE"
|
|
EXCLUSIVE = "EXCLUSIVE"
|
|
ACCESS_EXCLUSIVE = "ACCESS EXCLUSIVE"
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lock_tables(
|
|
session: AsyncSession,
|
|
tables: list[type[DeclarativeBase]],
|
|
*,
|
|
mode: LockMode = LockMode.SHARE_UPDATE_EXCLUSIVE,
|
|
timeout: str = "5s",
|
|
) -> AsyncGenerator[AsyncSession, None]:
|
|
"""Lock PostgreSQL tables for the duration of a transaction.
|
|
|
|
Acquires table-level locks that are held until the transaction ends.
|
|
Useful for preventing concurrent modifications during critical operations.
|
|
|
|
Args:
|
|
session: AsyncSession instance
|
|
tables: List of SQLAlchemy model classes to lock
|
|
mode: Lock mode (default: SHARE UPDATE EXCLUSIVE)
|
|
timeout: Lock timeout (default: "5s")
|
|
|
|
Yields:
|
|
The session with locked tables
|
|
|
|
Raises:
|
|
SQLAlchemyError: If lock cannot be acquired within timeout
|
|
|
|
Example:
|
|
```python
|
|
from fastapi_toolsets.db import lock_tables, LockMode
|
|
|
|
async with lock_tables(session, [User, Account]):
|
|
# Tables are locked with SHARE UPDATE EXCLUSIVE mode
|
|
user = await UserCrud.get(session, [User.id == 1])
|
|
user.balance += 100
|
|
|
|
# With custom lock mode
|
|
async with lock_tables(session, [Order], mode=LockMode.EXCLUSIVE):
|
|
# Exclusive lock - no other transactions can access
|
|
await process_order(session, order_id)
|
|
```
|
|
"""
|
|
table_names = ",".join(table.__tablename__ for table in tables)
|
|
|
|
async with get_transaction(session):
|
|
await session.execute(text(f"SET LOCAL lock_timeout='{timeout}'"))
|
|
await session.execute(text(f"LOCK {table_names} IN {mode.value} MODE"))
|
|
yield session
|
|
|
|
|
|
async def create_database(
|
|
db_name: str,
|
|
*,
|
|
server_url: str,
|
|
) -> None:
|
|
"""Create a database.
|
|
|
|
Connects to *server_url* using ``AUTOCOMMIT`` isolation and issues a
|
|
``CREATE DATABASE`` statement for *db_name*.
|
|
|
|
Args:
|
|
db_name: Name of the database to create.
|
|
server_url: URL used for server-level DDL (must point to an existing
|
|
database on the same server).
|
|
|
|
Example:
|
|
```python
|
|
from fastapi_toolsets.db import create_database
|
|
|
|
SERVER_URL = "postgresql+asyncpg://postgres:postgres@localhost/postgres"
|
|
await create_database("myapp_test", server_url=SERVER_URL)
|
|
```
|
|
"""
|
|
engine = create_async_engine(server_url, isolation_level="AUTOCOMMIT")
|
|
try:
|
|
async with engine.connect() as conn:
|
|
await conn.execute(text(f"CREATE DATABASE {db_name}"))
|
|
finally:
|
|
await engine.dispose()
|
|
|
|
|
|
async def cleanup_tables(
|
|
session: AsyncSession,
|
|
base: type[DeclarativeBase],
|
|
) -> None:
|
|
"""Truncate all tables for fast between-test cleanup.
|
|
|
|
Executes a single ``TRUNCATE … RESTART IDENTITY CASCADE`` statement
|
|
across every table in *base*'s metadata, which is significantly faster
|
|
than dropping and re-creating tables between tests.
|
|
|
|
This is a no-op when the metadata contains no tables.
|
|
|
|
Args:
|
|
session: An active async database session.
|
|
base: SQLAlchemy DeclarativeBase class containing model metadata.
|
|
|
|
Example:
|
|
```python
|
|
@pytest.fixture
|
|
async def db_session(worker_db_url):
|
|
async with create_db_session(worker_db_url, Base) as session:
|
|
yield session
|
|
await cleanup_tables(session, Base)
|
|
```
|
|
"""
|
|
tables = base.metadata.sorted_tables
|
|
if not tables:
|
|
return
|
|
|
|
table_names = ", ".join(f'"{t.name}"' for t in tables)
|
|
await session.execute(text(f"TRUNCATE {table_names} RESTART IDENTITY CASCADE"))
|
|
await session.commit()
|
|
|
|
|
|
_M = TypeVar("_M", bound=DeclarativeBase)
|
|
|
|
|
|
async def wait_for_row_change(
|
|
session: AsyncSession,
|
|
model: type[_M],
|
|
pk_value: Any,
|
|
*,
|
|
columns: list[str] | None = None,
|
|
interval: float = 0.5,
|
|
timeout: float | None = None,
|
|
) -> _M:
|
|
"""Poll a database row until a change is detected.
|
|
|
|
Queries the row every ``interval`` seconds and returns the model instance
|
|
once a change is detected in any column (or only the specified ``columns``).
|
|
|
|
Args:
|
|
session: AsyncSession instance
|
|
model: SQLAlchemy model class
|
|
pk_value: Primary key value of the row to watch
|
|
columns: Optional list of column names to watch. If None, all columns
|
|
are watched.
|
|
interval: Polling interval in seconds (default: 0.5)
|
|
timeout: Maximum time to wait in seconds. None means wait forever.
|
|
|
|
Returns:
|
|
The refreshed model instance with updated values
|
|
|
|
Raises:
|
|
NotFoundError: If the row does not exist or is deleted during polling
|
|
TimeoutError: If timeout expires before a change is detected
|
|
|
|
Example:
|
|
```python
|
|
from fastapi_toolsets.db import wait_for_row_change
|
|
|
|
# Wait for any column to change
|
|
updated = await wait_for_row_change(session, User, user_id)
|
|
|
|
# Watch specific columns with a timeout
|
|
updated = await wait_for_row_change(
|
|
session, User, user_id,
|
|
columns=["status", "email"],
|
|
interval=1.0,
|
|
timeout=30.0,
|
|
)
|
|
```
|
|
"""
|
|
instance = await session.get(model, pk_value)
|
|
if instance is None:
|
|
raise NotFoundError(f"{model.__name__} with pk={pk_value!r} not found")
|
|
|
|
if columns is not None:
|
|
watch_cols = columns
|
|
else:
|
|
watch_cols = [attr.key for attr in model.__mapper__.column_attrs]
|
|
|
|
initial = {col: getattr(instance, col) for col in watch_cols}
|
|
|
|
elapsed = 0.0
|
|
while True:
|
|
await asyncio.sleep(interval)
|
|
elapsed += interval
|
|
|
|
if timeout is not None and elapsed >= timeout:
|
|
raise TimeoutError(
|
|
f"No change detected on {model.__name__} "
|
|
f"with pk={pk_value!r} within {timeout}s"
|
|
)
|
|
|
|
session.expunge(instance)
|
|
instance = await session.get(model, pk_value)
|
|
|
|
if instance is None:
|
|
raise NotFoundError(f"{model.__name__} with pk={pk_value!r} was deleted")
|
|
|
|
current = {col: getattr(instance, col) for col in watch_cols}
|
|
if current != initial:
|
|
return instance
|
|
|
|
|
|
def _m2m_prop(rel_attr: QueryableAttribute) -> RelationshipProperty: # type: ignore[type-arg]
|
|
"""Return the validated M2M RelationshipProperty for *rel_attr*.
|
|
|
|
Raises TypeError if *rel_attr* is not a Many-to-Many relationship.
|
|
"""
|
|
prop = rel_attr.property
|
|
if not isinstance(prop, RelationshipProperty) or prop.secondary is None:
|
|
raise TypeError(
|
|
f"m2m helpers require a Many-to-Many relationship attribute, "
|
|
f"got {rel_attr!r}. Use a relationship with a secondary table."
|
|
)
|
|
return prop
|
|
|
|
|
|
async def m2m_add(
|
|
session: AsyncSession,
|
|
instance: DeclarativeBase,
|
|
rel_attr: QueryableAttribute,
|
|
*related: DeclarativeBase,
|
|
ignore_conflicts: bool = False,
|
|
) -> None:
|
|
"""Insert rows into a Many-to-Many association table without loading the ORM collection.
|
|
|
|
Args:
|
|
session: DB async session.
|
|
instance: The "owner" side model instance (e.g. the ``A`` in ``A.b_list``).
|
|
rel_attr: The M2M relationship attribute on the model class (e.g. ``A.b_list``).
|
|
*related: One or more related instances to associate with ``instance``.
|
|
ignore_conflicts: When ``True``, silently skip rows that already exist
|
|
in the association table (``ON CONFLICT DO NOTHING``).
|
|
|
|
Raises:
|
|
TypeError: If ``rel_attr`` is not a Many-to-Many relationship.
|
|
"""
|
|
prop = _m2m_prop(rel_attr)
|
|
if not related:
|
|
return
|
|
|
|
secondary = cast(Table, prop.secondary)
|
|
assert secondary is not None # guaranteed by _m2m_prop
|
|
sync_pairs = prop.secondary_synchronize_pairs
|
|
assert sync_pairs is not None # set whenever secondary is set
|
|
|
|
# synchronize_pairs: [(parent_col, assoc_col), ...]
|
|
# secondary_synchronize_pairs: [(related_col, assoc_col), ...]
|
|
rows: list[dict[str, Any]] = []
|
|
for rel_instance in related:
|
|
row: dict[str, Any] = {}
|
|
for parent_col, assoc_col in prop.synchronize_pairs:
|
|
row[assoc_col.name] = getattr(instance, cast(str, parent_col.key))
|
|
for related_col, assoc_col in sync_pairs:
|
|
row[assoc_col.name] = getattr(rel_instance, cast(str, related_col.key))
|
|
rows.append(row)
|
|
|
|
stmt = pg_insert(secondary).values(rows)
|
|
if ignore_conflicts:
|
|
stmt = stmt.on_conflict_do_nothing()
|
|
await session.execute(stmt)
|
|
|
|
|
|
async def m2m_remove(
|
|
session: AsyncSession,
|
|
instance: DeclarativeBase,
|
|
rel_attr: QueryableAttribute,
|
|
*related: DeclarativeBase,
|
|
) -> None:
|
|
"""Remove rows from a Many-to-Many association table without loading the ORM collection.
|
|
|
|
Args:
|
|
session: DB async session.
|
|
instance: The "owner" side model instance (e.g. the ``A`` in ``A.b_list``).
|
|
rel_attr: The M2M relationship attribute on the model class (e.g. ``A.b_list``).
|
|
*related: One or more related instances to disassociate from ``instance``.
|
|
|
|
Raises:
|
|
TypeError: If ``rel_attr`` is not a Many-to-Many relationship.
|
|
"""
|
|
prop = _m2m_prop(rel_attr)
|
|
if not related:
|
|
return
|
|
|
|
secondary = cast(Table, prop.secondary)
|
|
assert secondary is not None # guaranteed by _m2m_prop
|
|
related_pairs = prop.secondary_synchronize_pairs
|
|
assert related_pairs is not None # set whenever secondary is set
|
|
|
|
parent_where = [
|
|
assoc_col == getattr(instance, cast(str, parent_col.key))
|
|
for parent_col, assoc_col in prop.synchronize_pairs
|
|
]
|
|
|
|
if len(related_pairs) == 1:
|
|
related_col, assoc_col = related_pairs[0]
|
|
related_values = [getattr(r, cast(str, related_col.key)) for r in related]
|
|
related_where = assoc_col.in_(related_values)
|
|
else:
|
|
assoc_cols = [ac for _, ac in related_pairs]
|
|
rel_cols = [rc for rc, _ in related_pairs]
|
|
related_values_t = [
|
|
tuple(getattr(r, cast(str, rc.key)) for rc in rel_cols) for r in related
|
|
]
|
|
related_where = tuple_(*assoc_cols).in_(related_values_t)
|
|
|
|
await session.execute(delete(secondary).where(*parent_where, related_where))
|
|
|
|
|
|
async def m2m_set(
|
|
session: AsyncSession,
|
|
instance: DeclarativeBase,
|
|
rel_attr: QueryableAttribute,
|
|
*related: DeclarativeBase,
|
|
) -> None:
|
|
"""Replace the entire Many-to-Many association set atomically.
|
|
|
|
Args:
|
|
session: DB async session.
|
|
instance: The "owner" side model instance (e.g. the ``A`` in ``A.b_list``).
|
|
rel_attr: The M2M relationship attribute on the model class (e.g. ``A.b_list``).
|
|
*related: The new complete set of related instances.
|
|
|
|
Raises:
|
|
TypeError: If ``rel_attr`` is not a Many-to-Many relationship.
|
|
"""
|
|
prop = _m2m_prop(rel_attr)
|
|
secondary = cast(Table, prop.secondary)
|
|
assert secondary is not None # guaranteed by _m2m_prop
|
|
|
|
parent_where = [
|
|
assoc_col == getattr(instance, cast(str, parent_col.key))
|
|
for parent_col, assoc_col in prop.synchronize_pairs
|
|
]
|
|
await session.execute(delete(secondary).where(*parent_where))
|
|
|
|
if related:
|
|
await m2m_add(session, instance, rel_attr, *related)
|