feat: rework async event system (#196)

* feat: rework async event system

* docs: add v3 migration guide

* feat: add cache

* enhancements
This commit is contained in:
d3vyce
2026-03-30 18:24:36 +02:00
committed by GitHub
parent 104285c6e5
commit 1890d696bf
12 changed files with 1149 additions and 659 deletions

View File

@@ -48,7 +48,8 @@ uv add "fastapi-toolsets[all]"
- **Database**: Session management, transaction helpers, table locking, and polling-based row change detection - **Database**: Session management, transaction helpers, table locking, and polling-based row change detection
- **Dependencies**: FastAPI dependency factories (`PathDependency`, `BodyDependency`) for automatic DB lookups from path or body parameters - **Dependencies**: FastAPI dependency factories (`PathDependency`, `BodyDependency`) for automatic DB lookups from path or body parameters
- **Fixtures**: Fixture system with dependency management, context support, and pytest integration - **Fixtures**: Fixture system with dependency management, context support, and pytest integration
- **Model Mixins**: SQLAlchemy mixins for common column patterns (`UUIDMixin`, `UUIDv7Mixin`, `CreatedAtMixin`, `UpdatedAtMixin`, `TimestampMixin`) and lifecycle callbacks (`WatchedFieldsMixin`, `@watch`) that fire after commit for insert, update, and delete events - **Model Mixins**: SQLAlchemy mixins for common column patterns (`UUIDMixin`, `UUIDv7Mixin`, `CreatedAtMixin`, `UpdatedAtMixin`, `TimestampMixin`)
- **Lifecycle Events**: Post-commit event system (`EventSession`, `listens_for`) that dispatches async/sync callbacks for insert, update, and delete operations
- **Standardized API Responses**: Consistent response format with `Response`, `ErrorResponse`, `PaginatedResponse`, `CursorPaginatedResponse` and `OffsetPaginatedResponse`. - **Standardized API Responses**: Consistent response format with `Response`, `ErrorResponse`, `PaginatedResponse`, `CursorPaginatedResponse` and `OffsetPaginatedResponse`.
- **Exception Handling**: Structured error responses with automatic OpenAPI documentation - **Exception Handling**: Structured error responses with automatic OpenAPI documentation
- **Logging**: Logging configuration with uvicorn integration via `configure_logging` and `get_logger` - **Logging**: Logging configuration with uvicorn integration via `configure_logging` and `get_logger`

View File

@@ -48,7 +48,8 @@ uv add "fastapi-toolsets[all]"
- **Database**: Session management, transaction helpers, table locking, and polling-based row change detection - **Database**: Session management, transaction helpers, table locking, and polling-based row change detection
- **Dependencies**: FastAPI dependency factories (`PathDependency`, `BodyDependency`) for automatic DB lookups from path or body parameters - **Dependencies**: FastAPI dependency factories (`PathDependency`, `BodyDependency`) for automatic DB lookups from path or body parameters
- **Fixtures**: Fixture system with dependency management, context support, and pytest integration - **Fixtures**: Fixture system with dependency management, context support, and pytest integration
- **Model Mixins**: SQLAlchemy mixins for common column patterns (`UUIDMixin`, `UUIDv7Mixin`, `CreatedAtMixin`, `UpdatedAtMixin`, `TimestampMixin`) and lifecycle callbacks (`WatchedFieldsMixin`) that fire after commit for insert, update, and delete events. - **Model Mixins**: SQLAlchemy mixins for common column patterns (`UUIDMixin`, `UUIDv7Mixin`, `CreatedAtMixin`, `UpdatedAtMixin`, `TimestampMixin`).
- **Lifecycle Events**: Post-commit event system (`EventSession`, `listens_for`) that dispatches async/sync callbacks for insert, update, and delete operations.
- **Standardized API Responses**: Consistent response format with `Response`, `ErrorResponse`, `PaginatedResponse`, `CursorPaginatedResponse` and `OffsetPaginatedResponse`. - **Standardized API Responses**: Consistent response format with `Response`, `ErrorResponse`, `PaginatedResponse`, `CursorPaginatedResponse` and `OffsetPaginatedResponse`.
- **Exception Handling**: Structured error responses with automatic OpenAPI documentation - **Exception Handling**: Structured error responses with automatic OpenAPI documentation
- **Logging**: Logging configuration with uvicorn integration via `configure_logging` and `get_logger` - **Logging**: Logging configuration with uvicorn integration via `configure_logging` and `get_logger`

93
docs/migration/v3.md Normal file
View File

@@ -0,0 +1,93 @@
# Migrating to v3.0
This page covers every breaking change introduced in **v3.0** and the steps required to update your code.
---
## Models
The lifecycle event system has been rewritten. Callbacks are now registered with a module-level [`listens_for`](../reference/models.md#fastapi_toolsets.models.listens_for) decorator and dispatched by [`EventSession`](../reference/models.md#fastapi_toolsets.models.EventSession), replacing the mixin-based approach from `v2`.
### `WatchedFieldsMixin` and `@watch` removed
Importing `WatchedFieldsMixin` or `watch` will raise `ImportError`.
Model method callbacks (`on_create`, `on_delete`, `on_update`) and the `@watch` decorator are replaced by:
1. **`__watched_fields__`** — a plain class attribute to restrict which field changes trigger `UPDATE` events (replaces `@watch`).
2. **`@listens_for`** — a module-level decorator to register callbacks for one or more [`ModelEvent`](../reference/models.md#fastapi_toolsets.models.ModelEvent) types (replaces `on_create` / `on_delete` / `on_update` methods).
=== "Before (`v2`)"
```python
from fastapi_toolsets.models import WatchedFieldsMixin, watch
@watch("status")
class Order(Base, UUIDMixin, WatchedFieldsMixin):
__tablename__ = "orders"
status: Mapped[str]
async def on_create(self):
await notify_new_order(self.id)
async def on_update(self, changes):
if "status" in changes:
await notify_status_change(self.id, changes["status"])
async def on_delete(self):
await notify_order_cancelled(self.id)
```
=== "Now (`v3`)"
```python
from fastapi_toolsets.models import ModelEvent, UUIDMixin, listens_for
class Order(Base, UUIDMixin):
__tablename__ = "orders"
__watched_fields__ = ("status",)
status: Mapped[str]
@listens_for(Order, [ModelEvent.CREATE])
async def on_order_created(order: Order, event_type: ModelEvent, changes: None):
await notify_new_order(order.id)
@listens_for(Order, [ModelEvent.UPDATE])
async def on_order_updated(order: Order, event_type: ModelEvent, changes: dict):
if "status" in changes:
await notify_status_change(order.id, changes["status"])
@listens_for(Order, [ModelEvent.DELETE])
async def on_order_deleted(order: Order, event_type: ModelEvent, changes: None):
await notify_order_cancelled(order.id)
```
### `EventSession` now required
Without `EventSession`, lifecycle callbacks will silently stop firing.
Callbacks are now dispatched inside `EventSession.commit()` rather than via background tasks. Pass it as the session class when creating your session factory:
=== "Before (`v2`)"
```python
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
engine = create_async_engine("postgresql+asyncpg://...")
SessionLocal = async_sessionmaker(engine, expire_on_commit=False)
```
=== "Now (`v3`)"
```python
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from fastapi_toolsets.models import EventSession
engine = create_async_engine("postgresql+asyncpg://...")
SessionLocal = async_sessionmaker(engine, expire_on_commit=False, class_=EventSession)
```
!!! note
If you use `create_db_session` from `fastapi_toolsets.pytest`, the session already uses `EventSession` — no changes needed in tests.

View File

@@ -117,139 +117,118 @@ class Article(Base, UUIDMixin, TimestampMixin):
title: Mapped[str] title: Mapped[str]
``` ```
### [`WatchedFieldsMixin`](../reference/models.md#fastapi_toolsets.models.WatchedFieldsMixin) ## Lifecycle events
!!! info "Added in `v2.4`" The event system provides lifecycle callbacks that fire **after commit**. If the transaction rolls back, no callback fires.
`WatchedFieldsMixin` provides lifecycle callbacks that fire **after commit** — meaning the row is durably persisted when your callback runs. If the transaction rolls back, no callback fires. ### Setup
Three callbacks are available, each corresponding to a [`ModelEvent`](../reference/models.md#fastapi_toolsets.models.ModelEvent) value: Event dispatch requires [`EventSession`](../reference/models.md#fastapi_toolsets.models.EventSession). Pass it as the session class when creating your session factory:
| Callback | Event | Trigger |
|---|---|---|
| `on_create()` | `ModelEvent.CREATE` | After `INSERT` |
| `on_delete()` | `ModelEvent.DELETE` | After `DELETE` |
| `on_update(changes)` | `ModelEvent.UPDATE` | After `UPDATE` on a watched field |
Server-side defaults (e.g. `id`, `created_at`) are fully populated in all callbacks. All callbacks support both `async def` and plain `def`. Use `@watch` to restrict which fields trigger `on_update`:
| Decorator | `on_update` behaviour |
|---|---|
| `@watch("status", "role")` | Only fires when `status` or `role` changes |
| *(no decorator)* | Fires when **any** mapped field changes |
`@watch` is inherited through the class hierarchy. If a subclass does not declare its own `@watch`, it uses the filter from the nearest decorated parent. Applying `@watch` on the subclass overrides the parent's filter:
```python ```python
@watch("status") from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
class Order(Base, UUIDMixin, WatchedFieldsMixin): from fastapi_toolsets.models import EventSession
engine = create_async_engine("postgresql+asyncpg://...")
SessionLocal = async_sessionmaker(engine, expire_on_commit=False, class_=EventSession)
```
!!! info "Callbacks fire on `session.commit()` only — not on savepoints."
Savepoints created by [`get_transaction`](db.md) or `begin_nested()` do **not**
trigger callbacks. All events accumulated across flushes are dispatched once
when the outermost `commit()` is called.
### Events
Three event types are available, each corresponding to a [`ModelEvent`](../reference/models.md#fastapi_toolsets.models.ModelEvent) value:
| Event | Trigger |
|---|---|
| `ModelEvent.CREATE` | After `INSERT` commit |
| `ModelEvent.DELETE` | After `DELETE` commit |
| `ModelEvent.UPDATE` | After `UPDATE` commit on a watched field |
!!! warning "Callbacks fire only for ORM-level changes. Rows updated via raw SQL (`UPDATE ... SET ...`) are not detected."
### Watched fields
Set `__watched_fields__` on the model to restrict which field changes trigger `UPDATE` events. It must be a `tuple[str, ...]` — any other type raises `TypeError`:
| Class attribute | `UPDATE` behaviour |
|---|---|
| `__watched_fields__ = ("status", "role")` | Only fires when `status` or `role` changes |
| *(not set)* | Fires when **any** mapped field changes |
`__watched_fields__` is inherited through the class hierarchy via normal Python MRO. A subclass can override it:
```python
class Order(Base, UUIDMixin):
__watched_fields__ = ("status",)
... ...
class UrgentOrder(Order): class UrgentOrder(Order):
# inherits @watch("status") — on_update fires only for status changes # inherits __watched_fields__ = ("status",)
... ...
@watch("priority")
class PriorityOrder(Order): class PriorityOrder(Order):
# overrides parent — on_update fires only for priority changes __watched_fields__ = ("priority",)
# overrides parent — UPDATE fires only for priority changes
... ...
``` ```
#### Option 1 — catch-all with `on_event` ### Registering handlers
Override `on_event` to handle all event types in one place. The specific methods delegate here by default: Register handlers with the [`listens_for`](../reference/models.md#fastapi_toolsets.models.listens_for) decorator. Every callback receives three arguments: the model instance, the [`ModelEvent`](../reference/models.md#fastapi_toolsets.models.ModelEvent) that triggered it, and a `changes` dict (`None` for `CREATE` and `DELETE`):
```python ```python
from fastapi_toolsets.models import ModelEvent, UUIDMixin, WatchedFieldsMixin, watch from fastapi_toolsets.models import ModelEvent, UUIDMixin, listens_for
@watch("status") class Order(Base, UUIDMixin):
class Order(Base, UUIDMixin, WatchedFieldsMixin):
__tablename__ = "orders" __tablename__ = "orders"
__watched_fields__ = ("status",)
status: Mapped[str] status: Mapped[str]
async def on_event(self, event: ModelEvent, changes: dict | None = None) -> None: @listens_for(Order, [ModelEvent.CREATE])
if event == ModelEvent.CREATE: async def on_order_created(order: Order, event_type: ModelEvent, changes: None):
await notify_new_order(self.id) await notify_new_order(order.id)
elif event == ModelEvent.DELETE:
await notify_order_cancelled(self.id) @listens_for(Order, [ModelEvent.DELETE])
elif event == ModelEvent.UPDATE: async def on_order_deleted(order: Order, event_type: ModelEvent, changes: None):
await notify_status_change(self.id, changes["status"]) await notify_order_cancelled(order.id)
@listens_for(Order, [ModelEvent.UPDATE])
async def on_order_updated(order: Order, event_type: ModelEvent, changes: dict):
if "status" in changes:
await notify_status_change(order.id, changes["status"])
``` ```
#### Option 2 — targeted overrides Multiple handlers can be registered for the same model and event. Handlers registered on a parent class also fire for subclass instances.
Override individual methods for more focused logic: A single handler can listen for multiple events at once. When `event_types` is omitted, the handler fires for all events:
```python ```python
@watch("status") @listens_for(Order, [ModelEvent.CREATE, ModelEvent.UPDATE])
class Order(Base, UUIDMixin, WatchedFieldsMixin): async def on_order_changed(order: Order, event_type: ModelEvent, changes: dict | None):
__tablename__ = "orders" await invalidate_cache(order.id)
status: Mapped[str] @listens_for(Order) # all events
async def on_any_order_event(order: Order, event_type: ModelEvent, changes: dict | None):
async def on_create(self) -> None: await audit_log(order.id, event_type)
await notify_new_order(self.id)
async def on_delete(self) -> None:
await notify_order_cancelled(self.id)
async def on_update(self, changes: dict) -> None:
if "status" in changes:
old = changes["status"]["old"]
new = changes["status"]["new"]
await notify_status_change(self.id, old, new)
``` ```
#### Field changes format ### Field changes format
The `changes` dict maps each watched field that changed to `{"old": ..., "new": ...}`. Only fields that actually changed are included: The `changes` dict maps each watched field that changed to `{"old": ..., "new": ...}`. Only fields that actually changed are included. For `CREATE` and `DELETE` events, `changes` is `None`:
```python ```python
# CREATE / DELETE → changes is None
# status changed → {"status": {"old": "pending", "new": "shipped"}} # status changed → {"status": {"old": "pending", "new": "shipped"}}
# two fields changed → {"status": {...}, "assigned_to": {...}} # two fields changed → {"status": {...}, "assigned_to": {...}}
``` ```
!!! info "Multiple flushes in one transaction are merged: the earliest `old` and latest `new` are preserved, and `on_update` fires only once per commit." !!! info "Multiple flushes in one transaction are merged: the earliest `old` and latest `new` are preserved, and `on_update` fires only once per commit."
!!! warning "Callbacks fire only for ORM-level changes. Rows updated via raw SQL (`UPDATE ... SET ...`) are not detected."
!!! warning "Callbacks fire when the **outermost active context** (savepoint or transaction) commits."
If you create several related objects using `CrudFactory.create` and need
callbacks to see all of them (including associations), wrap the whole
operation in a single [`get_transaction`](db.md) or [`lock_tables`](db.md)
block. Without it, each `create` call commits its own savepoint and
`on_create` fires before the remaining objects exist.
```python
from fastapi_toolsets.db import get_transaction
async with get_transaction(session):
order = await OrderCrud.create(session, order_data)
item = await ItemCrud.create(session, item_data)
await session.refresh(order, attribute_names=["items"])
order.items.append(item)
# on_create fires here for both order and item,
# with the full association already committed.
```
## Composing mixins
All mixins can be combined in any order. The only constraint is that exactly one primary key must be defined — either via `UUIDMixin` or directly on the model.
```python
from fastapi_toolsets.models import UUIDMixin, TimestampMixin
class Event(Base, UUIDMixin, TimestampMixin):
__tablename__ = "events"
name: Mapped[str]
class Counter(Base, UpdatedAtMixin):
__tablename__ = "counters"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
value: Mapped[int]
```
--- ---
[:material-api: API Reference](../reference/models.md) [:material-api: API Reference](../reference/models.md)

View File

@@ -6,17 +6,19 @@ You can import them directly from `fastapi_toolsets.models`:
```python ```python
from fastapi_toolsets.models import ( from fastapi_toolsets.models import (
EventSession,
ModelEvent, ModelEvent,
UUIDMixin, UUIDMixin,
UUIDv7Mixin, UUIDv7Mixin,
CreatedAtMixin, CreatedAtMixin,
UpdatedAtMixin, UpdatedAtMixin,
TimestampMixin, TimestampMixin,
WatchedFieldsMixin, listens_for,
watch,
) )
``` ```
## ::: fastapi_toolsets.models.EventSession
## ::: fastapi_toolsets.models.ModelEvent ## ::: fastapi_toolsets.models.ModelEvent
## ::: fastapi_toolsets.models.UUIDMixin ## ::: fastapi_toolsets.models.UUIDMixin
@@ -29,6 +31,4 @@ from fastapi_toolsets.models import (
## ::: fastapi_toolsets.models.TimestampMixin ## ::: fastapi_toolsets.models.TimestampMixin
## ::: fastapi_toolsets.models.WatchedFieldsMixin ## ::: fastapi_toolsets.models.listens_for
## ::: fastapi_toolsets.models.watch

View File

@@ -24,9 +24,12 @@ __all__ = [
] ]
_SessionT = TypeVar("_SessionT", bound=AsyncSession)
def create_db_dependency( def create_db_dependency(
session_maker: async_sessionmaker[AsyncSession], session_maker: async_sessionmaker[_SessionT],
) -> Callable[[], AsyncGenerator[AsyncSession, None]]: ) -> Callable[[], AsyncGenerator[_SessionT, None]]:
"""Create a FastAPI dependency for database sessions. """Create a FastAPI dependency for database sessions.
Creates a dependency function that yields a session and auto-commits Creates a dependency function that yields a session and auto-commits
@@ -54,7 +57,7 @@ def create_db_dependency(
``` ```
""" """
async def get_db() -> AsyncGenerator[AsyncSession, None]: async def get_db() -> AsyncGenerator[_SessionT, None]:
async with session_maker() as session: async with session_maker() as session:
await session.connection() await session.connection()
yield session yield session
@@ -65,8 +68,8 @@ def create_db_dependency(
def create_db_context( def create_db_context(
session_maker: async_sessionmaker[AsyncSession], session_maker: async_sessionmaker[_SessionT],
) -> Callable[[], AbstractAsyncContextManager[AsyncSession]]: ) -> Callable[[], AbstractAsyncContextManager[_SessionT]]:
"""Create a context manager for database sessions. """Create a context manager for database sessions.
Creates a context manager for use outside of FastAPI request handlers, Creates a context manager for use outside of FastAPI request handlers,

View File

@@ -7,15 +7,15 @@ from .columns import (
UUIDv7Mixin, UUIDv7Mixin,
UpdatedAtMixin, UpdatedAtMixin,
) )
from .watched import ModelEvent, WatchedFieldsMixin, watch from .watched import EventSession, ModelEvent, listens_for
__all__ = [ __all__ = [
"EventSession",
"ModelEvent", "ModelEvent",
"UUIDMixin", "UUIDMixin",
"UUIDv7Mixin", "UUIDv7Mixin",
"CreatedAtMixin", "CreatedAtMixin",
"UpdatedAtMixin", "UpdatedAtMixin",
"TimestampMixin", "TimestampMixin",
"WatchedFieldsMixin", "listens_for",
"watch",
] ]

View File

@@ -6,14 +6,6 @@ from datetime import datetime
from sqlalchemy import DateTime, Uuid, text from sqlalchemy import DateTime, Uuid, text
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
__all__ = [
"UUIDMixin",
"UUIDv7Mixin",
"CreatedAtMixin",
"UpdatedAtMixin",
"TimestampMixin",
]
class UUIDMixin: class UUIDMixin:
"""Mixin that adds a UUID primary key auto-generated by the database.""" """Mixin that adds a UUID primary key auto-generated by the database."""

View File

@@ -1,11 +1,9 @@
"""Field-change monitoring via SQLAlchemy session events.""" """Field-change monitoring via SQLAlchemy session events."""
import asyncio
import inspect import inspect
import weakref from collections.abc import Callable
from collections.abc import Awaitable
from enum import Enum from enum import Enum
from typing import Any, TypeVar from typing import Any
from sqlalchemy import event from sqlalchemy import event
from sqlalchemy import inspect as sa_inspect from sqlalchemy import inspect as sa_inspect
@@ -14,49 +12,81 @@ from sqlalchemy.orm.attributes import set_committed_value as _sa_set_committed_v
from ..logger import get_logger from ..logger import get_logger
__all__ = ["ModelEvent", "WatchedFieldsMixin", "watch"]
_logger = get_logger() _logger = get_logger()
_T = TypeVar("_T")
_CALLBACK_ERROR_MSG = "WatchedFieldsMixin callback raised an unhandled exception"
_WATCHED_FIELDS: weakref.WeakKeyDictionary[type, list[str]] = (
weakref.WeakKeyDictionary()
)
_SESSION_PENDING_NEW = "_ft_pending_new"
_SESSION_CREATES = "_ft_creates"
_SESSION_DELETES = "_ft_deletes"
_SESSION_UPDATES = "_ft_updates"
_SESSION_SAVEPOINT_DEPTH = "_ft_sp_depth"
_DEFERRED_STRATEGY_KEY = (("deferred", True), ("instrument", True))
class ModelEvent(str, Enum): class ModelEvent(str, Enum):
"""Event types emitted by :class:`WatchedFieldsMixin`.""" """Event types dispatched by :class:`EventSession`."""
CREATE = "create" CREATE = "create"
DELETE = "delete" DELETE = "delete"
UPDATE = "update" UPDATE = "update"
def watch(*fields: str) -> Any: _CALLBACK_ERROR_MSG = "Event callback raised an unhandled exception"
"""Class decorator to filter which fields trigger ``on_update``. _SESSION_CREATES = "_ft_creates"
_SESSION_DELETES = "_ft_deletes"
_SESSION_UPDATES = "_ft_updates"
_DEFERRED_STRATEGY_KEY = (("deferred", True), ("instrument", True))
_EVENT_HANDLERS: dict[tuple[type, ModelEvent], list[Callable[..., Any]]] = {}
_WATCHED_MODELS: set[type] = set()
_WATCHED_CACHE: dict[type, bool] = {}
_HANDLER_CACHE: dict[tuple[type, ModelEvent], list[Callable[..., Any]]] = {}
def _invalidate_caches() -> None:
"""Clear lookup caches after handler registration."""
_WATCHED_CACHE.clear()
_HANDLER_CACHE.clear()
def listens_for(
model_class: type,
event_types: list[ModelEvent] | None = None,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""Register a callback for one or more model lifecycle events.
Args: Args:
*fields: One or more field names to watch. At least one name is required. model_class: The SQLAlchemy model class to listen on.
event_types: List of :class:`ModelEvent` values to listen for.
Raises: Defaults to all event types.
ValueError: If called with no field names.
""" """
if not fields: evs = event_types if event_types is not None else list(ModelEvent)
raise ValueError("@watch requires at least one field name.")
def decorator(cls: type[_T]) -> type[_T]: def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
_WATCHED_FIELDS[cls] = list(fields) for ev in evs:
return cls _EVENT_HANDLERS.setdefault((model_class, ev), []).append(fn)
_WATCHED_MODELS.add(model_class)
_invalidate_caches()
return fn
return decorator return decorator
def _is_watched(obj: Any) -> bool:
"""Return True if *obj*'s type (or any ancestor) has registered handlers."""
cls = type(obj)
try:
return _WATCHED_CACHE[cls]
except KeyError:
result = any(klass in _WATCHED_MODELS for klass in cls.__mro__)
_WATCHED_CACHE[cls] = result
return result
def _get_handlers(cls: type, ev: ModelEvent) -> list[Callable[..., Any]]:
"""Return registered handlers for *cls* and *ev*, walking the MRO."""
key = (cls, ev)
try:
return _HANDLER_CACHE[key]
except KeyError:
handlers: list[Callable[..., Any]] = []
for klass in cls.__mro__:
handlers.extend(_EVENT_HANDLERS.get((klass, ev), []))
_HANDLER_CACHE[key] = handlers
return handlers
def _snapshot_column_attrs(obj: Any) -> dict[str, Any]: def _snapshot_column_attrs(obj: Any) -> dict[str, Any]:
"""Read currently-loaded column values into a plain dict.""" """Read currently-loaded column values into a plain dict."""
state = sa_inspect(obj) # InstanceState state = sa_inspect(obj) # InstanceState
@@ -65,7 +95,7 @@ def _snapshot_column_attrs(obj: Any) -> dict[str, Any]:
for prop in state.mapper.column_attrs: for prop in state.mapper.column_attrs:
if prop.key in state_dict: if prop.key in state_dict:
snapshot[prop.key] = state_dict[prop.key] snapshot[prop.key] = state_dict[prop.key]
elif ( elif ( # pragma: no cover
not state.expired not state.expired
and prop.strategy_key != _DEFERRED_STRATEGY_KEY and prop.strategy_key != _DEFERRED_STRATEGY_KEY
and all( and all(
@@ -79,12 +109,17 @@ def _snapshot_column_attrs(obj: Any) -> dict[str, Any]:
return snapshot return snapshot
def _get_watched_fields(cls: type) -> list[str] | None: def _get_watched_fields(cls: type) -> tuple[str, ...] | None:
"""Return the watched fields for *cls*, walking the MRO to inherit from parents.""" """Return the watched fields for *cls*."""
for klass in cls.__mro__: fields = getattr(cls, "__watched_fields__", None)
if klass in _WATCHED_FIELDS: if fields is not None and (
return _WATCHED_FIELDS[klass] not isinstance(fields, tuple) or not all(isinstance(f, str) for f in fields)
return None ):
raise TypeError(
f"{cls.__name__}.__watched_fields__ must be a tuple[str, ...], "
f"got {type(fields).__name__}"
)
return fields
def _upsert_changes( def _upsert_changes(
@@ -105,50 +140,32 @@ def _upsert_changes(
pending[key] = (obj, changes) pending[key] = (obj, changes)
@event.listens_for(AsyncSession.sync_session_class, "after_transaction_create")
def _after_transaction_create(session: Any, transaction: Any) -> None:
if transaction.nested:
session.info[_SESSION_SAVEPOINT_DEPTH] = (
session.info.get(_SESSION_SAVEPOINT_DEPTH, 0) + 1
)
@event.listens_for(AsyncSession.sync_session_class, "after_transaction_end")
def _after_transaction_end(session: Any, transaction: Any) -> None:
if transaction.nested:
depth = session.info.get(_SESSION_SAVEPOINT_DEPTH, 0)
if depth > 0: # pragma: no branch
session.info[_SESSION_SAVEPOINT_DEPTH] = depth - 1
@event.listens_for(AsyncSession.sync_session_class, "after_flush") @event.listens_for(AsyncSession.sync_session_class, "after_flush")
def _after_flush(session: Any, flush_context: Any) -> None: def _after_flush(session: Any, flush_context: Any) -> None:
# New objects: capture references while session.new is still populated. # New objects: capture reference. Attributes will be refreshed after commit.
# Values are read in _after_flush_postexec once RETURNING has been processed.
for obj in session.new: for obj in session.new:
if isinstance(obj, WatchedFieldsMixin): if _is_watched(obj):
session.info.setdefault(_SESSION_PENDING_NEW, []).append(obj) session.info.setdefault(_SESSION_CREATES, []).append(obj)
# Deleted objects: capture before they leave the identity map. # Deleted objects: snapshot now while attributes are still loaded.
for obj in session.deleted: for obj in session.deleted:
if isinstance(obj, WatchedFieldsMixin): if _is_watched(obj):
session.info.setdefault(_SESSION_DELETES, []).append(obj) snapshot = _snapshot_column_attrs(obj)
session.info.setdefault(_SESSION_DELETES, []).append((obj, snapshot))
# Dirty objects: read old/new from SQLAlchemy attribute history. # Dirty objects: read old/new from SQLAlchemy attribute history.
for obj in session.dirty: for obj in session.dirty:
if not isinstance(obj, WatchedFieldsMixin): if not _is_watched(obj):
continue continue
# None = not in dict = watch all fields; list = specific fields only
watched = _get_watched_fields(type(obj)) watched = _get_watched_fields(type(obj))
changes: dict[str, dict[str, Any]] = {} changes: dict[str, dict[str, Any]] = {}
inst_attrs = sa_inspect(obj).attrs
attrs = ( attrs = (
# Specific fields ((field, inst_attrs[field]) for field in watched)
((field, sa_inspect(obj).attrs[field]) for field in watched)
if watched is not None if watched is not None
# All mapped fields else ((s.key, s) for s in inst_attrs)
else ((s.key, s) for s in sa_inspect(obj).attrs)
) )
for field, attr_state in attrs: for field, attr_state in attrs:
history = attr_state.history history = attr_state.history
@@ -166,116 +183,101 @@ def _after_flush(session: Any, flush_context: Any) -> None:
) )
@event.listens_for(AsyncSession.sync_session_class, "after_flush_postexec")
def _after_flush_postexec(session: Any, flush_context: Any) -> None:
# New objects are now persistent and RETURNING values have been applied,
# so server defaults (id, created_at, …) are available via getattr.
pending_new: list[Any] = session.info.pop(_SESSION_PENDING_NEW, [])
if not pending_new:
return
session.info.setdefault(_SESSION_CREATES, []).extend(pending_new)
@event.listens_for(AsyncSession.sync_session_class, "after_rollback") @event.listens_for(AsyncSession.sync_session_class, "after_rollback")
def _after_rollback(session: Any) -> None: def _after_rollback(session: Any) -> None:
session.info.pop(_SESSION_PENDING_NEW, None) if session.in_transaction():
return
session.info.pop(_SESSION_CREATES, None) session.info.pop(_SESSION_CREATES, None)
session.info.pop(_SESSION_DELETES, None) session.info.pop(_SESSION_DELETES, None)
session.info.pop(_SESSION_UPDATES, None) session.info.pop(_SESSION_UPDATES, None)
def _task_error_handler(task: asyncio.Task[Any]) -> None: async def _invoke_callback(
if not task.cancelled() and (exc := task.exception()): fn: Callable[..., Any],
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc) obj: Any,
event_type: ModelEvent,
changes: dict[str, dict[str, Any]] | None,
def _schedule_with_snapshot(
loop: asyncio.AbstractEventLoop, obj: Any, fn: Any, *args: Any
) -> None: ) -> None:
"""Snapshot *obj*'s column attrs now (before expire_on_commit wipes them), """Call *fn* and await the result if it is awaitable."""
then schedule a coroutine that restores the snapshot and calls *fn*. result = fn(obj, event_type, changes)
""" if inspect.isawaitable(result):
snapshot = _snapshot_column_attrs(obj) await result
async def _run(
obj: Any = obj,
fn: Any = fn,
snapshot: dict[str, Any] = snapshot,
args: tuple = args,
) -> None:
for key, value in snapshot.items():
_sa_set_committed_value(obj, key, value)
try:
result = fn(*args)
if inspect.isawaitable(result):
await result
except Exception as exc:
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
task = loop.create_task(_run())
task.add_done_callback(_task_error_handler)
@event.listens_for(AsyncSession.sync_session_class, "after_commit") class EventSession(AsyncSession):
def _after_commit(session: Any) -> None: """AsyncSession subclass that dispatches lifecycle callbacks after commit."""
if session.info.get(_SESSION_SAVEPOINT_DEPTH, 0) > 0:
return
creates: list[Any] = session.info.pop(_SESSION_CREATES, []) async def commit(self) -> None: # noqa: C901
deletes: list[Any] = session.info.pop(_SESSION_DELETES, []) await super().commit()
field_changes: dict[int, tuple[Any, dict[str, dict[str, Any]]]] = session.info.pop(
_SESSION_UPDATES, {}
)
if creates and deletes: creates: list[Any] = self.info.pop(_SESSION_CREATES, [])
transient_ids = {id(o) for o in creates} & {id(o) for o in deletes} deletes: list[tuple[Any, dict[str, Any]]] = self.info.pop(_SESSION_DELETES, [])
if transient_ids: field_changes: dict[int, tuple[Any, dict[str, dict[str, Any]]]] = self.info.pop(
creates = [o for o in creates if id(o) not in transient_ids] _SESSION_UPDATES, {}
deletes = [o for o in deletes if id(o) not in transient_ids] )
if not creates and not deletes and not field_changes:
return
# Suppress transient objects (created + deleted in same transaction).
if creates and deletes:
created_ids = {id(o) for o in creates}
deleted_ids = {id(o) for o, _ in deletes}
transient_ids = created_ids & deleted_ids
if transient_ids:
creates = [o for o in creates if id(o) not in transient_ids]
deletes = [(o, s) for o, s in deletes if id(o) not in transient_ids]
field_changes = {
k: v for k, v in field_changes.items() if k not in transient_ids
}
# Suppress updates for newly created objects (CREATE-only semantics).
if creates and field_changes:
create_ids = {id(o) for o in creates}
field_changes = { field_changes = {
k: v for k, v in field_changes.items() if k not in transient_ids k: v for k, v in field_changes.items() if k not in create_ids
} }
if not creates and not deletes and not field_changes: # Dispatch CREATE callbacks.
return for obj in creates:
try:
state = sa_inspect(obj, raiseerr=False)
if (
state is None or state.detached or state.transient
): # pragma: no cover
continue
await self.refresh(obj)
for handler in _get_handlers(type(obj), ModelEvent.CREATE):
await _invoke_callback(handler, obj, ModelEvent.CREATE, None)
except Exception as exc:
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
try: # Dispatch DELETE callbacks (restore snapshot; row is gone).
loop = asyncio.get_running_loop() for obj, snapshot in deletes:
except RuntimeError: try:
return for key, value in snapshot.items():
_sa_set_committed_value(obj, key, value)
for handler in _get_handlers(type(obj), ModelEvent.DELETE):
await _invoke_callback(handler, obj, ModelEvent.DELETE, None)
except Exception as exc:
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
for obj in creates: # Dispatch UPDATE callbacks.
_schedule_with_snapshot(loop, obj, obj.on_create) for obj, changes in field_changes.values():
try:
state = sa_inspect(obj, raiseerr=False)
if (
state is None or state.detached or state.transient
): # pragma: no cover
continue
await self.refresh(obj)
for handler in _get_handlers(type(obj), ModelEvent.UPDATE):
await _invoke_callback(handler, obj, ModelEvent.UPDATE, changes)
except Exception as exc:
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
for obj in deletes: async def rollback(self) -> None:
_schedule_with_snapshot(loop, obj, obj.on_delete) await super().rollback()
self.info.pop(_SESSION_CREATES, None)
for obj, changes in field_changes.values(): self.info.pop(_SESSION_DELETES, None)
_schedule_with_snapshot(loop, obj, obj.on_update, changes) self.info.pop(_SESSION_UPDATES, None)
class WatchedFieldsMixin:
"""Mixin that enables lifecycle callbacks for SQLAlchemy models."""
def on_event(
self, event: ModelEvent, changes: dict[str, dict[str, Any]] | None = None
) -> Awaitable[None] | None:
"""Catch-all callback fired for every lifecycle event.
Args:
event: The event type (:attr:`ModelEvent.CREATE`, :attr:`ModelEvent.DELETE`,
or :attr:`ModelEvent.UPDATE`).
changes: Field changes for :attr:`ModelEvent.UPDATE`, ``None`` otherwise.
"""
def on_create(self) -> Awaitable[None] | None:
"""Called after INSERT commit."""
return self.on_event(ModelEvent.CREATE)
def on_delete(self) -> Awaitable[None] | None:
"""Called after DELETE commit."""
return self.on_event(ModelEvent.DELETE)
def on_update(self, changes: dict[str, dict[str, Any]]) -> Awaitable[None] | None:
"""Called after UPDATE commit when watched fields change."""
return self.on_event(ModelEvent.UPDATE, changes=changes)

View File

@@ -18,6 +18,7 @@ from sqlalchemy.orm import DeclarativeBase
from ..db import cleanup_tables as _cleanup_tables from ..db import cleanup_tables as _cleanup_tables
from ..db import create_database from ..db import create_database
from ..models.watched import EventSession
async def cleanup_tables( async def cleanup_tables(
@@ -265,7 +266,9 @@ async def create_db_session(
async with engine.begin() as conn: async with engine.begin() as conn:
await conn.run_sync(base.metadata.create_all) await conn.run_sync(base.metadata.create_all)
session_maker = async_sessionmaker(engine, expire_on_commit=expire_on_commit) session_maker = async_sessionmaker(
engine, expire_on_commit=expire_on_commit, class_=EventSession
)
async with session_maker() as session: async with session_maker() as session:
yield session yield session

File diff suppressed because it is too large Load Diff

View File

@@ -140,6 +140,7 @@ Examples = [
[[project.nav]] [[project.nav]]
Migration = [ Migration = [
{"v3.0" = "migration/v3.md"},
{"v2.0" = "migration/v2.md"}, {"v2.0" = "migration/v2.md"},
] ]