mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-16 06:36:26 +02:00
fix: await any awaitable callback return value, not only coroutines (#168)
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
"""Field-change monitoring via SQLAlchemy session events."""
|
"""Field-change monitoring via SQLAlchemy session events."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import inspect
|
||||||
import weakref
|
import weakref
|
||||||
from collections.abc import Awaitable
|
from collections.abc import Awaitable
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@@ -169,7 +170,7 @@ def _schedule_with_snapshot(
|
|||||||
_sa_set_committed_value(obj, key, value)
|
_sa_set_committed_value(obj, key, value)
|
||||||
try:
|
try:
|
||||||
result = fn(*args)
|
result = fn(*args)
|
||||||
if asyncio.iscoroutine(result):
|
if inspect.isawaitable(result):
|
||||||
await result
|
await result
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
|
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
|
||||||
|
|||||||
@@ -199,6 +199,7 @@ class AttrAccessModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
|
|||||||
|
|
||||||
|
|
||||||
_sync_events: list[dict] = []
|
_sync_events: list[dict] = []
|
||||||
|
_future_events: list[str] = []
|
||||||
|
|
||||||
|
|
||||||
@watch("status")
|
@watch("status")
|
||||||
@@ -219,6 +220,20 @@ class SyncCallbackModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
|
|||||||
_sync_events.append({"event": "update", "changes": changes})
|
_sync_events.append({"event": "update", "changes": changes})
|
||||||
|
|
||||||
|
|
||||||
|
class FutureCallbackModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
|
||||||
|
"""Model whose on_create returns an asyncio.Task (awaitable, not a coroutine)."""
|
||||||
|
|
||||||
|
__tablename__ = "mixin_future_callback_models"
|
||||||
|
|
||||||
|
name: Mapped[str] = mapped_column(String(50))
|
||||||
|
|
||||||
|
def on_create(self) -> "asyncio.Task[None]":
|
||||||
|
async def _work() -> None:
|
||||||
|
_future_events.append("created")
|
||||||
|
|
||||||
|
return asyncio.ensure_future(_work())
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
async def mixin_session():
|
async def mixin_session():
|
||||||
engine = create_async_engine(DATABASE_URL, echo=False)
|
engine = create_async_engine(DATABASE_URL, echo=False)
|
||||||
@@ -1108,6 +1123,28 @@ class TestSyncCallbacks:
|
|||||||
assert updates[0]["changes"]["status"] == {"old": "initial", "new": "updated"}
|
assert updates[0]["changes"]["status"] == {"old": "initial", "new": "updated"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestFutureCallbacks:
|
||||||
|
"""Callbacks returning a non-coroutine awaitable (asyncio.Task / Future)."""
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def clear_events(self):
|
||||||
|
_future_events.clear()
|
||||||
|
yield
|
||||||
|
_future_events.clear()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_task_callback_is_awaited(self, mixin_session):
|
||||||
|
"""on_create returning an asyncio.Task is awaited and its work completes."""
|
||||||
|
obj = FutureCallbackModel(name="test")
|
||||||
|
mixin_session.add(obj)
|
||||||
|
await mixin_session.commit()
|
||||||
|
# Two turns: one for _run() to execute, one for the inner _work() task.
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
assert _future_events == ["created"]
|
||||||
|
|
||||||
|
|
||||||
class TestAttributeAccessInCallbacks:
|
class TestAttributeAccessInCallbacks:
|
||||||
"""Verify that self attributes are accessible inside every callback type.
|
"""Verify that self attributes are accessible inside every callback type.
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user