From 1890d696bf20f5cee78c34652f025162e7d13113 Mon Sep 17 00:00:00 2001 From: d3vyce <44915747+d3vyce@users.noreply.github.com> Date: Mon, 30 Mar 2026 18:24:36 +0200 Subject: [PATCH] feat: rework async event system (#196) * feat: rework async event system * docs: add v3 migration guide * feat: add cache * enhancements --- README.md | 3 +- docs/index.md | 3 +- docs/migration/v3.md | 93 ++ docs/module/models.md | 167 ++-- docs/reference/models.md | 10 +- src/fastapi_toolsets/db.py | 13 +- src/fastapi_toolsets/models/__init__.py | 6 +- src/fastapi_toolsets/models/columns.py | 8 - src/fastapi_toolsets/models/watched.py | 322 ++++--- src/fastapi_toolsets/pytest/utils.py | 5 +- tests/test_models.py | 1177 +++++++++++++++-------- zensical.toml | 1 + 12 files changed, 1149 insertions(+), 659 deletions(-) create mode 100644 docs/migration/v3.md diff --git a/README.md b/README.md index 652d34d..0250824 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,8 @@ uv add "fastapi-toolsets[all]" - **Database**: Session management, transaction helpers, table locking, and polling-based row change detection - **Dependencies**: FastAPI dependency factories (`PathDependency`, `BodyDependency`) for automatic DB lookups from path or body parameters - **Fixtures**: Fixture system with dependency management, context support, and pytest integration -- **Model Mixins**: SQLAlchemy mixins for common column patterns (`UUIDMixin`, `UUIDv7Mixin`, `CreatedAtMixin`, `UpdatedAtMixin`, `TimestampMixin`) and lifecycle callbacks (`WatchedFieldsMixin`, `@watch`) that fire after commit for insert, update, and delete events +- **Model Mixins**: SQLAlchemy mixins for common column patterns (`UUIDMixin`, `UUIDv7Mixin`, `CreatedAtMixin`, `UpdatedAtMixin`, `TimestampMixin`) +- **Lifecycle Events**: Post-commit event system (`EventSession`, `listens_for`) that dispatches async/sync callbacks for insert, update, and delete operations - **Standardized API Responses**: Consistent response format with `Response`, `ErrorResponse`, `PaginatedResponse`, `CursorPaginatedResponse` and `OffsetPaginatedResponse`. - **Exception Handling**: Structured error responses with automatic OpenAPI documentation - **Logging**: Logging configuration with uvicorn integration via `configure_logging` and `get_logger` diff --git a/docs/index.md b/docs/index.md index 6c4dd40..ac923e4 100644 --- a/docs/index.md +++ b/docs/index.md @@ -48,7 +48,8 @@ uv add "fastapi-toolsets[all]" - **Database**: Session management, transaction helpers, table locking, and polling-based row change detection - **Dependencies**: FastAPI dependency factories (`PathDependency`, `BodyDependency`) for automatic DB lookups from path or body parameters - **Fixtures**: Fixture system with dependency management, context support, and pytest integration -- **Model Mixins**: SQLAlchemy mixins for common column patterns (`UUIDMixin`, `UUIDv7Mixin`, `CreatedAtMixin`, `UpdatedAtMixin`, `TimestampMixin`) and lifecycle callbacks (`WatchedFieldsMixin`) that fire after commit for insert, update, and delete events. +- **Model Mixins**: SQLAlchemy mixins for common column patterns (`UUIDMixin`, `UUIDv7Mixin`, `CreatedAtMixin`, `UpdatedAtMixin`, `TimestampMixin`). +- **Lifecycle Events**: Post-commit event system (`EventSession`, `listens_for`) that dispatches async/sync callbacks for insert, update, and delete operations. - **Standardized API Responses**: Consistent response format with `Response`, `ErrorResponse`, `PaginatedResponse`, `CursorPaginatedResponse` and `OffsetPaginatedResponse`. - **Exception Handling**: Structured error responses with automatic OpenAPI documentation - **Logging**: Logging configuration with uvicorn integration via `configure_logging` and `get_logger` diff --git a/docs/migration/v3.md b/docs/migration/v3.md new file mode 100644 index 0000000..56ea99c --- /dev/null +++ b/docs/migration/v3.md @@ -0,0 +1,93 @@ +# Migrating to v3.0 + +This page covers every breaking change introduced in **v3.0** and the steps required to update your code. + +--- + +## Models + +The lifecycle event system has been rewritten. Callbacks are now registered with a module-level [`listens_for`](../reference/models.md#fastapi_toolsets.models.listens_for) decorator and dispatched by [`EventSession`](../reference/models.md#fastapi_toolsets.models.EventSession), replacing the mixin-based approach from `v2`. + +### `WatchedFieldsMixin` and `@watch` removed + +Importing `WatchedFieldsMixin` or `watch` will raise `ImportError`. + +Model method callbacks (`on_create`, `on_delete`, `on_update`) and the `@watch` decorator are replaced by: + +1. **`__watched_fields__`** — a plain class attribute to restrict which field changes trigger `UPDATE` events (replaces `@watch`). +2. **`@listens_for`** — a module-level decorator to register callbacks for one or more [`ModelEvent`](../reference/models.md#fastapi_toolsets.models.ModelEvent) types (replaces `on_create` / `on_delete` / `on_update` methods). + +=== "Before (`v2`)" + + ```python + from fastapi_toolsets.models import WatchedFieldsMixin, watch + + @watch("status") + class Order(Base, UUIDMixin, WatchedFieldsMixin): + __tablename__ = "orders" + + status: Mapped[str] + + async def on_create(self): + await notify_new_order(self.id) + + async def on_update(self, changes): + if "status" in changes: + await notify_status_change(self.id, changes["status"]) + + async def on_delete(self): + await notify_order_cancelled(self.id) + ``` + +=== "Now (`v3`)" + + ```python + from fastapi_toolsets.models import ModelEvent, UUIDMixin, listens_for + + class Order(Base, UUIDMixin): + __tablename__ = "orders" + __watched_fields__ = ("status",) + + status: Mapped[str] + + @listens_for(Order, [ModelEvent.CREATE]) + async def on_order_created(order: Order, event_type: ModelEvent, changes: None): + await notify_new_order(order.id) + + @listens_for(Order, [ModelEvent.UPDATE]) + async def on_order_updated(order: Order, event_type: ModelEvent, changes: dict): + if "status" in changes: + await notify_status_change(order.id, changes["status"]) + + @listens_for(Order, [ModelEvent.DELETE]) + async def on_order_deleted(order: Order, event_type: ModelEvent, changes: None): + await notify_order_cancelled(order.id) + ``` + +### `EventSession` now required + +Without `EventSession`, lifecycle callbacks will silently stop firing. + +Callbacks are now dispatched inside `EventSession.commit()` rather than via background tasks. Pass it as the session class when creating your session factory: + +=== "Before (`v2`)" + + ```python + from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine + + engine = create_async_engine("postgresql+asyncpg://...") + SessionLocal = async_sessionmaker(engine, expire_on_commit=False) + ``` + +=== "Now (`v3`)" + + ```python + from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine + from fastapi_toolsets.models import EventSession + + engine = create_async_engine("postgresql+asyncpg://...") + SessionLocal = async_sessionmaker(engine, expire_on_commit=False, class_=EventSession) + ``` + +!!! note + If you use `create_db_session` from `fastapi_toolsets.pytest`, the session already uses `EventSession` — no changes needed in tests. diff --git a/docs/module/models.md b/docs/module/models.md index 5ed4d77..eae194c 100644 --- a/docs/module/models.md +++ b/docs/module/models.md @@ -117,139 +117,118 @@ class Article(Base, UUIDMixin, TimestampMixin): title: Mapped[str] ``` -### [`WatchedFieldsMixin`](../reference/models.md#fastapi_toolsets.models.WatchedFieldsMixin) +## Lifecycle events -!!! info "Added in `v2.4`" +The event system provides lifecycle callbacks that fire **after commit**. If the transaction rolls back, no callback fires. -`WatchedFieldsMixin` provides lifecycle callbacks that fire **after commit** — meaning the row is durably persisted when your callback runs. If the transaction rolls back, no callback fires. +### Setup -Three callbacks are available, each corresponding to a [`ModelEvent`](../reference/models.md#fastapi_toolsets.models.ModelEvent) value: - -| Callback | Event | Trigger | -|---|---|---| -| `on_create()` | `ModelEvent.CREATE` | After `INSERT` | -| `on_delete()` | `ModelEvent.DELETE` | After `DELETE` | -| `on_update(changes)` | `ModelEvent.UPDATE` | After `UPDATE` on a watched field | - -Server-side defaults (e.g. `id`, `created_at`) are fully populated in all callbacks. All callbacks support both `async def` and plain `def`. Use `@watch` to restrict which fields trigger `on_update`: - -| Decorator | `on_update` behaviour | -|---|---| -| `@watch("status", "role")` | Only fires when `status` or `role` changes | -| *(no decorator)* | Fires when **any** mapped field changes | - -`@watch` is inherited through the class hierarchy. If a subclass does not declare its own `@watch`, it uses the filter from the nearest decorated parent. Applying `@watch` on the subclass overrides the parent's filter: +Event dispatch requires [`EventSession`](../reference/models.md#fastapi_toolsets.models.EventSession). Pass it as the session class when creating your session factory: ```python -@watch("status") -class Order(Base, UUIDMixin, WatchedFieldsMixin): +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine +from fastapi_toolsets.models import EventSession + +engine = create_async_engine("postgresql+asyncpg://...") +SessionLocal = async_sessionmaker(engine, expire_on_commit=False, class_=EventSession) +``` + +!!! info "Callbacks fire on `session.commit()` only — not on savepoints." + Savepoints created by [`get_transaction`](db.md) or `begin_nested()` do **not** + trigger callbacks. All events accumulated across flushes are dispatched once + when the outermost `commit()` is called. + +### Events + +Three event types are available, each corresponding to a [`ModelEvent`](../reference/models.md#fastapi_toolsets.models.ModelEvent) value: + +| Event | Trigger | +|---|---| +| `ModelEvent.CREATE` | After `INSERT` commit | +| `ModelEvent.DELETE` | After `DELETE` commit | +| `ModelEvent.UPDATE` | After `UPDATE` commit on a watched field | + +!!! warning "Callbacks fire only for ORM-level changes. Rows updated via raw SQL (`UPDATE ... SET ...`) are not detected." + +### Watched fields + +Set `__watched_fields__` on the model to restrict which field changes trigger `UPDATE` events. It must be a `tuple[str, ...]` — any other type raises `TypeError`: + +| Class attribute | `UPDATE` behaviour | +|---|---| +| `__watched_fields__ = ("status", "role")` | Only fires when `status` or `role` changes | +| *(not set)* | Fires when **any** mapped field changes | + +`__watched_fields__` is inherited through the class hierarchy via normal Python MRO. A subclass can override it: + +```python +class Order(Base, UUIDMixin): + __watched_fields__ = ("status",) ... class UrgentOrder(Order): - # inherits @watch("status") — on_update fires only for status changes + # inherits __watched_fields__ = ("status",) ... -@watch("priority") class PriorityOrder(Order): - # overrides parent — on_update fires only for priority changes + __watched_fields__ = ("priority",) + # overrides parent — UPDATE fires only for priority changes ... ``` -#### Option 1 — catch-all with `on_event` +### Registering handlers -Override `on_event` to handle all event types in one place. The specific methods delegate here by default: +Register handlers with the [`listens_for`](../reference/models.md#fastapi_toolsets.models.listens_for) decorator. Every callback receives three arguments: the model instance, the [`ModelEvent`](../reference/models.md#fastapi_toolsets.models.ModelEvent) that triggered it, and a `changes` dict (`None` for `CREATE` and `DELETE`): ```python -from fastapi_toolsets.models import ModelEvent, UUIDMixin, WatchedFieldsMixin, watch +from fastapi_toolsets.models import ModelEvent, UUIDMixin, listens_for -@watch("status") -class Order(Base, UUIDMixin, WatchedFieldsMixin): +class Order(Base, UUIDMixin): __tablename__ = "orders" + __watched_fields__ = ("status",) status: Mapped[str] - async def on_event(self, event: ModelEvent, changes: dict | None = None) -> None: - if event == ModelEvent.CREATE: - await notify_new_order(self.id) - elif event == ModelEvent.DELETE: - await notify_order_cancelled(self.id) - elif event == ModelEvent.UPDATE: - await notify_status_change(self.id, changes["status"]) +@listens_for(Order, [ModelEvent.CREATE]) +async def on_order_created(order: Order, event_type: ModelEvent, changes: None): + await notify_new_order(order.id) + +@listens_for(Order, [ModelEvent.DELETE]) +async def on_order_deleted(order: Order, event_type: ModelEvent, changes: None): + await notify_order_cancelled(order.id) + +@listens_for(Order, [ModelEvent.UPDATE]) +async def on_order_updated(order: Order, event_type: ModelEvent, changes: dict): + if "status" in changes: + await notify_status_change(order.id, changes["status"]) ``` -#### Option 2 — targeted overrides +Multiple handlers can be registered for the same model and event. Handlers registered on a parent class also fire for subclass instances. -Override individual methods for more focused logic: +A single handler can listen for multiple events at once. When `event_types` is omitted, the handler fires for all events: ```python -@watch("status") -class Order(Base, UUIDMixin, WatchedFieldsMixin): - __tablename__ = "orders" +@listens_for(Order, [ModelEvent.CREATE, ModelEvent.UPDATE]) +async def on_order_changed(order: Order, event_type: ModelEvent, changes: dict | None): + await invalidate_cache(order.id) - status: Mapped[str] - - async def on_create(self) -> None: - await notify_new_order(self.id) - - async def on_delete(self) -> None: - await notify_order_cancelled(self.id) - - async def on_update(self, changes: dict) -> None: - if "status" in changes: - old = changes["status"]["old"] - new = changes["status"]["new"] - await notify_status_change(self.id, old, new) +@listens_for(Order) # all events +async def on_any_order_event(order: Order, event_type: ModelEvent, changes: dict | None): + await audit_log(order.id, event_type) ``` -#### Field changes format +### Field changes format -The `changes` dict maps each watched field that changed to `{"old": ..., "new": ...}`. Only fields that actually changed are included: +The `changes` dict maps each watched field that changed to `{"old": ..., "new": ...}`. Only fields that actually changed are included. For `CREATE` and `DELETE` events, `changes` is `None`: ```python +# CREATE / DELETE → changes is None # status changed → {"status": {"old": "pending", "new": "shipped"}} # two fields changed → {"status": {...}, "assigned_to": {...}} ``` !!! info "Multiple flushes in one transaction are merged: the earliest `old` and latest `new` are preserved, and `on_update` fires only once per commit." -!!! warning "Callbacks fire only for ORM-level changes. Rows updated via raw SQL (`UPDATE ... SET ...`) are not detected." - -!!! warning "Callbacks fire when the **outermost active context** (savepoint or transaction) commits." - If you create several related objects using `CrudFactory.create` and need - callbacks to see all of them (including associations), wrap the whole - operation in a single [`get_transaction`](db.md) or [`lock_tables`](db.md) - block. Without it, each `create` call commits its own savepoint and - `on_create` fires before the remaining objects exist. - - ```python - from fastapi_toolsets.db import get_transaction - - async with get_transaction(session): - order = await OrderCrud.create(session, order_data) - item = await ItemCrud.create(session, item_data) - await session.refresh(order, attribute_names=["items"]) - order.items.append(item) - # on_create fires here for both order and item, - # with the full association already committed. - ``` - -## Composing mixins - -All mixins can be combined in any order. The only constraint is that exactly one primary key must be defined — either via `UUIDMixin` or directly on the model. - -```python -from fastapi_toolsets.models import UUIDMixin, TimestampMixin - -class Event(Base, UUIDMixin, TimestampMixin): - __tablename__ = "events" - name: Mapped[str] - -class Counter(Base, UpdatedAtMixin): - __tablename__ = "counters" - id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) - value: Mapped[int] -``` - --- [:material-api: API Reference](../reference/models.md) diff --git a/docs/reference/models.md b/docs/reference/models.md index a30d5b6..accbc94 100644 --- a/docs/reference/models.md +++ b/docs/reference/models.md @@ -6,17 +6,19 @@ You can import them directly from `fastapi_toolsets.models`: ```python from fastapi_toolsets.models import ( + EventSession, ModelEvent, UUIDMixin, UUIDv7Mixin, CreatedAtMixin, UpdatedAtMixin, TimestampMixin, - WatchedFieldsMixin, - watch, + listens_for, ) ``` +## ::: fastapi_toolsets.models.EventSession + ## ::: fastapi_toolsets.models.ModelEvent ## ::: fastapi_toolsets.models.UUIDMixin @@ -29,6 +31,4 @@ from fastapi_toolsets.models import ( ## ::: fastapi_toolsets.models.TimestampMixin -## ::: fastapi_toolsets.models.WatchedFieldsMixin - -## ::: fastapi_toolsets.models.watch +## ::: fastapi_toolsets.models.listens_for diff --git a/src/fastapi_toolsets/db.py b/src/fastapi_toolsets/db.py index 67dffd2..ac93a21 100644 --- a/src/fastapi_toolsets/db.py +++ b/src/fastapi_toolsets/db.py @@ -24,9 +24,12 @@ __all__ = [ ] +_SessionT = TypeVar("_SessionT", bound=AsyncSession) + + def create_db_dependency( - session_maker: async_sessionmaker[AsyncSession], -) -> Callable[[], AsyncGenerator[AsyncSession, None]]: + session_maker: async_sessionmaker[_SessionT], +) -> Callable[[], AsyncGenerator[_SessionT, None]]: """Create a FastAPI dependency for database sessions. Creates a dependency function that yields a session and auto-commits @@ -54,7 +57,7 @@ def create_db_dependency( ``` """ - async def get_db() -> AsyncGenerator[AsyncSession, None]: + async def get_db() -> AsyncGenerator[_SessionT, None]: async with session_maker() as session: await session.connection() yield session @@ -65,8 +68,8 @@ def create_db_dependency( def create_db_context( - session_maker: async_sessionmaker[AsyncSession], -) -> Callable[[], AbstractAsyncContextManager[AsyncSession]]: + session_maker: async_sessionmaker[_SessionT], +) -> Callable[[], AbstractAsyncContextManager[_SessionT]]: """Create a context manager for database sessions. Creates a context manager for use outside of FastAPI request handlers, diff --git a/src/fastapi_toolsets/models/__init__.py b/src/fastapi_toolsets/models/__init__.py index 5af6821..a54459d 100644 --- a/src/fastapi_toolsets/models/__init__.py +++ b/src/fastapi_toolsets/models/__init__.py @@ -7,15 +7,15 @@ from .columns import ( UUIDv7Mixin, UpdatedAtMixin, ) -from .watched import ModelEvent, WatchedFieldsMixin, watch +from .watched import EventSession, ModelEvent, listens_for __all__ = [ + "EventSession", "ModelEvent", "UUIDMixin", "UUIDv7Mixin", "CreatedAtMixin", "UpdatedAtMixin", "TimestampMixin", - "WatchedFieldsMixin", - "watch", + "listens_for", ] diff --git a/src/fastapi_toolsets/models/columns.py b/src/fastapi_toolsets/models/columns.py index bdbc38d..4253daa 100644 --- a/src/fastapi_toolsets/models/columns.py +++ b/src/fastapi_toolsets/models/columns.py @@ -6,14 +6,6 @@ from datetime import datetime from sqlalchemy import DateTime, Uuid, text from sqlalchemy.orm import Mapped, mapped_column -__all__ = [ - "UUIDMixin", - "UUIDv7Mixin", - "CreatedAtMixin", - "UpdatedAtMixin", - "TimestampMixin", -] - class UUIDMixin: """Mixin that adds a UUID primary key auto-generated by the database.""" diff --git a/src/fastapi_toolsets/models/watched.py b/src/fastapi_toolsets/models/watched.py index 65852bf..35774d1 100644 --- a/src/fastapi_toolsets/models/watched.py +++ b/src/fastapi_toolsets/models/watched.py @@ -1,11 +1,9 @@ """Field-change monitoring via SQLAlchemy session events.""" -import asyncio import inspect -import weakref -from collections.abc import Awaitable +from collections.abc import Callable from enum import Enum -from typing import Any, TypeVar +from typing import Any from sqlalchemy import event from sqlalchemy import inspect as sa_inspect @@ -14,49 +12,81 @@ from sqlalchemy.orm.attributes import set_committed_value as _sa_set_committed_v from ..logger import get_logger -__all__ = ["ModelEvent", "WatchedFieldsMixin", "watch"] - _logger = get_logger() -_T = TypeVar("_T") -_CALLBACK_ERROR_MSG = "WatchedFieldsMixin callback raised an unhandled exception" -_WATCHED_FIELDS: weakref.WeakKeyDictionary[type, list[str]] = ( - weakref.WeakKeyDictionary() -) -_SESSION_PENDING_NEW = "_ft_pending_new" -_SESSION_CREATES = "_ft_creates" -_SESSION_DELETES = "_ft_deletes" -_SESSION_UPDATES = "_ft_updates" -_SESSION_SAVEPOINT_DEPTH = "_ft_sp_depth" -_DEFERRED_STRATEGY_KEY = (("deferred", True), ("instrument", True)) class ModelEvent(str, Enum): - """Event types emitted by :class:`WatchedFieldsMixin`.""" + """Event types dispatched by :class:`EventSession`.""" CREATE = "create" DELETE = "delete" UPDATE = "update" -def watch(*fields: str) -> Any: - """Class decorator to filter which fields trigger ``on_update``. +_CALLBACK_ERROR_MSG = "Event callback raised an unhandled exception" +_SESSION_CREATES = "_ft_creates" +_SESSION_DELETES = "_ft_deletes" +_SESSION_UPDATES = "_ft_updates" +_DEFERRED_STRATEGY_KEY = (("deferred", True), ("instrument", True)) +_EVENT_HANDLERS: dict[tuple[type, ModelEvent], list[Callable[..., Any]]] = {} +_WATCHED_MODELS: set[type] = set() +_WATCHED_CACHE: dict[type, bool] = {} +_HANDLER_CACHE: dict[tuple[type, ModelEvent], list[Callable[..., Any]]] = {} + + +def _invalidate_caches() -> None: + """Clear lookup caches after handler registration.""" + _WATCHED_CACHE.clear() + _HANDLER_CACHE.clear() + + +def listens_for( + model_class: type, + event_types: list[ModelEvent] | None = None, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """Register a callback for one or more model lifecycle events. Args: - *fields: One or more field names to watch. At least one name is required. - - Raises: - ValueError: If called with no field names. + model_class: The SQLAlchemy model class to listen on. + event_types: List of :class:`ModelEvent` values to listen for. + Defaults to all event types. """ - if not fields: - raise ValueError("@watch requires at least one field name.") + evs = event_types if event_types is not None else list(ModelEvent) - def decorator(cls: type[_T]) -> type[_T]: - _WATCHED_FIELDS[cls] = list(fields) - return cls + def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: + for ev in evs: + _EVENT_HANDLERS.setdefault((model_class, ev), []).append(fn) + _WATCHED_MODELS.add(model_class) + _invalidate_caches() + return fn return decorator +def _is_watched(obj: Any) -> bool: + """Return True if *obj*'s type (or any ancestor) has registered handlers.""" + cls = type(obj) + try: + return _WATCHED_CACHE[cls] + except KeyError: + result = any(klass in _WATCHED_MODELS for klass in cls.__mro__) + _WATCHED_CACHE[cls] = result + return result + + +def _get_handlers(cls: type, ev: ModelEvent) -> list[Callable[..., Any]]: + """Return registered handlers for *cls* and *ev*, walking the MRO.""" + key = (cls, ev) + try: + return _HANDLER_CACHE[key] + except KeyError: + handlers: list[Callable[..., Any]] = [] + for klass in cls.__mro__: + handlers.extend(_EVENT_HANDLERS.get((klass, ev), [])) + _HANDLER_CACHE[key] = handlers + return handlers + + def _snapshot_column_attrs(obj: Any) -> dict[str, Any]: """Read currently-loaded column values into a plain dict.""" state = sa_inspect(obj) # InstanceState @@ -65,7 +95,7 @@ def _snapshot_column_attrs(obj: Any) -> dict[str, Any]: for prop in state.mapper.column_attrs: if prop.key in state_dict: snapshot[prop.key] = state_dict[prop.key] - elif ( + elif ( # pragma: no cover not state.expired and prop.strategy_key != _DEFERRED_STRATEGY_KEY and all( @@ -79,12 +109,17 @@ def _snapshot_column_attrs(obj: Any) -> dict[str, Any]: return snapshot -def _get_watched_fields(cls: type) -> list[str] | None: - """Return the watched fields for *cls*, walking the MRO to inherit from parents.""" - for klass in cls.__mro__: - if klass in _WATCHED_FIELDS: - return _WATCHED_FIELDS[klass] - return None +def _get_watched_fields(cls: type) -> tuple[str, ...] | None: + """Return the watched fields for *cls*.""" + fields = getattr(cls, "__watched_fields__", None) + if fields is not None and ( + not isinstance(fields, tuple) or not all(isinstance(f, str) for f in fields) + ): + raise TypeError( + f"{cls.__name__}.__watched_fields__ must be a tuple[str, ...], " + f"got {type(fields).__name__}" + ) + return fields def _upsert_changes( @@ -105,50 +140,32 @@ def _upsert_changes( pending[key] = (obj, changes) -@event.listens_for(AsyncSession.sync_session_class, "after_transaction_create") -def _after_transaction_create(session: Any, transaction: Any) -> None: - if transaction.nested: - session.info[_SESSION_SAVEPOINT_DEPTH] = ( - session.info.get(_SESSION_SAVEPOINT_DEPTH, 0) + 1 - ) - - -@event.listens_for(AsyncSession.sync_session_class, "after_transaction_end") -def _after_transaction_end(session: Any, transaction: Any) -> None: - if transaction.nested: - depth = session.info.get(_SESSION_SAVEPOINT_DEPTH, 0) - if depth > 0: # pragma: no branch - session.info[_SESSION_SAVEPOINT_DEPTH] = depth - 1 - - @event.listens_for(AsyncSession.sync_session_class, "after_flush") def _after_flush(session: Any, flush_context: Any) -> None: - # New objects: capture references while session.new is still populated. - # Values are read in _after_flush_postexec once RETURNING has been processed. + # New objects: capture reference. Attributes will be refreshed after commit. for obj in session.new: - if isinstance(obj, WatchedFieldsMixin): - session.info.setdefault(_SESSION_PENDING_NEW, []).append(obj) + if _is_watched(obj): + session.info.setdefault(_SESSION_CREATES, []).append(obj) - # Deleted objects: capture before they leave the identity map. + # Deleted objects: snapshot now while attributes are still loaded. for obj in session.deleted: - if isinstance(obj, WatchedFieldsMixin): - session.info.setdefault(_SESSION_DELETES, []).append(obj) + if _is_watched(obj): + snapshot = _snapshot_column_attrs(obj) + session.info.setdefault(_SESSION_DELETES, []).append((obj, snapshot)) # Dirty objects: read old/new from SQLAlchemy attribute history. for obj in session.dirty: - if not isinstance(obj, WatchedFieldsMixin): + if not _is_watched(obj): continue - # None = not in dict = watch all fields; list = specific fields only watched = _get_watched_fields(type(obj)) changes: dict[str, dict[str, Any]] = {} + inst_attrs = sa_inspect(obj).attrs attrs = ( - # Specific fields - ((field, sa_inspect(obj).attrs[field]) for field in watched) + ((field, inst_attrs[field]) for field in watched) if watched is not None - # All mapped fields - else ((s.key, s) for s in sa_inspect(obj).attrs) + else ((s.key, s) for s in inst_attrs) ) for field, attr_state in attrs: history = attr_state.history @@ -166,116 +183,101 @@ def _after_flush(session: Any, flush_context: Any) -> None: ) -@event.listens_for(AsyncSession.sync_session_class, "after_flush_postexec") -def _after_flush_postexec(session: Any, flush_context: Any) -> None: - # New objects are now persistent and RETURNING values have been applied, - # so server defaults (id, created_at, …) are available via getattr. - pending_new: list[Any] = session.info.pop(_SESSION_PENDING_NEW, []) - if not pending_new: - return - session.info.setdefault(_SESSION_CREATES, []).extend(pending_new) - - @event.listens_for(AsyncSession.sync_session_class, "after_rollback") def _after_rollback(session: Any) -> None: - session.info.pop(_SESSION_PENDING_NEW, None) + if session.in_transaction(): + return session.info.pop(_SESSION_CREATES, None) session.info.pop(_SESSION_DELETES, None) session.info.pop(_SESSION_UPDATES, None) -def _task_error_handler(task: asyncio.Task[Any]) -> None: - if not task.cancelled() and (exc := task.exception()): - _logger.error(_CALLBACK_ERROR_MSG, exc_info=exc) - - -def _schedule_with_snapshot( - loop: asyncio.AbstractEventLoop, obj: Any, fn: Any, *args: Any +async def _invoke_callback( + fn: Callable[..., Any], + obj: Any, + event_type: ModelEvent, + changes: dict[str, dict[str, Any]] | None, ) -> None: - """Snapshot *obj*'s column attrs now (before expire_on_commit wipes them), - then schedule a coroutine that restores the snapshot and calls *fn*. - """ - snapshot = _snapshot_column_attrs(obj) - - async def _run( - obj: Any = obj, - fn: Any = fn, - snapshot: dict[str, Any] = snapshot, - args: tuple = args, - ) -> None: - for key, value in snapshot.items(): - _sa_set_committed_value(obj, key, value) - try: - result = fn(*args) - if inspect.isawaitable(result): - await result - except Exception as exc: - _logger.error(_CALLBACK_ERROR_MSG, exc_info=exc) - - task = loop.create_task(_run()) - task.add_done_callback(_task_error_handler) + """Call *fn* and await the result if it is awaitable.""" + result = fn(obj, event_type, changes) + if inspect.isawaitable(result): + await result -@event.listens_for(AsyncSession.sync_session_class, "after_commit") -def _after_commit(session: Any) -> None: - if session.info.get(_SESSION_SAVEPOINT_DEPTH, 0) > 0: - return +class EventSession(AsyncSession): + """AsyncSession subclass that dispatches lifecycle callbacks after commit.""" - creates: list[Any] = session.info.pop(_SESSION_CREATES, []) - deletes: list[Any] = session.info.pop(_SESSION_DELETES, []) - field_changes: dict[int, tuple[Any, dict[str, dict[str, Any]]]] = session.info.pop( - _SESSION_UPDATES, {} - ) + async def commit(self) -> None: # noqa: C901 + await super().commit() - if creates and deletes: - transient_ids = {id(o) for o in creates} & {id(o) for o in deletes} - if transient_ids: - creates = [o for o in creates if id(o) not in transient_ids] - deletes = [o for o in deletes if id(o) not in transient_ids] + creates: list[Any] = self.info.pop(_SESSION_CREATES, []) + deletes: list[tuple[Any, dict[str, Any]]] = self.info.pop(_SESSION_DELETES, []) + field_changes: dict[int, tuple[Any, dict[str, dict[str, Any]]]] = self.info.pop( + _SESSION_UPDATES, {} + ) + + if not creates and not deletes and not field_changes: + return + + # Suppress transient objects (created + deleted in same transaction). + if creates and deletes: + created_ids = {id(o) for o in creates} + deleted_ids = {id(o) for o, _ in deletes} + transient_ids = created_ids & deleted_ids + if transient_ids: + creates = [o for o in creates if id(o) not in transient_ids] + deletes = [(o, s) for o, s in deletes if id(o) not in transient_ids] + field_changes = { + k: v for k, v in field_changes.items() if k not in transient_ids + } + + # Suppress updates for newly created objects (CREATE-only semantics). + if creates and field_changes: + create_ids = {id(o) for o in creates} field_changes = { - k: v for k, v in field_changes.items() if k not in transient_ids + k: v for k, v in field_changes.items() if k not in create_ids } - if not creates and not deletes and not field_changes: - return + # Dispatch CREATE callbacks. + for obj in creates: + try: + state = sa_inspect(obj, raiseerr=False) + if ( + state is None or state.detached or state.transient + ): # pragma: no cover + continue + await self.refresh(obj) + for handler in _get_handlers(type(obj), ModelEvent.CREATE): + await _invoke_callback(handler, obj, ModelEvent.CREATE, None) + except Exception as exc: + _logger.error(_CALLBACK_ERROR_MSG, exc_info=exc) - try: - loop = asyncio.get_running_loop() - except RuntimeError: - return + # Dispatch DELETE callbacks (restore snapshot; row is gone). + for obj, snapshot in deletes: + try: + for key, value in snapshot.items(): + _sa_set_committed_value(obj, key, value) + for handler in _get_handlers(type(obj), ModelEvent.DELETE): + await _invoke_callback(handler, obj, ModelEvent.DELETE, None) + except Exception as exc: + _logger.error(_CALLBACK_ERROR_MSG, exc_info=exc) - for obj in creates: - _schedule_with_snapshot(loop, obj, obj.on_create) + # Dispatch UPDATE callbacks. + for obj, changes in field_changes.values(): + try: + state = sa_inspect(obj, raiseerr=False) + if ( + state is None or state.detached or state.transient + ): # pragma: no cover + continue + await self.refresh(obj) + for handler in _get_handlers(type(obj), ModelEvent.UPDATE): + await _invoke_callback(handler, obj, ModelEvent.UPDATE, changes) + except Exception as exc: + _logger.error(_CALLBACK_ERROR_MSG, exc_info=exc) - for obj in deletes: - _schedule_with_snapshot(loop, obj, obj.on_delete) - - for obj, changes in field_changes.values(): - _schedule_with_snapshot(loop, obj, obj.on_update, changes) - - -class WatchedFieldsMixin: - """Mixin that enables lifecycle callbacks for SQLAlchemy models.""" - - def on_event( - self, event: ModelEvent, changes: dict[str, dict[str, Any]] | None = None - ) -> Awaitable[None] | None: - """Catch-all callback fired for every lifecycle event. - - Args: - event: The event type (:attr:`ModelEvent.CREATE`, :attr:`ModelEvent.DELETE`, - or :attr:`ModelEvent.UPDATE`). - changes: Field changes for :attr:`ModelEvent.UPDATE`, ``None`` otherwise. - """ - - def on_create(self) -> Awaitable[None] | None: - """Called after INSERT commit.""" - return self.on_event(ModelEvent.CREATE) - - def on_delete(self) -> Awaitable[None] | None: - """Called after DELETE commit.""" - return self.on_event(ModelEvent.DELETE) - - def on_update(self, changes: dict[str, dict[str, Any]]) -> Awaitable[None] | None: - """Called after UPDATE commit when watched fields change.""" - return self.on_event(ModelEvent.UPDATE, changes=changes) + async def rollback(self) -> None: + await super().rollback() + self.info.pop(_SESSION_CREATES, None) + self.info.pop(_SESSION_DELETES, None) + self.info.pop(_SESSION_UPDATES, None) diff --git a/src/fastapi_toolsets/pytest/utils.py b/src/fastapi_toolsets/pytest/utils.py index db5f118..f97316c 100644 --- a/src/fastapi_toolsets/pytest/utils.py +++ b/src/fastapi_toolsets/pytest/utils.py @@ -18,6 +18,7 @@ from sqlalchemy.orm import DeclarativeBase from ..db import cleanup_tables as _cleanup_tables from ..db import create_database +from ..models.watched import EventSession async def cleanup_tables( @@ -265,7 +266,9 @@ async def create_db_session( async with engine.begin() as conn: await conn.run_sync(base.metadata.create_all) - session_maker = async_sessionmaker(engine, expire_on_commit=expire_on_commit) + session_maker = async_sessionmaker( + engine, expire_on_commit=expire_on_commit, class_=EventSession + ) async with session_maker() as session: yield session diff --git a/tests/test_models.py b/tests/test_models.py index e9d517a..b3fd2d8 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -2,7 +2,6 @@ import asyncio import uuid -from contextlib import suppress from types import SimpleNamespace from unittest.mock import patch @@ -10,8 +9,6 @@ import pytest from sqlalchemy import String from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column -from fastapi_toolsets.pytest import create_db_session - import fastapi_toolsets.models.watched as _watched_module from fastapi_toolsets.models import ( CreatedAtMixin, @@ -20,22 +17,23 @@ from fastapi_toolsets.models import ( UpdatedAtMixin, UUIDMixin, UUIDv7Mixin, - WatchedFieldsMixin, - watch, + listens_for, ) from fastapi_toolsets.models.watched import ( + _EVENT_HANDLERS, _SESSION_CREATES, _SESSION_DELETES, - _SESSION_PENDING_NEW, _SESSION_UPDATES, - _after_commit, + _WATCHED_MODELS, _after_flush, - _after_flush_postexec, _after_rollback, + _get_watched_fields, + _invalidate_caches, + _is_watched, _snapshot_column_attrs, - _task_error_handler, _upsert_changes, ) +from fastapi_toolsets.pytest import create_db_session from .conftest import DATABASE_URL @@ -86,56 +84,64 @@ class FullMixinModel(MixinBase, UUIDMixin, UpdatedAtMixin): _test_events: list[dict] = [] -@watch("status") -class WatchedModel(MixinBase, UUIDMixin, WatchedFieldsMixin): +class WatchedModel(MixinBase, UUIDMixin): __tablename__ = "mixin_watched_models" + __watched_fields__ = ("status",) status: Mapped[str] = mapped_column(String(50)) other: Mapped[str] = mapped_column(String(50)) - async def on_create(self) -> None: - _test_events.append({"event": "create", "obj_id": self.id}) - async def on_delete(self) -> None: - _test_events.append({"event": "delete", "obj_id": self.id}) - - async def on_update(self, changes: dict) -> None: - _test_events.append({"event": "update", "obj_id": self.id, "changes": changes}) +@listens_for(WatchedModel, [ModelEvent.CREATE]) +async def _watched_on_create(obj, event_type, changes): + _test_events.append({"event": "create", "obj_id": obj.id}) -@watch("value") -class OnEventModel(MixinBase, UUIDMixin, WatchedFieldsMixin): - """Model that only overrides on_event to test the catch-all path.""" - - __tablename__ = "mixin_on_event_models" - - value: Mapped[str] = mapped_column(String(50)) - - async def on_event(self, event: ModelEvent, changes: dict | None = None) -> None: - _test_events.append({"event": event, "obj_id": self.id, "changes": changes}) +@listens_for(WatchedModel, [ModelEvent.DELETE]) +async def _watched_on_delete(obj, event_type, changes): + _test_events.append({"event": "delete", "obj_id": obj.id}) -class WatchAllModel(MixinBase, UUIDMixin, WatchedFieldsMixin): - """Model without @watch — watches all mapped fields by default.""" +@listens_for(WatchedModel, [ModelEvent.UPDATE]) +async def _watched_on_update(obj, event_type, changes): + _test_events.append({"event": "update", "obj_id": obj.id, "changes": changes}) + + +class WatchAllModel(MixinBase, UUIDMixin): + """Model without __watched_fields__ — watches all mapped fields by default.""" __tablename__ = "mixin_watch_all_models" status: Mapped[str] = mapped_column(String(50)) other: Mapped[str] = mapped_column(String(50)) - async def on_update(self, changes: dict) -> None: - _test_events.append({"event": "update", "obj_id": self.id, "changes": changes}) + +@listens_for(WatchAllModel, [ModelEvent.UPDATE]) +async def _watch_all_on_update(obj, event_type, changes): + _test_events.append({"event": "update", "obj_id": obj.id, "changes": changes}) -class FailingCallbackModel(MixinBase, UUIDMixin, WatchedFieldsMixin): - """Model whose on_create always raises to test exception logging.""" +class FailingCallbackModel(MixinBase, UUIDMixin): + """Model whose CREATE handler always raises to test exception logging.""" __tablename__ = "mixin_failing_callback_models" name: Mapped[str] = mapped_column(String(50)) - async def on_create(self) -> None: - raise RuntimeError("callback intentionally failed") + +@listens_for(FailingCallbackModel, [ModelEvent.CREATE]) +async def _failing_on_create(obj, event_type, changes): + raise RuntimeError("callback intentionally failed") + + +@listens_for(FailingCallbackModel, [ModelEvent.DELETE]) +async def _failing_on_delete(obj, event_type, changes): + raise RuntimeError("delete callback intentionally failed") + + +@listens_for(FailingCallbackModel, [ModelEvent.UPDATE]) +async def _failing_on_update(obj, event_type, changes): + raise RuntimeError("update callback intentionally failed") class NonWatchedModel(MixinBase): @@ -148,7 +154,7 @@ class NonWatchedModel(MixinBase): _poly_events: list[dict] = [] -class PolyAnimal(MixinBase, UUIDMixin, WatchedFieldsMixin): +class PolyAnimal(MixinBase, UUIDMixin): """Base class for STI polymorphism tests.""" __tablename__ = "mixin_poly_animals" @@ -157,15 +163,19 @@ class PolyAnimal(MixinBase, UUIDMixin, WatchedFieldsMixin): kind: Mapped[str] = mapped_column(String(50)) name: Mapped[str] = mapped_column(String(50)) - async def on_create(self) -> None: - _poly_events.append( - {"event": "create", "type": type(self).__name__, "obj_id": self.id} - ) - async def on_delete(self) -> None: - _poly_events.append( - {"event": "delete", "type": type(self).__name__, "obj_id": self.id} - ) +@listens_for(PolyAnimal, [ModelEvent.CREATE]) +async def _poly_on_create(obj, event_type, changes): + _poly_events.append( + {"event": "create", "type": type(obj).__name__, "obj_id": obj.id} + ) + + +@listens_for(PolyAnimal, [ModelEvent.DELETE]) +async def _poly_on_delete(obj, event_type, changes): + _poly_events.append( + {"event": "delete", "type": type(obj).__name__, "obj_id": obj.id} + ) class PolyDog(PolyAnimal): @@ -177,30 +187,33 @@ class PolyDog(PolyAnimal): _watch_inherit_events: list[dict] = [] -@watch("status") -class WatchParent(MixinBase, UUIDMixin, WatchedFieldsMixin): - """Base class with @watch("status") — subclasses should inherit this filter.""" +class WatchParent(MixinBase, UUIDMixin): + """Base class with __watched_fields__ = ("status",) — subclasses inherit.""" __tablename__ = "mixin_watch_parent" + __watched_fields__ = ("status",) __mapper_args__ = {"polymorphic_on": "kind", "polymorphic_identity": "parent"} kind: Mapped[str] = mapped_column(String(50)) status: Mapped[str] = mapped_column(String(50)) other: Mapped[str] = mapped_column(String(50)) - async def on_update(self, changes: dict) -> None: - _watch_inherit_events.append({"type": type(self).__name__, "changes": changes}) + +@listens_for(WatchParent, [ModelEvent.UPDATE]) +async def _watch_parent_on_update(obj, event_type, changes): + _watch_inherit_events.append({"type": type(obj).__name__, "changes": changes}) class WatchChild(WatchParent): - """STI subclass that does NOT redeclare @watch — should inherit parent's filter.""" + """STI subclass that does NOT redeclare __watched_fields__ — inherits parent's filter.""" __mapper_args__ = {"polymorphic_identity": "child"} -@watch("other") class WatchOverride(WatchParent): - """STI subclass that overrides @watch with a different field.""" + """STI subclass that overrides __watched_fields__ with a different field.""" + + __watched_fields__ = ("other",) __mapper_args__ = {"polymorphic_identity": "override"} @@ -208,79 +221,106 @@ class WatchOverride(WatchParent): _attr_access_events: list[dict] = [] -class AttrAccessModel(MixinBase, UUIDMixin, WatchedFieldsMixin): - """Model used to verify that self attributes are accessible in every callback.""" +class AttrAccessModel(MixinBase, UUIDMixin): + """Model used to verify that attributes are accessible in every callback.""" __tablename__ = "mixin_attr_access_models" name: Mapped[str] = mapped_column(String(50)) callback_url: Mapped[str | None] = mapped_column(String(200), nullable=True) - async def on_create(self) -> None: - _attr_access_events.append( - { - "event": "create", - "id": self.id, - "name": self.name, - "callback_url": self.callback_url, - } - ) - async def on_delete(self) -> None: - _attr_access_events.append( - { - "event": "delete", - "id": self.id, - "name": self.name, - "callback_url": self.callback_url, - } - ) +@listens_for(AttrAccessModel, [ModelEvent.CREATE]) +async def _attr_on_create(obj, event_type, changes): + _attr_access_events.append( + { + "event": "create", + "id": obj.id, + "name": obj.name, + "callback_url": obj.callback_url, + } + ) - async def on_update(self, changes: dict) -> None: - _attr_access_events.append( - { - "event": "update", - "id": self.id, - "name": self.name, - "callback_url": self.callback_url, - } - ) + +@listens_for(AttrAccessModel, [ModelEvent.DELETE]) +async def _attr_on_delete(obj, event_type, changes): + _attr_access_events.append( + { + "event": "delete", + "id": obj.id, + "name": obj.name, + "callback_url": obj.callback_url, + } + ) + + +@listens_for(AttrAccessModel, [ModelEvent.UPDATE]) +async def _attr_on_update(obj, event_type, changes): + _attr_access_events.append( + { + "event": "update", + "id": obj.id, + "name": obj.name, + "callback_url": obj.callback_url, + } + ) _sync_events: list[dict] = [] _future_events: list[str] = [] -@watch("status") -class SyncCallbackModel(MixinBase, UUIDMixin, WatchedFieldsMixin): - """Model with plain (sync) on_* callbacks.""" +class SyncCallbackModel(MixinBase, UUIDMixin): + """Model with plain (sync) callbacks.""" __tablename__ = "mixin_sync_callback_models" + __watched_fields__ = ("status",) status: Mapped[str] = mapped_column(String(50)) - def on_create(self) -> None: - _sync_events.append({"event": "create", "obj_id": self.id}) - def on_delete(self) -> None: - _sync_events.append({"event": "delete", "obj_id": self.id}) - - def on_update(self, changes: dict) -> None: - _sync_events.append({"event": "update", "changes": changes}) +@listens_for(SyncCallbackModel, [ModelEvent.CREATE]) +def _sync_on_create(obj, event_type, changes): + _sync_events.append({"event": "create", "obj_id": obj.id}) -class FutureCallbackModel(MixinBase, UUIDMixin, WatchedFieldsMixin): - """Model whose on_create returns an asyncio.Task (awaitable, not a coroutine).""" +@listens_for(SyncCallbackModel, [ModelEvent.DELETE]) +def _sync_on_delete(obj, event_type, changes): + _sync_events.append({"event": "delete", "obj_id": obj.id}) + + +@listens_for(SyncCallbackModel, [ModelEvent.UPDATE]) +def _sync_on_update(obj, event_type, changes): + _sync_events.append({"event": "update", "changes": changes}) + + +class FutureCallbackModel(MixinBase, UUIDMixin): + """Model whose CREATE handler returns an asyncio.Task (awaitable, not a coroutine).""" __tablename__ = "mixin_future_callback_models" name: Mapped[str] = mapped_column(String(50)) - def on_create(self) -> "asyncio.Task[None]": - async def _work() -> None: - _future_events.append("created") - return asyncio.ensure_future(_work()) +@listens_for(FutureCallbackModel, [ModelEvent.CREATE]) +def _future_on_create(obj, event_type, changes): + async def _work(): + _future_events.append("created") + + return asyncio.ensure_future(_work()) + + +class ListenerModel(MixinBase, UUIDMixin): + """Model for testing the listens_for decorator with dynamic registration.""" + + __tablename__ = "mixin_listener_models" + __watched_fields__ = ("status",) + + status: Mapped[str] = mapped_column(String(50)) + other: Mapped[str] = mapped_column(String(50)) + + +_listener_events: list[dict] = [] @pytest.fixture(scope="function") @@ -517,29 +557,34 @@ class TestFullMixinModel: assert obj.updated_at.tzinfo is not None -class TestWatchDecorator: - def test_registers_specific_fields(self): - """@watch("field") stores the field list in _WATCHED_FIELDS.""" - assert _watched_module._WATCHED_FIELDS.get(WatchedModel) == ["status"] +class TestWatchedFields: + def test_specific_fields_set(self): + """__watched_fields__ stores the watched field tuple.""" + assert WatchedModel.__watched_fields__ == ("status",) - def test_no_decorator_not_in_watched_fields(self): - """A model without @watch has no entry in _WATCHED_FIELDS (watch all).""" - assert WatchAllModel not in _watched_module._WATCHED_FIELDS + def test_no_watched_fields_means_all(self): + """A model without __watched_fields__ watches all fields.""" + assert _get_watched_fields(WatchAllModel) is None - def test_preserves_class_identity(self): - """watch returns the same class unchanged.""" + def test_inherits_from_parent(self): + """Subclass without __watched_fields__ inherits parent's value.""" + assert WatchChild.__watched_fields__ == ("status",) - class _Dummy(WatchedFieldsMixin): - pass + def test_override_takes_precedence(self): + """Subclass __watched_fields__ overrides parent's value.""" + assert WatchOverride.__watched_fields__ == ("other",) - result = watch("x")(_Dummy) - assert result is _Dummy - del _watched_module._WATCHED_FIELDS[_Dummy] + def test_invalid_watched_fields_raises_type_error(self): + """__watched_fields__ must be a tuple of strings.""" - def test_raises_when_no_fields_given(self): - """@watch() with no field names raises ValueError.""" - with pytest.raises(ValueError, match="@watch requires at least one field name"): - watch() + class BadModel(MixinBase, UUIDMixin): + __tablename__ = "mixin_bad_watched_fields" + __watched_fields__ = ["status"] # list, not tuple + + status: Mapped[str] = mapped_column(String(50)) + + with pytest.raises(TypeError, match="must be a tuple"): + _get_watched_fields(BadModel) class TestWatchInheritance: @@ -551,29 +596,25 @@ class TestWatchInheritance: @pytest.mark.anyio async def test_child_inherits_parent_watch_filter(self, mixin_session): - """Subclass without @watch inherits the parent's field filter.""" + """Subclass without __watched_fields__ inherits the parent's field filter.""" obj = WatchChild(status="initial", other="x") mixin_session.add(obj) await mixin_session.commit() - await asyncio.sleep(0) - obj.other = "changed" # not watched by parent's @watch("status") + obj.other = "changed" # not watched by parent's __watched_fields__ await mixin_session.commit() - await asyncio.sleep(0) assert _watch_inherit_events == [] @pytest.mark.anyio async def test_child_triggers_on_watched_field(self, mixin_session): - """Subclass without @watch triggers on_update for the parent's watched field.""" + """Subclass without __watched_fields__ triggers handler for the parent's watched field.""" obj = WatchChild(status="initial", other="x") mixin_session.add(obj) await mixin_session.commit() - await asyncio.sleep(0) obj.status = "updated" await mixin_session.commit() - await asyncio.sleep(0) assert len(_watch_inherit_events) == 1 assert _watch_inherit_events[0]["type"] == "WatchChild" @@ -581,28 +622,39 @@ class TestWatchInheritance: @pytest.mark.anyio async def test_subclass_override_takes_precedence(self, mixin_session): - """Subclass @watch overrides the parent's field filter.""" + """Subclass __watched_fields__ overrides the parent's field filter.""" obj = WatchOverride(status="initial", other="x") mixin_session.add(obj) await mixin_session.commit() - await asyncio.sleep(0) - obj.status = ( - "changed" # watched by parent but overridden by child's @watch("other") - ) + obj.status = "changed" # overridden by child's __watched_fields__ = ("other",) await mixin_session.commit() - await asyncio.sleep(0) assert _watch_inherit_events == [] obj.other = "changed" await mixin_session.commit() - await asyncio.sleep(0) assert len(_watch_inherit_events) == 1 assert "other" in _watch_inherit_events[0]["changes"] +class TestIsWatched: + def test_watched_model_is_watched(self): + """_is_watched returns True for models with registered handlers.""" + obj = WatchedModel(status="x", other="y") + assert _is_watched(obj) is True + + def test_non_watched_model_is_not_watched(self): + """_is_watched returns False for models without registered handlers.""" + assert _is_watched(object()) is False + + def test_subclass_of_watched_model_is_watched(self): + """_is_watched returns True for subclasses of watched models (via MRO).""" + dog = PolyDog(name="Rex") + assert _is_watched(dog) is True + + class TestUpsertChanges: def test_inserts_new_entry(self): """New key is inserted with the full changes dict.""" @@ -640,221 +692,166 @@ class TestAfterFlush: _after_flush(session, None) assert session.info == {} - def test_captures_new_watched_mixin_objects(self): - """New WatchedFieldsMixin instances are added to _SESSION_PENDING_NEW.""" - obj = WatchedFieldsMixin() + def test_captures_new_watched_objects(self): + """New watched objects are added to _SESSION_CREATES.""" + obj = object() session = SimpleNamespace(new=[obj], deleted=[], dirty=[], info={}) - _after_flush(session, None) - assert session.info[_SESSION_PENDING_NEW] == [obj] + with patch("fastapi_toolsets.models.watched._is_watched", return_value=True): + _after_flush(session, None) + assert session.info[_SESSION_CREATES] == [obj] - def test_ignores_new_non_mixin_objects(self): - """New objects that are not WatchedFieldsMixin are not captured.""" + def test_ignores_new_non_watched_objects(self): + """New objects that are not watched are not captured.""" obj = object() session = SimpleNamespace(new=[obj], deleted=[], dirty=[], info={}) _after_flush(session, None) - assert _SESSION_PENDING_NEW not in session.info + assert _SESSION_CREATES not in session.info - def test_captures_deleted_watched_mixin_objects(self): - """Deleted WatchedFieldsMixin instances are added to _SESSION_DELETES.""" - obj = WatchedFieldsMixin() + def test_captures_deleted_watched_objects(self): + """Deleted watched objects are stored as (obj, snapshot) tuples.""" + obj = object() session = SimpleNamespace(new=[], deleted=[obj], dirty=[], info={}) - _after_flush(session, None) - assert session.info[_SESSION_DELETES] == [obj] + with ( + patch("fastapi_toolsets.models.watched._is_watched", return_value=True), + patch( + "fastapi_toolsets.models.watched._snapshot_column_attrs", + return_value={"id": 1}, + ), + ): + _after_flush(session, None) + assert len(session.info[_SESSION_DELETES]) == 1 + assert session.info[_SESSION_DELETES][0][0] is obj + assert session.info[_SESSION_DELETES][0][1] == {"id": 1} - def test_ignores_deleted_non_mixin_objects(self): - """Deleted objects that are not WatchedFieldsMixin are not captured.""" + def test_ignores_deleted_non_watched_objects(self): + """Deleted objects that are not watched are not captured.""" obj = object() session = SimpleNamespace(new=[], deleted=[obj], dirty=[], info={}) _after_flush(session, None) assert _SESSION_DELETES not in session.info -class TestAfterFlushPostexec: - def test_does_nothing_when_no_pending_new(self): - """_after_flush_postexec does nothing when _SESSION_PENDING_NEW is absent.""" - session = SimpleNamespace(info={}) - _after_flush_postexec(session, None) - assert _SESSION_CREATES not in session.info - - def test_moves_pending_new_to_creates(self): - """Objects from _SESSION_PENDING_NEW are moved to _SESSION_CREATES.""" - obj = object() - session = SimpleNamespace(info={_SESSION_PENDING_NEW: [obj]}) - _after_flush_postexec(session, None) - assert _SESSION_PENDING_NEW not in session.info - assert session.info[_SESSION_CREATES] == [obj] - - def test_extends_existing_creates(self): - """Multiple flushes accumulate in _SESSION_CREATES.""" - a, b = object(), object() - session = SimpleNamespace( - info={_SESSION_PENDING_NEW: [b], _SESSION_CREATES: [a]} - ) - _after_flush_postexec(session, None) - assert session.info[_SESSION_CREATES] == [a, b] - - class TestAfterRollback: def test_clears_all_session_info_keys(self): - """_after_rollback removes all four tracking keys from session.info.""" + """_after_rollback removes all three tracking keys on full rollback.""" session = SimpleNamespace( info={ - _SESSION_PENDING_NEW: [object()], _SESSION_CREATES: [object()], _SESSION_DELETES: [object()], _SESSION_UPDATES: {1: ("obj", {"f": {"old": "a", "new": "b"}})}, - } + }, + in_transaction=lambda: False, ) _after_rollback(session) - assert _SESSION_PENDING_NEW not in session.info assert _SESSION_CREATES not in session.info assert _SESSION_DELETES not in session.info assert _SESSION_UPDATES not in session.info def test_tolerates_missing_keys(self): """_after_rollback does not raise when session.info has no pending data.""" - session = SimpleNamespace(info={}) + session = SimpleNamespace(info={}, in_transaction=lambda: False) _after_rollback(session) # must not raise - -class TestTaskErrorHandler: - @pytest.mark.anyio - async def test_logs_exception_from_failed_task(self): - """_task_error_handler calls _logger.error when the task raised.""" - - async def failing() -> None: - raise ValueError("boom") - - task = asyncio.create_task(failing()) - await asyncio.sleep(0) - - with patch.object(_watched_module._logger, "error") as mock_error: - _task_error_handler(task) - mock_error.assert_called_once() - - @pytest.mark.anyio - async def test_ignores_cancelled_task(self): - """_task_error_handler does not log when the task was cancelled.""" - - async def slow() -> None: - await asyncio.sleep(100) - - task = asyncio.create_task(slow()) - task.cancel() - with suppress(asyncio.CancelledError): - await task - - with patch.object(_watched_module._logger, "error") as mock_error: - _task_error_handler(task) - mock_error.assert_not_called() + def test_preserves_events_on_savepoint_rollback(self): + """_after_rollback keeps events when still in a transaction (savepoint).""" + creates = [object()] + session = SimpleNamespace( + info={ + _SESSION_CREATES: creates, + _SESSION_DELETES: [], + _SESSION_UPDATES: {}, + }, + in_transaction=lambda: True, + ) + _after_rollback(session) + assert session.info[_SESSION_CREATES] is creates -class TestAfterCommitNoLoop: - def test_no_task_scheduled_when_no_running_loop(self): - """_after_commit silently returns when called outside an async context.""" - called = [] - obj = SimpleNamespace(on_create=lambda: called.append("create")) - session = SimpleNamespace(info={_SESSION_CREATES: [obj]}) - _after_commit(session) - assert called == [] - - def test_returns_early_when_all_pending_empty(self): - """_after_commit does nothing when all pending lists are empty.""" - session = SimpleNamespace(info={}) - _after_commit(session) # should not raise - - -class TestWatchedFieldsMixin: +class TestEventCallbacks: @pytest.fixture(autouse=True) def clear_events(self): _test_events.clear() yield _test_events.clear() - # --- on_create --- + # --- CREATE --- @pytest.mark.anyio - async def test_on_create_fires_after_insert(self, mixin_session): - """on_create is called after INSERT commit.""" + async def test_create_fires_after_insert(self, mixin_session): + """CREATE handler is called after INSERT commit.""" obj = WatchedModel(status="active", other="x") mixin_session.add(obj) await mixin_session.commit() - await asyncio.sleep(0) creates = [e for e in _test_events if e["event"] == "create"] assert len(creates) == 1 @pytest.mark.anyio - async def test_on_create_server_defaults_populated(self, mixin_session): - """id (server default via RETURNING) is available inside on_create.""" + async def test_create_server_defaults_populated(self, mixin_session): + """id (server default via RETURNING) is available inside CREATE handler.""" obj = WatchedModel(status="active", other="x") mixin_session.add(obj) await mixin_session.commit() - await asyncio.sleep(0) creates = [e for e in _test_events if e["event"] == "create"] assert creates[0]["obj_id"] is not None assert isinstance(creates[0]["obj_id"], uuid.UUID) @pytest.mark.anyio - async def test_on_create_not_fired_on_update(self, mixin_session): - """on_create is NOT called when an existing row is updated.""" + async def test_create_not_fired_on_update(self, mixin_session): + """CREATE handler is NOT called when an existing row is updated.""" obj = WatchedModel(status="initial", other="x") mixin_session.add(obj) await mixin_session.commit() - await asyncio.sleep(0) + _test_events.clear() obj.status = "updated" await mixin_session.commit() - await asyncio.sleep(0) assert not any(e["event"] == "create" for e in _test_events) - # --- on_delete --- + # --- DELETE --- @pytest.mark.anyio - async def test_on_delete_fires_after_delete(self, mixin_session): - """on_delete is called after DELETE commit.""" + async def test_delete_fires_after_delete(self, mixin_session): + """DELETE handler is called after DELETE commit.""" obj = WatchedModel(status="active", other="x") mixin_session.add(obj) await mixin_session.commit() - await asyncio.sleep(0) + saved_id = obj.id _test_events.clear() await mixin_session.delete(obj) await mixin_session.commit() - await asyncio.sleep(0) deletes = [e for e in _test_events if e["event"] == "delete"] assert len(deletes) == 1 assert deletes[0]["obj_id"] == saved_id @pytest.mark.anyio - async def test_on_delete_not_fired_on_insert(self, mixin_session): - """on_delete is NOT called when a new row is inserted.""" + async def test_delete_not_fired_on_insert(self, mixin_session): + """DELETE handler is NOT called when a new row is inserted.""" obj = WatchedModel(status="active", other="x") mixin_session.add(obj) await mixin_session.commit() - await asyncio.sleep(0) assert not any(e["event"] == "delete" for e in _test_events) - # --- on_update --- + # --- UPDATE --- @pytest.mark.anyio - async def test_on_update_fires_on_update(self, mixin_session): - """on_update reports the correct before/after values on UPDATE.""" + async def test_update_fires_on_update(self, mixin_session): + """UPDATE handler reports the correct before/after values.""" obj = WatchedModel(status="initial", other="x") mixin_session.add(obj) await mixin_session.commit() - await asyncio.sleep(0) + _test_events.clear() obj.status = "updated" await mixin_session.commit() - await asyncio.sleep(0) changes_events = [e for e in _test_events if e["event"] == "update"] assert len(changes_events) == 1 @@ -864,27 +861,40 @@ class TestWatchedFieldsMixin: } @pytest.mark.anyio - async def test_on_update_not_fired_on_insert(self, mixin_session): - """on_update is NOT called on INSERT (on_create handles that).""" + async def test_update_not_fired_on_insert(self, mixin_session): + """UPDATE handler is NOT called on INSERT (CREATE handles that).""" obj = WatchedModel(status="active", other="x") mixin_session.add(obj) await mixin_session.commit() - await asyncio.sleep(0) assert not any(e["event"] == "update" for e in _test_events) + @pytest.mark.anyio + async def test_create_and_update_in_same_tx_only_fires_create(self, mixin_session): + """Modifying a watched field before commit only fires CREATE, not UPDATE.""" + obj = WatchedModel(status="initial", other="x") + mixin_session.add(obj) + await mixin_session.flush() + + obj.status = "updated-before-commit" + await mixin_session.commit() + + creates = [e for e in _test_events if e["event"] == "create"] + updates = [e for e in _test_events if e["event"] == "update"] + assert len(creates) == 1 + assert updates == [] + @pytest.mark.anyio async def test_unwatched_field_update_no_callback(self, mixin_session): - """Changing a field not listed in @update does not fire on_update.""" + """Changing a field not in __watched_fields__ does not fire UPDATE handler.""" obj = WatchedModel(status="active", other="x") mixin_session.add(obj) await mixin_session.commit() - await asyncio.sleep(0) + _test_events.clear() obj.other = "changed" await mixin_session.commit() - await asyncio.sleep(0) assert _test_events == [] @@ -894,7 +904,7 @@ class TestWatchedFieldsMixin: obj = WatchedModel(status="initial", other="x") mixin_session.add(obj) await mixin_session.commit() - await asyncio.sleep(0) + _test_events.clear() obj.status = "intermediate" @@ -902,7 +912,6 @@ class TestWatchedFieldsMixin: obj.status = "final" await mixin_session.commit() - await asyncio.sleep(0) changes_events = [e for e in _test_events if e["event"] == "update"] assert len(changes_events) == 1 @@ -917,89 +926,62 @@ class TestWatchedFieldsMixin: obj = WatchedModel(status="active", other="x") mixin_session.add(obj) await mixin_session.commit() - await asyncio.sleep(0) + _test_events.clear() obj.status = "changed" await mixin_session.flush() await mixin_session.rollback() - await asyncio.sleep(0) assert _test_events == [] @pytest.mark.anyio - async def test_callback_exception_is_logged(self, mixin_session): - """Exceptions raised inside on_create are logged, not propagated.""" + async def test_create_callback_exception_is_logged(self, mixin_session): + """Exceptions raised inside a CREATE handler are logged, not propagated.""" obj = FailingCallbackModel(name="boom") mixin_session.add(obj) with patch.object(_watched_module._logger, "error") as mock_error: await mixin_session.commit() - await asyncio.sleep(0) + + mock_error.assert_called_once() + + @pytest.mark.anyio + async def test_delete_callback_exception_is_logged(self, mixin_session): + """Exceptions raised inside a DELETE handler are logged, not propagated.""" + obj = FailingCallbackModel(name="boom") + mixin_session.add(obj) + await mixin_session.commit() # CREATE handler fails (logged) + + await mixin_session.delete(obj) + with patch.object(_watched_module._logger, "error") as mock_error: + await mixin_session.commit() + + mock_error.assert_called_once() + + @pytest.mark.anyio + async def test_update_callback_exception_is_logged(self, mixin_session): + """Exceptions raised inside an UPDATE handler are logged, not propagated.""" + obj = FailingCallbackModel(name="boom") + mixin_session.add(obj) + await mixin_session.commit() # CREATE handler fails (logged) + + obj.name = "changed" + with patch.object(_watched_module._logger, "error") as mock_error: + await mixin_session.commit() + mock_error.assert_called_once() @pytest.mark.anyio async def test_non_watched_model_no_callback(self, mixin_session): - """Dirty objects whose type is not a WatchedFieldsMixin are skipped.""" + """Dirty objects whose type has no registered handlers are skipped.""" nw = NonWatchedModel(value="x") mixin_session.add(nw) await mixin_session.flush() nw.value = "y" await mixin_session.commit() - await asyncio.sleep(0) assert _test_events == [] - # --- on_event (catch-all) --- - - @pytest.mark.anyio - async def test_on_event_receives_create(self, mixin_session): - """on_event is called with ModelEvent.CREATE on INSERT when only on_event is overridden.""" - obj = OnEventModel(value="x") - mixin_session.add(obj) - await mixin_session.commit() - await asyncio.sleep(0) - - creates = [e for e in _test_events if e["event"] == ModelEvent.CREATE] - assert len(creates) == 1 - assert creates[0]["changes"] is None - - @pytest.mark.anyio - async def test_on_event_receives_delete(self, mixin_session): - """on_event is called with ModelEvent.DELETE on DELETE when only on_event is overridden.""" - obj = OnEventModel(value="x") - mixin_session.add(obj) - await mixin_session.commit() - await asyncio.sleep(0) - _test_events.clear() - - await mixin_session.delete(obj) - await mixin_session.commit() - await asyncio.sleep(0) - - deletes = [e for e in _test_events if e["event"] == ModelEvent.DELETE] - assert len(deletes) == 1 - assert deletes[0]["changes"] is None - - @pytest.mark.anyio - async def test_on_event_receives_field_change(self, mixin_session): - """on_event is called with ModelEvent.UPDATE on UPDATE when only on_event is overridden.""" - obj = OnEventModel(value="initial") - mixin_session.add(obj) - await mixin_session.commit() - await asyncio.sleep(0) - _test_events.clear() - - obj.value = "updated" - await mixin_session.commit() - await asyncio.sleep(0) - - changes_events = [e for e in _test_events if e["event"] == ModelEvent.UPDATE] - assert len(changes_events) == 1 - assert changes_events[0]["changes"]["value"] == { - "old": "initial", - "new": "updated", - } - class TestTransientObject: """Create + delete within the same transaction should fire no events.""" @@ -1014,19 +996,18 @@ class TestTransientObject: async def test_no_events_when_created_and_deleted_in_same_transaction( self, mixin_session ): - """Neither on_create nor on_delete fires when the object never survives a commit.""" + """Neither CREATE nor DELETE fires when the object never survives a commit.""" obj = WatchedModel(status="active", other="x") mixin_session.add(obj) await mixin_session.flush() await mixin_session.delete(obj) await mixin_session.commit() - await asyncio.sleep(0) assert _test_events == [] @pytest.mark.anyio async def test_other_objects_unaffected(self, mixin_session): - """on_create still fires for objects that are not deleted in the same transaction.""" + """CREATE still fires for objects that are not deleted in the same transaction.""" survivor = WatchedModel(status="active", other="x") transient = WatchedModel(status="gone", other="y") mixin_session.add(survivor) @@ -1034,7 +1015,6 @@ class TestTransientObject: await mixin_session.flush() await mixin_session.delete(transient) await mixin_session.commit() - await asyncio.sleep(0) creates = [e for e in _test_events if e["event"] == "create"] deletes = [e for e in _test_events if e["event"] == "delete"] @@ -1044,18 +1024,17 @@ class TestTransientObject: @pytest.mark.anyio async def test_distinct_create_and_delete_both_fire(self, mixin_session): - """on_create and on_delete both fire when different objects are created and deleted.""" + """CREATE and DELETE both fire when different objects are created and deleted.""" existing = WatchedModel(status="old", other="x") mixin_session.add(existing) await mixin_session.commit() - await asyncio.sleep(0) + _test_events.clear() new_obj = WatchedModel(status="new", other="y") mixin_session.add(new_obj) await mixin_session.delete(existing) await mixin_session.commit() - await asyncio.sleep(0) creates = [e for e in _test_events if e["event"] == "create"] deletes = [e for e in _test_events if e["event"] == "delete"] @@ -1064,7 +1043,7 @@ class TestTransientObject: class TestPolymorphism: - """WatchedFieldsMixin with STI (Single Table Inheritance).""" + """Event dispatch with STI (Single Table Inheritance).""" @pytest.fixture(autouse=True) def clear_events(self): @@ -1073,29 +1052,27 @@ class TestPolymorphism: _poly_events.clear() @pytest.mark.anyio - async def test_on_create_fires_once_for_subclass(self, mixin_session): - """on_create fires exactly once for a STI subclass instance.""" + async def test_create_fires_once_for_subclass(self, mixin_session): + """CREATE fires exactly once for a STI subclass instance.""" dog = PolyDog(name="Rex") mixin_session.add(dog) await mixin_session.commit() - await asyncio.sleep(0) assert len(_poly_events) == 1 assert _poly_events[0]["event"] == "create" assert _poly_events[0]["type"] == "PolyDog" @pytest.mark.anyio - async def test_on_delete_fires_for_subclass(self, mixin_session): - """on_delete fires for a STI subclass instance.""" + async def test_delete_fires_for_subclass(self, mixin_session): + """DELETE fires for a STI subclass instance.""" dog = PolyDog(name="Rex") mixin_session.add(dog) await mixin_session.commit() - await asyncio.sleep(0) + _poly_events.clear() await mixin_session.delete(dog) await mixin_session.commit() - await asyncio.sleep(0) assert len(_poly_events) == 1 assert _poly_events[0]["event"] == "delete" @@ -1109,7 +1086,6 @@ class TestPolymorphism: await mixin_session.flush() await mixin_session.delete(dog) await mixin_session.commit() - await asyncio.sleep(0) assert _poly_events == [] @@ -1123,16 +1099,15 @@ class TestWatchAll: @pytest.mark.anyio async def test_watch_all_fires_for_any_field(self, mixin_session): - """Model without @watch fires on_update for any changed field.""" + """Model without __watched_fields__ fires UPDATE for any changed field.""" obj = WatchAllModel(status="initial", other="x") mixin_session.add(obj) await mixin_session.commit() - await asyncio.sleep(0) + _test_events.clear() obj.other = "changed" await mixin_session.commit() - await asyncio.sleep(0) changes_events = [e for e in _test_events if e["event"] == "update"] assert len(changes_events) == 1 @@ -1140,17 +1115,16 @@ class TestWatchAll: @pytest.mark.anyio async def test_watch_all_captures_multiple_fields(self, mixin_session): - """Model without @watch captures all fields changed in a single commit.""" + """Model without __watched_fields__ captures all fields changed in a single commit.""" obj = WatchAllModel(status="initial", other="x") mixin_session.add(obj) await mixin_session.commit() - await asyncio.sleep(0) + _test_events.clear() obj.status = "updated" obj.other = "changed" await mixin_session.commit() - await asyncio.sleep(0) changes_events = [e for e in _test_events if e["event"] == "update"] assert len(changes_events) == 1 @@ -1166,45 +1140,42 @@ class TestSyncCallbacks: _sync_events.clear() @pytest.mark.anyio - async def test_sync_on_create_fires(self, mixin_session): - """Sync on_create is called after INSERT commit.""" + async def test_sync_create_fires(self, mixin_session): + """Sync CREATE handler is called after INSERT commit.""" obj = SyncCallbackModel(status="active") mixin_session.add(obj) await mixin_session.commit() - await asyncio.sleep(0) creates = [e for e in _sync_events if e["event"] == "create"] assert len(creates) == 1 assert isinstance(creates[0]["obj_id"], uuid.UUID) @pytest.mark.anyio - async def test_sync_on_delete_fires(self, mixin_session): - """Sync on_delete is called after DELETE commit.""" + async def test_sync_delete_fires(self, mixin_session): + """Sync DELETE handler is called after DELETE commit.""" obj = SyncCallbackModel(status="active") mixin_session.add(obj) await mixin_session.commit() - await asyncio.sleep(0) + _sync_events.clear() await mixin_session.delete(obj) await mixin_session.commit() - await asyncio.sleep(0) deletes = [e for e in _sync_events if e["event"] == "delete"] assert len(deletes) == 1 @pytest.mark.anyio - async def test_sync_on_update_fires(self, mixin_session): - """Sync on_update is called after UPDATE commit with correct changes.""" + async def test_sync_update_fires(self, mixin_session): + """Sync UPDATE handler is called after UPDATE commit with correct changes.""" obj = SyncCallbackModel(status="initial") mixin_session.add(obj) await mixin_session.commit() - await asyncio.sleep(0) + _sync_events.clear() obj.status = "updated" await mixin_session.commit() - await asyncio.sleep(0) updates = [e for e in _sync_events if e["event"] == "update"] assert len(updates) == 1 @@ -1222,22 +1193,19 @@ class TestFutureCallbacks: @pytest.mark.anyio async def test_task_callback_is_awaited(self, mixin_session): - """on_create returning an asyncio.Task is awaited and its work completes.""" + """CREATE handler returning an asyncio.Task is awaited and its work completes.""" obj = FutureCallbackModel(name="test") mixin_session.add(obj) await mixin_session.commit() - # Two turns: one for _run() to execute, one for the inner _work() task. - await asyncio.sleep(0) - await asyncio.sleep(0) assert _future_events == ["created"] class TestAttributeAccessInCallbacks: - """Verify that self attributes are accessible inside every callback type. + """Verify that object attributes are accessible inside every callback type. Uses expire_on_commit=True (the SQLAlchemy default) so the tests would fail - without the snapshot-restore logic in _schedule_with_snapshot. + without the refresh/snapshot-restore logic in EventSession.commit(). """ @pytest.fixture(autouse=True) @@ -1247,12 +1215,11 @@ class TestAttributeAccessInCallbacks: _attr_access_events.clear() @pytest.mark.anyio - async def test_on_create_pk_and_field_accessible(self, mixin_session_expire): - """id (server default) and regular fields are readable inside on_create.""" + async def test_create_pk_and_field_accessible(self, mixin_session_expire): + """id (server default) and regular fields are readable inside CREATE handler.""" obj = AttrAccessModel(name="hello") mixin_session_expire.add(obj) await mixin_session_expire.commit() - await asyncio.sleep(0) events = [e for e in _attr_access_events if e["event"] == "create"] assert len(events) == 1 @@ -1260,17 +1227,16 @@ class TestAttributeAccessInCallbacks: assert events[0]["name"] == "hello" @pytest.mark.anyio - async def test_on_delete_pk_and_field_accessible(self, mixin_session_expire): - """id and regular fields are readable inside on_delete.""" + async def test_delete_pk_and_field_accessible(self, mixin_session_expire): + """id and regular fields are readable inside DELETE handler.""" obj = AttrAccessModel(name="to-delete") mixin_session_expire.add(obj) await mixin_session_expire.commit() - await asyncio.sleep(0) + _attr_access_events.clear() await mixin_session_expire.delete(obj) await mixin_session_expire.commit() - await asyncio.sleep(0) events = [e for e in _attr_access_events if e["event"] == "delete"] assert len(events) == 1 @@ -1278,19 +1244,16 @@ class TestAttributeAccessInCallbacks: assert events[0]["name"] == "to-delete" @pytest.mark.anyio - async def test_on_update_pk_and_updated_field_accessible( - self, mixin_session_expire - ): - """id and the new field value are readable inside on_update.""" + async def test_update_pk_and_updated_field_accessible(self, mixin_session_expire): + """id and the new field value are readable inside UPDATE handler.""" obj = AttrAccessModel(name="original") mixin_session_expire.add(obj) await mixin_session_expire.commit() - await asyncio.sleep(0) + _attr_access_events.clear() obj.name = "updated" await mixin_session_expire.commit() - await asyncio.sleep(0) events = [e for e in _attr_access_events if e["event"] == "update"] assert len(events) == 1 @@ -1298,28 +1261,26 @@ class TestAttributeAccessInCallbacks: assert events[0]["name"] == "updated" @pytest.mark.anyio - async def test_nullable_column_none_accessible_in_on_create( + async def test_nullable_column_none_accessible_in_create( self, mixin_session_expire ): - """Nullable column left as None is accessible in on_create without greenlet error.""" + """Nullable column left as None is accessible in CREATE handler without greenlet error.""" obj = AttrAccessModel(name="no-url") # callback_url not set → None mixin_session_expire.add(obj) await mixin_session_expire.commit() - await asyncio.sleep(0) events = [e for e in _attr_access_events if e["event"] == "create"] assert len(events) == 1 assert events[0]["callback_url"] is None @pytest.mark.anyio - async def test_nullable_column_with_value_accessible_in_on_create( + async def test_nullable_column_with_value_accessible_in_create( self, mixin_session_expire ): - """Nullable column set to a value is accessible in on_create without greenlet error.""" + """Nullable column set to a value is accessible in CREATE handler without greenlet error.""" obj = AttrAccessModel(name="with-url", callback_url="https://example.com/hook") mixin_session_expire.add(obj) await mixin_session_expire.commit() - await asyncio.sleep(0) events = [e for e in _attr_access_events if e["event"] == "create"] assert len(events) == 1 @@ -1329,34 +1290,488 @@ class TestAttributeAccessInCallbacks: async def test_nullable_column_accessible_after_update_to_none( self, mixin_session_expire ): - """Nullable column updated to None is accessible in on_update without greenlet error.""" + """Nullable column updated to None is accessible in UPDATE handler without greenlet error.""" obj = AttrAccessModel(name="x", callback_url="https://example.com/hook") mixin_session_expire.add(obj) await mixin_session_expire.commit() - await asyncio.sleep(0) + _attr_access_events.clear() obj.callback_url = None await mixin_session_expire.commit() - await asyncio.sleep(0) events = [e for e in _attr_access_events if e["event"] == "update"] assert len(events) == 1 assert events[0]["callback_url"] is None @pytest.mark.anyio - async def test_expired_nullable_column_not_inferred_as_none( + async def test_snapshot_on_loaded_object_captures_nullable_column( self, mixin_session_expire ): - """A nullable column with a real value that is expired (by a prior - expire_on_commit) must not be inferred as None in the snapshot — its - actual value is unknown without a DB refresh.""" + """_snapshot_column_attrs on a loaded (non-expired) object captures + nullable columns correctly — used for delete snapshots at flush time.""" obj = AttrAccessModel(name="original", callback_url="https://example.com/hook") mixin_session_expire.add(obj) - await mixin_session_expire.commit() - # expire_on_commit fired → obj.state.expired=True, callback_url not in state.dict + await mixin_session_expire.flush() + # Object is loaded (just flushed) — snapshot should capture everything. snapshot = _snapshot_column_attrs(obj) + assert snapshot["callback_url"] == "https://example.com/hook" + assert snapshot["name"] == "original" - # callback_url has a real DB value but is expired — must not be snapshotted as None. - assert "callback_url" not in snapshot + +class TestListensFor: + """Test the listens_for decorator for external handler registration.""" + + @pytest.fixture(autouse=True) + def clear_events(self): + _listener_events.clear() + yield + _listener_events.clear() + # Clean up registered handlers for ListenerModel. + for key in list(_EVENT_HANDLERS): + if key[0] is ListenerModel: + del _EVENT_HANDLERS[key] + _WATCHED_MODELS.discard(ListenerModel) + _invalidate_caches() + + @pytest.mark.anyio + async def test_create_handler_fires(self, mixin_session): + """Registered CREATE handler is called after INSERT commit.""" + + @listens_for(ListenerModel, [ModelEvent.CREATE]) + async def _on_create(obj, event_type, changes): + _listener_events.append({"event": "create", "id": obj.id}) + + obj = ListenerModel(status="active", other="x") + mixin_session.add(obj) + await mixin_session.commit() + + creates = [e for e in _listener_events if e["event"] == "create"] + assert len(creates) == 1 + assert isinstance(creates[0]["id"], uuid.UUID) + + @pytest.mark.anyio + async def test_delete_handler_fires(self, mixin_session): + """Registered DELETE handler is called after DELETE commit.""" + + @listens_for(ListenerModel, [ModelEvent.DELETE]) + async def _on_delete(obj, event_type, changes): + _listener_events.append({"event": "delete", "id": obj.id}) + + obj = ListenerModel(status="active", other="x") + mixin_session.add(obj) + await mixin_session.commit() + saved_id = obj.id + + await mixin_session.delete(obj) + await mixin_session.commit() + + deletes = [e for e in _listener_events if e["event"] == "delete"] + assert len(deletes) == 1 + assert deletes[0]["id"] == saved_id + + @pytest.mark.anyio + async def test_update_handler_receives_changes(self, mixin_session): + """Registered UPDATE handler receives the object and changes dict.""" + + @listens_for(ListenerModel, [ModelEvent.UPDATE]) + async def _on_update(obj, event_type, changes): + _listener_events.append( + {"event": "update", "id": obj.id, "changes": changes} + ) + + obj = ListenerModel(status="initial", other="x") + mixin_session.add(obj) + await mixin_session.commit() + + obj.status = "updated" + await mixin_session.commit() + + updates = [e for e in _listener_events if e["event"] == "update"] + assert len(updates) == 1 + assert updates[0]["changes"]["status"] == { + "old": "initial", + "new": "updated", + } + + @pytest.mark.anyio + async def test_default_all_event_types(self, mixin_session): + """listens_for defaults to all event types when none specified.""" + + @listens_for(ListenerModel) + async def _on_any(obj, event_type, changes): + _listener_events.append({"event": "any"}) + + obj = ListenerModel(status="active", other="x") + mixin_session.add(obj) + await mixin_session.commit() + + obj.status = "updated" + await mixin_session.commit() + + await mixin_session.delete(obj) + await mixin_session.commit() + + assert len(_listener_events) == 3 + + @pytest.mark.anyio + async def test_multiple_handlers_all_fire(self, mixin_session): + """Multiple handlers registered for the same event all fire.""" + + @listens_for(ListenerModel, [ModelEvent.CREATE]) + async def _handler_a(obj, event_type, changes): + _listener_events.append({"handler": "a"}) + + @listens_for(ListenerModel, [ModelEvent.CREATE]) + async def _handler_b(obj, event_type, changes): + _listener_events.append({"handler": "b"}) + + obj = ListenerModel(status="active", other="x") + mixin_session.add(obj) + await mixin_session.commit() + + handlers = [e["handler"] for e in _listener_events] + assert "a" in handlers + assert "b" in handlers + + @pytest.mark.anyio + async def test_sync_handler_works(self, mixin_session): + """Sync (non-async) registered handler is called.""" + + @listens_for(ListenerModel, [ModelEvent.CREATE]) + def _on_create(obj, event_type, changes): + _listener_events.append({"event": "create", "id": obj.id}) + + obj = ListenerModel(status="active", other="x") + mixin_session.add(obj) + await mixin_session.commit() + + assert len(_listener_events) == 1 + + @pytest.mark.anyio + async def test_multiple_event_types(self, mixin_session): + """listens_for accepts multiple event types and registers for all of them.""" + + @listens_for(ListenerModel, [ModelEvent.CREATE, ModelEvent.UPDATE]) + async def _on_change(obj, event_type, changes): + _listener_events.append({"event": "change", "id": obj.id}) + + obj = ListenerModel(status="initial", other="x") + mixin_session.add(obj) + await mixin_session.commit() + + obj.status = "updated" + await mixin_session.commit() + + assert len(_listener_events) == 2 + assert all(e["event"] == "change" for e in _listener_events) + + +class TestEventSessionWithGetTransaction: + """Verify callbacks fire correctly when using get_transaction / lock_tables.""" + + @pytest.fixture(autouse=True) + def clear_events(self): + _test_events.clear() + yield + _test_events.clear() + + @pytest.mark.anyio + async def test_callbacks_fire_after_outer_commit_not_savepoint(self, mixin_session): + """get_transaction creates a savepoint; callbacks fire only on outer commit.""" + from fastapi_toolsets.db import get_transaction + + async with get_transaction(mixin_session): + obj = WatchedModel(status="active", other="x") + mixin_session.add(obj) + + # Still inside the session's outer transaction — savepoint committed, + # but EventSession.commit() hasn't been called yet. + assert _test_events == [] + + await mixin_session.commit() + + creates = [e for e in _test_events if e["event"] == "create"] + assert len(creates) == 1 + + @pytest.mark.anyio + async def test_nested_transactions_accumulate_events(self, mixin_session): + """Multiple get_transaction blocks accumulate events for a single commit.""" + from fastapi_toolsets.db import get_transaction + + async with get_transaction(mixin_session): + obj1 = WatchedModel(status="first", other="x") + mixin_session.add(obj1) + + async with get_transaction(mixin_session): + obj2 = WatchedModel(status="second", other="y") + mixin_session.add(obj2) + + assert _test_events == [] + + await mixin_session.commit() + + creates = [e for e in _test_events if e["event"] == "create"] + assert len(creates) == 2 + + @pytest.mark.anyio + async def test_savepoint_rollback_suppresses_events(self, mixin_session): + """Objects from a rolled-back savepoint don't fire callbacks.""" + from fastapi_toolsets.db import get_transaction + + survivor = WatchedModel(status="kept", other="x") + mixin_session.add(survivor) + await mixin_session.flush() + + try: + async with get_transaction(mixin_session): + doomed = WatchedModel(status="doomed", other="y") + mixin_session.add(doomed) + await mixin_session.flush() + raise ValueError("rollback this savepoint") + except ValueError: + pass + + await mixin_session.commit() + + creates = [e for e in _test_events if e["event"] == "create"] + assert len(creates) == 1 + assert creates[0]["obj_id"] == survivor.id + + @pytest.mark.anyio + async def test_lock_tables_with_events(self, mixin_session): + """Events fire correctly after lock_tables context.""" + from fastapi_toolsets.db import lock_tables + + async with lock_tables(mixin_session, [WatchedModel]): + obj = WatchedModel(status="locked", other="x") + mixin_session.add(obj) + + await mixin_session.commit() + + creates = [e for e in _test_events if e["event"] == "create"] + assert len(creates) == 1 + + @pytest.mark.anyio + async def test_update_inside_get_transaction(self, mixin_session): + """UPDATE events fire with correct changes after get_transaction commit.""" + from fastapi_toolsets.db import get_transaction + + obj = WatchedModel(status="initial", other="x") + mixin_session.add(obj) + await mixin_session.commit() + + _test_events.clear() + + async with get_transaction(mixin_session): + obj.status = "updated" + + await mixin_session.commit() + + updates = [e for e in _test_events if e["event"] == "update"] + assert len(updates) == 1 + assert updates[0]["changes"]["status"] == { + "old": "initial", + "new": "updated", + } + + +class TestEventSessionWithNullableFields: + """Regression tests for nullable field access in callbacks (the original bug).""" + + @pytest.fixture(autouse=True) + def clear_events(self): + _attr_access_events.clear() + yield + _attr_access_events.clear() + + @pytest.mark.anyio + async def test_nullable_field_none_in_create(self, mixin_session_expire): + """Nullable field left as None is accessible in CREATE callback (expire_on_commit=True).""" + obj = AttrAccessModel(name="test") + mixin_session_expire.add(obj) + await mixin_session_expire.commit() + + events = [e for e in _attr_access_events if e["event"] == "create"] + assert len(events) == 1 + assert events[0]["callback_url"] is None + assert events[0]["name"] == "test" + + @pytest.mark.anyio + async def test_nullable_field_set_in_create(self, mixin_session_expire): + """Nullable field with a value is accessible in CREATE callback (expire_on_commit=True).""" + obj = AttrAccessModel(name="test", callback_url="https://hook.example.com") + mixin_session_expire.add(obj) + await mixin_session_expire.commit() + + events = [e for e in _attr_access_events if e["event"] == "create"] + assert len(events) == 1 + assert events[0]["callback_url"] == "https://hook.example.com" + + @pytest.mark.anyio + async def test_nullable_field_in_delete(self, mixin_session_expire): + """Nullable field is accessible in DELETE callback via snapshot restore.""" + obj = AttrAccessModel(name="to-delete", callback_url="https://hook.example.com") + mixin_session_expire.add(obj) + await mixin_session_expire.commit() + + _attr_access_events.clear() + + await mixin_session_expire.delete(obj) + await mixin_session_expire.commit() + + events = [e for e in _attr_access_events if e["event"] == "delete"] + assert len(events) == 1 + assert events[0]["callback_url"] == "https://hook.example.com" + assert events[0]["name"] == "to-delete" + + @pytest.mark.anyio + async def test_nullable_field_updated_to_none(self, mixin_session_expire): + """Nullable field changed to None is accessible in UPDATE callback.""" + obj = AttrAccessModel(name="x", callback_url="https://hook.example.com") + mixin_session_expire.add(obj) + await mixin_session_expire.commit() + + _attr_access_events.clear() + + obj.callback_url = None + await mixin_session_expire.commit() + + events = [e for e in _attr_access_events if e["event"] == "update"] + assert len(events) == 1 + assert events[0]["callback_url"] is None + + @pytest.mark.anyio + async def test_nullable_field_updated_from_none(self, mixin_session_expire): + """Nullable field changed from None to a value is accessible in UPDATE callback.""" + obj = AttrAccessModel(name="x") + mixin_session_expire.add(obj) + await mixin_session_expire.commit() + + _attr_access_events.clear() + + obj.callback_url = "https://new-hook.example.com" + await mixin_session_expire.commit() + + events = [e for e in _attr_access_events if e["event"] == "update"] + assert len(events) == 1 + assert events[0]["callback_url"] == "https://new-hook.example.com" + + +class TestEventSessionWithFastAPIDependency: + """Verify EventSession works when session comes from create_db_dependency.""" + + @pytest.fixture(autouse=True) + def clear_events(self): + _test_events.clear() + yield + _test_events.clear() + + @pytest.mark.anyio + async def test_create_event_fires_via_dependency(self): + """CREATE callback fires when session is provided by create_db_dependency.""" + from fastapi import Depends, FastAPI + from httpx import ASGITransport, AsyncClient + from sqlalchemy.ext.asyncio import ( + AsyncSession, + async_sessionmaker, + create_async_engine, + ) + + from fastapi_toolsets.db import create_db_dependency + from fastapi_toolsets.models import EventSession + + engine = create_async_engine(DATABASE_URL, echo=False) + session_factory = async_sessionmaker( + engine, expire_on_commit=False, class_=EventSession + ) + + async with engine.begin() as conn: + await conn.run_sync(MixinBase.metadata.create_all) + + get_db = create_db_dependency(session_factory) + app = FastAPI() + + @app.post("/watched") + async def create_watched(session: AsyncSession = Depends(get_db)): + obj = WatchedModel(status="from-api", other="x") + session.add(obj) + return {"id": str(obj.id)} + + try: + transport = ASGITransport(app=app) + async with AsyncClient( + transport=transport, base_url="http://test" + ) as client: + response = await client.post("/watched") + + assert response.status_code == 200 + + creates = [e for e in _test_events if e["event"] == "create"] + assert len(creates) == 1 + finally: + async with engine.begin() as conn: + await conn.run_sync(MixinBase.metadata.drop_all) + await engine.dispose() + + @pytest.mark.anyio + async def test_update_event_fires_via_dependency(self): + """UPDATE callback fires when session is provided by create_db_dependency.""" + from fastapi import Depends, FastAPI + from httpx import ASGITransport, AsyncClient + from sqlalchemy.ext.asyncio import ( + AsyncSession, + async_sessionmaker, + create_async_engine, + ) + + from fastapi_toolsets.db import create_db_dependency + from fastapi_toolsets.models import EventSession + + engine = create_async_engine(DATABASE_URL, echo=False) + session_factory = async_sessionmaker( + engine, expire_on_commit=False, class_=EventSession + ) + + async with engine.begin() as conn: + await conn.run_sync(MixinBase.metadata.create_all) + + get_db = create_db_dependency(session_factory) + app = FastAPI() + + # Pre-seed an object. + async with session_factory() as seed_session: + obj = WatchedModel(status="initial", other="x") + seed_session.add(obj) + await seed_session.commit() + obj_id = obj.id + + _test_events.clear() + + @app.put("/watched/{item_id}") + async def update_watched(item_id: str, session: AsyncSession = Depends(get_db)): + from sqlalchemy import select + + stmt = select(WatchedModel).where(WatchedModel.id == item_id) + result = await session.execute(stmt) + item = result.scalar_one() + item.status = "updated-via-api" + return {"ok": True} + + try: + transport = ASGITransport(app=app) + async with AsyncClient( + transport=transport, base_url="http://test" + ) as client: + response = await client.put(f"/watched/{obj_id}") + + assert response.status_code == 200 + + updates = [e for e in _test_events if e["event"] == "update"] + assert len(updates) == 1 + assert updates[0]["changes"]["status"]["new"] == "updated-via-api" + finally: + async with engine.begin() as conn: + await conn.run_sync(MixinBase.metadata.drop_all) + await engine.dispose() diff --git a/zensical.toml b/zensical.toml index 888944a..c884a3d 100644 --- a/zensical.toml +++ b/zensical.toml @@ -140,6 +140,7 @@ Examples = [ [[project.nav]] Migration = [ + {"v3.0" = "migration/v3.md"}, {"v2.0" = "migration/v2.md"}, ]