diff --git a/README.md b/README.md index a556461..652d34d 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ uv add "fastapi-toolsets[all]" - **Database**: Session management, transaction helpers, table locking, and polling-based row change detection - **Dependencies**: FastAPI dependency factories (`PathDependency`, `BodyDependency`) for automatic DB lookups from path or body parameters - **Fixtures**: Fixture system with dependency management, context support, and pytest integration -- **Model Mixins**: SQLAlchemy mixins for common column patterns (`UUIDMixin`, `UUIDv7Mixin`, `CreatedAtMixin`, `UpdatedAtMixin`, `TimestampMixin`) +- **Model Mixins**: SQLAlchemy mixins for common column patterns (`UUIDMixin`, `UUIDv7Mixin`, `CreatedAtMixin`, `UpdatedAtMixin`, `TimestampMixin`) and lifecycle callbacks (`WatchedFieldsMixin`, `@watch`) that fire after commit for insert, update, and delete events - **Standardized API Responses**: Consistent response format with `Response`, `ErrorResponse`, `PaginatedResponse`, `CursorPaginatedResponse` and `OffsetPaginatedResponse`. - **Exception Handling**: Structured error responses with automatic OpenAPI documentation - **Logging**: Logging configuration with uvicorn integration via `configure_logging` and `get_logger` diff --git a/docs/index.md b/docs/index.md index a556461..6c4dd40 100644 --- a/docs/index.md +++ b/docs/index.md @@ -48,7 +48,7 @@ uv add "fastapi-toolsets[all]" - **Database**: Session management, transaction helpers, table locking, and polling-based row change detection - **Dependencies**: FastAPI dependency factories (`PathDependency`, `BodyDependency`) for automatic DB lookups from path or body parameters - **Fixtures**: Fixture system with dependency management, context support, and pytest integration -- **Model Mixins**: SQLAlchemy mixins for common column patterns (`UUIDMixin`, `UUIDv7Mixin`, `CreatedAtMixin`, `UpdatedAtMixin`, `TimestampMixin`) +- **Model Mixins**: SQLAlchemy mixins for common column patterns (`UUIDMixin`, `UUIDv7Mixin`, `CreatedAtMixin`, `UpdatedAtMixin`, `TimestampMixin`) and lifecycle callbacks (`WatchedFieldsMixin`) that fire after commit for insert, update, and delete events. - **Standardized API Responses**: Consistent response format with `Response`, `ErrorResponse`, `PaginatedResponse`, `CursorPaginatedResponse` and `OffsetPaginatedResponse`. - **Exception Handling**: Structured error responses with automatic OpenAPI documentation - **Logging**: Logging configuration with uvicorn integration via `configure_logging` and `get_logger` diff --git a/docs/module/models.md b/docs/module/models.md index ff17c4b..8376cea 100644 --- a/docs/module/models.md +++ b/docs/module/models.md @@ -117,6 +117,86 @@ class Article(Base, UUIDMixin, TimestampMixin): title: Mapped[str] ``` +### [`WatchedFieldsMixin`](../reference/models.md#fastapi_toolsets.models.WatchedFieldsMixin) + +!!! info "Added in `v2.4`" + +`WatchedFieldsMixin` provides lifecycle callbacks that fire **after commit** — meaning the row is durably persisted when your callback runs. If the transaction rolls back, no callback fires. + +Three callbacks are available, each corresponding to a [`ModelEvent`](../reference/models.md#fastapi_toolsets.models.ModelEvent) value: + +| Callback | Event | Trigger | +|---|---|---| +| `on_create()` | `ModelEvent.CREATE` | After `INSERT` | +| `on_delete()` | `ModelEvent.DELETE` | After `DELETE` | +| `on_update(changes)` | `ModelEvent.UPDATE` | After `UPDATE` on a watched field | + +Server-side defaults (e.g. `id`, `created_at`) are fully populated in all callbacks. All callbacks support both `async def` and plain `def`. Use `@watch` to restrict which fields trigger `on_update`: + +| Decorator | `on_update` behaviour | +|---|---| +| `@watch("status", "role")` | Only fires when `status` or `role` changes | +| *(no decorator)* | Fires when **any** mapped field changes | + +#### Option 1 — catch-all with `on_event` + +Override `on_event` to handle all event types in one place. The specific methods delegate here by default: + +```python +from fastapi_toolsets.models import ModelEvent, UUIDMixin, WatchedFieldsMixin, watch + +@watch("status") +class Order(Base, UUIDMixin, WatchedFieldsMixin): + __tablename__ = "orders" + + status: Mapped[str] + + async def on_event(self, event: ModelEvent, changes: dict | None = None) -> None: + if event == ModelEvent.CREATE: + await notify_new_order(self.id) + elif event == ModelEvent.DELETE: + await notify_order_cancelled(self.id) + elif event == ModelEvent.UPDATE: + await notify_status_change(self.id, changes["status"]) +``` + +#### Option 2 — targeted overrides + +Override individual methods for more focused logic: + +```python +@watch("status") +class Order(Base, UUIDMixin, WatchedFieldsMixin): + __tablename__ = "orders" + + status: Mapped[str] + + async def on_create(self) -> None: + await notify_new_order(self.id) + + async def on_delete(self) -> None: + await notify_order_cancelled(self.id) + + async def on_update(self, changes: dict) -> None: + if "status" in changes: + old = changes["status"]["old"] + new = changes["status"]["new"] + await notify_status_change(self.id, old, new) +``` + +#### Field changes format + +The `changes` dict maps each watched field that changed to `{"old": ..., "new": ...}`. Only fields that actually changed are included: + +```python +# status changed → {"status": {"old": "pending", "new": "shipped"}} +# two fields changed → {"status": {...}, "assigned_to": {...}} +``` + +!!! info "Multiple flushes in one transaction are merged: the earliest `old` and latest `new` are preserved, and `on_update` fires only once per commit." + +!!! warning "Callbacks fire only for ORM-level changes. Rows updated via raw SQL (`UPDATE ... SET ...`) are not detected." + ## Composing mixins All mixins can be combined in any order. The only constraint is that exactly one primary key must be defined — either via `UUIDMixin` or directly on the model. diff --git a/docs/reference/models.md b/docs/reference/models.md index 7e0fb53..a30d5b6 100644 --- a/docs/reference/models.md +++ b/docs/reference/models.md @@ -6,14 +6,19 @@ You can import them directly from `fastapi_toolsets.models`: ```python from fastapi_toolsets.models import ( + ModelEvent, UUIDMixin, UUIDv7Mixin, CreatedAtMixin, UpdatedAtMixin, TimestampMixin, + WatchedFieldsMixin, + watch, ) ``` +## ::: fastapi_toolsets.models.ModelEvent + ## ::: fastapi_toolsets.models.UUIDMixin ## ::: fastapi_toolsets.models.UUIDv7Mixin @@ -23,3 +28,7 @@ from fastapi_toolsets.models import ( ## ::: fastapi_toolsets.models.UpdatedAtMixin ## ::: fastapi_toolsets.models.TimestampMixin + +## ::: fastapi_toolsets.models.WatchedFieldsMixin + +## ::: fastapi_toolsets.models.watch diff --git a/src/fastapi_toolsets/models/__init__.py b/src/fastapi_toolsets/models/__init__.py new file mode 100644 index 0000000..5af6821 --- /dev/null +++ b/src/fastapi_toolsets/models/__init__.py @@ -0,0 +1,21 @@ +"""SQLAlchemy model mixins for common column patterns.""" + +from .columns import ( + CreatedAtMixin, + TimestampMixin, + UUIDMixin, + UUIDv7Mixin, + UpdatedAtMixin, +) +from .watched import ModelEvent, WatchedFieldsMixin, watch + +__all__ = [ + "ModelEvent", + "UUIDMixin", + "UUIDv7Mixin", + "CreatedAtMixin", + "UpdatedAtMixin", + "TimestampMixin", + "WatchedFieldsMixin", + "watch", +] diff --git a/src/fastapi_toolsets/models.py b/src/fastapi_toolsets/models/columns.py similarity index 95% rename from src/fastapi_toolsets/models.py rename to src/fastapi_toolsets/models/columns.py index 9752e28..bdbc38d 100644 --- a/src/fastapi_toolsets/models.py +++ b/src/fastapi_toolsets/models/columns.py @@ -1,4 +1,4 @@ -"""SQLAlchemy model mixins for common column patterns.""" +"""SQLAlchemy column mixins for common column patterns.""" import uuid from datetime import datetime diff --git a/src/fastapi_toolsets/models/watched.py b/src/fastapi_toolsets/models/watched.py new file mode 100644 index 0000000..ff3e706 --- /dev/null +++ b/src/fastapi_toolsets/models/watched.py @@ -0,0 +1,204 @@ +"""Field-change monitoring via SQLAlchemy session events.""" + +import asyncio +import weakref +from collections.abc import Awaitable +from enum import Enum +from typing import Any, TypeVar + +from sqlalchemy import event +from sqlalchemy import inspect as sa_inspect +from sqlalchemy.ext.asyncio import AsyncSession + +from ..logger import get_logger + +__all__ = ["ModelEvent", "WatchedFieldsMixin", "watch"] + +_logger = get_logger() +_T = TypeVar("_T") +_CALLBACK_ERROR_MSG = "WatchedFieldsMixin callback raised an unhandled exception" +_WATCHED_FIELDS: weakref.WeakKeyDictionary[type, list[str]] = ( + weakref.WeakKeyDictionary() +) +_SESSION_PENDING_NEW = "_ft_pending_new" +_SESSION_CREATES = "_ft_creates" +_SESSION_DELETES = "_ft_deletes" +_SESSION_UPDATES = "_ft_updates" + + +class ModelEvent(str, Enum): + """Event types emitted by :class:`WatchedFieldsMixin`.""" + + CREATE = "create" + DELETE = "delete" + UPDATE = "update" + + +def watch(*fields: str) -> Any: + """Class decorator to filter which fields trigger ``on_update``. + + Args: + *fields: One or more field names to watch. At least one name is required. + + Raises: + ValueError: If called with no field names. + """ + if not fields: + raise ValueError("@watch requires at least one field name.") + + def decorator(cls: type[_T]) -> type[_T]: + _WATCHED_FIELDS[cls] = list(fields) + return cls + + return decorator + + +def _upsert_changes( + pending: dict[int, tuple[Any, dict[str, dict[str, Any]]]], + obj: Any, + changes: dict[str, dict[str, Any]], +) -> None: + """Insert or merge *changes* into *pending* for *obj*.""" + key = id(obj) + if key in pending: + existing = pending[key][1] + for field, change in changes.items(): + if field in existing: + existing[field]["new"] = change["new"] + else: + existing[field] = change + else: + pending[key] = (obj, changes) + + +@event.listens_for(AsyncSession.sync_session_class, "after_flush") +def _after_flush(session: Any, flush_context: Any) -> None: + # New objects: capture references while session.new is still populated. + # Values are read in _after_flush_postexec once RETURNING has been processed. + for obj in session.new: + if isinstance(obj, WatchedFieldsMixin): + session.info.setdefault(_SESSION_PENDING_NEW, []).append(obj) + + # Deleted objects: capture before they leave the identity map. + for obj in session.deleted: + if isinstance(obj, WatchedFieldsMixin): + session.info.setdefault(_SESSION_DELETES, []).append(obj) + + # Dirty objects: read old/new from SQLAlchemy attribute history. + for obj in session.dirty: + if not isinstance(obj, WatchedFieldsMixin): + continue + + # None = not in dict = watch all fields; list = specific fields only + watched = _WATCHED_FIELDS.get(type(obj)) + changes: dict[str, dict[str, Any]] = {} + + attrs = ( + # Specific fields + ((field, sa_inspect(obj).attrs[field]) for field in watched) + if watched is not None + # All mapped fields + else ((s.key, s) for s in sa_inspect(obj).attrs) + ) + for field, attr_state in attrs: + history = attr_state.history + if history.has_changes() and history.deleted: + changes[field] = { + "old": history.deleted[0], + "new": history.added[0] if history.added else None, + } + + if changes: + _upsert_changes( + session.info.setdefault(_SESSION_UPDATES, {}), + obj, + changes, + ) + + +@event.listens_for(AsyncSession.sync_session_class, "after_flush_postexec") +def _after_flush_postexec(session: Any, flush_context: Any) -> None: + # New objects are now persistent and RETURNING values have been applied, + # so server defaults (id, created_at, …) are available via getattr. + pending_new: list[Any] = session.info.pop(_SESSION_PENDING_NEW, []) + if not pending_new: + return + session.info.setdefault(_SESSION_CREATES, []).extend(pending_new) + + +@event.listens_for(AsyncSession.sync_session_class, "after_rollback") +def _after_rollback(session: Any) -> None: + session.info.pop(_SESSION_PENDING_NEW, None) + session.info.pop(_SESSION_CREATES, None) + session.info.pop(_SESSION_DELETES, None) + session.info.pop(_SESSION_UPDATES, None) + + +def _task_error_handler(task: asyncio.Task[Any]) -> None: + if not task.cancelled() and (exc := task.exception()): + _logger.error(_CALLBACK_ERROR_MSG, exc_info=exc) + + +def _call_callback(loop: asyncio.AbstractEventLoop, fn: Any, *args: Any) -> None: + """Dispatch *fn* with *args*, handling both sync and async callables.""" + try: + result = fn(*args) + except Exception as exc: + _logger.error(_CALLBACK_ERROR_MSG, exc_info=exc) + return + if asyncio.iscoroutine(result): + task = loop.create_task(result) + task.add_done_callback(_task_error_handler) + + +@event.listens_for(AsyncSession.sync_session_class, "after_commit") +def _after_commit(session: Any) -> None: + creates: list[Any] = session.info.pop(_SESSION_CREATES, []) + deletes: list[Any] = session.info.pop(_SESSION_DELETES, []) + field_changes: dict[int, tuple[Any, dict[str, dict[str, Any]]]] = session.info.pop( + _SESSION_UPDATES, {} + ) + + if not creates and not deletes and not field_changes: + return + + try: + loop = asyncio.get_running_loop() + except RuntimeError: + return + + for obj in creates: + _call_callback(loop, obj.on_create) + + for obj in deletes: + _call_callback(loop, obj.on_delete) + + for obj, changes in field_changes.values(): + _call_callback(loop, obj.on_update, changes) + + +class WatchedFieldsMixin: + """Mixin that enables lifecycle callbacks for SQLAlchemy models.""" + + def on_event( + self, event: ModelEvent, changes: dict[str, dict[str, Any]] | None = None + ) -> Awaitable[None] | None: + """Catch-all callback fired for every lifecycle event. + + Args: + event: The event type (:attr:`ModelEvent.CREATE`, :attr:`ModelEvent.DELETE`, + or :attr:`ModelEvent.UPDATE`). + changes: Field changes for :attr:`ModelEvent.UPDATE`, ``None`` otherwise. + """ + + def on_create(self) -> Awaitable[None] | None: + """Called after INSERT commit.""" + return self.on_event(ModelEvent.CREATE) + + def on_delete(self) -> Awaitable[None] | None: + """Called after DELETE commit.""" + return self.on_event(ModelEvent.DELETE) + + def on_update(self, changes: dict[str, dict[str, Any]]) -> Awaitable[None] | None: + """Called after UPDATE commit when watched fields change.""" + return self.on_event(ModelEvent.UPDATE, changes=changes) diff --git a/tests/test_models.py b/tests/test_models.py index 953d113..fb1b9a4 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,7 +1,12 @@ """Tests for fastapi_toolsets.models mixins.""" +import asyncio import uuid +from contextlib import suppress +from types import SimpleNamespace +from unittest.mock import patch +import fastapi_toolsets.models.watched as _watched_module import pytest from sqlalchemy import String from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine @@ -9,10 +14,26 @@ from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column from fastapi_toolsets.models import ( CreatedAtMixin, + ModelEvent, TimestampMixin, UUIDMixin, UUIDv7Mixin, UpdatedAtMixin, + WatchedFieldsMixin, + watch, +) +from fastapi_toolsets.models.watched import ( + _SESSION_CREATES, + _SESSION_DELETES, + _SESSION_UPDATES, + _SESSION_PENDING_NEW, + _after_commit, + _after_flush, + _after_flush_postexec, + _after_rollback, + _call_callback, + _task_error_handler, + _upsert_changes, ) from .conftest import DATABASE_URL @@ -61,6 +82,80 @@ class FullMixinModel(MixinBase, UUIDMixin, UpdatedAtMixin): name: Mapped[str] = mapped_column(String(50)) +# --- WatchedFieldsMixin test models --- + +_test_events: list[dict] = [] + + +@watch("status") +class WatchedModel(MixinBase, UUIDMixin, WatchedFieldsMixin): + __tablename__ = "mixin_watched_models" + + status: Mapped[str] = mapped_column(String(50)) + other: Mapped[str] = mapped_column(String(50)) + + async def on_create(self) -> None: + _test_events.append({"event": "create", "obj_id": self.id}) + + async def on_delete(self) -> None: + _test_events.append({"event": "delete", "obj_id": self.id}) + + async def on_update(self, changes: dict) -> None: + _test_events.append({"event": "update", "obj_id": self.id, "changes": changes}) + + +@watch("value") +class OnEventModel(MixinBase, UUIDMixin, WatchedFieldsMixin): + """Model that only overrides on_event to test the catch-all path.""" + + __tablename__ = "mixin_on_event_models" + + value: Mapped[str] = mapped_column(String(50)) + + async def on_event(self, event: ModelEvent, changes: dict | None = None) -> None: + _test_events.append({"event": event, "obj_id": self.id, "changes": changes}) + + +class WatchAllModel(MixinBase, UUIDMixin, WatchedFieldsMixin): + """Model without @watch — watches all mapped fields by default.""" + + __tablename__ = "mixin_watch_all_models" + + status: Mapped[str] = mapped_column(String(50)) + other: Mapped[str] = mapped_column(String(50)) + + async def on_update(self, changes: dict) -> None: + _test_events.append({"event": "update", "obj_id": self.id, "changes": changes}) + + +class NonWatchedModel(MixinBase): + __tablename__ = "mixin_non_watched_models" + + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + value: Mapped[str] = mapped_column(String(50)) + + +_sync_events: list[dict] = [] + + +@watch("status") +class SyncCallbackModel(MixinBase, UUIDMixin, WatchedFieldsMixin): + """Model with plain (sync) on_* callbacks.""" + + __tablename__ = "mixin_sync_callback_models" + + status: Mapped[str] = mapped_column(String(50)) + + def on_create(self) -> None: + _sync_events.append({"event": "create", "obj_id": self.id}) + + def on_delete(self) -> None: + _sync_events.append({"event": "delete", "obj_id": self.id}) + + def on_update(self, changes: dict) -> None: + _sync_events.append({"event": "update", "changes": changes}) + + @pytest.fixture(scope="function") async def mixin_session(): engine = create_async_engine(DATABASE_URL, echo=False) @@ -124,7 +219,7 @@ class TestUpdatedAtMixin: await mixin_session.refresh(obj) assert obj.updated_at is not None - assert obj.updated_at.tzinfo is not None # timezone-aware + assert obj.updated_at.tzinfo is not None @pytest.mark.anyio async def test_updated_at_changes_on_update(self, mixin_session): @@ -171,7 +266,7 @@ class TestCreatedAtMixin: await mixin_session.refresh(obj) assert obj.created_at is not None - assert obj.created_at.tzinfo is not None # timezone-aware + assert obj.created_at.tzinfo is not None @pytest.mark.anyio async def test_created_at_not_changed_on_update(self, mixin_session): @@ -296,3 +391,577 @@ class TestFullMixinModel: assert isinstance(obj.id, uuid.UUID) assert obj.updated_at is not None assert obj.updated_at.tzinfo is not None + + +class TestWatchDecorator: + def test_registers_specific_fields(self): + """@watch("field") stores the field list in _WATCHED_FIELDS.""" + assert _watched_module._WATCHED_FIELDS.get(WatchedModel) == ["status"] + + def test_no_decorator_not_in_watched_fields(self): + """A model without @watch has no entry in _WATCHED_FIELDS (watch all).""" + assert WatchAllModel not in _watched_module._WATCHED_FIELDS + + def test_preserves_class_identity(self): + """watch returns the same class unchanged.""" + + class _Dummy(WatchedFieldsMixin): + pass + + result = watch("x")(_Dummy) + assert result is _Dummy + del _watched_module._WATCHED_FIELDS[_Dummy] + + def test_raises_when_no_fields_given(self): + """@watch() with no field names raises ValueError.""" + with pytest.raises(ValueError, match="@watch requires at least one field name"): + watch() + + +class TestUpsertChanges: + def test_inserts_new_entry(self): + """New key is inserted with the full changes dict.""" + pending: dict = {} + obj = object() + changes = {"status": {"old": None, "new": "active"}} + _upsert_changes(pending, obj, changes) + assert pending[id(obj)] == (obj, changes) + + def test_merges_existing_field_keeps_old_updates_new(self): + """When the field already exists, old is preserved and new is overwritten.""" + obj = object() + pending = { + id(obj): (obj, {"status": {"old": "initial", "new": "intermediate"}}) + } + _upsert_changes( + pending, obj, {"status": {"old": "intermediate", "new": "final"}} + ) + assert pending[id(obj)][1]["status"] == {"old": "initial", "new": "final"} + + def test_adds_new_field_to_existing_entry(self): + """A previously unseen field is added alongside existing ones.""" + obj = object() + pending = {id(obj): (obj, {"status": {"old": "a", "new": "b"}})} + _upsert_changes(pending, obj, {"role": {"old": "user", "new": "admin"}}) + fields = pending[id(obj)][1] + assert fields["status"] == {"old": "a", "new": "b"} + assert fields["role"] == {"old": "user", "new": "admin"} + + +class TestAfterFlush: + def test_does_nothing_with_empty_session(self): + """_after_flush writes nothing to session.info when all collections are empty.""" + session = SimpleNamespace(new=[], deleted=[], dirty=[], info={}) + _after_flush(session, None) + assert session.info == {} + + def test_captures_new_watched_mixin_objects(self): + """New WatchedFieldsMixin instances are added to _SESSION_PENDING_NEW.""" + obj = WatchedFieldsMixin() + session = SimpleNamespace(new=[obj], deleted=[], dirty=[], info={}) + _after_flush(session, None) + assert session.info[_SESSION_PENDING_NEW] == [obj] + + def test_ignores_new_non_mixin_objects(self): + """New objects that are not WatchedFieldsMixin are not captured.""" + obj = object() + session = SimpleNamespace(new=[obj], deleted=[], dirty=[], info={}) + _after_flush(session, None) + assert _SESSION_PENDING_NEW not in session.info + + def test_captures_deleted_watched_mixin_objects(self): + """Deleted WatchedFieldsMixin instances are added to _SESSION_DELETES.""" + obj = WatchedFieldsMixin() + session = SimpleNamespace(new=[], deleted=[obj], dirty=[], info={}) + _after_flush(session, None) + assert session.info[_SESSION_DELETES] == [obj] + + def test_ignores_deleted_non_mixin_objects(self): + """Deleted objects that are not WatchedFieldsMixin are not captured.""" + obj = object() + session = SimpleNamespace(new=[], deleted=[obj], dirty=[], info={}) + _after_flush(session, None) + assert _SESSION_DELETES not in session.info + + +class TestAfterFlushPostexec: + def test_does_nothing_when_no_pending_new(self): + """_after_flush_postexec does nothing when _SESSION_PENDING_NEW is absent.""" + session = SimpleNamespace(info={}) + _after_flush_postexec(session, None) + assert _SESSION_CREATES not in session.info + + def test_moves_pending_new_to_creates(self): + """Objects from _SESSION_PENDING_NEW are moved to _SESSION_CREATES.""" + obj = object() + session = SimpleNamespace(info={_SESSION_PENDING_NEW: [obj]}) + _after_flush_postexec(session, None) + assert _SESSION_PENDING_NEW not in session.info + assert session.info[_SESSION_CREATES] == [obj] + + def test_extends_existing_creates(self): + """Multiple flushes accumulate in _SESSION_CREATES.""" + a, b = object(), object() + session = SimpleNamespace( + info={_SESSION_PENDING_NEW: [b], _SESSION_CREATES: [a]} + ) + _after_flush_postexec(session, None) + assert session.info[_SESSION_CREATES] == [a, b] + + +class TestAfterRollback: + def test_clears_all_session_info_keys(self): + """_after_rollback removes all four tracking keys from session.info.""" + session = SimpleNamespace( + info={ + _SESSION_PENDING_NEW: [object()], + _SESSION_CREATES: [object()], + _SESSION_DELETES: [object()], + _SESSION_UPDATES: {1: ("obj", {"f": {"old": "a", "new": "b"}})}, + } + ) + _after_rollback(session) + assert _SESSION_PENDING_NEW not in session.info + assert _SESSION_CREATES not in session.info + assert _SESSION_DELETES not in session.info + assert _SESSION_UPDATES not in session.info + + def test_tolerates_missing_keys(self): + """_after_rollback does not raise when session.info has no pending data.""" + session = SimpleNamespace(info={}) + _after_rollback(session) # must not raise + + +class TestTaskErrorHandler: + @pytest.mark.anyio + async def test_logs_exception_from_failed_task(self): + """_task_error_handler calls _logger.error when the task raised.""" + + async def failing() -> None: + raise ValueError("boom") + + task = asyncio.create_task(failing()) + await asyncio.sleep(0) + + with patch.object(_watched_module._logger, "error") as mock_error: + _task_error_handler(task) + mock_error.assert_called_once() + + @pytest.mark.anyio + async def test_ignores_cancelled_task(self): + """_task_error_handler does not log when the task was cancelled.""" + + async def slow() -> None: + await asyncio.sleep(100) + + task = asyncio.create_task(slow()) + task.cancel() + with suppress(asyncio.CancelledError): + await task + + with patch.object(_watched_module._logger, "error") as mock_error: + _task_error_handler(task) + mock_error.assert_not_called() + + +class TestAfterCommitNoLoop: + def test_no_task_scheduled_when_no_running_loop(self): + """_after_commit silently returns when called outside an async context.""" + called = [] + obj = SimpleNamespace(on_create=lambda: called.append("create")) + session = SimpleNamespace(info={_SESSION_CREATES: [obj]}) + _after_commit(session) + assert called == [] + + def test_returns_early_when_all_pending_empty(self): + """_after_commit does nothing when all pending lists are empty.""" + session = SimpleNamespace(info={}) + _after_commit(session) # should not raise + + +class TestWatchedFieldsMixin: + @pytest.fixture(autouse=True) + def clear_events(self): + _test_events.clear() + yield + _test_events.clear() + + # --- on_create --- + + @pytest.mark.anyio + async def test_on_create_fires_after_insert(self, mixin_session): + """on_create is called after INSERT commit.""" + obj = WatchedModel(status="active", other="x") + mixin_session.add(obj) + await mixin_session.commit() + await asyncio.sleep(0) + + creates = [e for e in _test_events if e["event"] == "create"] + assert len(creates) == 1 + + @pytest.mark.anyio + async def test_on_create_server_defaults_populated(self, mixin_session): + """id (server default via RETURNING) is available inside on_create.""" + obj = WatchedModel(status="active", other="x") + mixin_session.add(obj) + await mixin_session.commit() + await asyncio.sleep(0) + + creates = [e for e in _test_events if e["event"] == "create"] + assert creates[0]["obj_id"] is not None + assert isinstance(creates[0]["obj_id"], uuid.UUID) + + @pytest.mark.anyio + async def test_on_create_not_fired_on_update(self, mixin_session): + """on_create is NOT called when an existing row is updated.""" + obj = WatchedModel(status="initial", other="x") + mixin_session.add(obj) + await mixin_session.commit() + await asyncio.sleep(0) + _test_events.clear() + + obj.status = "updated" + await mixin_session.commit() + await asyncio.sleep(0) + + assert not any(e["event"] == "create" for e in _test_events) + + # --- on_delete --- + + @pytest.mark.anyio + async def test_on_delete_fires_after_delete(self, mixin_session): + """on_delete is called after DELETE commit.""" + obj = WatchedModel(status="active", other="x") + mixin_session.add(obj) + await mixin_session.commit() + await asyncio.sleep(0) + saved_id = obj.id + _test_events.clear() + + await mixin_session.delete(obj) + await mixin_session.commit() + await asyncio.sleep(0) + + deletes = [e for e in _test_events if e["event"] == "delete"] + assert len(deletes) == 1 + assert deletes[0]["obj_id"] == saved_id + + @pytest.mark.anyio + async def test_on_delete_not_fired_on_insert(self, mixin_session): + """on_delete is NOT called when a new row is inserted.""" + obj = WatchedModel(status="active", other="x") + mixin_session.add(obj) + await mixin_session.commit() + await asyncio.sleep(0) + + assert not any(e["event"] == "delete" for e in _test_events) + + # --- on_update --- + + @pytest.mark.anyio + async def test_on_update_fires_on_update(self, mixin_session): + """on_update reports the correct before/after values on UPDATE.""" + obj = WatchedModel(status="initial", other="x") + mixin_session.add(obj) + await mixin_session.commit() + await asyncio.sleep(0) + _test_events.clear() + + obj.status = "updated" + await mixin_session.commit() + await asyncio.sleep(0) + + changes_events = [e for e in _test_events if e["event"] == "update"] + assert len(changes_events) == 1 + assert changes_events[0]["changes"]["status"] == { + "old": "initial", + "new": "updated", + } + + @pytest.mark.anyio + async def test_on_update_not_fired_on_insert(self, mixin_session): + """on_update is NOT called on INSERT (on_create handles that).""" + obj = WatchedModel(status="active", other="x") + mixin_session.add(obj) + await mixin_session.commit() + await asyncio.sleep(0) + + assert not any(e["event"] == "update" for e in _test_events) + + @pytest.mark.anyio + async def test_unwatched_field_update_no_callback(self, mixin_session): + """Changing a field not listed in @update does not fire on_update.""" + obj = WatchedModel(status="active", other="x") + mixin_session.add(obj) + await mixin_session.commit() + await asyncio.sleep(0) + _test_events.clear() + + obj.other = "changed" + await mixin_session.commit() + await asyncio.sleep(0) + + assert _test_events == [] + + @pytest.mark.anyio + async def test_multiple_flushes_merge_earliest_old_latest_new(self, mixin_session): + """Two flushes in one transaction produce a single callback with earliest old / latest new.""" + obj = WatchedModel(status="initial", other="x") + mixin_session.add(obj) + await mixin_session.commit() + await asyncio.sleep(0) + _test_events.clear() + + obj.status = "intermediate" + await mixin_session.flush() + + obj.status = "final" + await mixin_session.commit() + await asyncio.sleep(0) + + changes_events = [e for e in _test_events if e["event"] == "update"] + assert len(changes_events) == 1 + assert changes_events[0]["changes"]["status"] == { + "old": "initial", + "new": "final", + } + + @pytest.mark.anyio + async def test_rollback_suppresses_all_callbacks(self, mixin_session): + """No callbacks are fired when the transaction is rolled back.""" + obj = WatchedModel(status="active", other="x") + mixin_session.add(obj) + await mixin_session.commit() + await asyncio.sleep(0) + _test_events.clear() + + obj.status = "changed" + await mixin_session.flush() + await mixin_session.rollback() + await asyncio.sleep(0) + + assert _test_events == [] + + @pytest.mark.anyio + async def test_non_watched_model_no_callback(self, mixin_session): + """Dirty objects whose type is not a WatchedFieldsMixin are skipped.""" + nw = NonWatchedModel(value="x") + mixin_session.add(nw) + await mixin_session.flush() + nw.value = "y" + await mixin_session.commit() + await asyncio.sleep(0) + + assert _test_events == [] + + # --- on_event (catch-all) --- + + @pytest.mark.anyio + async def test_on_event_receives_create(self, mixin_session): + """on_event is called with ModelEvent.CREATE on INSERT when only on_event is overridden.""" + obj = OnEventModel(value="x") + mixin_session.add(obj) + await mixin_session.commit() + await asyncio.sleep(0) + + creates = [e for e in _test_events if e["event"] == ModelEvent.CREATE] + assert len(creates) == 1 + assert creates[0]["changes"] is None + + @pytest.mark.anyio + async def test_on_event_receives_delete(self, mixin_session): + """on_event is called with ModelEvent.DELETE on DELETE when only on_event is overridden.""" + obj = OnEventModel(value="x") + mixin_session.add(obj) + await mixin_session.commit() + await asyncio.sleep(0) + _test_events.clear() + + await mixin_session.delete(obj) + await mixin_session.commit() + await asyncio.sleep(0) + + deletes = [e for e in _test_events if e["event"] == ModelEvent.DELETE] + assert len(deletes) == 1 + assert deletes[0]["changes"] is None + + @pytest.mark.anyio + async def test_on_event_receives_field_change(self, mixin_session): + """on_event is called with ModelEvent.UPDATE on UPDATE when only on_event is overridden.""" + obj = OnEventModel(value="initial") + mixin_session.add(obj) + await mixin_session.commit() + await asyncio.sleep(0) + _test_events.clear() + + obj.value = "updated" + await mixin_session.commit() + await asyncio.sleep(0) + + changes_events = [e for e in _test_events if e["event"] == ModelEvent.UPDATE] + assert len(changes_events) == 1 + assert changes_events[0]["changes"]["value"] == { + "old": "initial", + "new": "updated", + } + + +class TestWatchAll: + @pytest.fixture(autouse=True) + def clear_events(self): + _test_events.clear() + yield + _test_events.clear() + + @pytest.mark.anyio + async def test_watch_all_fires_for_any_field(self, mixin_session): + """Model without @watch fires on_update for any changed field.""" + obj = WatchAllModel(status="initial", other="x") + mixin_session.add(obj) + await mixin_session.commit() + await asyncio.sleep(0) + _test_events.clear() + + obj.other = "changed" + await mixin_session.commit() + await asyncio.sleep(0) + + changes_events = [e for e in _test_events if e["event"] == "update"] + assert len(changes_events) == 1 + assert "other" in changes_events[0]["changes"] + + @pytest.mark.anyio + async def test_watch_all_captures_multiple_fields(self, mixin_session): + """Model without @watch captures all fields changed in a single commit.""" + obj = WatchAllModel(status="initial", other="x") + mixin_session.add(obj) + await mixin_session.commit() + await asyncio.sleep(0) + _test_events.clear() + + obj.status = "updated" + obj.other = "changed" + await mixin_session.commit() + await asyncio.sleep(0) + + changes_events = [e for e in _test_events if e["event"] == "update"] + assert len(changes_events) == 1 + assert "status" in changes_events[0]["changes"] + assert "other" in changes_events[0]["changes"] + + +class TestSyncCallbacks: + @pytest.fixture(autouse=True) + def clear_events(self): + _sync_events.clear() + yield + _sync_events.clear() + + @pytest.mark.anyio + async def test_sync_on_create_fires(self, mixin_session): + """Sync on_create is called after INSERT commit.""" + obj = SyncCallbackModel(status="active") + mixin_session.add(obj) + await mixin_session.commit() + await asyncio.sleep(0) + + creates = [e for e in _sync_events if e["event"] == "create"] + assert len(creates) == 1 + assert isinstance(creates[0]["obj_id"], uuid.UUID) + + @pytest.mark.anyio + async def test_sync_on_delete_fires(self, mixin_session): + """Sync on_delete is called after DELETE commit.""" + obj = SyncCallbackModel(status="active") + mixin_session.add(obj) + await mixin_session.commit() + await asyncio.sleep(0) + _sync_events.clear() + + await mixin_session.delete(obj) + await mixin_session.commit() + await asyncio.sleep(0) + + deletes = [e for e in _sync_events if e["event"] == "delete"] + assert len(deletes) == 1 + + @pytest.mark.anyio + async def test_sync_on_update_fires(self, mixin_session): + """Sync on_update is called after UPDATE commit with correct changes.""" + obj = SyncCallbackModel(status="initial") + mixin_session.add(obj) + await mixin_session.commit() + await asyncio.sleep(0) + _sync_events.clear() + + obj.status = "updated" + await mixin_session.commit() + await asyncio.sleep(0) + + updates = [e for e in _sync_events if e["event"] == "update"] + assert len(updates) == 1 + assert updates[0]["changes"]["status"] == {"old": "initial", "new": "updated"} + + +class TestCallCallback: + @pytest.mark.anyio + async def test_async_callback_scheduled_as_task(self): + """_call_callback schedules async functions as tasks.""" + called = [] + + async def async_fn() -> None: + called.append("async") + + loop = asyncio.get_running_loop() + _call_callback(loop, async_fn) + await asyncio.sleep(0) + assert called == ["async"] + + @pytest.mark.anyio + async def test_sync_callback_called_directly(self): + """_call_callback invokes sync functions immediately.""" + called = [] + + def sync_fn() -> None: + called.append("sync") + + loop = asyncio.get_running_loop() + _call_callback(loop, sync_fn) + assert called == ["sync"] + + @pytest.mark.anyio + async def test_sync_callback_exception_logged(self): + """_call_callback logs exceptions from sync callbacks.""" + + def failing_fn() -> None: + raise RuntimeError("sync error") + + loop = asyncio.get_running_loop() + with patch.object(_watched_module._logger, "error") as mock_error: + _call_callback(loop, failing_fn) + mock_error.assert_called_once() + + @pytest.mark.anyio + async def test_async_callback_with_args(self): + """_call_callback passes arguments to async callbacks.""" + received = [] + + async def async_fn(changes: dict) -> None: + received.append(changes) + + loop = asyncio.get_running_loop() + _call_callback(loop, async_fn, {"status": {"old": "a", "new": "b"}}) + await asyncio.sleep(0) + assert received == [{"status": {"old": "a", "new": "b"}}] + + @pytest.mark.anyio + async def test_sync_callback_with_args(self): + """_call_callback passes arguments to sync callbacks.""" + received = [] + + def sync_fn(changes: dict) -> None: + received.append(changes) + + loop = asyncio.get_running_loop() + _call_callback(loop, sync_fn, {"x": 1}) + assert received == [{"x": 1}]