Compare commits

..

3 Commits

10 changed files with 76 additions and 158 deletions

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "fastapi-toolsets" name = "fastapi-toolsets"
version = "2.4.1" version = "2.4.2"
description = "Production-ready utilities for FastAPI applications" description = "Production-ready utilities for FastAPI applications"
readme = "README.md" readme = "README.md"
license = "MIT" license = "MIT"

View File

@@ -21,4 +21,4 @@ Example usage:
return Response(data={"user": user.username}, message="Success") return Response(data={"user": user.username}, message="Success")
""" """
__version__ = "2.4.1" __version__ = "2.4.2"

View File

@@ -26,6 +26,7 @@ _SESSION_PENDING_NEW = "_ft_pending_new"
_SESSION_CREATES = "_ft_creates" _SESSION_CREATES = "_ft_creates"
_SESSION_DELETES = "_ft_deletes" _SESSION_DELETES = "_ft_deletes"
_SESSION_UPDATES = "_ft_updates" _SESSION_UPDATES = "_ft_updates"
_SESSION_SAVEPOINT_DEPTH = "_ft_sp_depth"
class ModelEvent(str, Enum): class ModelEvent(str, Enum):
@@ -92,6 +93,22 @@ def _upsert_changes(
pending[key] = (obj, changes) pending[key] = (obj, changes)
@event.listens_for(AsyncSession.sync_session_class, "after_transaction_create")
def _after_transaction_create(session: Any, transaction: Any) -> None:
if transaction.nested:
session.info[_SESSION_SAVEPOINT_DEPTH] = (
session.info.get(_SESSION_SAVEPOINT_DEPTH, 0) + 1
)
@event.listens_for(AsyncSession.sync_session_class, "after_transaction_end")
def _after_transaction_end(session: Any, transaction: Any) -> None:
if transaction.nested:
depth = session.info.get(_SESSION_SAVEPOINT_DEPTH, 0)
if depth > 0: # pragma: no branch
session.info[_SESSION_SAVEPOINT_DEPTH] = depth - 1
@event.listens_for(AsyncSession.sync_session_class, "after_flush") @event.listens_for(AsyncSession.sync_session_class, "after_flush")
def _after_flush(session: Any, flush_context: Any) -> None: def _after_flush(session: Any, flush_context: Any) -> None:
# New objects: capture references while session.new is still populated. # New objects: capture references while session.new is still populated.
@@ -189,6 +206,9 @@ def _schedule_with_snapshot(
@event.listens_for(AsyncSession.sync_session_class, "after_commit") @event.listens_for(AsyncSession.sync_session_class, "after_commit")
def _after_commit(session: Any) -> None: def _after_commit(session: Any) -> None:
if session.info.get(_SESSION_SAVEPOINT_DEPTH, 0) > 0:
return
creates: list[Any] = session.info.pop(_SESSION_CREATES, []) creates: list[Any] = session.info.pop(_SESSION_CREATES, [])
deletes: list[Any] = session.info.pop(_SESSION_DELETES, []) deletes: list[Any] = session.info.pop(_SESSION_DELETES, [])
field_changes: dict[int, tuple[Any, dict[str, dict[str, Any]]]] = session.info.pop( field_changes: dict[int, tuple[Any, dict[str, dict[str, Any]]]] = session.info.pop(

View File

@@ -321,30 +321,3 @@ async def db_session(engine):
# Drop tables after test # Drop tables after test
async with engine.begin() as conn: async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all) await conn.run_sync(Base.metadata.drop_all)
@pytest.fixture
def sample_role_data() -> RoleCreate:
"""Sample role creation data."""
return RoleCreate(name="admin")
@pytest.fixture
def sample_user_data() -> UserCreate:
"""Sample user creation data."""
return UserCreate(
username="testuser",
email="test@example.com",
is_active=True,
)
@pytest.fixture
def sample_post_data() -> PostCreate:
"""Sample post creation data."""
return PostCreate(
title="Test Post",
content="Test content",
is_published=True,
author_id=uuid.uuid4(),
)

View File

@@ -10,12 +10,13 @@ import datetime
import pytest import pytest
from fastapi import FastAPI from fastapi import FastAPI
from httpx import ASGITransport, AsyncClient from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import AsyncSession
from docs_src.examples.pagination_search.db import get_db from docs_src.examples.pagination_search.db import get_db
from docs_src.examples.pagination_search.models import Article, Base, Category from docs_src.examples.pagination_search.models import Article, Base, Category
from docs_src.examples.pagination_search.routes import router from docs_src.examples.pagination_search.routes import router
from fastapi_toolsets.exceptions import init_exceptions_handlers from fastapi_toolsets.exceptions import init_exceptions_handlers
from fastapi_toolsets.pytest import create_db_session
from .conftest import DATABASE_URL from .conftest import DATABASE_URL
@@ -35,20 +36,8 @@ def build_app(session: AsyncSession) -> FastAPI:
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
async def ex_db_session(): async def ex_db_session():
"""Isolated session for the example models (separate tables from conftest).""" """Isolated session for the example models (separate tables from conftest)."""
engine = create_async_engine(DATABASE_URL, echo=False) async with create_db_session(DATABASE_URL, Base) as session:
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
session_factory = async_sessionmaker(engine, expire_on_commit=False)
session = session_factory()
try:
yield session yield session
finally:
await session.close()
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await engine.dispose()
@pytest.fixture @pytest.fixture

View File

@@ -16,7 +16,7 @@ from fastapi_toolsets.fixtures import (
from fastapi_toolsets.fixtures.utils import _get_primary_key from fastapi_toolsets.fixtures.utils import _get_primary_key
from .conftest import IntRole, Permission, Role, User from .conftest import IntRole, Permission, Role, RoleCrud, User, UserCrud
class TestContext: class TestContext:
@@ -447,8 +447,6 @@ class TestLoadFixtures:
assert "roles" in result assert "roles" in result
assert len(result["roles"]) == 2 assert len(result["roles"]) == 2
from .conftest import RoleCrud
count = await RoleCrud.count(db_session) count = await RoleCrud.count(db_session)
assert count == 2 assert count == 2
@@ -479,8 +477,6 @@ class TestLoadFixtures:
assert "roles" in result assert "roles" in result
assert "users" in result assert "users" in result
from .conftest import RoleCrud, UserCrud
assert await RoleCrud.count(db_session) == 1 assert await RoleCrud.count(db_session) == 1
assert await UserCrud.count(db_session) == 1 assert await UserCrud.count(db_session) == 1
@@ -497,8 +493,6 @@ class TestLoadFixtures:
await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE) await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE)
await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE) await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE)
from .conftest import RoleCrud
count = await RoleCrud.count(db_session) count = await RoleCrud.count(db_session)
assert count == 1 assert count == 1
@@ -526,8 +520,6 @@ class TestLoadFixtures:
db_session, registry, "roles", strategy=LoadStrategy.SKIP_EXISTING db_session, registry, "roles", strategy=LoadStrategy.SKIP_EXISTING
) )
from .conftest import RoleCrud
role = await RoleCrud.first(db_session, [Role.id == role_id]) role = await RoleCrud.first(db_session, [Role.id == role_id])
assert role is not None assert role is not None
assert role.name == "original" assert role.name == "original"
@@ -553,8 +545,6 @@ class TestLoadFixtures:
assert "roles" in result assert "roles" in result
assert len(result["roles"]) == 2 assert len(result["roles"]) == 2
from .conftest import RoleCrud
count = await RoleCrud.count(db_session) count = await RoleCrud.count(db_session)
assert count == 2 assert count == 2
@@ -594,8 +584,6 @@ class TestLoadFixtures:
assert "roles" in result assert "roles" in result
assert "other_roles" in result assert "other_roles" in result
from .conftest import RoleCrud
count = await RoleCrud.count(db_session) count = await RoleCrud.count(db_session)
assert count == 2 assert count == 2
@@ -660,8 +648,6 @@ class TestLoadFixturesByContext:
await load_fixtures_by_context(db_session, registry, Context.BASE) await load_fixtures_by_context(db_session, registry, Context.BASE)
from .conftest import RoleCrud
count = await RoleCrud.count(db_session) count = await RoleCrud.count(db_session)
assert count == 1 assert count == 1
@@ -688,8 +674,6 @@ class TestLoadFixturesByContext:
db_session, registry, Context.BASE, Context.TESTING db_session, registry, Context.BASE, Context.TESTING
) )
from .conftest import RoleCrud
count = await RoleCrud.count(db_session) count = await RoleCrud.count(db_session)
assert count == 2 assert count == 2
@@ -717,8 +701,6 @@ class TestLoadFixturesByContext:
await load_fixtures_by_context(db_session, registry, Context.TESTING) await load_fixtures_by_context(db_session, registry, Context.TESTING)
from .conftest import RoleCrud, UserCrud
assert await RoleCrud.count(db_session) == 1 assert await RoleCrud.count(db_session) == 1
assert await UserCrud.count(db_session) == 1 assert await UserCrud.count(db_session) == 1

View File

@@ -171,8 +171,15 @@ class TestPytestImportGuard:
class TestCliImportGuard: class TestCliImportGuard:
"""Tests for CLI module import guard when typer is missing.""" """Tests for CLI module import guard when typer is missing."""
def test_import_raises_without_typer(self): @pytest.mark.parametrize(
"""Importing cli.app raises when typer is missing.""" "expected_match",
[
"typer",
r"pip install fastapi-toolsets\[cli\]",
],
)
def test_import_raises_without_typer(self, expected_match):
"""Importing cli.app raises when typer is missing, with an informative error message."""
saved, blocking_import = _reload_without_package( saved, blocking_import = _reload_without_package(
"fastapi_toolsets.cli.app", ["typer"] "fastapi_toolsets.cli.app", ["typer"]
) )
@@ -186,33 +193,7 @@ class TestCliImportGuard:
try: try:
with patch("builtins.__import__", side_effect=blocking_import): with patch("builtins.__import__", side_effect=blocking_import):
with pytest.raises(ImportError, match="typer"): with pytest.raises(ImportError, match=expected_match):
importlib.import_module("fastapi_toolsets.cli.app")
finally:
for key in list(sys.modules):
if key.startswith("fastapi_toolsets.cli.app") or key.startswith(
"fastapi_toolsets.cli.config"
):
sys.modules.pop(key, None)
sys.modules.update(saved)
def test_error_message_suggests_cli_extra(self):
"""Error message suggests installing the cli extra."""
saved, blocking_import = _reload_without_package(
"fastapi_toolsets.cli.app", ["typer"]
)
config_keys = [
k for k in sys.modules if k.startswith("fastapi_toolsets.cli.config")
]
for key in config_keys:
if key not in saved:
saved[key] = sys.modules.pop(key)
try:
with patch("builtins.__import__", side_effect=blocking_import):
with pytest.raises(
ImportError, match=r"pip install fastapi-toolsets\[cli\]"
):
importlib.import_module("fastapi_toolsets.cli.app") importlib.import_module("fastapi_toolsets.cli.app")
finally: finally:
for key in list(sys.modules): for key in list(sys.modules):

View File

@@ -1,6 +1,5 @@
"""Tests for fastapi_toolsets.metrics module.""" """Tests for fastapi_toolsets.metrics module."""
import os
import tempfile import tempfile
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
@@ -287,6 +286,16 @@ class TestIncludeRegistry:
class TestInitMetrics: class TestInitMetrics:
"""Tests for init_metrics function.""" """Tests for init_metrics function."""
@pytest.fixture
def metrics_client(self):
"""Create a FastAPI app with MetricsRegistry and return a TestClient."""
app = FastAPI()
registry = MetricsRegistry()
init_metrics(app, registry)
client = TestClient(app)
yield client
client.close()
def test_returns_app(self): def test_returns_app(self):
"""Returns the FastAPI app.""" """Returns the FastAPI app."""
app = FastAPI() app = FastAPI()
@@ -294,26 +303,14 @@ class TestInitMetrics:
result = init_metrics(app, registry) result = init_metrics(app, registry)
assert result is app assert result is app
def test_metrics_endpoint_responds(self): def test_metrics_endpoint_responds(self, metrics_client):
"""The /metrics endpoint returns 200.""" """The /metrics endpoint returns 200."""
app = FastAPI() response = metrics_client.get("/metrics")
registry = MetricsRegistry()
init_metrics(app, registry)
client = TestClient(app)
response = client.get("/metrics")
assert response.status_code == 200 assert response.status_code == 200
def test_metrics_endpoint_content_type(self): def test_metrics_endpoint_content_type(self, metrics_client):
"""The /metrics endpoint returns prometheus content type.""" """The /metrics endpoint returns prometheus content type."""
app = FastAPI() response = metrics_client.get("/metrics")
registry = MetricsRegistry()
init_metrics(app, registry)
client = TestClient(app)
response = client.get("/metrics")
assert "text/plain" in response.headers["content-type"] assert "text/plain" in response.headers["content-type"]
def test_custom_path(self): def test_custom_path(self):
@@ -445,36 +442,33 @@ class TestInitMetrics:
class TestMultiProcessMode: class TestMultiProcessMode:
"""Tests for multi-process Prometheus mode.""" """Tests for multi-process Prometheus mode."""
def test_multiprocess_with_env_var(self): def test_multiprocess_with_env_var(self, monkeypatch):
"""Multi-process mode works when PROMETHEUS_MULTIPROC_DIR is set.""" """Multi-process mode works when PROMETHEUS_MULTIPROC_DIR is set."""
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
os.environ["PROMETHEUS_MULTIPROC_DIR"] = tmpdir monkeypatch.setenv("PROMETHEUS_MULTIPROC_DIR", tmpdir)
try: # Use a separate registry to avoid conflicts with default
# Use a separate registry to avoid conflicts with default prom_registry = CollectorRegistry()
prom_registry = CollectorRegistry() app = FastAPI()
app = FastAPI() registry = MetricsRegistry()
registry = MetricsRegistry()
@registry.register @registry.register
def mp_counter(): def mp_counter():
return Counter( return Counter(
"mp_test_counter", "mp_test_counter",
"A multiprocess counter", "A multiprocess counter",
registry=prom_registry, registry=prom_registry,
) )
init_metrics(app, registry) init_metrics(app, registry)
client = TestClient(app) client = TestClient(app)
response = client.get("/metrics") response = client.get("/metrics")
assert response.status_code == 200 assert response.status_code == 200
finally:
del os.environ["PROMETHEUS_MULTIPROC_DIR"]
def test_single_process_without_env_var(self): def test_single_process_without_env_var(self, monkeypatch):
"""Single-process mode when PROMETHEUS_MULTIPROC_DIR is not set.""" """Single-process mode when PROMETHEUS_MULTIPROC_DIR is not set."""
os.environ.pop("PROMETHEUS_MULTIPROC_DIR", None) monkeypatch.delenv("PROMETHEUS_MULTIPROC_DIR", raising=False)
app = FastAPI() app = FastAPI()
registry = MetricsRegistry() registry = MetricsRegistry()

View File

@@ -8,9 +8,10 @@ from unittest.mock import patch
import pytest import pytest
from sqlalchemy import String from sqlalchemy import String
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from fastapi_toolsets.pytest import create_db_session
import fastapi_toolsets.models.watched as _watched_module import fastapi_toolsets.models.watched as _watched_module
from fastapi_toolsets.models import ( from fastapi_toolsets.models import (
CreatedAtMixin, CreatedAtMixin,
@@ -267,39 +268,17 @@ class FutureCallbackModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
async def mixin_session(): async def mixin_session():
engine = create_async_engine(DATABASE_URL, echo=False) async with create_db_session(DATABASE_URL, MixinBase) as session:
async with engine.begin() as conn:
await conn.run_sync(MixinBase.metadata.create_all)
session_factory = async_sessionmaker(engine, expire_on_commit=False)
session = session_factory()
try:
yield session yield session
finally:
await session.close()
async with engine.begin() as conn:
await conn.run_sync(MixinBase.metadata.drop_all)
await engine.dispose()
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
async def mixin_session_expire(): async def mixin_session_expire():
"""Session with expire_on_commit=True (the default) to exercise attribute access after commit.""" """Session with expire_on_commit=True (the default) to exercise attribute access after commit."""
engine = create_async_engine(DATABASE_URL, echo=False) async with create_db_session(
async with engine.begin() as conn: DATABASE_URL, MixinBase, expire_on_commit=True
await conn.run_sync(MixinBase.metadata.create_all) ) as session:
session_factory = async_sessionmaker(engine, expire_on_commit=True)
session = session_factory()
try:
yield session yield session
finally:
await session.close()
async with engine.begin() as conn:
await conn.run_sync(MixinBase.metadata.drop_all)
await engine.dispose()
class TestUUIDMixin: class TestUUIDMixin:

2
uv.lock generated
View File

@@ -251,7 +251,7 @@ wheels = [
[[package]] [[package]]
name = "fastapi-toolsets" name = "fastapi-toolsets"
version = "2.4.1" version = "2.4.2"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "asyncpg" }, { name = "asyncpg" },