fix: resolve MissingGreenlet error when accessing self attributes in WatchedFieldsMixin callbacks (#154)

This commit is contained in:
d3vyce
2026-03-20 20:52:11 +01:00
committed by GitHub
parent f82225f995
commit fc9cd1f034
2 changed files with 160 additions and 67 deletions

View File

@@ -9,6 +9,7 @@ from typing import Any, TypeVar
from sqlalchemy import event
from sqlalchemy import inspect as sa_inspect
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import set_committed_value as _sa_set_committed_value
from ..logger import get_logger
@@ -53,6 +54,17 @@ def watch(*fields: str) -> Any:
return decorator
def _snapshot_column_attrs(obj: Any) -> dict[str, Any]:
"""Read currently-loaded column values into a plain dict."""
state = sa_inspect(obj) # InstanceState
state_dict = state.dict
return {
prop.key: state_dict[prop.key]
for prop in state.mapper.column_attrs
if prop.key in state_dict
}
def _upsert_changes(
pending: dict[int, tuple[Any, dict[str, dict[str, Any]]]],
obj: Any,
@@ -139,15 +151,30 @@ def _task_error_handler(task: asyncio.Task[Any]) -> None:
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
def _call_callback(loop: asyncio.AbstractEventLoop, fn: Any, *args: Any) -> None:
"""Dispatch *fn* with *args*, handling both sync and async callables."""
def _schedule_with_snapshot(
loop: asyncio.AbstractEventLoop, obj: Any, fn: Any, *args: Any
) -> None:
"""Snapshot *obj*'s column attrs now (before expire_on_commit wipes them),
then schedule a coroutine that restores the snapshot and calls *fn*.
"""
snapshot = _snapshot_column_attrs(obj)
async def _run(
obj: Any = obj,
fn: Any = fn,
snapshot: dict[str, Any] = snapshot,
args: tuple = args,
) -> None:
for key, value in snapshot.items():
_sa_set_committed_value(obj, key, value)
try:
result = fn(*args)
if asyncio.iscoroutine(result):
await result
except Exception as exc:
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
return
if asyncio.iscoroutine(result):
task = loop.create_task(result)
task = loop.create_task(_run())
task.add_done_callback(_task_error_handler)
@@ -168,13 +195,13 @@ def _after_commit(session: Any) -> None:
return
for obj in creates:
_call_callback(loop, obj.on_create)
_schedule_with_snapshot(loop, obj, obj.on_create)
for obj in deletes:
_call_callback(loop, obj.on_delete)
_schedule_with_snapshot(loop, obj, obj.on_delete)
for obj, changes in field_changes.values():
_call_callback(loop, obj.on_update, changes)
_schedule_with_snapshot(loop, obj, obj.on_update, changes)
class WatchedFieldsMixin:

View File

@@ -31,7 +31,6 @@ from fastapi_toolsets.models.watched import (
_after_flush,
_after_flush_postexec,
_after_rollback,
_call_callback,
_task_error_handler,
_upsert_changes,
)
@@ -128,6 +127,17 @@ class WatchAllModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
_test_events.append({"event": "update", "obj_id": self.id, "changes": changes})
class FailingCallbackModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
"""Model whose on_create always raises to test exception logging."""
__tablename__ = "mixin_failing_callback_models"
name: Mapped[str] = mapped_column(String(50))
async def on_create(self) -> None:
raise RuntimeError("callback intentionally failed")
class NonWatchedModel(MixinBase):
__tablename__ = "mixin_non_watched_models"
@@ -135,6 +145,32 @@ class NonWatchedModel(MixinBase):
value: Mapped[str] = mapped_column(String(50))
_attr_access_events: list[dict] = []
class AttrAccessModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
"""Model used to verify that self attributes are accessible in every callback."""
__tablename__ = "mixin_attr_access_models"
name: Mapped[str] = mapped_column(String(50))
async def on_create(self) -> None:
_attr_access_events.append(
{"event": "create", "id": self.id, "name": self.name}
)
async def on_delete(self) -> None:
_attr_access_events.append(
{"event": "delete", "id": self.id, "name": self.name}
)
async def on_update(self, changes: dict) -> None:
_attr_access_events.append(
{"event": "update", "id": self.id, "name": self.name}
)
_sync_events: list[dict] = []
@@ -174,6 +210,25 @@ async def mixin_session():
await engine.dispose()
@pytest.fixture(scope="function")
async def mixin_session_expire():
"""Session with expire_on_commit=True (the default) to exercise attribute access after commit."""
engine = create_async_engine(DATABASE_URL, echo=False)
async with engine.begin() as conn:
await conn.run_sync(MixinBase.metadata.create_all)
session_factory = async_sessionmaker(engine, expire_on_commit=True)
session = session_factory()
try:
yield session
finally:
await session.close()
async with engine.begin() as conn:
await conn.run_sync(MixinBase.metadata.drop_all)
await engine.dispose()
class TestUUIDMixin:
@pytest.mark.anyio
async def test_uuid_generated_by_db(self, mixin_session):
@@ -742,6 +797,16 @@ class TestWatchedFieldsMixin:
assert _test_events == []
@pytest.mark.anyio
async def test_callback_exception_is_logged(self, mixin_session):
"""Exceptions raised inside on_create are logged, not propagated."""
obj = FailingCallbackModel(name="boom")
mixin_session.add(obj)
with patch.object(_watched_module._logger, "error") as mock_error:
await mixin_session.commit()
await asyncio.sleep(0)
mock_error.assert_called_once()
@pytest.mark.anyio
async def test_non_watched_model_no_callback(self, mixin_session):
"""Dirty objects whose type is not a WatchedFieldsMixin are skipped."""
@@ -903,65 +968,66 @@ class TestSyncCallbacks:
assert updates[0]["changes"]["status"] == {"old": "initial", "new": "updated"}
class TestCallCallback:
class TestAttributeAccessInCallbacks:
"""Verify that self attributes are accessible inside every callback type.
Uses expire_on_commit=True (the SQLAlchemy default) so the tests would fail
without the snapshot-restore logic in _schedule_with_snapshot.
"""
@pytest.fixture(autouse=True)
def clear_events(self):
_attr_access_events.clear()
yield
_attr_access_events.clear()
@pytest.mark.anyio
async def test_async_callback_scheduled_as_task(self):
"""_call_callback schedules async functions as tasks."""
called = []
async def async_fn() -> None:
called.append("async")
loop = asyncio.get_running_loop()
_call_callback(loop, async_fn)
async def test_on_create_pk_and_field_accessible(self, mixin_session_expire):
"""id (server default) and regular fields are readable inside on_create."""
obj = AttrAccessModel(name="hello")
mixin_session_expire.add(obj)
await mixin_session_expire.commit()
await asyncio.sleep(0)
assert called == ["async"]
events = [e for e in _attr_access_events if e["event"] == "create"]
assert len(events) == 1
assert isinstance(events[0]["id"], uuid.UUID)
assert events[0]["name"] == "hello"
@pytest.mark.anyio
async def test_sync_callback_called_directly(self):
"""_call_callback invokes sync functions immediately."""
called = []
def sync_fn() -> None:
called.append("sync")
loop = asyncio.get_running_loop()
_call_callback(loop, sync_fn)
assert called == ["sync"]
@pytest.mark.anyio
async def test_sync_callback_exception_logged(self):
"""_call_callback logs exceptions from sync callbacks."""
def failing_fn() -> None:
raise RuntimeError("sync error")
loop = asyncio.get_running_loop()
with patch.object(_watched_module._logger, "error") as mock_error:
_call_callback(loop, failing_fn)
mock_error.assert_called_once()
@pytest.mark.anyio
async def test_async_callback_with_args(self):
"""_call_callback passes arguments to async callbacks."""
received = []
async def async_fn(changes: dict) -> None:
received.append(changes)
loop = asyncio.get_running_loop()
_call_callback(loop, async_fn, {"status": {"old": "a", "new": "b"}})
async def test_on_delete_pk_and_field_accessible(self, mixin_session_expire):
"""id and regular fields are readable inside on_delete."""
obj = AttrAccessModel(name="to-delete")
mixin_session_expire.add(obj)
await mixin_session_expire.commit()
await asyncio.sleep(0)
assert received == [{"status": {"old": "a", "new": "b"}}]
_attr_access_events.clear()
await mixin_session_expire.delete(obj)
await mixin_session_expire.commit()
await asyncio.sleep(0)
events = [e for e in _attr_access_events if e["event"] == "delete"]
assert len(events) == 1
assert isinstance(events[0]["id"], uuid.UUID)
assert events[0]["name"] == "to-delete"
@pytest.mark.anyio
async def test_sync_callback_with_args(self):
"""_call_callback passes arguments to sync callbacks."""
received = []
async def test_on_update_pk_and_updated_field_accessible(
self, mixin_session_expire
):
"""id and the new field value are readable inside on_update."""
obj = AttrAccessModel(name="original")
mixin_session_expire.add(obj)
await mixin_session_expire.commit()
await asyncio.sleep(0)
_attr_access_events.clear()
def sync_fn(changes: dict) -> None:
received.append(changes)
obj.name = "updated"
await mixin_session_expire.commit()
await asyncio.sleep(0)
loop = asyncio.get_running_loop()
_call_callback(loop, sync_fn, {"x": 1})
assert received == [{"x": 1}]
events = [e for e in _attr_access_events if e["event"] == "update"]
assert len(events) == 1
assert isinstance(events[0]["id"], uuid.UUID)
assert events[0]["name"] == "updated"