fix: await any awaitable callback return value, not only coroutines (#168)

This commit is contained in:
d3vyce
2026-03-23 18:58:48 +01:00
committed by GitHub
parent bcb5b0bfda
commit 0c7a99039c
2 changed files with 39 additions and 1 deletions

View File

@@ -1,6 +1,7 @@
"""Field-change monitoring via SQLAlchemy session events.""" """Field-change monitoring via SQLAlchemy session events."""
import asyncio import asyncio
import inspect
import weakref import weakref
from collections.abc import Awaitable from collections.abc import Awaitable
from enum import Enum from enum import Enum
@@ -169,7 +170,7 @@ def _schedule_with_snapshot(
_sa_set_committed_value(obj, key, value) _sa_set_committed_value(obj, key, value)
try: try:
result = fn(*args) result = fn(*args)
if asyncio.iscoroutine(result): if inspect.isawaitable(result):
await result await result
except Exception as exc: except Exception as exc:
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc) _logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)

View File

@@ -199,6 +199,7 @@ class AttrAccessModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
_sync_events: list[dict] = [] _sync_events: list[dict] = []
_future_events: list[str] = []
@watch("status") @watch("status")
@@ -219,6 +220,20 @@ class SyncCallbackModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
_sync_events.append({"event": "update", "changes": changes}) _sync_events.append({"event": "update", "changes": changes})
class FutureCallbackModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
"""Model whose on_create returns an asyncio.Task (awaitable, not a coroutine)."""
__tablename__ = "mixin_future_callback_models"
name: Mapped[str] = mapped_column(String(50))
def on_create(self) -> "asyncio.Task[None]":
async def _work() -> None:
_future_events.append("created")
return asyncio.ensure_future(_work())
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
async def mixin_session(): async def mixin_session():
engine = create_async_engine(DATABASE_URL, echo=False) engine = create_async_engine(DATABASE_URL, echo=False)
@@ -1108,6 +1123,28 @@ class TestSyncCallbacks:
assert updates[0]["changes"]["status"] == {"old": "initial", "new": "updated"} assert updates[0]["changes"]["status"] == {"old": "initial", "new": "updated"}
class TestFutureCallbacks:
"""Callbacks returning a non-coroutine awaitable (asyncio.Task / Future)."""
@pytest.fixture(autouse=True)
def clear_events(self):
_future_events.clear()
yield
_future_events.clear()
@pytest.mark.anyio
async def test_task_callback_is_awaited(self, mixin_session):
"""on_create returning an asyncio.Task is awaited and its work completes."""
obj = FutureCallbackModel(name="test")
mixin_session.add(obj)
await mixin_session.commit()
# Two turns: one for _run() to execute, one for the inner _work() task.
await asyncio.sleep(0)
await asyncio.sleep(0)
assert _future_events == ["created"]
class TestAttributeAccessInCallbacks: class TestAttributeAccessInCallbacks:
"""Verify that self attributes are accessible inside every callback type. """Verify that self attributes are accessible inside every callback type.