Initial commit

This commit is contained in:
2026-01-25 16:11:44 +01:00
commit 762ed35341
29 changed files with 5072 additions and 0 deletions

View File

@@ -0,0 +1,24 @@
"""FastAPI utilities package.
Provides CRUD operations, fixtures, CLI, and standardized API responses
for FastAPI with async SQLAlchemy and PostgreSQL.
Example usage:
from fastapi import FastAPI, Depends
from fastapi_toolsets.exceptions import init_exceptions_handlers
from fastapi_toolsets.crud import CrudFactory
from fastapi_toolsets.db import create_db_dependency
from fastapi_toolsets.schemas import Response
app = FastAPI()
init_exceptions_handlers(app)
UserCrud = CrudFactory(User)
@app.get("/users/{user_id}", response_model=Response[dict])
async def get_user(user_id: int, session = Depends(get_db)):
user = await UserCrud.get(session, [User.id == user_id])
return Response(data={"user": user.username}, message="Success")
"""
__version__ = "0.1.0"

View File

@@ -0,0 +1,5 @@
"""CLI for FastAPI projects."""
from .app import app, register_command
__all__ = ["app", "register_command"]

View File

@@ -0,0 +1,95 @@
"""Main CLI application."""
import importlib.util
import sys
from pathlib import Path
from typing import Annotated
import typer
from .commands import fixtures
app = typer.Typer(
name="fastapi-utils",
help="CLI utilities for FastAPI projects.",
no_args_is_help=True,
)
# Register built-in commands
app.add_typer(fixtures.app, name="fixtures")
def register_command(command: typer.Typer, name: str) -> None:
"""Register a custom command group.
Args:
command: Typer app for the command group
name: Name for the command group
Example:
# In your project's cli.py:
import typer
from fastapi_toolsets.cli import app, register_command
my_commands = typer.Typer()
@my_commands.command()
def seed():
'''Seed the database.'''
...
register_command(my_commands, "db")
# Now available as: fastapi-utils db seed
"""
app.add_typer(command, name=name)
@app.callback()
def main(
ctx: typer.Context,
config: Annotated[
Path | None,
typer.Option(
"--config",
"-c",
help="Path to project config file (Python module with fixtures registry).",
envvar="FASTAPI_TOOLSETS_CONFIG",
),
] = None,
) -> None:
"""FastAPI utilities CLI."""
ctx.ensure_object(dict)
if config:
ctx.obj["config_path"] = config
# Load the config module
config_module = _load_module_from_path(config)
ctx.obj["config_module"] = config_module
def _load_module_from_path(path: Path) -> object:
"""Load a Python module from a file path.
Handles both absolute and relative imports by adding the config's
parent directory to sys.path temporarily.
"""
path = path.resolve()
# Add the parent directory to sys.path to support relative imports
parent_dir = str(path.parent.parent) # Go up two levels (e.g., from app/cli_config.py to project root)
if parent_dir not in sys.path:
sys.path.insert(0, parent_dir)
# Also add immediate parent for direct module imports
immediate_parent = str(path.parent)
if immediate_parent not in sys.path:
sys.path.insert(0, immediate_parent)
spec = importlib.util.spec_from_file_location("config", path)
if spec is None or spec.loader is None:
raise typer.BadParameter(f"Cannot load module from {path}")
module = importlib.util.module_from_spec(spec)
sys.modules["config"] = module
spec.loader.exec_module(module)
return module

View File

@@ -0,0 +1 @@
"""Built-in CLI commands."""

View File

@@ -0,0 +1,211 @@
"""Fixture management commands."""
import asyncio
from typing import Annotated
import typer
from ...fixtures import Context, FixtureRegistry, LoadStrategy, load_fixtures_by_context
app = typer.Typer(
name="fixtures",
help="Manage database fixtures.",
no_args_is_help=True,
)
def _get_registry(ctx: typer.Context) -> FixtureRegistry:
"""Get fixture registry from context."""
config = ctx.obj.get("config_module") if ctx.obj else None
if config is None:
raise typer.BadParameter(
"No config provided. Use --config to specify a config file with a 'fixtures' registry."
)
registry = getattr(config, "fixtures", None)
if registry is None:
raise typer.BadParameter(
"Config module must have a 'fixtures' attribute (FixtureRegistry instance)."
)
if not isinstance(registry, FixtureRegistry):
raise typer.BadParameter(
f"'fixtures' must be a FixtureRegistry instance, got {type(registry).__name__}"
)
return registry
def _get_db_context(ctx: typer.Context):
"""Get database context manager from config."""
config = ctx.obj.get("config_module") if ctx.obj else None
if config is None:
raise typer.BadParameter("No config provided.")
get_db_context = getattr(config, "get_db_context", None)
if get_db_context is None:
raise typer.BadParameter(
"Config module must have a 'get_db_context' function."
)
return get_db_context
@app.command("list")
def list_fixtures(
ctx: typer.Context,
context: Annotated[
str | None,
typer.Option("--context", "-c", help="Filter by context (base, production, development, testing)."),
] = None,
) -> None:
"""List all registered fixtures."""
registry = _get_registry(ctx)
if context:
fixtures = registry.get_by_context(context)
else:
fixtures = registry.get_all()
if not fixtures:
typer.echo("No fixtures found.")
return
typer.echo(f"\n{'Name':<30} {'Contexts':<30} {'Dependencies'}")
typer.echo("-" * 80)
for fixture in fixtures:
contexts = ", ".join(fixture.contexts)
deps = ", ".join(fixture.depends_on) if fixture.depends_on else "-"
typer.echo(f"{fixture.name:<30} {contexts:<30} {deps}")
typer.echo(f"\nTotal: {len(fixtures)} fixture(s)")
@app.command("graph")
def show_graph(
ctx: typer.Context,
fixture_name: Annotated[
str | None,
typer.Argument(help="Show dependencies for a specific fixture."),
] = None,
) -> None:
"""Show fixture dependency graph."""
registry = _get_registry(ctx)
if fixture_name:
try:
order = registry.resolve_dependencies(fixture_name)
typer.echo(f"\nDependency chain for '{fixture_name}':\n")
for i, name in enumerate(order):
indent = " " * i
arrow = "└─> " if i > 0 else ""
typer.echo(f"{indent}{arrow}{name}")
except KeyError:
typer.echo(f"Fixture '{fixture_name}' not found.", err=True)
raise typer.Exit(1)
else:
# Show full graph
fixtures = registry.get_all()
typer.echo("\nFixture Dependency Graph:\n")
for fixture in fixtures:
deps = f" -> [{', '.join(fixture.depends_on)}]" if fixture.depends_on else ""
typer.echo(f" {fixture.name}{deps}")
@app.command("load")
def load(
ctx: typer.Context,
contexts: Annotated[
list[str] | None,
typer.Argument(help="Contexts to load (base, production, development, testing)."),
] = None,
strategy: Annotated[
str,
typer.Option("--strategy", "-s", help="Load strategy: merge, insert, skip_existing."),
] = "merge",
dry_run: Annotated[
bool,
typer.Option("--dry-run", "-n", help="Show what would be loaded without loading."),
] = False,
) -> None:
"""Load fixtures into the database."""
registry = _get_registry(ctx)
get_db_context = _get_db_context(ctx)
# Parse contexts
if contexts:
context_list = contexts
else:
context_list = [Context.BASE]
# Parse strategy
try:
load_strategy = LoadStrategy(strategy)
except ValueError:
typer.echo(f"Invalid strategy: {strategy}. Use: merge, insert, skip_existing", err=True)
raise typer.Exit(1)
# Resolve what will be loaded
ordered = registry.resolve_context_dependencies(*context_list)
if not ordered:
typer.echo("No fixtures to load for the specified context(s).")
return
typer.echo(f"\nFixtures to load ({load_strategy.value} strategy):")
for name in ordered:
fixture = registry.get(name)
instances = list(fixture.func())
model_name = type(instances[0]).__name__ if instances else "?"
typer.echo(f" - {name}: {len(instances)} {model_name}(s)")
if dry_run:
typer.echo("\n[Dry run - no changes made]")
return
typer.echo("\nLoading...")
async def do_load():
async with get_db_context() as session:
result = await load_fixtures_by_context(
session, registry, *context_list, strategy=load_strategy
)
return result
result = asyncio.run(do_load())
total = sum(len(items) for items in result.values())
typer.echo(f"\nLoaded {total} record(s) successfully.")
@app.command("show")
def show_fixture(
ctx: typer.Context,
name: Annotated[str, typer.Argument(help="Fixture name to show.")],
) -> None:
"""Show details of a specific fixture."""
registry = _get_registry(ctx)
try:
fixture = registry.get(name)
except KeyError:
typer.echo(f"Fixture '{name}' not found.", err=True)
raise typer.Exit(1)
typer.echo(f"\nFixture: {fixture.name}")
typer.echo(f"Contexts: {', '.join(fixture.contexts)}")
typer.echo(f"Dependencies: {', '.join(fixture.depends_on) if fixture.depends_on else 'None'}")
# Show instances
instances = list(fixture.func())
if instances:
model_name = type(instances[0]).__name__
typer.echo(f"\nInstances ({len(instances)} {model_name}):")
for instance in instances[:10]: # Limit to 10
typer.echo(f" - {instance!r}")
if len(instances) > 10:
typer.echo(f" ... and {len(instances) - 10} more")
else:
typer.echo("\nNo instances (empty fixture)")

View File

@@ -0,0 +1,378 @@
"""Generic async CRUD operations for SQLAlchemy models."""
from collections.abc import Sequence
from typing import Any, ClassVar, Generic, Self, TypeVar, cast
from pydantic import BaseModel
from sqlalchemy import and_, func, select
from sqlalchemy import delete as sql_delete
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.exc import NoResultFound
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.sql.roles import WhereHavingRole
from .db import get_transaction
from .exceptions import NotFoundError
__all__ = [
"AsyncCrud",
"CrudFactory",
]
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
class AsyncCrud(Generic[ModelType]):
"""Generic async CRUD operations for SQLAlchemy models.
Subclass this and set the `model` class variable, or use `CrudFactory`.
Example:
class UserCrud(AsyncCrud[User]):
model = User
# Or use the factory:
UserCrud = CrudFactory(User)
# Then use it:
user = await UserCrud.get(session, [User.id == 1])
users = await UserCrud.get_multi(session, limit=10)
"""
model: ClassVar[type[DeclarativeBase]]
@classmethod
async def create(
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
) -> ModelType:
"""Create a new record in the database.
Args:
session: DB async session
obj: Pydantic model with data to create
Returns:
Created model instance
"""
async with get_transaction(session):
db_model = cls.model(**obj.model_dump())
session.add(db_model)
await session.refresh(db_model)
return cast(ModelType, db_model)
@classmethod
async def get(
cls: type[Self],
session: AsyncSession,
filters: list[Any],
*,
with_for_update: bool = False,
load_options: list[Any] | None = None,
) -> ModelType:
"""Get exactly one record. Raises NotFoundError if not found.
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
with_for_update: Lock the row for update
load_options: SQLAlchemy loader options (e.g., selectinload)
Returns:
Model instance
Raises:
NotFoundError: If no record found
MultipleResultsFound: If more than one record found
"""
q = select(cls.model).where(and_(*filters))
if load_options:
q = q.options(*load_options)
if with_for_update:
q = q.with_for_update()
result = await session.execute(q)
item = result.unique().scalar_one_or_none()
if not item:
raise NotFoundError()
return cast(ModelType, item)
@classmethod
async def first(
cls: type[Self],
session: AsyncSession,
filters: list[Any] | None = None,
*,
load_options: list[Any] | None = None,
) -> ModelType | None:
"""Get the first matching record, or None.
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
load_options: SQLAlchemy loader options
Returns:
Model instance or None
"""
q = select(cls.model)
if filters:
q = q.where(and_(*filters))
if load_options:
q = q.options(*load_options)
result = await session.execute(q)
return cast(ModelType | None, result.unique().scalars().first())
@classmethod
async def get_multi(
cls: type[Self],
session: AsyncSession,
*,
filters: list[Any] | None = None,
load_options: list[Any] | None = None,
order_by: Any | None = None,
limit: int | None = None,
offset: int | None = None,
) -> Sequence[ModelType]:
"""Get multiple records from the database.
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
load_options: SQLAlchemy loader options
order_by: Column or list of columns to order by
limit: Max number of rows to return
offset: Rows to skip
Returns:
List of model instances
"""
q = select(cls.model)
if filters:
q = q.where(and_(*filters))
if load_options:
q = q.options(*load_options)
if order_by is not None:
q = q.order_by(order_by)
if offset is not None:
q = q.offset(offset)
if limit is not None:
q = q.limit(limit)
result = await session.execute(q)
return cast(Sequence[ModelType], result.unique().scalars().all())
@classmethod
async def update(
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
filters: list[Any],
*,
exclude_unset: bool = True,
exclude_none: bool = False,
) -> ModelType:
"""Update a record in the database.
Args:
session: DB async session
obj: Pydantic model with update data
filters: List of SQLAlchemy filter conditions
exclude_unset: Exclude fields not explicitly set in the schema
exclude_none: Exclude fields with None value
Returns:
Updated model instance
Raises:
NotFoundError: If no record found
"""
async with get_transaction(session):
db_model = await cls.get(session=session, filters=filters)
values = obj.model_dump(
exclude_unset=exclude_unset, exclude_none=exclude_none
)
for key, value in values.items():
setattr(db_model, key, value)
await session.refresh(db_model)
return db_model
@classmethod
async def upsert(
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
index_elements: list[str],
*,
set_: BaseModel | None = None,
where: WhereHavingRole | None = None,
) -> ModelType | None:
"""Create or update a record (PostgreSQL only).
Uses INSERT ... ON CONFLICT for atomic upsert.
Args:
session: DB async session
obj: Pydantic model with data
index_elements: Columns for ON CONFLICT (unique constraint)
set_: Pydantic model for ON CONFLICT DO UPDATE SET
where: WHERE clause for ON CONFLICT DO UPDATE
Returns:
Model instance
"""
async with get_transaction(session):
values = obj.model_dump(exclude_unset=True)
q = insert(cls.model).values(**values)
if set_:
q = q.on_conflict_do_update(
index_elements=index_elements,
set_=set_.model_dump(exclude_unset=True),
where=where,
)
else:
q = q.on_conflict_do_nothing(index_elements=index_elements)
q = q.returning(cls.model)
result = await session.execute(q)
try:
db_model = result.unique().scalar_one()
except NoResultFound:
db_model = await cls.first(
session=session,
filters=[getattr(cls.model, k) == v for k, v in values.items()],
)
return cast(ModelType | None, db_model)
@classmethod
async def delete(
cls: type[Self],
session: AsyncSession,
filters: list[Any],
) -> bool:
"""Delete records from the database.
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
Returns:
True if deletion was executed
"""
async with get_transaction(session):
q = sql_delete(cls.model).where(and_(*filters))
await session.execute(q)
return True
@classmethod
async def count(
cls: type[Self],
session: AsyncSession,
filters: list[Any] | None = None,
) -> int:
"""Count records matching the filters.
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
Returns:
Number of matching records
"""
q = select(func.count()).select_from(cls.model)
if filters:
q = q.where(and_(*filters))
result = await session.execute(q)
return result.scalar_one()
@classmethod
async def exists(
cls: type[Self],
session: AsyncSession,
filters: list[Any],
) -> bool:
"""Check if a record exists.
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
Returns:
True if at least one record matches
"""
q = select(cls.model).where(and_(*filters)).exists().select()
result = await session.execute(q)
return bool(result.scalar())
@classmethod
async def paginate(
cls: type[Self],
session: AsyncSession,
*,
filters: list[Any] | None = None,
load_options: list[Any] | None = None,
order_by: Any | None = None,
page: int = 1,
items_per_page: int = 20,
) -> dict[str, Any]:
"""Get paginated results with metadata.
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
load_options: SQLAlchemy loader options
order_by: Column or list of columns to order by
page: Page number (1-indexed)
items_per_page: Number of items per page
Returns:
Dict with 'data' and 'pagination' keys
"""
filters = filters or []
offset = (page - 1) * items_per_page
items = await cls.get_multi(
session,
filters=filters,
load_options=load_options,
order_by=order_by,
limit=items_per_page,
offset=offset,
)
total_count = await cls.count(session, filters=filters)
return {
"data": items,
"pagination": {
"total_count": total_count,
"items_per_page": items_per_page,
"page": page,
"has_more": page * items_per_page < total_count,
},
}
def CrudFactory(
model: type[ModelType],
) -> type[AsyncCrud[ModelType]]:
"""Create a CRUD class for a specific model.
Args:
model: SQLAlchemy model class
Returns:
AsyncCrud subclass bound to the model
Example:
from fastapi_toolsets.crud import CrudFactory
from myapp.models import User, Post
UserCrud = CrudFactory(User)
PostCrud = CrudFactory(Post)
# Usage
user = await UserCrud.get(session, [User.id == 1])
posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id])
"""
cls = type(f"Async{model.__name__}Crud", (AsyncCrud,), {"model": model})
return cast(type[AsyncCrud[ModelType]], cls)

175
src/fastapi_toolsets/db.py Normal file
View File

@@ -0,0 +1,175 @@
"""Database utilities: sessions, transactions, and locks."""
from collections.abc import AsyncGenerator, Callable
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from enum import Enum
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from sqlalchemy.orm import DeclarativeBase
__all__ = [
"LockMode",
"create_db_context",
"create_db_dependency",
"lock_tables",
"get_transaction",
]
def create_db_dependency(
session_maker: async_sessionmaker[AsyncSession],
) -> Callable[[], AsyncGenerator[AsyncSession, None]]:
"""Create a FastAPI dependency for database sessions.
Creates a dependency function that yields a session and auto-commits
if a transaction is active when the request completes.
Args:
session_maker: Async session factory from create_session_factory()
Returns:
An async generator function usable with FastAPI's Depends()
Example:
from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from fastapi_toolsets.db import create_db_dependency
engine = create_async_engine("postgresql+asyncpg://...")
SessionLocal = async_sessionmaker(engine, expire_on_commit=False)
get_db = create_db_dependency(SessionLocal)
@app.get("/users")
async def list_users(session: AsyncSession = Depends(get_db)):
...
"""
async def get_db() -> AsyncGenerator[AsyncSession, None]:
async with session_maker() as session:
yield session
if session.in_transaction():
await session.commit()
return get_db
def create_db_context(
session_maker: async_sessionmaker[AsyncSession],
) -> Callable[[], AbstractAsyncContextManager[AsyncSession]]:
"""Create a context manager for database sessions.
Creates a context manager for use outside of FastAPI request handlers,
such as in background tasks, CLI commands, or tests.
Args:
session_maker: Async session factory from create_session_factory()
Returns:
An async context manager function
Example:
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from fastapi_toolsets.db import create_db_context
engine = create_async_engine("postgresql+asyncpg://...")
SessionLocal = async_sessionmaker(engine, expire_on_commit=False)
get_db_context = create_db_context(SessionLocal)
async def background_task():
async with get_db_context() as session:
user = await UserCrud.get(session, [User.id == 1])
...
"""
get_db = create_db_dependency(session_maker)
return asynccontextmanager(get_db)
@asynccontextmanager
async def get_transaction(
session: AsyncSession,
) -> AsyncGenerator[AsyncSession, None]:
"""Get a transaction context, handling nested transactions.
If already in a transaction, creates a savepoint (nested transaction).
Otherwise, starts a new transaction.
Args:
session: AsyncSession instance
Yields:
The session within the transaction context
Example:
async with get_transaction(session):
session.add(model)
# Auto-commits on exit, rolls back on exception
"""
if session.in_transaction():
async with session.begin_nested():
yield session
else:
async with session.begin():
yield session
class LockMode(str, Enum):
"""PostgreSQL table lock modes.
See: https://www.postgresql.org/docs/current/explicit-locking.html
"""
ACCESS_SHARE = "ACCESS SHARE"
ROW_SHARE = "ROW SHARE"
ROW_EXCLUSIVE = "ROW EXCLUSIVE"
SHARE_UPDATE_EXCLUSIVE = "SHARE UPDATE EXCLUSIVE"
SHARE = "SHARE"
SHARE_ROW_EXCLUSIVE = "SHARE ROW EXCLUSIVE"
EXCLUSIVE = "EXCLUSIVE"
ACCESS_EXCLUSIVE = "ACCESS EXCLUSIVE"
@asynccontextmanager
async def lock_tables(
session: AsyncSession,
tables: list[type[DeclarativeBase]],
*,
mode: LockMode = LockMode.SHARE_UPDATE_EXCLUSIVE,
timeout: str = "5s",
) -> AsyncGenerator[AsyncSession, None]:
"""Lock PostgreSQL tables for the duration of a transaction.
Acquires table-level locks that are held until the transaction ends.
Useful for preventing concurrent modifications during critical operations.
Args:
session: AsyncSession instance
tables: List of SQLAlchemy model classes to lock
mode: Lock mode (default: SHARE UPDATE EXCLUSIVE)
timeout: Lock timeout (default: "5s")
Yields:
The session with locked tables
Raises:
SQLAlchemyError: If lock cannot be acquired within timeout
Example:
from fastapi_toolsets.db import lock_tables, LockMode
async with lock_tables(session, [User, Account]):
# Tables are locked with SHARE UPDATE EXCLUSIVE mode
user = await UserCrud.get(session, [User.id == 1])
user.balance += 100
# With custom lock mode
async with lock_tables(session, [Order], mode=LockMode.EXCLUSIVE):
# Exclusive lock - no other transactions can access
await process_order(session, order_id)
"""
table_names = ",".join(table.__tablename__ for table in tables)
async with get_transaction(session):
await session.execute(text(f"SET LOCAL lock_timeout='{timeout}'"))
await session.execute(text(f"LOCK {table_names} IN {mode.value} MODE"))
yield session

View File

@@ -0,0 +1,19 @@
from .exceptions import (
ApiException,
ConflictError,
ForbiddenError,
NotFoundError,
UnauthorizedError,
generate_error_responses,
)
from .handler import init_exceptions_handlers
__all__ = [
"init_exceptions_handlers",
"generate_error_responses",
"ApiException",
"ConflictError",
"ForbiddenError",
"NotFoundError",
"UnauthorizedError",
]

View File

@@ -0,0 +1,166 @@
"""Custom exceptions with standardized API error responses."""
from typing import Any, ClassVar
from ..schemas import ApiError, ErrorResponse, ResponseStatus
class ApiException(Exception):
"""Base exception for API errors with structured response.
Subclass this to create custom API exceptions with consistent error format.
The exception handler will use api_error to generate the response.
Example:
class CustomError(ApiException):
api_error = ApiError(
code=400,
msg="Bad Request",
desc="The request was invalid.",
err_code="CUSTOM-400",
)
"""
api_error: ClassVar[ApiError]
def __init__(self, detail: str | None = None):
"""Initialize the exception.
Args:
detail: Optional override for the error message
"""
super().__init__(detail or self.api_error.msg)
class UnauthorizedError(ApiException):
"""HTTP 401 - User is not authenticated."""
api_error = ApiError(
code=401,
msg="Unauthorized",
desc="Authentication credentials were missing or invalid.",
err_code="AUTH-401",
)
class ForbiddenError(ApiException):
"""HTTP 403 - User lacks required permissions."""
api_error = ApiError(
code=403,
msg="Forbidden",
desc="You do not have permission to access this resource.",
err_code="AUTH-403",
)
class NotFoundError(ApiException):
"""HTTP 404 - Resource not found."""
api_error = ApiError(
code=404,
msg="Not Found",
desc="The requested resource was not found.",
err_code="RES-404",
)
class ConflictError(ApiException):
"""HTTP 409 - Resource conflict."""
api_error = ApiError(
code=409,
msg="Conflict",
desc="The request conflicts with the current state of the resource.",
err_code="RES-409",
)
class InsufficientRolesError(ForbiddenError):
"""User does not have the required roles."""
api_error = ApiError(
code=403,
msg="Insufficient Roles",
desc="You do not have the required roles to access this resource.",
err_code="RBAC-403",
)
def __init__(self, required_roles: list[str], user_roles: set[str] | None = None):
self.required_roles = required_roles
self.user_roles = user_roles
desc = f"Required roles: {', '.join(required_roles)}"
if user_roles is not None:
desc += f". User has: {', '.join(user_roles) if user_roles else 'no roles'}"
super().__init__(desc)
class UserNotFoundError(NotFoundError):
"""User was not found."""
api_error = ApiError(
code=404,
msg="User Not Found",
desc="The requested user was not found.",
err_code="USER-404",
)
class RoleNotFoundError(NotFoundError):
"""Role was not found."""
api_error = ApiError(
code=404,
msg="Role Not Found",
desc="The requested role was not found.",
err_code="ROLE-404",
)
def generate_error_responses(
*errors: type[ApiException],
) -> dict[int | str, dict[str, Any]]:
"""Generate OpenAPI response documentation for exceptions.
Use this to document possible error responses for an endpoint.
Args:
*errors: Exception classes that inherit from ApiException
Returns:
Dict suitable for FastAPI's responses parameter
Example:
from fastapi_toolsets.exceptions import generate_error_responses, UnauthorizedError, ForbiddenError
@app.get(
"/admin",
responses=generate_error_responses(UnauthorizedError, ForbiddenError)
)
async def admin_endpoint():
...
"""
responses: dict[int | str, dict[str, Any]] = {}
for error in errors:
api_error = error.api_error
responses[api_error.code] = {
"model": ErrorResponse,
"description": api_error.msg,
"content": {
"application/json": {
"example": {
"data": None,
"status": ResponseStatus.FAIL.value,
"message": api_error.msg,
"description": api_error.desc,
"error_code": api_error.err_code,
}
}
},
}
return responses

View File

@@ -0,0 +1,169 @@
"""Exception handlers for FastAPI applications."""
from typing import Any
from fastapi import FastAPI, Request, Response, status
from fastapi.exceptions import RequestValidationError, ResponseValidationError
from fastapi.openapi.utils import get_openapi
from fastapi.responses import JSONResponse
from ..schemas import ResponseStatus
from .exceptions import ApiException
def init_exceptions_handlers(app: FastAPI) -> FastAPI:
_register_exception_handlers(app)
app.openapi = lambda: _custom_openapi(app) # type: ignore[method-assign]
return app
def _register_exception_handlers(app: FastAPI) -> None:
"""Register all exception handlers on a FastAPI application.
Args:
app: FastAPI application instance
Example:
from fastapi import FastAPI
from fastapi_toolsets.exceptions import init_exceptions_handlers
app = FastAPI()
init_exceptions_handlers(app)
"""
@app.exception_handler(ApiException)
async def api_exception_handler(request: Request, exc: ApiException) -> Response:
"""Handle custom API exceptions with structured response."""
api_error = exc.api_error
return JSONResponse(
status_code=api_error.code,
content={
"data": None,
"status": ResponseStatus.FAIL.value,
"message": api_error.msg,
"description": api_error.desc,
"error_code": api_error.err_code,
},
)
@app.exception_handler(RequestValidationError)
async def request_validation_handler(
request: Request, exc: RequestValidationError
) -> Response:
"""Handle Pydantic request validation errors (422)."""
return _format_validation_error(exc)
@app.exception_handler(ResponseValidationError)
async def response_validation_handler(
request: Request, exc: ResponseValidationError
) -> Response:
"""Handle Pydantic response validation errors (422)."""
return _format_validation_error(exc)
@app.exception_handler(Exception)
async def generic_exception_handler(request: Request, exc: Exception) -> Response:
"""Handle all unhandled exceptions with a generic 500 response."""
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={
"data": None,
"status": ResponseStatus.FAIL.value,
"message": "Internal Server Error",
"description": "An unexpected error occurred. Please try again later.",
"error_code": "SERVER-500",
},
)
def _format_validation_error(
exc: RequestValidationError | ResponseValidationError,
) -> JSONResponse:
"""Format validation errors into a structured response."""
errors = exc.errors()
formatted_errors = []
for error in errors:
field_path = ".".join(
str(loc)
for loc in error["loc"]
if loc not in ("body", "query", "path", "header", "cookie")
)
formatted_errors.append(
{
"field": field_path or "root",
"message": error.get("msg", ""),
"type": error.get("type", ""),
}
)
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content={
"data": {"errors": formatted_errors},
"status": ResponseStatus.FAIL.value,
"message": "Validation Error",
"description": f"{len(formatted_errors)} validation error(s) detected",
"error_code": "VAL-422",
},
)
def _custom_openapi(app: FastAPI) -> dict[str, Any]:
"""Generate custom OpenAPI schema with standardized error format.
Replaces default 422 validation error responses with the custom format.
Args:
app: FastAPI application instance
Returns:
OpenAPI schema dict
Example:
from fastapi import FastAPI
from fastapi_toolsets.exceptions import init_exceptions_handlers
app = FastAPI()
init_exceptions_handlers(app) # Automatically sets custom OpenAPI
"""
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title=app.title,
version=app.version,
openapi_version=app.openapi_version,
description=app.description,
routes=app.routes,
)
for path_data in openapi_schema.get("paths", {}).values():
for operation in path_data.values():
if isinstance(operation, dict) and "responses" in operation:
if "422" in operation["responses"]:
operation["responses"]["422"] = {
"description": "Validation Error",
"content": {
"application/json": {
"example": {
"data": {
"errors": [
{
"field": "field_name",
"message": "value is not valid",
"type": "value_error",
}
]
},
"status": ResponseStatus.FAIL.value,
"message": "Validation Error",
"description": "1 validation error(s) detected",
"error_code": "VAL-422",
}
}
},
}
app.openapi_schema = openapi_schema
return app.openapi_schema

View File

@@ -0,0 +1,17 @@
from .fixtures import (
Context,
FixtureRegistry,
LoadStrategy,
load_fixtures,
load_fixtures_by_context,
)
from .pytest_plugin import register_fixtures
__all__ = [
"Context",
"FixtureRegistry",
"LoadStrategy",
"load_fixtures",
"load_fixtures_by_context",
"register_fixtures",
]

View File

@@ -0,0 +1,321 @@
"""Fixture system with dependency management and context support."""
import logging
from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, cast
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase
from ..db import get_transaction
logger = logging.getLogger(__name__)
class LoadStrategy(str, Enum):
"""Strategy for loading fixtures into the database."""
INSERT = "insert"
"""Insert new records. Fails if record already exists."""
MERGE = "merge"
"""Insert or update based on primary key (SQLAlchemy merge)."""
SKIP_EXISTING = "skip_existing"
"""Insert only if record doesn't exist (based on primary key)."""
class Context(str, Enum):
"""Predefined fixture contexts."""
BASE = "base"
"""Base fixtures loaded in all environments."""
PRODUCTION = "production"
"""Production-only fixtures."""
DEVELOPMENT = "development"
"""Development fixtures."""
TESTING = "testing"
"""Test fixtures."""
@dataclass
class Fixture:
"""A fixture definition with metadata."""
name: str
func: Callable[[], Sequence[DeclarativeBase]]
depends_on: list[str] = field(default_factory=list)
contexts: list[str] = field(default_factory=lambda: [Context.BASE])
class FixtureRegistry:
"""Registry for managing fixtures with dependencies.
Example:
from fastapi_toolsets.fixtures import FixtureRegistry, Context
fixtures = FixtureRegistry()
@fixtures.register
def roles():
return [
Role(id=1, name="admin"),
Role(id=2, name="user"),
]
@fixtures.register(depends_on=["roles"])
def users():
return [
User(id=1, username="admin", role_id=1),
]
@fixtures.register(depends_on=["users"], contexts=[Context.TESTING])
def test_data():
return [
Post(id=1, title="Test", user_id=1),
]
"""
def __init__(self) -> None:
self._fixtures: dict[str, Fixture] = {}
def register(
self,
func: Callable[[], Sequence[DeclarativeBase]] | None = None,
*,
name: str | None = None,
depends_on: list[str] | None = None,
contexts: list[str | Context] | None = None,
) -> Callable[..., Any]:
"""Register a fixture function.
Can be used as a decorator with or without arguments.
Args:
func: Fixture function returning list of model instances
name: Fixture name (defaults to function name)
depends_on: List of fixture names this depends on
contexts: List of contexts this fixture belongs to
Example:
@fixtures.register
def roles():
return [Role(id=1, name="admin")]
@fixtures.register(depends_on=["roles"], contexts=[Context.TESTING])
def test_users():
return [User(id=1, username="test", role_id=1)]
"""
def decorator(
fn: Callable[[], Sequence[DeclarativeBase]],
) -> Callable[[], Sequence[DeclarativeBase]]:
fixture_name = name or cast(Any, fn).__name__
fixture_contexts = [
c.value if isinstance(c, Context) else c
for c in (contexts or [Context.BASE])
]
self._fixtures[fixture_name] = Fixture(
name=fixture_name,
func=fn,
depends_on=depends_on or [],
contexts=fixture_contexts,
)
return fn
if func is not None:
return decorator(func)
return decorator
def get(self, name: str) -> Fixture:
"""Get a fixture by name."""
if name not in self._fixtures:
raise KeyError(f"Fixture '{name}' not found")
return self._fixtures[name]
def get_all(self) -> list[Fixture]:
"""Get all registered fixtures."""
return list(self._fixtures.values())
def get_by_context(self, *contexts: str | Context) -> list[Fixture]:
"""Get fixtures for specific contexts."""
context_values = {c.value if isinstance(c, Context) else c for c in contexts}
return [f for f in self._fixtures.values() if set(f.contexts) & context_values]
def resolve_dependencies(self, *names: str) -> list[str]:
"""Resolve fixture dependencies in topological order.
Args:
*names: Fixture names to resolve
Returns:
List of fixture names in load order (dependencies first)
Raises:
KeyError: If a fixture is not found
ValueError: If circular dependency detected
"""
resolved: list[str] = []
seen: set[str] = set()
visiting: set[str] = set()
def visit(name: str) -> None:
if name in resolved:
return
if name in visiting:
raise ValueError(f"Circular dependency detected: {name}")
visiting.add(name)
fixture = self.get(name)
for dep in fixture.depends_on:
visit(dep)
visiting.remove(name)
resolved.append(name)
seen.add(name)
for name in names:
visit(name)
return resolved
def resolve_context_dependencies(self, *contexts: str | Context) -> list[str]:
"""Resolve all fixtures for contexts with dependencies.
Args:
*contexts: Contexts to load
Returns:
List of fixture names in load order
"""
context_fixtures = self.get_by_context(*contexts)
names = [f.name for f in context_fixtures]
all_deps: set[str] = set()
for name in names:
deps = self.resolve_dependencies(name)
all_deps.update(deps)
return self.resolve_dependencies(*all_deps)
async def load_fixtures(
session: AsyncSession,
registry: FixtureRegistry,
*names: str,
strategy: LoadStrategy = LoadStrategy.MERGE,
) -> dict[str, list[DeclarativeBase]]:
"""Load specific fixtures by name with dependencies.
Args:
session: Database session
registry: Fixture registry
*names: Fixture names to load (dependencies auto-resolved)
strategy: How to handle existing records
Returns:
Dict mapping fixture names to loaded instances
Example:
# Loads 'roles' first (dependency), then 'users'
result = await load_fixtures(session, fixtures, "users")
print(result["users"]) # [User(...), ...]
"""
ordered = registry.resolve_dependencies(*names)
return await _load_ordered(session, registry, ordered, strategy)
async def load_fixtures_by_context(
session: AsyncSession,
registry: FixtureRegistry,
*contexts: str | Context,
strategy: LoadStrategy = LoadStrategy.MERGE,
) -> dict[str, list[DeclarativeBase]]:
"""Load all fixtures for specific contexts.
Args:
session: Database session
registry: Fixture registry
*contexts: Contexts to load (e.g., Context.BASE, Context.TESTING)
strategy: How to handle existing records
Returns:
Dict mapping fixture names to loaded instances
Example:
# Load base + testing fixtures
await load_fixtures_by_context(
session, fixtures,
Context.BASE, Context.TESTING
)
"""
ordered = registry.resolve_context_dependencies(*contexts)
return await _load_ordered(session, registry, ordered, strategy)
async def _load_ordered(
session: AsyncSession,
registry: FixtureRegistry,
ordered_names: list[str],
strategy: LoadStrategy,
) -> dict[str, list[DeclarativeBase]]:
"""Load fixtures in order."""
results: dict[str, list[DeclarativeBase]] = {}
for name in ordered_names:
fixture = registry.get(name)
instances = list(fixture.func())
if not instances:
results[name] = []
continue
model_name = type(instances[0]).__name__
loaded: list[DeclarativeBase] = []
async with get_transaction(session):
for instance in instances:
if strategy == LoadStrategy.INSERT:
session.add(instance)
loaded.append(instance)
elif strategy == LoadStrategy.MERGE:
merged = await session.merge(instance)
loaded.append(merged)
elif strategy == LoadStrategy.SKIP_EXISTING:
pk = _get_primary_key(instance)
if pk is not None:
existing = await session.get(type(instance), pk)
if existing is None:
session.add(instance)
loaded.append(instance)
else:
session.add(instance)
loaded.append(instance)
results[name] = loaded
logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")
return results
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
"""Get the primary key value of a model instance."""
mapper = instance.__class__.__mapper__
pk_cols = mapper.primary_key
if len(pk_cols) == 1:
return getattr(instance, pk_cols[0].name, None)
pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols)
if all(v is not None for v in pk_values):
return pk_values
return None

View File

@@ -0,0 +1,205 @@
"""Pytest plugin for using FixtureRegistry fixtures in tests.
This module provides utilities to automatically generate pytest fixtures
from your FixtureRegistry, with proper dependency resolution.
Example:
# conftest.py
import pytest
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from app.fixtures import fixtures # Your FixtureRegistry
from app.models import Base
from fastapi_toolsets.pytest_plugin import register_fixtures
DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost:5432/test_db"
@pytest.fixture
async def engine():
engine = create_async_engine(DATABASE_URL)
yield engine
await engine.dispose()
@pytest.fixture
async def db_session(engine):
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
session_factory = async_sessionmaker(engine, expire_on_commit=False)
session = session_factory()
try:
yield session
finally:
await session.close()
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
# Automatically generate pytest fixtures from registry
# Creates: fixture_roles, fixture_users, fixture_posts, etc.
register_fixtures(fixtures, globals())
Usage in tests:
# test_users.py
async def test_user_count(db_session, fixture_users):
# fixture_users automatically loads fixture_roles first (if dependency)
# and returns the list of User models
assert len(fixture_users) > 0
async def test_user_role(db_session, fixture_users):
user = fixture_users[0]
assert user.role_id is not None
"""
from collections.abc import Callable, Sequence
from typing import Any
import pytest
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase
from ..db import get_transaction
from .fixtures import FixtureRegistry, LoadStrategy
def register_fixtures(
registry: FixtureRegistry,
namespace: dict[str, Any],
*,
prefix: str = "fixture_",
session_fixture: str = "db_session",
strategy: LoadStrategy = LoadStrategy.MERGE,
) -> list[str]:
"""Register pytest fixtures from a FixtureRegistry.
Automatically creates pytest fixtures for each fixture in the registry.
Dependencies are resolved via pytest fixture dependencies.
Args:
registry: The FixtureRegistry containing fixtures
namespace: The module's globals() dict to add fixtures to
prefix: Prefix for generated fixture names (default: "fixture_")
session_fixture: Name of the db session fixture (default: "db_session")
strategy: Loading strategy for fixtures (default: MERGE)
Returns:
List of created fixture names
Example:
# conftest.py
from app.fixtures import fixtures
from fastapi_toolsets.pytest_plugin import register_fixtures
register_fixtures(fixtures, globals())
# Creates fixtures like:
# - fixture_roles
# - fixture_users (depends on fixture_roles if users depends on roles)
# - fixture_posts (depends on fixture_users if posts depends on users)
"""
created_fixtures: list[str] = []
for fixture in registry.get_all():
fixture_name = f"{prefix}{fixture.name}"
# Build list of pytest fixture dependencies
pytest_deps = [session_fixture]
for dep in fixture.depends_on:
pytest_deps.append(f"{prefix}{dep}")
# Create the fixture function
fixture_func = _create_fixture_function(
registry=registry,
fixture_name=fixture.name,
dependencies=pytest_deps,
strategy=strategy,
)
# Apply pytest.fixture decorator
decorated = pytest.fixture(fixture_func)
# Add to namespace
namespace[fixture_name] = decorated
created_fixtures.append(fixture_name)
return created_fixtures
def _create_fixture_function(
registry: FixtureRegistry,
fixture_name: str,
dependencies: list[str],
strategy: LoadStrategy,
) -> Callable[..., Any]:
"""Create a fixture function with the correct signature.
The function signature must include all dependencies as parameters
for pytest to resolve them correctly.
"""
# Get the fixture definition
fixture_def = registry.get(fixture_name)
# Build the function dynamically with correct parameters
# We need the session as first param, then all dependencies
async def fixture_func(**kwargs: Any) -> Sequence[DeclarativeBase]:
# Get session from kwargs (first dependency)
session: AsyncSession = kwargs[dependencies[0]]
# Load the fixture data
instances = list(fixture_def.func())
if not instances:
return []
loaded: list[DeclarativeBase] = []
async with get_transaction(session):
for instance in instances:
if strategy == LoadStrategy.INSERT:
session.add(instance)
loaded.append(instance)
elif strategy == LoadStrategy.MERGE:
merged = await session.merge(instance)
loaded.append(merged)
elif strategy == LoadStrategy.SKIP_EXISTING:
pk = _get_primary_key(instance)
if pk is not None:
existing = await session.get(type(instance), pk)
if existing is None:
session.add(instance)
loaded.append(instance)
else:
loaded.append(existing)
else:
session.add(instance)
loaded.append(instance)
return loaded
# Update function signature to include dependencies
# This is needed for pytest to inject the right fixtures
params = ", ".join(dependencies)
code = f"async def {fixture_name}_fixture({params}):\n return await _impl({', '.join(f'{d}={d}' for d in dependencies)})"
local_ns: dict[str, Any] = {"_impl": fixture_func}
exec(code, local_ns) # noqa: S102
created_func = local_ns[f"{fixture_name}_fixture"]
created_func.__doc__ = f"Load {fixture_name} fixture data."
return created_func
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
"""Get the primary key value of a model instance."""
mapper = instance.__class__.__mapper__
pk_cols = mapper.primary_key
if len(pk_cols) == 1:
return getattr(instance, pk_cols[0].name, None)
pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols)
if all(v is not None for v in pk_values):
return pk_values
return None

View File

View File

@@ -0,0 +1,116 @@
"""Base Pydantic schemas for API responses."""
from enum import Enum
from typing import ClassVar, Generic, TypeVar
from pydantic import BaseModel, ConfigDict
__all__ = [
"ApiError",
"ErrorResponse",
"Pagination",
"PaginatedResponse",
"Response",
"ResponseStatus",
]
DataT = TypeVar("DataT")
class PydanticBase(BaseModel):
"""Base class for all Pydantic models with common configuration."""
model_config: ClassVar[ConfigDict] = ConfigDict(
from_attributes=True,
validate_assignment=True,
)
class ResponseStatus(str, Enum):
"""Standard API response status."""
SUCCESS = "SUCCESS"
FAIL = "FAIL"
class ApiError(PydanticBase):
"""Structured API error definition.
Used to define standard error responses with consistent format.
Attributes:
code: HTTP status code
msg: Short error message
desc: Detailed error description
err_code: Application-specific error code (e.g., "AUTH-401")
"""
code: int
msg: str
desc: str
err_code: str
class BaseResponse(PydanticBase):
"""Base response structure for all API responses.
Attributes:
status: SUCCESS or FAIL
message: Human-readable message
error_code: Error code if status is FAIL, None otherwise
"""
status: ResponseStatus = ResponseStatus.SUCCESS
message: str = "Success"
error_code: str | None = None
class Response(BaseResponse, Generic[DataT]):
"""Generic API response with data payload.
Example:
Response[UserRead](data=user, message="User retrieved")
"""
data: DataT | None = None
class ErrorResponse(BaseResponse):
"""Error response with additional description field.
Used for error responses that need more context.
"""
status: ResponseStatus = ResponseStatus.FAIL
description: str | None = None
data: None = None
class Pagination(PydanticBase):
"""Pagination metadata for list responses.
Attributes:
total_count: Total number of items across all pages
items_per_page: Number of items per page
page: Current page number (1-indexed)
has_more: Whether there are more pages
"""
total_count: int
items_per_page: int
page: int
has_more: bool
class PaginatedResponse(BaseResponse, Generic[DataT]):
"""Paginated API response for list endpoints.
Example:
PaginatedResponse[UserRead](
data=users,
pagination=Pagination(total_count=100, items_per_page=10, page=1, has_more=True)
)
"""
data: list[DataT]
pagination: Pagination