mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-15 22:26:25 +02:00
feat: add WatchedFieldsMixin (#148)
* feat/add WatchedFieldsMixin and watch_fields decorator for field-change monitoring * docs: add WatchedFieldsMixin * feat: add on_event, on_create and on_delete * docs: update README
This commit is contained in:
@@ -48,7 +48,7 @@ uv add "fastapi-toolsets[all]"
|
|||||||
- **Database**: Session management, transaction helpers, table locking, and polling-based row change detection
|
- **Database**: Session management, transaction helpers, table locking, and polling-based row change detection
|
||||||
- **Dependencies**: FastAPI dependency factories (`PathDependency`, `BodyDependency`) for automatic DB lookups from path or body parameters
|
- **Dependencies**: FastAPI dependency factories (`PathDependency`, `BodyDependency`) for automatic DB lookups from path or body parameters
|
||||||
- **Fixtures**: Fixture system with dependency management, context support, and pytest integration
|
- **Fixtures**: Fixture system with dependency management, context support, and pytest integration
|
||||||
- **Model Mixins**: SQLAlchemy mixins for common column patterns (`UUIDMixin`, `UUIDv7Mixin`, `CreatedAtMixin`, `UpdatedAtMixin`, `TimestampMixin`)
|
- **Model Mixins**: SQLAlchemy mixins for common column patterns (`UUIDMixin`, `UUIDv7Mixin`, `CreatedAtMixin`, `UpdatedAtMixin`, `TimestampMixin`) and lifecycle callbacks (`WatchedFieldsMixin`, `@watch`) that fire after commit for insert, update, and delete events
|
||||||
- **Standardized API Responses**: Consistent response format with `Response`, `ErrorResponse`, `PaginatedResponse`, `CursorPaginatedResponse` and `OffsetPaginatedResponse`.
|
- **Standardized API Responses**: Consistent response format with `Response`, `ErrorResponse`, `PaginatedResponse`, `CursorPaginatedResponse` and `OffsetPaginatedResponse`.
|
||||||
- **Exception Handling**: Structured error responses with automatic OpenAPI documentation
|
- **Exception Handling**: Structured error responses with automatic OpenAPI documentation
|
||||||
- **Logging**: Logging configuration with uvicorn integration via `configure_logging` and `get_logger`
|
- **Logging**: Logging configuration with uvicorn integration via `configure_logging` and `get_logger`
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ uv add "fastapi-toolsets[all]"
|
|||||||
- **Database**: Session management, transaction helpers, table locking, and polling-based row change detection
|
- **Database**: Session management, transaction helpers, table locking, and polling-based row change detection
|
||||||
- **Dependencies**: FastAPI dependency factories (`PathDependency`, `BodyDependency`) for automatic DB lookups from path or body parameters
|
- **Dependencies**: FastAPI dependency factories (`PathDependency`, `BodyDependency`) for automatic DB lookups from path or body parameters
|
||||||
- **Fixtures**: Fixture system with dependency management, context support, and pytest integration
|
- **Fixtures**: Fixture system with dependency management, context support, and pytest integration
|
||||||
- **Model Mixins**: SQLAlchemy mixins for common column patterns (`UUIDMixin`, `UUIDv7Mixin`, `CreatedAtMixin`, `UpdatedAtMixin`, `TimestampMixin`)
|
- **Model Mixins**: SQLAlchemy mixins for common column patterns (`UUIDMixin`, `UUIDv7Mixin`, `CreatedAtMixin`, `UpdatedAtMixin`, `TimestampMixin`) and lifecycle callbacks (`WatchedFieldsMixin`) that fire after commit for insert, update, and delete events.
|
||||||
- **Standardized API Responses**: Consistent response format with `Response`, `ErrorResponse`, `PaginatedResponse`, `CursorPaginatedResponse` and `OffsetPaginatedResponse`.
|
- **Standardized API Responses**: Consistent response format with `Response`, `ErrorResponse`, `PaginatedResponse`, `CursorPaginatedResponse` and `OffsetPaginatedResponse`.
|
||||||
- **Exception Handling**: Structured error responses with automatic OpenAPI documentation
|
- **Exception Handling**: Structured error responses with automatic OpenAPI documentation
|
||||||
- **Logging**: Logging configuration with uvicorn integration via `configure_logging` and `get_logger`
|
- **Logging**: Logging configuration with uvicorn integration via `configure_logging` and `get_logger`
|
||||||
|
|||||||
@@ -117,6 +117,86 @@ class Article(Base, UUIDMixin, TimestampMixin):
|
|||||||
title: Mapped[str]
|
title: Mapped[str]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### [`WatchedFieldsMixin`](../reference/models.md#fastapi_toolsets.models.WatchedFieldsMixin)
|
||||||
|
|
||||||
|
!!! info "Added in `v2.4`"
|
||||||
|
|
||||||
|
`WatchedFieldsMixin` provides lifecycle callbacks that fire **after commit** — meaning the row is durably persisted when your callback runs. If the transaction rolls back, no callback fires.
|
||||||
|
|
||||||
|
Three callbacks are available, each corresponding to a [`ModelEvent`](../reference/models.md#fastapi_toolsets.models.ModelEvent) value:
|
||||||
|
|
||||||
|
| Callback | Event | Trigger |
|
||||||
|
|---|---|---|
|
||||||
|
| `on_create()` | `ModelEvent.CREATE` | After `INSERT` |
|
||||||
|
| `on_delete()` | `ModelEvent.DELETE` | After `DELETE` |
|
||||||
|
| `on_update(changes)` | `ModelEvent.UPDATE` | After `UPDATE` on a watched field |
|
||||||
|
|
||||||
|
Server-side defaults (e.g. `id`, `created_at`) are fully populated in all callbacks. All callbacks support both `async def` and plain `def`. Use `@watch` to restrict which fields trigger `on_update`:
|
||||||
|
|
||||||
|
| Decorator | `on_update` behaviour |
|
||||||
|
|---|---|
|
||||||
|
| `@watch("status", "role")` | Only fires when `status` or `role` changes |
|
||||||
|
| *(no decorator)* | Fires when **any** mapped field changes |
|
||||||
|
|
||||||
|
#### Option 1 — catch-all with `on_event`
|
||||||
|
|
||||||
|
Override `on_event` to handle all event types in one place. The specific methods delegate here by default:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.models import ModelEvent, UUIDMixin, WatchedFieldsMixin, watch
|
||||||
|
|
||||||
|
@watch("status")
|
||||||
|
class Order(Base, UUIDMixin, WatchedFieldsMixin):
|
||||||
|
__tablename__ = "orders"
|
||||||
|
|
||||||
|
status: Mapped[str]
|
||||||
|
|
||||||
|
async def on_event(self, event: ModelEvent, changes: dict | None = None) -> None:
|
||||||
|
if event == ModelEvent.CREATE:
|
||||||
|
await notify_new_order(self.id)
|
||||||
|
elif event == ModelEvent.DELETE:
|
||||||
|
await notify_order_cancelled(self.id)
|
||||||
|
elif event == ModelEvent.UPDATE:
|
||||||
|
await notify_status_change(self.id, changes["status"])
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Option 2 — targeted overrides
|
||||||
|
|
||||||
|
Override individual methods for more focused logic:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@watch("status")
|
||||||
|
class Order(Base, UUIDMixin, WatchedFieldsMixin):
|
||||||
|
__tablename__ = "orders"
|
||||||
|
|
||||||
|
status: Mapped[str]
|
||||||
|
|
||||||
|
async def on_create(self) -> None:
|
||||||
|
await notify_new_order(self.id)
|
||||||
|
|
||||||
|
async def on_delete(self) -> None:
|
||||||
|
await notify_order_cancelled(self.id)
|
||||||
|
|
||||||
|
async def on_update(self, changes: dict) -> None:
|
||||||
|
if "status" in changes:
|
||||||
|
old = changes["status"]["old"]
|
||||||
|
new = changes["status"]["new"]
|
||||||
|
await notify_status_change(self.id, old, new)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Field changes format
|
||||||
|
|
||||||
|
The `changes` dict maps each watched field that changed to `{"old": ..., "new": ...}`. Only fields that actually changed are included:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# status changed → {"status": {"old": "pending", "new": "shipped"}}
|
||||||
|
# two fields changed → {"status": {...}, "assigned_to": {...}}
|
||||||
|
```
|
||||||
|
|
||||||
|
!!! info "Multiple flushes in one transaction are merged: the earliest `old` and latest `new` are preserved, and `on_update` fires only once per commit."
|
||||||
|
|
||||||
|
!!! warning "Callbacks fire only for ORM-level changes. Rows updated via raw SQL (`UPDATE ... SET ...`) are not detected."
|
||||||
|
|
||||||
## Composing mixins
|
## Composing mixins
|
||||||
|
|
||||||
All mixins can be combined in any order. The only constraint is that exactly one primary key must be defined — either via `UUIDMixin` or directly on the model.
|
All mixins can be combined in any order. The only constraint is that exactly one primary key must be defined — either via `UUIDMixin` or directly on the model.
|
||||||
|
|||||||
@@ -6,14 +6,19 @@ You can import them directly from `fastapi_toolsets.models`:
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
from fastapi_toolsets.models import (
|
from fastapi_toolsets.models import (
|
||||||
|
ModelEvent,
|
||||||
UUIDMixin,
|
UUIDMixin,
|
||||||
UUIDv7Mixin,
|
UUIDv7Mixin,
|
||||||
CreatedAtMixin,
|
CreatedAtMixin,
|
||||||
UpdatedAtMixin,
|
UpdatedAtMixin,
|
||||||
TimestampMixin,
|
TimestampMixin,
|
||||||
|
WatchedFieldsMixin,
|
||||||
|
watch,
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.models.ModelEvent
|
||||||
|
|
||||||
## ::: fastapi_toolsets.models.UUIDMixin
|
## ::: fastapi_toolsets.models.UUIDMixin
|
||||||
|
|
||||||
## ::: fastapi_toolsets.models.UUIDv7Mixin
|
## ::: fastapi_toolsets.models.UUIDv7Mixin
|
||||||
@@ -23,3 +28,7 @@ from fastapi_toolsets.models import (
|
|||||||
## ::: fastapi_toolsets.models.UpdatedAtMixin
|
## ::: fastapi_toolsets.models.UpdatedAtMixin
|
||||||
|
|
||||||
## ::: fastapi_toolsets.models.TimestampMixin
|
## ::: fastapi_toolsets.models.TimestampMixin
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.models.WatchedFieldsMixin
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.models.watch
|
||||||
|
|||||||
21
src/fastapi_toolsets/models/__init__.py
Normal file
21
src/fastapi_toolsets/models/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
"""SQLAlchemy model mixins for common column patterns."""
|
||||||
|
|
||||||
|
from .columns import (
|
||||||
|
CreatedAtMixin,
|
||||||
|
TimestampMixin,
|
||||||
|
UUIDMixin,
|
||||||
|
UUIDv7Mixin,
|
||||||
|
UpdatedAtMixin,
|
||||||
|
)
|
||||||
|
from .watched import ModelEvent, WatchedFieldsMixin, watch
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ModelEvent",
|
||||||
|
"UUIDMixin",
|
||||||
|
"UUIDv7Mixin",
|
||||||
|
"CreatedAtMixin",
|
||||||
|
"UpdatedAtMixin",
|
||||||
|
"TimestampMixin",
|
||||||
|
"WatchedFieldsMixin",
|
||||||
|
"watch",
|
||||||
|
]
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
"""SQLAlchemy model mixins for common column patterns."""
|
"""SQLAlchemy column mixins for common column patterns."""
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
204
src/fastapi_toolsets/models/watched.py
Normal file
204
src/fastapi_toolsets/models/watched.py
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
"""Field-change monitoring via SQLAlchemy session events."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import weakref
|
||||||
|
from collections.abc import Awaitable
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
|
from sqlalchemy import event
|
||||||
|
from sqlalchemy import inspect as sa_inspect
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from ..logger import get_logger
|
||||||
|
|
||||||
|
__all__ = ["ModelEvent", "WatchedFieldsMixin", "watch"]
|
||||||
|
|
||||||
|
_logger = get_logger()
|
||||||
|
_T = TypeVar("_T")
|
||||||
|
_CALLBACK_ERROR_MSG = "WatchedFieldsMixin callback raised an unhandled exception"
|
||||||
|
_WATCHED_FIELDS: weakref.WeakKeyDictionary[type, list[str]] = (
|
||||||
|
weakref.WeakKeyDictionary()
|
||||||
|
)
|
||||||
|
_SESSION_PENDING_NEW = "_ft_pending_new"
|
||||||
|
_SESSION_CREATES = "_ft_creates"
|
||||||
|
_SESSION_DELETES = "_ft_deletes"
|
||||||
|
_SESSION_UPDATES = "_ft_updates"
|
||||||
|
|
||||||
|
|
||||||
|
class ModelEvent(str, Enum):
|
||||||
|
"""Event types emitted by :class:`WatchedFieldsMixin`."""
|
||||||
|
|
||||||
|
CREATE = "create"
|
||||||
|
DELETE = "delete"
|
||||||
|
UPDATE = "update"
|
||||||
|
|
||||||
|
|
||||||
|
def watch(*fields: str) -> Any:
|
||||||
|
"""Class decorator to filter which fields trigger ``on_update``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*fields: One or more field names to watch. At least one name is required.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If called with no field names.
|
||||||
|
"""
|
||||||
|
if not fields:
|
||||||
|
raise ValueError("@watch requires at least one field name.")
|
||||||
|
|
||||||
|
def decorator(cls: type[_T]) -> type[_T]:
|
||||||
|
_WATCHED_FIELDS[cls] = list(fields)
|
||||||
|
return cls
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def _upsert_changes(
|
||||||
|
pending: dict[int, tuple[Any, dict[str, dict[str, Any]]]],
|
||||||
|
obj: Any,
|
||||||
|
changes: dict[str, dict[str, Any]],
|
||||||
|
) -> None:
|
||||||
|
"""Insert or merge *changes* into *pending* for *obj*."""
|
||||||
|
key = id(obj)
|
||||||
|
if key in pending:
|
||||||
|
existing = pending[key][1]
|
||||||
|
for field, change in changes.items():
|
||||||
|
if field in existing:
|
||||||
|
existing[field]["new"] = change["new"]
|
||||||
|
else:
|
||||||
|
existing[field] = change
|
||||||
|
else:
|
||||||
|
pending[key] = (obj, changes)
|
||||||
|
|
||||||
|
|
||||||
|
@event.listens_for(AsyncSession.sync_session_class, "after_flush")
|
||||||
|
def _after_flush(session: Any, flush_context: Any) -> None:
|
||||||
|
# New objects: capture references while session.new is still populated.
|
||||||
|
# Values are read in _after_flush_postexec once RETURNING has been processed.
|
||||||
|
for obj in session.new:
|
||||||
|
if isinstance(obj, WatchedFieldsMixin):
|
||||||
|
session.info.setdefault(_SESSION_PENDING_NEW, []).append(obj)
|
||||||
|
|
||||||
|
# Deleted objects: capture before they leave the identity map.
|
||||||
|
for obj in session.deleted:
|
||||||
|
if isinstance(obj, WatchedFieldsMixin):
|
||||||
|
session.info.setdefault(_SESSION_DELETES, []).append(obj)
|
||||||
|
|
||||||
|
# Dirty objects: read old/new from SQLAlchemy attribute history.
|
||||||
|
for obj in session.dirty:
|
||||||
|
if not isinstance(obj, WatchedFieldsMixin):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# None = not in dict = watch all fields; list = specific fields only
|
||||||
|
watched = _WATCHED_FIELDS.get(type(obj))
|
||||||
|
changes: dict[str, dict[str, Any]] = {}
|
||||||
|
|
||||||
|
attrs = (
|
||||||
|
# Specific fields
|
||||||
|
((field, sa_inspect(obj).attrs[field]) for field in watched)
|
||||||
|
if watched is not None
|
||||||
|
# All mapped fields
|
||||||
|
else ((s.key, s) for s in sa_inspect(obj).attrs)
|
||||||
|
)
|
||||||
|
for field, attr_state in attrs:
|
||||||
|
history = attr_state.history
|
||||||
|
if history.has_changes() and history.deleted:
|
||||||
|
changes[field] = {
|
||||||
|
"old": history.deleted[0],
|
||||||
|
"new": history.added[0] if history.added else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
if changes:
|
||||||
|
_upsert_changes(
|
||||||
|
session.info.setdefault(_SESSION_UPDATES, {}),
|
||||||
|
obj,
|
||||||
|
changes,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@event.listens_for(AsyncSession.sync_session_class, "after_flush_postexec")
|
||||||
|
def _after_flush_postexec(session: Any, flush_context: Any) -> None:
|
||||||
|
# New objects are now persistent and RETURNING values have been applied,
|
||||||
|
# so server defaults (id, created_at, …) are available via getattr.
|
||||||
|
pending_new: list[Any] = session.info.pop(_SESSION_PENDING_NEW, [])
|
||||||
|
if not pending_new:
|
||||||
|
return
|
||||||
|
session.info.setdefault(_SESSION_CREATES, []).extend(pending_new)
|
||||||
|
|
||||||
|
|
||||||
|
@event.listens_for(AsyncSession.sync_session_class, "after_rollback")
|
||||||
|
def _after_rollback(session: Any) -> None:
|
||||||
|
session.info.pop(_SESSION_PENDING_NEW, None)
|
||||||
|
session.info.pop(_SESSION_CREATES, None)
|
||||||
|
session.info.pop(_SESSION_DELETES, None)
|
||||||
|
session.info.pop(_SESSION_UPDATES, None)
|
||||||
|
|
||||||
|
|
||||||
|
def _task_error_handler(task: asyncio.Task[Any]) -> None:
|
||||||
|
if not task.cancelled() and (exc := task.exception()):
|
||||||
|
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
|
||||||
|
|
||||||
|
|
||||||
|
def _call_callback(loop: asyncio.AbstractEventLoop, fn: Any, *args: Any) -> None:
|
||||||
|
"""Dispatch *fn* with *args*, handling both sync and async callables."""
|
||||||
|
try:
|
||||||
|
result = fn(*args)
|
||||||
|
except Exception as exc:
|
||||||
|
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
|
||||||
|
return
|
||||||
|
if asyncio.iscoroutine(result):
|
||||||
|
task = loop.create_task(result)
|
||||||
|
task.add_done_callback(_task_error_handler)
|
||||||
|
|
||||||
|
|
||||||
|
@event.listens_for(AsyncSession.sync_session_class, "after_commit")
|
||||||
|
def _after_commit(session: Any) -> None:
|
||||||
|
creates: list[Any] = session.info.pop(_SESSION_CREATES, [])
|
||||||
|
deletes: list[Any] = session.info.pop(_SESSION_DELETES, [])
|
||||||
|
field_changes: dict[int, tuple[Any, dict[str, dict[str, Any]]]] = session.info.pop(
|
||||||
|
_SESSION_UPDATES, {}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not creates and not deletes and not field_changes:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
return
|
||||||
|
|
||||||
|
for obj in creates:
|
||||||
|
_call_callback(loop, obj.on_create)
|
||||||
|
|
||||||
|
for obj in deletes:
|
||||||
|
_call_callback(loop, obj.on_delete)
|
||||||
|
|
||||||
|
for obj, changes in field_changes.values():
|
||||||
|
_call_callback(loop, obj.on_update, changes)
|
||||||
|
|
||||||
|
|
||||||
|
class WatchedFieldsMixin:
|
||||||
|
"""Mixin that enables lifecycle callbacks for SQLAlchemy models."""
|
||||||
|
|
||||||
|
def on_event(
|
||||||
|
self, event: ModelEvent, changes: dict[str, dict[str, Any]] | None = None
|
||||||
|
) -> Awaitable[None] | None:
|
||||||
|
"""Catch-all callback fired for every lifecycle event.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event: The event type (:attr:`ModelEvent.CREATE`, :attr:`ModelEvent.DELETE`,
|
||||||
|
or :attr:`ModelEvent.UPDATE`).
|
||||||
|
changes: Field changes for :attr:`ModelEvent.UPDATE`, ``None`` otherwise.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def on_create(self) -> Awaitable[None] | None:
|
||||||
|
"""Called after INSERT commit."""
|
||||||
|
return self.on_event(ModelEvent.CREATE)
|
||||||
|
|
||||||
|
def on_delete(self) -> Awaitable[None] | None:
|
||||||
|
"""Called after DELETE commit."""
|
||||||
|
return self.on_event(ModelEvent.DELETE)
|
||||||
|
|
||||||
|
def on_update(self, changes: dict[str, dict[str, Any]]) -> Awaitable[None] | None:
|
||||||
|
"""Called after UPDATE commit when watched fields change."""
|
||||||
|
return self.on_event(ModelEvent.UPDATE, changes=changes)
|
||||||
@@ -1,7 +1,12 @@
|
|||||||
"""Tests for fastapi_toolsets.models mixins."""
|
"""Tests for fastapi_toolsets.models mixins."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import uuid
|
import uuid
|
||||||
|
from contextlib import suppress
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import fastapi_toolsets.models.watched as _watched_module
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy import String
|
from sqlalchemy import String
|
||||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||||
@@ -9,10 +14,26 @@ from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
|||||||
|
|
||||||
from fastapi_toolsets.models import (
|
from fastapi_toolsets.models import (
|
||||||
CreatedAtMixin,
|
CreatedAtMixin,
|
||||||
|
ModelEvent,
|
||||||
TimestampMixin,
|
TimestampMixin,
|
||||||
UUIDMixin,
|
UUIDMixin,
|
||||||
UUIDv7Mixin,
|
UUIDv7Mixin,
|
||||||
UpdatedAtMixin,
|
UpdatedAtMixin,
|
||||||
|
WatchedFieldsMixin,
|
||||||
|
watch,
|
||||||
|
)
|
||||||
|
from fastapi_toolsets.models.watched import (
|
||||||
|
_SESSION_CREATES,
|
||||||
|
_SESSION_DELETES,
|
||||||
|
_SESSION_UPDATES,
|
||||||
|
_SESSION_PENDING_NEW,
|
||||||
|
_after_commit,
|
||||||
|
_after_flush,
|
||||||
|
_after_flush_postexec,
|
||||||
|
_after_rollback,
|
||||||
|
_call_callback,
|
||||||
|
_task_error_handler,
|
||||||
|
_upsert_changes,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .conftest import DATABASE_URL
|
from .conftest import DATABASE_URL
|
||||||
@@ -61,6 +82,80 @@ class FullMixinModel(MixinBase, UUIDMixin, UpdatedAtMixin):
|
|||||||
name: Mapped[str] = mapped_column(String(50))
|
name: Mapped[str] = mapped_column(String(50))
|
||||||
|
|
||||||
|
|
||||||
|
# --- WatchedFieldsMixin test models ---
|
||||||
|
|
||||||
|
_test_events: list[dict] = []
|
||||||
|
|
||||||
|
|
||||||
|
@watch("status")
|
||||||
|
class WatchedModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
|
||||||
|
__tablename__ = "mixin_watched_models"
|
||||||
|
|
||||||
|
status: Mapped[str] = mapped_column(String(50))
|
||||||
|
other: Mapped[str] = mapped_column(String(50))
|
||||||
|
|
||||||
|
async def on_create(self) -> None:
|
||||||
|
_test_events.append({"event": "create", "obj_id": self.id})
|
||||||
|
|
||||||
|
async def on_delete(self) -> None:
|
||||||
|
_test_events.append({"event": "delete", "obj_id": self.id})
|
||||||
|
|
||||||
|
async def on_update(self, changes: dict) -> None:
|
||||||
|
_test_events.append({"event": "update", "obj_id": self.id, "changes": changes})
|
||||||
|
|
||||||
|
|
||||||
|
@watch("value")
|
||||||
|
class OnEventModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
|
||||||
|
"""Model that only overrides on_event to test the catch-all path."""
|
||||||
|
|
||||||
|
__tablename__ = "mixin_on_event_models"
|
||||||
|
|
||||||
|
value: Mapped[str] = mapped_column(String(50))
|
||||||
|
|
||||||
|
async def on_event(self, event: ModelEvent, changes: dict | None = None) -> None:
|
||||||
|
_test_events.append({"event": event, "obj_id": self.id, "changes": changes})
|
||||||
|
|
||||||
|
|
||||||
|
class WatchAllModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
|
||||||
|
"""Model without @watch — watches all mapped fields by default."""
|
||||||
|
|
||||||
|
__tablename__ = "mixin_watch_all_models"
|
||||||
|
|
||||||
|
status: Mapped[str] = mapped_column(String(50))
|
||||||
|
other: Mapped[str] = mapped_column(String(50))
|
||||||
|
|
||||||
|
async def on_update(self, changes: dict) -> None:
|
||||||
|
_test_events.append({"event": "update", "obj_id": self.id, "changes": changes})
|
||||||
|
|
||||||
|
|
||||||
|
class NonWatchedModel(MixinBase):
|
||||||
|
__tablename__ = "mixin_non_watched_models"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||||
|
value: Mapped[str] = mapped_column(String(50))
|
||||||
|
|
||||||
|
|
||||||
|
_sync_events: list[dict] = []
|
||||||
|
|
||||||
|
|
||||||
|
@watch("status")
|
||||||
|
class SyncCallbackModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
|
||||||
|
"""Model with plain (sync) on_* callbacks."""
|
||||||
|
|
||||||
|
__tablename__ = "mixin_sync_callback_models"
|
||||||
|
|
||||||
|
status: Mapped[str] = mapped_column(String(50))
|
||||||
|
|
||||||
|
def on_create(self) -> None:
|
||||||
|
_sync_events.append({"event": "create", "obj_id": self.id})
|
||||||
|
|
||||||
|
def on_delete(self) -> None:
|
||||||
|
_sync_events.append({"event": "delete", "obj_id": self.id})
|
||||||
|
|
||||||
|
def on_update(self, changes: dict) -> None:
|
||||||
|
_sync_events.append({"event": "update", "changes": changes})
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
async def mixin_session():
|
async def mixin_session():
|
||||||
engine = create_async_engine(DATABASE_URL, echo=False)
|
engine = create_async_engine(DATABASE_URL, echo=False)
|
||||||
@@ -124,7 +219,7 @@ class TestUpdatedAtMixin:
|
|||||||
await mixin_session.refresh(obj)
|
await mixin_session.refresh(obj)
|
||||||
|
|
||||||
assert obj.updated_at is not None
|
assert obj.updated_at is not None
|
||||||
assert obj.updated_at.tzinfo is not None # timezone-aware
|
assert obj.updated_at.tzinfo is not None
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_updated_at_changes_on_update(self, mixin_session):
|
async def test_updated_at_changes_on_update(self, mixin_session):
|
||||||
@@ -171,7 +266,7 @@ class TestCreatedAtMixin:
|
|||||||
await mixin_session.refresh(obj)
|
await mixin_session.refresh(obj)
|
||||||
|
|
||||||
assert obj.created_at is not None
|
assert obj.created_at is not None
|
||||||
assert obj.created_at.tzinfo is not None # timezone-aware
|
assert obj.created_at.tzinfo is not None
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_created_at_not_changed_on_update(self, mixin_session):
|
async def test_created_at_not_changed_on_update(self, mixin_session):
|
||||||
@@ -296,3 +391,577 @@ class TestFullMixinModel:
|
|||||||
assert isinstance(obj.id, uuid.UUID)
|
assert isinstance(obj.id, uuid.UUID)
|
||||||
assert obj.updated_at is not None
|
assert obj.updated_at is not None
|
||||||
assert obj.updated_at.tzinfo is not None
|
assert obj.updated_at.tzinfo is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestWatchDecorator:
|
||||||
|
def test_registers_specific_fields(self):
|
||||||
|
"""@watch("field") stores the field list in _WATCHED_FIELDS."""
|
||||||
|
assert _watched_module._WATCHED_FIELDS.get(WatchedModel) == ["status"]
|
||||||
|
|
||||||
|
def test_no_decorator_not_in_watched_fields(self):
|
||||||
|
"""A model without @watch has no entry in _WATCHED_FIELDS (watch all)."""
|
||||||
|
assert WatchAllModel not in _watched_module._WATCHED_FIELDS
|
||||||
|
|
||||||
|
def test_preserves_class_identity(self):
|
||||||
|
"""watch returns the same class unchanged."""
|
||||||
|
|
||||||
|
class _Dummy(WatchedFieldsMixin):
|
||||||
|
pass
|
||||||
|
|
||||||
|
result = watch("x")(_Dummy)
|
||||||
|
assert result is _Dummy
|
||||||
|
del _watched_module._WATCHED_FIELDS[_Dummy]
|
||||||
|
|
||||||
|
def test_raises_when_no_fields_given(self):
|
||||||
|
"""@watch() with no field names raises ValueError."""
|
||||||
|
with pytest.raises(ValueError, match="@watch requires at least one field name"):
|
||||||
|
watch()
|
||||||
|
|
||||||
|
|
||||||
|
class TestUpsertChanges:
|
||||||
|
def test_inserts_new_entry(self):
|
||||||
|
"""New key is inserted with the full changes dict."""
|
||||||
|
pending: dict = {}
|
||||||
|
obj = object()
|
||||||
|
changes = {"status": {"old": None, "new": "active"}}
|
||||||
|
_upsert_changes(pending, obj, changes)
|
||||||
|
assert pending[id(obj)] == (obj, changes)
|
||||||
|
|
||||||
|
def test_merges_existing_field_keeps_old_updates_new(self):
|
||||||
|
"""When the field already exists, old is preserved and new is overwritten."""
|
||||||
|
obj = object()
|
||||||
|
pending = {
|
||||||
|
id(obj): (obj, {"status": {"old": "initial", "new": "intermediate"}})
|
||||||
|
}
|
||||||
|
_upsert_changes(
|
||||||
|
pending, obj, {"status": {"old": "intermediate", "new": "final"}}
|
||||||
|
)
|
||||||
|
assert pending[id(obj)][1]["status"] == {"old": "initial", "new": "final"}
|
||||||
|
|
||||||
|
def test_adds_new_field_to_existing_entry(self):
|
||||||
|
"""A previously unseen field is added alongside existing ones."""
|
||||||
|
obj = object()
|
||||||
|
pending = {id(obj): (obj, {"status": {"old": "a", "new": "b"}})}
|
||||||
|
_upsert_changes(pending, obj, {"role": {"old": "user", "new": "admin"}})
|
||||||
|
fields = pending[id(obj)][1]
|
||||||
|
assert fields["status"] == {"old": "a", "new": "b"}
|
||||||
|
assert fields["role"] == {"old": "user", "new": "admin"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestAfterFlush:
|
||||||
|
def test_does_nothing_with_empty_session(self):
|
||||||
|
"""_after_flush writes nothing to session.info when all collections are empty."""
|
||||||
|
session = SimpleNamespace(new=[], deleted=[], dirty=[], info={})
|
||||||
|
_after_flush(session, None)
|
||||||
|
assert session.info == {}
|
||||||
|
|
||||||
|
def test_captures_new_watched_mixin_objects(self):
|
||||||
|
"""New WatchedFieldsMixin instances are added to _SESSION_PENDING_NEW."""
|
||||||
|
obj = WatchedFieldsMixin()
|
||||||
|
session = SimpleNamespace(new=[obj], deleted=[], dirty=[], info={})
|
||||||
|
_after_flush(session, None)
|
||||||
|
assert session.info[_SESSION_PENDING_NEW] == [obj]
|
||||||
|
|
||||||
|
def test_ignores_new_non_mixin_objects(self):
|
||||||
|
"""New objects that are not WatchedFieldsMixin are not captured."""
|
||||||
|
obj = object()
|
||||||
|
session = SimpleNamespace(new=[obj], deleted=[], dirty=[], info={})
|
||||||
|
_after_flush(session, None)
|
||||||
|
assert _SESSION_PENDING_NEW not in session.info
|
||||||
|
|
||||||
|
def test_captures_deleted_watched_mixin_objects(self):
|
||||||
|
"""Deleted WatchedFieldsMixin instances are added to _SESSION_DELETES."""
|
||||||
|
obj = WatchedFieldsMixin()
|
||||||
|
session = SimpleNamespace(new=[], deleted=[obj], dirty=[], info={})
|
||||||
|
_after_flush(session, None)
|
||||||
|
assert session.info[_SESSION_DELETES] == [obj]
|
||||||
|
|
||||||
|
def test_ignores_deleted_non_mixin_objects(self):
|
||||||
|
"""Deleted objects that are not WatchedFieldsMixin are not captured."""
|
||||||
|
obj = object()
|
||||||
|
session = SimpleNamespace(new=[], deleted=[obj], dirty=[], info={})
|
||||||
|
_after_flush(session, None)
|
||||||
|
assert _SESSION_DELETES not in session.info
|
||||||
|
|
||||||
|
|
||||||
|
class TestAfterFlushPostexec:
|
||||||
|
def test_does_nothing_when_no_pending_new(self):
|
||||||
|
"""_after_flush_postexec does nothing when _SESSION_PENDING_NEW is absent."""
|
||||||
|
session = SimpleNamespace(info={})
|
||||||
|
_after_flush_postexec(session, None)
|
||||||
|
assert _SESSION_CREATES not in session.info
|
||||||
|
|
||||||
|
def test_moves_pending_new_to_creates(self):
|
||||||
|
"""Objects from _SESSION_PENDING_NEW are moved to _SESSION_CREATES."""
|
||||||
|
obj = object()
|
||||||
|
session = SimpleNamespace(info={_SESSION_PENDING_NEW: [obj]})
|
||||||
|
_after_flush_postexec(session, None)
|
||||||
|
assert _SESSION_PENDING_NEW not in session.info
|
||||||
|
assert session.info[_SESSION_CREATES] == [obj]
|
||||||
|
|
||||||
|
def test_extends_existing_creates(self):
|
||||||
|
"""Multiple flushes accumulate in _SESSION_CREATES."""
|
||||||
|
a, b = object(), object()
|
||||||
|
session = SimpleNamespace(
|
||||||
|
info={_SESSION_PENDING_NEW: [b], _SESSION_CREATES: [a]}
|
||||||
|
)
|
||||||
|
_after_flush_postexec(session, None)
|
||||||
|
assert session.info[_SESSION_CREATES] == [a, b]
|
||||||
|
|
||||||
|
|
||||||
|
class TestAfterRollback:
|
||||||
|
def test_clears_all_session_info_keys(self):
|
||||||
|
"""_after_rollback removes all four tracking keys from session.info."""
|
||||||
|
session = SimpleNamespace(
|
||||||
|
info={
|
||||||
|
_SESSION_PENDING_NEW: [object()],
|
||||||
|
_SESSION_CREATES: [object()],
|
||||||
|
_SESSION_DELETES: [object()],
|
||||||
|
_SESSION_UPDATES: {1: ("obj", {"f": {"old": "a", "new": "b"}})},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
_after_rollback(session)
|
||||||
|
assert _SESSION_PENDING_NEW not in session.info
|
||||||
|
assert _SESSION_CREATES not in session.info
|
||||||
|
assert _SESSION_DELETES not in session.info
|
||||||
|
assert _SESSION_UPDATES not in session.info
|
||||||
|
|
||||||
|
def test_tolerates_missing_keys(self):
|
||||||
|
"""_after_rollback does not raise when session.info has no pending data."""
|
||||||
|
session = SimpleNamespace(info={})
|
||||||
|
_after_rollback(session) # must not raise
|
||||||
|
|
||||||
|
|
||||||
|
class TestTaskErrorHandler:
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_logs_exception_from_failed_task(self):
|
||||||
|
"""_task_error_handler calls _logger.error when the task raised."""
|
||||||
|
|
||||||
|
async def failing() -> None:
|
||||||
|
raise ValueError("boom")
|
||||||
|
|
||||||
|
task = asyncio.create_task(failing())
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
with patch.object(_watched_module._logger, "error") as mock_error:
|
||||||
|
_task_error_handler(task)
|
||||||
|
mock_error.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_ignores_cancelled_task(self):
|
||||||
|
"""_task_error_handler does not log when the task was cancelled."""
|
||||||
|
|
||||||
|
async def slow() -> None:
|
||||||
|
await asyncio.sleep(100)
|
||||||
|
|
||||||
|
task = asyncio.create_task(slow())
|
||||||
|
task.cancel()
|
||||||
|
with suppress(asyncio.CancelledError):
|
||||||
|
await task
|
||||||
|
|
||||||
|
with patch.object(_watched_module._logger, "error") as mock_error:
|
||||||
|
_task_error_handler(task)
|
||||||
|
mock_error.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
class TestAfterCommitNoLoop:
|
||||||
|
def test_no_task_scheduled_when_no_running_loop(self):
|
||||||
|
"""_after_commit silently returns when called outside an async context."""
|
||||||
|
called = []
|
||||||
|
obj = SimpleNamespace(on_create=lambda: called.append("create"))
|
||||||
|
session = SimpleNamespace(info={_SESSION_CREATES: [obj]})
|
||||||
|
_after_commit(session)
|
||||||
|
assert called == []
|
||||||
|
|
||||||
|
def test_returns_early_when_all_pending_empty(self):
|
||||||
|
"""_after_commit does nothing when all pending lists are empty."""
|
||||||
|
session = SimpleNamespace(info={})
|
||||||
|
_after_commit(session) # should not raise
|
||||||
|
|
||||||
|
|
||||||
|
class TestWatchedFieldsMixin:
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def clear_events(self):
|
||||||
|
_test_events.clear()
|
||||||
|
yield
|
||||||
|
_test_events.clear()
|
||||||
|
|
||||||
|
# --- on_create ---
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_on_create_fires_after_insert(self, mixin_session):
|
||||||
|
"""on_create is called after INSERT commit."""
|
||||||
|
obj = WatchedModel(status="active", other="x")
|
||||||
|
mixin_session.add(obj)
|
||||||
|
await mixin_session.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
creates = [e for e in _test_events if e["event"] == "create"]
|
||||||
|
assert len(creates) == 1
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_on_create_server_defaults_populated(self, mixin_session):
|
||||||
|
"""id (server default via RETURNING) is available inside on_create."""
|
||||||
|
obj = WatchedModel(status="active", other="x")
|
||||||
|
mixin_session.add(obj)
|
||||||
|
await mixin_session.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
creates = [e for e in _test_events if e["event"] == "create"]
|
||||||
|
assert creates[0]["obj_id"] is not None
|
||||||
|
assert isinstance(creates[0]["obj_id"], uuid.UUID)
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_on_create_not_fired_on_update(self, mixin_session):
|
||||||
|
"""on_create is NOT called when an existing row is updated."""
|
||||||
|
obj = WatchedModel(status="initial", other="x")
|
||||||
|
mixin_session.add(obj)
|
||||||
|
await mixin_session.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
_test_events.clear()
|
||||||
|
|
||||||
|
obj.status = "updated"
|
||||||
|
await mixin_session.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
assert not any(e["event"] == "create" for e in _test_events)
|
||||||
|
|
||||||
|
# --- on_delete ---
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_on_delete_fires_after_delete(self, mixin_session):
|
||||||
|
"""on_delete is called after DELETE commit."""
|
||||||
|
obj = WatchedModel(status="active", other="x")
|
||||||
|
mixin_session.add(obj)
|
||||||
|
await mixin_session.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
saved_id = obj.id
|
||||||
|
_test_events.clear()
|
||||||
|
|
||||||
|
await mixin_session.delete(obj)
|
||||||
|
await mixin_session.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
deletes = [e for e in _test_events if e["event"] == "delete"]
|
||||||
|
assert len(deletes) == 1
|
||||||
|
assert deletes[0]["obj_id"] == saved_id
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_on_delete_not_fired_on_insert(self, mixin_session):
|
||||||
|
"""on_delete is NOT called when a new row is inserted."""
|
||||||
|
obj = WatchedModel(status="active", other="x")
|
||||||
|
mixin_session.add(obj)
|
||||||
|
await mixin_session.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
assert not any(e["event"] == "delete" for e in _test_events)
|
||||||
|
|
||||||
|
# --- on_update ---
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_on_update_fires_on_update(self, mixin_session):
|
||||||
|
"""on_update reports the correct before/after values on UPDATE."""
|
||||||
|
obj = WatchedModel(status="initial", other="x")
|
||||||
|
mixin_session.add(obj)
|
||||||
|
await mixin_session.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
_test_events.clear()
|
||||||
|
|
||||||
|
obj.status = "updated"
|
||||||
|
await mixin_session.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
changes_events = [e for e in _test_events if e["event"] == "update"]
|
||||||
|
assert len(changes_events) == 1
|
||||||
|
assert changes_events[0]["changes"]["status"] == {
|
||||||
|
"old": "initial",
|
||||||
|
"new": "updated",
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_on_update_not_fired_on_insert(self, mixin_session):
|
||||||
|
"""on_update is NOT called on INSERT (on_create handles that)."""
|
||||||
|
obj = WatchedModel(status="active", other="x")
|
||||||
|
mixin_session.add(obj)
|
||||||
|
await mixin_session.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
assert not any(e["event"] == "update" for e in _test_events)
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_unwatched_field_update_no_callback(self, mixin_session):
|
||||||
|
"""Changing a field not listed in @update does not fire on_update."""
|
||||||
|
obj = WatchedModel(status="active", other="x")
|
||||||
|
mixin_session.add(obj)
|
||||||
|
await mixin_session.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
_test_events.clear()
|
||||||
|
|
||||||
|
obj.other = "changed"
|
||||||
|
await mixin_session.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
assert _test_events == []
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_multiple_flushes_merge_earliest_old_latest_new(self, mixin_session):
|
||||||
|
"""Two flushes in one transaction produce a single callback with earliest old / latest new."""
|
||||||
|
obj = WatchedModel(status="initial", other="x")
|
||||||
|
mixin_session.add(obj)
|
||||||
|
await mixin_session.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
_test_events.clear()
|
||||||
|
|
||||||
|
obj.status = "intermediate"
|
||||||
|
await mixin_session.flush()
|
||||||
|
|
||||||
|
obj.status = "final"
|
||||||
|
await mixin_session.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
changes_events = [e for e in _test_events if e["event"] == "update"]
|
||||||
|
assert len(changes_events) == 1
|
||||||
|
assert changes_events[0]["changes"]["status"] == {
|
||||||
|
"old": "initial",
|
||||||
|
"new": "final",
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_rollback_suppresses_all_callbacks(self, mixin_session):
|
||||||
|
"""No callbacks are fired when the transaction is rolled back."""
|
||||||
|
obj = WatchedModel(status="active", other="x")
|
||||||
|
mixin_session.add(obj)
|
||||||
|
await mixin_session.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
_test_events.clear()
|
||||||
|
|
||||||
|
obj.status = "changed"
|
||||||
|
await mixin_session.flush()
|
||||||
|
await mixin_session.rollback()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
assert _test_events == []
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_non_watched_model_no_callback(self, mixin_session):
|
||||||
|
"""Dirty objects whose type is not a WatchedFieldsMixin are skipped."""
|
||||||
|
nw = NonWatchedModel(value="x")
|
||||||
|
mixin_session.add(nw)
|
||||||
|
await mixin_session.flush()
|
||||||
|
nw.value = "y"
|
||||||
|
await mixin_session.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
assert _test_events == []
|
||||||
|
|
||||||
|
# --- on_event (catch-all) ---
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_on_event_receives_create(self, mixin_session):
|
||||||
|
"""on_event is called with ModelEvent.CREATE on INSERT when only on_event is overridden."""
|
||||||
|
obj = OnEventModel(value="x")
|
||||||
|
mixin_session.add(obj)
|
||||||
|
await mixin_session.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
creates = [e for e in _test_events if e["event"] == ModelEvent.CREATE]
|
||||||
|
assert len(creates) == 1
|
||||||
|
assert creates[0]["changes"] is None
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_on_event_receives_delete(self, mixin_session):
|
||||||
|
"""on_event is called with ModelEvent.DELETE on DELETE when only on_event is overridden."""
|
||||||
|
obj = OnEventModel(value="x")
|
||||||
|
mixin_session.add(obj)
|
||||||
|
await mixin_session.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
_test_events.clear()
|
||||||
|
|
||||||
|
await mixin_session.delete(obj)
|
||||||
|
await mixin_session.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
deletes = [e for e in _test_events if e["event"] == ModelEvent.DELETE]
|
||||||
|
assert len(deletes) == 1
|
||||||
|
assert deletes[0]["changes"] is None
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_on_event_receives_field_change(self, mixin_session):
|
||||||
|
"""on_event is called with ModelEvent.UPDATE on UPDATE when only on_event is overridden."""
|
||||||
|
obj = OnEventModel(value="initial")
|
||||||
|
mixin_session.add(obj)
|
||||||
|
await mixin_session.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
_test_events.clear()
|
||||||
|
|
||||||
|
obj.value = "updated"
|
||||||
|
await mixin_session.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
changes_events = [e for e in _test_events if e["event"] == ModelEvent.UPDATE]
|
||||||
|
assert len(changes_events) == 1
|
||||||
|
assert changes_events[0]["changes"]["value"] == {
|
||||||
|
"old": "initial",
|
||||||
|
"new": "updated",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestWatchAll:
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def clear_events(self):
|
||||||
|
_test_events.clear()
|
||||||
|
yield
|
||||||
|
_test_events.clear()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_watch_all_fires_for_any_field(self, mixin_session):
|
||||||
|
"""Model without @watch fires on_update for any changed field."""
|
||||||
|
obj = WatchAllModel(status="initial", other="x")
|
||||||
|
mixin_session.add(obj)
|
||||||
|
await mixin_session.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
_test_events.clear()
|
||||||
|
|
||||||
|
obj.other = "changed"
|
||||||
|
await mixin_session.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
changes_events = [e for e in _test_events if e["event"] == "update"]
|
||||||
|
assert len(changes_events) == 1
|
||||||
|
assert "other" in changes_events[0]["changes"]
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_watch_all_captures_multiple_fields(self, mixin_session):
|
||||||
|
"""Model without @watch captures all fields changed in a single commit."""
|
||||||
|
obj = WatchAllModel(status="initial", other="x")
|
||||||
|
mixin_session.add(obj)
|
||||||
|
await mixin_session.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
_test_events.clear()
|
||||||
|
|
||||||
|
obj.status = "updated"
|
||||||
|
obj.other = "changed"
|
||||||
|
await mixin_session.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
changes_events = [e for e in _test_events if e["event"] == "update"]
|
||||||
|
assert len(changes_events) == 1
|
||||||
|
assert "status" in changes_events[0]["changes"]
|
||||||
|
assert "other" in changes_events[0]["changes"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestSyncCallbacks:
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def clear_events(self):
|
||||||
|
_sync_events.clear()
|
||||||
|
yield
|
||||||
|
_sync_events.clear()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_sync_on_create_fires(self, mixin_session):
|
||||||
|
"""Sync on_create is called after INSERT commit."""
|
||||||
|
obj = SyncCallbackModel(status="active")
|
||||||
|
mixin_session.add(obj)
|
||||||
|
await mixin_session.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
creates = [e for e in _sync_events if e["event"] == "create"]
|
||||||
|
assert len(creates) == 1
|
||||||
|
assert isinstance(creates[0]["obj_id"], uuid.UUID)
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_sync_on_delete_fires(self, mixin_session):
|
||||||
|
"""Sync on_delete is called after DELETE commit."""
|
||||||
|
obj = SyncCallbackModel(status="active")
|
||||||
|
mixin_session.add(obj)
|
||||||
|
await mixin_session.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
_sync_events.clear()
|
||||||
|
|
||||||
|
await mixin_session.delete(obj)
|
||||||
|
await mixin_session.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
deletes = [e for e in _sync_events if e["event"] == "delete"]
|
||||||
|
assert len(deletes) == 1
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_sync_on_update_fires(self, mixin_session):
|
||||||
|
"""Sync on_update is called after UPDATE commit with correct changes."""
|
||||||
|
obj = SyncCallbackModel(status="initial")
|
||||||
|
mixin_session.add(obj)
|
||||||
|
await mixin_session.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
_sync_events.clear()
|
||||||
|
|
||||||
|
obj.status = "updated"
|
||||||
|
await mixin_session.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
updates = [e for e in _sync_events if e["event"] == "update"]
|
||||||
|
assert len(updates) == 1
|
||||||
|
assert updates[0]["changes"]["status"] == {"old": "initial", "new": "updated"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestCallCallback:
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_async_callback_scheduled_as_task(self):
|
||||||
|
"""_call_callback schedules async functions as tasks."""
|
||||||
|
called = []
|
||||||
|
|
||||||
|
async def async_fn() -> None:
|
||||||
|
called.append("async")
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
_call_callback(loop, async_fn)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
assert called == ["async"]
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_sync_callback_called_directly(self):
|
||||||
|
"""_call_callback invokes sync functions immediately."""
|
||||||
|
called = []
|
||||||
|
|
||||||
|
def sync_fn() -> None:
|
||||||
|
called.append("sync")
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
_call_callback(loop, sync_fn)
|
||||||
|
assert called == ["sync"]
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_sync_callback_exception_logged(self):
|
||||||
|
"""_call_callback logs exceptions from sync callbacks."""
|
||||||
|
|
||||||
|
def failing_fn() -> None:
|
||||||
|
raise RuntimeError("sync error")
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
with patch.object(_watched_module._logger, "error") as mock_error:
|
||||||
|
_call_callback(loop, failing_fn)
|
||||||
|
mock_error.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_async_callback_with_args(self):
|
||||||
|
"""_call_callback passes arguments to async callbacks."""
|
||||||
|
received = []
|
||||||
|
|
||||||
|
async def async_fn(changes: dict) -> None:
|
||||||
|
received.append(changes)
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
_call_callback(loop, async_fn, {"status": {"old": "a", "new": "b"}})
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
assert received == [{"status": {"old": "a", "new": "b"}}]
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_sync_callback_with_args(self):
|
||||||
|
"""_call_callback passes arguments to sync callbacks."""
|
||||||
|
received = []
|
||||||
|
|
||||||
|
def sync_fn(changes: dict) -> None:
|
||||||
|
received.append(changes)
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
_call_callback(loop, sync_fn, {"x": 1})
|
||||||
|
assert received == [{"x": 1}]
|
||||||
|
|||||||
Reference in New Issue
Block a user