diff --git a/tests/conftest.py b/tests/conftest.py index 68be228..10ac078 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -321,30 +321,3 @@ async def db_session(engine): # Drop tables after test async with engine.begin() as conn: await conn.run_sync(Base.metadata.drop_all) - - -@pytest.fixture -def sample_role_data() -> RoleCreate: - """Sample role creation data.""" - return RoleCreate(name="admin") - - -@pytest.fixture -def sample_user_data() -> UserCreate: - """Sample user creation data.""" - return UserCreate( - username="testuser", - email="test@example.com", - is_active=True, - ) - - -@pytest.fixture -def sample_post_data() -> PostCreate: - """Sample post creation data.""" - return PostCreate( - title="Test Post", - content="Test content", - is_published=True, - author_id=uuid.uuid4(), - ) diff --git a/tests/test_example_pagination_search.py b/tests/test_example_pagination_search.py index 3281cd2..2dbfacf 100644 --- a/tests/test_example_pagination_search.py +++ b/tests/test_example_pagination_search.py @@ -10,12 +10,13 @@ import datetime import pytest from fastapi import FastAPI from httpx import ASGITransport, AsyncClient -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.ext.asyncio import AsyncSession from docs_src.examples.pagination_search.db import get_db from docs_src.examples.pagination_search.models import Article, Base, Category from docs_src.examples.pagination_search.routes import router from fastapi_toolsets.exceptions import init_exceptions_handlers +from fastapi_toolsets.pytest import create_db_session from .conftest import DATABASE_URL @@ -35,20 +36,8 @@ def build_app(session: AsyncSession) -> FastAPI: @pytest.fixture(scope="function") async def ex_db_session(): """Isolated session for the example models (separate tables from conftest).""" - engine = create_async_engine(DATABASE_URL, echo=False) - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) - - session_factory = async_sessionmaker(engine, expire_on_commit=False) - session = session_factory() - - try: + async with create_db_session(DATABASE_URL, Base) as session: yield session - finally: - await session.close() - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.drop_all) - await engine.dispose() @pytest.fixture diff --git a/tests/test_fixtures.py b/tests/test_fixtures.py index 169765b..0ad3d4a 100644 --- a/tests/test_fixtures.py +++ b/tests/test_fixtures.py @@ -16,7 +16,7 @@ from fastapi_toolsets.fixtures import ( from fastapi_toolsets.fixtures.utils import _get_primary_key -from .conftest import IntRole, Permission, Role, User +from .conftest import IntRole, Permission, Role, RoleCrud, User, UserCrud class TestContext: @@ -447,8 +447,6 @@ class TestLoadFixtures: assert "roles" in result assert len(result["roles"]) == 2 - from .conftest import RoleCrud - count = await RoleCrud.count(db_session) assert count == 2 @@ -479,8 +477,6 @@ class TestLoadFixtures: assert "roles" in result assert "users" in result - from .conftest import RoleCrud, UserCrud - assert await RoleCrud.count(db_session) == 1 assert await UserCrud.count(db_session) == 1 @@ -497,8 +493,6 @@ class TestLoadFixtures: await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE) await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE) - from .conftest import RoleCrud - count = await RoleCrud.count(db_session) assert count == 1 @@ -526,8 +520,6 @@ class TestLoadFixtures: db_session, registry, "roles", strategy=LoadStrategy.SKIP_EXISTING ) - from .conftest import RoleCrud - role = await RoleCrud.first(db_session, [Role.id == role_id]) assert role is not None assert role.name == "original" @@ -553,8 +545,6 @@ class TestLoadFixtures: assert "roles" in result assert len(result["roles"]) == 2 - from .conftest import RoleCrud - count = await RoleCrud.count(db_session) assert count == 2 @@ -594,8 +584,6 @@ class TestLoadFixtures: assert "roles" in result assert "other_roles" in result - from .conftest import RoleCrud - count = await RoleCrud.count(db_session) assert count == 2 @@ -660,8 +648,6 @@ class TestLoadFixturesByContext: await load_fixtures_by_context(db_session, registry, Context.BASE) - from .conftest import RoleCrud - count = await RoleCrud.count(db_session) assert count == 1 @@ -688,8 +674,6 @@ class TestLoadFixturesByContext: db_session, registry, Context.BASE, Context.TESTING ) - from .conftest import RoleCrud - count = await RoleCrud.count(db_session) assert count == 2 @@ -717,8 +701,6 @@ class TestLoadFixturesByContext: await load_fixtures_by_context(db_session, registry, Context.TESTING) - from .conftest import RoleCrud, UserCrud - assert await RoleCrud.count(db_session) == 1 assert await UserCrud.count(db_session) == 1 diff --git a/tests/test_imports.py b/tests/test_imports.py index 7245208..1406989 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -171,8 +171,15 @@ class TestPytestImportGuard: class TestCliImportGuard: """Tests for CLI module import guard when typer is missing.""" - def test_import_raises_without_typer(self): - """Importing cli.app raises when typer is missing.""" + @pytest.mark.parametrize( + "expected_match", + [ + "typer", + r"pip install fastapi-toolsets\[cli\]", + ], + ) + def test_import_raises_without_typer(self, expected_match): + """Importing cli.app raises when typer is missing, with an informative error message.""" saved, blocking_import = _reload_without_package( "fastapi_toolsets.cli.app", ["typer"] ) @@ -186,33 +193,7 @@ class TestCliImportGuard: try: with patch("builtins.__import__", side_effect=blocking_import): - with pytest.raises(ImportError, match="typer"): - importlib.import_module("fastapi_toolsets.cli.app") - finally: - for key in list(sys.modules): - if key.startswith("fastapi_toolsets.cli.app") or key.startswith( - "fastapi_toolsets.cli.config" - ): - sys.modules.pop(key, None) - sys.modules.update(saved) - - def test_error_message_suggests_cli_extra(self): - """Error message suggests installing the cli extra.""" - saved, blocking_import = _reload_without_package( - "fastapi_toolsets.cli.app", ["typer"] - ) - config_keys = [ - k for k in sys.modules if k.startswith("fastapi_toolsets.cli.config") - ] - for key in config_keys: - if key not in saved: - saved[key] = sys.modules.pop(key) - - try: - with patch("builtins.__import__", side_effect=blocking_import): - with pytest.raises( - ImportError, match=r"pip install fastapi-toolsets\[cli\]" - ): + with pytest.raises(ImportError, match=expected_match): importlib.import_module("fastapi_toolsets.cli.app") finally: for key in list(sys.modules): diff --git a/tests/test_metrics.py b/tests/test_metrics.py index ace11b0..cefd249 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -1,6 +1,5 @@ """Tests for fastapi_toolsets.metrics module.""" -import os import tempfile from unittest.mock import AsyncMock, MagicMock @@ -287,6 +286,16 @@ class TestIncludeRegistry: class TestInitMetrics: """Tests for init_metrics function.""" + @pytest.fixture + def metrics_client(self): + """Create a FastAPI app with MetricsRegistry and return a TestClient.""" + app = FastAPI() + registry = MetricsRegistry() + init_metrics(app, registry) + client = TestClient(app) + yield client + client.close() + def test_returns_app(self): """Returns the FastAPI app.""" app = FastAPI() @@ -294,26 +303,14 @@ class TestInitMetrics: result = init_metrics(app, registry) assert result is app - def test_metrics_endpoint_responds(self): + def test_metrics_endpoint_responds(self, metrics_client): """The /metrics endpoint returns 200.""" - app = FastAPI() - registry = MetricsRegistry() - init_metrics(app, registry) - - client = TestClient(app) - response = client.get("/metrics") - + response = metrics_client.get("/metrics") assert response.status_code == 200 - def test_metrics_endpoint_content_type(self): + def test_metrics_endpoint_content_type(self, metrics_client): """The /metrics endpoint returns prometheus content type.""" - app = FastAPI() - registry = MetricsRegistry() - init_metrics(app, registry) - - client = TestClient(app) - response = client.get("/metrics") - + response = metrics_client.get("/metrics") assert "text/plain" in response.headers["content-type"] def test_custom_path(self): @@ -445,36 +442,33 @@ class TestInitMetrics: class TestMultiProcessMode: """Tests for multi-process Prometheus mode.""" - def test_multiprocess_with_env_var(self): + def test_multiprocess_with_env_var(self, monkeypatch): """Multi-process mode works when PROMETHEUS_MULTIPROC_DIR is set.""" with tempfile.TemporaryDirectory() as tmpdir: - os.environ["PROMETHEUS_MULTIPROC_DIR"] = tmpdir - try: - # Use a separate registry to avoid conflicts with default - prom_registry = CollectorRegistry() - app = FastAPI() - registry = MetricsRegistry() + monkeypatch.setenv("PROMETHEUS_MULTIPROC_DIR", tmpdir) + # Use a separate registry to avoid conflicts with default + prom_registry = CollectorRegistry() + app = FastAPI() + registry = MetricsRegistry() - @registry.register - def mp_counter(): - return Counter( - "mp_test_counter", - "A multiprocess counter", - registry=prom_registry, - ) + @registry.register + def mp_counter(): + return Counter( + "mp_test_counter", + "A multiprocess counter", + registry=prom_registry, + ) - init_metrics(app, registry) + init_metrics(app, registry) - client = TestClient(app) - response = client.get("/metrics") + client = TestClient(app) + response = client.get("/metrics") - assert response.status_code == 200 - finally: - del os.environ["PROMETHEUS_MULTIPROC_DIR"] + assert response.status_code == 200 - def test_single_process_without_env_var(self): + def test_single_process_without_env_var(self, monkeypatch): """Single-process mode when PROMETHEUS_MULTIPROC_DIR is not set.""" - os.environ.pop("PROMETHEUS_MULTIPROC_DIR", None) + monkeypatch.delenv("PROMETHEUS_MULTIPROC_DIR", raising=False) app = FastAPI() registry = MetricsRegistry() diff --git a/tests/test_models.py b/tests/test_models.py index 1c91742..c0ffba7 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -8,9 +8,10 @@ from unittest.mock import patch import pytest from sqlalchemy import String -from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column +from fastapi_toolsets.pytest import create_db_session + import fastapi_toolsets.models.watched as _watched_module from fastapi_toolsets.models import ( CreatedAtMixin, @@ -267,39 +268,17 @@ class FutureCallbackModel(MixinBase, UUIDMixin, WatchedFieldsMixin): @pytest.fixture(scope="function") async def mixin_session(): - engine = create_async_engine(DATABASE_URL, echo=False) - async with engine.begin() as conn: - await conn.run_sync(MixinBase.metadata.create_all) - - session_factory = async_sessionmaker(engine, expire_on_commit=False) - session = session_factory() - - try: + async with create_db_session(DATABASE_URL, MixinBase) as session: yield session - finally: - await session.close() - async with engine.begin() as conn: - await conn.run_sync(MixinBase.metadata.drop_all) - await engine.dispose() @pytest.fixture(scope="function") async def mixin_session_expire(): """Session with expire_on_commit=True (the default) to exercise attribute access after commit.""" - engine = create_async_engine(DATABASE_URL, echo=False) - async with engine.begin() as conn: - await conn.run_sync(MixinBase.metadata.create_all) - - session_factory = async_sessionmaker(engine, expire_on_commit=True) - session = session_factory() - - try: + async with create_db_session( + DATABASE_URL, MixinBase, expire_on_commit=True + ) as session: yield session - finally: - await session.close() - async with engine.begin() as conn: - await conn.run_sync(MixinBase.metadata.drop_all) - await engine.dispose() class TestUUIDMixin: