mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-03-01 17:00:48 +01:00
Initial commit
This commit is contained in:
175
src/fastapi_toolsets/db.py
Normal file
175
src/fastapi_toolsets/db.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""Database utilities: sessions, transactions, and locks."""
|
||||
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
__all__ = [
|
||||
"LockMode",
|
||||
"create_db_context",
|
||||
"create_db_dependency",
|
||||
"lock_tables",
|
||||
"get_transaction",
|
||||
]
|
||||
|
||||
|
||||
def create_db_dependency(
|
||||
session_maker: async_sessionmaker[AsyncSession],
|
||||
) -> Callable[[], AsyncGenerator[AsyncSession, None]]:
|
||||
"""Create a FastAPI dependency for database sessions.
|
||||
|
||||
Creates a dependency function that yields a session and auto-commits
|
||||
if a transaction is active when the request completes.
|
||||
|
||||
Args:
|
||||
session_maker: Async session factory from create_session_factory()
|
||||
|
||||
Returns:
|
||||
An async generator function usable with FastAPI's Depends()
|
||||
|
||||
Example:
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||
from fastapi_toolsets.db import create_db_dependency
|
||||
|
||||
engine = create_async_engine("postgresql+asyncpg://...")
|
||||
SessionLocal = async_sessionmaker(engine, expire_on_commit=False)
|
||||
get_db = create_db_dependency(SessionLocal)
|
||||
|
||||
@app.get("/users")
|
||||
async def list_users(session: AsyncSession = Depends(get_db)):
|
||||
...
|
||||
"""
|
||||
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with session_maker() as session:
|
||||
yield session
|
||||
if session.in_transaction():
|
||||
await session.commit()
|
||||
|
||||
return get_db
|
||||
|
||||
|
||||
def create_db_context(
|
||||
session_maker: async_sessionmaker[AsyncSession],
|
||||
) -> Callable[[], AbstractAsyncContextManager[AsyncSession]]:
|
||||
"""Create a context manager for database sessions.
|
||||
|
||||
Creates a context manager for use outside of FastAPI request handlers,
|
||||
such as in background tasks, CLI commands, or tests.
|
||||
|
||||
Args:
|
||||
session_maker: Async session factory from create_session_factory()
|
||||
|
||||
Returns:
|
||||
An async context manager function
|
||||
|
||||
Example:
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||
from fastapi_toolsets.db import create_db_context
|
||||
|
||||
engine = create_async_engine("postgresql+asyncpg://...")
|
||||
SessionLocal = async_sessionmaker(engine, expire_on_commit=False)
|
||||
get_db_context = create_db_context(SessionLocal)
|
||||
|
||||
async def background_task():
|
||||
async with get_db_context() as session:
|
||||
user = await UserCrud.get(session, [User.id == 1])
|
||||
...
|
||||
"""
|
||||
get_db = create_db_dependency(session_maker)
|
||||
return asynccontextmanager(get_db)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_transaction(
|
||||
session: AsyncSession,
|
||||
) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Get a transaction context, handling nested transactions.
|
||||
|
||||
If already in a transaction, creates a savepoint (nested transaction).
|
||||
Otherwise, starts a new transaction.
|
||||
|
||||
Args:
|
||||
session: AsyncSession instance
|
||||
|
||||
Yields:
|
||||
The session within the transaction context
|
||||
|
||||
Example:
|
||||
async with get_transaction(session):
|
||||
session.add(model)
|
||||
# Auto-commits on exit, rolls back on exception
|
||||
"""
|
||||
if session.in_transaction():
|
||||
async with session.begin_nested():
|
||||
yield session
|
||||
else:
|
||||
async with session.begin():
|
||||
yield session
|
||||
|
||||
|
||||
class LockMode(str, Enum):
|
||||
"""PostgreSQL table lock modes.
|
||||
|
||||
See: https://www.postgresql.org/docs/current/explicit-locking.html
|
||||
"""
|
||||
|
||||
ACCESS_SHARE = "ACCESS SHARE"
|
||||
ROW_SHARE = "ROW SHARE"
|
||||
ROW_EXCLUSIVE = "ROW EXCLUSIVE"
|
||||
SHARE_UPDATE_EXCLUSIVE = "SHARE UPDATE EXCLUSIVE"
|
||||
SHARE = "SHARE"
|
||||
SHARE_ROW_EXCLUSIVE = "SHARE ROW EXCLUSIVE"
|
||||
EXCLUSIVE = "EXCLUSIVE"
|
||||
ACCESS_EXCLUSIVE = "ACCESS EXCLUSIVE"
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lock_tables(
|
||||
session: AsyncSession,
|
||||
tables: list[type[DeclarativeBase]],
|
||||
*,
|
||||
mode: LockMode = LockMode.SHARE_UPDATE_EXCLUSIVE,
|
||||
timeout: str = "5s",
|
||||
) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Lock PostgreSQL tables for the duration of a transaction.
|
||||
|
||||
Acquires table-level locks that are held until the transaction ends.
|
||||
Useful for preventing concurrent modifications during critical operations.
|
||||
|
||||
Args:
|
||||
session: AsyncSession instance
|
||||
tables: List of SQLAlchemy model classes to lock
|
||||
mode: Lock mode (default: SHARE UPDATE EXCLUSIVE)
|
||||
timeout: Lock timeout (default: "5s")
|
||||
|
||||
Yields:
|
||||
The session with locked tables
|
||||
|
||||
Raises:
|
||||
SQLAlchemyError: If lock cannot be acquired within timeout
|
||||
|
||||
Example:
|
||||
from fastapi_toolsets.db import lock_tables, LockMode
|
||||
|
||||
async with lock_tables(session, [User, Account]):
|
||||
# Tables are locked with SHARE UPDATE EXCLUSIVE mode
|
||||
user = await UserCrud.get(session, [User.id == 1])
|
||||
user.balance += 100
|
||||
|
||||
# With custom lock mode
|
||||
async with lock_tables(session, [Order], mode=LockMode.EXCLUSIVE):
|
||||
# Exclusive lock - no other transactions can access
|
||||
await process_order(session, order_id)
|
||||
"""
|
||||
table_names = ",".join(table.__tablename__ for table in tables)
|
||||
|
||||
async with get_transaction(session):
|
||||
await session.execute(text(f"SET LOCAL lock_timeout='{timeout}'"))
|
||||
await session.execute(text(f"LOCK {table_names} IN {mode.value} MODE"))
|
||||
yield session
|
||||
Reference in New Issue
Block a user