diff --git a/src/fastapi_toolsets/models/watched.py b/src/fastapi_toolsets/models/watched.py index 00e1e5c..6198e48 100644 --- a/src/fastapi_toolsets/models/watched.py +++ b/src/fastapi_toolsets/models/watched.py @@ -26,6 +26,7 @@ _SESSION_PENDING_NEW = "_ft_pending_new" _SESSION_CREATES = "_ft_creates" _SESSION_DELETES = "_ft_deletes" _SESSION_UPDATES = "_ft_updates" +_SESSION_SAVEPOINT_DEPTH = "_ft_sp_depth" class ModelEvent(str, Enum): @@ -92,6 +93,22 @@ def _upsert_changes( pending[key] = (obj, changes) +@event.listens_for(AsyncSession.sync_session_class, "after_transaction_create") +def _after_transaction_create(session: Any, transaction: Any) -> None: + if transaction.nested: + session.info[_SESSION_SAVEPOINT_DEPTH] = ( + session.info.get(_SESSION_SAVEPOINT_DEPTH, 0) + 1 + ) + + +@event.listens_for(AsyncSession.sync_session_class, "after_transaction_end") +def _after_transaction_end(session: Any, transaction: Any) -> None: + if transaction.nested: + depth = session.info.get(_SESSION_SAVEPOINT_DEPTH, 0) + if depth > 0: # pragma: no branch + session.info[_SESSION_SAVEPOINT_DEPTH] = depth - 1 + + @event.listens_for(AsyncSession.sync_session_class, "after_flush") def _after_flush(session: Any, flush_context: Any) -> None: # New objects: capture references while session.new is still populated. @@ -189,6 +206,9 @@ def _schedule_with_snapshot( @event.listens_for(AsyncSession.sync_session_class, "after_commit") def _after_commit(session: Any) -> None: + if session.info.get(_SESSION_SAVEPOINT_DEPTH, 0) > 0: + return + creates: list[Any] = session.info.pop(_SESSION_CREATES, []) deletes: list[Any] = session.info.pop(_SESSION_DELETES, []) field_changes: dict[int, tuple[Any, dict[str, dict[str, Any]]]] = session.info.pop(