10 Commits

16 changed files with 752 additions and 17 deletions

View File

@@ -1,6 +1,6 @@
[project]
name = "fastapi-toolsets"
version = "0.6.1"
version = "0.8.0"
description = "Reusable tools for FastAPI: async CRUD, fixtures, CLI, and standardized responses for SQLAlchemy + PostgreSQL"
readme = "README.md"
license = "MIT"
@@ -49,6 +49,7 @@ Issues = "https://github.com/d3vyce/fastapi-toolsets/issues"
test = [
"pytest>=8.0.0",
"pytest-anyio>=0.0.0",
"pytest-xdist>=3.0.0",
"coverage>=7.0.0",
"pytest-cov>=4.0.0",
]
@@ -62,7 +63,7 @@ dev = [
manager = "fastapi_toolsets.cli.app:cli"
[build-system]
requires = ["uv_build>=0.9.26,<0.10.0"]
requires = ["uv_build>=0.10,<0.11.0"]
build-backend = "uv_build"
[tool.pytest.ini_options]

View File

@@ -21,4 +21,4 @@ Example usage:
return Response(data={"user": user.username}, message="Success")
"""
__version__ = "0.6.1"
__version__ = "0.8.0"

View File

@@ -2,6 +2,7 @@
import typer
from ..logger import configure_logging
from .config import get_custom_cli
from .pyproject import load_pyproject
@@ -27,4 +28,5 @@ if _config.get("fixtures") and _config.get("db_context"):
@cli.callback()
def main(ctx: typer.Context) -> None:
"""FastAPI utilities CLI."""
configure_logging()
ctx.ensure_object(dict)

View File

@@ -1,5 +1,7 @@
"""Generic async CRUD operations for SQLAlchemy models."""
from __future__ import annotations
from collections.abc import Sequence
from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload

View File

@@ -1,8 +1,10 @@
"""Database utilities: sessions, transactions, and locks."""
import asyncio
from collections.abc import AsyncGenerator, Callable
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from enum import Enum
from typing import Any, TypeVar
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
@@ -14,6 +16,7 @@ __all__ = [
"create_db_dependency",
"lock_tables",
"get_transaction",
"wait_for_row_change",
]
@@ -173,3 +176,69 @@ async def lock_tables(
await session.execute(text(f"SET LOCAL lock_timeout='{timeout}'"))
await session.execute(text(f"LOCK {table_names} IN {mode.value} MODE"))
yield session
_M = TypeVar("_M", bound=DeclarativeBase)
async def wait_for_row_change(
session: AsyncSession,
model: type[_M],
pk_value: Any,
*,
columns: list[str] | None = None,
interval: float = 0.5,
timeout: float | None = None,
) -> _M:
"""Poll a database row until a change is detected.
Queries the row every ``interval`` seconds and returns the model instance
once a change is detected in any column (or only the specified ``columns``).
Args:
session: AsyncSession instance
model: SQLAlchemy model class
pk_value: Primary key value of the row to watch
columns: Optional list of column names to watch. If None, all columns
are watched.
interval: Polling interval in seconds (default: 0.5)
timeout: Maximum time to wait in seconds. None means wait forever.
Returns:
The refreshed model instance with updated values
Raises:
LookupError: If the row does not exist or is deleted during polling
TimeoutError: If timeout expires before a change is detected
"""
instance = await session.get(model, pk_value)
if instance is None:
raise LookupError(f"{model.__name__} with pk={pk_value!r} not found")
if columns is not None:
watch_cols = columns
else:
watch_cols = [attr.key for attr in model.__mapper__.column_attrs]
initial = {col: getattr(instance, col) for col in watch_cols}
elapsed = 0.0
while True:
await asyncio.sleep(interval)
elapsed += interval
if timeout is not None and elapsed >= timeout:
raise TimeoutError(
f"No change detected on {model.__name__} "
f"with pk={pk_value!r} within {timeout}s"
)
session.expunge(instance)
instance = await session.get(model, pk_value)
if instance is None:
raise LookupError(f"{model.__name__} with pk={pk_value!r} was deleted")
current = {col: getattr(instance, col) for col in watch_cols}
if current != initial:
return instance

View File

@@ -1,15 +1,15 @@
"""Fixture system with dependency management and context support."""
import logging
from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
from typing import Any, cast
from sqlalchemy.orm import DeclarativeBase
from ..logger import get_logger
from .enum import Context
logger = logging.getLogger(__name__)
logger = get_logger()
@dataclass

View File

@@ -1,4 +1,3 @@
import logging
from collections.abc import Callable, Sequence
from typing import Any, TypeVar
@@ -6,10 +5,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase
from ..db import get_transaction
from ..logger import get_logger
from .enum import LoadStrategy
from .registry import Context, FixtureRegistry
logger = logging.getLogger(__name__)
logger = get_logger()
T = TypeVar("T", bound=DeclarativeBase)
@@ -29,9 +29,14 @@ def get_obj_by_attr(
The first model instance where the attribute matches the given value.
Raises:
StopIteration: If no matching object is found.
StopIteration: If no matching object is found in the fixture group.
"""
return next(obj for obj in fixtures() if getattr(obj, attr_name) == value)
try:
return next(obj for obj in fixtures() if getattr(obj, attr_name) == value)
except StopIteration:
raise StopIteration(
f"No object with {attr_name}={value} found in fixture '{getattr(fixtures, '__name__', repr(fixtures))}'"
) from None
async def load_fixtures(

View File

@@ -0,0 +1,81 @@
"""Logging configuration for FastAPI applications and CLI tools."""
import logging
import sys
from typing import Literal
__all__ = ["LogLevel", "configure_logging", "get_logger"]
DEFAULT_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
UVICORN_LOGGERS = ("uvicorn", "uvicorn.access", "uvicorn.error")
LogLevel = Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
def configure_logging(
level: LogLevel | int = "INFO",
fmt: str = DEFAULT_FORMAT,
logger_name: str | None = None,
) -> logging.Logger:
"""Configure logging with a stdout handler and consistent format.
Sets up a :class:`~logging.StreamHandler` writing to stdout with the
given format and level. Also configures the uvicorn loggers so that
FastAPI access logs use the same format.
Calling this function multiple times is safe -- existing handlers are
replaced rather than duplicated.
Args:
level: Log level (e.g. ``"DEBUG"``, ``"INFO"``, or ``logging.DEBUG``).
fmt: Log format string. Defaults to
``"%(asctime)s - %(name)s - %(levelname)s - %(message)s"``.
logger_name: Logger name to configure. ``None`` (the default)
configures the root logger so all loggers inherit the settings.
Returns:
The configured Logger instance.
"""
formatter = logging.Formatter(fmt)
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(formatter)
logger = logging.getLogger(logger_name)
logger.handlers.clear()
logger.addHandler(handler)
logger.setLevel(level)
for name in UVICORN_LOGGERS:
uv_logger = logging.getLogger(name)
uv_logger.handlers.clear()
uv_logger.addHandler(handler)
uv_logger.setLevel(level)
return logger
_SENTINEL = object()
def get_logger(name: str | None = _SENTINEL) -> logging.Logger: # type: ignore[assignment]
"""Return a logger with the given *name*.
A thin convenience wrapper around :func:`logging.getLogger` that keeps
logging imports consistent across the codebase.
When called without arguments, the caller's ``__name__`` is used
automatically, so ``get_logger()`` in a module is equivalent to
``logging.getLogger(__name__)``. Pass ``None`` explicitly to get the
root logger.
Args:
name: Logger name. Defaults to the caller's ``__name__``.
Pass ``None`` to get the root logger.
Returns:
A Logger instance.
"""
if name is _SENTINEL:
name = sys._getframe(1).f_globals.get("__name__")
return logging.getLogger(name)

View File

@@ -1,8 +1,17 @@
from .plugin import register_fixtures
from .utils import create_async_client, create_db_session
from .utils import (
cleanup_tables,
create_async_client,
create_db_session,
create_worker_database,
worker_database_url,
)
__all__ = [
"cleanup_tables",
"create_async_client",
"create_db_session",
"create_worker_database",
"register_fixtures",
"worker_database_url",
]

View File

@@ -1,11 +1,18 @@
"""Pytest helper utilities for FastAPI testing."""
import os
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from typing import Any
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy import text
from sqlalchemy.engine import make_url
from sqlalchemy.ext.asyncio import (
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.orm import DeclarativeBase
from ..db import create_db_context
@@ -108,3 +115,147 @@ async def create_db_session(
await conn.run_sync(base.metadata.drop_all)
finally:
await engine.dispose()
def _get_xdist_worker() -> str | None:
"""Return the pytest-xdist worker name, or ``None`` when not running under xdist.
Reads the ``PYTEST_XDIST_WORKER`` environment variable that xdist sets
automatically in each worker process (e.g. ``"gw0"``, ``"gw1"``).
When xdist is not installed or not active, the variable is absent and
``None`` is returned.
"""
return os.environ.get("PYTEST_XDIST_WORKER")
def worker_database_url(database_url: str) -> str:
"""Derive a per-worker database URL for pytest-xdist parallel runs.
Appends ``_{worker_name}`` to the database name so each xdist worker
operates on its own database. When not running under xdist the
original URL is returned unchanged.
The worker name is read from the ``PYTEST_XDIST_WORKER`` environment
variable (set automatically by xdist in each worker process).
Args:
database_url: Original database connection URL.
Returns:
A database URL with the worker-specific database name, or the
original URL when not running under xdist.
Example:
```python
# With PYTEST_XDIST_WORKER="gw0":
url = worker_database_url(
"postgresql+asyncpg://user:pass@localhost/test_db"
)
# "postgresql+asyncpg://user:pass@localhost/test_db_gw0"
```
"""
worker = _get_xdist_worker()
if worker is None:
return database_url
url = make_url(database_url)
url = url.set(database=f"{url.database}_{worker}")
return url.render_as_string(hide_password=False)
@asynccontextmanager
async def create_worker_database(
database_url: str,
) -> AsyncGenerator[str, None]:
"""Create and drop a per-worker database for pytest-xdist isolation.
Intended for use as a **session-scoped** fixture. Connects to the server
using the original *database_url* (with ``AUTOCOMMIT`` isolation for DDL),
creates a dedicated database for the worker, and yields the worker-specific
URL. On cleanup the worker database is dropped.
When not running under xdist (``PYTEST_XDIST_WORKER`` is unset), the
original URL is yielded without any database creation or teardown.
Args:
database_url: Original database connection URL.
Yields:
The worker-specific database URL.
Example:
```python
from fastapi_toolsets.pytest import (
create_worker_database, create_db_session, cleanup_tables
)
DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost/test_db"
@pytest.fixture(scope="session")
async def worker_db_url():
async with create_worker_database(DATABASE_URL) as url:
yield url
@pytest.fixture
async def db_session(worker_db_url):
async with create_db_session(worker_db_url, Base) as session:
yield session
await cleanup_tables(session, Base)
```
"""
if _get_xdist_worker() is None:
yield database_url
return
worker_url = worker_database_url(database_url)
worker_db_name = make_url(worker_url).database
engine = create_async_engine(
database_url,
isolation_level="AUTOCOMMIT",
)
try:
async with engine.connect() as conn:
await conn.execute(text(f"DROP DATABASE IF EXISTS {worker_db_name}"))
await conn.execute(text(f"CREATE DATABASE {worker_db_name}"))
yield worker_url
async with engine.connect() as conn:
await conn.execute(text(f"DROP DATABASE IF EXISTS {worker_db_name}"))
finally:
await engine.dispose()
async def cleanup_tables(
session: AsyncSession,
base: type[DeclarativeBase],
) -> None:
"""Truncate all tables for fast between-test cleanup.
Executes a single ``TRUNCATE … RESTART IDENTITY CASCADE`` statement
across every table in *base*'s metadata, which is significantly faster
than dropping and re-creating tables between tests.
This is a no-op when the metadata contains no tables.
Args:
session: An active async database session.
base: SQLAlchemy DeclarativeBase class containing model metadata.
Example:
```python
@pytest.fixture
async def db_session(worker_db_url):
async with create_db_session(worker_db_url, Base) as session:
yield session
await cleanup_tables(session, Base)
```
"""
tables = base.metadata.sorted_tables
if not tables:
return
table_names = ", ".join(f'"{t.name}"' for t in tables)
await session.execute(text(f"TRUNCATE {table_names} RESTART IDENTITY CASCADE"))
await session.commit()

View File

@@ -10,6 +10,7 @@ __all__ = [
"ErrorResponse",
"Pagination",
"PaginatedResponse",
"PydanticBase",
"Response",
"ResponseStatus",
]

View File

@@ -1,5 +1,8 @@
"""Tests for fastapi_toolsets.db module."""
import asyncio
import uuid
import pytest
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
@@ -9,6 +12,7 @@ from fastapi_toolsets.db import (
create_db_dependency,
get_transaction,
lock_tables,
wait_for_row_change,
)
from .conftest import DATABASE_URL, Base, Role, RoleCrud, User
@@ -241,3 +245,101 @@ class TestLockTables:
result = await RoleCrud.first(db_session, [Role.name == "lock_rollback_role"])
assert result is None
class TestWaitForRowChange:
"""Tests for wait_for_row_change polling function."""
@pytest.mark.anyio
async def test_detects_update(self, db_session: AsyncSession, engine):
"""Returns updated instance when a column value changes."""
role = Role(name="watch_role")
db_session.add(role)
await db_session.commit()
async def update_later():
await asyncio.sleep(0.15)
factory = async_sessionmaker(engine, expire_on_commit=False)
async with factory() as other:
r = await other.get(Role, role.id)
assert r is not None
r.name = "updated_role"
await other.commit()
update_task = asyncio.create_task(update_later())
result = await wait_for_row_change(db_session, Role, role.id, interval=0.05)
await update_task
assert result.name == "updated_role"
@pytest.mark.anyio
async def test_watches_specific_columns(self, db_session: AsyncSession, engine):
"""Only triggers on changes to specified columns."""
user = User(username="testuser", email="test@example.com")
db_session.add(user)
await db_session.commit()
async def update_later():
factory = async_sessionmaker(engine, expire_on_commit=False)
# First: change email (not watched) — should not trigger
await asyncio.sleep(0.15)
async with factory() as other:
u = await other.get(User, user.id)
assert u is not None
u.email = "new@example.com"
await other.commit()
# Second: change username (watched) — should trigger
await asyncio.sleep(0.15)
async with factory() as other:
u = await other.get(User, user.id)
assert u is not None
u.username = "newuser"
await other.commit()
update_task = asyncio.create_task(update_later())
result = await wait_for_row_change(
db_session, User, user.id, columns=["username"], interval=0.05
)
await update_task
assert result.username == "newuser"
assert result.email == "new@example.com"
@pytest.mark.anyio
async def test_nonexistent_row_raises(self, db_session: AsyncSession):
"""Raises LookupError when the row does not exist."""
fake_id = uuid.uuid4()
with pytest.raises(LookupError, match="not found"):
await wait_for_row_change(db_session, Role, fake_id, interval=0.05)
@pytest.mark.anyio
async def test_timeout_raises(self, db_session: AsyncSession):
"""Raises TimeoutError when no change is detected within timeout."""
role = Role(name="timeout_role")
db_session.add(role)
await db_session.commit()
with pytest.raises(TimeoutError):
await wait_for_row_change(
db_session, Role, role.id, interval=0.05, timeout=0.2
)
@pytest.mark.anyio
async def test_deleted_row_raises(self, db_session: AsyncSession, engine):
"""Raises LookupError when the row is deleted during polling."""
role = Role(name="delete_role")
db_session.add(role)
await db_session.commit()
async def delete_later():
await asyncio.sleep(0.15)
factory = async_sessionmaker(engine, expire_on_commit=False)
async with factory() as other:
r = await other.get(Role, role.id)
await other.delete(r)
await other.commit()
delete_task = asyncio.create_task(delete_later())
with pytest.raises(LookupError):
await wait_for_row_change(db_session, Role, role.id, interval=0.05)
await delete_task

View File

@@ -744,8 +744,11 @@ class TestGetObjByAttr:
assert user.username == "alice"
def test_no_match_raises_stop_iteration(self):
"""Raises StopIteration when no object matches."""
with pytest.raises(StopIteration):
"""Raises StopIteration with contextual message when no object matches."""
with pytest.raises(
StopIteration,
match="No object with name=nonexistent found in fixture 'roles'",
):
get_obj_by_attr(self.roles, "name", "nonexistent")
def test_no_match_on_wrong_value_type(self):

118
tests/test_logger.py Normal file
View File

@@ -0,0 +1,118 @@
import logging
import sys
import pytest
from fastapi_toolsets.logger import (
DEFAULT_FORMAT,
UVICORN_LOGGERS,
configure_logging,
get_logger,
)
@pytest.fixture(autouse=True)
def _reset_loggers():
"""Reset the root and uvicorn loggers after each test."""
yield
root = logging.getLogger()
root.handlers.clear()
root.setLevel(logging.WARNING)
for name in UVICORN_LOGGERS:
uv = logging.getLogger(name)
uv.handlers.clear()
uv.setLevel(logging.NOTSET)
class TestConfigureLogging:
def test_sets_up_handler_and_format(self):
logger = configure_logging()
assert len(logger.handlers) == 1
handler = logger.handlers[0]
assert isinstance(handler, logging.StreamHandler)
assert handler.stream is sys.stdout
assert handler.formatter is not None
assert handler.formatter._fmt == DEFAULT_FORMAT
def test_default_level_is_info(self):
logger = configure_logging()
assert logger.level == logging.INFO
def test_custom_level_string(self):
logger = configure_logging(level="DEBUG")
assert logger.level == logging.DEBUG
def test_custom_level_int(self):
logger = configure_logging(level=logging.WARNING)
assert logger.level == logging.WARNING
def test_custom_format(self):
custom_fmt = "%(levelname)s: %(message)s"
logger = configure_logging(fmt=custom_fmt)
handler = logger.handlers[0]
assert handler.formatter is not None
assert handler.formatter._fmt == custom_fmt
def test_named_logger(self):
logger = configure_logging(logger_name="myapp")
assert logger.name == "myapp"
assert len(logger.handlers) == 1
def test_default_configures_root_logger(self):
logger = configure_logging()
assert logger is logging.getLogger()
def test_idempotent_no_duplicate_handlers(self):
configure_logging()
configure_logging()
logger = configure_logging()
assert len(logger.handlers) == 1
def test_configures_uvicorn_loggers(self):
configure_logging(level="DEBUG")
for name in UVICORN_LOGGERS:
uv_logger = logging.getLogger(name)
assert len(uv_logger.handlers) == 1
assert uv_logger.level == logging.DEBUG
handler = uv_logger.handlers[0]
assert handler.formatter is not None
assert handler.formatter._fmt == DEFAULT_FORMAT
def test_returns_configured_logger(self):
logger = configure_logging(logger_name="test.return")
assert isinstance(logger, logging.Logger)
assert logger.name == "test.return"
class TestGetLogger:
def test_returns_named_logger(self):
logger = get_logger("myapp.services")
assert isinstance(logger, logging.Logger)
assert logger.name == "myapp.services"
def test_returns_root_logger_when_none(self):
logger = get_logger(None)
assert logger is logging.getLogger()
def test_defaults_to_caller_module_name(self):
logger = get_logger()
assert logger.name == __name__
def test_same_name_returns_same_logger(self):
a = get_logger("myapp")
b = get_logger("myapp")
assert a is b

View File

@@ -5,16 +5,21 @@ import uuid
import pytest
from fastapi import FastAPI
from httpx import AsyncClient
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sqlalchemy import select, text
from sqlalchemy.engine import make_url
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import DeclarativeBase, selectinload
from fastapi_toolsets.fixtures import Context, FixtureRegistry
from fastapi_toolsets.pytest import (
cleanup_tables,
create_async_client,
create_db_session,
create_worker_database,
register_fixtures,
worker_database_url,
)
from fastapi_toolsets.pytest.utils import _get_xdist_worker
from .conftest import DATABASE_URL, Base, Role, RoleCrud, User, UserCrud
@@ -291,3 +296,164 @@ class TestCreateDbSession:
# Cleanup: drop tables manually
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as _:
pass
class TestGetXdistWorker:
"""Tests for get_xdist_worker helper."""
def test_returns_none_without_env_var(self, monkeypatch: pytest.MonkeyPatch):
"""Returns None when PYTEST_XDIST_WORKER is not set."""
monkeypatch.delenv("PYTEST_XDIST_WORKER", raising=False)
assert _get_xdist_worker() is None
def test_returns_worker_name(self, monkeypatch: pytest.MonkeyPatch):
"""Returns the worker name from the environment variable."""
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw0")
assert _get_xdist_worker() == "gw0"
class TestWorkerDatabaseUrl:
"""Tests for worker_database_url helper."""
def test_returns_original_url_without_xdist(self, monkeypatch: pytest.MonkeyPatch):
"""URL is returned unchanged when not running under xdist."""
monkeypatch.delenv("PYTEST_XDIST_WORKER", raising=False)
url = "postgresql+asyncpg://user:pass@localhost:5432/mydb"
assert worker_database_url(url) == url
def test_appends_worker_id_to_database_name(self, monkeypatch: pytest.MonkeyPatch):
"""Worker name is appended to the database name."""
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw0")
url = "postgresql+asyncpg://user:pass@localhost:5432/db"
result = worker_database_url(url)
assert make_url(result).database == "db_gw0"
def test_preserves_url_components(self, monkeypatch: pytest.MonkeyPatch):
"""Host, port, username, password, and driver are preserved."""
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw2")
url = "postgresql+asyncpg://myuser:secret@dbhost:6543/testdb"
result = make_url(worker_database_url(url))
assert result.drivername == "postgresql+asyncpg"
assert result.username == "myuser"
assert result.password == "secret"
assert result.host == "dbhost"
assert result.port == 6543
assert result.database == "testdb_gw2"
class TestCreateWorkerDatabase:
"""Tests for create_worker_database context manager."""
@pytest.mark.anyio
async def test_yields_original_url_without_xdist(
self, monkeypatch: pytest.MonkeyPatch
):
"""Without xdist, yields the original URL without database operations."""
monkeypatch.delenv("PYTEST_XDIST_WORKER", raising=False)
async with create_worker_database(DATABASE_URL) as url:
assert url == DATABASE_URL
@pytest.mark.anyio
async def test_creates_and_drops_worker_database(
self, monkeypatch: pytest.MonkeyPatch
):
"""Worker database exists inside the context and is dropped after."""
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw_test_create")
expected_db = make_url(worker_database_url(DATABASE_URL)).database
async with create_worker_database(DATABASE_URL) as url:
assert make_url(url).database == expected_db
# Verify the database exists while inside the context
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
async with engine.connect() as conn:
result = await conn.execute(
text("SELECT 1 FROM pg_database WHERE datname = :name"),
{"name": expected_db},
)
assert result.scalar() == 1
await engine.dispose()
# After context exit the database should be dropped
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
async with engine.connect() as conn:
result = await conn.execute(
text("SELECT 1 FROM pg_database WHERE datname = :name"),
{"name": expected_db},
)
assert result.scalar() is None
await engine.dispose()
@pytest.mark.anyio
async def test_cleans_up_stale_database(self, monkeypatch: pytest.MonkeyPatch):
"""A pre-existing worker database is dropped and recreated."""
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw_test_stale")
expected_db = make_url(worker_database_url(DATABASE_URL)).database
# Pre-create the database to simulate a stale leftover
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
async with engine.connect() as conn:
await conn.execute(text(f"DROP DATABASE IF EXISTS {expected_db}"))
await conn.execute(text(f"CREATE DATABASE {expected_db}"))
await engine.dispose()
# Should succeed despite the database already existing
async with create_worker_database(DATABASE_URL) as url:
assert make_url(url).database == expected_db
# Verify cleanup after context exit
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
async with engine.connect() as conn:
result = await conn.execute(
text("SELECT 1 FROM pg_database WHERE datname = :name"),
{"name": expected_db},
)
assert result.scalar() is None
await engine.dispose()
class TestCleanupTables:
"""Tests for cleanup_tables helper."""
@pytest.mark.anyio
async def test_truncates_all_tables(self):
"""All table rows are removed after cleanup_tables."""
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
role = Role(id=uuid.uuid4(), name="cleanup_role")
session.add(role)
await session.flush()
user = User(
id=uuid.uuid4(),
username="cleanup_user",
email="cleanup@test.com",
role_id=role.id,
)
session.add(user)
await session.commit()
# Verify rows exist
roles_count = await RoleCrud.count(session)
users_count = await UserCrud.count(session)
assert roles_count == 1
assert users_count == 1
await cleanup_tables(session, Base)
# Verify tables are empty
roles_count = await RoleCrud.count(session)
users_count = await UserCrud.count(session)
assert roles_count == 0
assert users_count == 0
@pytest.mark.anyio
async def test_noop_for_empty_metadata(self):
"""cleanup_tables does not raise when metadata has no tables."""
class EmptyBase(DeclarativeBase):
pass
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
# Should not raise
await cleanup_tables(session, EmptyBase)

27
uv.lock generated
View File

@@ -203,6 +203,15 @@ toml = [
{ name = "tomli", marker = "python_full_version <= '3.11'" },
]
[[package]]
name = "execnet"
version = "2.1.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/bf/89/780e11f9588d9e7128a3f87788354c7946a9cbb1401ad38a48c4db9a4f07/execnet-2.1.2.tar.gz", hash = "sha256:63d83bfdd9a23e35b9c6a3261412324f964c2ec8dcd8d3c6916ee9373e0befcd", size = 166622, upload-time = "2025-11-12T09:56:37.75Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/ab/84/02fc1827e8cdded4aa65baef11296a9bbe595c474f0d6d758af082d849fd/execnet-2.1.2-py3-none-any.whl", hash = "sha256:67fba928dd5a544b783f6056f449e5e3931a5c378b128bc18501f7ea79e296ec", size = 40708, upload-time = "2025-11-12T09:56:36.333Z" },
]
[[package]]
name = "fastapi"
version = "0.128.1"
@@ -220,7 +229,7 @@ wheels = [
[[package]]
name = "fastapi-toolsets"
version = "0.6.1"
version = "0.8.0"
source = { editable = "." }
dependencies = [
{ name = "asyncpg" },
@@ -237,6 +246,7 @@ dev = [
{ name = "pytest" },
{ name = "pytest-anyio" },
{ name = "pytest-cov" },
{ name = "pytest-xdist" },
{ name = "ruff" },
{ name = "ty" },
]
@@ -245,6 +255,7 @@ test = [
{ name = "pytest" },
{ name = "pytest-anyio" },
{ name = "pytest-cov" },
{ name = "pytest-xdist" },
]
[package.metadata]
@@ -258,6 +269,7 @@ requires-dist = [
{ name = "pytest", marker = "extra == 'test'", specifier = ">=8.0.0" },
{ name = "pytest-anyio", marker = "extra == 'test'", specifier = ">=0.0.0" },
{ name = "pytest-cov", marker = "extra == 'test'", specifier = ">=4.0.0" },
{ name = "pytest-xdist", marker = "extra == 'test'", specifier = ">=3.0.0" },
{ name = "ruff", marker = "extra == 'dev'", specifier = ">=0.1.0" },
{ name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0" },
{ name = "ty", marker = "extra == 'dev'", specifier = ">=0.0.1a0" },
@@ -575,6 +587,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/ee/49/1377b49de7d0c1ce41292161ea0f721913fa8722c19fb9c1e3aa0367eecb/pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861", size = 22424, upload-time = "2025-09-09T10:57:00.695Z" },
]
[[package]]
name = "pytest-xdist"
version = "3.8.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "execnet" },
{ name = "pytest" },
]
sdist = { url = "https://files.pythonhosted.org/packages/78/b4/439b179d1ff526791eb921115fca8e44e596a13efeda518b9d845a619450/pytest_xdist-3.8.0.tar.gz", hash = "sha256:7e578125ec9bc6050861aa93f2d59f1d8d085595d6551c2c90b6f4fad8d3a9f1", size = 88069, upload-time = "2025-07-01T13:30:59.346Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/ca/31/d4e37e9e550c2b92a9cbc2e4d0b7420a27224968580b5a447f420847c975/pytest_xdist-3.8.0-py3-none-any.whl", hash = "sha256:202ca578cfeb7370784a8c33d6d05bc6e13b4f25b5053c30a152269fd10f0b88", size = 46396, upload-time = "2025-07-01T13:30:56.632Z" },
]
[[package]]
name = "rich"
version = "14.3.2"