feat: add wait_for_row_change db helper (#49)

This commit is contained in:
d3vyce
2026-02-10 21:46:59 +01:00
committed by GitHub
parent ced1a655f2
commit 1ea316bef4
2 changed files with 171 additions and 0 deletions

View File

@@ -1,8 +1,10 @@
"""Database utilities: sessions, transactions, and locks."""
import asyncio
from collections.abc import AsyncGenerator, Callable
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from enum import Enum
from typing import Any, TypeVar
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
@@ -14,6 +16,7 @@ __all__ = [
"create_db_dependency",
"lock_tables",
"get_transaction",
"wait_for_row_change",
]
@@ -173,3 +176,69 @@ async def lock_tables(
await session.execute(text(f"SET LOCAL lock_timeout='{timeout}'"))
await session.execute(text(f"LOCK {table_names} IN {mode.value} MODE"))
yield session
_M = TypeVar("_M", bound=DeclarativeBase)
async def wait_for_row_change(
session: AsyncSession,
model: type[_M],
pk_value: Any,
*,
columns: list[str] | None = None,
interval: float = 0.5,
timeout: float | None = None,
) -> _M:
"""Poll a database row until a change is detected.
Queries the row every ``interval`` seconds and returns the model instance
once a change is detected in any column (or only the specified ``columns``).
Args:
session: AsyncSession instance
model: SQLAlchemy model class
pk_value: Primary key value of the row to watch
columns: Optional list of column names to watch. If None, all columns
are watched.
interval: Polling interval in seconds (default: 0.5)
timeout: Maximum time to wait in seconds. None means wait forever.
Returns:
The refreshed model instance with updated values
Raises:
LookupError: If the row does not exist or is deleted during polling
TimeoutError: If timeout expires before a change is detected
"""
instance = await session.get(model, pk_value)
if instance is None:
raise LookupError(f"{model.__name__} with pk={pk_value!r} not found")
if columns is not None:
watch_cols = columns
else:
watch_cols = [attr.key for attr in model.__mapper__.column_attrs]
initial = {col: getattr(instance, col) for col in watch_cols}
elapsed = 0.0
while True:
await asyncio.sleep(interval)
elapsed += interval
if timeout is not None and elapsed >= timeout:
raise TimeoutError(
f"No change detected on {model.__name__} "
f"with pk={pk_value!r} within {timeout}s"
)
session.expunge(instance)
instance = await session.get(model, pk_value)
if instance is None:
raise LookupError(f"{model.__name__} with pk={pk_value!r} was deleted")
current = {col: getattr(instance, col) for col in watch_cols}
if current != initial:
return instance