"""Tests for fastapi_toolsets.db module.""" import asyncio import uuid import pytest from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from fastapi_toolsets.db import ( LockMode, create_db_context, create_db_dependency, get_transaction, lock_tables, wait_for_row_change, ) from .conftest import DATABASE_URL, Base, Role, RoleCrud, User class TestCreateDbDependency: """Tests for create_db_dependency.""" @pytest.mark.anyio async def test_yields_session(self): """Dependency yields a valid session.""" engine = create_async_engine(DATABASE_URL, echo=False) session_factory = async_sessionmaker(engine, expire_on_commit=False) get_db = create_db_dependency(session_factory) async for session in get_db(): assert isinstance(session, AsyncSession) break await engine.dispose() @pytest.mark.anyio async def test_auto_commits_transaction(self): """Dependency auto-commits if transaction is active.""" engine = create_async_engine(DATABASE_URL, echo=False) session_factory = async_sessionmaker(engine, expire_on_commit=False) async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) try: get_db = create_db_dependency(session_factory) async for session in get_db(): role = Role(name="test_role_dep") session.add(role) await session.flush() async with session_factory() as verify_session: result = await RoleCrud.first( verify_session, [Role.name == "test_role_dep"] ) assert result is not None finally: async with engine.begin() as conn: await conn.run_sync(Base.metadata.drop_all) await engine.dispose() class TestCreateDbContext: """Tests for create_db_context.""" @pytest.mark.anyio async def test_context_manager_yields_session(self): """Context manager yields a valid session.""" engine = create_async_engine(DATABASE_URL, echo=False) session_factory = async_sessionmaker(engine, expire_on_commit=False) get_db_context = create_db_context(session_factory) async with get_db_context() as session: assert isinstance(session, AsyncSession) await engine.dispose() @pytest.mark.anyio async def test_context_manager_commits(self): """Context manager commits on exit.""" engine = create_async_engine(DATABASE_URL, echo=False) session_factory = async_sessionmaker(engine, expire_on_commit=False) async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) try: get_db_context = create_db_context(session_factory) async with get_db_context() as session: role = Role(name="context_role") session.add(role) await session.flush() async with session_factory() as verify_session: result = await RoleCrud.first( verify_session, [Role.name == "context_role"] ) assert result is not None finally: async with engine.begin() as conn: await conn.run_sync(Base.metadata.drop_all) await engine.dispose() class TestGetTransaction: """Tests for get_transaction context manager.""" @pytest.mark.anyio async def test_starts_transaction(self, db_session: AsyncSession): """get_transaction starts a new transaction.""" async with get_transaction(db_session): role = Role(name="tx_role") db_session.add(role) result = await RoleCrud.first(db_session, [Role.name == "tx_role"]) assert result is not None @pytest.mark.anyio async def test_nested_transaction_uses_savepoint(self, db_session: AsyncSession): """Nested transactions use savepoints.""" async with get_transaction(db_session): role1 = Role(name="outer_role") db_session.add(role1) await db_session.flush() async with get_transaction(db_session): role2 = Role(name="inner_role") db_session.add(role2) results = await RoleCrud.get_multi(db_session) names = {r.name for r in results} assert "outer_role" in names assert "inner_role" in names @pytest.mark.anyio async def test_rollback_on_exception(self, db_session: AsyncSession): """Transaction rolls back on exception.""" try: async with get_transaction(db_session): role = Role(name="rollback_role") db_session.add(role) await db_session.flush() raise ValueError("Simulated error") except ValueError: pass result = await RoleCrud.first(db_session, [Role.name == "rollback_role"]) assert result is None @pytest.mark.anyio async def test_nested_rollback_preserves_outer(self, db_session: AsyncSession): """Nested rollback preserves outer transaction.""" async with get_transaction(db_session): role1 = Role(name="preserved_role") db_session.add(role1) await db_session.flush() try: async with get_transaction(db_session): role2 = Role(name="rolled_back_role") db_session.add(role2) await db_session.flush() raise ValueError("Inner error") except ValueError: pass outer = await RoleCrud.first(db_session, [Role.name == "preserved_role"]) inner = await RoleCrud.first(db_session, [Role.name == "rolled_back_role"]) assert outer is not None assert inner is None class TestLockMode: """Tests for LockMode enum.""" def test_lock_modes_exist(self): """All expected lock modes are defined.""" assert LockMode.ACCESS_SHARE == "ACCESS SHARE" assert LockMode.ROW_SHARE == "ROW SHARE" assert LockMode.ROW_EXCLUSIVE == "ROW EXCLUSIVE" assert LockMode.SHARE_UPDATE_EXCLUSIVE == "SHARE UPDATE EXCLUSIVE" assert LockMode.SHARE == "SHARE" assert LockMode.SHARE_ROW_EXCLUSIVE == "SHARE ROW EXCLUSIVE" assert LockMode.EXCLUSIVE == "EXCLUSIVE" assert LockMode.ACCESS_EXCLUSIVE == "ACCESS EXCLUSIVE" def test_lock_mode_is_string(self): """Lock modes are string enums.""" assert isinstance(LockMode.EXCLUSIVE, str) assert LockMode.EXCLUSIVE.value == "EXCLUSIVE" class TestLockTables: """Tests for lock_tables context manager (PostgreSQL-specific).""" @pytest.mark.anyio async def test_lock_single_table(self, db_session: AsyncSession): """Lock a single table.""" async with lock_tables(db_session, [Role]): # Inside the lock, we can still perform operations role = Role(name="locked_role") db_session.add(role) await db_session.flush() # After lock is released, verify the data was committed result = await RoleCrud.first(db_session, [Role.name == "locked_role"]) assert result is not None @pytest.mark.anyio async def test_lock_multiple_tables(self, db_session: AsyncSession): """Lock multiple tables.""" async with lock_tables(db_session, [Role, User]): role = Role(name="multi_lock_role") db_session.add(role) await db_session.flush() result = await RoleCrud.first(db_session, [Role.name == "multi_lock_role"]) assert result is not None @pytest.mark.anyio async def test_lock_with_custom_mode(self, db_session: AsyncSession): """Lock with custom lock mode.""" async with lock_tables(db_session, [Role], mode=LockMode.EXCLUSIVE): role = Role(name="exclusive_lock_role") db_session.add(role) await db_session.flush() result = await RoleCrud.first(db_session, [Role.name == "exclusive_lock_role"]) assert result is not None @pytest.mark.anyio async def test_lock_rollback_on_exception(self, db_session: AsyncSession): """Lock context rolls back on exception.""" try: async with lock_tables(db_session, [Role]): role = Role(name="lock_rollback_role") db_session.add(role) await db_session.flush() raise ValueError("Simulated error") except ValueError: pass result = await RoleCrud.first(db_session, [Role.name == "lock_rollback_role"]) assert result is None class TestWaitForRowChange: """Tests for wait_for_row_change polling function.""" @pytest.mark.anyio async def test_detects_update(self, db_session: AsyncSession, engine): """Returns updated instance when a column value changes.""" role = Role(name="watch_role") db_session.add(role) await db_session.commit() async def update_later(): await asyncio.sleep(0.15) factory = async_sessionmaker(engine, expire_on_commit=False) async with factory() as other: r = await other.get(Role, role.id) assert r is not None r.name = "updated_role" await other.commit() update_task = asyncio.create_task(update_later()) result = await wait_for_row_change(db_session, Role, role.id, interval=0.05) await update_task assert result.name == "updated_role" @pytest.mark.anyio async def test_watches_specific_columns(self, db_session: AsyncSession, engine): """Only triggers on changes to specified columns.""" user = User(username="testuser", email="test@example.com") db_session.add(user) await db_session.commit() async def update_later(): factory = async_sessionmaker(engine, expire_on_commit=False) # First: change email (not watched) — should not trigger await asyncio.sleep(0.15) async with factory() as other: u = await other.get(User, user.id) assert u is not None u.email = "new@example.com" await other.commit() # Second: change username (watched) — should trigger await asyncio.sleep(0.15) async with factory() as other: u = await other.get(User, user.id) assert u is not None u.username = "newuser" await other.commit() update_task = asyncio.create_task(update_later()) result = await wait_for_row_change( db_session, User, user.id, columns=["username"], interval=0.05 ) await update_task assert result.username == "newuser" assert result.email == "new@example.com" @pytest.mark.anyio async def test_nonexistent_row_raises(self, db_session: AsyncSession): """Raises LookupError when the row does not exist.""" fake_id = uuid.uuid4() with pytest.raises(LookupError, match="not found"): await wait_for_row_change(db_session, Role, fake_id, interval=0.05) @pytest.mark.anyio async def test_timeout_raises(self, db_session: AsyncSession): """Raises TimeoutError when no change is detected within timeout.""" role = Role(name="timeout_role") db_session.add(role) await db_session.commit() with pytest.raises(TimeoutError): await wait_for_row_change( db_session, Role, role.id, interval=0.05, timeout=0.2 ) @pytest.mark.anyio async def test_deleted_row_raises(self, db_session: AsyncSession, engine): """Raises LookupError when the row is deleted during polling.""" role = Role(name="delete_role") db_session.add(role) await db_session.commit() async def delete_later(): await asyncio.sleep(0.15) factory = async_sessionmaker(engine, expire_on_commit=False) async with factory() as other: r = await other.get(Role, role.id) await other.delete(r) await other.commit() delete_task = asyncio.create_task(delete_later()) with pytest.raises(LookupError): await wait_for_row_change(db_session, Role, role.id, interval=0.05) await delete_task