"""Tests for fastapi_toolsets.db module.""" import asyncio import uuid import pytest from sqlalchemy import text from sqlalchemy.engine import make_url from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.orm import DeclarativeBase from fastapi_toolsets.db import ( LockMode, cleanup_tables, create_database, create_db_context, create_db_dependency, get_transaction, lock_tables, wait_for_row_change, ) from fastapi_toolsets.exceptions import NotFoundError from fastapi_toolsets.pytest import create_db_session from .conftest import DATABASE_URL, Base, Role, RoleCrud, User, UserCrud class TestCreateDbDependency: """Tests for create_db_dependency.""" @pytest.mark.anyio async def test_yields_session(self): """Dependency yields a valid session.""" engine = create_async_engine(DATABASE_URL, echo=False) session_factory = async_sessionmaker(engine, expire_on_commit=False) get_db = create_db_dependency(session_factory) async for session in get_db(): assert isinstance(session, AsyncSession) break await engine.dispose() @pytest.mark.anyio async def test_auto_commits_transaction(self): """Dependency auto-commits if transaction is active.""" engine = create_async_engine(DATABASE_URL, echo=False) session_factory = async_sessionmaker(engine, expire_on_commit=False) async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) try: get_db = create_db_dependency(session_factory) async for session in get_db(): role = Role(name="test_role_dep") session.add(role) await session.flush() async with session_factory() as verify_session: result = await RoleCrud.first( verify_session, [Role.name == "test_role_dep"] ) assert result is not None finally: async with engine.begin() as conn: await conn.run_sync(Base.metadata.drop_all) await engine.dispose() @pytest.mark.anyio async def test_in_transaction_on_yield(self): """Session is already in a transaction when the endpoint body starts.""" engine = create_async_engine(DATABASE_URL, echo=False) session_factory = async_sessionmaker(engine, expire_on_commit=False) get_db = create_db_dependency(session_factory) async for session in get_db(): assert session.in_transaction() break await engine.dispose() @pytest.mark.anyio async def test_update_after_lock_tables_is_persisted(self): """Changes made after lock_tables exits (before endpoint returns) are committed. Regression: without the auto-begin fix, lock_tables would start and commit a real outer transaction, leaving the session idle. Any modifications after that point were silently dropped. """ engine = create_async_engine(DATABASE_URL, echo=False) session_factory = async_sessionmaker(engine, expire_on_commit=False) async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) try: get_db = create_db_dependency(session_factory) async for session in get_db(): async with lock_tables(session, [Role]): role = Role(name="lock_then_update") session.add(role) await session.flush() # lock_tables has exited — outer transaction must still be open assert session.in_transaction() role.name = "updated_after_lock" async with session_factory() as verify: result = await RoleCrud.first( verify, [Role.name == "updated_after_lock"] ) assert result is not None finally: async with engine.begin() as conn: await conn.run_sync(Base.metadata.drop_all) await engine.dispose() class TestCreateDbContext: """Tests for create_db_context.""" @pytest.mark.anyio async def test_context_manager_yields_session(self): """Context manager yields a valid session.""" engine = create_async_engine(DATABASE_URL, echo=False) session_factory = async_sessionmaker(engine, expire_on_commit=False) get_db_context = create_db_context(session_factory) async with get_db_context() as session: assert isinstance(session, AsyncSession) await engine.dispose() @pytest.mark.anyio async def test_context_manager_commits(self): """Context manager commits on exit.""" engine = create_async_engine(DATABASE_URL, echo=False) session_factory = async_sessionmaker(engine, expire_on_commit=False) async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) try: get_db_context = create_db_context(session_factory) async with get_db_context() as session: role = Role(name="context_role") session.add(role) await session.flush() async with session_factory() as verify_session: result = await RoleCrud.first( verify_session, [Role.name == "context_role"] ) assert result is not None finally: async with engine.begin() as conn: await conn.run_sync(Base.metadata.drop_all) await engine.dispose() class TestGetTransaction: """Tests for get_transaction context manager.""" @pytest.mark.anyio async def test_starts_transaction(self, db_session: AsyncSession): """get_transaction starts a new transaction.""" async with get_transaction(db_session): role = Role(name="tx_role") db_session.add(role) result = await RoleCrud.first(db_session, [Role.name == "tx_role"]) assert result is not None @pytest.mark.anyio async def test_nested_transaction_uses_savepoint(self, db_session: AsyncSession): """Nested transactions use savepoints.""" async with get_transaction(db_session): role1 = Role(name="outer_role") db_session.add(role1) await db_session.flush() async with get_transaction(db_session): role2 = Role(name="inner_role") db_session.add(role2) results = await RoleCrud.get_multi(db_session) names = {r.name for r in results} assert "outer_role" in names assert "inner_role" in names @pytest.mark.anyio async def test_rollback_on_exception(self, db_session: AsyncSession): """Transaction rolls back on exception.""" try: async with get_transaction(db_session): role = Role(name="rollback_role") db_session.add(role) await db_session.flush() raise ValueError("Simulated error") except ValueError: pass result = await RoleCrud.first(db_session, [Role.name == "rollback_role"]) assert result is None @pytest.mark.anyio async def test_nested_rollback_preserves_outer(self, db_session: AsyncSession): """Nested rollback preserves outer transaction.""" async with get_transaction(db_session): role1 = Role(name="preserved_role") db_session.add(role1) await db_session.flush() try: async with get_transaction(db_session): role2 = Role(name="rolled_back_role") db_session.add(role2) await db_session.flush() raise ValueError("Inner error") except ValueError: pass outer = await RoleCrud.first(db_session, [Role.name == "preserved_role"]) inner = await RoleCrud.first(db_session, [Role.name == "rolled_back_role"]) assert outer is not None assert inner is None class TestLockMode: """Tests for LockMode enum.""" def test_lock_modes_exist(self): """All expected lock modes are defined.""" assert LockMode.ACCESS_SHARE == "ACCESS SHARE" assert LockMode.ROW_SHARE == "ROW SHARE" assert LockMode.ROW_EXCLUSIVE == "ROW EXCLUSIVE" assert LockMode.SHARE_UPDATE_EXCLUSIVE == "SHARE UPDATE EXCLUSIVE" assert LockMode.SHARE == "SHARE" assert LockMode.SHARE_ROW_EXCLUSIVE == "SHARE ROW EXCLUSIVE" assert LockMode.EXCLUSIVE == "EXCLUSIVE" assert LockMode.ACCESS_EXCLUSIVE == "ACCESS EXCLUSIVE" def test_lock_mode_is_string(self): """Lock modes are string enums.""" assert isinstance(LockMode.EXCLUSIVE, str) assert LockMode.EXCLUSIVE.value == "EXCLUSIVE" class TestLockTables: """Tests for lock_tables context manager (PostgreSQL-specific).""" @pytest.mark.anyio async def test_lock_single_table(self, db_session: AsyncSession): """Lock a single table.""" async with lock_tables(db_session, [Role]): # Inside the lock, we can still perform operations role = Role(name="locked_role") db_session.add(role) await db_session.flush() # After lock is released, verify the data was committed result = await RoleCrud.first(db_session, [Role.name == "locked_role"]) assert result is not None @pytest.mark.anyio async def test_lock_multiple_tables(self, db_session: AsyncSession): """Lock multiple tables.""" async with lock_tables(db_session, [Role, User]): role = Role(name="multi_lock_role") db_session.add(role) await db_session.flush() result = await RoleCrud.first(db_session, [Role.name == "multi_lock_role"]) assert result is not None @pytest.mark.anyio async def test_lock_with_custom_mode(self, db_session: AsyncSession): """Lock with custom lock mode.""" async with lock_tables(db_session, [Role], mode=LockMode.EXCLUSIVE): role = Role(name="exclusive_lock_role") db_session.add(role) await db_session.flush() result = await RoleCrud.first(db_session, [Role.name == "exclusive_lock_role"]) assert result is not None @pytest.mark.anyio async def test_lock_rollback_on_exception(self, db_session: AsyncSession): """Lock context rolls back on exception.""" try: async with lock_tables(db_session, [Role]): role = Role(name="lock_rollback_role") db_session.add(role) await db_session.flush() raise ValueError("Simulated error") except ValueError: pass result = await RoleCrud.first(db_session, [Role.name == "lock_rollback_role"]) assert result is None class TestWaitForRowChange: """Tests for wait_for_row_change polling function.""" @pytest.mark.anyio async def test_detects_update(self, db_session: AsyncSession, engine): """Returns updated instance when a column value changes.""" role = Role(name="watch_role") db_session.add(role) await db_session.commit() async def update_later(): await asyncio.sleep(0.15) factory = async_sessionmaker(engine, expire_on_commit=False) async with factory() as other: r = await other.get(Role, role.id) assert r is not None r.name = "updated_role" await other.commit() update_task = asyncio.create_task(update_later()) result = await wait_for_row_change(db_session, Role, role.id, interval=0.05) await update_task assert result.name == "updated_role" @pytest.mark.anyio async def test_watches_specific_columns(self, db_session: AsyncSession, engine): """Only triggers on changes to specified columns.""" user = User(username="testuser", email="test@example.com") db_session.add(user) await db_session.commit() async def update_later(): factory = async_sessionmaker(engine, expire_on_commit=False) # First: change email (not watched) — should not trigger await asyncio.sleep(0.15) async with factory() as other: u = await other.get(User, user.id) assert u is not None u.email = "new@example.com" await other.commit() # Second: change username (watched) — should trigger await asyncio.sleep(0.15) async with factory() as other: u = await other.get(User, user.id) assert u is not None u.username = "newuser" await other.commit() update_task = asyncio.create_task(update_later()) result = await wait_for_row_change( db_session, User, user.id, columns=["username"], interval=0.05 ) await update_task assert result.username == "newuser" assert result.email == "new@example.com" @pytest.mark.anyio async def test_nonexistent_row_raises(self, db_session: AsyncSession): """Raises NotFoundError when the row does not exist.""" fake_id = uuid.uuid4() with pytest.raises(NotFoundError, match="not found"): await wait_for_row_change(db_session, Role, fake_id, interval=0.05) @pytest.mark.anyio async def test_timeout_raises(self, db_session: AsyncSession): """Raises TimeoutError when no change is detected within timeout.""" role = Role(name="timeout_role") db_session.add(role) await db_session.commit() with pytest.raises(TimeoutError): await wait_for_row_change( db_session, Role, role.id, interval=0.05, timeout=0.2 ) @pytest.mark.anyio async def test_deleted_row_raises(self, db_session: AsyncSession, engine): """Raises NotFoundError when the row is deleted during polling.""" role = Role(name="delete_role") db_session.add(role) await db_session.commit() async def delete_later(): await asyncio.sleep(0.15) factory = async_sessionmaker(engine, expire_on_commit=False) async with factory() as other: r = await other.get(Role, role.id) await other.delete(r) await other.commit() delete_task = asyncio.create_task(delete_later()) with pytest.raises(NotFoundError): await wait_for_row_change(db_session, Role, role.id, interval=0.05) await delete_task class TestCreateDatabase: """Tests for create_database.""" @pytest.mark.anyio async def test_creates_database(self): """Database is created by create_database.""" target_url = ( make_url(DATABASE_URL) .set(database="test_create_db_general") .render_as_string(hide_password=False) ) expected_db: str = make_url(target_url).database # type: ignore[assignment] engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT") try: async with engine.connect() as conn: await conn.execute(text(f"DROP DATABASE IF EXISTS {expected_db}")) await create_database(db_name=expected_db, server_url=DATABASE_URL) async with engine.connect() as conn: result = await conn.execute( text("SELECT 1 FROM pg_database WHERE datname = :name"), {"name": expected_db}, ) assert result.scalar() == 1 # Cleanup async with engine.connect() as conn: await conn.execute(text(f"DROP DATABASE IF EXISTS {expected_db}")) finally: await engine.dispose() class TestCleanupTables: """Tests for cleanup_tables helper.""" @pytest.mark.anyio async def test_truncates_all_tables(self): """All table rows are removed after cleanup_tables.""" async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session: role = Role(id=uuid.uuid4(), name="cleanup_role") session.add(role) await session.flush() user = User( id=uuid.uuid4(), username="cleanup_user", email="cleanup@test.com", role_id=role.id, ) session.add(user) await session.commit() # Verify rows exist roles_count = await RoleCrud.count(session) users_count = await UserCrud.count(session) assert roles_count == 1 assert users_count == 1 await cleanup_tables(session, Base) # Verify tables are empty roles_count = await RoleCrud.count(session) users_count = await UserCrud.count(session) assert roles_count == 0 assert users_count == 0 @pytest.mark.anyio async def test_noop_for_empty_metadata(self): """cleanup_tables does not raise when metadata has no tables.""" class EmptyBase(DeclarativeBase): pass async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session: # Should not raise await cleanup_tables(session, EmptyBase)