feat: rework async event system (#196)

* feat: rework async event system

* docs: add v3 migration guide

* feat: add cache

* enhancements
This commit is contained in:
d3vyce
2026-03-30 18:24:36 +02:00
committed by GitHub
parent 104285c6e5
commit 1890d696bf
12 changed files with 1149 additions and 659 deletions

View File

@@ -24,9 +24,12 @@ __all__ = [
]
_SessionT = TypeVar("_SessionT", bound=AsyncSession)
def create_db_dependency(
session_maker: async_sessionmaker[AsyncSession],
) -> Callable[[], AsyncGenerator[AsyncSession, None]]:
session_maker: async_sessionmaker[_SessionT],
) -> Callable[[], AsyncGenerator[_SessionT, None]]:
"""Create a FastAPI dependency for database sessions.
Creates a dependency function that yields a session and auto-commits
@@ -54,7 +57,7 @@ def create_db_dependency(
```
"""
async def get_db() -> AsyncGenerator[AsyncSession, None]:
async def get_db() -> AsyncGenerator[_SessionT, None]:
async with session_maker() as session:
await session.connection()
yield session
@@ -65,8 +68,8 @@ def create_db_dependency(
def create_db_context(
session_maker: async_sessionmaker[AsyncSession],
) -> Callable[[], AbstractAsyncContextManager[AsyncSession]]:
session_maker: async_sessionmaker[_SessionT],
) -> Callable[[], AbstractAsyncContextManager[_SessionT]]:
"""Create a context manager for database sessions.
Creates a context manager for use outside of FastAPI request handlers,

View File

@@ -7,15 +7,15 @@ from .columns import (
UUIDv7Mixin,
UpdatedAtMixin,
)
from .watched import ModelEvent, WatchedFieldsMixin, watch
from .watched import EventSession, ModelEvent, listens_for
__all__ = [
"EventSession",
"ModelEvent",
"UUIDMixin",
"UUIDv7Mixin",
"CreatedAtMixin",
"UpdatedAtMixin",
"TimestampMixin",
"WatchedFieldsMixin",
"watch",
"listens_for",
]

View File

@@ -6,14 +6,6 @@ from datetime import datetime
from sqlalchemy import DateTime, Uuid, text
from sqlalchemy.orm import Mapped, mapped_column
__all__ = [
"UUIDMixin",
"UUIDv7Mixin",
"CreatedAtMixin",
"UpdatedAtMixin",
"TimestampMixin",
]
class UUIDMixin:
"""Mixin that adds a UUID primary key auto-generated by the database."""

View File

@@ -1,11 +1,9 @@
"""Field-change monitoring via SQLAlchemy session events."""
import asyncio
import inspect
import weakref
from collections.abc import Awaitable
from collections.abc import Callable
from enum import Enum
from typing import Any, TypeVar
from typing import Any
from sqlalchemy import event
from sqlalchemy import inspect as sa_inspect
@@ -14,49 +12,81 @@ from sqlalchemy.orm.attributes import set_committed_value as _sa_set_committed_v
from ..logger import get_logger
__all__ = ["ModelEvent", "WatchedFieldsMixin", "watch"]
_logger = get_logger()
_T = TypeVar("_T")
_CALLBACK_ERROR_MSG = "WatchedFieldsMixin callback raised an unhandled exception"
_WATCHED_FIELDS: weakref.WeakKeyDictionary[type, list[str]] = (
weakref.WeakKeyDictionary()
)
_SESSION_PENDING_NEW = "_ft_pending_new"
_SESSION_CREATES = "_ft_creates"
_SESSION_DELETES = "_ft_deletes"
_SESSION_UPDATES = "_ft_updates"
_SESSION_SAVEPOINT_DEPTH = "_ft_sp_depth"
_DEFERRED_STRATEGY_KEY = (("deferred", True), ("instrument", True))
class ModelEvent(str, Enum):
"""Event types emitted by :class:`WatchedFieldsMixin`."""
"""Event types dispatched by :class:`EventSession`."""
CREATE = "create"
DELETE = "delete"
UPDATE = "update"
def watch(*fields: str) -> Any:
"""Class decorator to filter which fields trigger ``on_update``.
_CALLBACK_ERROR_MSG = "Event callback raised an unhandled exception"
_SESSION_CREATES = "_ft_creates"
_SESSION_DELETES = "_ft_deletes"
_SESSION_UPDATES = "_ft_updates"
_DEFERRED_STRATEGY_KEY = (("deferred", True), ("instrument", True))
_EVENT_HANDLERS: dict[tuple[type, ModelEvent], list[Callable[..., Any]]] = {}
_WATCHED_MODELS: set[type] = set()
_WATCHED_CACHE: dict[type, bool] = {}
_HANDLER_CACHE: dict[tuple[type, ModelEvent], list[Callable[..., Any]]] = {}
def _invalidate_caches() -> None:
"""Clear lookup caches after handler registration."""
_WATCHED_CACHE.clear()
_HANDLER_CACHE.clear()
def listens_for(
model_class: type,
event_types: list[ModelEvent] | None = None,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""Register a callback for one or more model lifecycle events.
Args:
*fields: One or more field names to watch. At least one name is required.
Raises:
ValueError: If called with no field names.
model_class: The SQLAlchemy model class to listen on.
event_types: List of :class:`ModelEvent` values to listen for.
Defaults to all event types.
"""
if not fields:
raise ValueError("@watch requires at least one field name.")
evs = event_types if event_types is not None else list(ModelEvent)
def decorator(cls: type[_T]) -> type[_T]:
_WATCHED_FIELDS[cls] = list(fields)
return cls
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
for ev in evs:
_EVENT_HANDLERS.setdefault((model_class, ev), []).append(fn)
_WATCHED_MODELS.add(model_class)
_invalidate_caches()
return fn
return decorator
def _is_watched(obj: Any) -> bool:
"""Return True if *obj*'s type (or any ancestor) has registered handlers."""
cls = type(obj)
try:
return _WATCHED_CACHE[cls]
except KeyError:
result = any(klass in _WATCHED_MODELS for klass in cls.__mro__)
_WATCHED_CACHE[cls] = result
return result
def _get_handlers(cls: type, ev: ModelEvent) -> list[Callable[..., Any]]:
"""Return registered handlers for *cls* and *ev*, walking the MRO."""
key = (cls, ev)
try:
return _HANDLER_CACHE[key]
except KeyError:
handlers: list[Callable[..., Any]] = []
for klass in cls.__mro__:
handlers.extend(_EVENT_HANDLERS.get((klass, ev), []))
_HANDLER_CACHE[key] = handlers
return handlers
def _snapshot_column_attrs(obj: Any) -> dict[str, Any]:
"""Read currently-loaded column values into a plain dict."""
state = sa_inspect(obj) # InstanceState
@@ -65,7 +95,7 @@ def _snapshot_column_attrs(obj: Any) -> dict[str, Any]:
for prop in state.mapper.column_attrs:
if prop.key in state_dict:
snapshot[prop.key] = state_dict[prop.key]
elif (
elif ( # pragma: no cover
not state.expired
and prop.strategy_key != _DEFERRED_STRATEGY_KEY
and all(
@@ -79,12 +109,17 @@ def _snapshot_column_attrs(obj: Any) -> dict[str, Any]:
return snapshot
def _get_watched_fields(cls: type) -> list[str] | None:
"""Return the watched fields for *cls*, walking the MRO to inherit from parents."""
for klass in cls.__mro__:
if klass in _WATCHED_FIELDS:
return _WATCHED_FIELDS[klass]
return None
def _get_watched_fields(cls: type) -> tuple[str, ...] | None:
"""Return the watched fields for *cls*."""
fields = getattr(cls, "__watched_fields__", None)
if fields is not None and (
not isinstance(fields, tuple) or not all(isinstance(f, str) for f in fields)
):
raise TypeError(
f"{cls.__name__}.__watched_fields__ must be a tuple[str, ...], "
f"got {type(fields).__name__}"
)
return fields
def _upsert_changes(
@@ -105,50 +140,32 @@ def _upsert_changes(
pending[key] = (obj, changes)
@event.listens_for(AsyncSession.sync_session_class, "after_transaction_create")
def _after_transaction_create(session: Any, transaction: Any) -> None:
if transaction.nested:
session.info[_SESSION_SAVEPOINT_DEPTH] = (
session.info.get(_SESSION_SAVEPOINT_DEPTH, 0) + 1
)
@event.listens_for(AsyncSession.sync_session_class, "after_transaction_end")
def _after_transaction_end(session: Any, transaction: Any) -> None:
if transaction.nested:
depth = session.info.get(_SESSION_SAVEPOINT_DEPTH, 0)
if depth > 0: # pragma: no branch
session.info[_SESSION_SAVEPOINT_DEPTH] = depth - 1
@event.listens_for(AsyncSession.sync_session_class, "after_flush")
def _after_flush(session: Any, flush_context: Any) -> None:
# New objects: capture references while session.new is still populated.
# Values are read in _after_flush_postexec once RETURNING has been processed.
# New objects: capture reference. Attributes will be refreshed after commit.
for obj in session.new:
if isinstance(obj, WatchedFieldsMixin):
session.info.setdefault(_SESSION_PENDING_NEW, []).append(obj)
if _is_watched(obj):
session.info.setdefault(_SESSION_CREATES, []).append(obj)
# Deleted objects: capture before they leave the identity map.
# Deleted objects: snapshot now while attributes are still loaded.
for obj in session.deleted:
if isinstance(obj, WatchedFieldsMixin):
session.info.setdefault(_SESSION_DELETES, []).append(obj)
if _is_watched(obj):
snapshot = _snapshot_column_attrs(obj)
session.info.setdefault(_SESSION_DELETES, []).append((obj, snapshot))
# Dirty objects: read old/new from SQLAlchemy attribute history.
for obj in session.dirty:
if not isinstance(obj, WatchedFieldsMixin):
if not _is_watched(obj):
continue
# None = not in dict = watch all fields; list = specific fields only
watched = _get_watched_fields(type(obj))
changes: dict[str, dict[str, Any]] = {}
inst_attrs = sa_inspect(obj).attrs
attrs = (
# Specific fields
((field, sa_inspect(obj).attrs[field]) for field in watched)
((field, inst_attrs[field]) for field in watched)
if watched is not None
# All mapped fields
else ((s.key, s) for s in sa_inspect(obj).attrs)
else ((s.key, s) for s in inst_attrs)
)
for field, attr_state in attrs:
history = attr_state.history
@@ -166,116 +183,101 @@ def _after_flush(session: Any, flush_context: Any) -> None:
)
@event.listens_for(AsyncSession.sync_session_class, "after_flush_postexec")
def _after_flush_postexec(session: Any, flush_context: Any) -> None:
# New objects are now persistent and RETURNING values have been applied,
# so server defaults (id, created_at, …) are available via getattr.
pending_new: list[Any] = session.info.pop(_SESSION_PENDING_NEW, [])
if not pending_new:
return
session.info.setdefault(_SESSION_CREATES, []).extend(pending_new)
@event.listens_for(AsyncSession.sync_session_class, "after_rollback")
def _after_rollback(session: Any) -> None:
session.info.pop(_SESSION_PENDING_NEW, None)
if session.in_transaction():
return
session.info.pop(_SESSION_CREATES, None)
session.info.pop(_SESSION_DELETES, None)
session.info.pop(_SESSION_UPDATES, None)
def _task_error_handler(task: asyncio.Task[Any]) -> None:
if not task.cancelled() and (exc := task.exception()):
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
def _schedule_with_snapshot(
loop: asyncio.AbstractEventLoop, obj: Any, fn: Any, *args: Any
async def _invoke_callback(
fn: Callable[..., Any],
obj: Any,
event_type: ModelEvent,
changes: dict[str, dict[str, Any]] | None,
) -> None:
"""Snapshot *obj*'s column attrs now (before expire_on_commit wipes them),
then schedule a coroutine that restores the snapshot and calls *fn*.
"""
snapshot = _snapshot_column_attrs(obj)
async def _run(
obj: Any = obj,
fn: Any = fn,
snapshot: dict[str, Any] = snapshot,
args: tuple = args,
) -> None:
for key, value in snapshot.items():
_sa_set_committed_value(obj, key, value)
try:
result = fn(*args)
if inspect.isawaitable(result):
await result
except Exception as exc:
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
task = loop.create_task(_run())
task.add_done_callback(_task_error_handler)
"""Call *fn* and await the result if it is awaitable."""
result = fn(obj, event_type, changes)
if inspect.isawaitable(result):
await result
@event.listens_for(AsyncSession.sync_session_class, "after_commit")
def _after_commit(session: Any) -> None:
if session.info.get(_SESSION_SAVEPOINT_DEPTH, 0) > 0:
return
class EventSession(AsyncSession):
"""AsyncSession subclass that dispatches lifecycle callbacks after commit."""
creates: list[Any] = session.info.pop(_SESSION_CREATES, [])
deletes: list[Any] = session.info.pop(_SESSION_DELETES, [])
field_changes: dict[int, tuple[Any, dict[str, dict[str, Any]]]] = session.info.pop(
_SESSION_UPDATES, {}
)
async def commit(self) -> None: # noqa: C901
await super().commit()
if creates and deletes:
transient_ids = {id(o) for o in creates} & {id(o) for o in deletes}
if transient_ids:
creates = [o for o in creates if id(o) not in transient_ids]
deletes = [o for o in deletes if id(o) not in transient_ids]
creates: list[Any] = self.info.pop(_SESSION_CREATES, [])
deletes: list[tuple[Any, dict[str, Any]]] = self.info.pop(_SESSION_DELETES, [])
field_changes: dict[int, tuple[Any, dict[str, dict[str, Any]]]] = self.info.pop(
_SESSION_UPDATES, {}
)
if not creates and not deletes and not field_changes:
return
# Suppress transient objects (created + deleted in same transaction).
if creates and deletes:
created_ids = {id(o) for o in creates}
deleted_ids = {id(o) for o, _ in deletes}
transient_ids = created_ids & deleted_ids
if transient_ids:
creates = [o for o in creates if id(o) not in transient_ids]
deletes = [(o, s) for o, s in deletes if id(o) not in transient_ids]
field_changes = {
k: v for k, v in field_changes.items() if k not in transient_ids
}
# Suppress updates for newly created objects (CREATE-only semantics).
if creates and field_changes:
create_ids = {id(o) for o in creates}
field_changes = {
k: v for k, v in field_changes.items() if k not in transient_ids
k: v for k, v in field_changes.items() if k not in create_ids
}
if not creates and not deletes and not field_changes:
return
# Dispatch CREATE callbacks.
for obj in creates:
try:
state = sa_inspect(obj, raiseerr=False)
if (
state is None or state.detached or state.transient
): # pragma: no cover
continue
await self.refresh(obj)
for handler in _get_handlers(type(obj), ModelEvent.CREATE):
await _invoke_callback(handler, obj, ModelEvent.CREATE, None)
except Exception as exc:
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return
# Dispatch DELETE callbacks (restore snapshot; row is gone).
for obj, snapshot in deletes:
try:
for key, value in snapshot.items():
_sa_set_committed_value(obj, key, value)
for handler in _get_handlers(type(obj), ModelEvent.DELETE):
await _invoke_callback(handler, obj, ModelEvent.DELETE, None)
except Exception as exc:
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
for obj in creates:
_schedule_with_snapshot(loop, obj, obj.on_create)
# Dispatch UPDATE callbacks.
for obj, changes in field_changes.values():
try:
state = sa_inspect(obj, raiseerr=False)
if (
state is None or state.detached or state.transient
): # pragma: no cover
continue
await self.refresh(obj)
for handler in _get_handlers(type(obj), ModelEvent.UPDATE):
await _invoke_callback(handler, obj, ModelEvent.UPDATE, changes)
except Exception as exc:
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
for obj in deletes:
_schedule_with_snapshot(loop, obj, obj.on_delete)
for obj, changes in field_changes.values():
_schedule_with_snapshot(loop, obj, obj.on_update, changes)
class WatchedFieldsMixin:
"""Mixin that enables lifecycle callbacks for SQLAlchemy models."""
def on_event(
self, event: ModelEvent, changes: dict[str, dict[str, Any]] | None = None
) -> Awaitable[None] | None:
"""Catch-all callback fired for every lifecycle event.
Args:
event: The event type (:attr:`ModelEvent.CREATE`, :attr:`ModelEvent.DELETE`,
or :attr:`ModelEvent.UPDATE`).
changes: Field changes for :attr:`ModelEvent.UPDATE`, ``None`` otherwise.
"""
def on_create(self) -> Awaitable[None] | None:
"""Called after INSERT commit."""
return self.on_event(ModelEvent.CREATE)
def on_delete(self) -> Awaitable[None] | None:
"""Called after DELETE commit."""
return self.on_event(ModelEvent.DELETE)
def on_update(self, changes: dict[str, dict[str, Any]]) -> Awaitable[None] | None:
"""Called after UPDATE commit when watched fields change."""
return self.on_event(ModelEvent.UPDATE, changes=changes)
async def rollback(self) -> None:
await super().rollback()
self.info.pop(_SESSION_CREATES, None)
self.info.pop(_SESSION_DELETES, None)
self.info.pop(_SESSION_UPDATES, None)

View File

@@ -18,6 +18,7 @@ from sqlalchemy.orm import DeclarativeBase
from ..db import cleanup_tables as _cleanup_tables
from ..db import create_database
from ..models.watched import EventSession
async def cleanup_tables(
@@ -265,7 +266,9 @@ async def create_db_session(
async with engine.begin() as conn:
await conn.run_sync(base.metadata.create_all)
session_maker = async_sessionmaker(engine, expire_on_commit=expire_on_commit)
session_maker = async_sessionmaker(
engine, expire_on_commit=expire_on_commit, class_=EventSession
)
async with session_maker() as session:
yield session