mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-16 06:36:26 +02:00
feat: rework async event system (#196)
* feat: rework async event system * docs: add v3 migration guide * feat: add cache * enhancements
This commit is contained in:
@@ -7,15 +7,15 @@ from .columns import (
|
||||
UUIDv7Mixin,
|
||||
UpdatedAtMixin,
|
||||
)
|
||||
from .watched import ModelEvent, WatchedFieldsMixin, watch
|
||||
from .watched import EventSession, ModelEvent, listens_for
|
||||
|
||||
__all__ = [
|
||||
"EventSession",
|
||||
"ModelEvent",
|
||||
"UUIDMixin",
|
||||
"UUIDv7Mixin",
|
||||
"CreatedAtMixin",
|
||||
"UpdatedAtMixin",
|
||||
"TimestampMixin",
|
||||
"WatchedFieldsMixin",
|
||||
"watch",
|
||||
"listens_for",
|
||||
]
|
||||
|
||||
@@ -6,14 +6,6 @@ from datetime import datetime
|
||||
from sqlalchemy import DateTime, Uuid, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
__all__ = [
|
||||
"UUIDMixin",
|
||||
"UUIDv7Mixin",
|
||||
"CreatedAtMixin",
|
||||
"UpdatedAtMixin",
|
||||
"TimestampMixin",
|
||||
]
|
||||
|
||||
|
||||
class UUIDMixin:
|
||||
"""Mixin that adds a UUID primary key auto-generated by the database."""
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
"""Field-change monitoring via SQLAlchemy session events."""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import weakref
|
||||
from collections.abc import Awaitable
|
||||
from collections.abc import Callable
|
||||
from enum import Enum
|
||||
from typing import Any, TypeVar
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy import inspect as sa_inspect
|
||||
@@ -14,49 +12,81 @@ from sqlalchemy.orm.attributes import set_committed_value as _sa_set_committed_v
|
||||
|
||||
from ..logger import get_logger
|
||||
|
||||
__all__ = ["ModelEvent", "WatchedFieldsMixin", "watch"]
|
||||
|
||||
_logger = get_logger()
|
||||
_T = TypeVar("_T")
|
||||
_CALLBACK_ERROR_MSG = "WatchedFieldsMixin callback raised an unhandled exception"
|
||||
_WATCHED_FIELDS: weakref.WeakKeyDictionary[type, list[str]] = (
|
||||
weakref.WeakKeyDictionary()
|
||||
)
|
||||
_SESSION_PENDING_NEW = "_ft_pending_new"
|
||||
_SESSION_CREATES = "_ft_creates"
|
||||
_SESSION_DELETES = "_ft_deletes"
|
||||
_SESSION_UPDATES = "_ft_updates"
|
||||
_SESSION_SAVEPOINT_DEPTH = "_ft_sp_depth"
|
||||
_DEFERRED_STRATEGY_KEY = (("deferred", True), ("instrument", True))
|
||||
|
||||
|
||||
class ModelEvent(str, Enum):
|
||||
"""Event types emitted by :class:`WatchedFieldsMixin`."""
|
||||
"""Event types dispatched by :class:`EventSession`."""
|
||||
|
||||
CREATE = "create"
|
||||
DELETE = "delete"
|
||||
UPDATE = "update"
|
||||
|
||||
|
||||
def watch(*fields: str) -> Any:
|
||||
"""Class decorator to filter which fields trigger ``on_update``.
|
||||
_CALLBACK_ERROR_MSG = "Event callback raised an unhandled exception"
|
||||
_SESSION_CREATES = "_ft_creates"
|
||||
_SESSION_DELETES = "_ft_deletes"
|
||||
_SESSION_UPDATES = "_ft_updates"
|
||||
_DEFERRED_STRATEGY_KEY = (("deferred", True), ("instrument", True))
|
||||
_EVENT_HANDLERS: dict[tuple[type, ModelEvent], list[Callable[..., Any]]] = {}
|
||||
_WATCHED_MODELS: set[type] = set()
|
||||
_WATCHED_CACHE: dict[type, bool] = {}
|
||||
_HANDLER_CACHE: dict[tuple[type, ModelEvent], list[Callable[..., Any]]] = {}
|
||||
|
||||
|
||||
def _invalidate_caches() -> None:
|
||||
"""Clear lookup caches after handler registration."""
|
||||
_WATCHED_CACHE.clear()
|
||||
_HANDLER_CACHE.clear()
|
||||
|
||||
|
||||
def listens_for(
|
||||
model_class: type,
|
||||
event_types: list[ModelEvent] | None = None,
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
"""Register a callback for one or more model lifecycle events.
|
||||
|
||||
Args:
|
||||
*fields: One or more field names to watch. At least one name is required.
|
||||
|
||||
Raises:
|
||||
ValueError: If called with no field names.
|
||||
model_class: The SQLAlchemy model class to listen on.
|
||||
event_types: List of :class:`ModelEvent` values to listen for.
|
||||
Defaults to all event types.
|
||||
"""
|
||||
if not fields:
|
||||
raise ValueError("@watch requires at least one field name.")
|
||||
evs = event_types if event_types is not None else list(ModelEvent)
|
||||
|
||||
def decorator(cls: type[_T]) -> type[_T]:
|
||||
_WATCHED_FIELDS[cls] = list(fields)
|
||||
return cls
|
||||
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
|
||||
for ev in evs:
|
||||
_EVENT_HANDLERS.setdefault((model_class, ev), []).append(fn)
|
||||
_WATCHED_MODELS.add(model_class)
|
||||
_invalidate_caches()
|
||||
return fn
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _is_watched(obj: Any) -> bool:
|
||||
"""Return True if *obj*'s type (or any ancestor) has registered handlers."""
|
||||
cls = type(obj)
|
||||
try:
|
||||
return _WATCHED_CACHE[cls]
|
||||
except KeyError:
|
||||
result = any(klass in _WATCHED_MODELS for klass in cls.__mro__)
|
||||
_WATCHED_CACHE[cls] = result
|
||||
return result
|
||||
|
||||
|
||||
def _get_handlers(cls: type, ev: ModelEvent) -> list[Callable[..., Any]]:
|
||||
"""Return registered handlers for *cls* and *ev*, walking the MRO."""
|
||||
key = (cls, ev)
|
||||
try:
|
||||
return _HANDLER_CACHE[key]
|
||||
except KeyError:
|
||||
handlers: list[Callable[..., Any]] = []
|
||||
for klass in cls.__mro__:
|
||||
handlers.extend(_EVENT_HANDLERS.get((klass, ev), []))
|
||||
_HANDLER_CACHE[key] = handlers
|
||||
return handlers
|
||||
|
||||
|
||||
def _snapshot_column_attrs(obj: Any) -> dict[str, Any]:
|
||||
"""Read currently-loaded column values into a plain dict."""
|
||||
state = sa_inspect(obj) # InstanceState
|
||||
@@ -65,7 +95,7 @@ def _snapshot_column_attrs(obj: Any) -> dict[str, Any]:
|
||||
for prop in state.mapper.column_attrs:
|
||||
if prop.key in state_dict:
|
||||
snapshot[prop.key] = state_dict[prop.key]
|
||||
elif (
|
||||
elif ( # pragma: no cover
|
||||
not state.expired
|
||||
and prop.strategy_key != _DEFERRED_STRATEGY_KEY
|
||||
and all(
|
||||
@@ -79,12 +109,17 @@ def _snapshot_column_attrs(obj: Any) -> dict[str, Any]:
|
||||
return snapshot
|
||||
|
||||
|
||||
def _get_watched_fields(cls: type) -> list[str] | None:
|
||||
"""Return the watched fields for *cls*, walking the MRO to inherit from parents."""
|
||||
for klass in cls.__mro__:
|
||||
if klass in _WATCHED_FIELDS:
|
||||
return _WATCHED_FIELDS[klass]
|
||||
return None
|
||||
def _get_watched_fields(cls: type) -> tuple[str, ...] | None:
|
||||
"""Return the watched fields for *cls*."""
|
||||
fields = getattr(cls, "__watched_fields__", None)
|
||||
if fields is not None and (
|
||||
not isinstance(fields, tuple) or not all(isinstance(f, str) for f in fields)
|
||||
):
|
||||
raise TypeError(
|
||||
f"{cls.__name__}.__watched_fields__ must be a tuple[str, ...], "
|
||||
f"got {type(fields).__name__}"
|
||||
)
|
||||
return fields
|
||||
|
||||
|
||||
def _upsert_changes(
|
||||
@@ -105,50 +140,32 @@ def _upsert_changes(
|
||||
pending[key] = (obj, changes)
|
||||
|
||||
|
||||
@event.listens_for(AsyncSession.sync_session_class, "after_transaction_create")
|
||||
def _after_transaction_create(session: Any, transaction: Any) -> None:
|
||||
if transaction.nested:
|
||||
session.info[_SESSION_SAVEPOINT_DEPTH] = (
|
||||
session.info.get(_SESSION_SAVEPOINT_DEPTH, 0) + 1
|
||||
)
|
||||
|
||||
|
||||
@event.listens_for(AsyncSession.sync_session_class, "after_transaction_end")
|
||||
def _after_transaction_end(session: Any, transaction: Any) -> None:
|
||||
if transaction.nested:
|
||||
depth = session.info.get(_SESSION_SAVEPOINT_DEPTH, 0)
|
||||
if depth > 0: # pragma: no branch
|
||||
session.info[_SESSION_SAVEPOINT_DEPTH] = depth - 1
|
||||
|
||||
|
||||
@event.listens_for(AsyncSession.sync_session_class, "after_flush")
|
||||
def _after_flush(session: Any, flush_context: Any) -> None:
|
||||
# New objects: capture references while session.new is still populated.
|
||||
# Values are read in _after_flush_postexec once RETURNING has been processed.
|
||||
# New objects: capture reference. Attributes will be refreshed after commit.
|
||||
for obj in session.new:
|
||||
if isinstance(obj, WatchedFieldsMixin):
|
||||
session.info.setdefault(_SESSION_PENDING_NEW, []).append(obj)
|
||||
if _is_watched(obj):
|
||||
session.info.setdefault(_SESSION_CREATES, []).append(obj)
|
||||
|
||||
# Deleted objects: capture before they leave the identity map.
|
||||
# Deleted objects: snapshot now while attributes are still loaded.
|
||||
for obj in session.deleted:
|
||||
if isinstance(obj, WatchedFieldsMixin):
|
||||
session.info.setdefault(_SESSION_DELETES, []).append(obj)
|
||||
if _is_watched(obj):
|
||||
snapshot = _snapshot_column_attrs(obj)
|
||||
session.info.setdefault(_SESSION_DELETES, []).append((obj, snapshot))
|
||||
|
||||
# Dirty objects: read old/new from SQLAlchemy attribute history.
|
||||
for obj in session.dirty:
|
||||
if not isinstance(obj, WatchedFieldsMixin):
|
||||
if not _is_watched(obj):
|
||||
continue
|
||||
|
||||
# None = not in dict = watch all fields; list = specific fields only
|
||||
watched = _get_watched_fields(type(obj))
|
||||
changes: dict[str, dict[str, Any]] = {}
|
||||
|
||||
inst_attrs = sa_inspect(obj).attrs
|
||||
attrs = (
|
||||
# Specific fields
|
||||
((field, sa_inspect(obj).attrs[field]) for field in watched)
|
||||
((field, inst_attrs[field]) for field in watched)
|
||||
if watched is not None
|
||||
# All mapped fields
|
||||
else ((s.key, s) for s in sa_inspect(obj).attrs)
|
||||
else ((s.key, s) for s in inst_attrs)
|
||||
)
|
||||
for field, attr_state in attrs:
|
||||
history = attr_state.history
|
||||
@@ -166,116 +183,101 @@ def _after_flush(session: Any, flush_context: Any) -> None:
|
||||
)
|
||||
|
||||
|
||||
@event.listens_for(AsyncSession.sync_session_class, "after_flush_postexec")
|
||||
def _after_flush_postexec(session: Any, flush_context: Any) -> None:
|
||||
# New objects are now persistent and RETURNING values have been applied,
|
||||
# so server defaults (id, created_at, …) are available via getattr.
|
||||
pending_new: list[Any] = session.info.pop(_SESSION_PENDING_NEW, [])
|
||||
if not pending_new:
|
||||
return
|
||||
session.info.setdefault(_SESSION_CREATES, []).extend(pending_new)
|
||||
|
||||
|
||||
@event.listens_for(AsyncSession.sync_session_class, "after_rollback")
|
||||
def _after_rollback(session: Any) -> None:
|
||||
session.info.pop(_SESSION_PENDING_NEW, None)
|
||||
if session.in_transaction():
|
||||
return
|
||||
session.info.pop(_SESSION_CREATES, None)
|
||||
session.info.pop(_SESSION_DELETES, None)
|
||||
session.info.pop(_SESSION_UPDATES, None)
|
||||
|
||||
|
||||
def _task_error_handler(task: asyncio.Task[Any]) -> None:
|
||||
if not task.cancelled() and (exc := task.exception()):
|
||||
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
|
||||
|
||||
|
||||
def _schedule_with_snapshot(
|
||||
loop: asyncio.AbstractEventLoop, obj: Any, fn: Any, *args: Any
|
||||
async def _invoke_callback(
|
||||
fn: Callable[..., Any],
|
||||
obj: Any,
|
||||
event_type: ModelEvent,
|
||||
changes: dict[str, dict[str, Any]] | None,
|
||||
) -> None:
|
||||
"""Snapshot *obj*'s column attrs now (before expire_on_commit wipes them),
|
||||
then schedule a coroutine that restores the snapshot and calls *fn*.
|
||||
"""
|
||||
snapshot = _snapshot_column_attrs(obj)
|
||||
|
||||
async def _run(
|
||||
obj: Any = obj,
|
||||
fn: Any = fn,
|
||||
snapshot: dict[str, Any] = snapshot,
|
||||
args: tuple = args,
|
||||
) -> None:
|
||||
for key, value in snapshot.items():
|
||||
_sa_set_committed_value(obj, key, value)
|
||||
try:
|
||||
result = fn(*args)
|
||||
if inspect.isawaitable(result):
|
||||
await result
|
||||
except Exception as exc:
|
||||
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
|
||||
|
||||
task = loop.create_task(_run())
|
||||
task.add_done_callback(_task_error_handler)
|
||||
"""Call *fn* and await the result if it is awaitable."""
|
||||
result = fn(obj, event_type, changes)
|
||||
if inspect.isawaitable(result):
|
||||
await result
|
||||
|
||||
|
||||
@event.listens_for(AsyncSession.sync_session_class, "after_commit")
|
||||
def _after_commit(session: Any) -> None:
|
||||
if session.info.get(_SESSION_SAVEPOINT_DEPTH, 0) > 0:
|
||||
return
|
||||
class EventSession(AsyncSession):
|
||||
"""AsyncSession subclass that dispatches lifecycle callbacks after commit."""
|
||||
|
||||
creates: list[Any] = session.info.pop(_SESSION_CREATES, [])
|
||||
deletes: list[Any] = session.info.pop(_SESSION_DELETES, [])
|
||||
field_changes: dict[int, tuple[Any, dict[str, dict[str, Any]]]] = session.info.pop(
|
||||
_SESSION_UPDATES, {}
|
||||
)
|
||||
async def commit(self) -> None: # noqa: C901
|
||||
await super().commit()
|
||||
|
||||
if creates and deletes:
|
||||
transient_ids = {id(o) for o in creates} & {id(o) for o in deletes}
|
||||
if transient_ids:
|
||||
creates = [o for o in creates if id(o) not in transient_ids]
|
||||
deletes = [o for o in deletes if id(o) not in transient_ids]
|
||||
creates: list[Any] = self.info.pop(_SESSION_CREATES, [])
|
||||
deletes: list[tuple[Any, dict[str, Any]]] = self.info.pop(_SESSION_DELETES, [])
|
||||
field_changes: dict[int, tuple[Any, dict[str, dict[str, Any]]]] = self.info.pop(
|
||||
_SESSION_UPDATES, {}
|
||||
)
|
||||
|
||||
if not creates and not deletes and not field_changes:
|
||||
return
|
||||
|
||||
# Suppress transient objects (created + deleted in same transaction).
|
||||
if creates and deletes:
|
||||
created_ids = {id(o) for o in creates}
|
||||
deleted_ids = {id(o) for o, _ in deletes}
|
||||
transient_ids = created_ids & deleted_ids
|
||||
if transient_ids:
|
||||
creates = [o for o in creates if id(o) not in transient_ids]
|
||||
deletes = [(o, s) for o, s in deletes if id(o) not in transient_ids]
|
||||
field_changes = {
|
||||
k: v for k, v in field_changes.items() if k not in transient_ids
|
||||
}
|
||||
|
||||
# Suppress updates for newly created objects (CREATE-only semantics).
|
||||
if creates and field_changes:
|
||||
create_ids = {id(o) for o in creates}
|
||||
field_changes = {
|
||||
k: v for k, v in field_changes.items() if k not in transient_ids
|
||||
k: v for k, v in field_changes.items() if k not in create_ids
|
||||
}
|
||||
|
||||
if not creates and not deletes and not field_changes:
|
||||
return
|
||||
# Dispatch CREATE callbacks.
|
||||
for obj in creates:
|
||||
try:
|
||||
state = sa_inspect(obj, raiseerr=False)
|
||||
if (
|
||||
state is None or state.detached or state.transient
|
||||
): # pragma: no cover
|
||||
continue
|
||||
await self.refresh(obj)
|
||||
for handler in _get_handlers(type(obj), ModelEvent.CREATE):
|
||||
await _invoke_callback(handler, obj, ModelEvent.CREATE, None)
|
||||
except Exception as exc:
|
||||
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return
|
||||
# Dispatch DELETE callbacks (restore snapshot; row is gone).
|
||||
for obj, snapshot in deletes:
|
||||
try:
|
||||
for key, value in snapshot.items():
|
||||
_sa_set_committed_value(obj, key, value)
|
||||
for handler in _get_handlers(type(obj), ModelEvent.DELETE):
|
||||
await _invoke_callback(handler, obj, ModelEvent.DELETE, None)
|
||||
except Exception as exc:
|
||||
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
|
||||
|
||||
for obj in creates:
|
||||
_schedule_with_snapshot(loop, obj, obj.on_create)
|
||||
# Dispatch UPDATE callbacks.
|
||||
for obj, changes in field_changes.values():
|
||||
try:
|
||||
state = sa_inspect(obj, raiseerr=False)
|
||||
if (
|
||||
state is None or state.detached or state.transient
|
||||
): # pragma: no cover
|
||||
continue
|
||||
await self.refresh(obj)
|
||||
for handler in _get_handlers(type(obj), ModelEvent.UPDATE):
|
||||
await _invoke_callback(handler, obj, ModelEvent.UPDATE, changes)
|
||||
except Exception as exc:
|
||||
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
|
||||
|
||||
for obj in deletes:
|
||||
_schedule_with_snapshot(loop, obj, obj.on_delete)
|
||||
|
||||
for obj, changes in field_changes.values():
|
||||
_schedule_with_snapshot(loop, obj, obj.on_update, changes)
|
||||
|
||||
|
||||
class WatchedFieldsMixin:
|
||||
"""Mixin that enables lifecycle callbacks for SQLAlchemy models."""
|
||||
|
||||
def on_event(
|
||||
self, event: ModelEvent, changes: dict[str, dict[str, Any]] | None = None
|
||||
) -> Awaitable[None] | None:
|
||||
"""Catch-all callback fired for every lifecycle event.
|
||||
|
||||
Args:
|
||||
event: The event type (:attr:`ModelEvent.CREATE`, :attr:`ModelEvent.DELETE`,
|
||||
or :attr:`ModelEvent.UPDATE`).
|
||||
changes: Field changes for :attr:`ModelEvent.UPDATE`, ``None`` otherwise.
|
||||
"""
|
||||
|
||||
def on_create(self) -> Awaitable[None] | None:
|
||||
"""Called after INSERT commit."""
|
||||
return self.on_event(ModelEvent.CREATE)
|
||||
|
||||
def on_delete(self) -> Awaitable[None] | None:
|
||||
"""Called after DELETE commit."""
|
||||
return self.on_event(ModelEvent.DELETE)
|
||||
|
||||
def on_update(self, changes: dict[str, dict[str, Any]]) -> Awaitable[None] | None:
|
||||
"""Called after UPDATE commit when watched fields change."""
|
||||
return self.on_event(ModelEvent.UPDATE, changes=changes)
|
||||
async def rollback(self) -> None:
|
||||
await super().rollback()
|
||||
self.info.pop(_SESSION_CREATES, None)
|
||||
self.info.pop(_SESSION_DELETES, None)
|
||||
self.info.pop(_SESSION_UPDATES, None)
|
||||
|
||||
Reference in New Issue
Block a user