Files
fastapi-toolsets/tests/test_models.py

1363 lines
46 KiB
Python

"""Tests for fastapi_toolsets.models mixins."""
import asyncio
import uuid
from contextlib import suppress
from types import SimpleNamespace
from unittest.mock import patch
import pytest
from sqlalchemy import String
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from fastapi_toolsets.pytest import create_db_session
import fastapi_toolsets.models.watched as _watched_module
from fastapi_toolsets.models import (
CreatedAtMixin,
ModelEvent,
TimestampMixin,
UpdatedAtMixin,
UUIDMixin,
UUIDv7Mixin,
WatchedFieldsMixin,
watch,
)
from fastapi_toolsets.models.watched import (
_SESSION_CREATES,
_SESSION_DELETES,
_SESSION_PENDING_NEW,
_SESSION_UPDATES,
_after_commit,
_after_flush,
_after_flush_postexec,
_after_rollback,
_snapshot_column_attrs,
_task_error_handler,
_upsert_changes,
)
from .conftest import DATABASE_URL
class MixinBase(DeclarativeBase):
pass
class UUIDModel(MixinBase, UUIDMixin):
__tablename__ = "mixin_uuid_models"
name: Mapped[str] = mapped_column(String(50))
class UpdatedAtModel(MixinBase, UpdatedAtMixin):
__tablename__ = "mixin_updated_at_models"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
name: Mapped[str] = mapped_column(String(50))
class CreatedAtModel(MixinBase, CreatedAtMixin):
__tablename__ = "mixin_created_at_models"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
name: Mapped[str] = mapped_column(String(50))
class TimestampModel(MixinBase, TimestampMixin):
__tablename__ = "mixin_timestamp_models"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
name: Mapped[str] = mapped_column(String(50))
class UUIDv7Model(MixinBase, UUIDv7Mixin):
__tablename__ = "mixin_uuidv7_models"
name: Mapped[str] = mapped_column(String(50))
class FullMixinModel(MixinBase, UUIDMixin, UpdatedAtMixin):
__tablename__ = "mixin_full_models"
name: Mapped[str] = mapped_column(String(50))
_test_events: list[dict] = []
@watch("status")
class WatchedModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
__tablename__ = "mixin_watched_models"
status: Mapped[str] = mapped_column(String(50))
other: Mapped[str] = mapped_column(String(50))
async def on_create(self) -> None:
_test_events.append({"event": "create", "obj_id": self.id})
async def on_delete(self) -> None:
_test_events.append({"event": "delete", "obj_id": self.id})
async def on_update(self, changes: dict) -> None:
_test_events.append({"event": "update", "obj_id": self.id, "changes": changes})
@watch("value")
class OnEventModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
"""Model that only overrides on_event to test the catch-all path."""
__tablename__ = "mixin_on_event_models"
value: Mapped[str] = mapped_column(String(50))
async def on_event(self, event: ModelEvent, changes: dict | None = None) -> None:
_test_events.append({"event": event, "obj_id": self.id, "changes": changes})
class WatchAllModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
"""Model without @watch — watches all mapped fields by default."""
__tablename__ = "mixin_watch_all_models"
status: Mapped[str] = mapped_column(String(50))
other: Mapped[str] = mapped_column(String(50))
async def on_update(self, changes: dict) -> None:
_test_events.append({"event": "update", "obj_id": self.id, "changes": changes})
class FailingCallbackModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
"""Model whose on_create always raises to test exception logging."""
__tablename__ = "mixin_failing_callback_models"
name: Mapped[str] = mapped_column(String(50))
async def on_create(self) -> None:
raise RuntimeError("callback intentionally failed")
class NonWatchedModel(MixinBase):
__tablename__ = "mixin_non_watched_models"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
value: Mapped[str] = mapped_column(String(50))
_poly_events: list[dict] = []
class PolyAnimal(MixinBase, UUIDMixin, WatchedFieldsMixin):
"""Base class for STI polymorphism tests."""
__tablename__ = "mixin_poly_animals"
__mapper_args__ = {"polymorphic_on": "kind", "polymorphic_identity": "animal"}
kind: Mapped[str] = mapped_column(String(50))
name: Mapped[str] = mapped_column(String(50))
async def on_create(self) -> None:
_poly_events.append(
{"event": "create", "type": type(self).__name__, "obj_id": self.id}
)
async def on_delete(self) -> None:
_poly_events.append(
{"event": "delete", "type": type(self).__name__, "obj_id": self.id}
)
class PolyDog(PolyAnimal):
"""STI subclass — shares the same table as PolyAnimal."""
__mapper_args__ = {"polymorphic_identity": "dog"}
_watch_inherit_events: list[dict] = []
@watch("status")
class WatchParent(MixinBase, UUIDMixin, WatchedFieldsMixin):
"""Base class with @watch("status") — subclasses should inherit this filter."""
__tablename__ = "mixin_watch_parent"
__mapper_args__ = {"polymorphic_on": "kind", "polymorphic_identity": "parent"}
kind: Mapped[str] = mapped_column(String(50))
status: Mapped[str] = mapped_column(String(50))
other: Mapped[str] = mapped_column(String(50))
async def on_update(self, changes: dict) -> None:
_watch_inherit_events.append({"type": type(self).__name__, "changes": changes})
class WatchChild(WatchParent):
"""STI subclass that does NOT redeclare @watch — should inherit parent's filter."""
__mapper_args__ = {"polymorphic_identity": "child"}
@watch("other")
class WatchOverride(WatchParent):
"""STI subclass that overrides @watch with a different field."""
__mapper_args__ = {"polymorphic_identity": "override"}
_attr_access_events: list[dict] = []
class AttrAccessModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
"""Model used to verify that self attributes are accessible in every callback."""
__tablename__ = "mixin_attr_access_models"
name: Mapped[str] = mapped_column(String(50))
callback_url: Mapped[str | None] = mapped_column(String(200), nullable=True)
async def on_create(self) -> None:
_attr_access_events.append(
{
"event": "create",
"id": self.id,
"name": self.name,
"callback_url": self.callback_url,
}
)
async def on_delete(self) -> None:
_attr_access_events.append(
{
"event": "delete",
"id": self.id,
"name": self.name,
"callback_url": self.callback_url,
}
)
async def on_update(self, changes: dict) -> None:
_attr_access_events.append(
{
"event": "update",
"id": self.id,
"name": self.name,
"callback_url": self.callback_url,
}
)
_sync_events: list[dict] = []
_future_events: list[str] = []
@watch("status")
class SyncCallbackModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
"""Model with plain (sync) on_* callbacks."""
__tablename__ = "mixin_sync_callback_models"
status: Mapped[str] = mapped_column(String(50))
def on_create(self) -> None:
_sync_events.append({"event": "create", "obj_id": self.id})
def on_delete(self) -> None:
_sync_events.append({"event": "delete", "obj_id": self.id})
def on_update(self, changes: dict) -> None:
_sync_events.append({"event": "update", "changes": changes})
class FutureCallbackModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
"""Model whose on_create returns an asyncio.Task (awaitable, not a coroutine)."""
__tablename__ = "mixin_future_callback_models"
name: Mapped[str] = mapped_column(String(50))
def on_create(self) -> "asyncio.Task[None]":
async def _work() -> None:
_future_events.append("created")
return asyncio.ensure_future(_work())
@pytest.fixture(scope="function")
async def mixin_session():
async with create_db_session(DATABASE_URL, MixinBase) as session:
yield session
@pytest.fixture(scope="function")
async def mixin_session_expire():
"""Session with expire_on_commit=True (the default) to exercise attribute access after commit."""
async with create_db_session(
DATABASE_URL, MixinBase, expire_on_commit=True
) as session:
yield session
class TestUUIDMixin:
@pytest.mark.anyio
async def test_uuid_generated_by_db(self, mixin_session):
"""UUID is generated server-side and populated after flush."""
obj = UUIDModel(name="test")
mixin_session.add(obj)
await mixin_session.flush()
assert obj.id is not None
assert isinstance(obj.id, uuid.UUID)
@pytest.mark.anyio
async def test_uuid_is_primary_key(self):
"""UUIDMixin adds id as primary key column."""
pk_cols = [c.name for c in UUIDModel.__table__.primary_key]
assert pk_cols == ["id"]
@pytest.mark.anyio
async def test_each_row_gets_unique_uuid(self, mixin_session):
"""Each inserted row gets a distinct UUID."""
a = UUIDModel(name="a")
b = UUIDModel(name="b")
mixin_session.add_all([a, b])
await mixin_session.flush()
assert a.id != b.id
@pytest.mark.anyio
async def test_uuid_server_default_set(self):
"""Column has gen_random_uuid() as server default."""
col = UUIDModel.__table__.c["id"]
assert col.server_default is not None
assert "gen_random_uuid" in str(col.server_default.arg)
class TestUpdatedAtMixin:
@pytest.mark.anyio
async def test_updated_at_set_on_insert(self, mixin_session):
"""updated_at is populated after insert."""
obj = UpdatedAtModel(name="initial")
mixin_session.add(obj)
await mixin_session.flush()
await mixin_session.refresh(obj)
assert obj.updated_at is not None
assert obj.updated_at.tzinfo is not None
@pytest.mark.anyio
async def test_updated_at_changes_on_update(self, mixin_session):
"""updated_at is updated when the row is modified."""
obj = UpdatedAtModel(name="initial")
mixin_session.add(obj)
await mixin_session.flush()
await mixin_session.refresh(obj)
original_ts = obj.updated_at
obj.name = "modified"
await mixin_session.flush()
await mixin_session.refresh(obj)
assert obj.updated_at >= original_ts
@pytest.mark.anyio
async def test_updated_at_column_is_not_nullable(self):
"""updated_at column is non-nullable."""
col = UpdatedAtModel.__table__.c["updated_at"]
assert not col.nullable
@pytest.mark.anyio
async def test_updated_at_has_server_default(self):
"""updated_at column has a server-side default."""
col = UpdatedAtModel.__table__.c["updated_at"]
assert col.server_default is not None
@pytest.mark.anyio
async def test_updated_at_has_onupdate(self):
"""updated_at column has an onupdate clause."""
col = UpdatedAtModel.__table__.c["updated_at"]
assert col.onupdate is not None
class TestCreatedAtMixin:
@pytest.mark.anyio
async def test_created_at_set_on_insert(self, mixin_session):
"""created_at is populated after insert."""
obj = CreatedAtModel(name="new")
mixin_session.add(obj)
await mixin_session.flush()
await mixin_session.refresh(obj)
assert obj.created_at is not None
assert obj.created_at.tzinfo is not None
@pytest.mark.anyio
async def test_created_at_not_changed_on_update(self, mixin_session):
"""created_at is not modified when the row is updated."""
obj = CreatedAtModel(name="original")
mixin_session.add(obj)
await mixin_session.flush()
await mixin_session.refresh(obj)
original_ts = obj.created_at
obj.name = "updated"
await mixin_session.flush()
await mixin_session.refresh(obj)
assert obj.created_at == original_ts
@pytest.mark.anyio
async def test_created_at_column_is_not_nullable(self):
"""created_at column is non-nullable."""
col = CreatedAtModel.__table__.c["created_at"]
assert not col.nullable
@pytest.mark.anyio
async def test_created_at_has_no_onupdate(self):
"""created_at column has no onupdate clause."""
col = CreatedAtModel.__table__.c["created_at"]
assert col.onupdate is None
class TestTimestampMixin:
@pytest.mark.anyio
async def test_both_columns_set_on_insert(self, mixin_session):
"""created_at and updated_at are both populated after insert."""
obj = TimestampModel(name="new")
mixin_session.add(obj)
await mixin_session.flush()
await mixin_session.refresh(obj)
assert obj.created_at is not None
assert obj.updated_at is not None
@pytest.mark.anyio
async def test_created_at_stable_updated_at_changes_on_update(self, mixin_session):
"""On update: created_at stays the same, updated_at advances."""
obj = TimestampModel(name="original")
mixin_session.add(obj)
await mixin_session.flush()
await mixin_session.refresh(obj)
original_created = obj.created_at
original_updated = obj.updated_at
obj.name = "modified"
await mixin_session.flush()
await mixin_session.refresh(obj)
assert obj.created_at == original_created
assert obj.updated_at >= original_updated
@pytest.mark.anyio
async def test_timestamp_mixin_has_both_columns(self):
"""TimestampModel exposes both created_at and updated_at columns."""
col_names = {c.name for c in TimestampModel.__table__.columns}
assert "created_at" in col_names
assert "updated_at" in col_names
class TestUUIDv7Mixin:
@pytest.mark.anyio
async def test_uuid7_generated_by_db(self, mixin_session):
"""UUIDv7 is generated server-side and populated after flush."""
obj = UUIDv7Model(name="test")
mixin_session.add(obj)
await mixin_session.flush()
assert obj.id is not None
assert isinstance(obj.id, uuid.UUID)
@pytest.mark.anyio
async def test_uuid7_is_primary_key(self):
"""UUIDv7Mixin adds id as primary key column."""
pk_cols = [c.name for c in UUIDv7Model.__table__.primary_key]
assert pk_cols == ["id"]
@pytest.mark.anyio
async def test_each_row_gets_unique_uuid7(self, mixin_session):
"""Each inserted row gets a distinct UUIDv7."""
a = UUIDv7Model(name="a")
b = UUIDv7Model(name="b")
mixin_session.add_all([a, b])
await mixin_session.flush()
assert a.id != b.id
@pytest.mark.anyio
async def test_uuid7_version(self, mixin_session):
"""Generated UUIDs have version 7."""
obj = UUIDv7Model(name="test")
mixin_session.add(obj)
await mixin_session.flush()
assert obj.id.version == 7
@pytest.mark.anyio
async def test_uuid7_server_default_set(self):
"""Column has uuidv7() as server default."""
col = UUIDv7Model.__table__.c["id"]
assert col.server_default is not None
assert "uuidv7" in str(col.server_default.arg)
class TestFullMixinModel:
@pytest.mark.anyio
async def test_combined_mixins_work_together(self, mixin_session):
"""UUIDMixin and UpdatedAtMixin can be combined on the same model."""
obj = FullMixinModel(name="combined")
mixin_session.add(obj)
await mixin_session.flush()
await mixin_session.refresh(obj)
assert isinstance(obj.id, uuid.UUID)
assert obj.updated_at is not None
assert obj.updated_at.tzinfo is not None
class TestWatchDecorator:
def test_registers_specific_fields(self):
"""@watch("field") stores the field list in _WATCHED_FIELDS."""
assert _watched_module._WATCHED_FIELDS.get(WatchedModel) == ["status"]
def test_no_decorator_not_in_watched_fields(self):
"""A model without @watch has no entry in _WATCHED_FIELDS (watch all)."""
assert WatchAllModel not in _watched_module._WATCHED_FIELDS
def test_preserves_class_identity(self):
"""watch returns the same class unchanged."""
class _Dummy(WatchedFieldsMixin):
pass
result = watch("x")(_Dummy)
assert result is _Dummy
del _watched_module._WATCHED_FIELDS[_Dummy]
def test_raises_when_no_fields_given(self):
"""@watch() with no field names raises ValueError."""
with pytest.raises(ValueError, match="@watch requires at least one field name"):
watch()
class TestWatchInheritance:
@pytest.fixture(autouse=True)
def clear_events(self):
_watch_inherit_events.clear()
yield
_watch_inherit_events.clear()
@pytest.mark.anyio
async def test_child_inherits_parent_watch_filter(self, mixin_session):
"""Subclass without @watch inherits the parent's field filter."""
obj = WatchChild(status="initial", other="x")
mixin_session.add(obj)
await mixin_session.commit()
await asyncio.sleep(0)
obj.other = "changed" # not watched by parent's @watch("status")
await mixin_session.commit()
await asyncio.sleep(0)
assert _watch_inherit_events == []
@pytest.mark.anyio
async def test_child_triggers_on_watched_field(self, mixin_session):
"""Subclass without @watch triggers on_update for the parent's watched field."""
obj = WatchChild(status="initial", other="x")
mixin_session.add(obj)
await mixin_session.commit()
await asyncio.sleep(0)
obj.status = "updated"
await mixin_session.commit()
await asyncio.sleep(0)
assert len(_watch_inherit_events) == 1
assert _watch_inherit_events[0]["type"] == "WatchChild"
assert "status" in _watch_inherit_events[0]["changes"]
@pytest.mark.anyio
async def test_subclass_override_takes_precedence(self, mixin_session):
"""Subclass @watch overrides the parent's field filter."""
obj = WatchOverride(status="initial", other="x")
mixin_session.add(obj)
await mixin_session.commit()
await asyncio.sleep(0)
obj.status = (
"changed" # watched by parent but overridden by child's @watch("other")
)
await mixin_session.commit()
await asyncio.sleep(0)
assert _watch_inherit_events == []
obj.other = "changed"
await mixin_session.commit()
await asyncio.sleep(0)
assert len(_watch_inherit_events) == 1
assert "other" in _watch_inherit_events[0]["changes"]
class TestUpsertChanges:
def test_inserts_new_entry(self):
"""New key is inserted with the full changes dict."""
pending: dict = {}
obj = object()
changes = {"status": {"old": None, "new": "active"}}
_upsert_changes(pending, obj, changes)
assert pending[id(obj)] == (obj, changes)
def test_merges_existing_field_keeps_old_updates_new(self):
"""When the field already exists, old is preserved and new is overwritten."""
obj = object()
pending = {
id(obj): (obj, {"status": {"old": "initial", "new": "intermediate"}})
}
_upsert_changes(
pending, obj, {"status": {"old": "intermediate", "new": "final"}}
)
assert pending[id(obj)][1]["status"] == {"old": "initial", "new": "final"}
def test_adds_new_field_to_existing_entry(self):
"""A previously unseen field is added alongside existing ones."""
obj = object()
pending = {id(obj): (obj, {"status": {"old": "a", "new": "b"}})}
_upsert_changes(pending, obj, {"role": {"old": "user", "new": "admin"}})
fields = pending[id(obj)][1]
assert fields["status"] == {"old": "a", "new": "b"}
assert fields["role"] == {"old": "user", "new": "admin"}
class TestAfterFlush:
def test_does_nothing_with_empty_session(self):
"""_after_flush writes nothing to session.info when all collections are empty."""
session = SimpleNamespace(new=[], deleted=[], dirty=[], info={})
_after_flush(session, None)
assert session.info == {}
def test_captures_new_watched_mixin_objects(self):
"""New WatchedFieldsMixin instances are added to _SESSION_PENDING_NEW."""
obj = WatchedFieldsMixin()
session = SimpleNamespace(new=[obj], deleted=[], dirty=[], info={})
_after_flush(session, None)
assert session.info[_SESSION_PENDING_NEW] == [obj]
def test_ignores_new_non_mixin_objects(self):
"""New objects that are not WatchedFieldsMixin are not captured."""
obj = object()
session = SimpleNamespace(new=[obj], deleted=[], dirty=[], info={})
_after_flush(session, None)
assert _SESSION_PENDING_NEW not in session.info
def test_captures_deleted_watched_mixin_objects(self):
"""Deleted WatchedFieldsMixin instances are added to _SESSION_DELETES."""
obj = WatchedFieldsMixin()
session = SimpleNamespace(new=[], deleted=[obj], dirty=[], info={})
_after_flush(session, None)
assert session.info[_SESSION_DELETES] == [obj]
def test_ignores_deleted_non_mixin_objects(self):
"""Deleted objects that are not WatchedFieldsMixin are not captured."""
obj = object()
session = SimpleNamespace(new=[], deleted=[obj], dirty=[], info={})
_after_flush(session, None)
assert _SESSION_DELETES not in session.info
class TestAfterFlushPostexec:
def test_does_nothing_when_no_pending_new(self):
"""_after_flush_postexec does nothing when _SESSION_PENDING_NEW is absent."""
session = SimpleNamespace(info={})
_after_flush_postexec(session, None)
assert _SESSION_CREATES not in session.info
def test_moves_pending_new_to_creates(self):
"""Objects from _SESSION_PENDING_NEW are moved to _SESSION_CREATES."""
obj = object()
session = SimpleNamespace(info={_SESSION_PENDING_NEW: [obj]})
_after_flush_postexec(session, None)
assert _SESSION_PENDING_NEW not in session.info
assert session.info[_SESSION_CREATES] == [obj]
def test_extends_existing_creates(self):
"""Multiple flushes accumulate in _SESSION_CREATES."""
a, b = object(), object()
session = SimpleNamespace(
info={_SESSION_PENDING_NEW: [b], _SESSION_CREATES: [a]}
)
_after_flush_postexec(session, None)
assert session.info[_SESSION_CREATES] == [a, b]
class TestAfterRollback:
def test_clears_all_session_info_keys(self):
"""_after_rollback removes all four tracking keys from session.info."""
session = SimpleNamespace(
info={
_SESSION_PENDING_NEW: [object()],
_SESSION_CREATES: [object()],
_SESSION_DELETES: [object()],
_SESSION_UPDATES: {1: ("obj", {"f": {"old": "a", "new": "b"}})},
}
)
_after_rollback(session)
assert _SESSION_PENDING_NEW not in session.info
assert _SESSION_CREATES not in session.info
assert _SESSION_DELETES not in session.info
assert _SESSION_UPDATES not in session.info
def test_tolerates_missing_keys(self):
"""_after_rollback does not raise when session.info has no pending data."""
session = SimpleNamespace(info={})
_after_rollback(session) # must not raise
class TestTaskErrorHandler:
@pytest.mark.anyio
async def test_logs_exception_from_failed_task(self):
"""_task_error_handler calls _logger.error when the task raised."""
async def failing() -> None:
raise ValueError("boom")
task = asyncio.create_task(failing())
await asyncio.sleep(0)
with patch.object(_watched_module._logger, "error") as mock_error:
_task_error_handler(task)
mock_error.assert_called_once()
@pytest.mark.anyio
async def test_ignores_cancelled_task(self):
"""_task_error_handler does not log when the task was cancelled."""
async def slow() -> None:
await asyncio.sleep(100)
task = asyncio.create_task(slow())
task.cancel()
with suppress(asyncio.CancelledError):
await task
with patch.object(_watched_module._logger, "error") as mock_error:
_task_error_handler(task)
mock_error.assert_not_called()
class TestAfterCommitNoLoop:
def test_no_task_scheduled_when_no_running_loop(self):
"""_after_commit silently returns when called outside an async context."""
called = []
obj = SimpleNamespace(on_create=lambda: called.append("create"))
session = SimpleNamespace(info={_SESSION_CREATES: [obj]})
_after_commit(session)
assert called == []
def test_returns_early_when_all_pending_empty(self):
"""_after_commit does nothing when all pending lists are empty."""
session = SimpleNamespace(info={})
_after_commit(session) # should not raise
class TestWatchedFieldsMixin:
@pytest.fixture(autouse=True)
def clear_events(self):
_test_events.clear()
yield
_test_events.clear()
# --- on_create ---
@pytest.mark.anyio
async def test_on_create_fires_after_insert(self, mixin_session):
"""on_create is called after INSERT commit."""
obj = WatchedModel(status="active", other="x")
mixin_session.add(obj)
await mixin_session.commit()
await asyncio.sleep(0)
creates = [e for e in _test_events if e["event"] == "create"]
assert len(creates) == 1
@pytest.mark.anyio
async def test_on_create_server_defaults_populated(self, mixin_session):
"""id (server default via RETURNING) is available inside on_create."""
obj = WatchedModel(status="active", other="x")
mixin_session.add(obj)
await mixin_session.commit()
await asyncio.sleep(0)
creates = [e for e in _test_events if e["event"] == "create"]
assert creates[0]["obj_id"] is not None
assert isinstance(creates[0]["obj_id"], uuid.UUID)
@pytest.mark.anyio
async def test_on_create_not_fired_on_update(self, mixin_session):
"""on_create is NOT called when an existing row is updated."""
obj = WatchedModel(status="initial", other="x")
mixin_session.add(obj)
await mixin_session.commit()
await asyncio.sleep(0)
_test_events.clear()
obj.status = "updated"
await mixin_session.commit()
await asyncio.sleep(0)
assert not any(e["event"] == "create" for e in _test_events)
# --- on_delete ---
@pytest.mark.anyio
async def test_on_delete_fires_after_delete(self, mixin_session):
"""on_delete is called after DELETE commit."""
obj = WatchedModel(status="active", other="x")
mixin_session.add(obj)
await mixin_session.commit()
await asyncio.sleep(0)
saved_id = obj.id
_test_events.clear()
await mixin_session.delete(obj)
await mixin_session.commit()
await asyncio.sleep(0)
deletes = [e for e in _test_events if e["event"] == "delete"]
assert len(deletes) == 1
assert deletes[0]["obj_id"] == saved_id
@pytest.mark.anyio
async def test_on_delete_not_fired_on_insert(self, mixin_session):
"""on_delete is NOT called when a new row is inserted."""
obj = WatchedModel(status="active", other="x")
mixin_session.add(obj)
await mixin_session.commit()
await asyncio.sleep(0)
assert not any(e["event"] == "delete" for e in _test_events)
# --- on_update ---
@pytest.mark.anyio
async def test_on_update_fires_on_update(self, mixin_session):
"""on_update reports the correct before/after values on UPDATE."""
obj = WatchedModel(status="initial", other="x")
mixin_session.add(obj)
await mixin_session.commit()
await asyncio.sleep(0)
_test_events.clear()
obj.status = "updated"
await mixin_session.commit()
await asyncio.sleep(0)
changes_events = [e for e in _test_events if e["event"] == "update"]
assert len(changes_events) == 1
assert changes_events[0]["changes"]["status"] == {
"old": "initial",
"new": "updated",
}
@pytest.mark.anyio
async def test_on_update_not_fired_on_insert(self, mixin_session):
"""on_update is NOT called on INSERT (on_create handles that)."""
obj = WatchedModel(status="active", other="x")
mixin_session.add(obj)
await mixin_session.commit()
await asyncio.sleep(0)
assert not any(e["event"] == "update" for e in _test_events)
@pytest.mark.anyio
async def test_unwatched_field_update_no_callback(self, mixin_session):
"""Changing a field not listed in @update does not fire on_update."""
obj = WatchedModel(status="active", other="x")
mixin_session.add(obj)
await mixin_session.commit()
await asyncio.sleep(0)
_test_events.clear()
obj.other = "changed"
await mixin_session.commit()
await asyncio.sleep(0)
assert _test_events == []
@pytest.mark.anyio
async def test_multiple_flushes_merge_earliest_old_latest_new(self, mixin_session):
"""Two flushes in one transaction produce a single callback with earliest old / latest new."""
obj = WatchedModel(status="initial", other="x")
mixin_session.add(obj)
await mixin_session.commit()
await asyncio.sleep(0)
_test_events.clear()
obj.status = "intermediate"
await mixin_session.flush()
obj.status = "final"
await mixin_session.commit()
await asyncio.sleep(0)
changes_events = [e for e in _test_events if e["event"] == "update"]
assert len(changes_events) == 1
assert changes_events[0]["changes"]["status"] == {
"old": "initial",
"new": "final",
}
@pytest.mark.anyio
async def test_rollback_suppresses_all_callbacks(self, mixin_session):
"""No callbacks are fired when the transaction is rolled back."""
obj = WatchedModel(status="active", other="x")
mixin_session.add(obj)
await mixin_session.commit()
await asyncio.sleep(0)
_test_events.clear()
obj.status = "changed"
await mixin_session.flush()
await mixin_session.rollback()
await asyncio.sleep(0)
assert _test_events == []
@pytest.mark.anyio
async def test_callback_exception_is_logged(self, mixin_session):
"""Exceptions raised inside on_create are logged, not propagated."""
obj = FailingCallbackModel(name="boom")
mixin_session.add(obj)
with patch.object(_watched_module._logger, "error") as mock_error:
await mixin_session.commit()
await asyncio.sleep(0)
mock_error.assert_called_once()
@pytest.mark.anyio
async def test_non_watched_model_no_callback(self, mixin_session):
"""Dirty objects whose type is not a WatchedFieldsMixin are skipped."""
nw = NonWatchedModel(value="x")
mixin_session.add(nw)
await mixin_session.flush()
nw.value = "y"
await mixin_session.commit()
await asyncio.sleep(0)
assert _test_events == []
# --- on_event (catch-all) ---
@pytest.mark.anyio
async def test_on_event_receives_create(self, mixin_session):
"""on_event is called with ModelEvent.CREATE on INSERT when only on_event is overridden."""
obj = OnEventModel(value="x")
mixin_session.add(obj)
await mixin_session.commit()
await asyncio.sleep(0)
creates = [e for e in _test_events if e["event"] == ModelEvent.CREATE]
assert len(creates) == 1
assert creates[0]["changes"] is None
@pytest.mark.anyio
async def test_on_event_receives_delete(self, mixin_session):
"""on_event is called with ModelEvent.DELETE on DELETE when only on_event is overridden."""
obj = OnEventModel(value="x")
mixin_session.add(obj)
await mixin_session.commit()
await asyncio.sleep(0)
_test_events.clear()
await mixin_session.delete(obj)
await mixin_session.commit()
await asyncio.sleep(0)
deletes = [e for e in _test_events if e["event"] == ModelEvent.DELETE]
assert len(deletes) == 1
assert deletes[0]["changes"] is None
@pytest.mark.anyio
async def test_on_event_receives_field_change(self, mixin_session):
"""on_event is called with ModelEvent.UPDATE on UPDATE when only on_event is overridden."""
obj = OnEventModel(value="initial")
mixin_session.add(obj)
await mixin_session.commit()
await asyncio.sleep(0)
_test_events.clear()
obj.value = "updated"
await mixin_session.commit()
await asyncio.sleep(0)
changes_events = [e for e in _test_events if e["event"] == ModelEvent.UPDATE]
assert len(changes_events) == 1
assert changes_events[0]["changes"]["value"] == {
"old": "initial",
"new": "updated",
}
class TestTransientObject:
"""Create + delete within the same transaction should fire no events."""
@pytest.fixture(autouse=True)
def clear_events(self):
_test_events.clear()
yield
_test_events.clear()
@pytest.mark.anyio
async def test_no_events_when_created_and_deleted_in_same_transaction(
self, mixin_session
):
"""Neither on_create nor on_delete fires when the object never survives a commit."""
obj = WatchedModel(status="active", other="x")
mixin_session.add(obj)
await mixin_session.flush()
await mixin_session.delete(obj)
await mixin_session.commit()
await asyncio.sleep(0)
assert _test_events == []
@pytest.mark.anyio
async def test_other_objects_unaffected(self, mixin_session):
"""on_create still fires for objects that are not deleted in the same transaction."""
survivor = WatchedModel(status="active", other="x")
transient = WatchedModel(status="gone", other="y")
mixin_session.add(survivor)
mixin_session.add(transient)
await mixin_session.flush()
await mixin_session.delete(transient)
await mixin_session.commit()
await asyncio.sleep(0)
creates = [e for e in _test_events if e["event"] == "create"]
deletes = [e for e in _test_events if e["event"] == "delete"]
assert len(creates) == 1
assert creates[0]["obj_id"] == survivor.id
assert deletes == []
@pytest.mark.anyio
async def test_distinct_create_and_delete_both_fire(self, mixin_session):
"""on_create and on_delete both fire when different objects are created and deleted."""
existing = WatchedModel(status="old", other="x")
mixin_session.add(existing)
await mixin_session.commit()
await asyncio.sleep(0)
_test_events.clear()
new_obj = WatchedModel(status="new", other="y")
mixin_session.add(new_obj)
await mixin_session.delete(existing)
await mixin_session.commit()
await asyncio.sleep(0)
creates = [e for e in _test_events if e["event"] == "create"]
deletes = [e for e in _test_events if e["event"] == "delete"]
assert len(creates) == 1
assert len(deletes) == 1
class TestPolymorphism:
"""WatchedFieldsMixin with STI (Single Table Inheritance)."""
@pytest.fixture(autouse=True)
def clear_events(self):
_poly_events.clear()
yield
_poly_events.clear()
@pytest.mark.anyio
async def test_on_create_fires_once_for_subclass(self, mixin_session):
"""on_create fires exactly once for a STI subclass instance."""
dog = PolyDog(name="Rex")
mixin_session.add(dog)
await mixin_session.commit()
await asyncio.sleep(0)
assert len(_poly_events) == 1
assert _poly_events[0]["event"] == "create"
assert _poly_events[0]["type"] == "PolyDog"
@pytest.mark.anyio
async def test_on_delete_fires_for_subclass(self, mixin_session):
"""on_delete fires for a STI subclass instance."""
dog = PolyDog(name="Rex")
mixin_session.add(dog)
await mixin_session.commit()
await asyncio.sleep(0)
_poly_events.clear()
await mixin_session.delete(dog)
await mixin_session.commit()
await asyncio.sleep(0)
assert len(_poly_events) == 1
assert _poly_events[0]["event"] == "delete"
assert _poly_events[0]["type"] == "PolyDog"
@pytest.mark.anyio
async def test_transient_subclass_fires_no_events(self, mixin_session):
"""Create + delete of a STI subclass in one transaction fires no events."""
dog = PolyDog(name="Rex")
mixin_session.add(dog)
await mixin_session.flush()
await mixin_session.delete(dog)
await mixin_session.commit()
await asyncio.sleep(0)
assert _poly_events == []
class TestWatchAll:
@pytest.fixture(autouse=True)
def clear_events(self):
_test_events.clear()
yield
_test_events.clear()
@pytest.mark.anyio
async def test_watch_all_fires_for_any_field(self, mixin_session):
"""Model without @watch fires on_update for any changed field."""
obj = WatchAllModel(status="initial", other="x")
mixin_session.add(obj)
await mixin_session.commit()
await asyncio.sleep(0)
_test_events.clear()
obj.other = "changed"
await mixin_session.commit()
await asyncio.sleep(0)
changes_events = [e for e in _test_events if e["event"] == "update"]
assert len(changes_events) == 1
assert "other" in changes_events[0]["changes"]
@pytest.mark.anyio
async def test_watch_all_captures_multiple_fields(self, mixin_session):
"""Model without @watch captures all fields changed in a single commit."""
obj = WatchAllModel(status="initial", other="x")
mixin_session.add(obj)
await mixin_session.commit()
await asyncio.sleep(0)
_test_events.clear()
obj.status = "updated"
obj.other = "changed"
await mixin_session.commit()
await asyncio.sleep(0)
changes_events = [e for e in _test_events if e["event"] == "update"]
assert len(changes_events) == 1
assert "status" in changes_events[0]["changes"]
assert "other" in changes_events[0]["changes"]
class TestSyncCallbacks:
@pytest.fixture(autouse=True)
def clear_events(self):
_sync_events.clear()
yield
_sync_events.clear()
@pytest.mark.anyio
async def test_sync_on_create_fires(self, mixin_session):
"""Sync on_create is called after INSERT commit."""
obj = SyncCallbackModel(status="active")
mixin_session.add(obj)
await mixin_session.commit()
await asyncio.sleep(0)
creates = [e for e in _sync_events if e["event"] == "create"]
assert len(creates) == 1
assert isinstance(creates[0]["obj_id"], uuid.UUID)
@pytest.mark.anyio
async def test_sync_on_delete_fires(self, mixin_session):
"""Sync on_delete is called after DELETE commit."""
obj = SyncCallbackModel(status="active")
mixin_session.add(obj)
await mixin_session.commit()
await asyncio.sleep(0)
_sync_events.clear()
await mixin_session.delete(obj)
await mixin_session.commit()
await asyncio.sleep(0)
deletes = [e for e in _sync_events if e["event"] == "delete"]
assert len(deletes) == 1
@pytest.mark.anyio
async def test_sync_on_update_fires(self, mixin_session):
"""Sync on_update is called after UPDATE commit with correct changes."""
obj = SyncCallbackModel(status="initial")
mixin_session.add(obj)
await mixin_session.commit()
await asyncio.sleep(0)
_sync_events.clear()
obj.status = "updated"
await mixin_session.commit()
await asyncio.sleep(0)
updates = [e for e in _sync_events if e["event"] == "update"]
assert len(updates) == 1
assert updates[0]["changes"]["status"] == {"old": "initial", "new": "updated"}
class TestFutureCallbacks:
"""Callbacks returning a non-coroutine awaitable (asyncio.Task / Future)."""
@pytest.fixture(autouse=True)
def clear_events(self):
_future_events.clear()
yield
_future_events.clear()
@pytest.mark.anyio
async def test_task_callback_is_awaited(self, mixin_session):
"""on_create returning an asyncio.Task is awaited and its work completes."""
obj = FutureCallbackModel(name="test")
mixin_session.add(obj)
await mixin_session.commit()
# Two turns: one for _run() to execute, one for the inner _work() task.
await asyncio.sleep(0)
await asyncio.sleep(0)
assert _future_events == ["created"]
class TestAttributeAccessInCallbacks:
"""Verify that self attributes are accessible inside every callback type.
Uses expire_on_commit=True (the SQLAlchemy default) so the tests would fail
without the snapshot-restore logic in _schedule_with_snapshot.
"""
@pytest.fixture(autouse=True)
def clear_events(self):
_attr_access_events.clear()
yield
_attr_access_events.clear()
@pytest.mark.anyio
async def test_on_create_pk_and_field_accessible(self, mixin_session_expire):
"""id (server default) and regular fields are readable inside on_create."""
obj = AttrAccessModel(name="hello")
mixin_session_expire.add(obj)
await mixin_session_expire.commit()
await asyncio.sleep(0)
events = [e for e in _attr_access_events if e["event"] == "create"]
assert len(events) == 1
assert isinstance(events[0]["id"], uuid.UUID)
assert events[0]["name"] == "hello"
@pytest.mark.anyio
async def test_on_delete_pk_and_field_accessible(self, mixin_session_expire):
"""id and regular fields are readable inside on_delete."""
obj = AttrAccessModel(name="to-delete")
mixin_session_expire.add(obj)
await mixin_session_expire.commit()
await asyncio.sleep(0)
_attr_access_events.clear()
await mixin_session_expire.delete(obj)
await mixin_session_expire.commit()
await asyncio.sleep(0)
events = [e for e in _attr_access_events if e["event"] == "delete"]
assert len(events) == 1
assert isinstance(events[0]["id"], uuid.UUID)
assert events[0]["name"] == "to-delete"
@pytest.mark.anyio
async def test_on_update_pk_and_updated_field_accessible(
self, mixin_session_expire
):
"""id and the new field value are readable inside on_update."""
obj = AttrAccessModel(name="original")
mixin_session_expire.add(obj)
await mixin_session_expire.commit()
await asyncio.sleep(0)
_attr_access_events.clear()
obj.name = "updated"
await mixin_session_expire.commit()
await asyncio.sleep(0)
events = [e for e in _attr_access_events if e["event"] == "update"]
assert len(events) == 1
assert isinstance(events[0]["id"], uuid.UUID)
assert events[0]["name"] == "updated"
@pytest.mark.anyio
async def test_nullable_column_none_accessible_in_on_create(
self, mixin_session_expire
):
"""Nullable column left as None is accessible in on_create without greenlet error."""
obj = AttrAccessModel(name="no-url") # callback_url not set → None
mixin_session_expire.add(obj)
await mixin_session_expire.commit()
await asyncio.sleep(0)
events = [e for e in _attr_access_events if e["event"] == "create"]
assert len(events) == 1
assert events[0]["callback_url"] is None
@pytest.mark.anyio
async def test_nullable_column_with_value_accessible_in_on_create(
self, mixin_session_expire
):
"""Nullable column set to a value is accessible in on_create without greenlet error."""
obj = AttrAccessModel(name="with-url", callback_url="https://example.com/hook")
mixin_session_expire.add(obj)
await mixin_session_expire.commit()
await asyncio.sleep(0)
events = [e for e in _attr_access_events if e["event"] == "create"]
assert len(events) == 1
assert events[0]["callback_url"] == "https://example.com/hook"
@pytest.mark.anyio
async def test_nullable_column_accessible_after_update_to_none(
self, mixin_session_expire
):
"""Nullable column updated to None is accessible in on_update without greenlet error."""
obj = AttrAccessModel(name="x", callback_url="https://example.com/hook")
mixin_session_expire.add(obj)
await mixin_session_expire.commit()
await asyncio.sleep(0)
_attr_access_events.clear()
obj.callback_url = None
await mixin_session_expire.commit()
await asyncio.sleep(0)
events = [e for e in _attr_access_events if e["event"] == "update"]
assert len(events) == 1
assert events[0]["callback_url"] is None
@pytest.mark.anyio
async def test_expired_nullable_column_not_inferred_as_none(
self, mixin_session_expire
):
"""A nullable column with a real value that is expired (by a prior
expire_on_commit) must not be inferred as None in the snapshot — its
actual value is unknown without a DB refresh."""
obj = AttrAccessModel(name="original", callback_url="https://example.com/hook")
mixin_session_expire.add(obj)
await mixin_session_expire.commit()
# expire_on_commit fired → obj.state.expired=True, callback_url not in state.dict
snapshot = _snapshot_column_attrs(obj)
# callback_url has a real DB value but is expired — must not be snapshotted as None.
assert "callback_url" not in snapshot