diff --git a/src/fastapi_toolsets/db.py b/src/fastapi_toolsets/db.py index d60987d..6939b30 100644 --- a/src/fastapi_toolsets/db.py +++ b/src/fastapi_toolsets/db.py @@ -1,8 +1,10 @@ """Database utilities: sessions, transactions, and locks.""" +import asyncio from collections.abc import AsyncGenerator, Callable from contextlib import AbstractAsyncContextManager, asynccontextmanager from enum import Enum +from typing import Any, TypeVar from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker @@ -14,6 +16,7 @@ __all__ = [ "create_db_dependency", "lock_tables", "get_transaction", + "wait_for_row_change", ] @@ -173,3 +176,69 @@ async def lock_tables( await session.execute(text(f"SET LOCAL lock_timeout='{timeout}'")) await session.execute(text(f"LOCK {table_names} IN {mode.value} MODE")) yield session + + +_M = TypeVar("_M", bound=DeclarativeBase) + + +async def wait_for_row_change( + session: AsyncSession, + model: type[_M], + pk_value: Any, + *, + columns: list[str] | None = None, + interval: float = 0.5, + timeout: float | None = None, +) -> _M: + """Poll a database row until a change is detected. + + Queries the row every ``interval`` seconds and returns the model instance + once a change is detected in any column (or only the specified ``columns``). + + Args: + session: AsyncSession instance + model: SQLAlchemy model class + pk_value: Primary key value of the row to watch + columns: Optional list of column names to watch. If None, all columns + are watched. + interval: Polling interval in seconds (default: 0.5) + timeout: Maximum time to wait in seconds. None means wait forever. + + Returns: + The refreshed model instance with updated values + + Raises: + LookupError: If the row does not exist or is deleted during polling + TimeoutError: If timeout expires before a change is detected + """ + instance = await session.get(model, pk_value) + if instance is None: + raise LookupError(f"{model.__name__} with pk={pk_value!r} not found") + + if columns is not None: + watch_cols = columns + else: + watch_cols = [attr.key for attr in model.__mapper__.column_attrs] + + initial = {col: getattr(instance, col) for col in watch_cols} + + elapsed = 0.0 + while True: + await asyncio.sleep(interval) + elapsed += interval + + if timeout is not None and elapsed >= timeout: + raise TimeoutError( + f"No change detected on {model.__name__} " + f"with pk={pk_value!r} within {timeout}s" + ) + + session.expunge(instance) + instance = await session.get(model, pk_value) + + if instance is None: + raise LookupError(f"{model.__name__} with pk={pk_value!r} was deleted") + + current = {col: getattr(instance, col) for col in watch_cols} + if current != initial: + return instance diff --git a/tests/test_db.py b/tests/test_db.py index 2cfda03..dabc3e7 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -1,5 +1,8 @@ """Tests for fastapi_toolsets.db module.""" +import asyncio +import uuid + import pytest from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine @@ -9,6 +12,7 @@ from fastapi_toolsets.db import ( create_db_dependency, get_transaction, lock_tables, + wait_for_row_change, ) from .conftest import DATABASE_URL, Base, Role, RoleCrud, User @@ -241,3 +245,101 @@ class TestLockTables: result = await RoleCrud.first(db_session, [Role.name == "lock_rollback_role"]) assert result is None + + +class TestWaitForRowChange: + """Tests for wait_for_row_change polling function.""" + + @pytest.mark.anyio + async def test_detects_update(self, db_session: AsyncSession, engine): + """Returns updated instance when a column value changes.""" + role = Role(name="watch_role") + db_session.add(role) + await db_session.commit() + + async def update_later(): + await asyncio.sleep(0.15) + factory = async_sessionmaker(engine, expire_on_commit=False) + async with factory() as other: + r = await other.get(Role, role.id) + assert r is not None + r.name = "updated_role" + await other.commit() + + update_task = asyncio.create_task(update_later()) + result = await wait_for_row_change(db_session, Role, role.id, interval=0.05) + await update_task + + assert result.name == "updated_role" + + @pytest.mark.anyio + async def test_watches_specific_columns(self, db_session: AsyncSession, engine): + """Only triggers on changes to specified columns.""" + user = User(username="testuser", email="test@example.com") + db_session.add(user) + await db_session.commit() + + async def update_later(): + factory = async_sessionmaker(engine, expire_on_commit=False) + # First: change email (not watched) — should not trigger + await asyncio.sleep(0.15) + async with factory() as other: + u = await other.get(User, user.id) + assert u is not None + u.email = "new@example.com" + await other.commit() + # Second: change username (watched) — should trigger + await asyncio.sleep(0.15) + async with factory() as other: + u = await other.get(User, user.id) + assert u is not None + u.username = "newuser" + await other.commit() + + update_task = asyncio.create_task(update_later()) + result = await wait_for_row_change( + db_session, User, user.id, columns=["username"], interval=0.05 + ) + await update_task + + assert result.username == "newuser" + assert result.email == "new@example.com" + + @pytest.mark.anyio + async def test_nonexistent_row_raises(self, db_session: AsyncSession): + """Raises LookupError when the row does not exist.""" + fake_id = uuid.uuid4() + with pytest.raises(LookupError, match="not found"): + await wait_for_row_change(db_session, Role, fake_id, interval=0.05) + + @pytest.mark.anyio + async def test_timeout_raises(self, db_session: AsyncSession): + """Raises TimeoutError when no change is detected within timeout.""" + role = Role(name="timeout_role") + db_session.add(role) + await db_session.commit() + + with pytest.raises(TimeoutError): + await wait_for_row_change( + db_session, Role, role.id, interval=0.05, timeout=0.2 + ) + + @pytest.mark.anyio + async def test_deleted_row_raises(self, db_session: AsyncSession, engine): + """Raises LookupError when the row is deleted during polling.""" + role = Role(name="delete_role") + db_session.add(role) + await db_session.commit() + + async def delete_later(): + await asyncio.sleep(0.15) + factory = async_sessionmaker(engine, expire_on_commit=False) + async with factory() as other: + r = await other.get(Role, role.id) + await other.delete(r) + await other.commit() + + delete_task = asyncio.create_task(delete_later()) + with pytest.raises(LookupError): + await wait_for_row_change(db_session, Role, role.id, interval=0.05) + await delete_task