From 6981c33dc8641a747da09f158c2ba067fc8f95a0 Mon Sep 17 00:00:00 2001 From: d3vyce <44915747+d3vyce@users.noreply.github.com> Date: Mon, 23 Mar 2026 19:08:17 +0100 Subject: [PATCH] fix: inherit @watch field filter from parent classes via MRO traversal (#170) --- docs/module/models.md | 36 ++++++++++ src/fastapi_toolsets/models/watched.py | 10 ++- tests/test_models.py | 92 ++++++++++++++++++++++++++ 3 files changed, 137 insertions(+), 1 deletion(-) diff --git a/docs/module/models.md b/docs/module/models.md index 8376cea..5613900 100644 --- a/docs/module/models.md +++ b/docs/module/models.md @@ -138,6 +138,23 @@ Server-side defaults (e.g. `id`, `created_at`) are fully populated in all callba | `@watch("status", "role")` | Only fires when `status` or `role` changes | | *(no decorator)* | Fires when **any** mapped field changes | +`@watch` is inherited through the class hierarchy. If a subclass does not declare its own `@watch`, it uses the filter from the nearest decorated parent. Applying `@watch` on the subclass overrides the parent's filter: + +```python +@watch("status") +class Order(Base, UUIDMixin, WatchedFieldsMixin): + ... + +class UrgentOrder(Order): + # inherits @watch("status") — on_update fires only for status changes + ... + +@watch("priority") +class PriorityOrder(Order): + # overrides parent — on_update fires only for priority changes + ... +``` + #### Option 1 — catch-all with `on_event` Override `on_event` to handle all event types in one place. The specific methods delegate here by default: @@ -197,6 +214,25 @@ The `changes` dict maps each watched field that changed to `{"old": ..., "new": !!! warning "Callbacks fire only for ORM-level changes. Rows updated via raw SQL (`UPDATE ... SET ...`) are not detected." +!!! warning "Callbacks fire after the **outermost** transaction commits." + If you create several related objects using `CrudFactory.create` and need + callbacks to see all of them (including associations), wrap the whole + operation in a single [`get_transaction`](db.md) block. Without it, each + `create` call commits independently and `on_create` fires before the + remaining objects exist. + + ```python + from fastapi_toolsets.db import get_transaction + + async with get_transaction(session): + order = await OrderCrud.create(session, order_data) + item = await ItemCrud.create(session, item_data) + await session.refresh(order, attribute_names=["items"]) + order.items.append(item) + # on_create fires here for both order and item, + # with the full association already committed. + ``` + ## Composing mixins All mixins can be combined in any order. The only constraint is that exactly one primary key must be defined — either via `UUIDMixin` or directly on the model. diff --git a/src/fastapi_toolsets/models/watched.py b/src/fastapi_toolsets/models/watched.py index 69caf01..00e1e5c 100644 --- a/src/fastapi_toolsets/models/watched.py +++ b/src/fastapi_toolsets/models/watched.py @@ -66,6 +66,14 @@ def _snapshot_column_attrs(obj: Any) -> dict[str, Any]: } +def _get_watched_fields(cls: type) -> list[str] | None: + """Return the watched fields for *cls*, walking the MRO to inherit from parents.""" + for klass in cls.__mro__: + if klass in _WATCHED_FIELDS: + return _WATCHED_FIELDS[klass] + return None + + def _upsert_changes( pending: dict[int, tuple[Any, dict[str, dict[str, Any]]]], obj: Any, @@ -103,7 +111,7 @@ def _after_flush(session: Any, flush_context: Any) -> None: continue # None = not in dict = watch all fields; list = specific fields only - watched = _WATCHED_FIELDS.get(type(obj)) + watched = _get_watched_fields(type(obj)) changes: dict[str, dict[str, Any]] = {} attrs = ( diff --git a/tests/test_models.py b/tests/test_models.py index 18e8a8b..1c91742 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -172,6 +172,37 @@ class PolyDog(PolyAnimal): __mapper_args__ = {"polymorphic_identity": "dog"} +_watch_inherit_events: list[dict] = [] + + +@watch("status") +class WatchParent(MixinBase, UUIDMixin, WatchedFieldsMixin): + """Base class with @watch("status") — subclasses should inherit this filter.""" + + __tablename__ = "mixin_watch_parent" + __mapper_args__ = {"polymorphic_on": "kind", "polymorphic_identity": "parent"} + + kind: Mapped[str] = mapped_column(String(50)) + status: Mapped[str] = mapped_column(String(50)) + other: Mapped[str] = mapped_column(String(50)) + + async def on_update(self, changes: dict) -> None: + _watch_inherit_events.append({"type": type(self).__name__, "changes": changes}) + + +class WatchChild(WatchParent): + """STI subclass that does NOT redeclare @watch — should inherit parent's filter.""" + + __mapper_args__ = {"polymorphic_identity": "child"} + + +@watch("other") +class WatchOverride(WatchParent): + """STI subclass that overrides @watch with a different field.""" + + __mapper_args__ = {"polymorphic_identity": "override"} + + _attr_access_events: list[dict] = [] @@ -515,6 +546,67 @@ class TestWatchDecorator: watch() +class TestWatchInheritance: + @pytest.fixture(autouse=True) + def clear_events(self): + _watch_inherit_events.clear() + yield + _watch_inherit_events.clear() + + @pytest.mark.anyio + async def test_child_inherits_parent_watch_filter(self, mixin_session): + """Subclass without @watch inherits the parent's field filter.""" + obj = WatchChild(status="initial", other="x") + mixin_session.add(obj) + await mixin_session.commit() + await asyncio.sleep(0) + + obj.other = "changed" # not watched by parent's @watch("status") + await mixin_session.commit() + await asyncio.sleep(0) + + assert _watch_inherit_events == [] + + @pytest.mark.anyio + async def test_child_triggers_on_watched_field(self, mixin_session): + """Subclass without @watch triggers on_update for the parent's watched field.""" + obj = WatchChild(status="initial", other="x") + mixin_session.add(obj) + await mixin_session.commit() + await asyncio.sleep(0) + + obj.status = "updated" + await mixin_session.commit() + await asyncio.sleep(0) + + assert len(_watch_inherit_events) == 1 + assert _watch_inherit_events[0]["type"] == "WatchChild" + assert "status" in _watch_inherit_events[0]["changes"] + + @pytest.mark.anyio + async def test_subclass_override_takes_precedence(self, mixin_session): + """Subclass @watch overrides the parent's field filter.""" + obj = WatchOverride(status="initial", other="x") + mixin_session.add(obj) + await mixin_session.commit() + await asyncio.sleep(0) + + obj.status = ( + "changed" # watched by parent but overridden by child's @watch("other") + ) + await mixin_session.commit() + await asyncio.sleep(0) + + assert _watch_inherit_events == [] + + obj.other = "changed" + await mixin_session.commit() + await asyncio.sleep(0) + + assert len(_watch_inherit_events) == 1 + assert "other" in _watch_inherit_events[0]["changes"] + + class TestUpsertChanges: def test_inserts_new_entry(self): """New key is inserted with the full changes dict."""