mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-03-01 17:00:48 +01:00
feat: add wait_for_row_change db helper (#49)
This commit is contained in:
@@ -1,8 +1,10 @@
|
|||||||
"""Database utilities: sessions, transactions, and locks."""
|
"""Database utilities: sessions, transactions, and locks."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from collections.abc import AsyncGenerator, Callable
|
from collections.abc import AsyncGenerator, Callable
|
||||||
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
@@ -14,6 +16,7 @@ __all__ = [
|
|||||||
"create_db_dependency",
|
"create_db_dependency",
|
||||||
"lock_tables",
|
"lock_tables",
|
||||||
"get_transaction",
|
"get_transaction",
|
||||||
|
"wait_for_row_change",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -173,3 +176,69 @@ async def lock_tables(
|
|||||||
await session.execute(text(f"SET LOCAL lock_timeout='{timeout}'"))
|
await session.execute(text(f"SET LOCAL lock_timeout='{timeout}'"))
|
||||||
await session.execute(text(f"LOCK {table_names} IN {mode.value} MODE"))
|
await session.execute(text(f"LOCK {table_names} IN {mode.value} MODE"))
|
||||||
yield session
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
_M = TypeVar("_M", bound=DeclarativeBase)
|
||||||
|
|
||||||
|
|
||||||
|
async def wait_for_row_change(
|
||||||
|
session: AsyncSession,
|
||||||
|
model: type[_M],
|
||||||
|
pk_value: Any,
|
||||||
|
*,
|
||||||
|
columns: list[str] | None = None,
|
||||||
|
interval: float = 0.5,
|
||||||
|
timeout: float | None = None,
|
||||||
|
) -> _M:
|
||||||
|
"""Poll a database row until a change is detected.
|
||||||
|
|
||||||
|
Queries the row every ``interval`` seconds and returns the model instance
|
||||||
|
once a change is detected in any column (or only the specified ``columns``).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: AsyncSession instance
|
||||||
|
model: SQLAlchemy model class
|
||||||
|
pk_value: Primary key value of the row to watch
|
||||||
|
columns: Optional list of column names to watch. If None, all columns
|
||||||
|
are watched.
|
||||||
|
interval: Polling interval in seconds (default: 0.5)
|
||||||
|
timeout: Maximum time to wait in seconds. None means wait forever.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The refreshed model instance with updated values
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
LookupError: If the row does not exist or is deleted during polling
|
||||||
|
TimeoutError: If timeout expires before a change is detected
|
||||||
|
"""
|
||||||
|
instance = await session.get(model, pk_value)
|
||||||
|
if instance is None:
|
||||||
|
raise LookupError(f"{model.__name__} with pk={pk_value!r} not found")
|
||||||
|
|
||||||
|
if columns is not None:
|
||||||
|
watch_cols = columns
|
||||||
|
else:
|
||||||
|
watch_cols = [attr.key for attr in model.__mapper__.column_attrs]
|
||||||
|
|
||||||
|
initial = {col: getattr(instance, col) for col in watch_cols}
|
||||||
|
|
||||||
|
elapsed = 0.0
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
elapsed += interval
|
||||||
|
|
||||||
|
if timeout is not None and elapsed >= timeout:
|
||||||
|
raise TimeoutError(
|
||||||
|
f"No change detected on {model.__name__} "
|
||||||
|
f"with pk={pk_value!r} within {timeout}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
session.expunge(instance)
|
||||||
|
instance = await session.get(model, pk_value)
|
||||||
|
|
||||||
|
if instance is None:
|
||||||
|
raise LookupError(f"{model.__name__} with pk={pk_value!r} was deleted")
|
||||||
|
|
||||||
|
current = {col: getattr(instance, col) for col in watch_cols}
|
||||||
|
if current != initial:
|
||||||
|
return instance
|
||||||
|
|||||||
102
tests/test_db.py
102
tests/test_db.py
@@ -1,5 +1,8 @@
|
|||||||
"""Tests for fastapi_toolsets.db module."""
|
"""Tests for fastapi_toolsets.db module."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
@@ -9,6 +12,7 @@ from fastapi_toolsets.db import (
|
|||||||
create_db_dependency,
|
create_db_dependency,
|
||||||
get_transaction,
|
get_transaction,
|
||||||
lock_tables,
|
lock_tables,
|
||||||
|
wait_for_row_change,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .conftest import DATABASE_URL, Base, Role, RoleCrud, User
|
from .conftest import DATABASE_URL, Base, Role, RoleCrud, User
|
||||||
@@ -241,3 +245,101 @@ class TestLockTables:
|
|||||||
|
|
||||||
result = await RoleCrud.first(db_session, [Role.name == "lock_rollback_role"])
|
result = await RoleCrud.first(db_session, [Role.name == "lock_rollback_role"])
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestWaitForRowChange:
|
||||||
|
"""Tests for wait_for_row_change polling function."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_detects_update(self, db_session: AsyncSession, engine):
|
||||||
|
"""Returns updated instance when a column value changes."""
|
||||||
|
role = Role(name="watch_role")
|
||||||
|
db_session.add(role)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
async def update_later():
|
||||||
|
await asyncio.sleep(0.15)
|
||||||
|
factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
async with factory() as other:
|
||||||
|
r = await other.get(Role, role.id)
|
||||||
|
assert r is not None
|
||||||
|
r.name = "updated_role"
|
||||||
|
await other.commit()
|
||||||
|
|
||||||
|
update_task = asyncio.create_task(update_later())
|
||||||
|
result = await wait_for_row_change(db_session, Role, role.id, interval=0.05)
|
||||||
|
await update_task
|
||||||
|
|
||||||
|
assert result.name == "updated_role"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_watches_specific_columns(self, db_session: AsyncSession, engine):
|
||||||
|
"""Only triggers on changes to specified columns."""
|
||||||
|
user = User(username="testuser", email="test@example.com")
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
async def update_later():
|
||||||
|
factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
# First: change email (not watched) — should not trigger
|
||||||
|
await asyncio.sleep(0.15)
|
||||||
|
async with factory() as other:
|
||||||
|
u = await other.get(User, user.id)
|
||||||
|
assert u is not None
|
||||||
|
u.email = "new@example.com"
|
||||||
|
await other.commit()
|
||||||
|
# Second: change username (watched) — should trigger
|
||||||
|
await asyncio.sleep(0.15)
|
||||||
|
async with factory() as other:
|
||||||
|
u = await other.get(User, user.id)
|
||||||
|
assert u is not None
|
||||||
|
u.username = "newuser"
|
||||||
|
await other.commit()
|
||||||
|
|
||||||
|
update_task = asyncio.create_task(update_later())
|
||||||
|
result = await wait_for_row_change(
|
||||||
|
db_session, User, user.id, columns=["username"], interval=0.05
|
||||||
|
)
|
||||||
|
await update_task
|
||||||
|
|
||||||
|
assert result.username == "newuser"
|
||||||
|
assert result.email == "new@example.com"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_nonexistent_row_raises(self, db_session: AsyncSession):
|
||||||
|
"""Raises LookupError when the row does not exist."""
|
||||||
|
fake_id = uuid.uuid4()
|
||||||
|
with pytest.raises(LookupError, match="not found"):
|
||||||
|
await wait_for_row_change(db_session, Role, fake_id, interval=0.05)
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_timeout_raises(self, db_session: AsyncSession):
|
||||||
|
"""Raises TimeoutError when no change is detected within timeout."""
|
||||||
|
role = Role(name="timeout_role")
|
||||||
|
db_session.add(role)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
with pytest.raises(TimeoutError):
|
||||||
|
await wait_for_row_change(
|
||||||
|
db_session, Role, role.id, interval=0.05, timeout=0.2
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_deleted_row_raises(self, db_session: AsyncSession, engine):
|
||||||
|
"""Raises LookupError when the row is deleted during polling."""
|
||||||
|
role = Role(name="delete_role")
|
||||||
|
db_session.add(role)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
async def delete_later():
|
||||||
|
await asyncio.sleep(0.15)
|
||||||
|
factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
async with factory() as other:
|
||||||
|
r = await other.get(Role, role.id)
|
||||||
|
await other.delete(r)
|
||||||
|
await other.commit()
|
||||||
|
|
||||||
|
delete_task = asyncio.create_task(delete_later())
|
||||||
|
with pytest.raises(LookupError):
|
||||||
|
await wait_for_row_change(db_session, Role, role.id, interval=0.05)
|
||||||
|
await delete_task
|
||||||
|
|||||||
Reference in New Issue
Block a user