diff --git a/src/fastapi_toolsets/models/watched.py b/src/fastapi_toolsets/models/watched.py index f5e815b..69caf01 100644 --- a/src/fastapi_toolsets/models/watched.py +++ b/src/fastapi_toolsets/models/watched.py @@ -1,6 +1,7 @@ """Field-change monitoring via SQLAlchemy session events.""" import asyncio +import inspect import weakref from collections.abc import Awaitable from enum import Enum @@ -169,7 +170,7 @@ def _schedule_with_snapshot( _sa_set_committed_value(obj, key, value) try: result = fn(*args) - if asyncio.iscoroutine(result): + if inspect.isawaitable(result): await result except Exception as exc: _logger.error(_CALLBACK_ERROR_MSG, exc_info=exc) diff --git a/tests/test_models.py b/tests/test_models.py index b9cf7a8..18e8a8b 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -199,6 +199,7 @@ class AttrAccessModel(MixinBase, UUIDMixin, WatchedFieldsMixin): _sync_events: list[dict] = [] +_future_events: list[str] = [] @watch("status") @@ -219,6 +220,20 @@ class SyncCallbackModel(MixinBase, UUIDMixin, WatchedFieldsMixin): _sync_events.append({"event": "update", "changes": changes}) +class FutureCallbackModel(MixinBase, UUIDMixin, WatchedFieldsMixin): + """Model whose on_create returns an asyncio.Task (awaitable, not a coroutine).""" + + __tablename__ = "mixin_future_callback_models" + + name: Mapped[str] = mapped_column(String(50)) + + def on_create(self) -> "asyncio.Task[None]": + async def _work() -> None: + _future_events.append("created") + + return asyncio.ensure_future(_work()) + + @pytest.fixture(scope="function") async def mixin_session(): engine = create_async_engine(DATABASE_URL, echo=False) @@ -1108,6 +1123,28 @@ class TestSyncCallbacks: assert updates[0]["changes"]["status"] == {"old": "initial", "new": "updated"} +class TestFutureCallbacks: + """Callbacks returning a non-coroutine awaitable (asyncio.Task / Future).""" + + @pytest.fixture(autouse=True) + def clear_events(self): + _future_events.clear() + yield + _future_events.clear() + + @pytest.mark.anyio + async def test_task_callback_is_awaited(self, mixin_session): + """on_create returning an asyncio.Task is awaited and its work completes.""" + obj = FutureCallbackModel(name="test") + mixin_session.add(obj) + await mixin_session.commit() + # Two turns: one for _run() to execute, one for the inner _work() task. + await asyncio.sleep(0) + await asyncio.sleep(0) + + assert _future_events == ["created"] + + class TestAttributeAccessInCallbacks: """Verify that self attributes are accessible inside every callback type.