fix: defer on_create/on_update/on_delete dispatch until outermost transaction commits (#172)

This commit is contained in:
d3vyce
2026-03-24 19:56:03 +01:00
committed by GitHub
parent 6981c33dc8
commit 6681b7ade7

View File

@@ -26,6 +26,7 @@ _SESSION_PENDING_NEW = "_ft_pending_new"
_SESSION_CREATES = "_ft_creates" _SESSION_CREATES = "_ft_creates"
_SESSION_DELETES = "_ft_deletes" _SESSION_DELETES = "_ft_deletes"
_SESSION_UPDATES = "_ft_updates" _SESSION_UPDATES = "_ft_updates"
_SESSION_SAVEPOINT_DEPTH = "_ft_sp_depth"
class ModelEvent(str, Enum): class ModelEvent(str, Enum):
@@ -92,6 +93,22 @@ def _upsert_changes(
pending[key] = (obj, changes) pending[key] = (obj, changes)
@event.listens_for(AsyncSession.sync_session_class, "after_transaction_create")
def _after_transaction_create(session: Any, transaction: Any) -> None:
if transaction.nested:
session.info[_SESSION_SAVEPOINT_DEPTH] = (
session.info.get(_SESSION_SAVEPOINT_DEPTH, 0) + 1
)
@event.listens_for(AsyncSession.sync_session_class, "after_transaction_end")
def _after_transaction_end(session: Any, transaction: Any) -> None:
if transaction.nested:
depth = session.info.get(_SESSION_SAVEPOINT_DEPTH, 0)
if depth > 0: # pragma: no branch
session.info[_SESSION_SAVEPOINT_DEPTH] = depth - 1
@event.listens_for(AsyncSession.sync_session_class, "after_flush") @event.listens_for(AsyncSession.sync_session_class, "after_flush")
def _after_flush(session: Any, flush_context: Any) -> None: def _after_flush(session: Any, flush_context: Any) -> None:
# New objects: capture references while session.new is still populated. # New objects: capture references while session.new is still populated.
@@ -189,6 +206,9 @@ def _schedule_with_snapshot(
@event.listens_for(AsyncSession.sync_session_class, "after_commit") @event.listens_for(AsyncSession.sync_session_class, "after_commit")
def _after_commit(session: Any) -> None: def _after_commit(session: Any) -> None:
if session.info.get(_SESSION_SAVEPOINT_DEPTH, 0) > 0:
return
creates: list[Any] = session.info.pop(_SESSION_CREATES, []) creates: list[Any] = session.info.pop(_SESSION_CREATES, [])
deletes: list[Any] = session.info.pop(_SESSION_DELETES, []) deletes: list[Any] = session.info.pop(_SESSION_DELETES, [])
field_changes: dict[int, tuple[Any, dict[str, dict[str, Any]]]] = session.info.pop( field_changes: dict[int, tuple[Any, dict[str, dict[str, Any]]]] = session.info.pop(