diff --git a/src/fastapi_toolsets/models/watched.py b/src/fastapi_toolsets/models/watched.py index ff3e706..2c04bc7 100644 --- a/src/fastapi_toolsets/models/watched.py +++ b/src/fastapi_toolsets/models/watched.py @@ -9,6 +9,7 @@ from typing import Any, TypeVar from sqlalchemy import event from sqlalchemy import inspect as sa_inspect from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm.attributes import set_committed_value as _sa_set_committed_value from ..logger import get_logger @@ -53,6 +54,17 @@ def watch(*fields: str) -> Any: return decorator +def _snapshot_column_attrs(obj: Any) -> dict[str, Any]: + """Read currently-loaded column values into a plain dict.""" + state = sa_inspect(obj) # InstanceState + state_dict = state.dict + return { + prop.key: state_dict[prop.key] + for prop in state.mapper.column_attrs + if prop.key in state_dict + } + + def _upsert_changes( pending: dict[int, tuple[Any, dict[str, dict[str, Any]]]], obj: Any, @@ -139,16 +151,31 @@ def _task_error_handler(task: asyncio.Task[Any]) -> None: _logger.error(_CALLBACK_ERROR_MSG, exc_info=exc) -def _call_callback(loop: asyncio.AbstractEventLoop, fn: Any, *args: Any) -> None: - """Dispatch *fn* with *args*, handling both sync and async callables.""" - try: - result = fn(*args) - except Exception as exc: - _logger.error(_CALLBACK_ERROR_MSG, exc_info=exc) - return - if asyncio.iscoroutine(result): - task = loop.create_task(result) - task.add_done_callback(_task_error_handler) +def _schedule_with_snapshot( + loop: asyncio.AbstractEventLoop, obj: Any, fn: Any, *args: Any +) -> None: + """Snapshot *obj*'s column attrs now (before expire_on_commit wipes them), + then schedule a coroutine that restores the snapshot and calls *fn*. + """ + snapshot = _snapshot_column_attrs(obj) + + async def _run( + obj: Any = obj, + fn: Any = fn, + snapshot: dict[str, Any] = snapshot, + args: tuple = args, + ) -> None: + for key, value in snapshot.items(): + _sa_set_committed_value(obj, key, value) + try: + result = fn(*args) + if asyncio.iscoroutine(result): + await result + except Exception as exc: + _logger.error(_CALLBACK_ERROR_MSG, exc_info=exc) + + task = loop.create_task(_run()) + task.add_done_callback(_task_error_handler) @event.listens_for(AsyncSession.sync_session_class, "after_commit") @@ -168,13 +195,13 @@ def _after_commit(session: Any) -> None: return for obj in creates: - _call_callback(loop, obj.on_create) + _schedule_with_snapshot(loop, obj, obj.on_create) for obj in deletes: - _call_callback(loop, obj.on_delete) + _schedule_with_snapshot(loop, obj, obj.on_delete) for obj, changes in field_changes.values(): - _call_callback(loop, obj.on_update, changes) + _schedule_with_snapshot(loop, obj, obj.on_update, changes) class WatchedFieldsMixin: diff --git a/tests/test_models.py b/tests/test_models.py index fb1b9a4..9ef63c9 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -31,7 +31,6 @@ from fastapi_toolsets.models.watched import ( _after_flush, _after_flush_postexec, _after_rollback, - _call_callback, _task_error_handler, _upsert_changes, ) @@ -128,6 +127,17 @@ class WatchAllModel(MixinBase, UUIDMixin, WatchedFieldsMixin): _test_events.append({"event": "update", "obj_id": self.id, "changes": changes}) +class FailingCallbackModel(MixinBase, UUIDMixin, WatchedFieldsMixin): + """Model whose on_create always raises to test exception logging.""" + + __tablename__ = "mixin_failing_callback_models" + + name: Mapped[str] = mapped_column(String(50)) + + async def on_create(self) -> None: + raise RuntimeError("callback intentionally failed") + + class NonWatchedModel(MixinBase): __tablename__ = "mixin_non_watched_models" @@ -135,6 +145,32 @@ class NonWatchedModel(MixinBase): value: Mapped[str] = mapped_column(String(50)) +_attr_access_events: list[dict] = [] + + +class AttrAccessModel(MixinBase, UUIDMixin, WatchedFieldsMixin): + """Model used to verify that self attributes are accessible in every callback.""" + + __tablename__ = "mixin_attr_access_models" + + name: Mapped[str] = mapped_column(String(50)) + + async def on_create(self) -> None: + _attr_access_events.append( + {"event": "create", "id": self.id, "name": self.name} + ) + + async def on_delete(self) -> None: + _attr_access_events.append( + {"event": "delete", "id": self.id, "name": self.name} + ) + + async def on_update(self, changes: dict) -> None: + _attr_access_events.append( + {"event": "update", "id": self.id, "name": self.name} + ) + + _sync_events: list[dict] = [] @@ -174,6 +210,25 @@ async def mixin_session(): await engine.dispose() +@pytest.fixture(scope="function") +async def mixin_session_expire(): + """Session with expire_on_commit=True (the default) to exercise attribute access after commit.""" + engine = create_async_engine(DATABASE_URL, echo=False) + async with engine.begin() as conn: + await conn.run_sync(MixinBase.metadata.create_all) + + session_factory = async_sessionmaker(engine, expire_on_commit=True) + session = session_factory() + + try: + yield session + finally: + await session.close() + async with engine.begin() as conn: + await conn.run_sync(MixinBase.metadata.drop_all) + await engine.dispose() + + class TestUUIDMixin: @pytest.mark.anyio async def test_uuid_generated_by_db(self, mixin_session): @@ -742,6 +797,16 @@ class TestWatchedFieldsMixin: assert _test_events == [] + @pytest.mark.anyio + async def test_callback_exception_is_logged(self, mixin_session): + """Exceptions raised inside on_create are logged, not propagated.""" + obj = FailingCallbackModel(name="boom") + mixin_session.add(obj) + with patch.object(_watched_module._logger, "error") as mock_error: + await mixin_session.commit() + await asyncio.sleep(0) + mock_error.assert_called_once() + @pytest.mark.anyio async def test_non_watched_model_no_callback(self, mixin_session): """Dirty objects whose type is not a WatchedFieldsMixin are skipped.""" @@ -903,65 +968,66 @@ class TestSyncCallbacks: assert updates[0]["changes"]["status"] == {"old": "initial", "new": "updated"} -class TestCallCallback: +class TestAttributeAccessInCallbacks: + """Verify that self attributes are accessible inside every callback type. + + Uses expire_on_commit=True (the SQLAlchemy default) so the tests would fail + without the snapshot-restore logic in _schedule_with_snapshot. + """ + + @pytest.fixture(autouse=True) + def clear_events(self): + _attr_access_events.clear() + yield + _attr_access_events.clear() + @pytest.mark.anyio - async def test_async_callback_scheduled_as_task(self): - """_call_callback schedules async functions as tasks.""" - called = [] - - async def async_fn() -> None: - called.append("async") - - loop = asyncio.get_running_loop() - _call_callback(loop, async_fn) + async def test_on_create_pk_and_field_accessible(self, mixin_session_expire): + """id (server default) and regular fields are readable inside on_create.""" + obj = AttrAccessModel(name="hello") + mixin_session_expire.add(obj) + await mixin_session_expire.commit() await asyncio.sleep(0) - assert called == ["async"] + + events = [e for e in _attr_access_events if e["event"] == "create"] + assert len(events) == 1 + assert isinstance(events[0]["id"], uuid.UUID) + assert events[0]["name"] == "hello" @pytest.mark.anyio - async def test_sync_callback_called_directly(self): - """_call_callback invokes sync functions immediately.""" - called = [] - - def sync_fn() -> None: - called.append("sync") - - loop = asyncio.get_running_loop() - _call_callback(loop, sync_fn) - assert called == ["sync"] - - @pytest.mark.anyio - async def test_sync_callback_exception_logged(self): - """_call_callback logs exceptions from sync callbacks.""" - - def failing_fn() -> None: - raise RuntimeError("sync error") - - loop = asyncio.get_running_loop() - with patch.object(_watched_module._logger, "error") as mock_error: - _call_callback(loop, failing_fn) - mock_error.assert_called_once() - - @pytest.mark.anyio - async def test_async_callback_with_args(self): - """_call_callback passes arguments to async callbacks.""" - received = [] - - async def async_fn(changes: dict) -> None: - received.append(changes) - - loop = asyncio.get_running_loop() - _call_callback(loop, async_fn, {"status": {"old": "a", "new": "b"}}) + async def test_on_delete_pk_and_field_accessible(self, mixin_session_expire): + """id and regular fields are readable inside on_delete.""" + obj = AttrAccessModel(name="to-delete") + mixin_session_expire.add(obj) + await mixin_session_expire.commit() await asyncio.sleep(0) - assert received == [{"status": {"old": "a", "new": "b"}}] + _attr_access_events.clear() + + await mixin_session_expire.delete(obj) + await mixin_session_expire.commit() + await asyncio.sleep(0) + + events = [e for e in _attr_access_events if e["event"] == "delete"] + assert len(events) == 1 + assert isinstance(events[0]["id"], uuid.UUID) + assert events[0]["name"] == "to-delete" @pytest.mark.anyio - async def test_sync_callback_with_args(self): - """_call_callback passes arguments to sync callbacks.""" - received = [] + async def test_on_update_pk_and_updated_field_accessible( + self, mixin_session_expire + ): + """id and the new field value are readable inside on_update.""" + obj = AttrAccessModel(name="original") + mixin_session_expire.add(obj) + await mixin_session_expire.commit() + await asyncio.sleep(0) + _attr_access_events.clear() - def sync_fn(changes: dict) -> None: - received.append(changes) + obj.name = "updated" + await mixin_session_expire.commit() + await asyncio.sleep(0) - loop = asyncio.get_running_loop() - _call_callback(loop, sync_fn, {"x": 1}) - assert received == [{"x": 1}] + events = [e for e in _attr_access_events if e["event"] == "update"] + assert len(events) == 1 + assert isinstance(events[0]["id"], uuid.UUID) + assert events[0]["name"] == "updated"