Compare commits

..

3 Commits

3 changed files with 330 additions and 7 deletions

View File

@@ -138,6 +138,23 @@ Server-side defaults (e.g. `id`, `created_at`) are fully populated in all callba
| `@watch("status", "role")` | Only fires when `status` or `role` changes | | `@watch("status", "role")` | Only fires when `status` or `role` changes |
| *(no decorator)* | Fires when **any** mapped field changes | | *(no decorator)* | Fires when **any** mapped field changes |
`@watch` is inherited through the class hierarchy. If a subclass does not declare its own `@watch`, it uses the filter from the nearest decorated parent. Applying `@watch` on the subclass overrides the parent's filter:
```python
@watch("status")
class Order(Base, UUIDMixin, WatchedFieldsMixin):
...
class UrgentOrder(Order):
# inherits @watch("status") — on_update fires only for status changes
...
@watch("priority")
class PriorityOrder(Order):
# overrides parent — on_update fires only for priority changes
...
```
#### Option 1 — catch-all with `on_event` #### Option 1 — catch-all with `on_event`
Override `on_event` to handle all event types in one place. The specific methods delegate here by default: Override `on_event` to handle all event types in one place. The specific methods delegate here by default:
@@ -197,6 +214,25 @@ The `changes` dict maps each watched field that changed to `{"old": ..., "new":
!!! warning "Callbacks fire only for ORM-level changes. Rows updated via raw SQL (`UPDATE ... SET ...`) are not detected." !!! warning "Callbacks fire only for ORM-level changes. Rows updated via raw SQL (`UPDATE ... SET ...`) are not detected."
!!! warning "Callbacks fire after the **outermost** transaction commits."
If you create several related objects using `CrudFactory.create` and need
callbacks to see all of them (including associations), wrap the whole
operation in a single [`get_transaction`](db.md) block. Without it, each
`create` call commits independently and `on_create` fires before the
remaining objects exist.
```python
from fastapi_toolsets.db import get_transaction
async with get_transaction(session):
order = await OrderCrud.create(session, order_data)
item = await ItemCrud.create(session, item_data)
await session.refresh(order, attribute_names=["items"])
order.items.append(item)
# on_create fires here for both order and item,
# with the full association already committed.
```
## Composing mixins ## Composing mixins
All mixins can be combined in any order. The only constraint is that exactly one primary key must be defined — either via `UUIDMixin` or directly on the model. All mixins can be combined in any order. The only constraint is that exactly one primary key must be defined — either via `UUIDMixin` or directly on the model.

View File

@@ -1,6 +1,7 @@
"""Field-change monitoring via SQLAlchemy session events.""" """Field-change monitoring via SQLAlchemy session events."""
import asyncio import asyncio
import inspect
import weakref import weakref
from collections.abc import Awaitable from collections.abc import Awaitable
from enum import Enum from enum import Enum
@@ -65,6 +66,14 @@ def _snapshot_column_attrs(obj: Any) -> dict[str, Any]:
} }
def _get_watched_fields(cls: type) -> list[str] | None:
"""Return the watched fields for *cls*, walking the MRO to inherit from parents."""
for klass in cls.__mro__:
if klass in _WATCHED_FIELDS:
return _WATCHED_FIELDS[klass]
return None
def _upsert_changes( def _upsert_changes(
pending: dict[int, tuple[Any, dict[str, dict[str, Any]]]], pending: dict[int, tuple[Any, dict[str, dict[str, Any]]]],
obj: Any, obj: Any,
@@ -102,7 +111,7 @@ def _after_flush(session: Any, flush_context: Any) -> None:
continue continue
# None = not in dict = watch all fields; list = specific fields only # None = not in dict = watch all fields; list = specific fields only
watched = _WATCHED_FIELDS.get(type(obj)) watched = _get_watched_fields(type(obj))
changes: dict[str, dict[str, Any]] = {} changes: dict[str, dict[str, Any]] = {}
attrs = ( attrs = (
@@ -169,7 +178,7 @@ def _schedule_with_snapshot(
_sa_set_committed_value(obj, key, value) _sa_set_committed_value(obj, key, value)
try: try:
result = fn(*args) result = fn(*args)
if asyncio.iscoroutine(result): if inspect.isawaitable(result):
await result await result
except Exception as exc: except Exception as exc:
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc) _logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
@@ -186,6 +195,15 @@ def _after_commit(session: Any) -> None:
_SESSION_UPDATES, {} _SESSION_UPDATES, {}
) )
if creates and deletes:
transient_ids = {id(o) for o in creates} & {id(o) for o in deletes}
if transient_ids:
creates = [o for o in creates if id(o) not in transient_ids]
deletes = [o for o in deletes if id(o) not in transient_ids]
field_changes = {
k: v for k, v in field_changes.items() if k not in transient_ids
}
if not creates and not deletes and not field_changes: if not creates and not deletes and not field_changes:
return return

View File

@@ -6,27 +6,27 @@ from contextlib import suppress
from types import SimpleNamespace from types import SimpleNamespace
from unittest.mock import patch from unittest.mock import patch
import fastapi_toolsets.models.watched as _watched_module
import pytest import pytest
from sqlalchemy import String from sqlalchemy import String
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
import fastapi_toolsets.models.watched as _watched_module
from fastapi_toolsets.models import ( from fastapi_toolsets.models import (
CreatedAtMixin, CreatedAtMixin,
ModelEvent, ModelEvent,
TimestampMixin, TimestampMixin,
UpdatedAtMixin,
UUIDMixin, UUIDMixin,
UUIDv7Mixin, UUIDv7Mixin,
UpdatedAtMixin,
WatchedFieldsMixin, WatchedFieldsMixin,
watch, watch,
) )
from fastapi_toolsets.models.watched import ( from fastapi_toolsets.models.watched import (
_SESSION_CREATES, _SESSION_CREATES,
_SESSION_DELETES, _SESSION_DELETES,
_SESSION_UPDATES,
_SESSION_PENDING_NEW, _SESSION_PENDING_NEW,
_SESSION_UPDATES,
_after_commit, _after_commit,
_after_flush, _after_flush,
_after_flush_postexec, _after_flush_postexec,
@@ -81,8 +81,6 @@ class FullMixinModel(MixinBase, UUIDMixin, UpdatedAtMixin):
name: Mapped[str] = mapped_column(String(50)) name: Mapped[str] = mapped_column(String(50))
# --- WatchedFieldsMixin test models ---
_test_events: list[dict] = [] _test_events: list[dict] = []
@@ -145,6 +143,66 @@ class NonWatchedModel(MixinBase):
value: Mapped[str] = mapped_column(String(50)) value: Mapped[str] = mapped_column(String(50))
_poly_events: list[dict] = []
class PolyAnimal(MixinBase, UUIDMixin, WatchedFieldsMixin):
"""Base class for STI polymorphism tests."""
__tablename__ = "mixin_poly_animals"
__mapper_args__ = {"polymorphic_on": "kind", "polymorphic_identity": "animal"}
kind: Mapped[str] = mapped_column(String(50))
name: Mapped[str] = mapped_column(String(50))
async def on_create(self) -> None:
_poly_events.append(
{"event": "create", "type": type(self).__name__, "obj_id": self.id}
)
async def on_delete(self) -> None:
_poly_events.append(
{"event": "delete", "type": type(self).__name__, "obj_id": self.id}
)
class PolyDog(PolyAnimal):
"""STI subclass — shares the same table as PolyAnimal."""
__mapper_args__ = {"polymorphic_identity": "dog"}
_watch_inherit_events: list[dict] = []
@watch("status")
class WatchParent(MixinBase, UUIDMixin, WatchedFieldsMixin):
"""Base class with @watch("status") — subclasses should inherit this filter."""
__tablename__ = "mixin_watch_parent"
__mapper_args__ = {"polymorphic_on": "kind", "polymorphic_identity": "parent"}
kind: Mapped[str] = mapped_column(String(50))
status: Mapped[str] = mapped_column(String(50))
other: Mapped[str] = mapped_column(String(50))
async def on_update(self, changes: dict) -> None:
_watch_inherit_events.append({"type": type(self).__name__, "changes": changes})
class WatchChild(WatchParent):
"""STI subclass that does NOT redeclare @watch — should inherit parent's filter."""
__mapper_args__ = {"polymorphic_identity": "child"}
@watch("other")
class WatchOverride(WatchParent):
"""STI subclass that overrides @watch with a different field."""
__mapper_args__ = {"polymorphic_identity": "override"}
_attr_access_events: list[dict] = [] _attr_access_events: list[dict] = []
@@ -172,6 +230,7 @@ class AttrAccessModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
_sync_events: list[dict] = [] _sync_events: list[dict] = []
_future_events: list[str] = []
@watch("status") @watch("status")
@@ -192,6 +251,20 @@ class SyncCallbackModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
_sync_events.append({"event": "update", "changes": changes}) _sync_events.append({"event": "update", "changes": changes})
class FutureCallbackModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
"""Model whose on_create returns an asyncio.Task (awaitable, not a coroutine)."""
__tablename__ = "mixin_future_callback_models"
name: Mapped[str] = mapped_column(String(50))
def on_create(self) -> "asyncio.Task[None]":
async def _work() -> None:
_future_events.append("created")
return asyncio.ensure_future(_work())
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
async def mixin_session(): async def mixin_session():
engine = create_async_engine(DATABASE_URL, echo=False) engine = create_async_engine(DATABASE_URL, echo=False)
@@ -473,6 +546,67 @@ class TestWatchDecorator:
watch() watch()
class TestWatchInheritance:
@pytest.fixture(autouse=True)
def clear_events(self):
_watch_inherit_events.clear()
yield
_watch_inherit_events.clear()
@pytest.mark.anyio
async def test_child_inherits_parent_watch_filter(self, mixin_session):
"""Subclass without @watch inherits the parent's field filter."""
obj = WatchChild(status="initial", other="x")
mixin_session.add(obj)
await mixin_session.commit()
await asyncio.sleep(0)
obj.other = "changed" # not watched by parent's @watch("status")
await mixin_session.commit()
await asyncio.sleep(0)
assert _watch_inherit_events == []
@pytest.mark.anyio
async def test_child_triggers_on_watched_field(self, mixin_session):
"""Subclass without @watch triggers on_update for the parent's watched field."""
obj = WatchChild(status="initial", other="x")
mixin_session.add(obj)
await mixin_session.commit()
await asyncio.sleep(0)
obj.status = "updated"
await mixin_session.commit()
await asyncio.sleep(0)
assert len(_watch_inherit_events) == 1
assert _watch_inherit_events[0]["type"] == "WatchChild"
assert "status" in _watch_inherit_events[0]["changes"]
@pytest.mark.anyio
async def test_subclass_override_takes_precedence(self, mixin_session):
"""Subclass @watch overrides the parent's field filter."""
obj = WatchOverride(status="initial", other="x")
mixin_session.add(obj)
await mixin_session.commit()
await asyncio.sleep(0)
obj.status = (
"changed" # watched by parent but overridden by child's @watch("other")
)
await mixin_session.commit()
await asyncio.sleep(0)
assert _watch_inherit_events == []
obj.other = "changed"
await mixin_session.commit()
await asyncio.sleep(0)
assert len(_watch_inherit_events) == 1
assert "other" in _watch_inherit_events[0]["changes"]
class TestUpsertChanges: class TestUpsertChanges:
def test_inserts_new_entry(self): def test_inserts_new_entry(self):
"""New key is inserted with the full changes dict.""" """New key is inserted with the full changes dict."""
@@ -871,6 +1005,119 @@ class TestWatchedFieldsMixin:
} }
class TestTransientObject:
"""Create + delete within the same transaction should fire no events."""
@pytest.fixture(autouse=True)
def clear_events(self):
_test_events.clear()
yield
_test_events.clear()
@pytest.mark.anyio
async def test_no_events_when_created_and_deleted_in_same_transaction(
self, mixin_session
):
"""Neither on_create nor on_delete fires when the object never survives a commit."""
obj = WatchedModel(status="active", other="x")
mixin_session.add(obj)
await mixin_session.flush()
await mixin_session.delete(obj)
await mixin_session.commit()
await asyncio.sleep(0)
assert _test_events == []
@pytest.mark.anyio
async def test_other_objects_unaffected(self, mixin_session):
"""on_create still fires for objects that are not deleted in the same transaction."""
survivor = WatchedModel(status="active", other="x")
transient = WatchedModel(status="gone", other="y")
mixin_session.add(survivor)
mixin_session.add(transient)
await mixin_session.flush()
await mixin_session.delete(transient)
await mixin_session.commit()
await asyncio.sleep(0)
creates = [e for e in _test_events if e["event"] == "create"]
deletes = [e for e in _test_events if e["event"] == "delete"]
assert len(creates) == 1
assert creates[0]["obj_id"] == survivor.id
assert deletes == []
@pytest.mark.anyio
async def test_distinct_create_and_delete_both_fire(self, mixin_session):
"""on_create and on_delete both fire when different objects are created and deleted."""
existing = WatchedModel(status="old", other="x")
mixin_session.add(existing)
await mixin_session.commit()
await asyncio.sleep(0)
_test_events.clear()
new_obj = WatchedModel(status="new", other="y")
mixin_session.add(new_obj)
await mixin_session.delete(existing)
await mixin_session.commit()
await asyncio.sleep(0)
creates = [e for e in _test_events if e["event"] == "create"]
deletes = [e for e in _test_events if e["event"] == "delete"]
assert len(creates) == 1
assert len(deletes) == 1
class TestPolymorphism:
"""WatchedFieldsMixin with STI (Single Table Inheritance)."""
@pytest.fixture(autouse=True)
def clear_events(self):
_poly_events.clear()
yield
_poly_events.clear()
@pytest.mark.anyio
async def test_on_create_fires_once_for_subclass(self, mixin_session):
"""on_create fires exactly once for a STI subclass instance."""
dog = PolyDog(name="Rex")
mixin_session.add(dog)
await mixin_session.commit()
await asyncio.sleep(0)
assert len(_poly_events) == 1
assert _poly_events[0]["event"] == "create"
assert _poly_events[0]["type"] == "PolyDog"
@pytest.mark.anyio
async def test_on_delete_fires_for_subclass(self, mixin_session):
"""on_delete fires for a STI subclass instance."""
dog = PolyDog(name="Rex")
mixin_session.add(dog)
await mixin_session.commit()
await asyncio.sleep(0)
_poly_events.clear()
await mixin_session.delete(dog)
await mixin_session.commit()
await asyncio.sleep(0)
assert len(_poly_events) == 1
assert _poly_events[0]["event"] == "delete"
assert _poly_events[0]["type"] == "PolyDog"
@pytest.mark.anyio
async def test_transient_subclass_fires_no_events(self, mixin_session):
"""Create + delete of a STI subclass in one transaction fires no events."""
dog = PolyDog(name="Rex")
mixin_session.add(dog)
await mixin_session.flush()
await mixin_session.delete(dog)
await mixin_session.commit()
await asyncio.sleep(0)
assert _poly_events == []
class TestWatchAll: class TestWatchAll:
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def clear_events(self): def clear_events(self):
@@ -968,6 +1215,28 @@ class TestSyncCallbacks:
assert updates[0]["changes"]["status"] == {"old": "initial", "new": "updated"} assert updates[0]["changes"]["status"] == {"old": "initial", "new": "updated"}
class TestFutureCallbacks:
"""Callbacks returning a non-coroutine awaitable (asyncio.Task / Future)."""
@pytest.fixture(autouse=True)
def clear_events(self):
_future_events.clear()
yield
_future_events.clear()
@pytest.mark.anyio
async def test_task_callback_is_awaited(self, mixin_session):
"""on_create returning an asyncio.Task is awaited and its work completes."""
obj = FutureCallbackModel(name="test")
mixin_session.add(obj)
await mixin_session.commit()
# Two turns: one for _run() to execute, one for the inner _work() task.
await asyncio.sleep(0)
await asyncio.sleep(0)
assert _future_events == ["created"]
class TestAttributeAccessInCallbacks: class TestAttributeAccessInCallbacks:
"""Verify that self attributes are accessible inside every callback type. """Verify that self attributes are accessible inside every callback type.