Compare commits

..

2 Commits

2 changed files with 193 additions and 6 deletions

View File

@@ -1,6 +1,7 @@
"""Field-change monitoring via SQLAlchemy session events.""" """Field-change monitoring via SQLAlchemy session events."""
import asyncio import asyncio
import inspect
import weakref import weakref
from collections.abc import Awaitable from collections.abc import Awaitable
from enum import Enum from enum import Enum
@@ -169,7 +170,7 @@ def _schedule_with_snapshot(
_sa_set_committed_value(obj, key, value) _sa_set_committed_value(obj, key, value)
try: try:
result = fn(*args) result = fn(*args)
if asyncio.iscoroutine(result): if inspect.isawaitable(result):
await result await result
except Exception as exc: except Exception as exc:
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc) _logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
@@ -186,6 +187,15 @@ def _after_commit(session: Any) -> None:
_SESSION_UPDATES, {} _SESSION_UPDATES, {}
) )
if creates and deletes:
transient_ids = {id(o) for o in creates} & {id(o) for o in deletes}
if transient_ids:
creates = [o for o in creates if id(o) not in transient_ids]
deletes = [o for o in deletes if id(o) not in transient_ids]
field_changes = {
k: v for k, v in field_changes.items() if k not in transient_ids
}
if not creates and not deletes and not field_changes: if not creates and not deletes and not field_changes:
return return

View File

@@ -6,27 +6,27 @@ from contextlib import suppress
from types import SimpleNamespace from types import SimpleNamespace
from unittest.mock import patch from unittest.mock import patch
import fastapi_toolsets.models.watched as _watched_module
import pytest import pytest
from sqlalchemy import String from sqlalchemy import String
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
import fastapi_toolsets.models.watched as _watched_module
from fastapi_toolsets.models import ( from fastapi_toolsets.models import (
CreatedAtMixin, CreatedAtMixin,
ModelEvent, ModelEvent,
TimestampMixin, TimestampMixin,
UpdatedAtMixin,
UUIDMixin, UUIDMixin,
UUIDv7Mixin, UUIDv7Mixin,
UpdatedAtMixin,
WatchedFieldsMixin, WatchedFieldsMixin,
watch, watch,
) )
from fastapi_toolsets.models.watched import ( from fastapi_toolsets.models.watched import (
_SESSION_CREATES, _SESSION_CREATES,
_SESSION_DELETES, _SESSION_DELETES,
_SESSION_UPDATES,
_SESSION_PENDING_NEW, _SESSION_PENDING_NEW,
_SESSION_UPDATES,
_after_commit, _after_commit,
_after_flush, _after_flush,
_after_flush_postexec, _after_flush_postexec,
@@ -81,8 +81,6 @@ class FullMixinModel(MixinBase, UUIDMixin, UpdatedAtMixin):
name: Mapped[str] = mapped_column(String(50)) name: Mapped[str] = mapped_column(String(50))
# --- WatchedFieldsMixin test models ---
_test_events: list[dict] = [] _test_events: list[dict] = []
@@ -145,6 +143,35 @@ class NonWatchedModel(MixinBase):
value: Mapped[str] = mapped_column(String(50)) value: Mapped[str] = mapped_column(String(50))
_poly_events: list[dict] = []
class PolyAnimal(MixinBase, UUIDMixin, WatchedFieldsMixin):
"""Base class for STI polymorphism tests."""
__tablename__ = "mixin_poly_animals"
__mapper_args__ = {"polymorphic_on": "kind", "polymorphic_identity": "animal"}
kind: Mapped[str] = mapped_column(String(50))
name: Mapped[str] = mapped_column(String(50))
async def on_create(self) -> None:
_poly_events.append(
{"event": "create", "type": type(self).__name__, "obj_id": self.id}
)
async def on_delete(self) -> None:
_poly_events.append(
{"event": "delete", "type": type(self).__name__, "obj_id": self.id}
)
class PolyDog(PolyAnimal):
"""STI subclass — shares the same table as PolyAnimal."""
__mapper_args__ = {"polymorphic_identity": "dog"}
_attr_access_events: list[dict] = [] _attr_access_events: list[dict] = []
@@ -172,6 +199,7 @@ class AttrAccessModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
_sync_events: list[dict] = [] _sync_events: list[dict] = []
_future_events: list[str] = []
@watch("status") @watch("status")
@@ -192,6 +220,20 @@ class SyncCallbackModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
_sync_events.append({"event": "update", "changes": changes}) _sync_events.append({"event": "update", "changes": changes})
class FutureCallbackModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
"""Model whose on_create returns an asyncio.Task (awaitable, not a coroutine)."""
__tablename__ = "mixin_future_callback_models"
name: Mapped[str] = mapped_column(String(50))
def on_create(self) -> "asyncio.Task[None]":
async def _work() -> None:
_future_events.append("created")
return asyncio.ensure_future(_work())
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
async def mixin_session(): async def mixin_session():
engine = create_async_engine(DATABASE_URL, echo=False) engine = create_async_engine(DATABASE_URL, echo=False)
@@ -871,6 +913,119 @@ class TestWatchedFieldsMixin:
} }
class TestTransientObject:
"""Create + delete within the same transaction should fire no events."""
@pytest.fixture(autouse=True)
def clear_events(self):
_test_events.clear()
yield
_test_events.clear()
@pytest.mark.anyio
async def test_no_events_when_created_and_deleted_in_same_transaction(
self, mixin_session
):
"""Neither on_create nor on_delete fires when the object never survives a commit."""
obj = WatchedModel(status="active", other="x")
mixin_session.add(obj)
await mixin_session.flush()
await mixin_session.delete(obj)
await mixin_session.commit()
await asyncio.sleep(0)
assert _test_events == []
@pytest.mark.anyio
async def test_other_objects_unaffected(self, mixin_session):
"""on_create still fires for objects that are not deleted in the same transaction."""
survivor = WatchedModel(status="active", other="x")
transient = WatchedModel(status="gone", other="y")
mixin_session.add(survivor)
mixin_session.add(transient)
await mixin_session.flush()
await mixin_session.delete(transient)
await mixin_session.commit()
await asyncio.sleep(0)
creates = [e for e in _test_events if e["event"] == "create"]
deletes = [e for e in _test_events if e["event"] == "delete"]
assert len(creates) == 1
assert creates[0]["obj_id"] == survivor.id
assert deletes == []
@pytest.mark.anyio
async def test_distinct_create_and_delete_both_fire(self, mixin_session):
"""on_create and on_delete both fire when different objects are created and deleted."""
existing = WatchedModel(status="old", other="x")
mixin_session.add(existing)
await mixin_session.commit()
await asyncio.sleep(0)
_test_events.clear()
new_obj = WatchedModel(status="new", other="y")
mixin_session.add(new_obj)
await mixin_session.delete(existing)
await mixin_session.commit()
await asyncio.sleep(0)
creates = [e for e in _test_events if e["event"] == "create"]
deletes = [e for e in _test_events if e["event"] == "delete"]
assert len(creates) == 1
assert len(deletes) == 1
class TestPolymorphism:
"""WatchedFieldsMixin with STI (Single Table Inheritance)."""
@pytest.fixture(autouse=True)
def clear_events(self):
_poly_events.clear()
yield
_poly_events.clear()
@pytest.mark.anyio
async def test_on_create_fires_once_for_subclass(self, mixin_session):
"""on_create fires exactly once for a STI subclass instance."""
dog = PolyDog(name="Rex")
mixin_session.add(dog)
await mixin_session.commit()
await asyncio.sleep(0)
assert len(_poly_events) == 1
assert _poly_events[0]["event"] == "create"
assert _poly_events[0]["type"] == "PolyDog"
@pytest.mark.anyio
async def test_on_delete_fires_for_subclass(self, mixin_session):
"""on_delete fires for a STI subclass instance."""
dog = PolyDog(name="Rex")
mixin_session.add(dog)
await mixin_session.commit()
await asyncio.sleep(0)
_poly_events.clear()
await mixin_session.delete(dog)
await mixin_session.commit()
await asyncio.sleep(0)
assert len(_poly_events) == 1
assert _poly_events[0]["event"] == "delete"
assert _poly_events[0]["type"] == "PolyDog"
@pytest.mark.anyio
async def test_transient_subclass_fires_no_events(self, mixin_session):
"""Create + delete of a STI subclass in one transaction fires no events."""
dog = PolyDog(name="Rex")
mixin_session.add(dog)
await mixin_session.flush()
await mixin_session.delete(dog)
await mixin_session.commit()
await asyncio.sleep(0)
assert _poly_events == []
class TestWatchAll: class TestWatchAll:
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def clear_events(self): def clear_events(self):
@@ -968,6 +1123,28 @@ class TestSyncCallbacks:
assert updates[0]["changes"]["status"] == {"old": "initial", "new": "updated"} assert updates[0]["changes"]["status"] == {"old": "initial", "new": "updated"}
class TestFutureCallbacks:
"""Callbacks returning a non-coroutine awaitable (asyncio.Task / Future)."""
@pytest.fixture(autouse=True)
def clear_events(self):
_future_events.clear()
yield
_future_events.clear()
@pytest.mark.anyio
async def test_task_callback_is_awaited(self, mixin_session):
"""on_create returning an asyncio.Task is awaited and its work completes."""
obj = FutureCallbackModel(name="test")
mixin_session.add(obj)
await mixin_session.commit()
# Two turns: one for _run() to execute, one for the inner _work() task.
await asyncio.sleep(0)
await asyncio.sleep(0)
assert _future_events == ["created"]
class TestAttributeAccessInCallbacks: class TestAttributeAccessInCallbacks:
"""Verify that self attributes are accessible inside every callback type. """Verify that self attributes are accessible inside every callback type.