mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-03-01 17:00:48 +01:00
Initial commit
This commit is contained in:
24
src/fastapi_toolsets/__init__.py
Normal file
24
src/fastapi_toolsets/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""FastAPI utilities package.
|
||||
|
||||
Provides CRUD operations, fixtures, CLI, and standardized API responses
|
||||
for FastAPI with async SQLAlchemy and PostgreSQL.
|
||||
|
||||
Example usage:
|
||||
from fastapi import FastAPI, Depends
|
||||
from fastapi_toolsets.exceptions import init_exceptions_handlers
|
||||
from fastapi_toolsets.crud import CrudFactory
|
||||
from fastapi_toolsets.db import create_db_dependency
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
app = FastAPI()
|
||||
init_exceptions_handlers(app)
|
||||
|
||||
UserCrud = CrudFactory(User)
|
||||
|
||||
@app.get("/users/{user_id}", response_model=Response[dict])
|
||||
async def get_user(user_id: int, session = Depends(get_db)):
|
||||
user = await UserCrud.get(session, [User.id == user_id])
|
||||
return Response(data={"user": user.username}, message="Success")
|
||||
"""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
5
src/fastapi_toolsets/cli/__init__.py
Normal file
5
src/fastapi_toolsets/cli/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""CLI for FastAPI projects."""
|
||||
|
||||
from .app import app, register_command
|
||||
|
||||
__all__ = ["app", "register_command"]
|
||||
95
src/fastapi_toolsets/cli/app.py
Normal file
95
src/fastapi_toolsets/cli/app.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""Main CLI application."""
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
import typer
|
||||
|
||||
from .commands import fixtures
|
||||
|
||||
app = typer.Typer(
|
||||
name="fastapi-utils",
|
||||
help="CLI utilities for FastAPI projects.",
|
||||
no_args_is_help=True,
|
||||
)
|
||||
|
||||
# Register built-in commands
|
||||
app.add_typer(fixtures.app, name="fixtures")
|
||||
|
||||
|
||||
def register_command(command: typer.Typer, name: str) -> None:
|
||||
"""Register a custom command group.
|
||||
|
||||
Args:
|
||||
command: Typer app for the command group
|
||||
name: Name for the command group
|
||||
|
||||
Example:
|
||||
# In your project's cli.py:
|
||||
import typer
|
||||
from fastapi_toolsets.cli import app, register_command
|
||||
|
||||
my_commands = typer.Typer()
|
||||
|
||||
@my_commands.command()
|
||||
def seed():
|
||||
'''Seed the database.'''
|
||||
...
|
||||
|
||||
register_command(my_commands, "db")
|
||||
# Now available as: fastapi-utils db seed
|
||||
"""
|
||||
app.add_typer(command, name=name)
|
||||
|
||||
|
||||
@app.callback()
|
||||
def main(
|
||||
ctx: typer.Context,
|
||||
config: Annotated[
|
||||
Path | None,
|
||||
typer.Option(
|
||||
"--config",
|
||||
"-c",
|
||||
help="Path to project config file (Python module with fixtures registry).",
|
||||
envvar="FASTAPI_TOOLSETS_CONFIG",
|
||||
),
|
||||
] = None,
|
||||
) -> None:
|
||||
"""FastAPI utilities CLI."""
|
||||
ctx.ensure_object(dict)
|
||||
|
||||
if config:
|
||||
ctx.obj["config_path"] = config
|
||||
# Load the config module
|
||||
config_module = _load_module_from_path(config)
|
||||
ctx.obj["config_module"] = config_module
|
||||
|
||||
|
||||
def _load_module_from_path(path: Path) -> object:
|
||||
"""Load a Python module from a file path.
|
||||
|
||||
Handles both absolute and relative imports by adding the config's
|
||||
parent directory to sys.path temporarily.
|
||||
"""
|
||||
path = path.resolve()
|
||||
|
||||
# Add the parent directory to sys.path to support relative imports
|
||||
parent_dir = str(path.parent.parent) # Go up two levels (e.g., from app/cli_config.py to project root)
|
||||
if parent_dir not in sys.path:
|
||||
sys.path.insert(0, parent_dir)
|
||||
|
||||
# Also add immediate parent for direct module imports
|
||||
immediate_parent = str(path.parent)
|
||||
if immediate_parent not in sys.path:
|
||||
sys.path.insert(0, immediate_parent)
|
||||
|
||||
spec = importlib.util.spec_from_file_location("config", path)
|
||||
if spec is None or spec.loader is None:
|
||||
raise typer.BadParameter(f"Cannot load module from {path}")
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules["config"] = module
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
1
src/fastapi_toolsets/cli/commands/__init__.py
Normal file
1
src/fastapi_toolsets/cli/commands/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Built-in CLI commands."""
|
||||
211
src/fastapi_toolsets/cli/commands/fixtures.py
Normal file
211
src/fastapi_toolsets/cli/commands/fixtures.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""Fixture management commands."""
|
||||
|
||||
import asyncio
|
||||
from typing import Annotated
|
||||
|
||||
import typer
|
||||
|
||||
from ...fixtures import Context, FixtureRegistry, LoadStrategy, load_fixtures_by_context
|
||||
|
||||
app = typer.Typer(
|
||||
name="fixtures",
|
||||
help="Manage database fixtures.",
|
||||
no_args_is_help=True,
|
||||
)
|
||||
|
||||
|
||||
def _get_registry(ctx: typer.Context) -> FixtureRegistry:
|
||||
"""Get fixture registry from context."""
|
||||
config = ctx.obj.get("config_module") if ctx.obj else None
|
||||
if config is None:
|
||||
raise typer.BadParameter(
|
||||
"No config provided. Use --config to specify a config file with a 'fixtures' registry."
|
||||
)
|
||||
|
||||
registry = getattr(config, "fixtures", None)
|
||||
if registry is None:
|
||||
raise typer.BadParameter(
|
||||
"Config module must have a 'fixtures' attribute (FixtureRegistry instance)."
|
||||
)
|
||||
|
||||
if not isinstance(registry, FixtureRegistry):
|
||||
raise typer.BadParameter(
|
||||
f"'fixtures' must be a FixtureRegistry instance, got {type(registry).__name__}"
|
||||
)
|
||||
|
||||
return registry
|
||||
|
||||
|
||||
def _get_db_context(ctx: typer.Context):
|
||||
"""Get database context manager from config."""
|
||||
config = ctx.obj.get("config_module") if ctx.obj else None
|
||||
if config is None:
|
||||
raise typer.BadParameter("No config provided.")
|
||||
|
||||
get_db_context = getattr(config, "get_db_context", None)
|
||||
if get_db_context is None:
|
||||
raise typer.BadParameter(
|
||||
"Config module must have a 'get_db_context' function."
|
||||
)
|
||||
|
||||
return get_db_context
|
||||
|
||||
|
||||
@app.command("list")
|
||||
def list_fixtures(
|
||||
ctx: typer.Context,
|
||||
context: Annotated[
|
||||
str | None,
|
||||
typer.Option("--context", "-c", help="Filter by context (base, production, development, testing)."),
|
||||
] = None,
|
||||
) -> None:
|
||||
"""List all registered fixtures."""
|
||||
registry = _get_registry(ctx)
|
||||
|
||||
if context:
|
||||
fixtures = registry.get_by_context(context)
|
||||
else:
|
||||
fixtures = registry.get_all()
|
||||
|
||||
if not fixtures:
|
||||
typer.echo("No fixtures found.")
|
||||
return
|
||||
|
||||
typer.echo(f"\n{'Name':<30} {'Contexts':<30} {'Dependencies'}")
|
||||
typer.echo("-" * 80)
|
||||
|
||||
for fixture in fixtures:
|
||||
contexts = ", ".join(fixture.contexts)
|
||||
deps = ", ".join(fixture.depends_on) if fixture.depends_on else "-"
|
||||
typer.echo(f"{fixture.name:<30} {contexts:<30} {deps}")
|
||||
|
||||
typer.echo(f"\nTotal: {len(fixtures)} fixture(s)")
|
||||
|
||||
|
||||
@app.command("graph")
|
||||
def show_graph(
|
||||
ctx: typer.Context,
|
||||
fixture_name: Annotated[
|
||||
str | None,
|
||||
typer.Argument(help="Show dependencies for a specific fixture."),
|
||||
] = None,
|
||||
) -> None:
|
||||
"""Show fixture dependency graph."""
|
||||
registry = _get_registry(ctx)
|
||||
|
||||
if fixture_name:
|
||||
try:
|
||||
order = registry.resolve_dependencies(fixture_name)
|
||||
typer.echo(f"\nDependency chain for '{fixture_name}':\n")
|
||||
for i, name in enumerate(order):
|
||||
indent = " " * i
|
||||
arrow = "└─> " if i > 0 else ""
|
||||
typer.echo(f"{indent}{arrow}{name}")
|
||||
except KeyError:
|
||||
typer.echo(f"Fixture '{fixture_name}' not found.", err=True)
|
||||
raise typer.Exit(1)
|
||||
else:
|
||||
# Show full graph
|
||||
fixtures = registry.get_all()
|
||||
|
||||
typer.echo("\nFixture Dependency Graph:\n")
|
||||
for fixture in fixtures:
|
||||
deps = f" -> [{', '.join(fixture.depends_on)}]" if fixture.depends_on else ""
|
||||
typer.echo(f" {fixture.name}{deps}")
|
||||
|
||||
|
||||
@app.command("load")
|
||||
def load(
|
||||
ctx: typer.Context,
|
||||
contexts: Annotated[
|
||||
list[str] | None,
|
||||
typer.Argument(help="Contexts to load (base, production, development, testing)."),
|
||||
] = None,
|
||||
strategy: Annotated[
|
||||
str,
|
||||
typer.Option("--strategy", "-s", help="Load strategy: merge, insert, skip_existing."),
|
||||
] = "merge",
|
||||
dry_run: Annotated[
|
||||
bool,
|
||||
typer.Option("--dry-run", "-n", help="Show what would be loaded without loading."),
|
||||
] = False,
|
||||
) -> None:
|
||||
"""Load fixtures into the database."""
|
||||
registry = _get_registry(ctx)
|
||||
get_db_context = _get_db_context(ctx)
|
||||
|
||||
# Parse contexts
|
||||
if contexts:
|
||||
context_list = contexts
|
||||
else:
|
||||
context_list = [Context.BASE]
|
||||
|
||||
# Parse strategy
|
||||
try:
|
||||
load_strategy = LoadStrategy(strategy)
|
||||
except ValueError:
|
||||
typer.echo(f"Invalid strategy: {strategy}. Use: merge, insert, skip_existing", err=True)
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Resolve what will be loaded
|
||||
ordered = registry.resolve_context_dependencies(*context_list)
|
||||
|
||||
if not ordered:
|
||||
typer.echo("No fixtures to load for the specified context(s).")
|
||||
return
|
||||
|
||||
typer.echo(f"\nFixtures to load ({load_strategy.value} strategy):")
|
||||
for name in ordered:
|
||||
fixture = registry.get(name)
|
||||
instances = list(fixture.func())
|
||||
model_name = type(instances[0]).__name__ if instances else "?"
|
||||
typer.echo(f" - {name}: {len(instances)} {model_name}(s)")
|
||||
|
||||
if dry_run:
|
||||
typer.echo("\n[Dry run - no changes made]")
|
||||
return
|
||||
|
||||
typer.echo("\nLoading...")
|
||||
|
||||
async def do_load():
|
||||
async with get_db_context() as session:
|
||||
result = await load_fixtures_by_context(
|
||||
session, registry, *context_list, strategy=load_strategy
|
||||
)
|
||||
return result
|
||||
|
||||
result = asyncio.run(do_load())
|
||||
|
||||
total = sum(len(items) for items in result.values())
|
||||
typer.echo(f"\nLoaded {total} record(s) successfully.")
|
||||
|
||||
|
||||
@app.command("show")
|
||||
def show_fixture(
|
||||
ctx: typer.Context,
|
||||
name: Annotated[str, typer.Argument(help="Fixture name to show.")],
|
||||
) -> None:
|
||||
"""Show details of a specific fixture."""
|
||||
registry = _get_registry(ctx)
|
||||
|
||||
try:
|
||||
fixture = registry.get(name)
|
||||
except KeyError:
|
||||
typer.echo(f"Fixture '{name}' not found.", err=True)
|
||||
raise typer.Exit(1)
|
||||
|
||||
typer.echo(f"\nFixture: {fixture.name}")
|
||||
typer.echo(f"Contexts: {', '.join(fixture.contexts)}")
|
||||
typer.echo(f"Dependencies: {', '.join(fixture.depends_on) if fixture.depends_on else 'None'}")
|
||||
|
||||
# Show instances
|
||||
instances = list(fixture.func())
|
||||
if instances:
|
||||
model_name = type(instances[0]).__name__
|
||||
typer.echo(f"\nInstances ({len(instances)} {model_name}):")
|
||||
for instance in instances[:10]: # Limit to 10
|
||||
typer.echo(f" - {instance!r}")
|
||||
if len(instances) > 10:
|
||||
typer.echo(f" ... and {len(instances) - 10} more")
|
||||
else:
|
||||
typer.echo("\nNo instances (empty fixture)")
|
||||
378
src/fastapi_toolsets/crud.py
Normal file
378
src/fastapi_toolsets/crud.py
Normal file
@@ -0,0 +1,378 @@
|
||||
"""Generic async CRUD operations for SQLAlchemy models."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, ClassVar, Generic, Self, TypeVar, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import and_, func, select
|
||||
from sqlalchemy import delete as sql_delete
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from sqlalchemy.sql.roles import WhereHavingRole
|
||||
|
||||
from .db import get_transaction
|
||||
from .exceptions import NotFoundError
|
||||
|
||||
__all__ = [
|
||||
"AsyncCrud",
|
||||
"CrudFactory",
|
||||
]
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
||||
|
||||
|
||||
class AsyncCrud(Generic[ModelType]):
|
||||
"""Generic async CRUD operations for SQLAlchemy models.
|
||||
|
||||
Subclass this and set the `model` class variable, or use `CrudFactory`.
|
||||
|
||||
Example:
|
||||
class UserCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
|
||||
# Or use the factory:
|
||||
UserCrud = CrudFactory(User)
|
||||
|
||||
# Then use it:
|
||||
user = await UserCrud.get(session, [User.id == 1])
|
||||
users = await UserCrud.get_multi(session, limit=10)
|
||||
"""
|
||||
|
||||
model: ClassVar[type[DeclarativeBase]]
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
obj: BaseModel,
|
||||
) -> ModelType:
|
||||
"""Create a new record in the database.
|
||||
|
||||
Args:
|
||||
session: DB async session
|
||||
obj: Pydantic model with data to create
|
||||
|
||||
Returns:
|
||||
Created model instance
|
||||
"""
|
||||
async with get_transaction(session):
|
||||
db_model = cls.model(**obj.model_dump())
|
||||
session.add(db_model)
|
||||
await session.refresh(db_model)
|
||||
return cast(ModelType, db_model)
|
||||
|
||||
@classmethod
|
||||
async def get(
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
filters: list[Any],
|
||||
*,
|
||||
with_for_update: bool = False,
|
||||
load_options: list[Any] | None = None,
|
||||
) -> ModelType:
|
||||
"""Get exactly one record. Raises NotFoundError if not found.
|
||||
|
||||
Args:
|
||||
session: DB async session
|
||||
filters: List of SQLAlchemy filter conditions
|
||||
with_for_update: Lock the row for update
|
||||
load_options: SQLAlchemy loader options (e.g., selectinload)
|
||||
|
||||
Returns:
|
||||
Model instance
|
||||
|
||||
Raises:
|
||||
NotFoundError: If no record found
|
||||
MultipleResultsFound: If more than one record found
|
||||
"""
|
||||
q = select(cls.model).where(and_(*filters))
|
||||
if load_options:
|
||||
q = q.options(*load_options)
|
||||
if with_for_update:
|
||||
q = q.with_for_update()
|
||||
result = await session.execute(q)
|
||||
item = result.unique().scalar_one_or_none()
|
||||
if not item:
|
||||
raise NotFoundError()
|
||||
return cast(ModelType, item)
|
||||
|
||||
@classmethod
|
||||
async def first(
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
filters: list[Any] | None = None,
|
||||
*,
|
||||
load_options: list[Any] | None = None,
|
||||
) -> ModelType | None:
|
||||
"""Get the first matching record, or None.
|
||||
|
||||
Args:
|
||||
session: DB async session
|
||||
filters: List of SQLAlchemy filter conditions
|
||||
load_options: SQLAlchemy loader options
|
||||
|
||||
Returns:
|
||||
Model instance or None
|
||||
"""
|
||||
q = select(cls.model)
|
||||
if filters:
|
||||
q = q.where(and_(*filters))
|
||||
if load_options:
|
||||
q = q.options(*load_options)
|
||||
result = await session.execute(q)
|
||||
return cast(ModelType | None, result.unique().scalars().first())
|
||||
|
||||
@classmethod
|
||||
async def get_multi(
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
*,
|
||||
filters: list[Any] | None = None,
|
||||
load_options: list[Any] | None = None,
|
||||
order_by: Any | None = None,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
) -> Sequence[ModelType]:
|
||||
"""Get multiple records from the database.
|
||||
|
||||
Args:
|
||||
session: DB async session
|
||||
filters: List of SQLAlchemy filter conditions
|
||||
load_options: SQLAlchemy loader options
|
||||
order_by: Column or list of columns to order by
|
||||
limit: Max number of rows to return
|
||||
offset: Rows to skip
|
||||
|
||||
Returns:
|
||||
List of model instances
|
||||
"""
|
||||
q = select(cls.model)
|
||||
if filters:
|
||||
q = q.where(and_(*filters))
|
||||
if load_options:
|
||||
q = q.options(*load_options)
|
||||
if order_by is not None:
|
||||
q = q.order_by(order_by)
|
||||
if offset is not None:
|
||||
q = q.offset(offset)
|
||||
if limit is not None:
|
||||
q = q.limit(limit)
|
||||
result = await session.execute(q)
|
||||
return cast(Sequence[ModelType], result.unique().scalars().all())
|
||||
|
||||
@classmethod
|
||||
async def update(
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
obj: BaseModel,
|
||||
filters: list[Any],
|
||||
*,
|
||||
exclude_unset: bool = True,
|
||||
exclude_none: bool = False,
|
||||
) -> ModelType:
|
||||
"""Update a record in the database.
|
||||
|
||||
Args:
|
||||
session: DB async session
|
||||
obj: Pydantic model with update data
|
||||
filters: List of SQLAlchemy filter conditions
|
||||
exclude_unset: Exclude fields not explicitly set in the schema
|
||||
exclude_none: Exclude fields with None value
|
||||
|
||||
Returns:
|
||||
Updated model instance
|
||||
|
||||
Raises:
|
||||
NotFoundError: If no record found
|
||||
"""
|
||||
async with get_transaction(session):
|
||||
db_model = await cls.get(session=session, filters=filters)
|
||||
values = obj.model_dump(
|
||||
exclude_unset=exclude_unset, exclude_none=exclude_none
|
||||
)
|
||||
for key, value in values.items():
|
||||
setattr(db_model, key, value)
|
||||
await session.refresh(db_model)
|
||||
return db_model
|
||||
|
||||
@classmethod
|
||||
async def upsert(
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
obj: BaseModel,
|
||||
index_elements: list[str],
|
||||
*,
|
||||
set_: BaseModel | None = None,
|
||||
where: WhereHavingRole | None = None,
|
||||
) -> ModelType | None:
|
||||
"""Create or update a record (PostgreSQL only).
|
||||
|
||||
Uses INSERT ... ON CONFLICT for atomic upsert.
|
||||
|
||||
Args:
|
||||
session: DB async session
|
||||
obj: Pydantic model with data
|
||||
index_elements: Columns for ON CONFLICT (unique constraint)
|
||||
set_: Pydantic model for ON CONFLICT DO UPDATE SET
|
||||
where: WHERE clause for ON CONFLICT DO UPDATE
|
||||
|
||||
Returns:
|
||||
Model instance
|
||||
"""
|
||||
async with get_transaction(session):
|
||||
values = obj.model_dump(exclude_unset=True)
|
||||
q = insert(cls.model).values(**values)
|
||||
if set_:
|
||||
q = q.on_conflict_do_update(
|
||||
index_elements=index_elements,
|
||||
set_=set_.model_dump(exclude_unset=True),
|
||||
where=where,
|
||||
)
|
||||
else:
|
||||
q = q.on_conflict_do_nothing(index_elements=index_elements)
|
||||
q = q.returning(cls.model)
|
||||
result = await session.execute(q)
|
||||
try:
|
||||
db_model = result.unique().scalar_one()
|
||||
except NoResultFound:
|
||||
db_model = await cls.first(
|
||||
session=session,
|
||||
filters=[getattr(cls.model, k) == v for k, v in values.items()],
|
||||
)
|
||||
return cast(ModelType | None, db_model)
|
||||
|
||||
@classmethod
|
||||
async def delete(
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
filters: list[Any],
|
||||
) -> bool:
|
||||
"""Delete records from the database.
|
||||
|
||||
Args:
|
||||
session: DB async session
|
||||
filters: List of SQLAlchemy filter conditions
|
||||
|
||||
Returns:
|
||||
True if deletion was executed
|
||||
"""
|
||||
async with get_transaction(session):
|
||||
q = sql_delete(cls.model).where(and_(*filters))
|
||||
await session.execute(q)
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
async def count(
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
filters: list[Any] | None = None,
|
||||
) -> int:
|
||||
"""Count records matching the filters.
|
||||
|
||||
Args:
|
||||
session: DB async session
|
||||
filters: List of SQLAlchemy filter conditions
|
||||
|
||||
Returns:
|
||||
Number of matching records
|
||||
"""
|
||||
q = select(func.count()).select_from(cls.model)
|
||||
if filters:
|
||||
q = q.where(and_(*filters))
|
||||
result = await session.execute(q)
|
||||
return result.scalar_one()
|
||||
|
||||
@classmethod
|
||||
async def exists(
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
filters: list[Any],
|
||||
) -> bool:
|
||||
"""Check if a record exists.
|
||||
|
||||
Args:
|
||||
session: DB async session
|
||||
filters: List of SQLAlchemy filter conditions
|
||||
|
||||
Returns:
|
||||
True if at least one record matches
|
||||
"""
|
||||
q = select(cls.model).where(and_(*filters)).exists().select()
|
||||
result = await session.execute(q)
|
||||
return bool(result.scalar())
|
||||
|
||||
@classmethod
|
||||
async def paginate(
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
*,
|
||||
filters: list[Any] | None = None,
|
||||
load_options: list[Any] | None = None,
|
||||
order_by: Any | None = None,
|
||||
page: int = 1,
|
||||
items_per_page: int = 20,
|
||||
) -> dict[str, Any]:
|
||||
"""Get paginated results with metadata.
|
||||
|
||||
Args:
|
||||
session: DB async session
|
||||
filters: List of SQLAlchemy filter conditions
|
||||
load_options: SQLAlchemy loader options
|
||||
order_by: Column or list of columns to order by
|
||||
page: Page number (1-indexed)
|
||||
items_per_page: Number of items per page
|
||||
|
||||
Returns:
|
||||
Dict with 'data' and 'pagination' keys
|
||||
"""
|
||||
filters = filters or []
|
||||
offset = (page - 1) * items_per_page
|
||||
|
||||
items = await cls.get_multi(
|
||||
session,
|
||||
filters=filters,
|
||||
load_options=load_options,
|
||||
order_by=order_by,
|
||||
limit=items_per_page,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
total_count = await cls.count(session, filters=filters)
|
||||
|
||||
return {
|
||||
"data": items,
|
||||
"pagination": {
|
||||
"total_count": total_count,
|
||||
"items_per_page": items_per_page,
|
||||
"page": page,
|
||||
"has_more": page * items_per_page < total_count,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def CrudFactory(
|
||||
model: type[ModelType],
|
||||
) -> type[AsyncCrud[ModelType]]:
|
||||
"""Create a CRUD class for a specific model.
|
||||
|
||||
Args:
|
||||
model: SQLAlchemy model class
|
||||
|
||||
Returns:
|
||||
AsyncCrud subclass bound to the model
|
||||
|
||||
Example:
|
||||
from fastapi_toolsets.crud import CrudFactory
|
||||
from myapp.models import User, Post
|
||||
|
||||
UserCrud = CrudFactory(User)
|
||||
PostCrud = CrudFactory(Post)
|
||||
|
||||
# Usage
|
||||
user = await UserCrud.get(session, [User.id == 1])
|
||||
posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id])
|
||||
"""
|
||||
cls = type(f"Async{model.__name__}Crud", (AsyncCrud,), {"model": model})
|
||||
return cast(type[AsyncCrud[ModelType]], cls)
|
||||
175
src/fastapi_toolsets/db.py
Normal file
175
src/fastapi_toolsets/db.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""Database utilities: sessions, transactions, and locks."""
|
||||
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
__all__ = [
|
||||
"LockMode",
|
||||
"create_db_context",
|
||||
"create_db_dependency",
|
||||
"lock_tables",
|
||||
"get_transaction",
|
||||
]
|
||||
|
||||
|
||||
def create_db_dependency(
|
||||
session_maker: async_sessionmaker[AsyncSession],
|
||||
) -> Callable[[], AsyncGenerator[AsyncSession, None]]:
|
||||
"""Create a FastAPI dependency for database sessions.
|
||||
|
||||
Creates a dependency function that yields a session and auto-commits
|
||||
if a transaction is active when the request completes.
|
||||
|
||||
Args:
|
||||
session_maker: Async session factory from create_session_factory()
|
||||
|
||||
Returns:
|
||||
An async generator function usable with FastAPI's Depends()
|
||||
|
||||
Example:
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||
from fastapi_toolsets.db import create_db_dependency
|
||||
|
||||
engine = create_async_engine("postgresql+asyncpg://...")
|
||||
SessionLocal = async_sessionmaker(engine, expire_on_commit=False)
|
||||
get_db = create_db_dependency(SessionLocal)
|
||||
|
||||
@app.get("/users")
|
||||
async def list_users(session: AsyncSession = Depends(get_db)):
|
||||
...
|
||||
"""
|
||||
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with session_maker() as session:
|
||||
yield session
|
||||
if session.in_transaction():
|
||||
await session.commit()
|
||||
|
||||
return get_db
|
||||
|
||||
|
||||
def create_db_context(
|
||||
session_maker: async_sessionmaker[AsyncSession],
|
||||
) -> Callable[[], AbstractAsyncContextManager[AsyncSession]]:
|
||||
"""Create a context manager for database sessions.
|
||||
|
||||
Creates a context manager for use outside of FastAPI request handlers,
|
||||
such as in background tasks, CLI commands, or tests.
|
||||
|
||||
Args:
|
||||
session_maker: Async session factory from create_session_factory()
|
||||
|
||||
Returns:
|
||||
An async context manager function
|
||||
|
||||
Example:
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||
from fastapi_toolsets.db import create_db_context
|
||||
|
||||
engine = create_async_engine("postgresql+asyncpg://...")
|
||||
SessionLocal = async_sessionmaker(engine, expire_on_commit=False)
|
||||
get_db_context = create_db_context(SessionLocal)
|
||||
|
||||
async def background_task():
|
||||
async with get_db_context() as session:
|
||||
user = await UserCrud.get(session, [User.id == 1])
|
||||
...
|
||||
"""
|
||||
get_db = create_db_dependency(session_maker)
|
||||
return asynccontextmanager(get_db)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_transaction(
|
||||
session: AsyncSession,
|
||||
) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Get a transaction context, handling nested transactions.
|
||||
|
||||
If already in a transaction, creates a savepoint (nested transaction).
|
||||
Otherwise, starts a new transaction.
|
||||
|
||||
Args:
|
||||
session: AsyncSession instance
|
||||
|
||||
Yields:
|
||||
The session within the transaction context
|
||||
|
||||
Example:
|
||||
async with get_transaction(session):
|
||||
session.add(model)
|
||||
# Auto-commits on exit, rolls back on exception
|
||||
"""
|
||||
if session.in_transaction():
|
||||
async with session.begin_nested():
|
||||
yield session
|
||||
else:
|
||||
async with session.begin():
|
||||
yield session
|
||||
|
||||
|
||||
class LockMode(str, Enum):
|
||||
"""PostgreSQL table lock modes.
|
||||
|
||||
See: https://www.postgresql.org/docs/current/explicit-locking.html
|
||||
"""
|
||||
|
||||
ACCESS_SHARE = "ACCESS SHARE"
|
||||
ROW_SHARE = "ROW SHARE"
|
||||
ROW_EXCLUSIVE = "ROW EXCLUSIVE"
|
||||
SHARE_UPDATE_EXCLUSIVE = "SHARE UPDATE EXCLUSIVE"
|
||||
SHARE = "SHARE"
|
||||
SHARE_ROW_EXCLUSIVE = "SHARE ROW EXCLUSIVE"
|
||||
EXCLUSIVE = "EXCLUSIVE"
|
||||
ACCESS_EXCLUSIVE = "ACCESS EXCLUSIVE"
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lock_tables(
|
||||
session: AsyncSession,
|
||||
tables: list[type[DeclarativeBase]],
|
||||
*,
|
||||
mode: LockMode = LockMode.SHARE_UPDATE_EXCLUSIVE,
|
||||
timeout: str = "5s",
|
||||
) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Lock PostgreSQL tables for the duration of a transaction.
|
||||
|
||||
Acquires table-level locks that are held until the transaction ends.
|
||||
Useful for preventing concurrent modifications during critical operations.
|
||||
|
||||
Args:
|
||||
session: AsyncSession instance
|
||||
tables: List of SQLAlchemy model classes to lock
|
||||
mode: Lock mode (default: SHARE UPDATE EXCLUSIVE)
|
||||
timeout: Lock timeout (default: "5s")
|
||||
|
||||
Yields:
|
||||
The session with locked tables
|
||||
|
||||
Raises:
|
||||
SQLAlchemyError: If lock cannot be acquired within timeout
|
||||
|
||||
Example:
|
||||
from fastapi_toolsets.db import lock_tables, LockMode
|
||||
|
||||
async with lock_tables(session, [User, Account]):
|
||||
# Tables are locked with SHARE UPDATE EXCLUSIVE mode
|
||||
user = await UserCrud.get(session, [User.id == 1])
|
||||
user.balance += 100
|
||||
|
||||
# With custom lock mode
|
||||
async with lock_tables(session, [Order], mode=LockMode.EXCLUSIVE):
|
||||
# Exclusive lock - no other transactions can access
|
||||
await process_order(session, order_id)
|
||||
"""
|
||||
table_names = ",".join(table.__tablename__ for table in tables)
|
||||
|
||||
async with get_transaction(session):
|
||||
await session.execute(text(f"SET LOCAL lock_timeout='{timeout}'"))
|
||||
await session.execute(text(f"LOCK {table_names} IN {mode.value} MODE"))
|
||||
yield session
|
||||
19
src/fastapi_toolsets/exceptions/__init__.py
Normal file
19
src/fastapi_toolsets/exceptions/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from .exceptions import (
|
||||
ApiException,
|
||||
ConflictError,
|
||||
ForbiddenError,
|
||||
NotFoundError,
|
||||
UnauthorizedError,
|
||||
generate_error_responses,
|
||||
)
|
||||
from .handler import init_exceptions_handlers
|
||||
|
||||
__all__ = [
|
||||
"init_exceptions_handlers",
|
||||
"generate_error_responses",
|
||||
"ApiException",
|
||||
"ConflictError",
|
||||
"ForbiddenError",
|
||||
"NotFoundError",
|
||||
"UnauthorizedError",
|
||||
]
|
||||
166
src/fastapi_toolsets/exceptions/exceptions.py
Normal file
166
src/fastapi_toolsets/exceptions/exceptions.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""Custom exceptions with standardized API error responses."""
|
||||
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from ..schemas import ApiError, ErrorResponse, ResponseStatus
|
||||
|
||||
|
||||
class ApiException(Exception):
|
||||
"""Base exception for API errors with structured response.
|
||||
|
||||
Subclass this to create custom API exceptions with consistent error format.
|
||||
The exception handler will use api_error to generate the response.
|
||||
|
||||
Example:
|
||||
class CustomError(ApiException):
|
||||
api_error = ApiError(
|
||||
code=400,
|
||||
msg="Bad Request",
|
||||
desc="The request was invalid.",
|
||||
err_code="CUSTOM-400",
|
||||
)
|
||||
"""
|
||||
|
||||
api_error: ClassVar[ApiError]
|
||||
|
||||
def __init__(self, detail: str | None = None):
|
||||
"""Initialize the exception.
|
||||
|
||||
Args:
|
||||
detail: Optional override for the error message
|
||||
"""
|
||||
super().__init__(detail or self.api_error.msg)
|
||||
|
||||
|
||||
class UnauthorizedError(ApiException):
|
||||
"""HTTP 401 - User is not authenticated."""
|
||||
|
||||
api_error = ApiError(
|
||||
code=401,
|
||||
msg="Unauthorized",
|
||||
desc="Authentication credentials were missing or invalid.",
|
||||
err_code="AUTH-401",
|
||||
)
|
||||
|
||||
|
||||
class ForbiddenError(ApiException):
|
||||
"""HTTP 403 - User lacks required permissions."""
|
||||
|
||||
api_error = ApiError(
|
||||
code=403,
|
||||
msg="Forbidden",
|
||||
desc="You do not have permission to access this resource.",
|
||||
err_code="AUTH-403",
|
||||
)
|
||||
|
||||
|
||||
class NotFoundError(ApiException):
|
||||
"""HTTP 404 - Resource not found."""
|
||||
|
||||
api_error = ApiError(
|
||||
code=404,
|
||||
msg="Not Found",
|
||||
desc="The requested resource was not found.",
|
||||
err_code="RES-404",
|
||||
)
|
||||
|
||||
|
||||
class ConflictError(ApiException):
|
||||
"""HTTP 409 - Resource conflict."""
|
||||
|
||||
api_error = ApiError(
|
||||
code=409,
|
||||
msg="Conflict",
|
||||
desc="The request conflicts with the current state of the resource.",
|
||||
err_code="RES-409",
|
||||
)
|
||||
|
||||
|
||||
class InsufficientRolesError(ForbiddenError):
|
||||
"""User does not have the required roles."""
|
||||
|
||||
api_error = ApiError(
|
||||
code=403,
|
||||
msg="Insufficient Roles",
|
||||
desc="You do not have the required roles to access this resource.",
|
||||
err_code="RBAC-403",
|
||||
)
|
||||
|
||||
def __init__(self, required_roles: list[str], user_roles: set[str] | None = None):
|
||||
self.required_roles = required_roles
|
||||
self.user_roles = user_roles
|
||||
|
||||
desc = f"Required roles: {', '.join(required_roles)}"
|
||||
if user_roles is not None:
|
||||
desc += f". User has: {', '.join(user_roles) if user_roles else 'no roles'}"
|
||||
|
||||
super().__init__(desc)
|
||||
|
||||
|
||||
class UserNotFoundError(NotFoundError):
|
||||
"""User was not found."""
|
||||
|
||||
api_error = ApiError(
|
||||
code=404,
|
||||
msg="User Not Found",
|
||||
desc="The requested user was not found.",
|
||||
err_code="USER-404",
|
||||
)
|
||||
|
||||
|
||||
class RoleNotFoundError(NotFoundError):
|
||||
"""Role was not found."""
|
||||
|
||||
api_error = ApiError(
|
||||
code=404,
|
||||
msg="Role Not Found",
|
||||
desc="The requested role was not found.",
|
||||
err_code="ROLE-404",
|
||||
)
|
||||
|
||||
|
||||
def generate_error_responses(
|
||||
*errors: type[ApiException],
|
||||
) -> dict[int | str, dict[str, Any]]:
|
||||
"""Generate OpenAPI response documentation for exceptions.
|
||||
|
||||
Use this to document possible error responses for an endpoint.
|
||||
|
||||
Args:
|
||||
*errors: Exception classes that inherit from ApiException
|
||||
|
||||
Returns:
|
||||
Dict suitable for FastAPI's responses parameter
|
||||
|
||||
Example:
|
||||
from fastapi_toolsets.exceptions import generate_error_responses, UnauthorizedError, ForbiddenError
|
||||
|
||||
@app.get(
|
||||
"/admin",
|
||||
responses=generate_error_responses(UnauthorizedError, ForbiddenError)
|
||||
)
|
||||
async def admin_endpoint():
|
||||
...
|
||||
"""
|
||||
responses: dict[int | str, dict[str, Any]] = {}
|
||||
|
||||
for error in errors:
|
||||
api_error = error.api_error
|
||||
|
||||
responses[api_error.code] = {
|
||||
"model": ErrorResponse,
|
||||
"description": api_error.msg,
|
||||
"content": {
|
||||
"application/json": {
|
||||
"example": {
|
||||
"data": None,
|
||||
"status": ResponseStatus.FAIL.value,
|
||||
"message": api_error.msg,
|
||||
"description": api_error.desc,
|
||||
"error_code": api_error.err_code,
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
return responses
|
||||
169
src/fastapi_toolsets/exceptions/handler.py
Normal file
169
src/fastapi_toolsets/exceptions/handler.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""Exception handlers for FastAPI applications."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI, Request, Response, status
|
||||
from fastapi.exceptions import RequestValidationError, ResponseValidationError
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from ..schemas import ResponseStatus
|
||||
from .exceptions import ApiException
|
||||
|
||||
|
||||
def init_exceptions_handlers(app: FastAPI) -> FastAPI:
|
||||
_register_exception_handlers(app)
|
||||
app.openapi = lambda: _custom_openapi(app) # type: ignore[method-assign]
|
||||
return app
|
||||
|
||||
|
||||
def _register_exception_handlers(app: FastAPI) -> None:
|
||||
"""Register all exception handlers on a FastAPI application.
|
||||
|
||||
Args:
|
||||
app: FastAPI application instance
|
||||
|
||||
Example:
|
||||
from fastapi import FastAPI
|
||||
from fastapi_toolsets.exceptions import init_exceptions_handlers
|
||||
|
||||
app = FastAPI()
|
||||
init_exceptions_handlers(app)
|
||||
"""
|
||||
|
||||
@app.exception_handler(ApiException)
|
||||
async def api_exception_handler(request: Request, exc: ApiException) -> Response:
|
||||
"""Handle custom API exceptions with structured response."""
|
||||
api_error = exc.api_error
|
||||
|
||||
return JSONResponse(
|
||||
status_code=api_error.code,
|
||||
content={
|
||||
"data": None,
|
||||
"status": ResponseStatus.FAIL.value,
|
||||
"message": api_error.msg,
|
||||
"description": api_error.desc,
|
||||
"error_code": api_error.err_code,
|
||||
},
|
||||
)
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def request_validation_handler(
|
||||
request: Request, exc: RequestValidationError
|
||||
) -> Response:
|
||||
"""Handle Pydantic request validation errors (422)."""
|
||||
return _format_validation_error(exc)
|
||||
|
||||
@app.exception_handler(ResponseValidationError)
|
||||
async def response_validation_handler(
|
||||
request: Request, exc: ResponseValidationError
|
||||
) -> Response:
|
||||
"""Handle Pydantic response validation errors (422)."""
|
||||
return _format_validation_error(exc)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def generic_exception_handler(request: Request, exc: Exception) -> Response:
|
||||
"""Handle all unhandled exceptions with a generic 500 response."""
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={
|
||||
"data": None,
|
||||
"status": ResponseStatus.FAIL.value,
|
||||
"message": "Internal Server Error",
|
||||
"description": "An unexpected error occurred. Please try again later.",
|
||||
"error_code": "SERVER-500",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _format_validation_error(
|
||||
exc: RequestValidationError | ResponseValidationError,
|
||||
) -> JSONResponse:
|
||||
"""Format validation errors into a structured response."""
|
||||
errors = exc.errors()
|
||||
formatted_errors = []
|
||||
|
||||
for error in errors:
|
||||
field_path = ".".join(
|
||||
str(loc)
|
||||
for loc in error["loc"]
|
||||
if loc not in ("body", "query", "path", "header", "cookie")
|
||||
)
|
||||
formatted_errors.append(
|
||||
{
|
||||
"field": field_path or "root",
|
||||
"message": error.get("msg", ""),
|
||||
"type": error.get("type", ""),
|
||||
}
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
content={
|
||||
"data": {"errors": formatted_errors},
|
||||
"status": ResponseStatus.FAIL.value,
|
||||
"message": "Validation Error",
|
||||
"description": f"{len(formatted_errors)} validation error(s) detected",
|
||||
"error_code": "VAL-422",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _custom_openapi(app: FastAPI) -> dict[str, Any]:
|
||||
"""Generate custom OpenAPI schema with standardized error format.
|
||||
|
||||
Replaces default 422 validation error responses with the custom format.
|
||||
|
||||
Args:
|
||||
app: FastAPI application instance
|
||||
|
||||
Returns:
|
||||
OpenAPI schema dict
|
||||
|
||||
Example:
|
||||
from fastapi import FastAPI
|
||||
from fastapi_toolsets.exceptions import init_exceptions_handlers
|
||||
|
||||
app = FastAPI()
|
||||
init_exceptions_handlers(app) # Automatically sets custom OpenAPI
|
||||
"""
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
|
||||
openapi_schema = get_openapi(
|
||||
title=app.title,
|
||||
version=app.version,
|
||||
openapi_version=app.openapi_version,
|
||||
description=app.description,
|
||||
routes=app.routes,
|
||||
)
|
||||
|
||||
for path_data in openapi_schema.get("paths", {}).values():
|
||||
for operation in path_data.values():
|
||||
if isinstance(operation, dict) and "responses" in operation:
|
||||
if "422" in operation["responses"]:
|
||||
operation["responses"]["422"] = {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"example": {
|
||||
"data": {
|
||||
"errors": [
|
||||
{
|
||||
"field": "field_name",
|
||||
"message": "value is not valid",
|
||||
"type": "value_error",
|
||||
}
|
||||
]
|
||||
},
|
||||
"status": ResponseStatus.FAIL.value,
|
||||
"message": "Validation Error",
|
||||
"description": "1 validation error(s) detected",
|
||||
"error_code": "VAL-422",
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
17
src/fastapi_toolsets/fixtures/__init__.py
Normal file
17
src/fastapi_toolsets/fixtures/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from .fixtures import (
|
||||
Context,
|
||||
FixtureRegistry,
|
||||
LoadStrategy,
|
||||
load_fixtures,
|
||||
load_fixtures_by_context,
|
||||
)
|
||||
from .pytest_plugin import register_fixtures
|
||||
|
||||
__all__ = [
|
||||
"Context",
|
||||
"FixtureRegistry",
|
||||
"LoadStrategy",
|
||||
"load_fixtures",
|
||||
"load_fixtures_by_context",
|
||||
"register_fixtures",
|
||||
]
|
||||
321
src/fastapi_toolsets/fixtures/fixtures.py
Normal file
321
src/fastapi_toolsets/fixtures/fixtures.py
Normal file
@@ -0,0 +1,321 @@
|
||||
"""Fixture system with dependency management and context support."""
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from ..db import get_transaction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoadStrategy(str, Enum):
|
||||
"""Strategy for loading fixtures into the database."""
|
||||
|
||||
INSERT = "insert"
|
||||
"""Insert new records. Fails if record already exists."""
|
||||
|
||||
MERGE = "merge"
|
||||
"""Insert or update based on primary key (SQLAlchemy merge)."""
|
||||
|
||||
SKIP_EXISTING = "skip_existing"
|
||||
"""Insert only if record doesn't exist (based on primary key)."""
|
||||
|
||||
|
||||
class Context(str, Enum):
|
||||
"""Predefined fixture contexts."""
|
||||
|
||||
BASE = "base"
|
||||
"""Base fixtures loaded in all environments."""
|
||||
|
||||
PRODUCTION = "production"
|
||||
"""Production-only fixtures."""
|
||||
|
||||
DEVELOPMENT = "development"
|
||||
"""Development fixtures."""
|
||||
|
||||
TESTING = "testing"
|
||||
"""Test fixtures."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class Fixture:
|
||||
"""A fixture definition with metadata."""
|
||||
|
||||
name: str
|
||||
func: Callable[[], Sequence[DeclarativeBase]]
|
||||
depends_on: list[str] = field(default_factory=list)
|
||||
contexts: list[str] = field(default_factory=lambda: [Context.BASE])
|
||||
|
||||
|
||||
class FixtureRegistry:
|
||||
"""Registry for managing fixtures with dependencies.
|
||||
|
||||
Example:
|
||||
from fastapi_toolsets.fixtures import FixtureRegistry, Context
|
||||
|
||||
fixtures = FixtureRegistry()
|
||||
|
||||
@fixtures.register
|
||||
def roles():
|
||||
return [
|
||||
Role(id=1, name="admin"),
|
||||
Role(id=2, name="user"),
|
||||
]
|
||||
|
||||
@fixtures.register(depends_on=["roles"])
|
||||
def users():
|
||||
return [
|
||||
User(id=1, username="admin", role_id=1),
|
||||
]
|
||||
|
||||
@fixtures.register(depends_on=["users"], contexts=[Context.TESTING])
|
||||
def test_data():
|
||||
return [
|
||||
Post(id=1, title="Test", user_id=1),
|
||||
]
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._fixtures: dict[str, Fixture] = {}
|
||||
|
||||
def register(
|
||||
self,
|
||||
func: Callable[[], Sequence[DeclarativeBase]] | None = None,
|
||||
*,
|
||||
name: str | None = None,
|
||||
depends_on: list[str] | None = None,
|
||||
contexts: list[str | Context] | None = None,
|
||||
) -> Callable[..., Any]:
|
||||
"""Register a fixture function.
|
||||
|
||||
Can be used as a decorator with or without arguments.
|
||||
|
||||
Args:
|
||||
func: Fixture function returning list of model instances
|
||||
name: Fixture name (defaults to function name)
|
||||
depends_on: List of fixture names this depends on
|
||||
contexts: List of contexts this fixture belongs to
|
||||
|
||||
Example:
|
||||
@fixtures.register
|
||||
def roles():
|
||||
return [Role(id=1, name="admin")]
|
||||
|
||||
@fixtures.register(depends_on=["roles"], contexts=[Context.TESTING])
|
||||
def test_users():
|
||||
return [User(id=1, username="test", role_id=1)]
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
fn: Callable[[], Sequence[DeclarativeBase]],
|
||||
) -> Callable[[], Sequence[DeclarativeBase]]:
|
||||
fixture_name = name or cast(Any, fn).__name__
|
||||
fixture_contexts = [
|
||||
c.value if isinstance(c, Context) else c
|
||||
for c in (contexts or [Context.BASE])
|
||||
]
|
||||
|
||||
self._fixtures[fixture_name] = Fixture(
|
||||
name=fixture_name,
|
||||
func=fn,
|
||||
depends_on=depends_on or [],
|
||||
contexts=fixture_contexts,
|
||||
)
|
||||
return fn
|
||||
|
||||
if func is not None:
|
||||
return decorator(func)
|
||||
return decorator
|
||||
|
||||
def get(self, name: str) -> Fixture:
|
||||
"""Get a fixture by name."""
|
||||
if name not in self._fixtures:
|
||||
raise KeyError(f"Fixture '{name}' not found")
|
||||
return self._fixtures[name]
|
||||
|
||||
def get_all(self) -> list[Fixture]:
|
||||
"""Get all registered fixtures."""
|
||||
return list(self._fixtures.values())
|
||||
|
||||
def get_by_context(self, *contexts: str | Context) -> list[Fixture]:
|
||||
"""Get fixtures for specific contexts."""
|
||||
context_values = {c.value if isinstance(c, Context) else c for c in contexts}
|
||||
return [f for f in self._fixtures.values() if set(f.contexts) & context_values]
|
||||
|
||||
def resolve_dependencies(self, *names: str) -> list[str]:
|
||||
"""Resolve fixture dependencies in topological order.
|
||||
|
||||
Args:
|
||||
*names: Fixture names to resolve
|
||||
|
||||
Returns:
|
||||
List of fixture names in load order (dependencies first)
|
||||
|
||||
Raises:
|
||||
KeyError: If a fixture is not found
|
||||
ValueError: If circular dependency detected
|
||||
"""
|
||||
resolved: list[str] = []
|
||||
seen: set[str] = set()
|
||||
visiting: set[str] = set()
|
||||
|
||||
def visit(name: str) -> None:
|
||||
if name in resolved:
|
||||
return
|
||||
if name in visiting:
|
||||
raise ValueError(f"Circular dependency detected: {name}")
|
||||
|
||||
visiting.add(name)
|
||||
fixture = self.get(name)
|
||||
|
||||
for dep in fixture.depends_on:
|
||||
visit(dep)
|
||||
|
||||
visiting.remove(name)
|
||||
resolved.append(name)
|
||||
seen.add(name)
|
||||
|
||||
for name in names:
|
||||
visit(name)
|
||||
|
||||
return resolved
|
||||
|
||||
def resolve_context_dependencies(self, *contexts: str | Context) -> list[str]:
|
||||
"""Resolve all fixtures for contexts with dependencies.
|
||||
|
||||
Args:
|
||||
*contexts: Contexts to load
|
||||
|
||||
Returns:
|
||||
List of fixture names in load order
|
||||
"""
|
||||
context_fixtures = self.get_by_context(*contexts)
|
||||
names = [f.name for f in context_fixtures]
|
||||
|
||||
all_deps: set[str] = set()
|
||||
for name in names:
|
||||
deps = self.resolve_dependencies(name)
|
||||
all_deps.update(deps)
|
||||
|
||||
return self.resolve_dependencies(*all_deps)
|
||||
|
||||
|
||||
async def load_fixtures(
|
||||
session: AsyncSession,
|
||||
registry: FixtureRegistry,
|
||||
*names: str,
|
||||
strategy: LoadStrategy = LoadStrategy.MERGE,
|
||||
) -> dict[str, list[DeclarativeBase]]:
|
||||
"""Load specific fixtures by name with dependencies.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
registry: Fixture registry
|
||||
*names: Fixture names to load (dependencies auto-resolved)
|
||||
strategy: How to handle existing records
|
||||
|
||||
Returns:
|
||||
Dict mapping fixture names to loaded instances
|
||||
|
||||
Example:
|
||||
# Loads 'roles' first (dependency), then 'users'
|
||||
result = await load_fixtures(session, fixtures, "users")
|
||||
print(result["users"]) # [User(...), ...]
|
||||
"""
|
||||
ordered = registry.resolve_dependencies(*names)
|
||||
return await _load_ordered(session, registry, ordered, strategy)
|
||||
|
||||
|
||||
async def load_fixtures_by_context(
|
||||
session: AsyncSession,
|
||||
registry: FixtureRegistry,
|
||||
*contexts: str | Context,
|
||||
strategy: LoadStrategy = LoadStrategy.MERGE,
|
||||
) -> dict[str, list[DeclarativeBase]]:
|
||||
"""Load all fixtures for specific contexts.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
registry: Fixture registry
|
||||
*contexts: Contexts to load (e.g., Context.BASE, Context.TESTING)
|
||||
strategy: How to handle existing records
|
||||
|
||||
Returns:
|
||||
Dict mapping fixture names to loaded instances
|
||||
|
||||
Example:
|
||||
# Load base + testing fixtures
|
||||
await load_fixtures_by_context(
|
||||
session, fixtures,
|
||||
Context.BASE, Context.TESTING
|
||||
)
|
||||
"""
|
||||
ordered = registry.resolve_context_dependencies(*contexts)
|
||||
return await _load_ordered(session, registry, ordered, strategy)
|
||||
|
||||
|
||||
async def _load_ordered(
|
||||
session: AsyncSession,
|
||||
registry: FixtureRegistry,
|
||||
ordered_names: list[str],
|
||||
strategy: LoadStrategy,
|
||||
) -> dict[str, list[DeclarativeBase]]:
|
||||
"""Load fixtures in order."""
|
||||
results: dict[str, list[DeclarativeBase]] = {}
|
||||
|
||||
for name in ordered_names:
|
||||
fixture = registry.get(name)
|
||||
instances = list(fixture.func())
|
||||
|
||||
if not instances:
|
||||
results[name] = []
|
||||
continue
|
||||
|
||||
model_name = type(instances[0]).__name__
|
||||
loaded: list[DeclarativeBase] = []
|
||||
|
||||
async with get_transaction(session):
|
||||
for instance in instances:
|
||||
if strategy == LoadStrategy.INSERT:
|
||||
session.add(instance)
|
||||
loaded.append(instance)
|
||||
|
||||
elif strategy == LoadStrategy.MERGE:
|
||||
merged = await session.merge(instance)
|
||||
loaded.append(merged)
|
||||
|
||||
elif strategy == LoadStrategy.SKIP_EXISTING:
|
||||
pk = _get_primary_key(instance)
|
||||
if pk is not None:
|
||||
existing = await session.get(type(instance), pk)
|
||||
if existing is None:
|
||||
session.add(instance)
|
||||
loaded.append(instance)
|
||||
else:
|
||||
session.add(instance)
|
||||
loaded.append(instance)
|
||||
|
||||
results[name] = loaded
|
||||
logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
|
||||
"""Get the primary key value of a model instance."""
|
||||
mapper = instance.__class__.__mapper__
|
||||
pk_cols = mapper.primary_key
|
||||
|
||||
if len(pk_cols) == 1:
|
||||
return getattr(instance, pk_cols[0].name, None)
|
||||
|
||||
pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols)
|
||||
if all(v is not None for v in pk_values):
|
||||
return pk_values
|
||||
return None
|
||||
205
src/fastapi_toolsets/fixtures/pytest_plugin.py
Normal file
205
src/fastapi_toolsets/fixtures/pytest_plugin.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""Pytest plugin for using FixtureRegistry fixtures in tests.
|
||||
|
||||
This module provides utilities to automatically generate pytest fixtures
|
||||
from your FixtureRegistry, with proper dependency resolution.
|
||||
|
||||
Example:
|
||||
# conftest.py
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
from app.fixtures import fixtures # Your FixtureRegistry
|
||||
from app.models import Base
|
||||
from fastapi_toolsets.pytest_plugin import register_fixtures
|
||||
|
||||
DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost:5432/test_db"
|
||||
|
||||
@pytest.fixture
|
||||
async def engine():
|
||||
engine = create_async_engine(DATABASE_URL)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session(engine):
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
session = session_factory()
|
||||
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
|
||||
# Automatically generate pytest fixtures from registry
|
||||
# Creates: fixture_roles, fixture_users, fixture_posts, etc.
|
||||
register_fixtures(fixtures, globals())
|
||||
|
||||
Usage in tests:
|
||||
# test_users.py
|
||||
async def test_user_count(db_session, fixture_users):
|
||||
# fixture_users automatically loads fixture_roles first (if dependency)
|
||||
# and returns the list of User models
|
||||
assert len(fixture_users) > 0
|
||||
|
||||
async def test_user_role(db_session, fixture_users):
|
||||
user = fixture_users[0]
|
||||
assert user.role_id is not None
|
||||
"""
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from ..db import get_transaction
|
||||
from .fixtures import FixtureRegistry, LoadStrategy
|
||||
|
||||
|
||||
|
||||
def register_fixtures(
|
||||
registry: FixtureRegistry,
|
||||
namespace: dict[str, Any],
|
||||
*,
|
||||
prefix: str = "fixture_",
|
||||
session_fixture: str = "db_session",
|
||||
strategy: LoadStrategy = LoadStrategy.MERGE,
|
||||
) -> list[str]:
|
||||
"""Register pytest fixtures from a FixtureRegistry.
|
||||
|
||||
Automatically creates pytest fixtures for each fixture in the registry.
|
||||
Dependencies are resolved via pytest fixture dependencies.
|
||||
|
||||
Args:
|
||||
registry: The FixtureRegistry containing fixtures
|
||||
namespace: The module's globals() dict to add fixtures to
|
||||
prefix: Prefix for generated fixture names (default: "fixture_")
|
||||
session_fixture: Name of the db session fixture (default: "db_session")
|
||||
strategy: Loading strategy for fixtures (default: MERGE)
|
||||
|
||||
Returns:
|
||||
List of created fixture names
|
||||
|
||||
Example:
|
||||
# conftest.py
|
||||
from app.fixtures import fixtures
|
||||
from fastapi_toolsets.pytest_plugin import register_fixtures
|
||||
|
||||
register_fixtures(fixtures, globals())
|
||||
|
||||
# Creates fixtures like:
|
||||
# - fixture_roles
|
||||
# - fixture_users (depends on fixture_roles if users depends on roles)
|
||||
# - fixture_posts (depends on fixture_users if posts depends on users)
|
||||
"""
|
||||
created_fixtures: list[str] = []
|
||||
|
||||
for fixture in registry.get_all():
|
||||
fixture_name = f"{prefix}{fixture.name}"
|
||||
|
||||
# Build list of pytest fixture dependencies
|
||||
pytest_deps = [session_fixture]
|
||||
for dep in fixture.depends_on:
|
||||
pytest_deps.append(f"{prefix}{dep}")
|
||||
|
||||
# Create the fixture function
|
||||
fixture_func = _create_fixture_function(
|
||||
registry=registry,
|
||||
fixture_name=fixture.name,
|
||||
dependencies=pytest_deps,
|
||||
strategy=strategy,
|
||||
)
|
||||
|
||||
# Apply pytest.fixture decorator
|
||||
decorated = pytest.fixture(fixture_func)
|
||||
|
||||
# Add to namespace
|
||||
namespace[fixture_name] = decorated
|
||||
created_fixtures.append(fixture_name)
|
||||
|
||||
return created_fixtures
|
||||
|
||||
|
||||
def _create_fixture_function(
|
||||
registry: FixtureRegistry,
|
||||
fixture_name: str,
|
||||
dependencies: list[str],
|
||||
strategy: LoadStrategy,
|
||||
) -> Callable[..., Any]:
|
||||
"""Create a fixture function with the correct signature.
|
||||
|
||||
The function signature must include all dependencies as parameters
|
||||
for pytest to resolve them correctly.
|
||||
"""
|
||||
# Get the fixture definition
|
||||
fixture_def = registry.get(fixture_name)
|
||||
|
||||
# Build the function dynamically with correct parameters
|
||||
# We need the session as first param, then all dependencies
|
||||
async def fixture_func(**kwargs: Any) -> Sequence[DeclarativeBase]:
|
||||
# Get session from kwargs (first dependency)
|
||||
session: AsyncSession = kwargs[dependencies[0]]
|
||||
|
||||
# Load the fixture data
|
||||
instances = list(fixture_def.func())
|
||||
|
||||
if not instances:
|
||||
return []
|
||||
|
||||
loaded: list[DeclarativeBase] = []
|
||||
|
||||
async with get_transaction(session):
|
||||
for instance in instances:
|
||||
if strategy == LoadStrategy.INSERT:
|
||||
session.add(instance)
|
||||
loaded.append(instance)
|
||||
elif strategy == LoadStrategy.MERGE:
|
||||
merged = await session.merge(instance)
|
||||
loaded.append(merged)
|
||||
elif strategy == LoadStrategy.SKIP_EXISTING:
|
||||
pk = _get_primary_key(instance)
|
||||
if pk is not None:
|
||||
existing = await session.get(type(instance), pk)
|
||||
if existing is None:
|
||||
session.add(instance)
|
||||
loaded.append(instance)
|
||||
else:
|
||||
loaded.append(existing)
|
||||
else:
|
||||
session.add(instance)
|
||||
loaded.append(instance)
|
||||
|
||||
return loaded
|
||||
|
||||
# Update function signature to include dependencies
|
||||
# This is needed for pytest to inject the right fixtures
|
||||
params = ", ".join(dependencies)
|
||||
code = f"async def {fixture_name}_fixture({params}):\n return await _impl({', '.join(f'{d}={d}' for d in dependencies)})"
|
||||
|
||||
local_ns: dict[str, Any] = {"_impl": fixture_func}
|
||||
exec(code, local_ns) # noqa: S102
|
||||
|
||||
created_func = local_ns[f"{fixture_name}_fixture"]
|
||||
created_func.__doc__ = f"Load {fixture_name} fixture data."
|
||||
|
||||
return created_func
|
||||
|
||||
|
||||
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
|
||||
"""Get the primary key value of a model instance."""
|
||||
mapper = instance.__class__.__mapper__
|
||||
pk_cols = mapper.primary_key
|
||||
|
||||
if len(pk_cols) == 1:
|
||||
return getattr(instance, pk_cols[0].name, None)
|
||||
|
||||
pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols)
|
||||
if all(v is not None for v in pk_values):
|
||||
return pk_values
|
||||
return None
|
||||
0
src/fastapi_toolsets/py.typed
Normal file
0
src/fastapi_toolsets/py.typed
Normal file
116
src/fastapi_toolsets/schemas.py
Normal file
116
src/fastapi_toolsets/schemas.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""Base Pydantic schemas for API responses."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import ClassVar, Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
__all__ = [
|
||||
"ApiError",
|
||||
"ErrorResponse",
|
||||
"Pagination",
|
||||
"PaginatedResponse",
|
||||
"Response",
|
||||
"ResponseStatus",
|
||||
]
|
||||
|
||||
DataT = TypeVar("DataT")
|
||||
|
||||
|
||||
class PydanticBase(BaseModel):
|
||||
"""Base class for all Pydantic models with common configuration."""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(
|
||||
from_attributes=True,
|
||||
validate_assignment=True,
|
||||
)
|
||||
|
||||
|
||||
class ResponseStatus(str, Enum):
|
||||
"""Standard API response status."""
|
||||
|
||||
SUCCESS = "SUCCESS"
|
||||
FAIL = "FAIL"
|
||||
|
||||
|
||||
class ApiError(PydanticBase):
|
||||
"""Structured API error definition.
|
||||
|
||||
Used to define standard error responses with consistent format.
|
||||
|
||||
Attributes:
|
||||
code: HTTP status code
|
||||
msg: Short error message
|
||||
desc: Detailed error description
|
||||
err_code: Application-specific error code (e.g., "AUTH-401")
|
||||
"""
|
||||
|
||||
code: int
|
||||
msg: str
|
||||
desc: str
|
||||
err_code: str
|
||||
|
||||
|
||||
class BaseResponse(PydanticBase):
|
||||
"""Base response structure for all API responses.
|
||||
|
||||
Attributes:
|
||||
status: SUCCESS or FAIL
|
||||
message: Human-readable message
|
||||
error_code: Error code if status is FAIL, None otherwise
|
||||
"""
|
||||
|
||||
status: ResponseStatus = ResponseStatus.SUCCESS
|
||||
message: str = "Success"
|
||||
error_code: str | None = None
|
||||
|
||||
|
||||
class Response(BaseResponse, Generic[DataT]):
|
||||
"""Generic API response with data payload.
|
||||
|
||||
Example:
|
||||
Response[UserRead](data=user, message="User retrieved")
|
||||
"""
|
||||
|
||||
data: DataT | None = None
|
||||
|
||||
|
||||
class ErrorResponse(BaseResponse):
|
||||
"""Error response with additional description field.
|
||||
|
||||
Used for error responses that need more context.
|
||||
"""
|
||||
|
||||
status: ResponseStatus = ResponseStatus.FAIL
|
||||
description: str | None = None
|
||||
data: None = None
|
||||
|
||||
|
||||
class Pagination(PydanticBase):
|
||||
"""Pagination metadata for list responses.
|
||||
|
||||
Attributes:
|
||||
total_count: Total number of items across all pages
|
||||
items_per_page: Number of items per page
|
||||
page: Current page number (1-indexed)
|
||||
has_more: Whether there are more pages
|
||||
"""
|
||||
|
||||
total_count: int
|
||||
items_per_page: int
|
||||
page: int
|
||||
has_more: bool
|
||||
|
||||
|
||||
class PaginatedResponse(BaseResponse, Generic[DataT]):
|
||||
"""Paginated API response for list endpoints.
|
||||
|
||||
Example:
|
||||
PaginatedResponse[UserRead](
|
||||
data=users,
|
||||
pagination=Pagination(total_count=100, items_per_page=10, page=1, has_more=True)
|
||||
)
|
||||
"""
|
||||
|
||||
data: list[DataT]
|
||||
pagination: Pagination
|
||||
Reference in New Issue
Block a user