Files
fastapi-toolsets/tests/test_pytest.py
2026-04-12 18:46:57 +02:00

726 lines
26 KiB
Python

"""Tests for fastapi_toolsets.pytest module."""
import uuid
from typing import cast
import pytest
from fastapi import Depends, FastAPI
from httpx import AsyncClient
from sqlalchemy import ForeignKey, String, select, text
from sqlalchemy.engine import make_url
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from fastapi_toolsets.db import get_transaction
from fastapi_toolsets.fixtures import Context, FixtureRegistry, LoadStrategy
from fastapi_toolsets.pytest import (
create_async_client,
create_db_session,
create_worker_database,
register_fixtures,
worker_database_url,
)
from fastapi_toolsets.pytest.plugin import (
_get_primary_key,
_relationship_load_options,
_reload_with_relationships,
)
from fastapi_toolsets.pytest.utils import _get_xdist_worker
from .conftest import (
DATABASE_URL,
Base,
IntRole,
Permission,
Role,
RoleCrud,
User,
UserCrud,
)
test_registry = FixtureRegistry()
# Fixed UUIDs for test fixtures to allow consistent assertions
ROLE_ADMIN_ID = uuid.UUID("00000000-0000-0000-0000-000000001000")
ROLE_USER_ID = uuid.UUID("00000000-0000-0000-0000-000000001001")
USER_ADMIN_ID = uuid.UUID("00000000-0000-0000-0000-000000002000")
USER_USER_ID = uuid.UUID("00000000-0000-0000-0000-000000002001")
USER_EXTRA_ID = uuid.UUID("00000000-0000-0000-0000-000000002002")
@test_registry.register(contexts=[Context.BASE])
def roles() -> list[Role]:
return [
Role(id=ROLE_ADMIN_ID, name="plugin_admin"),
Role(id=ROLE_USER_ID, name="plugin_user"),
]
@test_registry.register(depends_on=["roles"], contexts=[Context.BASE])
def users() -> list[User]:
return [
User(
id=USER_ADMIN_ID,
username="plugin_admin",
email="padmin@test.com",
role_id=ROLE_ADMIN_ID,
),
User(
id=USER_USER_ID,
username="plugin_user",
email="puser@test.com",
role_id=ROLE_USER_ID,
),
]
@test_registry.register(depends_on=["users"], contexts=[Context.TESTING])
def extra_users() -> list[User]:
return [
User(
id=USER_EXTRA_ID,
username="plugin_extra",
email="pextra@test.com",
role_id=ROLE_USER_ID,
),
]
register_fixtures(test_registry, globals())
class TestRegisterFixtures:
"""Tests for register_fixtures function."""
def test_creates_fixtures_in_namespace(self):
"""Fixtures are created in the namespace."""
assert "fixture_roles" in globals()
assert "fixture_users" in globals()
assert "fixture_extra_users" in globals()
def test_fixtures_are_callable(self):
"""Created fixtures are callable."""
assert callable(globals()["fixture_roles"])
assert callable(globals()["fixture_users"])
class TestGeneratedFixtures:
"""Tests for the generated pytest fixtures."""
@pytest.mark.anyio
async def test_fixture_loads_data(
self, db_session: AsyncSession, fixture_roles: list[Role]
):
"""Fixture loads data into database and returns it."""
assert len(fixture_roles) == 2
assert fixture_roles[0].name == "plugin_admin"
assert fixture_roles[1].name == "plugin_user"
# Verify data is in database
count = await RoleCrud.count(db_session)
assert count == 2
@pytest.mark.anyio
async def test_fixture_with_dependency(
self, db_session: AsyncSession, fixture_users: list[User]
):
"""Fixture with dependency loads parent fixture first."""
# fixture_users depends on fixture_roles
# Both should be loaded
assert len(fixture_users) == 2
# Roles should also be in database
roles_count = await RoleCrud.count(db_session)
assert roles_count == 2
# Users should be in database
users_count = await UserCrud.count(db_session)
assert users_count == 2
@pytest.mark.anyio
async def test_fixture_returns_models(
self, db_session: AsyncSession, fixture_users: list[User]
):
"""Fixture returns actual model instances."""
user = fixture_users[0]
assert isinstance(user, User)
assert user.id == USER_ADMIN_ID
assert user.username == "plugin_admin"
@pytest.mark.anyio
async def test_fixture_relationships_work(
self, db_session: AsyncSession, fixture_users: list[User]
):
"""Loaded fixtures have working relationships directly accessible."""
user = next(u for u in fixture_users if u.id == USER_ADMIN_ID)
assert user.role is not None
assert user.role.name == "plugin_admin"
@pytest.mark.anyio
async def test_chained_dependencies(
self, db_session: AsyncSession, fixture_extra_users: list[User]
):
"""Chained dependencies are resolved correctly."""
# fixture_extra_users -> fixture_users -> fixture_roles
assert len(fixture_extra_users) == 1
# All fixtures should be loaded
roles_count = await RoleCrud.count(db_session)
users_count = await UserCrud.count(db_session)
assert roles_count == 2
assert users_count == 3 # 2 from users + 1 from extra_users
@pytest.mark.anyio
async def test_can_query_loaded_data(
self, db_session: AsyncSession, fixture_users: list[User]
):
"""Can query the loaded fixture data."""
# Get all users loaded by fixture
users = await UserCrud.get_multi(
db_session,
order_by=User.username,
)
assert len(users) == 2
assert users[0].username == "plugin_admin"
assert users[1].username == "plugin_user"
@pytest.mark.anyio
async def test_fixture_auto_loads_relationships(
self, db_session: AsyncSession, fixture_users: list[User]
):
"""Fixtures automatically eager-load all direct relationships."""
user = next(u for u in fixture_users if u.username == "plugin_admin")
assert user.role is not None
assert user.role.name == "plugin_admin"
@pytest.mark.anyio
async def test_multiple_fixtures_in_same_test(
self,
db_session: AsyncSession,
fixture_roles: list[Role],
fixture_users: list[User],
):
"""Multiple fixtures can be used in the same test."""
assert len(fixture_roles) == 2
assert len(fixture_users) == 2
# Both should be in database
roles = await RoleCrud.get_multi(db_session)
users = await UserCrud.get_multi(db_session)
assert len(roles) == 2
assert len(users) == 2
class TestCreateAsyncClient:
"""Tests for create_async_client helper."""
@pytest.mark.anyio
async def test_creates_working_client(self):
"""Client can make requests to the app."""
app = FastAPI()
@app.get("/health")
async def health():
return {"status": "ok"}
async with create_async_client(app) as client:
assert isinstance(client, AsyncClient)
response = await client.get("/health")
assert response.status_code == 200
assert response.json() == {"status": "ok"}
@pytest.mark.anyio
async def test_custom_base_url(self):
"""Client uses custom base URL."""
app = FastAPI()
@app.get("/test")
async def test_endpoint():
return {"url": "test"}
async with create_async_client(app, base_url="http://custom") as client:
assert str(client.base_url) == "http://custom"
@pytest.mark.anyio
async def test_client_closes_properly(self):
"""Client is properly closed after context exit."""
app = FastAPI()
async with create_async_client(app) as client:
client_ref = client
assert client_ref.is_closed
@pytest.mark.anyio
async def test_dependency_overrides_applied_and_cleaned(self):
"""Dependency overrides are applied during the context and removed after."""
app = FastAPI()
async def original_dep() -> str:
return "original"
async def override_dep() -> str:
return "overridden"
@app.get("/dep")
async def dep_endpoint(value: str = Depends(original_dep)):
return {"value": value}
async with create_async_client(
app, dependency_overrides={original_dep: override_dep}
) as client:
response = await client.get("/dep")
assert response.json() == {"value": "overridden"}
# Overrides should be cleaned up
assert original_dep not in app.dependency_overrides
class TestCreateDbSession:
"""Tests for create_db_session helper."""
@pytest.mark.anyio
async def test_creates_working_session(self):
"""Session can perform database operations."""
role_id = uuid.uuid4()
async with create_db_session(DATABASE_URL, Base) as session:
assert isinstance(session, AsyncSession)
role = Role(id=role_id, name="test_helper_role")
session.add(role)
await session.commit()
result = await session.execute(select(Role).where(Role.id == role_id))
fetched = result.scalar_one()
assert fetched.name == "test_helper_role"
@pytest.mark.anyio
async def test_tables_created_before_session(self):
"""Tables exist when session is yielded."""
async with create_db_session(DATABASE_URL, Base) as session:
# Should not raise - tables exist
result = await session.execute(select(Role))
assert result.all() == []
@pytest.mark.anyio
async def test_tables_dropped_after_session(self):
"""Tables are dropped after session closes when drop_tables=True."""
role_id = uuid.uuid4()
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
role = Role(id=role_id, name="will_be_dropped")
session.add(role)
await session.commit()
# Verify tables were dropped by creating new session
async with create_db_session(DATABASE_URL, Base) as session:
result = await session.execute(select(Role))
assert result.all() == []
@pytest.mark.anyio
async def test_tables_preserved_when_drop_disabled(self):
"""Tables are preserved when drop_tables=False."""
role_id = uuid.uuid4()
async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session:
role = Role(id=role_id, name="preserved_role")
session.add(role)
await session.commit()
# Create another session without dropping
async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session:
result = await session.execute(select(Role).where(Role.id == role_id))
fetched = result.scalar_one_or_none()
assert fetched is not None
assert fetched.name == "preserved_role"
# Cleanup: drop tables manually
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as _:
pass
@pytest.mark.anyio
async def test_cleanup_truncates_tables(self):
"""Tables are truncated after session closes when cleanup=True."""
role_id = uuid.uuid4()
async with create_db_session(
DATABASE_URL, Base, cleanup=True, drop_tables=False
) as session:
role = Role(id=role_id, name="will_be_cleaned")
session.add(role)
await session.commit()
# Data should have been truncated, but tables still exist
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
result = await session.execute(select(Role))
assert result.all() == []
@pytest.mark.anyio
async def test_get_transaction_commits_visible_to_separate_session(self):
"""Data written via get_transaction() is committed and visible to other sessions."""
role_id = uuid.uuid4()
async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session:
# Simulate what _create_fixture_function does: insert via get_transaction
# with no explicit commit afterward.
async with get_transaction(session):
role = Role(id=role_id, name="visible_to_other_session")
session.add(role)
# The data must have been committed (begin/commit, not a savepoint),
# so a separate engine/session can read it.
other_engine = create_async_engine(DATABASE_URL, echo=False)
try:
other_session_maker = async_sessionmaker(
other_engine, expire_on_commit=False
)
async with other_session_maker() as other:
result = await other.execute(select(Role).where(Role.id == role_id))
fetched = result.scalar_one_or_none()
assert fetched is not None, (
"Fixture data inserted via get_transaction() must be committed "
"and visible to a separate session. If create_db_session uses "
"create_db_context, auto-begin forces get_transaction() into "
"savepoints instead of real commits."
)
assert fetched.name == "visible_to_other_session"
finally:
await other_engine.dispose()
# Cleanup
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as _:
pass
class TestGetXdistWorker:
"""Tests for _get_xdist_worker helper."""
def test_returns_default_test_db_without_env_var(
self, monkeypatch: pytest.MonkeyPatch
):
"""Returns default_test_db when PYTEST_XDIST_WORKER is not set."""
monkeypatch.delenv("PYTEST_XDIST_WORKER", raising=False)
assert _get_xdist_worker("my_default") == "my_default"
def test_returns_worker_name(self, monkeypatch: pytest.MonkeyPatch):
"""Returns the worker name from the environment variable."""
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw0")
assert _get_xdist_worker("ignored") == "gw0"
class TestWorkerDatabaseUrl:
"""Tests for worker_database_url helper."""
def test_appends_default_test_db_without_xdist(
self, monkeypatch: pytest.MonkeyPatch
):
"""default_test_db is appended when not running under xdist."""
monkeypatch.delenv("PYTEST_XDIST_WORKER", raising=False)
url = "postgresql+asyncpg://user:pass@localhost:5432/mydb"
result = worker_database_url(url, default_test_db="fallback")
assert make_url(result).database == "mydb_fallback"
def test_appends_worker_id_to_database_name(self, monkeypatch: pytest.MonkeyPatch):
"""Worker name is appended to the database name."""
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw0")
url = "postgresql+asyncpg://user:pass@localhost:5432/db"
result = worker_database_url(url, default_test_db="unused")
assert make_url(result).database == "db_gw0"
def test_preserves_url_components(self, monkeypatch: pytest.MonkeyPatch):
"""Host, port, username, password, and driver are preserved."""
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw2")
url = "postgresql+asyncpg://myuser:secret@dbhost:6543/testdb"
result = make_url(worker_database_url(url, default_test_db="unused"))
assert result.drivername == "postgresql+asyncpg"
assert result.username == "myuser"
assert result.password == "secret"
assert result.host == "dbhost"
assert result.port == 6543
assert result.database == "testdb_gw2"
class TestCreateWorkerDatabase:
"""Tests for create_worker_database context manager."""
@pytest.mark.anyio
async def test_creates_default_db_without_xdist(
self, monkeypatch: pytest.MonkeyPatch
):
"""Without xdist, creates a database suffixed with default_test_db."""
monkeypatch.delenv("PYTEST_XDIST_WORKER", raising=False)
default_test_db = "no_xdist_default"
expected_db = make_url(
worker_database_url(DATABASE_URL, default_test_db=default_test_db)
).database
async with create_worker_database(
DATABASE_URL, default_test_db=default_test_db
) as url:
assert make_url(url).database == expected_db
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
async with engine.connect() as conn:
result = await conn.execute(
text("SELECT 1 FROM pg_database WHERE datname = :name"),
{"name": expected_db},
)
assert result.scalar() == 1
await engine.dispose()
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
async with engine.connect() as conn:
result = await conn.execute(
text("SELECT 1 FROM pg_database WHERE datname = :name"),
{"name": expected_db},
)
assert result.scalar() is None
await engine.dispose()
@pytest.mark.anyio
async def test_creates_and_drops_worker_database(
self, monkeypatch: pytest.MonkeyPatch
):
"""Worker database exists inside the context and is dropped after."""
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw_test_create")
expected_db = make_url(
worker_database_url(DATABASE_URL, default_test_db="unused")
).database
async with create_worker_database(DATABASE_URL) as url:
assert make_url(url).database == expected_db
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
async with engine.connect() as conn:
result = await conn.execute(
text("SELECT 1 FROM pg_database WHERE datname = :name"),
{"name": expected_db},
)
assert result.scalar() == 1
await engine.dispose()
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
async with engine.connect() as conn:
result = await conn.execute(
text("SELECT 1 FROM pg_database WHERE datname = :name"),
{"name": expected_db},
)
assert result.scalar() is None
await engine.dispose()
@pytest.mark.anyio
async def test_cleans_up_stale_database(self, monkeypatch: pytest.MonkeyPatch):
"""A pre-existing worker database is dropped and recreated."""
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw_test_stale")
expected_db = make_url(
worker_database_url(DATABASE_URL, default_test_db="unused")
).database
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
async with engine.connect() as conn:
await conn.execute(text(f"DROP DATABASE IF EXISTS {expected_db}"))
await conn.execute(text(f"CREATE DATABASE {expected_db}"))
await engine.dispose()
async with create_worker_database(DATABASE_URL) as url:
assert make_url(url).database == expected_db
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
async with engine.connect() as conn:
result = await conn.execute(
text("SELECT 1 FROM pg_database WHERE datname = :name"),
{"name": expected_db},
)
assert result.scalar() is None
await engine.dispose()
class _LocalBase(DeclarativeBase):
pass
class _Group(_LocalBase):
__tablename__ = "_test_groups"
id: Mapped[uuid.UUID] = mapped_column(primary_key=True, default=uuid.uuid4)
name: Mapped[str] = mapped_column(String(50))
class _CompositeItem(_LocalBase):
"""Model with composite PK and a relationship — exercises the fallback path."""
__tablename__ = "_test_composite_items"
group_id: Mapped[uuid.UUID] = mapped_column(
ForeignKey("_test_groups.id"), primary_key=True
)
item_code: Mapped[str] = mapped_column(String(50), primary_key=True)
group: Mapped["_Group"] = relationship()
class TestGetPrimaryKey:
"""Unit tests for _get_primary_key — no DB needed."""
def test_single_pk_returns_value(self):
rid = uuid.UUID("00000000-0000-0000-0000-000000000001")
role = Role(id=rid, name="x")
assert _get_primary_key(role) == rid
def test_composite_pk_all_set_returns_tuple(self):
perm = Permission(subject="posts", action="read")
assert _get_primary_key(perm) == ("posts", "read")
def test_composite_pk_partial_none_returns_none(self):
perm = Permission(subject=None, action="read")
assert _get_primary_key(perm) is None
def test_composite_pk_all_none_returns_none(self):
perm = Permission(subject=None, action=None)
assert _get_primary_key(perm) is None
class TestRelationshipLoadOptions:
"""Unit tests for _relationship_load_options — no DB needed."""
def test_empty_for_model_with_no_relationships(self):
assert _relationship_load_options(IntRole) == []
def test_returns_options_for_model_with_relationships(self):
opts = _relationship_load_options(User)
assert len(opts) >= 1
class TestFixtureStrategies:
"""Integration tests covering INSERT, SKIP_EXISTING, empty fixture, no-rels model."""
@pytest.mark.anyio
async def test_empty_fixture_returns_empty_list(self, db_session: AsyncSession):
"""Fixture function returning [] produces an empty list."""
registry = FixtureRegistry()
@registry.register()
def empty() -> list[Role]:
return []
local_ns: dict = {}
register_fixtures(registry, local_ns, session_fixture="db_session")
inner = local_ns["fixture_empty"].__wrapped__ # type: ignore[attr-defined]
result = await inner(db_session=db_session)
assert result == []
@pytest.mark.anyio
async def test_insert_strategy_no_relationships(self, db_session: AsyncSession):
"""INSERT strategy adds instances; model with no rels skips reload (line 135)."""
registry = FixtureRegistry()
@registry.register()
def int_roles() -> list[IntRole]:
return [IntRole(name="insert_role")]
local_ns: dict = {}
register_fixtures(
registry,
local_ns,
session_fixture="db_session",
strategy=LoadStrategy.INSERT,
)
inner = local_ns["fixture_int_roles"].__wrapped__ # type: ignore[attr-defined]
result = await inner(db_session=db_session)
assert len(result) == 1
assert result[0].name == "insert_role"
@pytest.mark.anyio
async def test_skip_existing_inserts_new_record(self, db_session: AsyncSession):
"""SKIP_EXISTING inserts when the record does not yet exist."""
registry = FixtureRegistry()
role_id = uuid.uuid4()
@registry.register()
def new_roles() -> list[Role]:
return [Role(id=role_id, name="skip_new")]
local_ns: dict = {}
register_fixtures(
registry,
local_ns,
session_fixture="db_session",
strategy=LoadStrategy.SKIP_EXISTING,
)
inner = local_ns["fixture_new_roles"].__wrapped__ # type: ignore[attr-defined]
result = await inner(db_session=db_session)
assert len(result) == 1
assert result[0].id == role_id
@pytest.mark.anyio
async def test_skip_existing_returns_existing_record(
self, db_session: AsyncSession
):
"""SKIP_EXISTING returns the existing DB record when PK already present."""
role_id = uuid.uuid4()
existing = Role(id=role_id, name="already_there")
db_session.add(existing)
await db_session.flush()
registry = FixtureRegistry()
@registry.register()
def dup_roles() -> list[Role]:
return [Role(id=role_id, name="should_not_overwrite")]
local_ns: dict = {}
register_fixtures(
registry,
local_ns,
session_fixture="db_session",
strategy=LoadStrategy.SKIP_EXISTING,
)
inner = local_ns["fixture_dup_roles"].__wrapped__ # type: ignore[attr-defined]
result = await inner(db_session=db_session)
assert len(result) == 1
assert result[0].name == "already_there"
@pytest.mark.anyio
async def test_skip_existing_null_pk_inserts(self, db_session: AsyncSession):
"""SKIP_EXISTING with null PK (auto-increment) falls through to session.add()."""
registry = FixtureRegistry()
@registry.register()
def auto_roles() -> list[IntRole]:
return [IntRole(name="auto_int")]
local_ns: dict = {}
register_fixtures(
registry,
local_ns,
session_fixture="db_session",
strategy=LoadStrategy.SKIP_EXISTING,
)
inner = local_ns["fixture_auto_roles"].__wrapped__ # type: ignore[attr-defined]
result = await inner(db_session=db_session)
assert len(result) == 1
assert result[0].name == "auto_int"
class TestReloadWithRelationshipsCompositePK:
"""Integration test for _reload_with_relationships composite-PK fallback."""
@pytest.mark.anyio
async def test_composite_pk_fallback_loads_relationships(self):
"""Models with composite PKs are reloaded per-instance via session.get()."""
async with create_db_session(DATABASE_URL, _LocalBase) as session:
group = _Group(id=uuid.uuid4(), name="g1")
session.add(group)
await session.flush()
item = _CompositeItem(group_id=group.id, item_code="A")
session.add(item)
await session.flush()
load_opts = _relationship_load_options(_CompositeItem)
assert load_opts # _CompositeItem has 'group' relationship
reloaded = await _reload_with_relationships(session, [item], load_opts)
assert len(reloaded) == 1
reloaded_item = cast(_CompositeItem, reloaded[0])
assert reloaded_item.group is not None
assert reloaded_item.group.name == "g1"