diff --git a/docs/module/models.md b/docs/module/models.md index 5613900..5ed4d77 100644 --- a/docs/module/models.md +++ b/docs/module/models.md @@ -214,12 +214,12 @@ The `changes` dict maps each watched field that changed to `{"old": ..., "new": !!! warning "Callbacks fire only for ORM-level changes. Rows updated via raw SQL (`UPDATE ... SET ...`) are not detected." -!!! warning "Callbacks fire after the **outermost** transaction commits." +!!! warning "Callbacks fire when the **outermost active context** (savepoint or transaction) commits." If you create several related objects using `CrudFactory.create` and need callbacks to see all of them (including associations), wrap the whole - operation in a single [`get_transaction`](db.md) block. Without it, each - `create` call commits independently and `on_create` fires before the - remaining objects exist. + operation in a single [`get_transaction`](db.md) or [`lock_tables`](db.md) + block. Without it, each `create` call commits its own savepoint and + `on_create` fires before the remaining objects exist. ```python from fastapi_toolsets.db import get_transaction diff --git a/src/fastapi_toolsets/db.py b/src/fastapi_toolsets/db.py index fdeaab0..67dffd2 100644 --- a/src/fastapi_toolsets/db.py +++ b/src/fastapi_toolsets/db.py @@ -56,6 +56,7 @@ def create_db_dependency( async def get_db() -> AsyncGenerator[AsyncSession, None]: async with session_maker() as session: + await session.connection() yield session if session.in_transaction(): await session.commit() diff --git a/tests/test_db.py b/tests/test_db.py index 0cfad0b..5e7cc7c 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -68,6 +68,55 @@ class TestCreateDbDependency: await conn.run_sync(Base.metadata.drop_all) await engine.dispose() + @pytest.mark.anyio + async def test_in_transaction_on_yield(self): + """Session is already in a transaction when the endpoint body starts.""" + engine = create_async_engine(DATABASE_URL, echo=False) + session_factory = async_sessionmaker(engine, expire_on_commit=False) + get_db = create_db_dependency(session_factory) + + async for session in get_db(): + assert session.in_transaction() + break + + await engine.dispose() + + @pytest.mark.anyio + async def test_update_after_lock_tables_is_persisted(self): + """Changes made after lock_tables exits (before endpoint returns) are committed. + + Regression: without the auto-begin fix, lock_tables would start and commit a + real outer transaction, leaving the session idle. Any modifications after that + point were silently dropped. + """ + engine = create_async_engine(DATABASE_URL, echo=False) + session_factory = async_sessionmaker(engine, expire_on_commit=False) + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + try: + get_db = create_db_dependency(session_factory) + + async for session in get_db(): + async with lock_tables(session, [Role]): + role = Role(name="lock_then_update") + session.add(role) + await session.flush() + # lock_tables has exited — outer transaction must still be open + assert session.in_transaction() + role.name = "updated_after_lock" + + async with session_factory() as verify: + result = await RoleCrud.first( + verify, [Role.name == "updated_after_lock"] + ) + assert result is not None + finally: + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + await engine.dispose() + class TestCreateDbContext: """Tests for create_db_context."""