mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-15 22:26:25 +02:00
fix: resolve MissingGreenlet error when accessing self attributes in WatchedFieldsMixin callbacks (#154)
This commit is contained in:
@@ -9,6 +9,7 @@ from typing import Any, TypeVar
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy import inspect as sa_inspect
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import set_committed_value as _sa_set_committed_value
|
||||
|
||||
from ..logger import get_logger
|
||||
|
||||
@@ -53,6 +54,17 @@ def watch(*fields: str) -> Any:
|
||||
return decorator
|
||||
|
||||
|
||||
def _snapshot_column_attrs(obj: Any) -> dict[str, Any]:
|
||||
"""Read currently-loaded column values into a plain dict."""
|
||||
state = sa_inspect(obj) # InstanceState
|
||||
state_dict = state.dict
|
||||
return {
|
||||
prop.key: state_dict[prop.key]
|
||||
for prop in state.mapper.column_attrs
|
||||
if prop.key in state_dict
|
||||
}
|
||||
|
||||
|
||||
def _upsert_changes(
|
||||
pending: dict[int, tuple[Any, dict[str, dict[str, Any]]]],
|
||||
obj: Any,
|
||||
@@ -139,16 +151,31 @@ def _task_error_handler(task: asyncio.Task[Any]) -> None:
|
||||
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
|
||||
|
||||
|
||||
def _call_callback(loop: asyncio.AbstractEventLoop, fn: Any, *args: Any) -> None:
|
||||
"""Dispatch *fn* with *args*, handling both sync and async callables."""
|
||||
try:
|
||||
result = fn(*args)
|
||||
except Exception as exc:
|
||||
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
|
||||
return
|
||||
if asyncio.iscoroutine(result):
|
||||
task = loop.create_task(result)
|
||||
task.add_done_callback(_task_error_handler)
|
||||
def _schedule_with_snapshot(
|
||||
loop: asyncio.AbstractEventLoop, obj: Any, fn: Any, *args: Any
|
||||
) -> None:
|
||||
"""Snapshot *obj*'s column attrs now (before expire_on_commit wipes them),
|
||||
then schedule a coroutine that restores the snapshot and calls *fn*.
|
||||
"""
|
||||
snapshot = _snapshot_column_attrs(obj)
|
||||
|
||||
async def _run(
|
||||
obj: Any = obj,
|
||||
fn: Any = fn,
|
||||
snapshot: dict[str, Any] = snapshot,
|
||||
args: tuple = args,
|
||||
) -> None:
|
||||
for key, value in snapshot.items():
|
||||
_sa_set_committed_value(obj, key, value)
|
||||
try:
|
||||
result = fn(*args)
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
except Exception as exc:
|
||||
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
|
||||
|
||||
task = loop.create_task(_run())
|
||||
task.add_done_callback(_task_error_handler)
|
||||
|
||||
|
||||
@event.listens_for(AsyncSession.sync_session_class, "after_commit")
|
||||
@@ -168,13 +195,13 @@ def _after_commit(session: Any) -> None:
|
||||
return
|
||||
|
||||
for obj in creates:
|
||||
_call_callback(loop, obj.on_create)
|
||||
_schedule_with_snapshot(loop, obj, obj.on_create)
|
||||
|
||||
for obj in deletes:
|
||||
_call_callback(loop, obj.on_delete)
|
||||
_schedule_with_snapshot(loop, obj, obj.on_delete)
|
||||
|
||||
for obj, changes in field_changes.values():
|
||||
_call_callback(loop, obj.on_update, changes)
|
||||
_schedule_with_snapshot(loop, obj, obj.on_update, changes)
|
||||
|
||||
|
||||
class WatchedFieldsMixin:
|
||||
|
||||
Reference in New Issue
Block a user