Files
fastapi-toolsets/src/fastapi_toolsets/models/watched.py

270 lines
9.0 KiB
Python

"""Field-change monitoring via SQLAlchemy session events."""
import asyncio
import inspect
import weakref
from collections.abc import Awaitable
from enum import Enum
from typing import Any, TypeVar
from sqlalchemy import event
from sqlalchemy import inspect as sa_inspect
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import set_committed_value as _sa_set_committed_value
from ..logger import get_logger
__all__ = ["ModelEvent", "WatchedFieldsMixin", "watch"]
_logger = get_logger()
_T = TypeVar("_T")
_CALLBACK_ERROR_MSG = "WatchedFieldsMixin callback raised an unhandled exception"
_WATCHED_FIELDS: weakref.WeakKeyDictionary[type, list[str]] = (
weakref.WeakKeyDictionary()
)
_SESSION_PENDING_NEW = "_ft_pending_new"
_SESSION_CREATES = "_ft_creates"
_SESSION_DELETES = "_ft_deletes"
_SESSION_UPDATES = "_ft_updates"
_SESSION_SAVEPOINT_DEPTH = "_ft_sp_depth"
class ModelEvent(str, Enum):
"""Event types emitted by :class:`WatchedFieldsMixin`."""
CREATE = "create"
DELETE = "delete"
UPDATE = "update"
def watch(*fields: str) -> Any:
"""Class decorator to filter which fields trigger ``on_update``.
Args:
*fields: One or more field names to watch. At least one name is required.
Raises:
ValueError: If called with no field names.
"""
if not fields:
raise ValueError("@watch requires at least one field name.")
def decorator(cls: type[_T]) -> type[_T]:
_WATCHED_FIELDS[cls] = list(fields)
return cls
return decorator
def _snapshot_column_attrs(obj: Any) -> dict[str, Any]:
"""Read currently-loaded column values into a plain dict."""
state = sa_inspect(obj) # InstanceState
state_dict = state.dict
return {
prop.key: state_dict[prop.key]
for prop in state.mapper.column_attrs
if prop.key in state_dict
}
def _get_watched_fields(cls: type) -> list[str] | None:
"""Return the watched fields for *cls*, walking the MRO to inherit from parents."""
for klass in cls.__mro__:
if klass in _WATCHED_FIELDS:
return _WATCHED_FIELDS[klass]
return None
def _upsert_changes(
pending: dict[int, tuple[Any, dict[str, dict[str, Any]]]],
obj: Any,
changes: dict[str, dict[str, Any]],
) -> None:
"""Insert or merge *changes* into *pending* for *obj*."""
key = id(obj)
if key in pending:
existing = pending[key][1]
for field, change in changes.items():
if field in existing:
existing[field]["new"] = change["new"]
else:
existing[field] = change
else:
pending[key] = (obj, changes)
@event.listens_for(AsyncSession.sync_session_class, "after_transaction_create")
def _after_transaction_create(session: Any, transaction: Any) -> None:
if transaction.nested:
session.info[_SESSION_SAVEPOINT_DEPTH] = (
session.info.get(_SESSION_SAVEPOINT_DEPTH, 0) + 1
)
@event.listens_for(AsyncSession.sync_session_class, "after_transaction_end")
def _after_transaction_end(session: Any, transaction: Any) -> None:
if transaction.nested:
depth = session.info.get(_SESSION_SAVEPOINT_DEPTH, 0)
if depth > 0: # pragma: no branch
session.info[_SESSION_SAVEPOINT_DEPTH] = depth - 1
@event.listens_for(AsyncSession.sync_session_class, "after_flush")
def _after_flush(session: Any, flush_context: Any) -> None:
# New objects: capture references while session.new is still populated.
# Values are read in _after_flush_postexec once RETURNING has been processed.
for obj in session.new:
if isinstance(obj, WatchedFieldsMixin):
session.info.setdefault(_SESSION_PENDING_NEW, []).append(obj)
# Deleted objects: capture before they leave the identity map.
for obj in session.deleted:
if isinstance(obj, WatchedFieldsMixin):
session.info.setdefault(_SESSION_DELETES, []).append(obj)
# Dirty objects: read old/new from SQLAlchemy attribute history.
for obj in session.dirty:
if not isinstance(obj, WatchedFieldsMixin):
continue
# None = not in dict = watch all fields; list = specific fields only
watched = _get_watched_fields(type(obj))
changes: dict[str, dict[str, Any]] = {}
attrs = (
# Specific fields
((field, sa_inspect(obj).attrs[field]) for field in watched)
if watched is not None
# All mapped fields
else ((s.key, s) for s in sa_inspect(obj).attrs)
)
for field, attr_state in attrs:
history = attr_state.history
if history.has_changes() and history.deleted:
changes[field] = {
"old": history.deleted[0],
"new": history.added[0] if history.added else None,
}
if changes:
_upsert_changes(
session.info.setdefault(_SESSION_UPDATES, {}),
obj,
changes,
)
@event.listens_for(AsyncSession.sync_session_class, "after_flush_postexec")
def _after_flush_postexec(session: Any, flush_context: Any) -> None:
# New objects are now persistent and RETURNING values have been applied,
# so server defaults (id, created_at, …) are available via getattr.
pending_new: list[Any] = session.info.pop(_SESSION_PENDING_NEW, [])
if not pending_new:
return
session.info.setdefault(_SESSION_CREATES, []).extend(pending_new)
@event.listens_for(AsyncSession.sync_session_class, "after_rollback")
def _after_rollback(session: Any) -> None:
session.info.pop(_SESSION_PENDING_NEW, None)
session.info.pop(_SESSION_CREATES, None)
session.info.pop(_SESSION_DELETES, None)
session.info.pop(_SESSION_UPDATES, None)
def _task_error_handler(task: asyncio.Task[Any]) -> None:
if not task.cancelled() and (exc := task.exception()):
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
def _schedule_with_snapshot(
loop: asyncio.AbstractEventLoop, obj: Any, fn: Any, *args: Any
) -> None:
"""Snapshot *obj*'s column attrs now (before expire_on_commit wipes them),
then schedule a coroutine that restores the snapshot and calls *fn*.
"""
snapshot = _snapshot_column_attrs(obj)
async def _run(
obj: Any = obj,
fn: Any = fn,
snapshot: dict[str, Any] = snapshot,
args: tuple = args,
) -> None:
for key, value in snapshot.items():
_sa_set_committed_value(obj, key, value)
try:
result = fn(*args)
if inspect.isawaitable(result):
await result
except Exception as exc:
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
task = loop.create_task(_run())
task.add_done_callback(_task_error_handler)
@event.listens_for(AsyncSession.sync_session_class, "after_commit")
def _after_commit(session: Any) -> None:
if session.info.get(_SESSION_SAVEPOINT_DEPTH, 0) > 0:
return
creates: list[Any] = session.info.pop(_SESSION_CREATES, [])
deletes: list[Any] = session.info.pop(_SESSION_DELETES, [])
field_changes: dict[int, tuple[Any, dict[str, dict[str, Any]]]] = session.info.pop(
_SESSION_UPDATES, {}
)
if creates and deletes:
transient_ids = {id(o) for o in creates} & {id(o) for o in deletes}
if transient_ids:
creates = [o for o in creates if id(o) not in transient_ids]
deletes = [o for o in deletes if id(o) not in transient_ids]
field_changes = {
k: v for k, v in field_changes.items() if k not in transient_ids
}
if not creates and not deletes and not field_changes:
return
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return
for obj in creates:
_schedule_with_snapshot(loop, obj, obj.on_create)
for obj in deletes:
_schedule_with_snapshot(loop, obj, obj.on_delete)
for obj, changes in field_changes.values():
_schedule_with_snapshot(loop, obj, obj.on_update, changes)
class WatchedFieldsMixin:
"""Mixin that enables lifecycle callbacks for SQLAlchemy models."""
def on_event(
self, event: ModelEvent, changes: dict[str, dict[str, Any]] | None = None
) -> Awaitable[None] | None:
"""Catch-all callback fired for every lifecycle event.
Args:
event: The event type (:attr:`ModelEvent.CREATE`, :attr:`ModelEvent.DELETE`,
or :attr:`ModelEvent.UPDATE`).
changes: Field changes for :attr:`ModelEvent.UPDATE`, ``None`` otherwise.
"""
def on_create(self) -> Awaitable[None] | None:
"""Called after INSERT commit."""
return self.on_event(ModelEvent.CREATE)
def on_delete(self) -> Awaitable[None] | None:
"""Called after DELETE commit."""
return self.on_event(ModelEvent.DELETE)
def on_update(self, changes: dict[str, dict[str, Any]]) -> Awaitable[None] | None:
"""Called after UPDATE commit when watched fields change."""
return self.on_event(ModelEvent.UPDATE, changes=changes)