mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-03-01 17:00:48 +01:00
Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
2020fa2f92
|
|||
|
|
1ea316bef4 | ||
|
|
ced1a655f2 | ||
|
|
290b2a06ec | ||
|
|
baa9711665 | ||
|
d526969d0e
|
|||
|
|
e24153053e | ||
|
348ed4c148
|
|||
|
bd6e90de1b
|
|||
|
|
4404fb3df9 | ||
|
|
f68793fbdb | ||
|
|
3a69c3c788 |
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "fastapi-toolsets"
|
||||
version = "0.6.1"
|
||||
version = "0.8.0"
|
||||
description = "Reusable tools for FastAPI: async CRUD, fixtures, CLI, and standardized responses for SQLAlchemy + PostgreSQL"
|
||||
readme = "README.md"
|
||||
license = "MIT"
|
||||
@@ -49,6 +49,7 @@ Issues = "https://github.com/d3vyce/fastapi-toolsets/issues"
|
||||
test = [
|
||||
"pytest>=8.0.0",
|
||||
"pytest-anyio>=0.0.0",
|
||||
"pytest-xdist>=3.0.0",
|
||||
"coverage>=7.0.0",
|
||||
"pytest-cov>=4.0.0",
|
||||
]
|
||||
@@ -62,7 +63,7 @@ dev = [
|
||||
manager = "fastapi_toolsets.cli.app:cli"
|
||||
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.9.26,<0.10.0"]
|
||||
requires = ["uv_build>=0.10,<0.11.0"]
|
||||
build-backend = "uv_build"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
|
||||
@@ -21,4 +21,4 @@ Example usage:
|
||||
return Response(data={"user": user.username}, message="Success")
|
||||
"""
|
||||
|
||||
__version__ = "0.6.1"
|
||||
__version__ = "0.8.0"
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import typer
|
||||
|
||||
from ..logger import configure_logging
|
||||
from .config import get_custom_cli
|
||||
from .pyproject import load_pyproject
|
||||
|
||||
@@ -27,4 +28,5 @@ if _config.get("fixtures") and _config.get("db_context"):
|
||||
@cli.callback()
|
||||
def main(ctx: typer.Context) -> None:
|
||||
"""FastAPI utilities CLI."""
|
||||
configure_logging()
|
||||
ctx.ensure_object(dict)
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
"""Generic async CRUD operations for SQLAlchemy models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, ClassVar, Generic, Self, TypeVar, cast
|
||||
from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import and_, func, select
|
||||
@@ -14,6 +16,7 @@ from sqlalchemy.sql.roles import WhereHavingRole
|
||||
|
||||
from ..db import get_transaction
|
||||
from ..exceptions import NotFoundError
|
||||
from ..schemas import PaginatedResponse, Pagination, Response
|
||||
from .search import SearchConfig, SearchFieldType, build_search_filters
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
||||
@@ -29,26 +32,80 @@ class AsyncCrud(Generic[ModelType]):
|
||||
model: ClassVar[type[DeclarativeBase]]
|
||||
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def create( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
obj: BaseModel,
|
||||
*,
|
||||
as_response: Literal[True],
|
||||
) -> Response[ModelType]: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def create( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
obj: BaseModel,
|
||||
*,
|
||||
as_response: Literal[False] = ...,
|
||||
) -> ModelType: ...
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
obj: BaseModel,
|
||||
) -> ModelType:
|
||||
*,
|
||||
as_response: bool = False,
|
||||
) -> ModelType | Response[ModelType]:
|
||||
"""Create a new record in the database.
|
||||
|
||||
Args:
|
||||
session: DB async session
|
||||
obj: Pydantic model with data to create
|
||||
as_response: If True, wrap result in Response object
|
||||
|
||||
Returns:
|
||||
Created model instance
|
||||
Created model instance or Response wrapping it
|
||||
"""
|
||||
async with get_transaction(session):
|
||||
db_model = cls.model(**obj.model_dump())
|
||||
session.add(db_model)
|
||||
await session.refresh(db_model)
|
||||
return cast(ModelType, db_model)
|
||||
result = cast(ModelType, db_model)
|
||||
if as_response:
|
||||
return Response(data=result)
|
||||
return result
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def get( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
filters: list[Any],
|
||||
*,
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
with_for_update: bool = False,
|
||||
load_options: list[Any] | None = None,
|
||||
as_response: Literal[True],
|
||||
) -> Response[ModelType]: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def get( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
filters: list[Any],
|
||||
*,
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
with_for_update: bool = False,
|
||||
load_options: list[Any] | None = None,
|
||||
as_response: Literal[False] = ...,
|
||||
) -> ModelType: ...
|
||||
|
||||
@classmethod
|
||||
async def get(
|
||||
@@ -60,7 +117,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
outer_join: bool = False,
|
||||
with_for_update: bool = False,
|
||||
load_options: list[Any] | None = None,
|
||||
) -> ModelType:
|
||||
as_response: bool = False,
|
||||
) -> ModelType | Response[ModelType]:
|
||||
"""Get exactly one record. Raises NotFoundError if not found.
|
||||
|
||||
Args:
|
||||
@@ -70,9 +128,10 @@ class AsyncCrud(Generic[ModelType]):
|
||||
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
|
||||
with_for_update: Lock the row for update
|
||||
load_options: SQLAlchemy loader options (e.g., selectinload)
|
||||
as_response: If True, wrap result in Response object
|
||||
|
||||
Returns:
|
||||
Model instance
|
||||
Model instance or Response wrapping it
|
||||
|
||||
Raises:
|
||||
NotFoundError: If no record found
|
||||
@@ -95,7 +154,10 @@ class AsyncCrud(Generic[ModelType]):
|
||||
item = result.unique().scalar_one_or_none()
|
||||
if not item:
|
||||
raise NotFoundError()
|
||||
return cast(ModelType, item)
|
||||
result = cast(ModelType, item)
|
||||
if as_response:
|
||||
return Response(data=result)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
async def first(
|
||||
@@ -183,6 +245,32 @@ class AsyncCrud(Generic[ModelType]):
|
||||
result = await session.execute(q)
|
||||
return cast(Sequence[ModelType], result.unique().scalars().all())
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def update( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
obj: BaseModel,
|
||||
filters: list[Any],
|
||||
*,
|
||||
exclude_unset: bool = True,
|
||||
exclude_none: bool = False,
|
||||
as_response: Literal[True],
|
||||
) -> Response[ModelType]: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def update( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
obj: BaseModel,
|
||||
filters: list[Any],
|
||||
*,
|
||||
exclude_unset: bool = True,
|
||||
exclude_none: bool = False,
|
||||
as_response: Literal[False] = ...,
|
||||
) -> ModelType: ...
|
||||
|
||||
@classmethod
|
||||
async def update(
|
||||
cls: type[Self],
|
||||
@@ -192,7 +280,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
*,
|
||||
exclude_unset: bool = True,
|
||||
exclude_none: bool = False,
|
||||
) -> ModelType:
|
||||
as_response: bool = False,
|
||||
) -> ModelType | Response[ModelType]:
|
||||
"""Update a record in the database.
|
||||
|
||||
Args:
|
||||
@@ -201,9 +290,10 @@ class AsyncCrud(Generic[ModelType]):
|
||||
filters: List of SQLAlchemy filter conditions
|
||||
exclude_unset: Exclude fields not explicitly set in the schema
|
||||
exclude_none: Exclude fields with None value
|
||||
as_response: If True, wrap result in Response object
|
||||
|
||||
Returns:
|
||||
Updated model instance
|
||||
Updated model instance or Response wrapping it
|
||||
|
||||
Raises:
|
||||
NotFoundError: If no record found
|
||||
@@ -216,6 +306,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
for key, value in values.items():
|
||||
setattr(db_model, key, value)
|
||||
await session.refresh(db_model)
|
||||
if as_response:
|
||||
return Response(data=db_model)
|
||||
return db_model
|
||||
|
||||
@classmethod
|
||||
@@ -264,24 +356,49 @@ class AsyncCrud(Generic[ModelType]):
|
||||
)
|
||||
return cast(ModelType | None, db_model)
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def delete( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
filters: list[Any],
|
||||
*,
|
||||
as_response: Literal[True],
|
||||
) -> Response[None]: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def delete( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
filters: list[Any],
|
||||
*,
|
||||
as_response: Literal[False] = ...,
|
||||
) -> bool: ...
|
||||
|
||||
@classmethod
|
||||
async def delete(
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
filters: list[Any],
|
||||
) -> bool:
|
||||
*,
|
||||
as_response: bool = False,
|
||||
) -> bool | Response[None]:
|
||||
"""Delete records from the database.
|
||||
|
||||
Args:
|
||||
session: DB async session
|
||||
filters: List of SQLAlchemy filter conditions
|
||||
as_response: If True, wrap result in Response object
|
||||
|
||||
Returns:
|
||||
True if deletion was executed
|
||||
True if deletion was executed, or Response wrapping it
|
||||
"""
|
||||
async with get_transaction(session):
|
||||
q = sql_delete(cls.model).where(and_(*filters))
|
||||
await session.execute(q)
|
||||
if as_response:
|
||||
return Response(data=None)
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
@@ -363,7 +480,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
items_per_page: int = 20,
|
||||
search: str | SearchConfig | None = None,
|
||||
search_fields: Sequence[SearchFieldType] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
) -> PaginatedResponse[ModelType]:
|
||||
"""Get paginated results with metadata.
|
||||
|
||||
Args:
|
||||
@@ -420,7 +537,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
|
||||
q = q.offset(offset).limit(items_per_page)
|
||||
result = await session.execute(q)
|
||||
items = result.unique().scalars().all()
|
||||
items = cast(list[ModelType], result.unique().scalars().all())
|
||||
|
||||
# Count query (with same joins and filters)
|
||||
pk_col = cls.model.__mapper__.primary_key[0]
|
||||
@@ -446,15 +563,15 @@ class AsyncCrud(Generic[ModelType]):
|
||||
count_result = await session.execute(count_q)
|
||||
total_count = count_result.scalar_one()
|
||||
|
||||
return {
|
||||
"data": items,
|
||||
"pagination": {
|
||||
"total_count": total_count,
|
||||
"items_per_page": items_per_page,
|
||||
"page": page,
|
||||
"has_more": page * items_per_page < total_count,
|
||||
},
|
||||
}
|
||||
return PaginatedResponse(
|
||||
data=items,
|
||||
pagination=Pagination(
|
||||
total_count=total_count,
|
||||
items_per_page=items_per_page,
|
||||
page=page,
|
||||
has_more=page * items_per_page < total_count,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def CrudFactory(
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
"""Database utilities: sessions, transactions, and locks."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
||||
from enum import Enum
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
@@ -14,6 +16,7 @@ __all__ = [
|
||||
"create_db_dependency",
|
||||
"lock_tables",
|
||||
"get_transaction",
|
||||
"wait_for_row_change",
|
||||
]
|
||||
|
||||
|
||||
@@ -173,3 +176,69 @@ async def lock_tables(
|
||||
await session.execute(text(f"SET LOCAL lock_timeout='{timeout}'"))
|
||||
await session.execute(text(f"LOCK {table_names} IN {mode.value} MODE"))
|
||||
yield session
|
||||
|
||||
|
||||
_M = TypeVar("_M", bound=DeclarativeBase)
|
||||
|
||||
|
||||
async def wait_for_row_change(
|
||||
session: AsyncSession,
|
||||
model: type[_M],
|
||||
pk_value: Any,
|
||||
*,
|
||||
columns: list[str] | None = None,
|
||||
interval: float = 0.5,
|
||||
timeout: float | None = None,
|
||||
) -> _M:
|
||||
"""Poll a database row until a change is detected.
|
||||
|
||||
Queries the row every ``interval`` seconds and returns the model instance
|
||||
once a change is detected in any column (or only the specified ``columns``).
|
||||
|
||||
Args:
|
||||
session: AsyncSession instance
|
||||
model: SQLAlchemy model class
|
||||
pk_value: Primary key value of the row to watch
|
||||
columns: Optional list of column names to watch. If None, all columns
|
||||
are watched.
|
||||
interval: Polling interval in seconds (default: 0.5)
|
||||
timeout: Maximum time to wait in seconds. None means wait forever.
|
||||
|
||||
Returns:
|
||||
The refreshed model instance with updated values
|
||||
|
||||
Raises:
|
||||
LookupError: If the row does not exist or is deleted during polling
|
||||
TimeoutError: If timeout expires before a change is detected
|
||||
"""
|
||||
instance = await session.get(model, pk_value)
|
||||
if instance is None:
|
||||
raise LookupError(f"{model.__name__} with pk={pk_value!r} not found")
|
||||
|
||||
if columns is not None:
|
||||
watch_cols = columns
|
||||
else:
|
||||
watch_cols = [attr.key for attr in model.__mapper__.column_attrs]
|
||||
|
||||
initial = {col: getattr(instance, col) for col in watch_cols}
|
||||
|
||||
elapsed = 0.0
|
||||
while True:
|
||||
await asyncio.sleep(interval)
|
||||
elapsed += interval
|
||||
|
||||
if timeout is not None and elapsed >= timeout:
|
||||
raise TimeoutError(
|
||||
f"No change detected on {model.__name__} "
|
||||
f"with pk={pk_value!r} within {timeout}s"
|
||||
)
|
||||
|
||||
session.expunge(instance)
|
||||
instance = await session.get(model, pk_value)
|
||||
|
||||
if instance is None:
|
||||
raise LookupError(f"{model.__name__} with pk={pk_value!r} was deleted")
|
||||
|
||||
current = {col: getattr(instance, col) for col in watch_cols}
|
||||
if current != initial:
|
||||
return instance
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
"""Fixture system with dependency management and context support."""
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from ..logger import get_logger
|
||||
from .enum import Context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import logging
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, TypeVar
|
||||
|
||||
@@ -6,10 +5,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from ..db import get_transaction
|
||||
from ..logger import get_logger
|
||||
from .enum import LoadStrategy
|
||||
from .registry import Context, FixtureRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger()
|
||||
|
||||
T = TypeVar("T", bound=DeclarativeBase)
|
||||
|
||||
@@ -29,9 +29,14 @@ def get_obj_by_attr(
|
||||
The first model instance where the attribute matches the given value.
|
||||
|
||||
Raises:
|
||||
StopIteration: If no matching object is found.
|
||||
StopIteration: If no matching object is found in the fixture group.
|
||||
"""
|
||||
return next(obj for obj in fixtures() if getattr(obj, attr_name) == value)
|
||||
try:
|
||||
return next(obj for obj in fixtures() if getattr(obj, attr_name) == value)
|
||||
except StopIteration:
|
||||
raise StopIteration(
|
||||
f"No object with {attr_name}={value} found in fixture '{getattr(fixtures, '__name__', repr(fixtures))}'"
|
||||
) from None
|
||||
|
||||
|
||||
async def load_fixtures(
|
||||
|
||||
81
src/fastapi_toolsets/logger.py
Normal file
81
src/fastapi_toolsets/logger.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Logging configuration for FastAPI applications and CLI tools."""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from typing import Literal
|
||||
|
||||
__all__ = ["LogLevel", "configure_logging", "get_logger"]
|
||||
|
||||
DEFAULT_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
UVICORN_LOGGERS = ("uvicorn", "uvicorn.access", "uvicorn.error")
|
||||
|
||||
LogLevel = Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
|
||||
|
||||
|
||||
def configure_logging(
|
||||
level: LogLevel | int = "INFO",
|
||||
fmt: str = DEFAULT_FORMAT,
|
||||
logger_name: str | None = None,
|
||||
) -> logging.Logger:
|
||||
"""Configure logging with a stdout handler and consistent format.
|
||||
|
||||
Sets up a :class:`~logging.StreamHandler` writing to stdout with the
|
||||
given format and level. Also configures the uvicorn loggers so that
|
||||
FastAPI access logs use the same format.
|
||||
|
||||
Calling this function multiple times is safe -- existing handlers are
|
||||
replaced rather than duplicated.
|
||||
|
||||
Args:
|
||||
level: Log level (e.g. ``"DEBUG"``, ``"INFO"``, or ``logging.DEBUG``).
|
||||
fmt: Log format string. Defaults to
|
||||
``"%(asctime)s - %(name)s - %(levelname)s - %(message)s"``.
|
||||
logger_name: Logger name to configure. ``None`` (the default)
|
||||
configures the root logger so all loggers inherit the settings.
|
||||
|
||||
Returns:
|
||||
The configured Logger instance.
|
||||
"""
|
||||
formatter = logging.Formatter(fmt)
|
||||
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
logger = logging.getLogger(logger_name)
|
||||
logger.handlers.clear()
|
||||
logger.addHandler(handler)
|
||||
logger.setLevel(level)
|
||||
|
||||
for name in UVICORN_LOGGERS:
|
||||
uv_logger = logging.getLogger(name)
|
||||
uv_logger.handlers.clear()
|
||||
uv_logger.addHandler(handler)
|
||||
uv_logger.setLevel(level)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
_SENTINEL = object()
|
||||
|
||||
|
||||
def get_logger(name: str | None = _SENTINEL) -> logging.Logger: # type: ignore[assignment]
|
||||
"""Return a logger with the given *name*.
|
||||
|
||||
A thin convenience wrapper around :func:`logging.getLogger` that keeps
|
||||
logging imports consistent across the codebase.
|
||||
|
||||
When called without arguments, the caller's ``__name__`` is used
|
||||
automatically, so ``get_logger()`` in a module is equivalent to
|
||||
``logging.getLogger(__name__)``. Pass ``None`` explicitly to get the
|
||||
root logger.
|
||||
|
||||
Args:
|
||||
name: Logger name. Defaults to the caller's ``__name__``.
|
||||
Pass ``None`` to get the root logger.
|
||||
|
||||
Returns:
|
||||
A Logger instance.
|
||||
"""
|
||||
if name is _SENTINEL:
|
||||
name = sys._getframe(1).f_globals.get("__name__")
|
||||
return logging.getLogger(name)
|
||||
@@ -1,8 +1,17 @@
|
||||
from .plugin import register_fixtures
|
||||
from .utils import create_async_client, create_db_session
|
||||
from .utils import (
|
||||
cleanup_tables,
|
||||
create_async_client,
|
||||
create_db_session,
|
||||
create_worker_database,
|
||||
worker_database_url,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"cleanup_tables",
|
||||
"create_async_client",
|
||||
"create_db_session",
|
||||
"create_worker_database",
|
||||
"register_fixtures",
|
||||
"worker_database_url",
|
||||
]
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
"""Pytest helper utilities for FastAPI testing."""
|
||||
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine import make_url
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from ..db import create_db_context
|
||||
@@ -108,3 +115,147 @@ async def create_db_session(
|
||||
await conn.run_sync(base.metadata.drop_all)
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
def _get_xdist_worker() -> str | None:
|
||||
"""Return the pytest-xdist worker name, or ``None`` when not running under xdist.
|
||||
|
||||
Reads the ``PYTEST_XDIST_WORKER`` environment variable that xdist sets
|
||||
automatically in each worker process (e.g. ``"gw0"``, ``"gw1"``).
|
||||
When xdist is not installed or not active, the variable is absent and
|
||||
``None`` is returned.
|
||||
"""
|
||||
return os.environ.get("PYTEST_XDIST_WORKER")
|
||||
|
||||
|
||||
def worker_database_url(database_url: str) -> str:
|
||||
"""Derive a per-worker database URL for pytest-xdist parallel runs.
|
||||
|
||||
Appends ``_{worker_name}`` to the database name so each xdist worker
|
||||
operates on its own database. When not running under xdist the
|
||||
original URL is returned unchanged.
|
||||
|
||||
The worker name is read from the ``PYTEST_XDIST_WORKER`` environment
|
||||
variable (set automatically by xdist in each worker process).
|
||||
|
||||
Args:
|
||||
database_url: Original database connection URL.
|
||||
|
||||
Returns:
|
||||
A database URL with the worker-specific database name, or the
|
||||
original URL when not running under xdist.
|
||||
|
||||
Example:
|
||||
```python
|
||||
# With PYTEST_XDIST_WORKER="gw0":
|
||||
url = worker_database_url(
|
||||
"postgresql+asyncpg://user:pass@localhost/test_db"
|
||||
)
|
||||
# "postgresql+asyncpg://user:pass@localhost/test_db_gw0"
|
||||
```
|
||||
"""
|
||||
worker = _get_xdist_worker()
|
||||
if worker is None:
|
||||
return database_url
|
||||
|
||||
url = make_url(database_url)
|
||||
url = url.set(database=f"{url.database}_{worker}")
|
||||
return url.render_as_string(hide_password=False)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def create_worker_database(
|
||||
database_url: str,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Create and drop a per-worker database for pytest-xdist isolation.
|
||||
|
||||
Intended for use as a **session-scoped** fixture. Connects to the server
|
||||
using the original *database_url* (with ``AUTOCOMMIT`` isolation for DDL),
|
||||
creates a dedicated database for the worker, and yields the worker-specific
|
||||
URL. On cleanup the worker database is dropped.
|
||||
|
||||
When not running under xdist (``PYTEST_XDIST_WORKER`` is unset), the
|
||||
original URL is yielded without any database creation or teardown.
|
||||
|
||||
Args:
|
||||
database_url: Original database connection URL.
|
||||
|
||||
Yields:
|
||||
The worker-specific database URL.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from fastapi_toolsets.pytest import (
|
||||
create_worker_database, create_db_session, cleanup_tables
|
||||
)
|
||||
|
||||
DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost/test_db"
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def worker_db_url():
|
||||
async with create_worker_database(DATABASE_URL) as url:
|
||||
yield url
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session(worker_db_url):
|
||||
async with create_db_session(worker_db_url, Base) as session:
|
||||
yield session
|
||||
await cleanup_tables(session, Base)
|
||||
```
|
||||
"""
|
||||
if _get_xdist_worker() is None:
|
||||
yield database_url
|
||||
return
|
||||
|
||||
worker_url = worker_database_url(database_url)
|
||||
worker_db_name = make_url(worker_url).database
|
||||
|
||||
engine = create_async_engine(
|
||||
database_url,
|
||||
isolation_level="AUTOCOMMIT",
|
||||
)
|
||||
try:
|
||||
async with engine.connect() as conn:
|
||||
await conn.execute(text(f"DROP DATABASE IF EXISTS {worker_db_name}"))
|
||||
await conn.execute(text(f"CREATE DATABASE {worker_db_name}"))
|
||||
|
||||
yield worker_url
|
||||
|
||||
async with engine.connect() as conn:
|
||||
await conn.execute(text(f"DROP DATABASE IF EXISTS {worker_db_name}"))
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
async def cleanup_tables(
|
||||
session: AsyncSession,
|
||||
base: type[DeclarativeBase],
|
||||
) -> None:
|
||||
"""Truncate all tables for fast between-test cleanup.
|
||||
|
||||
Executes a single ``TRUNCATE … RESTART IDENTITY CASCADE`` statement
|
||||
across every table in *base*'s metadata, which is significantly faster
|
||||
than dropping and re-creating tables between tests.
|
||||
|
||||
This is a no-op when the metadata contains no tables.
|
||||
|
||||
Args:
|
||||
session: An active async database session.
|
||||
base: SQLAlchemy DeclarativeBase class containing model metadata.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@pytest.fixture
|
||||
async def db_session(worker_db_url):
|
||||
async with create_db_session(worker_db_url, Base) as session:
|
||||
yield session
|
||||
await cleanup_tables(session, Base)
|
||||
```
|
||||
"""
|
||||
tables = base.metadata.sorted_tables
|
||||
if not tables:
|
||||
return
|
||||
|
||||
table_names = ", ".join(f'"{t.name}"' for t in tables)
|
||||
await session.execute(text(f"TRUNCATE {table_names} RESTART IDENTITY CASCADE"))
|
||||
await session.commit()
|
||||
|
||||
@@ -10,6 +10,7 @@ __all__ = [
|
||||
"ErrorResponse",
|
||||
"Pagination",
|
||||
"PaginatedResponse",
|
||||
"PydanticBase",
|
||||
"Response",
|
||||
"ResponseStatus",
|
||||
]
|
||||
|
||||
@@ -429,11 +429,11 @@ class TestCrudPaginate:
|
||||
|
||||
result = await RoleCrud.paginate(db_session, page=1, items_per_page=10)
|
||||
|
||||
assert len(result["data"]) == 10
|
||||
assert result["pagination"]["total_count"] == 25
|
||||
assert result["pagination"]["page"] == 1
|
||||
assert result["pagination"]["items_per_page"] == 10
|
||||
assert result["pagination"]["has_more"] is True
|
||||
assert len(result.data) == 10
|
||||
assert result.pagination.total_count == 25
|
||||
assert result.pagination.page == 1
|
||||
assert result.pagination.items_per_page == 10
|
||||
assert result.pagination.has_more is True
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_paginate_last_page(self, db_session: AsyncSession):
|
||||
@@ -443,8 +443,8 @@ class TestCrudPaginate:
|
||||
|
||||
result = await RoleCrud.paginate(db_session, page=3, items_per_page=10)
|
||||
|
||||
assert len(result["data"]) == 5
|
||||
assert result["pagination"]["has_more"] is False
|
||||
assert len(result.data) == 5
|
||||
assert result.pagination.has_more is False
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_paginate_with_filters(self, db_session: AsyncSession):
|
||||
@@ -466,7 +466,7 @@ class TestCrudPaginate:
|
||||
items_per_page=10,
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 5
|
||||
assert result.pagination.total_count == 5
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_paginate_with_ordering(self, db_session: AsyncSession):
|
||||
@@ -482,7 +482,7 @@ class TestCrudPaginate:
|
||||
items_per_page=10,
|
||||
)
|
||||
|
||||
names = [r.name for r in result["data"]]
|
||||
names = [r.name for r in result.data]
|
||||
assert names == ["alpha", "bravo", "charlie"]
|
||||
|
||||
|
||||
@@ -690,8 +690,8 @@ class TestCrudJoins:
|
||||
items_per_page=10,
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 3
|
||||
assert len(result["data"]) == 3
|
||||
assert result.pagination.total_count == 3
|
||||
assert len(result.data) == 3
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_paginate_with_outer_join(self, db_session: AsyncSession):
|
||||
@@ -721,8 +721,8 @@ class TestCrudJoins:
|
||||
items_per_page=10,
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 2
|
||||
assert len(result["data"]) == 2
|
||||
assert result.pagination.total_count == 2
|
||||
assert len(result.data) == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_multiple_joins(self, db_session: AsyncSession):
|
||||
@@ -752,3 +752,63 @@ class TestCrudJoins:
|
||||
)
|
||||
assert len(users) == 1
|
||||
assert users[0].username == "multi_join"
|
||||
|
||||
|
||||
class TestAsResponse:
|
||||
"""Tests for as_response parameter."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_as_response(self, db_session: AsyncSession):
|
||||
"""Create with as_response=True returns Response."""
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
data = RoleCreate(name="response_role")
|
||||
result = await RoleCrud.create(db_session, data, as_response=True)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert result.data is not None
|
||||
assert result.data.name == "response_role"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_as_response(self, db_session: AsyncSession):
|
||||
"""Get with as_response=True returns Response."""
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
created = await RoleCrud.create(db_session, RoleCreate(name="get_response"))
|
||||
result = await RoleCrud.get(
|
||||
db_session, [Role.id == created.id], as_response=True
|
||||
)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert result.data is not None
|
||||
assert result.data.id == created.id
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_as_response(self, db_session: AsyncSession):
|
||||
"""Update with as_response=True returns Response."""
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
created = await RoleCrud.create(db_session, RoleCreate(name="old_name"))
|
||||
result = await RoleCrud.update(
|
||||
db_session,
|
||||
RoleUpdate(name="new_name"),
|
||||
[Role.id == created.id],
|
||||
as_response=True,
|
||||
)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert result.data is not None
|
||||
assert result.data.name == "new_name"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_as_response(self, db_session: AsyncSession):
|
||||
"""Delete with as_response=True returns Response."""
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
created = await RoleCrud.create(db_session, RoleCreate(name="to_delete"))
|
||||
result = await RoleCrud.delete(
|
||||
db_session, [Role.id == created.id], as_response=True
|
||||
)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert result.data is None
|
||||
|
||||
@@ -39,7 +39,7 @@ class TestPaginateSearch:
|
||||
search_fields=[User.username],
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 2
|
||||
assert result.pagination.total_count == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_multiple_columns(self, db_session: AsyncSession):
|
||||
@@ -57,7 +57,7 @@ class TestPaginateSearch:
|
||||
search_fields=[User.username, User.email],
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 2
|
||||
assert result.pagination.total_count == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_relationship_depth1(self, db_session: AsyncSession):
|
||||
@@ -84,7 +84,7 @@ class TestPaginateSearch:
|
||||
search_fields=[(User.role, Role.name)],
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 2
|
||||
assert result.pagination.total_count == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_mixed_direct_and_relation(self, db_session: AsyncSession):
|
||||
@@ -102,7 +102,7 @@ class TestPaginateSearch:
|
||||
search_fields=[User.username, (User.role, Role.name)],
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 1
|
||||
assert result.pagination.total_count == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_case_insensitive(self, db_session: AsyncSession):
|
||||
@@ -117,7 +117,7 @@ class TestPaginateSearch:
|
||||
search_fields=[User.username],
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 1
|
||||
assert result.pagination.total_count == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_case_sensitive(self, db_session: AsyncSession):
|
||||
@@ -132,7 +132,7 @@ class TestPaginateSearch:
|
||||
search=SearchConfig(query="johndoe", case_sensitive=True),
|
||||
search_fields=[User.username],
|
||||
)
|
||||
assert result["pagination"]["total_count"] == 0
|
||||
assert result.pagination.total_count == 0
|
||||
|
||||
# Should find (case match)
|
||||
result = await UserCrud.paginate(
|
||||
@@ -140,7 +140,7 @@ class TestPaginateSearch:
|
||||
search=SearchConfig(query="JohnDoe", case_sensitive=True),
|
||||
search_fields=[User.username],
|
||||
)
|
||||
assert result["pagination"]["total_count"] == 1
|
||||
assert result.pagination.total_count == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_empty_query(self, db_session: AsyncSession):
|
||||
@@ -153,10 +153,10 @@ class TestPaginateSearch:
|
||||
)
|
||||
|
||||
result = await UserCrud.paginate(db_session, search="")
|
||||
assert result["pagination"]["total_count"] == 2
|
||||
assert result.pagination.total_count == 2
|
||||
|
||||
result = await UserCrud.paginate(db_session, search=None)
|
||||
assert result["pagination"]["total_count"] == 2
|
||||
assert result.pagination.total_count == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_with_existing_filters(self, db_session: AsyncSession):
|
||||
@@ -177,8 +177,8 @@ class TestPaginateSearch:
|
||||
search_fields=[User.username],
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 1
|
||||
assert result["data"][0].username == "active_john"
|
||||
assert result.pagination.total_count == 1
|
||||
assert result.data[0].username == "active_john"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_auto_detect_fields(self, db_session: AsyncSession):
|
||||
@@ -189,7 +189,7 @@ class TestPaginateSearch:
|
||||
|
||||
result = await UserCrud.paginate(db_session, search="findme")
|
||||
|
||||
assert result["pagination"]["total_count"] == 1
|
||||
assert result.pagination.total_count == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_no_results(self, db_session: AsyncSession):
|
||||
@@ -204,8 +204,8 @@ class TestPaginateSearch:
|
||||
search_fields=[User.username],
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 0
|
||||
assert result["data"] == []
|
||||
assert result.pagination.total_count == 0
|
||||
assert result.data == []
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_with_pagination(self, db_session: AsyncSession):
|
||||
@@ -224,9 +224,9 @@ class TestPaginateSearch:
|
||||
items_per_page=5,
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 15
|
||||
assert len(result["data"]) == 5
|
||||
assert result["pagination"]["has_more"] is True
|
||||
assert result.pagination.total_count == 15
|
||||
assert len(result.data) == 5
|
||||
assert result.pagination.has_more is True
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_null_relationship(self, db_session: AsyncSession):
|
||||
@@ -248,7 +248,7 @@ class TestPaginateSearch:
|
||||
search_fields=[User.username],
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 2
|
||||
assert result.pagination.total_count == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_with_order_by(self, db_session: AsyncSession):
|
||||
@@ -270,8 +270,8 @@ class TestPaginateSearch:
|
||||
order_by=User.username,
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 3
|
||||
usernames = [u.username for u in result["data"]]
|
||||
assert result.pagination.total_count == 3
|
||||
usernames = [u.username for u in result.data]
|
||||
assert usernames == ["alice", "bob", "charlie"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -292,8 +292,8 @@ class TestPaginateSearch:
|
||||
search_fields=[User.id, User.username],
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 1
|
||||
assert result["data"][0].id == user_id
|
||||
assert result.pagination.total_count == 1
|
||||
assert result.data[0].id == user_id
|
||||
|
||||
|
||||
class TestSearchConfig:
|
||||
@@ -318,8 +318,8 @@ class TestSearchConfig:
|
||||
search_fields=[User.username, User.email],
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 1
|
||||
assert result["data"][0].username == "john_test"
|
||||
assert result.pagination.total_count == 1
|
||||
assert result.data[0].username == "john_test"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_config_with_fields(self, db_session: AsyncSession):
|
||||
@@ -333,7 +333,7 @@ class TestSearchConfig:
|
||||
search=SearchConfig(query="findme", fields=[User.email]),
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 1
|
||||
assert result.pagination.total_count == 1
|
||||
|
||||
|
||||
class TestNoSearchableFieldsError:
|
||||
|
||||
102
tests/test_db.py
102
tests/test_db.py
@@ -1,5 +1,8 @@
|
||||
"""Tests for fastapi_toolsets.db module."""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
@@ -9,6 +12,7 @@ from fastapi_toolsets.db import (
|
||||
create_db_dependency,
|
||||
get_transaction,
|
||||
lock_tables,
|
||||
wait_for_row_change,
|
||||
)
|
||||
|
||||
from .conftest import DATABASE_URL, Base, Role, RoleCrud, User
|
||||
@@ -241,3 +245,101 @@ class TestLockTables:
|
||||
|
||||
result = await RoleCrud.first(db_session, [Role.name == "lock_rollback_role"])
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestWaitForRowChange:
|
||||
"""Tests for wait_for_row_change polling function."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_detects_update(self, db_session: AsyncSession, engine):
|
||||
"""Returns updated instance when a column value changes."""
|
||||
role = Role(name="watch_role")
|
||||
db_session.add(role)
|
||||
await db_session.commit()
|
||||
|
||||
async def update_later():
|
||||
await asyncio.sleep(0.15)
|
||||
factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
async with factory() as other:
|
||||
r = await other.get(Role, role.id)
|
||||
assert r is not None
|
||||
r.name = "updated_role"
|
||||
await other.commit()
|
||||
|
||||
update_task = asyncio.create_task(update_later())
|
||||
result = await wait_for_row_change(db_session, Role, role.id, interval=0.05)
|
||||
await update_task
|
||||
|
||||
assert result.name == "updated_role"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_watches_specific_columns(self, db_session: AsyncSession, engine):
|
||||
"""Only triggers on changes to specified columns."""
|
||||
user = User(username="testuser", email="test@example.com")
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
|
||||
async def update_later():
|
||||
factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
# First: change email (not watched) — should not trigger
|
||||
await asyncio.sleep(0.15)
|
||||
async with factory() as other:
|
||||
u = await other.get(User, user.id)
|
||||
assert u is not None
|
||||
u.email = "new@example.com"
|
||||
await other.commit()
|
||||
# Second: change username (watched) — should trigger
|
||||
await asyncio.sleep(0.15)
|
||||
async with factory() as other:
|
||||
u = await other.get(User, user.id)
|
||||
assert u is not None
|
||||
u.username = "newuser"
|
||||
await other.commit()
|
||||
|
||||
update_task = asyncio.create_task(update_later())
|
||||
result = await wait_for_row_change(
|
||||
db_session, User, user.id, columns=["username"], interval=0.05
|
||||
)
|
||||
await update_task
|
||||
|
||||
assert result.username == "newuser"
|
||||
assert result.email == "new@example.com"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_nonexistent_row_raises(self, db_session: AsyncSession):
|
||||
"""Raises LookupError when the row does not exist."""
|
||||
fake_id = uuid.uuid4()
|
||||
with pytest.raises(LookupError, match="not found"):
|
||||
await wait_for_row_change(db_session, Role, fake_id, interval=0.05)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_timeout_raises(self, db_session: AsyncSession):
|
||||
"""Raises TimeoutError when no change is detected within timeout."""
|
||||
role = Role(name="timeout_role")
|
||||
db_session.add(role)
|
||||
await db_session.commit()
|
||||
|
||||
with pytest.raises(TimeoutError):
|
||||
await wait_for_row_change(
|
||||
db_session, Role, role.id, interval=0.05, timeout=0.2
|
||||
)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_deleted_row_raises(self, db_session: AsyncSession, engine):
|
||||
"""Raises LookupError when the row is deleted during polling."""
|
||||
role = Role(name="delete_role")
|
||||
db_session.add(role)
|
||||
await db_session.commit()
|
||||
|
||||
async def delete_later():
|
||||
await asyncio.sleep(0.15)
|
||||
factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
async with factory() as other:
|
||||
r = await other.get(Role, role.id)
|
||||
await other.delete(r)
|
||||
await other.commit()
|
||||
|
||||
delete_task = asyncio.create_task(delete_later())
|
||||
with pytest.raises(LookupError):
|
||||
await wait_for_row_change(db_session, Role, role.id, interval=0.05)
|
||||
await delete_task
|
||||
|
||||
@@ -744,8 +744,11 @@ class TestGetObjByAttr:
|
||||
assert user.username == "alice"
|
||||
|
||||
def test_no_match_raises_stop_iteration(self):
|
||||
"""Raises StopIteration when no object matches."""
|
||||
with pytest.raises(StopIteration):
|
||||
"""Raises StopIteration with contextual message when no object matches."""
|
||||
with pytest.raises(
|
||||
StopIteration,
|
||||
match="No object with name=nonexistent found in fixture 'roles'",
|
||||
):
|
||||
get_obj_by_attr(self.roles, "name", "nonexistent")
|
||||
|
||||
def test_no_match_on_wrong_value_type(self):
|
||||
|
||||
118
tests/test_logger.py
Normal file
118
tests/test_logger.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from fastapi_toolsets.logger import (
|
||||
DEFAULT_FORMAT,
|
||||
UVICORN_LOGGERS,
|
||||
configure_logging,
|
||||
get_logger,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_loggers():
|
||||
"""Reset the root and uvicorn loggers after each test."""
|
||||
yield
|
||||
root = logging.getLogger()
|
||||
root.handlers.clear()
|
||||
root.setLevel(logging.WARNING)
|
||||
for name in UVICORN_LOGGERS:
|
||||
uv = logging.getLogger(name)
|
||||
uv.handlers.clear()
|
||||
uv.setLevel(logging.NOTSET)
|
||||
|
||||
|
||||
class TestConfigureLogging:
|
||||
def test_sets_up_handler_and_format(self):
|
||||
logger = configure_logging()
|
||||
|
||||
assert len(logger.handlers) == 1
|
||||
handler = logger.handlers[0]
|
||||
assert isinstance(handler, logging.StreamHandler)
|
||||
assert handler.stream is sys.stdout
|
||||
assert handler.formatter is not None
|
||||
assert handler.formatter._fmt == DEFAULT_FORMAT
|
||||
|
||||
def test_default_level_is_info(self):
|
||||
logger = configure_logging()
|
||||
|
||||
assert logger.level == logging.INFO
|
||||
|
||||
def test_custom_level_string(self):
|
||||
logger = configure_logging(level="DEBUG")
|
||||
|
||||
assert logger.level == logging.DEBUG
|
||||
|
||||
def test_custom_level_int(self):
|
||||
logger = configure_logging(level=logging.WARNING)
|
||||
|
||||
assert logger.level == logging.WARNING
|
||||
|
||||
def test_custom_format(self):
|
||||
custom_fmt = "%(levelname)s: %(message)s"
|
||||
logger = configure_logging(fmt=custom_fmt)
|
||||
|
||||
handler = logger.handlers[0]
|
||||
assert handler.formatter is not None
|
||||
assert handler.formatter._fmt == custom_fmt
|
||||
|
||||
def test_named_logger(self):
|
||||
logger = configure_logging(logger_name="myapp")
|
||||
|
||||
assert logger.name == "myapp"
|
||||
assert len(logger.handlers) == 1
|
||||
|
||||
def test_default_configures_root_logger(self):
|
||||
logger = configure_logging()
|
||||
|
||||
assert logger is logging.getLogger()
|
||||
|
||||
def test_idempotent_no_duplicate_handlers(self):
|
||||
configure_logging()
|
||||
configure_logging()
|
||||
logger = configure_logging()
|
||||
|
||||
assert len(logger.handlers) == 1
|
||||
|
||||
def test_configures_uvicorn_loggers(self):
|
||||
configure_logging(level="DEBUG")
|
||||
|
||||
for name in UVICORN_LOGGERS:
|
||||
uv_logger = logging.getLogger(name)
|
||||
assert len(uv_logger.handlers) == 1
|
||||
assert uv_logger.level == logging.DEBUG
|
||||
handler = uv_logger.handlers[0]
|
||||
assert handler.formatter is not None
|
||||
assert handler.formatter._fmt == DEFAULT_FORMAT
|
||||
|
||||
def test_returns_configured_logger(self):
|
||||
logger = configure_logging(logger_name="test.return")
|
||||
|
||||
assert isinstance(logger, logging.Logger)
|
||||
assert logger.name == "test.return"
|
||||
|
||||
|
||||
class TestGetLogger:
|
||||
def test_returns_named_logger(self):
|
||||
logger = get_logger("myapp.services")
|
||||
|
||||
assert isinstance(logger, logging.Logger)
|
||||
assert logger.name == "myapp.services"
|
||||
|
||||
def test_returns_root_logger_when_none(self):
|
||||
logger = get_logger(None)
|
||||
|
||||
assert logger is logging.getLogger()
|
||||
|
||||
def test_defaults_to_caller_module_name(self):
|
||||
logger = get_logger()
|
||||
|
||||
assert logger.name == __name__
|
||||
|
||||
def test_same_name_returns_same_logger(self):
|
||||
a = get_logger("myapp")
|
||||
b = get_logger("myapp")
|
||||
|
||||
assert a is b
|
||||
@@ -5,16 +5,21 @@ import uuid
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from httpx import AsyncClient
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.engine import make_url
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, selectinload
|
||||
|
||||
from fastapi_toolsets.fixtures import Context, FixtureRegistry
|
||||
from fastapi_toolsets.pytest import (
|
||||
cleanup_tables,
|
||||
create_async_client,
|
||||
create_db_session,
|
||||
create_worker_database,
|
||||
register_fixtures,
|
||||
worker_database_url,
|
||||
)
|
||||
from fastapi_toolsets.pytest.utils import _get_xdist_worker
|
||||
|
||||
from .conftest import DATABASE_URL, Base, Role, RoleCrud, User, UserCrud
|
||||
|
||||
@@ -291,3 +296,164 @@ class TestCreateDbSession:
|
||||
# Cleanup: drop tables manually
|
||||
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as _:
|
||||
pass
|
||||
|
||||
|
||||
class TestGetXdistWorker:
|
||||
"""Tests for get_xdist_worker helper."""
|
||||
|
||||
def test_returns_none_without_env_var(self, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Returns None when PYTEST_XDIST_WORKER is not set."""
|
||||
monkeypatch.delenv("PYTEST_XDIST_WORKER", raising=False)
|
||||
assert _get_xdist_worker() is None
|
||||
|
||||
def test_returns_worker_name(self, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Returns the worker name from the environment variable."""
|
||||
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw0")
|
||||
assert _get_xdist_worker() == "gw0"
|
||||
|
||||
|
||||
class TestWorkerDatabaseUrl:
|
||||
"""Tests for worker_database_url helper."""
|
||||
|
||||
def test_returns_original_url_without_xdist(self, monkeypatch: pytest.MonkeyPatch):
|
||||
"""URL is returned unchanged when not running under xdist."""
|
||||
monkeypatch.delenv("PYTEST_XDIST_WORKER", raising=False)
|
||||
url = "postgresql+asyncpg://user:pass@localhost:5432/mydb"
|
||||
assert worker_database_url(url) == url
|
||||
|
||||
def test_appends_worker_id_to_database_name(self, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Worker name is appended to the database name."""
|
||||
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw0")
|
||||
url = "postgresql+asyncpg://user:pass@localhost:5432/db"
|
||||
result = worker_database_url(url)
|
||||
assert make_url(result).database == "db_gw0"
|
||||
|
||||
def test_preserves_url_components(self, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Host, port, username, password, and driver are preserved."""
|
||||
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw2")
|
||||
url = "postgresql+asyncpg://myuser:secret@dbhost:6543/testdb"
|
||||
result = make_url(worker_database_url(url))
|
||||
|
||||
assert result.drivername == "postgresql+asyncpg"
|
||||
assert result.username == "myuser"
|
||||
assert result.password == "secret"
|
||||
assert result.host == "dbhost"
|
||||
assert result.port == 6543
|
||||
assert result.database == "testdb_gw2"
|
||||
|
||||
|
||||
class TestCreateWorkerDatabase:
|
||||
"""Tests for create_worker_database context manager."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_yields_original_url_without_xdist(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
"""Without xdist, yields the original URL without database operations."""
|
||||
monkeypatch.delenv("PYTEST_XDIST_WORKER", raising=False)
|
||||
async with create_worker_database(DATABASE_URL) as url:
|
||||
assert url == DATABASE_URL
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_creates_and_drops_worker_database(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
"""Worker database exists inside the context and is dropped after."""
|
||||
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw_test_create")
|
||||
expected_db = make_url(worker_database_url(DATABASE_URL)).database
|
||||
|
||||
async with create_worker_database(DATABASE_URL) as url:
|
||||
assert make_url(url).database == expected_db
|
||||
|
||||
# Verify the database exists while inside the context
|
||||
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
|
||||
async with engine.connect() as conn:
|
||||
result = await conn.execute(
|
||||
text("SELECT 1 FROM pg_database WHERE datname = :name"),
|
||||
{"name": expected_db},
|
||||
)
|
||||
assert result.scalar() == 1
|
||||
await engine.dispose()
|
||||
|
||||
# After context exit the database should be dropped
|
||||
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
|
||||
async with engine.connect() as conn:
|
||||
result = await conn.execute(
|
||||
text("SELECT 1 FROM pg_database WHERE datname = :name"),
|
||||
{"name": expected_db},
|
||||
)
|
||||
assert result.scalar() is None
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_cleans_up_stale_database(self, monkeypatch: pytest.MonkeyPatch):
|
||||
"""A pre-existing worker database is dropped and recreated."""
|
||||
monkeypatch.setenv("PYTEST_XDIST_WORKER", "gw_test_stale")
|
||||
expected_db = make_url(worker_database_url(DATABASE_URL)).database
|
||||
|
||||
# Pre-create the database to simulate a stale leftover
|
||||
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
|
||||
async with engine.connect() as conn:
|
||||
await conn.execute(text(f"DROP DATABASE IF EXISTS {expected_db}"))
|
||||
await conn.execute(text(f"CREATE DATABASE {expected_db}"))
|
||||
await engine.dispose()
|
||||
|
||||
# Should succeed despite the database already existing
|
||||
async with create_worker_database(DATABASE_URL) as url:
|
||||
assert make_url(url).database == expected_db
|
||||
|
||||
# Verify cleanup after context exit
|
||||
engine = create_async_engine(DATABASE_URL, isolation_level="AUTOCOMMIT")
|
||||
async with engine.connect() as conn:
|
||||
result = await conn.execute(
|
||||
text("SELECT 1 FROM pg_database WHERE datname = :name"),
|
||||
{"name": expected_db},
|
||||
)
|
||||
assert result.scalar() is None
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
class TestCleanupTables:
|
||||
"""Tests for cleanup_tables helper."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_truncates_all_tables(self):
|
||||
"""All table rows are removed after cleanup_tables."""
|
||||
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
|
||||
role = Role(id=uuid.uuid4(), name="cleanup_role")
|
||||
session.add(role)
|
||||
await session.flush()
|
||||
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
username="cleanup_user",
|
||||
email="cleanup@test.com",
|
||||
role_id=role.id,
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
|
||||
# Verify rows exist
|
||||
roles_count = await RoleCrud.count(session)
|
||||
users_count = await UserCrud.count(session)
|
||||
assert roles_count == 1
|
||||
assert users_count == 1
|
||||
|
||||
await cleanup_tables(session, Base)
|
||||
|
||||
# Verify tables are empty
|
||||
roles_count = await RoleCrud.count(session)
|
||||
users_count = await UserCrud.count(session)
|
||||
assert roles_count == 0
|
||||
assert users_count == 0
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_noop_for_empty_metadata(self):
|
||||
"""cleanup_tables does not raise when metadata has no tables."""
|
||||
|
||||
class EmptyBase(DeclarativeBase):
|
||||
pass
|
||||
|
||||
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
|
||||
# Should not raise
|
||||
await cleanup_tables(session, EmptyBase)
|
||||
|
||||
33
uv.lock
generated
33
uv.lock
generated
@@ -203,9 +203,18 @@ toml = [
|
||||
{ name = "tomli", marker = "python_full_version <= '3.11'" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "execnet"
|
||||
version = "2.1.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/bf/89/780e11f9588d9e7128a3f87788354c7946a9cbb1401ad38a48c4db9a4f07/execnet-2.1.2.tar.gz", hash = "sha256:63d83bfdd9a23e35b9c6a3261412324f964c2ec8dcd8d3c6916ee9373e0befcd", size = 166622, upload-time = "2025-11-12T09:56:37.75Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ab/84/02fc1827e8cdded4aa65baef11296a9bbe595c474f0d6d758af082d849fd/execnet-2.1.2-py3-none-any.whl", hash = "sha256:67fba928dd5a544b783f6056f449e5e3931a5c378b128bc18501f7ea79e296ec", size = 40708, upload-time = "2025-11-12T09:56:36.333Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fastapi"
|
||||
version = "0.128.0"
|
||||
version = "0.128.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "annotated-doc" },
|
||||
@@ -213,14 +222,14 @@ dependencies = [
|
||||
{ name = "starlette" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/52/08/8c8508db6c7b9aae8f7175046af41baad690771c9bcde676419965e338c7/fastapi-0.128.0.tar.gz", hash = "sha256:1cc179e1cef10a6be60ffe429f79b829dce99d8de32d7acb7e6c8dfdf7f2645a", size = 365682, upload-time = "2025-12-27T15:21:13.714Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f6/59/28bde150415783ff084334e3de106eb7461a57864cf69f343950ad5a5ddd/fastapi-0.128.1.tar.gz", hash = "sha256:ce5be4fa26d4ce6f54debcc873d1fb8e0e248f5c48d7502ba6c61457ab2dc766", size = 374260, upload-time = "2026-02-04T17:35:10.542Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/5c/05/5cbb59154b093548acd0f4c7c474a118eda06da25aa75c616b72d8fcd92a/fastapi-0.128.0-py3-none-any.whl", hash = "sha256:aebd93f9716ee3b4f4fcfe13ffb7cf308d99c9f3ab5622d8877441072561582d", size = 103094, upload-time = "2025-12-27T15:21:12.154Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1a/08/3953db1979ea131c68279b997c6465080118b407f0800445b843f8e164b3/fastapi-0.128.1-py3-none-any.whl", hash = "sha256:ee82146bbf91ea5bbf2bb8629e4c6e056c4fbd997ea6068501b11b15260b50fb", size = 103810, upload-time = "2026-02-04T17:35:08.02Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fastapi-toolsets"
|
||||
version = "0.6.1"
|
||||
version = "0.8.0"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "asyncpg" },
|
||||
@@ -237,6 +246,7 @@ dev = [
|
||||
{ name = "pytest" },
|
||||
{ name = "pytest-anyio" },
|
||||
{ name = "pytest-cov" },
|
||||
{ name = "pytest-xdist" },
|
||||
{ name = "ruff" },
|
||||
{ name = "ty" },
|
||||
]
|
||||
@@ -245,6 +255,7 @@ test = [
|
||||
{ name = "pytest" },
|
||||
{ name = "pytest-anyio" },
|
||||
{ name = "pytest-cov" },
|
||||
{ name = "pytest-xdist" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
@@ -258,6 +269,7 @@ requires-dist = [
|
||||
{ name = "pytest", marker = "extra == 'test'", specifier = ">=8.0.0" },
|
||||
{ name = "pytest-anyio", marker = "extra == 'test'", specifier = ">=0.0.0" },
|
||||
{ name = "pytest-cov", marker = "extra == 'test'", specifier = ">=4.0.0" },
|
||||
{ name = "pytest-xdist", marker = "extra == 'test'", specifier = ">=3.0.0" },
|
||||
{ name = "ruff", marker = "extra == 'dev'", specifier = ">=0.1.0" },
|
||||
{ name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0" },
|
||||
{ name = "ty", marker = "extra == 'dev'", specifier = ">=0.0.1a0" },
|
||||
@@ -575,6 +587,19 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ee/49/1377b49de7d0c1ce41292161ea0f721913fa8722c19fb9c1e3aa0367eecb/pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861", size = 22424, upload-time = "2025-09-09T10:57:00.695Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-xdist"
|
||||
version = "3.8.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "execnet" },
|
||||
{ name = "pytest" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/78/b4/439b179d1ff526791eb921115fca8e44e596a13efeda518b9d845a619450/pytest_xdist-3.8.0.tar.gz", hash = "sha256:7e578125ec9bc6050861aa93f2d59f1d8d085595d6551c2c90b6f4fad8d3a9f1", size = 88069, upload-time = "2025-07-01T13:30:59.346Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ca/31/d4e37e9e550c2b92a9cbc2e4d0b7420a27224968580b5a447f420847c975/pytest_xdist-3.8.0-py3-none-any.whl", hash = "sha256:202ca578cfeb7370784a8c33d6d05bc6e13b4f25b5053c30a152269fd10f0b88", size = 46396, upload-time = "2025-07-01T13:30:56.632Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rich"
|
||||
version = "14.3.2"
|
||||
|
||||
Reference in New Issue
Block a user