feat: rework async event system (#196)

* feat: rework async event system

* docs: add v3 migration guide

* feat: add cache

* enhancements
This commit is contained in:
d3vyce
2026-03-30 18:24:36 +02:00
committed by GitHub
parent 104285c6e5
commit 1890d696bf
12 changed files with 1149 additions and 659 deletions

View File

@@ -48,7 +48,8 @@ uv add "fastapi-toolsets[all]"
- **Database**: Session management, transaction helpers, table locking, and polling-based row change detection
- **Dependencies**: FastAPI dependency factories (`PathDependency`, `BodyDependency`) for automatic DB lookups from path or body parameters
- **Fixtures**: Fixture system with dependency management, context support, and pytest integration
- **Model Mixins**: SQLAlchemy mixins for common column patterns (`UUIDMixin`, `UUIDv7Mixin`, `CreatedAtMixin`, `UpdatedAtMixin`, `TimestampMixin`) and lifecycle callbacks (`WatchedFieldsMixin`, `@watch`) that fire after commit for insert, update, and delete events
- **Model Mixins**: SQLAlchemy mixins for common column patterns (`UUIDMixin`, `UUIDv7Mixin`, `CreatedAtMixin`, `UpdatedAtMixin`, `TimestampMixin`)
- **Lifecycle Events**: Post-commit event system (`EventSession`, `listens_for`) that dispatches async/sync callbacks for insert, update, and delete operations
- **Standardized API Responses**: Consistent response format with `Response`, `ErrorResponse`, `PaginatedResponse`, `CursorPaginatedResponse` and `OffsetPaginatedResponse`.
- **Exception Handling**: Structured error responses with automatic OpenAPI documentation
- **Logging**: Logging configuration with uvicorn integration via `configure_logging` and `get_logger`

View File

@@ -48,7 +48,8 @@ uv add "fastapi-toolsets[all]"
- **Database**: Session management, transaction helpers, table locking, and polling-based row change detection
- **Dependencies**: FastAPI dependency factories (`PathDependency`, `BodyDependency`) for automatic DB lookups from path or body parameters
- **Fixtures**: Fixture system with dependency management, context support, and pytest integration
- **Model Mixins**: SQLAlchemy mixins for common column patterns (`UUIDMixin`, `UUIDv7Mixin`, `CreatedAtMixin`, `UpdatedAtMixin`, `TimestampMixin`) and lifecycle callbacks (`WatchedFieldsMixin`) that fire after commit for insert, update, and delete events.
- **Model Mixins**: SQLAlchemy mixins for common column patterns (`UUIDMixin`, `UUIDv7Mixin`, `CreatedAtMixin`, `UpdatedAtMixin`, `TimestampMixin`).
- **Lifecycle Events**: Post-commit event system (`EventSession`, `listens_for`) that dispatches async/sync callbacks for insert, update, and delete operations.
- **Standardized API Responses**: Consistent response format with `Response`, `ErrorResponse`, `PaginatedResponse`, `CursorPaginatedResponse` and `OffsetPaginatedResponse`.
- **Exception Handling**: Structured error responses with automatic OpenAPI documentation
- **Logging**: Logging configuration with uvicorn integration via `configure_logging` and `get_logger`

93
docs/migration/v3.md Normal file
View File

@@ -0,0 +1,93 @@
# Migrating to v3.0
This page covers every breaking change introduced in **v3.0** and the steps required to update your code.
---
## Models
The lifecycle event system has been rewritten. Callbacks are now registered with a module-level [`listens_for`](../reference/models.md#fastapi_toolsets.models.listens_for) decorator and dispatched by [`EventSession`](../reference/models.md#fastapi_toolsets.models.EventSession), replacing the mixin-based approach from `v2`.
### `WatchedFieldsMixin` and `@watch` removed
Importing `WatchedFieldsMixin` or `watch` will raise `ImportError`.
Model method callbacks (`on_create`, `on_delete`, `on_update`) and the `@watch` decorator are replaced by:
1. **`__watched_fields__`** — a plain class attribute to restrict which field changes trigger `UPDATE` events (replaces `@watch`).
2. **`@listens_for`** — a module-level decorator to register callbacks for one or more [`ModelEvent`](../reference/models.md#fastapi_toolsets.models.ModelEvent) types (replaces `on_create` / `on_delete` / `on_update` methods).
=== "Before (`v2`)"
```python
from fastapi_toolsets.models import WatchedFieldsMixin, watch
@watch("status")
class Order(Base, UUIDMixin, WatchedFieldsMixin):
__tablename__ = "orders"
status: Mapped[str]
async def on_create(self):
await notify_new_order(self.id)
async def on_update(self, changes):
if "status" in changes:
await notify_status_change(self.id, changes["status"])
async def on_delete(self):
await notify_order_cancelled(self.id)
```
=== "Now (`v3`)"
```python
from fastapi_toolsets.models import ModelEvent, UUIDMixin, listens_for
class Order(Base, UUIDMixin):
__tablename__ = "orders"
__watched_fields__ = ("status",)
status: Mapped[str]
@listens_for(Order, [ModelEvent.CREATE])
async def on_order_created(order: Order, event_type: ModelEvent, changes: None):
await notify_new_order(order.id)
@listens_for(Order, [ModelEvent.UPDATE])
async def on_order_updated(order: Order, event_type: ModelEvent, changes: dict):
if "status" in changes:
await notify_status_change(order.id, changes["status"])
@listens_for(Order, [ModelEvent.DELETE])
async def on_order_deleted(order: Order, event_type: ModelEvent, changes: None):
await notify_order_cancelled(order.id)
```
### `EventSession` now required
Without `EventSession`, lifecycle callbacks will silently stop firing.
Callbacks are now dispatched inside `EventSession.commit()` rather than via background tasks. Pass it as the session class when creating your session factory:
=== "Before (`v2`)"
```python
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
engine = create_async_engine("postgresql+asyncpg://...")
SessionLocal = async_sessionmaker(engine, expire_on_commit=False)
```
=== "Now (`v3`)"
```python
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from fastapi_toolsets.models import EventSession
engine = create_async_engine("postgresql+asyncpg://...")
SessionLocal = async_sessionmaker(engine, expire_on_commit=False, class_=EventSession)
```
!!! note
If you use `create_db_session` from `fastapi_toolsets.pytest`, the session already uses `EventSession` — no changes needed in tests.

View File

@@ -117,139 +117,118 @@ class Article(Base, UUIDMixin, TimestampMixin):
title: Mapped[str]
```
### [`WatchedFieldsMixin`](../reference/models.md#fastapi_toolsets.models.WatchedFieldsMixin)
## Lifecycle events
!!! info "Added in `v2.4`"
The event system provides lifecycle callbacks that fire **after commit**. If the transaction rolls back, no callback fires.
`WatchedFieldsMixin` provides lifecycle callbacks that fire **after commit** — meaning the row is durably persisted when your callback runs. If the transaction rolls back, no callback fires.
### Setup
Three callbacks are available, each corresponding to a [`ModelEvent`](../reference/models.md#fastapi_toolsets.models.ModelEvent) value:
| Callback | Event | Trigger |
|---|---|---|
| `on_create()` | `ModelEvent.CREATE` | After `INSERT` |
| `on_delete()` | `ModelEvent.DELETE` | After `DELETE` |
| `on_update(changes)` | `ModelEvent.UPDATE` | After `UPDATE` on a watched field |
Server-side defaults (e.g. `id`, `created_at`) are fully populated in all callbacks. All callbacks support both `async def` and plain `def`. Use `@watch` to restrict which fields trigger `on_update`:
| Decorator | `on_update` behaviour |
|---|---|
| `@watch("status", "role")` | Only fires when `status` or `role` changes |
| *(no decorator)* | Fires when **any** mapped field changes |
`@watch` is inherited through the class hierarchy. If a subclass does not declare its own `@watch`, it uses the filter from the nearest decorated parent. Applying `@watch` on the subclass overrides the parent's filter:
Event dispatch requires [`EventSession`](../reference/models.md#fastapi_toolsets.models.EventSession). Pass it as the session class when creating your session factory:
```python
@watch("status")
class Order(Base, UUIDMixin, WatchedFieldsMixin):
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from fastapi_toolsets.models import EventSession
engine = create_async_engine("postgresql+asyncpg://...")
SessionLocal = async_sessionmaker(engine, expire_on_commit=False, class_=EventSession)
```
!!! info "Callbacks fire on `session.commit()` only — not on savepoints."
Savepoints created by [`get_transaction`](db.md) or `begin_nested()` do **not**
trigger callbacks. All events accumulated across flushes are dispatched once
when the outermost `commit()` is called.
### Events
Three event types are available, each corresponding to a [`ModelEvent`](../reference/models.md#fastapi_toolsets.models.ModelEvent) value:
| Event | Trigger |
|---|---|
| `ModelEvent.CREATE` | After `INSERT` commit |
| `ModelEvent.DELETE` | After `DELETE` commit |
| `ModelEvent.UPDATE` | After `UPDATE` commit on a watched field |
!!! warning "Callbacks fire only for ORM-level changes. Rows updated via raw SQL (`UPDATE ... SET ...`) are not detected."
### Watched fields
Set `__watched_fields__` on the model to restrict which field changes trigger `UPDATE` events. It must be a `tuple[str, ...]` — any other type raises `TypeError`:
| Class attribute | `UPDATE` behaviour |
|---|---|
| `__watched_fields__ = ("status", "role")` | Only fires when `status` or `role` changes |
| *(not set)* | Fires when **any** mapped field changes |
`__watched_fields__` is inherited through the class hierarchy via normal Python MRO. A subclass can override it:
```python
class Order(Base, UUIDMixin):
__watched_fields__ = ("status",)
...
class UrgentOrder(Order):
# inherits @watch("status") — on_update fires only for status changes
# inherits __watched_fields__ = ("status",)
...
@watch("priority")
class PriorityOrder(Order):
# overrides parent — on_update fires only for priority changes
__watched_fields__ = ("priority",)
# overrides parent — UPDATE fires only for priority changes
...
```
#### Option 1 — catch-all with `on_event`
### Registering handlers
Override `on_event` to handle all event types in one place. The specific methods delegate here by default:
Register handlers with the [`listens_for`](../reference/models.md#fastapi_toolsets.models.listens_for) decorator. Every callback receives three arguments: the model instance, the [`ModelEvent`](../reference/models.md#fastapi_toolsets.models.ModelEvent) that triggered it, and a `changes` dict (`None` for `CREATE` and `DELETE`):
```python
from fastapi_toolsets.models import ModelEvent, UUIDMixin, WatchedFieldsMixin, watch
from fastapi_toolsets.models import ModelEvent, UUIDMixin, listens_for
@watch("status")
class Order(Base, UUIDMixin, WatchedFieldsMixin):
class Order(Base, UUIDMixin):
__tablename__ = "orders"
__watched_fields__ = ("status",)
status: Mapped[str]
async def on_event(self, event: ModelEvent, changes: dict | None = None) -> None:
if event == ModelEvent.CREATE:
await notify_new_order(self.id)
elif event == ModelEvent.DELETE:
await notify_order_cancelled(self.id)
elif event == ModelEvent.UPDATE:
await notify_status_change(self.id, changes["status"])
@listens_for(Order, [ModelEvent.CREATE])
async def on_order_created(order: Order, event_type: ModelEvent, changes: None):
await notify_new_order(order.id)
@listens_for(Order, [ModelEvent.DELETE])
async def on_order_deleted(order: Order, event_type: ModelEvent, changes: None):
await notify_order_cancelled(order.id)
@listens_for(Order, [ModelEvent.UPDATE])
async def on_order_updated(order: Order, event_type: ModelEvent, changes: dict):
if "status" in changes:
await notify_status_change(order.id, changes["status"])
```
#### Option 2 — targeted overrides
Multiple handlers can be registered for the same model and event. Handlers registered on a parent class also fire for subclass instances.
Override individual methods for more focused logic:
A single handler can listen for multiple events at once. When `event_types` is omitted, the handler fires for all events:
```python
@watch("status")
class Order(Base, UUIDMixin, WatchedFieldsMixin):
__tablename__ = "orders"
@listens_for(Order, [ModelEvent.CREATE, ModelEvent.UPDATE])
async def on_order_changed(order: Order, event_type: ModelEvent, changes: dict | None):
await invalidate_cache(order.id)
status: Mapped[str]
async def on_create(self) -> None:
await notify_new_order(self.id)
async def on_delete(self) -> None:
await notify_order_cancelled(self.id)
async def on_update(self, changes: dict) -> None:
if "status" in changes:
old = changes["status"]["old"]
new = changes["status"]["new"]
await notify_status_change(self.id, old, new)
@listens_for(Order) # all events
async def on_any_order_event(order: Order, event_type: ModelEvent, changes: dict | None):
await audit_log(order.id, event_type)
```
#### Field changes format
### Field changes format
The `changes` dict maps each watched field that changed to `{"old": ..., "new": ...}`. Only fields that actually changed are included:
The `changes` dict maps each watched field that changed to `{"old": ..., "new": ...}`. Only fields that actually changed are included. For `CREATE` and `DELETE` events, `changes` is `None`:
```python
# CREATE / DELETE → changes is None
# status changed → {"status": {"old": "pending", "new": "shipped"}}
# two fields changed → {"status": {...}, "assigned_to": {...}}
```
!!! info "Multiple flushes in one transaction are merged: the earliest `old` and latest `new` are preserved, and `on_update` fires only once per commit."
!!! warning "Callbacks fire only for ORM-level changes. Rows updated via raw SQL (`UPDATE ... SET ...`) are not detected."
!!! warning "Callbacks fire when the **outermost active context** (savepoint or transaction) commits."
If you create several related objects using `CrudFactory.create` and need
callbacks to see all of them (including associations), wrap the whole
operation in a single [`get_transaction`](db.md) or [`lock_tables`](db.md)
block. Without it, each `create` call commits its own savepoint and
`on_create` fires before the remaining objects exist.
```python
from fastapi_toolsets.db import get_transaction
async with get_transaction(session):
order = await OrderCrud.create(session, order_data)
item = await ItemCrud.create(session, item_data)
await session.refresh(order, attribute_names=["items"])
order.items.append(item)
# on_create fires here for both order and item,
# with the full association already committed.
```
## Composing mixins
All mixins can be combined in any order. The only constraint is that exactly one primary key must be defined — either via `UUIDMixin` or directly on the model.
```python
from fastapi_toolsets.models import UUIDMixin, TimestampMixin
class Event(Base, UUIDMixin, TimestampMixin):
__tablename__ = "events"
name: Mapped[str]
class Counter(Base, UpdatedAtMixin):
__tablename__ = "counters"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
value: Mapped[int]
```
---
[:material-api: API Reference](../reference/models.md)

View File

@@ -6,17 +6,19 @@ You can import them directly from `fastapi_toolsets.models`:
```python
from fastapi_toolsets.models import (
EventSession,
ModelEvent,
UUIDMixin,
UUIDv7Mixin,
CreatedAtMixin,
UpdatedAtMixin,
TimestampMixin,
WatchedFieldsMixin,
watch,
listens_for,
)
```
## ::: fastapi_toolsets.models.EventSession
## ::: fastapi_toolsets.models.ModelEvent
## ::: fastapi_toolsets.models.UUIDMixin
@@ -29,6 +31,4 @@ from fastapi_toolsets.models import (
## ::: fastapi_toolsets.models.TimestampMixin
## ::: fastapi_toolsets.models.WatchedFieldsMixin
## ::: fastapi_toolsets.models.watch
## ::: fastapi_toolsets.models.listens_for

View File

@@ -24,9 +24,12 @@ __all__ = [
]
_SessionT = TypeVar("_SessionT", bound=AsyncSession)
def create_db_dependency(
session_maker: async_sessionmaker[AsyncSession],
) -> Callable[[], AsyncGenerator[AsyncSession, None]]:
session_maker: async_sessionmaker[_SessionT],
) -> Callable[[], AsyncGenerator[_SessionT, None]]:
"""Create a FastAPI dependency for database sessions.
Creates a dependency function that yields a session and auto-commits
@@ -54,7 +57,7 @@ def create_db_dependency(
```
"""
async def get_db() -> AsyncGenerator[AsyncSession, None]:
async def get_db() -> AsyncGenerator[_SessionT, None]:
async with session_maker() as session:
await session.connection()
yield session
@@ -65,8 +68,8 @@ def create_db_dependency(
def create_db_context(
session_maker: async_sessionmaker[AsyncSession],
) -> Callable[[], AbstractAsyncContextManager[AsyncSession]]:
session_maker: async_sessionmaker[_SessionT],
) -> Callable[[], AbstractAsyncContextManager[_SessionT]]:
"""Create a context manager for database sessions.
Creates a context manager for use outside of FastAPI request handlers,

View File

@@ -7,15 +7,15 @@ from .columns import (
UUIDv7Mixin,
UpdatedAtMixin,
)
from .watched import ModelEvent, WatchedFieldsMixin, watch
from .watched import EventSession, ModelEvent, listens_for
__all__ = [
"EventSession",
"ModelEvent",
"UUIDMixin",
"UUIDv7Mixin",
"CreatedAtMixin",
"UpdatedAtMixin",
"TimestampMixin",
"WatchedFieldsMixin",
"watch",
"listens_for",
]

View File

@@ -6,14 +6,6 @@ from datetime import datetime
from sqlalchemy import DateTime, Uuid, text
from sqlalchemy.orm import Mapped, mapped_column
__all__ = [
"UUIDMixin",
"UUIDv7Mixin",
"CreatedAtMixin",
"UpdatedAtMixin",
"TimestampMixin",
]
class UUIDMixin:
"""Mixin that adds a UUID primary key auto-generated by the database."""

View File

@@ -1,11 +1,9 @@
"""Field-change monitoring via SQLAlchemy session events."""
import asyncio
import inspect
import weakref
from collections.abc import Awaitable
from collections.abc import Callable
from enum import Enum
from typing import Any, TypeVar
from typing import Any
from sqlalchemy import event
from sqlalchemy import inspect as sa_inspect
@@ -14,49 +12,81 @@ from sqlalchemy.orm.attributes import set_committed_value as _sa_set_committed_v
from ..logger import get_logger
__all__ = ["ModelEvent", "WatchedFieldsMixin", "watch"]
_logger = get_logger()
_T = TypeVar("_T")
_CALLBACK_ERROR_MSG = "WatchedFieldsMixin callback raised an unhandled exception"
_WATCHED_FIELDS: weakref.WeakKeyDictionary[type, list[str]] = (
weakref.WeakKeyDictionary()
)
_SESSION_PENDING_NEW = "_ft_pending_new"
_SESSION_CREATES = "_ft_creates"
_SESSION_DELETES = "_ft_deletes"
_SESSION_UPDATES = "_ft_updates"
_SESSION_SAVEPOINT_DEPTH = "_ft_sp_depth"
_DEFERRED_STRATEGY_KEY = (("deferred", True), ("instrument", True))
class ModelEvent(str, Enum):
"""Event types emitted by :class:`WatchedFieldsMixin`."""
"""Event types dispatched by :class:`EventSession`."""
CREATE = "create"
DELETE = "delete"
UPDATE = "update"
def watch(*fields: str) -> Any:
"""Class decorator to filter which fields trigger ``on_update``.
_CALLBACK_ERROR_MSG = "Event callback raised an unhandled exception"
_SESSION_CREATES = "_ft_creates"
_SESSION_DELETES = "_ft_deletes"
_SESSION_UPDATES = "_ft_updates"
_DEFERRED_STRATEGY_KEY = (("deferred", True), ("instrument", True))
_EVENT_HANDLERS: dict[tuple[type, ModelEvent], list[Callable[..., Any]]] = {}
_WATCHED_MODELS: set[type] = set()
_WATCHED_CACHE: dict[type, bool] = {}
_HANDLER_CACHE: dict[tuple[type, ModelEvent], list[Callable[..., Any]]] = {}
def _invalidate_caches() -> None:
"""Clear lookup caches after handler registration."""
_WATCHED_CACHE.clear()
_HANDLER_CACHE.clear()
def listens_for(
model_class: type,
event_types: list[ModelEvent] | None = None,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""Register a callback for one or more model lifecycle events.
Args:
*fields: One or more field names to watch. At least one name is required.
Raises:
ValueError: If called with no field names.
model_class: The SQLAlchemy model class to listen on.
event_types: List of :class:`ModelEvent` values to listen for.
Defaults to all event types.
"""
if not fields:
raise ValueError("@watch requires at least one field name.")
evs = event_types if event_types is not None else list(ModelEvent)
def decorator(cls: type[_T]) -> type[_T]:
_WATCHED_FIELDS[cls] = list(fields)
return cls
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
for ev in evs:
_EVENT_HANDLERS.setdefault((model_class, ev), []).append(fn)
_WATCHED_MODELS.add(model_class)
_invalidate_caches()
return fn
return decorator
def _is_watched(obj: Any) -> bool:
"""Return True if *obj*'s type (or any ancestor) has registered handlers."""
cls = type(obj)
try:
return _WATCHED_CACHE[cls]
except KeyError:
result = any(klass in _WATCHED_MODELS for klass in cls.__mro__)
_WATCHED_CACHE[cls] = result
return result
def _get_handlers(cls: type, ev: ModelEvent) -> list[Callable[..., Any]]:
"""Return registered handlers for *cls* and *ev*, walking the MRO."""
key = (cls, ev)
try:
return _HANDLER_CACHE[key]
except KeyError:
handlers: list[Callable[..., Any]] = []
for klass in cls.__mro__:
handlers.extend(_EVENT_HANDLERS.get((klass, ev), []))
_HANDLER_CACHE[key] = handlers
return handlers
def _snapshot_column_attrs(obj: Any) -> dict[str, Any]:
"""Read currently-loaded column values into a plain dict."""
state = sa_inspect(obj) # InstanceState
@@ -65,7 +95,7 @@ def _snapshot_column_attrs(obj: Any) -> dict[str, Any]:
for prop in state.mapper.column_attrs:
if prop.key in state_dict:
snapshot[prop.key] = state_dict[prop.key]
elif (
elif ( # pragma: no cover
not state.expired
and prop.strategy_key != _DEFERRED_STRATEGY_KEY
and all(
@@ -79,12 +109,17 @@ def _snapshot_column_attrs(obj: Any) -> dict[str, Any]:
return snapshot
def _get_watched_fields(cls: type) -> list[str] | None:
"""Return the watched fields for *cls*, walking the MRO to inherit from parents."""
for klass in cls.__mro__:
if klass in _WATCHED_FIELDS:
return _WATCHED_FIELDS[klass]
return None
def _get_watched_fields(cls: type) -> tuple[str, ...] | None:
"""Return the watched fields for *cls*."""
fields = getattr(cls, "__watched_fields__", None)
if fields is not None and (
not isinstance(fields, tuple) or not all(isinstance(f, str) for f in fields)
):
raise TypeError(
f"{cls.__name__}.__watched_fields__ must be a tuple[str, ...], "
f"got {type(fields).__name__}"
)
return fields
def _upsert_changes(
@@ -105,50 +140,32 @@ def _upsert_changes(
pending[key] = (obj, changes)
@event.listens_for(AsyncSession.sync_session_class, "after_transaction_create")
def _after_transaction_create(session: Any, transaction: Any) -> None:
if transaction.nested:
session.info[_SESSION_SAVEPOINT_DEPTH] = (
session.info.get(_SESSION_SAVEPOINT_DEPTH, 0) + 1
)
@event.listens_for(AsyncSession.sync_session_class, "after_transaction_end")
def _after_transaction_end(session: Any, transaction: Any) -> None:
if transaction.nested:
depth = session.info.get(_SESSION_SAVEPOINT_DEPTH, 0)
if depth > 0: # pragma: no branch
session.info[_SESSION_SAVEPOINT_DEPTH] = depth - 1
@event.listens_for(AsyncSession.sync_session_class, "after_flush")
def _after_flush(session: Any, flush_context: Any) -> None:
# New objects: capture references while session.new is still populated.
# Values are read in _after_flush_postexec once RETURNING has been processed.
# New objects: capture reference. Attributes will be refreshed after commit.
for obj in session.new:
if isinstance(obj, WatchedFieldsMixin):
session.info.setdefault(_SESSION_PENDING_NEW, []).append(obj)
if _is_watched(obj):
session.info.setdefault(_SESSION_CREATES, []).append(obj)
# Deleted objects: capture before they leave the identity map.
# Deleted objects: snapshot now while attributes are still loaded.
for obj in session.deleted:
if isinstance(obj, WatchedFieldsMixin):
session.info.setdefault(_SESSION_DELETES, []).append(obj)
if _is_watched(obj):
snapshot = _snapshot_column_attrs(obj)
session.info.setdefault(_SESSION_DELETES, []).append((obj, snapshot))
# Dirty objects: read old/new from SQLAlchemy attribute history.
for obj in session.dirty:
if not isinstance(obj, WatchedFieldsMixin):
if not _is_watched(obj):
continue
# None = not in dict = watch all fields; list = specific fields only
watched = _get_watched_fields(type(obj))
changes: dict[str, dict[str, Any]] = {}
inst_attrs = sa_inspect(obj).attrs
attrs = (
# Specific fields
((field, sa_inspect(obj).attrs[field]) for field in watched)
((field, inst_attrs[field]) for field in watched)
if watched is not None
# All mapped fields
else ((s.key, s) for s in sa_inspect(obj).attrs)
else ((s.key, s) for s in inst_attrs)
)
for field, attr_state in attrs:
history = attr_state.history
@@ -166,116 +183,101 @@ def _after_flush(session: Any, flush_context: Any) -> None:
)
@event.listens_for(AsyncSession.sync_session_class, "after_flush_postexec")
def _after_flush_postexec(session: Any, flush_context: Any) -> None:
# New objects are now persistent and RETURNING values have been applied,
# so server defaults (id, created_at, …) are available via getattr.
pending_new: list[Any] = session.info.pop(_SESSION_PENDING_NEW, [])
if not pending_new:
return
session.info.setdefault(_SESSION_CREATES, []).extend(pending_new)
@event.listens_for(AsyncSession.sync_session_class, "after_rollback")
def _after_rollback(session: Any) -> None:
session.info.pop(_SESSION_PENDING_NEW, None)
if session.in_transaction():
return
session.info.pop(_SESSION_CREATES, None)
session.info.pop(_SESSION_DELETES, None)
session.info.pop(_SESSION_UPDATES, None)
def _task_error_handler(task: asyncio.Task[Any]) -> None:
if not task.cancelled() and (exc := task.exception()):
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
def _schedule_with_snapshot(
loop: asyncio.AbstractEventLoop, obj: Any, fn: Any, *args: Any
async def _invoke_callback(
fn: Callable[..., Any],
obj: Any,
event_type: ModelEvent,
changes: dict[str, dict[str, Any]] | None,
) -> None:
"""Snapshot *obj*'s column attrs now (before expire_on_commit wipes them),
then schedule a coroutine that restores the snapshot and calls *fn*.
"""
snapshot = _snapshot_column_attrs(obj)
async def _run(
obj: Any = obj,
fn: Any = fn,
snapshot: dict[str, Any] = snapshot,
args: tuple = args,
) -> None:
for key, value in snapshot.items():
_sa_set_committed_value(obj, key, value)
try:
result = fn(*args)
if inspect.isawaitable(result):
await result
except Exception as exc:
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
task = loop.create_task(_run())
task.add_done_callback(_task_error_handler)
"""Call *fn* and await the result if it is awaitable."""
result = fn(obj, event_type, changes)
if inspect.isawaitable(result):
await result
@event.listens_for(AsyncSession.sync_session_class, "after_commit")
def _after_commit(session: Any) -> None:
if session.info.get(_SESSION_SAVEPOINT_DEPTH, 0) > 0:
return
class EventSession(AsyncSession):
"""AsyncSession subclass that dispatches lifecycle callbacks after commit."""
creates: list[Any] = session.info.pop(_SESSION_CREATES, [])
deletes: list[Any] = session.info.pop(_SESSION_DELETES, [])
field_changes: dict[int, tuple[Any, dict[str, dict[str, Any]]]] = session.info.pop(
_SESSION_UPDATES, {}
)
async def commit(self) -> None: # noqa: C901
await super().commit()
if creates and deletes:
transient_ids = {id(o) for o in creates} & {id(o) for o in deletes}
if transient_ids:
creates = [o for o in creates if id(o) not in transient_ids]
deletes = [o for o in deletes if id(o) not in transient_ids]
creates: list[Any] = self.info.pop(_SESSION_CREATES, [])
deletes: list[tuple[Any, dict[str, Any]]] = self.info.pop(_SESSION_DELETES, [])
field_changes: dict[int, tuple[Any, dict[str, dict[str, Any]]]] = self.info.pop(
_SESSION_UPDATES, {}
)
if not creates and not deletes and not field_changes:
return
# Suppress transient objects (created + deleted in same transaction).
if creates and deletes:
created_ids = {id(o) for o in creates}
deleted_ids = {id(o) for o, _ in deletes}
transient_ids = created_ids & deleted_ids
if transient_ids:
creates = [o for o in creates if id(o) not in transient_ids]
deletes = [(o, s) for o, s in deletes if id(o) not in transient_ids]
field_changes = {
k: v for k, v in field_changes.items() if k not in transient_ids
}
# Suppress updates for newly created objects (CREATE-only semantics).
if creates and field_changes:
create_ids = {id(o) for o in creates}
field_changes = {
k: v for k, v in field_changes.items() if k not in transient_ids
k: v for k, v in field_changes.items() if k not in create_ids
}
if not creates and not deletes and not field_changes:
return
# Dispatch CREATE callbacks.
for obj in creates:
try:
state = sa_inspect(obj, raiseerr=False)
if (
state is None or state.detached or state.transient
): # pragma: no cover
continue
await self.refresh(obj)
for handler in _get_handlers(type(obj), ModelEvent.CREATE):
await _invoke_callback(handler, obj, ModelEvent.CREATE, None)
except Exception as exc:
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return
# Dispatch DELETE callbacks (restore snapshot; row is gone).
for obj, snapshot in deletes:
try:
for key, value in snapshot.items():
_sa_set_committed_value(obj, key, value)
for handler in _get_handlers(type(obj), ModelEvent.DELETE):
await _invoke_callback(handler, obj, ModelEvent.DELETE, None)
except Exception as exc:
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
for obj in creates:
_schedule_with_snapshot(loop, obj, obj.on_create)
# Dispatch UPDATE callbacks.
for obj, changes in field_changes.values():
try:
state = sa_inspect(obj, raiseerr=False)
if (
state is None or state.detached or state.transient
): # pragma: no cover
continue
await self.refresh(obj)
for handler in _get_handlers(type(obj), ModelEvent.UPDATE):
await _invoke_callback(handler, obj, ModelEvent.UPDATE, changes)
except Exception as exc:
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
for obj in deletes:
_schedule_with_snapshot(loop, obj, obj.on_delete)
for obj, changes in field_changes.values():
_schedule_with_snapshot(loop, obj, obj.on_update, changes)
class WatchedFieldsMixin:
"""Mixin that enables lifecycle callbacks for SQLAlchemy models."""
def on_event(
self, event: ModelEvent, changes: dict[str, dict[str, Any]] | None = None
) -> Awaitable[None] | None:
"""Catch-all callback fired for every lifecycle event.
Args:
event: The event type (:attr:`ModelEvent.CREATE`, :attr:`ModelEvent.DELETE`,
or :attr:`ModelEvent.UPDATE`).
changes: Field changes for :attr:`ModelEvent.UPDATE`, ``None`` otherwise.
"""
def on_create(self) -> Awaitable[None] | None:
"""Called after INSERT commit."""
return self.on_event(ModelEvent.CREATE)
def on_delete(self) -> Awaitable[None] | None:
"""Called after DELETE commit."""
return self.on_event(ModelEvent.DELETE)
def on_update(self, changes: dict[str, dict[str, Any]]) -> Awaitable[None] | None:
"""Called after UPDATE commit when watched fields change."""
return self.on_event(ModelEvent.UPDATE, changes=changes)
async def rollback(self) -> None:
await super().rollback()
self.info.pop(_SESSION_CREATES, None)
self.info.pop(_SESSION_DELETES, None)
self.info.pop(_SESSION_UPDATES, None)

View File

@@ -18,6 +18,7 @@ from sqlalchemy.orm import DeclarativeBase
from ..db import cleanup_tables as _cleanup_tables
from ..db import create_database
from ..models.watched import EventSession
async def cleanup_tables(
@@ -265,7 +266,9 @@ async def create_db_session(
async with engine.begin() as conn:
await conn.run_sync(base.metadata.create_all)
session_maker = async_sessionmaker(engine, expire_on_commit=expire_on_commit)
session_maker = async_sessionmaker(
engine, expire_on_commit=expire_on_commit, class_=EventSession
)
async with session_maker() as session:
yield session

File diff suppressed because it is too large Load Diff

View File

@@ -140,6 +140,7 @@ Examples = [
[[project.nav]]
Migration = [
{"v3.0" = "migration/v3.md"},
{"v2.0" = "migration/v2.md"},
]