mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-03-01 17:00:48 +01:00
244 lines
8.6 KiB
Python
244 lines
8.6 KiB
Python
"""Tests for fastapi_toolsets.db module."""
|
|
|
|
import pytest
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
|
|
from fastapi_toolsets.db import (
|
|
LockMode,
|
|
create_db_context,
|
|
create_db_dependency,
|
|
get_transaction,
|
|
lock_tables,
|
|
)
|
|
|
|
from .conftest import DATABASE_URL, Base, Role, RoleCrud, User
|
|
|
|
|
|
class TestCreateDbDependency:
|
|
"""Tests for create_db_dependency."""
|
|
|
|
@pytest.mark.anyio
|
|
async def test_yields_session(self):
|
|
"""Dependency yields a valid session."""
|
|
engine = create_async_engine(DATABASE_URL, echo=False)
|
|
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
|
get_db = create_db_dependency(session_factory)
|
|
|
|
async for session in get_db():
|
|
assert isinstance(session, AsyncSession)
|
|
break
|
|
|
|
await engine.dispose()
|
|
|
|
@pytest.mark.anyio
|
|
async def test_auto_commits_transaction(self):
|
|
"""Dependency auto-commits if transaction is active."""
|
|
engine = create_async_engine(DATABASE_URL, echo=False)
|
|
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
|
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
try:
|
|
get_db = create_db_dependency(session_factory)
|
|
|
|
async for session in get_db():
|
|
role = Role(name="test_role_dep")
|
|
session.add(role)
|
|
await session.flush()
|
|
|
|
async with session_factory() as verify_session:
|
|
result = await RoleCrud.first(
|
|
verify_session, [Role.name == "test_role_dep"]
|
|
)
|
|
assert result is not None
|
|
finally:
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.drop_all)
|
|
await engine.dispose()
|
|
|
|
|
|
class TestCreateDbContext:
|
|
"""Tests for create_db_context."""
|
|
|
|
@pytest.mark.anyio
|
|
async def test_context_manager_yields_session(self):
|
|
"""Context manager yields a valid session."""
|
|
engine = create_async_engine(DATABASE_URL, echo=False)
|
|
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
|
get_db_context = create_db_context(session_factory)
|
|
|
|
async with get_db_context() as session:
|
|
assert isinstance(session, AsyncSession)
|
|
|
|
await engine.dispose()
|
|
|
|
@pytest.mark.anyio
|
|
async def test_context_manager_commits(self):
|
|
"""Context manager commits on exit."""
|
|
engine = create_async_engine(DATABASE_URL, echo=False)
|
|
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
|
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
try:
|
|
get_db_context = create_db_context(session_factory)
|
|
|
|
async with get_db_context() as session:
|
|
role = Role(name="context_role")
|
|
session.add(role)
|
|
await session.flush()
|
|
|
|
async with session_factory() as verify_session:
|
|
result = await RoleCrud.first(
|
|
verify_session, [Role.name == "context_role"]
|
|
)
|
|
assert result is not None
|
|
finally:
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.drop_all)
|
|
await engine.dispose()
|
|
|
|
|
|
class TestGetTransaction:
|
|
"""Tests for get_transaction context manager."""
|
|
|
|
@pytest.mark.anyio
|
|
async def test_starts_transaction(self, db_session: AsyncSession):
|
|
"""get_transaction starts a new transaction."""
|
|
async with get_transaction(db_session):
|
|
role = Role(name="tx_role")
|
|
db_session.add(role)
|
|
|
|
result = await RoleCrud.first(db_session, [Role.name == "tx_role"])
|
|
assert result is not None
|
|
|
|
@pytest.mark.anyio
|
|
async def test_nested_transaction_uses_savepoint(self, db_session: AsyncSession):
|
|
"""Nested transactions use savepoints."""
|
|
async with get_transaction(db_session):
|
|
role1 = Role(name="outer_role")
|
|
db_session.add(role1)
|
|
await db_session.flush()
|
|
|
|
async with get_transaction(db_session):
|
|
role2 = Role(name="inner_role")
|
|
db_session.add(role2)
|
|
|
|
results = await RoleCrud.get_multi(db_session)
|
|
names = {r.name for r in results}
|
|
assert "outer_role" in names
|
|
assert "inner_role" in names
|
|
|
|
@pytest.mark.anyio
|
|
async def test_rollback_on_exception(self, db_session: AsyncSession):
|
|
"""Transaction rolls back on exception."""
|
|
try:
|
|
async with get_transaction(db_session):
|
|
role = Role(name="rollback_role")
|
|
db_session.add(role)
|
|
await db_session.flush()
|
|
raise ValueError("Simulated error")
|
|
except ValueError:
|
|
pass
|
|
|
|
result = await RoleCrud.first(db_session, [Role.name == "rollback_role"])
|
|
assert result is None
|
|
|
|
@pytest.mark.anyio
|
|
async def test_nested_rollback_preserves_outer(self, db_session: AsyncSession):
|
|
"""Nested rollback preserves outer transaction."""
|
|
async with get_transaction(db_session):
|
|
role1 = Role(name="preserved_role")
|
|
db_session.add(role1)
|
|
await db_session.flush()
|
|
|
|
try:
|
|
async with get_transaction(db_session):
|
|
role2 = Role(name="rolled_back_role")
|
|
db_session.add(role2)
|
|
await db_session.flush()
|
|
raise ValueError("Inner error")
|
|
except ValueError:
|
|
pass
|
|
|
|
outer = await RoleCrud.first(db_session, [Role.name == "preserved_role"])
|
|
inner = await RoleCrud.first(db_session, [Role.name == "rolled_back_role"])
|
|
assert outer is not None
|
|
assert inner is None
|
|
|
|
|
|
class TestLockMode:
|
|
"""Tests for LockMode enum."""
|
|
|
|
def test_lock_modes_exist(self):
|
|
"""All expected lock modes are defined."""
|
|
assert LockMode.ACCESS_SHARE == "ACCESS SHARE"
|
|
assert LockMode.ROW_SHARE == "ROW SHARE"
|
|
assert LockMode.ROW_EXCLUSIVE == "ROW EXCLUSIVE"
|
|
assert LockMode.SHARE_UPDATE_EXCLUSIVE == "SHARE UPDATE EXCLUSIVE"
|
|
assert LockMode.SHARE == "SHARE"
|
|
assert LockMode.SHARE_ROW_EXCLUSIVE == "SHARE ROW EXCLUSIVE"
|
|
assert LockMode.EXCLUSIVE == "EXCLUSIVE"
|
|
assert LockMode.ACCESS_EXCLUSIVE == "ACCESS EXCLUSIVE"
|
|
|
|
def test_lock_mode_is_string(self):
|
|
"""Lock modes are string enums."""
|
|
assert isinstance(LockMode.EXCLUSIVE, str)
|
|
assert LockMode.EXCLUSIVE.value == "EXCLUSIVE"
|
|
|
|
|
|
class TestLockTables:
|
|
"""Tests for lock_tables context manager (PostgreSQL-specific)."""
|
|
|
|
@pytest.mark.anyio
|
|
async def test_lock_single_table(self, db_session: AsyncSession):
|
|
"""Lock a single table."""
|
|
async with lock_tables(db_session, [Role]):
|
|
# Inside the lock, we can still perform operations
|
|
role = Role(name="locked_role")
|
|
db_session.add(role)
|
|
await db_session.flush()
|
|
|
|
# After lock is released, verify the data was committed
|
|
result = await RoleCrud.first(db_session, [Role.name == "locked_role"])
|
|
assert result is not None
|
|
|
|
@pytest.mark.anyio
|
|
async def test_lock_multiple_tables(self, db_session: AsyncSession):
|
|
"""Lock multiple tables."""
|
|
async with lock_tables(db_session, [Role, User]):
|
|
role = Role(name="multi_lock_role")
|
|
db_session.add(role)
|
|
await db_session.flush()
|
|
|
|
result = await RoleCrud.first(db_session, [Role.name == "multi_lock_role"])
|
|
assert result is not None
|
|
|
|
@pytest.mark.anyio
|
|
async def test_lock_with_custom_mode(self, db_session: AsyncSession):
|
|
"""Lock with custom lock mode."""
|
|
async with lock_tables(db_session, [Role], mode=LockMode.EXCLUSIVE):
|
|
role = Role(name="exclusive_lock_role")
|
|
db_session.add(role)
|
|
await db_session.flush()
|
|
|
|
result = await RoleCrud.first(db_session, [Role.name == "exclusive_lock_role"])
|
|
assert result is not None
|
|
|
|
@pytest.mark.anyio
|
|
async def test_lock_rollback_on_exception(self, db_session: AsyncSession):
|
|
"""Lock context rolls back on exception."""
|
|
try:
|
|
async with lock_tables(db_session, [Role]):
|
|
role = Role(name="lock_rollback_role")
|
|
db_session.add(role)
|
|
await db_session.flush()
|
|
raise ValueError("Simulated error")
|
|
except ValueError:
|
|
pass
|
|
|
|
result = await RoleCrud.first(db_session, [Role.name == "lock_rollback_role"])
|
|
assert result is None
|