From 6681b7ade7fb1b0c3a366167383ec0fdb5c023e2 Mon Sep 17 00:00:00 2001 From: d3vyce <44915747+d3vyce@users.noreply.github.com> Date: Tue, 24 Mar 2026 19:56:03 +0100 Subject: [PATCH] fix: defer on_create/on_update/on_delete dispatch until outermost transaction commits (#172) --- src/fastapi_toolsets/models/watched.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/fastapi_toolsets/models/watched.py b/src/fastapi_toolsets/models/watched.py index 00e1e5c..6198e48 100644 --- a/src/fastapi_toolsets/models/watched.py +++ b/src/fastapi_toolsets/models/watched.py @@ -26,6 +26,7 @@ _SESSION_PENDING_NEW = "_ft_pending_new" _SESSION_CREATES = "_ft_creates" _SESSION_DELETES = "_ft_deletes" _SESSION_UPDATES = "_ft_updates" +_SESSION_SAVEPOINT_DEPTH = "_ft_sp_depth" class ModelEvent(str, Enum): @@ -92,6 +93,22 @@ def _upsert_changes( pending[key] = (obj, changes) +@event.listens_for(AsyncSession.sync_session_class, "after_transaction_create") +def _after_transaction_create(session: Any, transaction: Any) -> None: + if transaction.nested: + session.info[_SESSION_SAVEPOINT_DEPTH] = ( + session.info.get(_SESSION_SAVEPOINT_DEPTH, 0) + 1 + ) + + +@event.listens_for(AsyncSession.sync_session_class, "after_transaction_end") +def _after_transaction_end(session: Any, transaction: Any) -> None: + if transaction.nested: + depth = session.info.get(_SESSION_SAVEPOINT_DEPTH, 0) + if depth > 0: # pragma: no branch + session.info[_SESSION_SAVEPOINT_DEPTH] = depth - 1 + + @event.listens_for(AsyncSession.sync_session_class, "after_flush") def _after_flush(session: Any, flush_context: Any) -> None: # New objects: capture references while session.new is still populated. @@ -189,6 +206,9 @@ def _schedule_with_snapshot( @event.listens_for(AsyncSession.sync_session_class, "after_commit") def _after_commit(session: Any) -> None: + if session.info.get(_SESSION_SAVEPOINT_DEPTH, 0) > 0: + return + creates: list[Any] = session.info.pop(_SESSION_CREATES, []) deletes: list[Any] = session.info.pop(_SESSION_DELETES, []) field_changes: dict[int, tuple[Any, dict[str, dict[str, Any]]]] = session.info.pop(