fix: resolve MissingGreenlet error when accessing self attributes in WatchedFieldsMixin callbacks (#154)

This commit is contained in:
d3vyce
2026-03-20 20:52:11 +01:00
committed by GitHub
parent f82225f995
commit fc9cd1f034
2 changed files with 160 additions and 67 deletions

View File

@@ -9,6 +9,7 @@ from typing import Any, TypeVar
from sqlalchemy import event
from sqlalchemy import inspect as sa_inspect
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import set_committed_value as _sa_set_committed_value
from ..logger import get_logger
@@ -53,6 +54,17 @@ def watch(*fields: str) -> Any:
return decorator
def _snapshot_column_attrs(obj: Any) -> dict[str, Any]:
"""Read currently-loaded column values into a plain dict."""
state = sa_inspect(obj) # InstanceState
state_dict = state.dict
return {
prop.key: state_dict[prop.key]
for prop in state.mapper.column_attrs
if prop.key in state_dict
}
def _upsert_changes(
pending: dict[int, tuple[Any, dict[str, dict[str, Any]]]],
obj: Any,
@@ -139,16 +151,31 @@ def _task_error_handler(task: asyncio.Task[Any]) -> None:
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
def _call_callback(loop: asyncio.AbstractEventLoop, fn: Any, *args: Any) -> None:
"""Dispatch *fn* with *args*, handling both sync and async callables."""
try:
result = fn(*args)
except Exception as exc:
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
return
if asyncio.iscoroutine(result):
task = loop.create_task(result)
task.add_done_callback(_task_error_handler)
def _schedule_with_snapshot(
loop: asyncio.AbstractEventLoop, obj: Any, fn: Any, *args: Any
) -> None:
"""Snapshot *obj*'s column attrs now (before expire_on_commit wipes them),
then schedule a coroutine that restores the snapshot and calls *fn*.
"""
snapshot = _snapshot_column_attrs(obj)
async def _run(
obj: Any = obj,
fn: Any = fn,
snapshot: dict[str, Any] = snapshot,
args: tuple = args,
) -> None:
for key, value in snapshot.items():
_sa_set_committed_value(obj, key, value)
try:
result = fn(*args)
if asyncio.iscoroutine(result):
await result
except Exception as exc:
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
task = loop.create_task(_run())
task.add_done_callback(_task_error_handler)
@event.listens_for(AsyncSession.sync_session_class, "after_commit")
@@ -168,13 +195,13 @@ def _after_commit(session: Any) -> None:
return
for obj in creates:
_call_callback(loop, obj.on_create)
_schedule_with_snapshot(loop, obj, obj.on_create)
for obj in deletes:
_call_callback(loop, obj.on_delete)
_schedule_with_snapshot(loop, obj, obj.on_delete)
for obj, changes in field_changes.values():
_call_callback(loop, obj.on_update, changes)
_schedule_with_snapshot(loop, obj, obj.on_update, changes)
class WatchedFieldsMixin: