mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-03-02 01:10:47 +01:00
176 lines
5.5 KiB
Python
176 lines
5.5 KiB
Python
"""Database utilities: sessions, transactions, and locks."""
|
|
|
|
from collections.abc import AsyncGenerator, Callable
|
|
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
|
from enum import Enum
|
|
|
|
from sqlalchemy import text
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|
from sqlalchemy.orm import DeclarativeBase
|
|
|
|
__all__ = [
|
|
"LockMode",
|
|
"create_db_context",
|
|
"create_db_dependency",
|
|
"lock_tables",
|
|
"get_transaction",
|
|
]
|
|
|
|
|
|
def create_db_dependency(
|
|
session_maker: async_sessionmaker[AsyncSession],
|
|
) -> Callable[[], AsyncGenerator[AsyncSession, None]]:
|
|
"""Create a FastAPI dependency for database sessions.
|
|
|
|
Creates a dependency function that yields a session and auto-commits
|
|
if a transaction is active when the request completes.
|
|
|
|
Args:
|
|
session_maker: Async session factory from create_session_factory()
|
|
|
|
Returns:
|
|
An async generator function usable with FastAPI's Depends()
|
|
|
|
Example:
|
|
from fastapi import Depends
|
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
|
from fastapi_toolsets.db import create_db_dependency
|
|
|
|
engine = create_async_engine("postgresql+asyncpg://...")
|
|
SessionLocal = async_sessionmaker(engine, expire_on_commit=False)
|
|
get_db = create_db_dependency(SessionLocal)
|
|
|
|
@app.get("/users")
|
|
async def list_users(session: AsyncSession = Depends(get_db)):
|
|
...
|
|
"""
|
|
|
|
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
|
async with session_maker() as session:
|
|
yield session
|
|
if session.in_transaction():
|
|
await session.commit()
|
|
|
|
return get_db
|
|
|
|
|
|
def create_db_context(
|
|
session_maker: async_sessionmaker[AsyncSession],
|
|
) -> Callable[[], AbstractAsyncContextManager[AsyncSession]]:
|
|
"""Create a context manager for database sessions.
|
|
|
|
Creates a context manager for use outside of FastAPI request handlers,
|
|
such as in background tasks, CLI commands, or tests.
|
|
|
|
Args:
|
|
session_maker: Async session factory from create_session_factory()
|
|
|
|
Returns:
|
|
An async context manager function
|
|
|
|
Example:
|
|
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
|
from fastapi_toolsets.db import create_db_context
|
|
|
|
engine = create_async_engine("postgresql+asyncpg://...")
|
|
SessionLocal = async_sessionmaker(engine, expire_on_commit=False)
|
|
get_db_context = create_db_context(SessionLocal)
|
|
|
|
async def background_task():
|
|
async with get_db_context() as session:
|
|
user = await UserCrud.get(session, [User.id == 1])
|
|
...
|
|
"""
|
|
get_db = create_db_dependency(session_maker)
|
|
return asynccontextmanager(get_db)
|
|
|
|
|
|
@asynccontextmanager
|
|
async def get_transaction(
|
|
session: AsyncSession,
|
|
) -> AsyncGenerator[AsyncSession, None]:
|
|
"""Get a transaction context, handling nested transactions.
|
|
|
|
If already in a transaction, creates a savepoint (nested transaction).
|
|
Otherwise, starts a new transaction.
|
|
|
|
Args:
|
|
session: AsyncSession instance
|
|
|
|
Yields:
|
|
The session within the transaction context
|
|
|
|
Example:
|
|
async with get_transaction(session):
|
|
session.add(model)
|
|
# Auto-commits on exit, rolls back on exception
|
|
"""
|
|
if session.in_transaction():
|
|
async with session.begin_nested():
|
|
yield session
|
|
else:
|
|
async with session.begin():
|
|
yield session
|
|
|
|
|
|
class LockMode(str, Enum):
|
|
"""PostgreSQL table lock modes.
|
|
|
|
See: https://www.postgresql.org/docs/current/explicit-locking.html
|
|
"""
|
|
|
|
ACCESS_SHARE = "ACCESS SHARE"
|
|
ROW_SHARE = "ROW SHARE"
|
|
ROW_EXCLUSIVE = "ROW EXCLUSIVE"
|
|
SHARE_UPDATE_EXCLUSIVE = "SHARE UPDATE EXCLUSIVE"
|
|
SHARE = "SHARE"
|
|
SHARE_ROW_EXCLUSIVE = "SHARE ROW EXCLUSIVE"
|
|
EXCLUSIVE = "EXCLUSIVE"
|
|
ACCESS_EXCLUSIVE = "ACCESS EXCLUSIVE"
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lock_tables(
|
|
session: AsyncSession,
|
|
tables: list[type[DeclarativeBase]],
|
|
*,
|
|
mode: LockMode = LockMode.SHARE_UPDATE_EXCLUSIVE,
|
|
timeout: str = "5s",
|
|
) -> AsyncGenerator[AsyncSession, None]:
|
|
"""Lock PostgreSQL tables for the duration of a transaction.
|
|
|
|
Acquires table-level locks that are held until the transaction ends.
|
|
Useful for preventing concurrent modifications during critical operations.
|
|
|
|
Args:
|
|
session: AsyncSession instance
|
|
tables: List of SQLAlchemy model classes to lock
|
|
mode: Lock mode (default: SHARE UPDATE EXCLUSIVE)
|
|
timeout: Lock timeout (default: "5s")
|
|
|
|
Yields:
|
|
The session with locked tables
|
|
|
|
Raises:
|
|
SQLAlchemyError: If lock cannot be acquired within timeout
|
|
|
|
Example:
|
|
from fastapi_toolsets.db import lock_tables, LockMode
|
|
|
|
async with lock_tables(session, [User, Account]):
|
|
# Tables are locked with SHARE UPDATE EXCLUSIVE mode
|
|
user = await UserCrud.get(session, [User.id == 1])
|
|
user.balance += 100
|
|
|
|
# With custom lock mode
|
|
async with lock_tables(session, [Order], mode=LockMode.EXCLUSIVE):
|
|
# Exclusive lock - no other transactions can access
|
|
await process_order(session, order_id)
|
|
"""
|
|
table_names = ",".join(table.__tablename__ for table in tables)
|
|
|
|
async with get_transaction(session):
|
|
await session.execute(text(f"SET LOCAL lock_timeout='{timeout}'"))
|
|
await session.execute(text(f"LOCK {table_names} IN {mode.value} MODE"))
|
|
yield session
|