4 Commits

Author SHA1 Message Date
0f50c8a0f0 Version 0.5.0 2026-02-03 09:12:20 -05:00
d3vyce
691fb78fda feat: add include_registry to FixtureRegistry + add context default to the registry (#25) 2026-02-03 14:59:36 +01:00
d3vyce
34ef4da317 feat: simplify CLI feature (#23)
* chore: cleanup + add tests

* chore: remove graph and show fixtures commands

* feat: add async_command wrapper
2026-02-03 14:35:15 +01:00
d3vyce
8c287b3ce7 feat: add join to crud functions (#21) 2026-02-01 15:01:10 +01:00
14 changed files with 1087 additions and 244 deletions

View File

@@ -1,6 +1,6 @@
[project]
name = "fastapi-toolsets"
version = "0.4.1"
version = "0.5.0"
description = "Reusable tools for FastAPI: async CRUD, fixtures, CLI, and standardized responses for SQLAlchemy + PostgreSQL"
readme = "README.md"
license = "MIT"
@@ -59,7 +59,7 @@ dev = [
]
[project.scripts]
fastapi-toolsets = "fastapi_toolsets.cli:app"
manager = "fastapi_toolsets.cli:cli"
[build-system]
requires = ["uv_build>=0.9.26,<0.10.0"]

View File

@@ -21,4 +21,4 @@ Example usage:
return Response(data={"user": user.username}, message="Success")
"""
__version__ = "0.4.1"
__version__ = "0.5.0"

View File

@@ -1,5 +1,6 @@
"""CLI for FastAPI projects."""
from .app import app, register_command
from .app import cli
from .utils import async_command
__all__ = ["app", "register_command"]
__all__ = ["async_command", "cli"]

View File

@@ -1,97 +1,25 @@
"""Main CLI application."""
import importlib.util
import sys
from pathlib import Path
from typing import Annotated
import typer
from .commands import fixtures
from .config import load_config
app = typer.Typer(
name="fastapi-utils",
cli = typer.Typer(
name="manager",
help="CLI utilities for FastAPI projects.",
no_args_is_help=True,
)
# Register built-in commands
app.add_typer(fixtures.app, name="fixtures")
_config = load_config()
if _config.fixtures:
from .commands.fixtures import fixture_cli
cli.add_typer(fixture_cli, name="fixtures")
def register_command(command: typer.Typer, name: str) -> None:
"""Register a custom command group.
Args:
command: Typer app for the command group
name: Name for the command group
Example:
# In your project's cli.py:
import typer
from fastapi_toolsets.cli import app, register_command
my_commands = typer.Typer()
@my_commands.command()
def seed():
'''Seed the database.'''
...
register_command(my_commands, "db")
# Now available as: fastapi-utils db seed
"""
app.add_typer(command, name=name)
@app.callback()
def main(
ctx: typer.Context,
config: Annotated[
Path | None,
typer.Option(
"--config",
"-c",
help="Path to project config file (Python module with fixtures registry).",
envvar="FASTAPI_TOOLSETS_CONFIG",
),
] = None,
) -> None:
@cli.callback()
def main(ctx: typer.Context) -> None:
"""FastAPI utilities CLI."""
ctx.ensure_object(dict)
if config:
ctx.obj["config_path"] = config
# Load the config module
config_module = _load_module_from_path(config)
ctx.obj["config_module"] = config_module
def _load_module_from_path(path: Path) -> object:
"""Load a Python module from a file path.
Handles both absolute and relative imports by adding the config's
parent directory to sys.path temporarily.
"""
path = path.resolve()
# Add the parent directory to sys.path to support relative imports
parent_dir = str(
path.parent.parent
) # Go up two levels (e.g., from app/cli_config.py to project root)
if parent_dir not in sys.path:
sys.path.insert(0, parent_dir)
# Also add immediate parent for direct module imports
immediate_parent = str(path.parent)
if immediate_parent not in sys.path:
sys.path.insert(0, immediate_parent)
spec = importlib.util.spec_from_file_location("config", path)
if spec is None or spec.loader is None:
raise typer.BadParameter(f"Cannot load module from {path}")
module = importlib.util.module_from_spec(spec)
sys.modules["config"] = module
spec.loader.exec_module(module)
return module
ctx.obj["config"] = _config

View File

@@ -1,55 +1,29 @@
"""Fixture management commands."""
import asyncio
from typing import Annotated
import typer
from rich.console import Console
from rich.table import Table
from ...fixtures import Context, FixtureRegistry, LoadStrategy, load_fixtures_by_context
from ...fixtures import Context, LoadStrategy, load_fixtures_by_context
from ..config import CliConfig
from ..utils import async_command
app = typer.Typer(
fixture_cli = typer.Typer(
name="fixtures",
help="Manage database fixtures.",
no_args_is_help=True,
)
console = Console()
def _get_registry(ctx: typer.Context) -> FixtureRegistry:
"""Get fixture registry from context."""
config = ctx.obj.get("config_module") if ctx.obj else None
if config is None:
raise typer.BadParameter(
"No config provided. Use --config to specify a config file with a 'fixtures' registry."
)
registry = getattr(config, "fixtures", None)
if registry is None:
raise typer.BadParameter(
"Config module must have a 'fixtures' attribute (FixtureRegistry instance)."
)
if not isinstance(registry, FixtureRegistry):
raise typer.BadParameter(
f"'fixtures' must be a FixtureRegistry instance, got {type(registry).__name__}"
)
return registry
def _get_config(ctx: typer.Context) -> CliConfig:
"""Get CLI config from context."""
return ctx.obj["config"]
def _get_db_context(ctx: typer.Context):
"""Get database context manager from config."""
config = ctx.obj.get("config_module") if ctx.obj else None
if config is None:
raise typer.BadParameter("No config provided.")
get_db_context = getattr(config, "get_db_context", None)
if get_db_context is None:
raise typer.BadParameter("Config module must have a 'get_db_context' function.")
return get_db_context
@app.command("list")
@fixture_cli.command("list")
def list_fixtures(
ctx: typer.Context,
context: Annotated[
@@ -62,64 +36,28 @@ def list_fixtures(
] = None,
) -> None:
"""List all registered fixtures."""
registry = _get_registry(ctx)
if context:
fixtures = registry.get_by_context(context)
else:
fixtures = registry.get_all()
config = _get_config(ctx)
registry = config.get_fixtures_registry()
fixtures = registry.get_by_context(context) if context else registry.get_all()
if not fixtures:
typer.echo("No fixtures found.")
print("No fixtures found.")
return
typer.echo(f"\n{'Name':<30} {'Contexts':<30} {'Dependencies'}")
typer.echo("-" * 80)
table = Table("Name", "Contexts", "Dependencies")
for fixture in fixtures:
contexts = ", ".join(fixture.contexts)
deps = ", ".join(fixture.depends_on) if fixture.depends_on else "-"
typer.echo(f"{fixture.name:<30} {contexts:<30} {deps}")
table.add_row(fixture.name, contexts, deps)
typer.echo(f"\nTotal: {len(fixtures)} fixture(s)")
console.print(table)
print(f"\nTotal: {len(fixtures)} fixture(s)")
@app.command("graph")
def show_graph(
ctx: typer.Context,
fixture_name: Annotated[
str | None,
typer.Argument(help="Show dependencies for a specific fixture."),
] = None,
) -> None:
"""Show fixture dependency graph."""
registry = _get_registry(ctx)
if fixture_name:
try:
order = registry.resolve_dependencies(fixture_name)
typer.echo(f"\nDependency chain for '{fixture_name}':\n")
for i, name in enumerate(order):
indent = " " * i
arrow = "└─> " if i > 0 else ""
typer.echo(f"{indent}{arrow}{name}")
except KeyError:
typer.echo(f"Fixture '{fixture_name}' not found.", err=True)
raise typer.Exit(1)
else:
# Show full graph
fixtures = registry.get_all()
typer.echo("\nFixture Dependency Graph:\n")
for fixture in fixtures:
deps = (
f" -> [{', '.join(fixture.depends_on)}]" if fixture.depends_on else ""
)
typer.echo(f" {fixture.name}{deps}")
@app.command("load")
def load(
@fixture_cli.command("load")
@async_command
async def load(
ctx: typer.Context,
contexts: Annotated[
list[str] | None,
@@ -141,16 +79,12 @@ def load(
] = False,
) -> None:
"""Load fixtures into the database."""
registry = _get_registry(ctx)
get_db_context = _get_db_context(ctx)
config = _get_config(ctx)
registry = config.get_fixtures_registry()
get_db_context = config.get_db_context()
# Parse contexts
if contexts:
context_list = contexts
else:
context_list = [Context.BASE]
context_list = contexts if contexts else [Context.BASE]
# Parse strategy
try:
load_strategy = LoadStrategy(strategy)
except ValueError:
@@ -159,67 +93,27 @@ def load(
)
raise typer.Exit(1)
# Resolve what will be loaded
ordered = registry.resolve_context_dependencies(*context_list)
if not ordered:
typer.echo("No fixtures to load for the specified context(s).")
print("No fixtures to load for the specified context(s).")
return
typer.echo(f"\nFixtures to load ({load_strategy.value} strategy):")
print(f"\nFixtures to load ({load_strategy.value} strategy):")
for name in ordered:
fixture = registry.get(name)
instances = list(fixture.func())
model_name = type(instances[0]).__name__ if instances else "?"
typer.echo(f" - {name}: {len(instances)} {model_name}(s)")
print(f" - {name}: {len(instances)} {model_name}(s)")
if dry_run:
typer.echo("\n[Dry run - no changes made]")
print("\n[Dry run - no changes made]")
return
typer.echo("\nLoading...")
async def do_load():
async with get_db_context() as session:
result = await load_fixtures_by_context(
session, registry, *context_list, strategy=load_strategy
)
return result
result = asyncio.run(do_load())
total = sum(len(items) for items in result.values())
typer.echo(f"\nLoaded {total} record(s) successfully.")
@app.command("show")
def show_fixture(
ctx: typer.Context,
name: Annotated[str, typer.Argument(help="Fixture name to show.")],
) -> None:
"""Show details of a specific fixture."""
registry = _get_registry(ctx)
try:
fixture = registry.get(name)
except KeyError:
typer.echo(f"Fixture '{name}' not found.", err=True)
raise typer.Exit(1)
typer.echo(f"\nFixture: {fixture.name}")
typer.echo(f"Contexts: {', '.join(fixture.contexts)}")
typer.echo(
f"Dependencies: {', '.join(fixture.depends_on) if fixture.depends_on else 'None'}"
)
# Show instances
instances = list(fixture.func())
if instances:
model_name = type(instances[0]).__name__
typer.echo(f"\nInstances ({len(instances)} {model_name}):")
for instance in instances[:10]: # Limit to 10
typer.echo(f" - {instance!r}")
if len(instances) > 10:
typer.echo(f" ... and {len(instances) - 10} more")
else:
typer.echo("\nNo instances (empty fixture)")
print(f"\nLoaded {total} record(s) successfully.")

View File

@@ -0,0 +1,92 @@
"""CLI configuration."""
import importlib
import sys
import tomllib
from dataclasses import dataclass
from pathlib import Path
import typer
@dataclass
class CliConfig:
"""CLI configuration loaded from pyproject.toml."""
fixtures: str | None = None
db_context: str | None = None
def get_fixtures_registry(self):
"""Import and return the fixtures registry."""
from ..fixtures import FixtureRegistry
if not self.fixtures:
raise typer.BadParameter(
"No fixtures registry configured. "
"Add 'fixtures' to [tool.fastapi-toolsets] in pyproject.toml."
)
registry = _import_from_string(self.fixtures)
if not isinstance(registry, FixtureRegistry):
raise typer.BadParameter(
f"'fixtures' must be a FixtureRegistry instance, got {type(registry).__name__}"
)
return registry
def get_db_context(self):
"""Import and return the db_context function."""
if not self.db_context:
raise typer.BadParameter(
"No db_context configured. "
"Add 'db_context' to [tool.fastapi-toolsets] in pyproject.toml."
)
return _import_from_string(self.db_context)
def _import_from_string(import_path: str):
"""Import an object from a string path like 'module.submodule:attribute'."""
if ":" not in import_path:
raise typer.BadParameter(
f"Invalid import path '{import_path}'. Expected format: 'module:attribute'"
)
module_path, attr_name = import_path.rsplit(":", 1)
# Add cwd to sys.path for local imports
cwd = str(Path.cwd())
if cwd not in sys.path:
sys.path.insert(0, cwd)
try:
module = importlib.import_module(module_path)
except ImportError as e:
raise typer.BadParameter(f"Cannot import module '{module_path}': {e}")
if not hasattr(module, attr_name):
raise typer.BadParameter(
f"Module '{module_path}' has no attribute '{attr_name}'"
)
return getattr(module, attr_name)
def load_config() -> CliConfig:
"""Load CLI configuration from pyproject.toml."""
pyproject_path = Path.cwd() / "pyproject.toml"
if not pyproject_path.exists():
return CliConfig()
try:
with open(pyproject_path, "rb") as f:
data = tomllib.load(f)
tool_config = data.get("tool", {}).get("fastapi-toolsets", {})
return CliConfig(
fixtures=tool_config.get("fixtures"),
db_context=tool_config.get("db_context"),
)
except Exception:
return CliConfig()

View File

@@ -0,0 +1,27 @@
"""CLI utility functions."""
import asyncio
import functools
from collections.abc import Callable, Coroutine
from typing import Any, ParamSpec, TypeVar
P = ParamSpec("P")
T = TypeVar("T")
def async_command(func: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, T]:
"""Decorator to run an async function as a sync CLI command.
Example:
@fixture_cli.command("load")
@async_command
async def load(ctx: typer.Context) -> None:
async with get_db_context() as session:
await load_fixtures(session, registry)
"""
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
return asyncio.run(func(*args, **kwargs))
return wrapper

View File

@@ -4,7 +4,6 @@ from ..exceptions import NoSearchableFieldsError
from .factory import CrudFactory
from .search import (
SearchConfig,
SearchFieldType,
get_searchable_fields,
)
@@ -13,5 +12,4 @@ __all__ = [
"get_searchable_fields",
"NoSearchableFieldsError",
"SearchConfig",
"SearchFieldType",
]

View File

@@ -17,6 +17,7 @@ from ..exceptions import NotFoundError
from .search import SearchConfig, SearchFieldType, build_search_filters
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
JoinType = list[tuple[type[DeclarativeBase], Any]]
class AsyncCrud(Generic[ModelType]):
@@ -55,6 +56,8 @@ class AsyncCrud(Generic[ModelType]):
session: AsyncSession,
filters: list[Any],
*,
joins: JoinType | None = None,
outer_join: bool = False,
with_for_update: bool = False,
load_options: list[Any] | None = None,
) -> ModelType:
@@ -63,6 +66,8 @@ class AsyncCrud(Generic[ModelType]):
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
with_for_update: Lock the row for update
load_options: SQLAlchemy loader options (e.g., selectinload)
@@ -73,7 +78,15 @@ class AsyncCrud(Generic[ModelType]):
NotFoundError: If no record found
MultipleResultsFound: If more than one record found
"""
q = select(cls.model).where(and_(*filters))
q = select(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
q = q.where(and_(*filters))
if load_options:
q = q.options(*load_options)
if with_for_update:
@@ -90,6 +103,8 @@ class AsyncCrud(Generic[ModelType]):
session: AsyncSession,
filters: list[Any] | None = None,
*,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: list[Any] | None = None,
) -> ModelType | None:
"""Get the first matching record, or None.
@@ -97,12 +112,21 @@ class AsyncCrud(Generic[ModelType]):
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
load_options: SQLAlchemy loader options
Returns:
Model instance or None
"""
q = select(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
if filters:
q = q.where(and_(*filters))
if load_options:
@@ -116,6 +140,8 @@ class AsyncCrud(Generic[ModelType]):
session: AsyncSession,
*,
filters: list[Any] | None = None,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: list[Any] | None = None,
order_by: Any | None = None,
limit: int | None = None,
@@ -126,6 +152,8 @@ class AsyncCrud(Generic[ModelType]):
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
load_options: SQLAlchemy loader options
order_by: Column or list of columns to order by
limit: Max number of rows to return
@@ -135,6 +163,13 @@ class AsyncCrud(Generic[ModelType]):
List of model instances
"""
q = select(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
if filters:
q = q.where(and_(*filters))
if load_options:
@@ -254,17 +289,29 @@ class AsyncCrud(Generic[ModelType]):
cls: type[Self],
session: AsyncSession,
filters: list[Any] | None = None,
*,
joins: JoinType | None = None,
outer_join: bool = False,
) -> int:
"""Count records matching the filters.
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
Returns:
Number of matching records
"""
q = select(func.count()).select_from(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
if filters:
q = q.where(and_(*filters))
result = await session.execute(q)
@@ -275,17 +322,30 @@ class AsyncCrud(Generic[ModelType]):
cls: type[Self],
session: AsyncSession,
filters: list[Any],
*,
joins: JoinType | None = None,
outer_join: bool = False,
) -> bool:
"""Check if a record exists.
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
Returns:
True if at least one record matches
"""
q = select(cls.model).where(and_(*filters)).exists().select()
q = select(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
q = q.where(and_(*filters)).exists().select()
result = await session.execute(q)
return bool(result.scalar())
@@ -295,6 +355,8 @@ class AsyncCrud(Generic[ModelType]):
session: AsyncSession,
*,
filters: list[Any] | None = None,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: list[Any] | None = None,
order_by: Any | None = None,
page: int = 1,
@@ -307,6 +369,8 @@ class AsyncCrud(Generic[ModelType]):
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
load_options: SQLAlchemy loader options
order_by: Column or list of columns to order by
page: Page number (1-indexed)
@@ -319,7 +383,7 @@ class AsyncCrud(Generic[ModelType]):
"""
filters = list(filters) if filters else []
offset = (page - 1) * items_per_page
joins: list[Any] = []
search_joins: list[Any] = []
# Build search filters
if search:
@@ -330,11 +394,21 @@ class AsyncCrud(Generic[ModelType]):
default_fields=cls.searchable_fields,
)
filters.extend(search_filters)
joins.extend(search_joins)
# Build query with joins
q = select(cls.model)
for join_rel in joins:
# Apply explicit joins
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
# Apply search joins (always outer joins for search)
for join_rel in search_joins:
q = q.outerjoin(join_rel)
if filters:
@@ -352,8 +426,20 @@ class AsyncCrud(Generic[ModelType]):
pk_col = cls.model.__mapper__.primary_key[0]
count_q = select(func.count(func.distinct(getattr(cls.model, pk_col.name))))
count_q = count_q.select_from(cls.model)
for join_rel in joins:
# Apply explicit joins to count query
if joins:
for model, condition in joins:
count_q = (
count_q.outerjoin(model, condition)
if outer_join
else count_q.join(model, condition)
)
# Apply search joins to count query
for join_rel in search_joins:
count_q = count_q.outerjoin(join_rel)
if filters:
count_q = count_q.where(and_(*filters))
@@ -404,6 +490,20 @@ def CrudFactory(
# With search
result = await UserCrud.paginate(session, search="john")
# With joins (inner join by default):
users = await UserCrud.get_multi(
session,
joins=[(Post, Post.user_id == User.id)],
filters=[Post.published == True],
)
# With outer join:
users = await UserCrud.get_multi(
session,
joins=[(Post, Post.user_id == User.id)],
outer_join=True,
)
"""
cls = type(
f"Async{model.__name__}Crud",

View File

@@ -50,8 +50,16 @@ class FixtureRegistry:
]
"""
def __init__(self) -> None:
def __init__(
self,
contexts: list[str | Context] | None = None,
) -> None:
self._fixtures: dict[str, Fixture] = {}
self._default_contexts: list[str] | None = (
[c.value if isinstance(c, Context) else c for c in contexts]
if contexts
else None
)
def register(
self,
@@ -85,10 +93,14 @@ class FixtureRegistry:
fn: Callable[[], Sequence[DeclarativeBase]],
) -> Callable[[], Sequence[DeclarativeBase]]:
fixture_name = name or cast(Any, fn).__name__
if contexts is not None:
fixture_contexts = [
c.value if isinstance(c, Context) else c
for c in (contexts or [Context.BASE])
c.value if isinstance(c, Context) else c for c in contexts
]
elif self._default_contexts is not None:
fixture_contexts = self._default_contexts
else:
fixture_contexts = [Context.BASE.value]
self._fixtures[fixture_name] = Fixture(
name=fixture_name,
@@ -102,6 +114,32 @@ class FixtureRegistry:
return decorator(func)
return decorator
def include_registry(self, registry: "FixtureRegistry") -> None:
"""Include another `FixtureRegistry` in the same current `FixtureRegistry`.
Args:
registry: The `FixtureRegistry` to include
Raises:
ValueError: If a fixture name already exists in the current registry
Example:
registry = FixtureRegistry()
dev_registry = FixtureRegistry()
@dev_registry.register
def dev_data():
return [...]
registry.include_registry(registry=dev_registry)
"""
for name, fixture in registry._fixtures.items():
if name in self._fixtures:
raise ValueError(
f"Fixture '{name}' already exists in the current registry"
)
self._fixtures[name] = fixture
def get(self, name: str) -> Fixture:
"""Get a fixture by name."""
if name not in self._fixtures:

322
tests/test_cli.py Normal file
View File

@@ -0,0 +1,322 @@
"""Tests for fastapi_toolsets.cli module."""
import sys
import pytest
from typer.testing import CliRunner
from fastapi_toolsets.cli.config import CliConfig, _import_from_string, load_config
from fastapi_toolsets.cli.utils import async_command
from fastapi_toolsets.fixtures import FixtureRegistry
runner = CliRunner()
class TestCliConfig:
"""Tests for CliConfig dataclass."""
def test_default_values(self):
"""Config has None defaults."""
config = CliConfig()
assert config.fixtures is None
assert config.db_context is None
def test_with_values(self):
"""Config stores provided values."""
config = CliConfig(
fixtures="app.fixtures:registry",
db_context="app.db:get_session",
)
assert config.fixtures == "app.fixtures:registry"
assert config.db_context == "app.db:get_session"
def test_get_fixtures_registry_without_config(self):
"""get_fixtures_registry raises error when not configured."""
config = CliConfig()
with pytest.raises(Exception) as exc_info:
config.get_fixtures_registry()
assert "No fixtures registry configured" in str(exc_info.value)
def test_get_db_context_without_config(self):
"""get_db_context raises error when not configured."""
config = CliConfig()
with pytest.raises(Exception) as exc_info:
config.get_db_context()
assert "No db_context configured" in str(exc_info.value)
class TestImportFromString:
"""Tests for _import_from_string function."""
def test_import_valid_path(self):
"""Import valid module:attribute path."""
result = _import_from_string("fastapi_toolsets.fixtures:FixtureRegistry")
assert result is FixtureRegistry
def test_import_without_colon_raises_error(self):
"""Import path without colon raises error."""
with pytest.raises(Exception) as exc_info:
_import_from_string("fastapi_toolsets.fixtures.FixtureRegistry")
assert "Expected format: 'module:attribute'" in str(exc_info.value)
def test_import_nonexistent_module_raises_error(self):
"""Import nonexistent module raises error."""
with pytest.raises(Exception) as exc_info:
_import_from_string("nonexistent.module:something")
assert "Cannot import module" in str(exc_info.value)
def test_import_nonexistent_attribute_raises_error(self):
"""Import nonexistent attribute raises error."""
with pytest.raises(Exception) as exc_info:
_import_from_string("fastapi_toolsets.fixtures:NonexistentClass")
assert "has no attribute" in str(exc_info.value)
class TestLoadConfig:
"""Tests for load_config function."""
def test_load_without_pyproject(self, tmp_path, monkeypatch):
"""Returns empty config when no pyproject.toml exists."""
monkeypatch.chdir(tmp_path)
config = load_config()
assert config.fixtures is None
assert config.db_context is None
def test_load_without_tool_section(self, tmp_path, monkeypatch):
"""Returns empty config when no [tool.fastapi-toolsets] section."""
pyproject = tmp_path / "pyproject.toml"
pyproject.write_text("[project]\nname = 'test'\n")
monkeypatch.chdir(tmp_path)
config = load_config()
assert config.fixtures is None
assert config.db_context is None
def test_load_with_fixtures_config(self, tmp_path, monkeypatch):
"""Loads fixtures config from pyproject.toml."""
pyproject = tmp_path / "pyproject.toml"
pyproject.write_text(
'[tool.fastapi-toolsets]\nfixtures = "app.fixtures:registry"\n'
)
monkeypatch.chdir(tmp_path)
config = load_config()
assert config.fixtures == "app.fixtures:registry"
assert config.db_context is None
def test_load_with_full_config(self, tmp_path, monkeypatch):
"""Loads full config from pyproject.toml."""
pyproject = tmp_path / "pyproject.toml"
pyproject.write_text(
"[tool.fastapi-toolsets]\n"
'fixtures = "app.fixtures:registry"\n'
'db_context = "app.db:get_session"\n'
)
monkeypatch.chdir(tmp_path)
config = load_config()
assert config.fixtures == "app.fixtures:registry"
assert config.db_context == "app.db:get_session"
def test_load_with_invalid_toml(self, tmp_path, monkeypatch):
"""Returns empty config when pyproject.toml is invalid."""
pyproject = tmp_path / "pyproject.toml"
pyproject.write_text("invalid toml {{{")
monkeypatch.chdir(tmp_path)
config = load_config()
assert config.fixtures is None
class TestCliApp:
"""Tests for CLI application."""
def test_cli_import(self):
"""CLI can be imported."""
from fastapi_toolsets.cli import cli
assert cli is not None
def test_cli_help(self, tmp_path, monkeypatch):
"""CLI shows help without fixtures."""
monkeypatch.chdir(tmp_path)
# Need to reload the module to pick up new cwd
import importlib
from fastapi_toolsets.cli import app
importlib.reload(app)
result = runner.invoke(app.cli, ["--help"])
assert result.exit_code == 0
assert "CLI utilities for FastAPI projects" in result.output
class TestFixturesCli:
"""Tests for fixtures CLI commands."""
@pytest.fixture
def cli_env(self, tmp_path, monkeypatch):
"""Set up CLI environment with fixtures config."""
# Create pyproject.toml
pyproject = tmp_path / "pyproject.toml"
pyproject.write_text(
"[tool.fastapi-toolsets]\n"
'fixtures = "fixtures:registry"\n'
'db_context = "db:get_session"\n'
)
# Create fixtures module
fixtures_file = tmp_path / "fixtures.py"
fixtures_file.write_text(
"from fastapi_toolsets.fixtures import FixtureRegistry, Context\n"
"\n"
"registry = FixtureRegistry()\n"
"\n"
"@registry.register(contexts=[Context.BASE])\n"
"def roles():\n"
' return [{"id": 1, "name": "admin"}, {"id": 2, "name": "user"}]\n'
"\n"
'@registry.register(depends_on=["roles"], contexts=[Context.TESTING])\n'
"def users():\n"
' return [{"id": 1, "name": "alice", "role_id": 1}]\n'
)
# Create db module
db_file = tmp_path / "db.py"
db_file.write_text(
"from contextlib import asynccontextmanager\n"
"\n"
"@asynccontextmanager\n"
"async def get_session():\n"
" yield None\n"
)
monkeypatch.chdir(tmp_path)
# Add tmp_path to sys.path for imports
if str(tmp_path) not in sys.path:
sys.path.insert(0, str(tmp_path))
# Reload the CLI module to pick up new config
import importlib
from fastapi_toolsets.cli import app, config
importlib.reload(config)
importlib.reload(app)
yield tmp_path, app.cli
# Cleanup
if str(tmp_path) in sys.path:
sys.path.remove(str(tmp_path))
def test_fixtures_list(self, cli_env):
"""fixtures list shows registered fixtures."""
tmp_path, cli = cli_env
result = runner.invoke(cli, ["fixtures", "list"])
assert result.exit_code == 0
assert "roles" in result.output
assert "users" in result.output
assert "Total: 2 fixture(s)" in result.output
def test_fixtures_list_with_context(self, cli_env):
"""fixtures list --context filters by context."""
tmp_path, cli = cli_env
result = runner.invoke(cli, ["fixtures", "list", "--context", "base"])
assert result.exit_code == 0
assert "roles" in result.output
assert "users" not in result.output
assert "Total: 1 fixture(s)" in result.output
def test_fixtures_load_dry_run(self, cli_env):
"""fixtures load --dry-run shows what would be loaded."""
tmp_path, cli = cli_env
result = runner.invoke(cli, ["fixtures", "load", "base", "--dry-run"])
assert result.exit_code == 0
assert "Fixtures to load" in result.output
assert "roles" in result.output
assert "[Dry run - no changes made]" in result.output
def test_fixtures_load_invalid_strategy(self, cli_env):
"""fixtures load with invalid strategy shows error."""
tmp_path, cli = cli_env
result = runner.invoke(
cli, ["fixtures", "load", "base", "--strategy", "invalid"]
)
assert result.exit_code == 1
assert "Invalid strategy" in result.output
class TestCliWithoutFixturesConfig:
"""Tests for CLI when fixtures is not configured."""
def test_no_fixtures_command(self, tmp_path, monkeypatch):
"""fixtures command is not available when not configured."""
# Create pyproject.toml without fixtures
pyproject = tmp_path / "pyproject.toml"
pyproject.write_text('[project]\nname = "test"\n')
monkeypatch.chdir(tmp_path)
# Reload the CLI module
import importlib
from fastapi_toolsets.cli import app, config
importlib.reload(config)
importlib.reload(app)
result = runner.invoke(app.cli, ["--help"])
assert result.exit_code == 0
assert "fixtures" not in result.output
class TestAsyncCommand:
"""Tests for async_command decorator."""
def test_async_command_runs_coroutine(self):
"""async_command runs async function synchronously."""
@async_command
async def async_func(value: int) -> int:
return value * 2
result = async_func(21)
assert result == 42
def test_async_command_preserves_signature(self):
"""async_command preserves function signature."""
@async_command
async def async_func(name: str, count: int = 1) -> str:
return f"{name} x {count}"
result = async_func("test", count=3)
assert result == "test x 3"
def test_async_command_preserves_docstring(self):
"""async_command preserves function docstring."""
@async_command
async def async_func() -> None:
"""This is a docstring."""
pass
assert async_func.__doc__ == """This is a docstring."""
def test_async_command_preserves_name(self):
"""async_command preserves function name."""
@async_command
async def my_async_function() -> None:
pass
assert my_async_function.__name__ == "my_async_function"

View File

@@ -10,6 +10,9 @@ from fastapi_toolsets.crud.factory import AsyncCrud
from fastapi_toolsets.exceptions import NotFoundError
from .conftest import (
Post,
PostCreate,
PostCrud,
Role,
RoleCreate,
RoleCrud,
@@ -481,3 +484,271 @@ class TestCrudPaginate:
names = [r.name for r in result["data"]]
assert names == ["alpha", "bravo", "charlie"]
class TestCrudJoins:
"""Tests for CRUD operations with joins."""
@pytest.mark.anyio
async def test_get_with_join(self, db_session: AsyncSession):
"""Get with inner join filters correctly."""
# Create user with posts
user = await UserCrud.create(
db_session,
UserCreate(username="author", email="author@test.com"),
)
await PostCrud.create(
db_session,
PostCreate(title="Post 1", author_id=user.id, is_published=True),
)
# Get user with join on published posts
fetched = await UserCrud.get(
db_session,
filters=[User.id == user.id, Post.is_published == True], # noqa: E712
joins=[(Post, Post.author_id == User.id)],
)
assert fetched.id == user.id
@pytest.mark.anyio
async def test_first_with_join(self, db_session: AsyncSession):
"""First with join returns matching record."""
user = await UserCrud.create(
db_session,
UserCreate(username="writer", email="writer@test.com"),
)
await PostCrud.create(
db_session,
PostCreate(title="Draft", author_id=user.id, is_published=False),
)
# Find user with unpublished posts
result = await UserCrud.first(
db_session,
filters=[Post.is_published == False], # noqa: E712
joins=[(Post, Post.author_id == User.id)],
)
assert result is not None
assert result.id == user.id
@pytest.mark.anyio
async def test_first_with_outer_join(self, db_session: AsyncSession):
"""First with outer join includes records without related data."""
# User without posts
user = await UserCrud.create(
db_session,
UserCreate(username="no_posts", email="no_posts@test.com"),
)
# With outer join, user should be found even without posts
result = await UserCrud.first(
db_session,
filters=[User.id == user.id],
joins=[(Post, Post.author_id == User.id)],
outer_join=True,
)
assert result is not None
assert result.id == user.id
@pytest.mark.anyio
async def test_get_multi_with_inner_join(self, db_session: AsyncSession):
"""Get multiple with inner join only returns matching records."""
# User with published post
user1 = await UserCrud.create(
db_session,
UserCreate(username="publisher", email="pub@test.com"),
)
await PostCrud.create(
db_session,
PostCreate(title="Published", author_id=user1.id, is_published=True),
)
# User without posts
await UserCrud.create(
db_session,
UserCreate(username="lurker", email="lurk@test.com"),
)
# Inner join should only return user with published post
users = await UserCrud.get_multi(
db_session,
joins=[(Post, Post.author_id == User.id)],
filters=[Post.is_published == True], # noqa: E712
)
assert len(users) == 1
assert users[0].username == "publisher"
@pytest.mark.anyio
async def test_get_multi_with_outer_join(self, db_session: AsyncSession):
"""Get multiple with outer join includes all records."""
# User with post
user1 = await UserCrud.create(
db_session,
UserCreate(username="has_post", email="has@test.com"),
)
await PostCrud.create(
db_session,
PostCreate(title="My Post", author_id=user1.id),
)
# User without posts
await UserCrud.create(
db_session,
UserCreate(username="no_post", email="no@test.com"),
)
# Outer join should return both users
users = await UserCrud.get_multi(
db_session,
joins=[(Post, Post.author_id == User.id)],
outer_join=True,
)
assert len(users) == 2
@pytest.mark.anyio
async def test_count_with_join(self, db_session: AsyncSession):
"""Count with join counts correctly."""
# Create users with different post statuses
user1 = await UserCrud.create(
db_session,
UserCreate(username="active_author", email="active@test.com"),
)
await PostCrud.create(
db_session,
PostCreate(title="Published 1", author_id=user1.id, is_published=True),
)
user2 = await UserCrud.create(
db_session,
UserCreate(username="draft_author", email="draft@test.com"),
)
await PostCrud.create(
db_session,
PostCreate(title="Draft 1", author_id=user2.id, is_published=False),
)
# Count users with published posts
count = await UserCrud.count(
db_session,
filters=[Post.is_published == True], # noqa: E712
joins=[(Post, Post.author_id == User.id)],
)
assert count == 1
@pytest.mark.anyio
async def test_exists_with_join(self, db_session: AsyncSession):
"""Exists with join checks correctly."""
user = await UserCrud.create(
db_session,
UserCreate(username="poster", email="poster@test.com"),
)
await PostCrud.create(
db_session,
PostCreate(title="Exists Post", author_id=user.id, is_published=True),
)
# Check if user with published post exists
exists = await UserCrud.exists(
db_session,
filters=[Post.is_published == True], # noqa: E712
joins=[(Post, Post.author_id == User.id)],
)
assert exists is True
# Check if user with specific title exists
exists = await UserCrud.exists(
db_session,
filters=[Post.title == "Nonexistent"],
joins=[(Post, Post.author_id == User.id)],
)
assert exists is False
@pytest.mark.anyio
async def test_paginate_with_join(self, db_session: AsyncSession):
"""Paginate with join works correctly."""
# Create users with posts
for i in range(5):
user = await UserCrud.create(
db_session,
UserCreate(username=f"author{i}", email=f"author{i}@test.com"),
)
await PostCrud.create(
db_session,
PostCreate(
title=f"Post {i}",
author_id=user.id,
is_published=i % 2 == 0,
),
)
# Paginate users with published posts
result = await UserCrud.paginate(
db_session,
joins=[(Post, Post.author_id == User.id)],
filters=[Post.is_published == True], # noqa: E712
page=1,
items_per_page=10,
)
assert result["pagination"]["total_count"] == 3
assert len(result["data"]) == 3
@pytest.mark.anyio
async def test_paginate_with_outer_join(self, db_session: AsyncSession):
"""Paginate with outer join includes all records."""
# User with post
user1 = await UserCrud.create(
db_session,
UserCreate(username="with_post", email="with@test.com"),
)
await PostCrud.create(
db_session,
PostCreate(title="A Post", author_id=user1.id),
)
# User without post
await UserCrud.create(
db_session,
UserCreate(username="without_post", email="without@test.com"),
)
# Paginate with outer join
result = await UserCrud.paginate(
db_session,
joins=[(Post, Post.author_id == User.id)],
outer_join=True,
page=1,
items_per_page=10,
)
assert result["pagination"]["total_count"] == 2
assert len(result["data"]) == 2
@pytest.mark.anyio
async def test_multiple_joins(self, db_session: AsyncSession):
"""Multiple joins can be applied."""
role = await RoleCrud.create(db_session, RoleCreate(name="author_role"))
user = await UserCrud.create(
db_session,
UserCreate(
username="multi_join",
email="multi@test.com",
role_id=role.id,
),
)
await PostCrud.create(
db_session,
PostCreate(title="Multi Join Post", author_id=user.id, is_published=True),
)
# Join both Role and Post
users = await UserCrud.get_multi(
db_session,
joins=[
(Role, Role.id == User.role_id),
(Post, Post.author_id == User.id),
],
filters=[Role.name == "author_role", Post.is_published == True], # noqa: E712
)
assert len(users) == 1
assert users[0].username == "multi_join"

View File

@@ -159,6 +159,178 @@ class TestFixtureRegistry:
assert names == {"test_data"}
class TestIncludeRegistry:
"""Tests for FixtureRegistry.include_registry method."""
def test_include_empty_registry(self):
"""Include an empty registry does nothing."""
main_registry = FixtureRegistry()
other_registry = FixtureRegistry()
@main_registry.register
def roles():
return []
main_registry.include_registry(other_registry)
assert len(main_registry.get_all()) == 1
def test_include_registry_adds_fixtures(self):
"""Include registry adds all fixtures from the other registry."""
main_registry = FixtureRegistry()
other_registry = FixtureRegistry()
@main_registry.register
def roles():
return []
@other_registry.register
def users():
return []
@other_registry.register
def posts():
return []
main_registry.include_registry(other_registry)
names = {f.name for f in main_registry.get_all()}
assert names == {"roles", "users", "posts"}
def test_include_registry_preserves_dependencies(self):
"""Include registry preserves fixture dependencies."""
main_registry = FixtureRegistry()
other_registry = FixtureRegistry()
@main_registry.register
def roles():
return []
@other_registry.register(depends_on=["roles"])
def users():
return []
main_registry.include_registry(other_registry)
fixture = main_registry.get("users")
assert fixture.depends_on == ["roles"]
def test_include_registry_preserves_contexts(self):
"""Include registry preserves fixture contexts."""
main_registry = FixtureRegistry()
other_registry = FixtureRegistry()
@other_registry.register(contexts=[Context.TESTING, Context.DEVELOPMENT])
def test_data():
return []
main_registry.include_registry(other_registry)
fixture = main_registry.get("test_data")
assert Context.TESTING.value in fixture.contexts
assert Context.DEVELOPMENT.value in fixture.contexts
def test_include_registry_raises_on_duplicate(self):
"""Include registry raises ValueError on duplicate fixture names."""
main_registry = FixtureRegistry()
other_registry = FixtureRegistry()
@main_registry.register(name="roles")
def roles_main():
return []
@other_registry.register(name="roles")
def roles_other():
return []
with pytest.raises(ValueError, match="already exists"):
main_registry.include_registry(other_registry)
def test_include_multiple_registries(self):
"""Include multiple registries sequentially."""
main_registry = FixtureRegistry()
dev_registry = FixtureRegistry()
test_registry = FixtureRegistry()
@main_registry.register
def base():
return []
@dev_registry.register
def dev_data():
return []
@test_registry.register
def test_data():
return []
main_registry.include_registry(dev_registry)
main_registry.include_registry(test_registry)
names = {f.name for f in main_registry.get_all()}
assert names == {"base", "dev_data", "test_data"}
class TestDefaultContexts:
"""Tests for FixtureRegistry default contexts."""
def test_default_contexts_applied_to_fixtures(self):
"""Default contexts are applied when no contexts specified."""
registry = FixtureRegistry(contexts=[Context.TESTING])
@registry.register
def test_data():
return []
fixture = registry.get("test_data")
assert fixture.contexts == [Context.TESTING.value]
def test_explicit_contexts_override_default(self):
"""Explicit contexts override default contexts."""
registry = FixtureRegistry(contexts=[Context.TESTING])
@registry.register(contexts=[Context.PRODUCTION])
def prod_data():
return []
fixture = registry.get("prod_data")
assert fixture.contexts == [Context.PRODUCTION.value]
def test_no_default_contexts_uses_base(self):
"""Without default contexts, BASE is used."""
registry = FixtureRegistry()
@registry.register
def data():
return []
fixture = registry.get("data")
assert fixture.contexts == [Context.BASE.value]
def test_multiple_default_contexts(self):
"""Multiple default contexts are applied."""
registry = FixtureRegistry(contexts=[Context.DEVELOPMENT, Context.TESTING])
@registry.register
def dev_test_data():
return []
fixture = registry.get("dev_test_data")
assert Context.DEVELOPMENT.value in fixture.contexts
assert Context.TESTING.value in fixture.contexts
def test_default_contexts_with_string_values(self):
"""Default contexts work with string values."""
registry = FixtureRegistry(contexts=["custom_context"])
@registry.register
def custom_data():
return []
fixture = registry.get("custom_data")
assert fixture.contexts == ["custom_context"]
class TestDependencyResolution:
"""Tests for fixture dependency resolution."""

2
uv.lock generated
View File

@@ -220,7 +220,7 @@ wheels = [
[[package]]
name = "fastapi-toolsets"
version = "0.4.1"
version = "0.5.0"
source = { editable = "." }
dependencies = [
{ name = "asyncpg" },