8 Commits

Author SHA1 Message Date
0f50c8a0f0 Version 0.5.0 2026-02-03 09:12:20 -05:00
d3vyce
691fb78fda feat: add include_registry to FixtureRegistry + add context default to the registry (#25) 2026-02-03 14:59:36 +01:00
d3vyce
34ef4da317 feat: simplify CLI feature (#23)
* chore: cleanup + add tests

* chore: remove graph and show fixtures commands

* feat: add async_command wrapper
2026-02-03 14:35:15 +01:00
d3vyce
8c287b3ce7 feat: add join to crud functions (#21) 2026-02-01 15:01:10 +01:00
54f5479c24 Version 0.4.1 2026-01-29 14:15:55 -05:00
d3vyce
f467754df1 fix: cast to String non-text columns for crud search (#18)
fix: cast to String non-text columns for crud search
2026-01-29 19:44:48 +01:00
b57ce40b05 tests: change models to use UUID as primary key 2026-01-29 13:43:03 -05:00
5264631550 fix: cast to String non-text columns for crud search 2026-01-29 13:35:20 -05:00
19 changed files with 1295 additions and 325 deletions

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "fastapi-toolsets" name = "fastapi-toolsets"
version = "0.4.0" version = "0.5.0"
description = "Reusable tools for FastAPI: async CRUD, fixtures, CLI, and standardized responses for SQLAlchemy + PostgreSQL" description = "Reusable tools for FastAPI: async CRUD, fixtures, CLI, and standardized responses for SQLAlchemy + PostgreSQL"
readme = "README.md" readme = "README.md"
license = "MIT" license = "MIT"
@@ -59,7 +59,7 @@ dev = [
] ]
[project.scripts] [project.scripts]
fastapi-toolsets = "fastapi_toolsets.cli:app" manager = "fastapi_toolsets.cli:cli"
[build-system] [build-system]
requires = ["uv_build>=0.9.26,<0.10.0"] requires = ["uv_build>=0.9.26,<0.10.0"]

View File

@@ -21,4 +21,4 @@ Example usage:
return Response(data={"user": user.username}, message="Success") return Response(data={"user": user.username}, message="Success")
""" """
__version__ = "0.4.0" __version__ = "0.5.0"

View File

@@ -1,5 +1,6 @@
"""CLI for FastAPI projects.""" """CLI for FastAPI projects."""
from .app import app, register_command from .app import cli
from .utils import async_command
__all__ = ["app", "register_command"] __all__ = ["async_command", "cli"]

View File

@@ -1,97 +1,25 @@
"""Main CLI application.""" """Main CLI application."""
import importlib.util
import sys
from pathlib import Path
from typing import Annotated
import typer import typer
from .commands import fixtures from .config import load_config
app = typer.Typer( cli = typer.Typer(
name="fastapi-utils", name="manager",
help="CLI utilities for FastAPI projects.", help="CLI utilities for FastAPI projects.",
no_args_is_help=True, no_args_is_help=True,
) )
# Register built-in commands _config = load_config()
app.add_typer(fixtures.app, name="fixtures")
if _config.fixtures:
from .commands.fixtures import fixture_cli
cli.add_typer(fixture_cli, name="fixtures")
def register_command(command: typer.Typer, name: str) -> None: @cli.callback()
"""Register a custom command group. def main(ctx: typer.Context) -> None:
Args:
command: Typer app for the command group
name: Name for the command group
Example:
# In your project's cli.py:
import typer
from fastapi_toolsets.cli import app, register_command
my_commands = typer.Typer()
@my_commands.command()
def seed():
'''Seed the database.'''
...
register_command(my_commands, "db")
# Now available as: fastapi-utils db seed
"""
app.add_typer(command, name=name)
@app.callback()
def main(
ctx: typer.Context,
config: Annotated[
Path | None,
typer.Option(
"--config",
"-c",
help="Path to project config file (Python module with fixtures registry).",
envvar="FASTAPI_TOOLSETS_CONFIG",
),
] = None,
) -> None:
"""FastAPI utilities CLI.""" """FastAPI utilities CLI."""
ctx.ensure_object(dict) ctx.ensure_object(dict)
ctx.obj["config"] = _config
if config:
ctx.obj["config_path"] = config
# Load the config module
config_module = _load_module_from_path(config)
ctx.obj["config_module"] = config_module
def _load_module_from_path(path: Path) -> object:
"""Load a Python module from a file path.
Handles both absolute and relative imports by adding the config's
parent directory to sys.path temporarily.
"""
path = path.resolve()
# Add the parent directory to sys.path to support relative imports
parent_dir = str(
path.parent.parent
) # Go up two levels (e.g., from app/cli_config.py to project root)
if parent_dir not in sys.path:
sys.path.insert(0, parent_dir)
# Also add immediate parent for direct module imports
immediate_parent = str(path.parent)
if immediate_parent not in sys.path:
sys.path.insert(0, immediate_parent)
spec = importlib.util.spec_from_file_location("config", path)
if spec is None or spec.loader is None:
raise typer.BadParameter(f"Cannot load module from {path}")
module = importlib.util.module_from_spec(spec)
sys.modules["config"] = module
spec.loader.exec_module(module)
return module

View File

@@ -1,55 +1,29 @@
"""Fixture management commands.""" """Fixture management commands."""
import asyncio
from typing import Annotated from typing import Annotated
import typer import typer
from rich.console import Console
from rich.table import Table
from ...fixtures import Context, FixtureRegistry, LoadStrategy, load_fixtures_by_context from ...fixtures import Context, LoadStrategy, load_fixtures_by_context
from ..config import CliConfig
from ..utils import async_command
app = typer.Typer( fixture_cli = typer.Typer(
name="fixtures", name="fixtures",
help="Manage database fixtures.", help="Manage database fixtures.",
no_args_is_help=True, no_args_is_help=True,
) )
console = Console()
def _get_registry(ctx: typer.Context) -> FixtureRegistry: def _get_config(ctx: typer.Context) -> CliConfig:
"""Get fixture registry from context.""" """Get CLI config from context."""
config = ctx.obj.get("config_module") if ctx.obj else None return ctx.obj["config"]
if config is None:
raise typer.BadParameter(
"No config provided. Use --config to specify a config file with a 'fixtures' registry."
)
registry = getattr(config, "fixtures", None)
if registry is None:
raise typer.BadParameter(
"Config module must have a 'fixtures' attribute (FixtureRegistry instance)."
)
if not isinstance(registry, FixtureRegistry):
raise typer.BadParameter(
f"'fixtures' must be a FixtureRegistry instance, got {type(registry).__name__}"
)
return registry
def _get_db_context(ctx: typer.Context): @fixture_cli.command("list")
"""Get database context manager from config."""
config = ctx.obj.get("config_module") if ctx.obj else None
if config is None:
raise typer.BadParameter("No config provided.")
get_db_context = getattr(config, "get_db_context", None)
if get_db_context is None:
raise typer.BadParameter("Config module must have a 'get_db_context' function.")
return get_db_context
@app.command("list")
def list_fixtures( def list_fixtures(
ctx: typer.Context, ctx: typer.Context,
context: Annotated[ context: Annotated[
@@ -62,64 +36,28 @@ def list_fixtures(
] = None, ] = None,
) -> None: ) -> None:
"""List all registered fixtures.""" """List all registered fixtures."""
registry = _get_registry(ctx) config = _get_config(ctx)
registry = config.get_fixtures_registry()
if context: fixtures = registry.get_by_context(context) if context else registry.get_all()
fixtures = registry.get_by_context(context)
else:
fixtures = registry.get_all()
if not fixtures: if not fixtures:
typer.echo("No fixtures found.") print("No fixtures found.")
return return
typer.echo(f"\n{'Name':<30} {'Contexts':<30} {'Dependencies'}") table = Table("Name", "Contexts", "Dependencies")
typer.echo("-" * 80)
for fixture in fixtures: for fixture in fixtures:
contexts = ", ".join(fixture.contexts) contexts = ", ".join(fixture.contexts)
deps = ", ".join(fixture.depends_on) if fixture.depends_on else "-" deps = ", ".join(fixture.depends_on) if fixture.depends_on else "-"
typer.echo(f"{fixture.name:<30} {contexts:<30} {deps}") table.add_row(fixture.name, contexts, deps)
typer.echo(f"\nTotal: {len(fixtures)} fixture(s)") console.print(table)
print(f"\nTotal: {len(fixtures)} fixture(s)")
@app.command("graph") @fixture_cli.command("load")
def show_graph( @async_command
ctx: typer.Context, async def load(
fixture_name: Annotated[
str | None,
typer.Argument(help="Show dependencies for a specific fixture."),
] = None,
) -> None:
"""Show fixture dependency graph."""
registry = _get_registry(ctx)
if fixture_name:
try:
order = registry.resolve_dependencies(fixture_name)
typer.echo(f"\nDependency chain for '{fixture_name}':\n")
for i, name in enumerate(order):
indent = " " * i
arrow = "└─> " if i > 0 else ""
typer.echo(f"{indent}{arrow}{name}")
except KeyError:
typer.echo(f"Fixture '{fixture_name}' not found.", err=True)
raise typer.Exit(1)
else:
# Show full graph
fixtures = registry.get_all()
typer.echo("\nFixture Dependency Graph:\n")
for fixture in fixtures:
deps = (
f" -> [{', '.join(fixture.depends_on)}]" if fixture.depends_on else ""
)
typer.echo(f" {fixture.name}{deps}")
@app.command("load")
def load(
ctx: typer.Context, ctx: typer.Context,
contexts: Annotated[ contexts: Annotated[
list[str] | None, list[str] | None,
@@ -141,16 +79,12 @@ def load(
] = False, ] = False,
) -> None: ) -> None:
"""Load fixtures into the database.""" """Load fixtures into the database."""
registry = _get_registry(ctx) config = _get_config(ctx)
get_db_context = _get_db_context(ctx) registry = config.get_fixtures_registry()
get_db_context = config.get_db_context()
# Parse contexts context_list = contexts if contexts else [Context.BASE]
if contexts:
context_list = contexts
else:
context_list = [Context.BASE]
# Parse strategy
try: try:
load_strategy = LoadStrategy(strategy) load_strategy = LoadStrategy(strategy)
except ValueError: except ValueError:
@@ -159,67 +93,27 @@ def load(
) )
raise typer.Exit(1) raise typer.Exit(1)
# Resolve what will be loaded
ordered = registry.resolve_context_dependencies(*context_list) ordered = registry.resolve_context_dependencies(*context_list)
if not ordered: if not ordered:
typer.echo("No fixtures to load for the specified context(s).") print("No fixtures to load for the specified context(s).")
return return
typer.echo(f"\nFixtures to load ({load_strategy.value} strategy):") print(f"\nFixtures to load ({load_strategy.value} strategy):")
for name in ordered: for name in ordered:
fixture = registry.get(name) fixture = registry.get(name)
instances = list(fixture.func()) instances = list(fixture.func())
model_name = type(instances[0]).__name__ if instances else "?" model_name = type(instances[0]).__name__ if instances else "?"
typer.echo(f" - {name}: {len(instances)} {model_name}(s)") print(f" - {name}: {len(instances)} {model_name}(s)")
if dry_run: if dry_run:
typer.echo("\n[Dry run - no changes made]") print("\n[Dry run - no changes made]")
return return
typer.echo("\nLoading...") async with get_db_context() as session:
result = await load_fixtures_by_context(
async def do_load(): session, registry, *context_list, strategy=load_strategy
async with get_db_context() as session: )
result = await load_fixtures_by_context(
session, registry, *context_list, strategy=load_strategy
)
return result
result = asyncio.run(do_load())
total = sum(len(items) for items in result.values()) total = sum(len(items) for items in result.values())
typer.echo(f"\nLoaded {total} record(s) successfully.") print(f"\nLoaded {total} record(s) successfully.")
@app.command("show")
def show_fixture(
ctx: typer.Context,
name: Annotated[str, typer.Argument(help="Fixture name to show.")],
) -> None:
"""Show details of a specific fixture."""
registry = _get_registry(ctx)
try:
fixture = registry.get(name)
except KeyError:
typer.echo(f"Fixture '{name}' not found.", err=True)
raise typer.Exit(1)
typer.echo(f"\nFixture: {fixture.name}")
typer.echo(f"Contexts: {', '.join(fixture.contexts)}")
typer.echo(
f"Dependencies: {', '.join(fixture.depends_on) if fixture.depends_on else 'None'}"
)
# Show instances
instances = list(fixture.func())
if instances:
model_name = type(instances[0]).__name__
typer.echo(f"\nInstances ({len(instances)} {model_name}):")
for instance in instances[:10]: # Limit to 10
typer.echo(f" - {instance!r}")
if len(instances) > 10:
typer.echo(f" ... and {len(instances) - 10} more")
else:
typer.echo("\nNo instances (empty fixture)")

View File

@@ -0,0 +1,92 @@
"""CLI configuration."""
import importlib
import sys
import tomllib
from dataclasses import dataclass
from pathlib import Path
import typer
@dataclass
class CliConfig:
"""CLI configuration loaded from pyproject.toml."""
fixtures: str | None = None
db_context: str | None = None
def get_fixtures_registry(self):
"""Import and return the fixtures registry."""
from ..fixtures import FixtureRegistry
if not self.fixtures:
raise typer.BadParameter(
"No fixtures registry configured. "
"Add 'fixtures' to [tool.fastapi-toolsets] in pyproject.toml."
)
registry = _import_from_string(self.fixtures)
if not isinstance(registry, FixtureRegistry):
raise typer.BadParameter(
f"'fixtures' must be a FixtureRegistry instance, got {type(registry).__name__}"
)
return registry
def get_db_context(self):
"""Import and return the db_context function."""
if not self.db_context:
raise typer.BadParameter(
"No db_context configured. "
"Add 'db_context' to [tool.fastapi-toolsets] in pyproject.toml."
)
return _import_from_string(self.db_context)
def _import_from_string(import_path: str):
"""Import an object from a string path like 'module.submodule:attribute'."""
if ":" not in import_path:
raise typer.BadParameter(
f"Invalid import path '{import_path}'. Expected format: 'module:attribute'"
)
module_path, attr_name = import_path.rsplit(":", 1)
# Add cwd to sys.path for local imports
cwd = str(Path.cwd())
if cwd not in sys.path:
sys.path.insert(0, cwd)
try:
module = importlib.import_module(module_path)
except ImportError as e:
raise typer.BadParameter(f"Cannot import module '{module_path}': {e}")
if not hasattr(module, attr_name):
raise typer.BadParameter(
f"Module '{module_path}' has no attribute '{attr_name}'"
)
return getattr(module, attr_name)
def load_config() -> CliConfig:
"""Load CLI configuration from pyproject.toml."""
pyproject_path = Path.cwd() / "pyproject.toml"
if not pyproject_path.exists():
return CliConfig()
try:
with open(pyproject_path, "rb") as f:
data = tomllib.load(f)
tool_config = data.get("tool", {}).get("fastapi-toolsets", {})
return CliConfig(
fixtures=tool_config.get("fixtures"),
db_context=tool_config.get("db_context"),
)
except Exception:
return CliConfig()

View File

@@ -0,0 +1,27 @@
"""CLI utility functions."""
import asyncio
import functools
from collections.abc import Callable, Coroutine
from typing import Any, ParamSpec, TypeVar
P = ParamSpec("P")
T = TypeVar("T")
def async_command(func: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, T]:
"""Decorator to run an async function as a sync CLI command.
Example:
@fixture_cli.command("load")
@async_command
async def load(ctx: typer.Context) -> None:
async with get_db_context() as session:
await load_fixtures(session, registry)
"""
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
return asyncio.run(func(*args, **kwargs))
return wrapper

View File

@@ -4,7 +4,6 @@ from ..exceptions import NoSearchableFieldsError
from .factory import CrudFactory from .factory import CrudFactory
from .search import ( from .search import (
SearchConfig, SearchConfig,
SearchFieldType,
get_searchable_fields, get_searchable_fields,
) )
@@ -13,5 +12,4 @@ __all__ = [
"get_searchable_fields", "get_searchable_fields",
"NoSearchableFieldsError", "NoSearchableFieldsError",
"SearchConfig", "SearchConfig",
"SearchFieldType",
] ]

View File

@@ -17,6 +17,7 @@ from ..exceptions import NotFoundError
from .search import SearchConfig, SearchFieldType, build_search_filters from .search import SearchConfig, SearchFieldType, build_search_filters
ModelType = TypeVar("ModelType", bound=DeclarativeBase) ModelType = TypeVar("ModelType", bound=DeclarativeBase)
JoinType = list[tuple[type[DeclarativeBase], Any]]
class AsyncCrud(Generic[ModelType]): class AsyncCrud(Generic[ModelType]):
@@ -55,6 +56,8 @@ class AsyncCrud(Generic[ModelType]):
session: AsyncSession, session: AsyncSession,
filters: list[Any], filters: list[Any],
*, *,
joins: JoinType | None = None,
outer_join: bool = False,
with_for_update: bool = False, with_for_update: bool = False,
load_options: list[Any] | None = None, load_options: list[Any] | None = None,
) -> ModelType: ) -> ModelType:
@@ -63,6 +66,8 @@ class AsyncCrud(Generic[ModelType]):
Args: Args:
session: DB async session session: DB async session
filters: List of SQLAlchemy filter conditions filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
with_for_update: Lock the row for update with_for_update: Lock the row for update
load_options: SQLAlchemy loader options (e.g., selectinload) load_options: SQLAlchemy loader options (e.g., selectinload)
@@ -73,7 +78,15 @@ class AsyncCrud(Generic[ModelType]):
NotFoundError: If no record found NotFoundError: If no record found
MultipleResultsFound: If more than one record found MultipleResultsFound: If more than one record found
""" """
q = select(cls.model).where(and_(*filters)) q = select(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
q = q.where(and_(*filters))
if load_options: if load_options:
q = q.options(*load_options) q = q.options(*load_options)
if with_for_update: if with_for_update:
@@ -90,6 +103,8 @@ class AsyncCrud(Generic[ModelType]):
session: AsyncSession, session: AsyncSession,
filters: list[Any] | None = None, filters: list[Any] | None = None,
*, *,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: list[Any] | None = None, load_options: list[Any] | None = None,
) -> ModelType | None: ) -> ModelType | None:
"""Get the first matching record, or None. """Get the first matching record, or None.
@@ -97,12 +112,21 @@ class AsyncCrud(Generic[ModelType]):
Args: Args:
session: DB async session session: DB async session
filters: List of SQLAlchemy filter conditions filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
load_options: SQLAlchemy loader options load_options: SQLAlchemy loader options
Returns: Returns:
Model instance or None Model instance or None
""" """
q = select(cls.model) q = select(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
if filters: if filters:
q = q.where(and_(*filters)) q = q.where(and_(*filters))
if load_options: if load_options:
@@ -116,6 +140,8 @@ class AsyncCrud(Generic[ModelType]):
session: AsyncSession, session: AsyncSession,
*, *,
filters: list[Any] | None = None, filters: list[Any] | None = None,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: list[Any] | None = None, load_options: list[Any] | None = None,
order_by: Any | None = None, order_by: Any | None = None,
limit: int | None = None, limit: int | None = None,
@@ -126,6 +152,8 @@ class AsyncCrud(Generic[ModelType]):
Args: Args:
session: DB async session session: DB async session
filters: List of SQLAlchemy filter conditions filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
load_options: SQLAlchemy loader options load_options: SQLAlchemy loader options
order_by: Column or list of columns to order by order_by: Column or list of columns to order by
limit: Max number of rows to return limit: Max number of rows to return
@@ -135,6 +163,13 @@ class AsyncCrud(Generic[ModelType]):
List of model instances List of model instances
""" """
q = select(cls.model) q = select(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
if filters: if filters:
q = q.where(and_(*filters)) q = q.where(and_(*filters))
if load_options: if load_options:
@@ -254,17 +289,29 @@ class AsyncCrud(Generic[ModelType]):
cls: type[Self], cls: type[Self],
session: AsyncSession, session: AsyncSession,
filters: list[Any] | None = None, filters: list[Any] | None = None,
*,
joins: JoinType | None = None,
outer_join: bool = False,
) -> int: ) -> int:
"""Count records matching the filters. """Count records matching the filters.
Args: Args:
session: DB async session session: DB async session
filters: List of SQLAlchemy filter conditions filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
Returns: Returns:
Number of matching records Number of matching records
""" """
q = select(func.count()).select_from(cls.model) q = select(func.count()).select_from(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
if filters: if filters:
q = q.where(and_(*filters)) q = q.where(and_(*filters))
result = await session.execute(q) result = await session.execute(q)
@@ -275,17 +322,30 @@ class AsyncCrud(Generic[ModelType]):
cls: type[Self], cls: type[Self],
session: AsyncSession, session: AsyncSession,
filters: list[Any], filters: list[Any],
*,
joins: JoinType | None = None,
outer_join: bool = False,
) -> bool: ) -> bool:
"""Check if a record exists. """Check if a record exists.
Args: Args:
session: DB async session session: DB async session
filters: List of SQLAlchemy filter conditions filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
Returns: Returns:
True if at least one record matches True if at least one record matches
""" """
q = select(cls.model).where(and_(*filters)).exists().select() q = select(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
q = q.where(and_(*filters)).exists().select()
result = await session.execute(q) result = await session.execute(q)
return bool(result.scalar()) return bool(result.scalar())
@@ -295,6 +355,8 @@ class AsyncCrud(Generic[ModelType]):
session: AsyncSession, session: AsyncSession,
*, *,
filters: list[Any] | None = None, filters: list[Any] | None = None,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: list[Any] | None = None, load_options: list[Any] | None = None,
order_by: Any | None = None, order_by: Any | None = None,
page: int = 1, page: int = 1,
@@ -307,6 +369,8 @@ class AsyncCrud(Generic[ModelType]):
Args: Args:
session: DB async session session: DB async session
filters: List of SQLAlchemy filter conditions filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
load_options: SQLAlchemy loader options load_options: SQLAlchemy loader options
order_by: Column or list of columns to order by order_by: Column or list of columns to order by
page: Page number (1-indexed) page: Page number (1-indexed)
@@ -319,7 +383,7 @@ class AsyncCrud(Generic[ModelType]):
""" """
filters = list(filters) if filters else [] filters = list(filters) if filters else []
offset = (page - 1) * items_per_page offset = (page - 1) * items_per_page
joins: list[Any] = [] search_joins: list[Any] = []
# Build search filters # Build search filters
if search: if search:
@@ -330,11 +394,21 @@ class AsyncCrud(Generic[ModelType]):
default_fields=cls.searchable_fields, default_fields=cls.searchable_fields,
) )
filters.extend(search_filters) filters.extend(search_filters)
joins.extend(search_joins)
# Build query with joins # Build query with joins
q = select(cls.model) q = select(cls.model)
for join_rel in joins:
# Apply explicit joins
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
# Apply search joins (always outer joins for search)
for join_rel in search_joins:
q = q.outerjoin(join_rel) q = q.outerjoin(join_rel)
if filters: if filters:
@@ -352,8 +426,20 @@ class AsyncCrud(Generic[ModelType]):
pk_col = cls.model.__mapper__.primary_key[0] pk_col = cls.model.__mapper__.primary_key[0]
count_q = select(func.count(func.distinct(getattr(cls.model, pk_col.name)))) count_q = select(func.count(func.distinct(getattr(cls.model, pk_col.name))))
count_q = count_q.select_from(cls.model) count_q = count_q.select_from(cls.model)
for join_rel in joins:
# Apply explicit joins to count query
if joins:
for model, condition in joins:
count_q = (
count_q.outerjoin(model, condition)
if outer_join
else count_q.join(model, condition)
)
# Apply search joins to count query
for join_rel in search_joins:
count_q = count_q.outerjoin(join_rel) count_q = count_q.outerjoin(join_rel)
if filters: if filters:
count_q = count_q.where(and_(*filters)) count_q = count_q.where(and_(*filters))
@@ -404,6 +490,20 @@ def CrudFactory(
# With search # With search
result = await UserCrud.paginate(session, search="john") result = await UserCrud.paginate(session, search="john")
# With joins (inner join by default):
users = await UserCrud.get_multi(
session,
joins=[(Post, Post.user_id == User.id)],
filters=[Post.published == True],
)
# With outer join:
users = await UserCrud.get_multi(
session,
joins=[(Post, Post.user_id == User.id)],
outer_join=True,
)
""" """
cls = type( cls = type(
f"Async{model.__name__}Crud", f"Async{model.__name__}Crud",

View File

@@ -129,11 +129,12 @@ def build_search_filters(
else: else:
column = field column = field
# Build the filter # Build the filter (cast to String for non-text columns)
column_as_string = column.cast(String)
if config.case_sensitive: if config.case_sensitive:
filters.append(column.like(f"%{query}%")) filters.append(column_as_string.like(f"%{query}%"))
else: else:
filters.append(column.ilike(f"%{query}%")) filters.append(column_as_string.ilike(f"%{query}%"))
if not filters: if not filters:
return [], [] return [], []

View File

@@ -1,4 +1,5 @@
from .exceptions import ( from .exceptions import (
ApiError,
ApiException, ApiException,
ConflictError, ConflictError,
ForbiddenError, ForbiddenError,
@@ -10,11 +11,12 @@ from .exceptions import (
from .handler import init_exceptions_handlers from .handler import init_exceptions_handlers
__all__ = [ __all__ = [
"init_exceptions_handlers", "ApiError",
"generate_error_responses",
"ApiException", "ApiException",
"ConflictError", "ConflictError",
"ForbiddenError", "ForbiddenError",
"generate_error_responses",
"init_exceptions_handlers",
"NoSearchableFieldsError", "NoSearchableFieldsError",
"NotFoundError", "NotFoundError",
"UnauthorizedError", "UnauthorizedError",

View File

@@ -50,8 +50,16 @@ class FixtureRegistry:
] ]
""" """
def __init__(self) -> None: def __init__(
self,
contexts: list[str | Context] | None = None,
) -> None:
self._fixtures: dict[str, Fixture] = {} self._fixtures: dict[str, Fixture] = {}
self._default_contexts: list[str] | None = (
[c.value if isinstance(c, Context) else c for c in contexts]
if contexts
else None
)
def register( def register(
self, self,
@@ -85,10 +93,14 @@ class FixtureRegistry:
fn: Callable[[], Sequence[DeclarativeBase]], fn: Callable[[], Sequence[DeclarativeBase]],
) -> Callable[[], Sequence[DeclarativeBase]]: ) -> Callable[[], Sequence[DeclarativeBase]]:
fixture_name = name or cast(Any, fn).__name__ fixture_name = name or cast(Any, fn).__name__
fixture_contexts = [ if contexts is not None:
c.value if isinstance(c, Context) else c fixture_contexts = [
for c in (contexts or [Context.BASE]) c.value if isinstance(c, Context) else c for c in contexts
] ]
elif self._default_contexts is not None:
fixture_contexts = self._default_contexts
else:
fixture_contexts = [Context.BASE.value]
self._fixtures[fixture_name] = Fixture( self._fixtures[fixture_name] = Fixture(
name=fixture_name, name=fixture_name,
@@ -102,6 +114,32 @@ class FixtureRegistry:
return decorator(func) return decorator(func)
return decorator return decorator
def include_registry(self, registry: "FixtureRegistry") -> None:
"""Include another `FixtureRegistry` in the same current `FixtureRegistry`.
Args:
registry: The `FixtureRegistry` to include
Raises:
ValueError: If a fixture name already exists in the current registry
Example:
registry = FixtureRegistry()
dev_registry = FixtureRegistry()
@dev_registry.register
def dev_data():
return [...]
registry.include_registry(registry=dev_registry)
"""
for name, fixture in registry._fixtures.items():
if name in self._fixtures:
raise ValueError(
f"Fixture '{name}' already exists in the current registry"
)
self._fixtures[name] = fixture
def get(self, name: str) -> Fixture: def get(self, name: str) -> Fixture:
"""Get a fixture by name.""" """Get a fixture by name."""
if name not in self._fixtures: if name not in self._fixtures:

View File

@@ -1,10 +1,11 @@
"""Shared pytest fixtures for fastapi-utils tests.""" """Shared pytest fixtures for fastapi-utils tests."""
import os import os
import uuid
import pytest import pytest
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import ForeignKey, String from sqlalchemy import ForeignKey, String, Uuid
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
@@ -33,7 +34,7 @@ class Role(Base):
__tablename__ = "roles" __tablename__ = "roles"
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
name: Mapped[str] = mapped_column(String(50), unique=True) name: Mapped[str] = mapped_column(String(50), unique=True)
users: Mapped[list["User"]] = relationship(back_populates="role") users: Mapped[list["User"]] = relationship(back_populates="role")
@@ -44,11 +45,13 @@ class User(Base):
__tablename__ = "users" __tablename__ = "users"
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
username: Mapped[str] = mapped_column(String(50), unique=True) username: Mapped[str] = mapped_column(String(50), unique=True)
email: Mapped[str] = mapped_column(String(100), unique=True) email: Mapped[str] = mapped_column(String(100), unique=True)
is_active: Mapped[bool] = mapped_column(default=True) is_active: Mapped[bool] = mapped_column(default=True)
role_id: Mapped[int | None] = mapped_column(ForeignKey("roles.id"), nullable=True) role_id: Mapped[uuid.UUID | None] = mapped_column(
ForeignKey("roles.id"), nullable=True
)
role: Mapped[Role | None] = relationship(back_populates="users") role: Mapped[Role | None] = relationship(back_populates="users")
@@ -58,11 +61,11 @@ class Post(Base):
__tablename__ = "posts" __tablename__ = "posts"
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
title: Mapped[str] = mapped_column(String(200)) title: Mapped[str] = mapped_column(String(200))
content: Mapped[str] = mapped_column(String(1000), default="") content: Mapped[str] = mapped_column(String(1000), default="")
is_published: Mapped[bool] = mapped_column(default=False) is_published: Mapped[bool] = mapped_column(default=False)
author_id: Mapped[int] = mapped_column(ForeignKey("users.id")) author_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"))
# ============================================================================= # =============================================================================
@@ -73,7 +76,7 @@ class Post(Base):
class RoleCreate(BaseModel): class RoleCreate(BaseModel):
"""Schema for creating a role.""" """Schema for creating a role."""
id: int | None = None id: uuid.UUID | None = None
name: str name: str
@@ -86,11 +89,11 @@ class RoleUpdate(BaseModel):
class UserCreate(BaseModel): class UserCreate(BaseModel):
"""Schema for creating a user.""" """Schema for creating a user."""
id: int | None = None id: uuid.UUID | None = None
username: str username: str
email: str email: str
is_active: bool = True is_active: bool = True
role_id: int | None = None role_id: uuid.UUID | None = None
class UserUpdate(BaseModel): class UserUpdate(BaseModel):
@@ -99,17 +102,17 @@ class UserUpdate(BaseModel):
username: str | None = None username: str | None = None
email: str | None = None email: str | None = None
is_active: bool | None = None is_active: bool | None = None
role_id: int | None = None role_id: uuid.UUID | None = None
class PostCreate(BaseModel): class PostCreate(BaseModel):
"""Schema for creating a post.""" """Schema for creating a post."""
id: int | None = None id: uuid.UUID | None = None
title: str title: str
content: str = "" content: str = ""
is_published: bool = False is_published: bool = False
author_id: int author_id: uuid.UUID
class PostUpdate(BaseModel): class PostUpdate(BaseModel):
@@ -195,5 +198,5 @@ def sample_post_data() -> PostCreate:
title="Test Post", title="Test Post",
content="Test content", content="Test content",
is_published=True, is_published=True,
author_id=1, author_id=uuid.uuid4(),
) )

322
tests/test_cli.py Normal file
View File

@@ -0,0 +1,322 @@
"""Tests for fastapi_toolsets.cli module."""
import sys
import pytest
from typer.testing import CliRunner
from fastapi_toolsets.cli.config import CliConfig, _import_from_string, load_config
from fastapi_toolsets.cli.utils import async_command
from fastapi_toolsets.fixtures import FixtureRegistry
runner = CliRunner()
class TestCliConfig:
"""Tests for CliConfig dataclass."""
def test_default_values(self):
"""Config has None defaults."""
config = CliConfig()
assert config.fixtures is None
assert config.db_context is None
def test_with_values(self):
"""Config stores provided values."""
config = CliConfig(
fixtures="app.fixtures:registry",
db_context="app.db:get_session",
)
assert config.fixtures == "app.fixtures:registry"
assert config.db_context == "app.db:get_session"
def test_get_fixtures_registry_without_config(self):
"""get_fixtures_registry raises error when not configured."""
config = CliConfig()
with pytest.raises(Exception) as exc_info:
config.get_fixtures_registry()
assert "No fixtures registry configured" in str(exc_info.value)
def test_get_db_context_without_config(self):
"""get_db_context raises error when not configured."""
config = CliConfig()
with pytest.raises(Exception) as exc_info:
config.get_db_context()
assert "No db_context configured" in str(exc_info.value)
class TestImportFromString:
"""Tests for _import_from_string function."""
def test_import_valid_path(self):
"""Import valid module:attribute path."""
result = _import_from_string("fastapi_toolsets.fixtures:FixtureRegistry")
assert result is FixtureRegistry
def test_import_without_colon_raises_error(self):
"""Import path without colon raises error."""
with pytest.raises(Exception) as exc_info:
_import_from_string("fastapi_toolsets.fixtures.FixtureRegistry")
assert "Expected format: 'module:attribute'" in str(exc_info.value)
def test_import_nonexistent_module_raises_error(self):
"""Import nonexistent module raises error."""
with pytest.raises(Exception) as exc_info:
_import_from_string("nonexistent.module:something")
assert "Cannot import module" in str(exc_info.value)
def test_import_nonexistent_attribute_raises_error(self):
"""Import nonexistent attribute raises error."""
with pytest.raises(Exception) as exc_info:
_import_from_string("fastapi_toolsets.fixtures:NonexistentClass")
assert "has no attribute" in str(exc_info.value)
class TestLoadConfig:
"""Tests for load_config function."""
def test_load_without_pyproject(self, tmp_path, monkeypatch):
"""Returns empty config when no pyproject.toml exists."""
monkeypatch.chdir(tmp_path)
config = load_config()
assert config.fixtures is None
assert config.db_context is None
def test_load_without_tool_section(self, tmp_path, monkeypatch):
"""Returns empty config when no [tool.fastapi-toolsets] section."""
pyproject = tmp_path / "pyproject.toml"
pyproject.write_text("[project]\nname = 'test'\n")
monkeypatch.chdir(tmp_path)
config = load_config()
assert config.fixtures is None
assert config.db_context is None
def test_load_with_fixtures_config(self, tmp_path, monkeypatch):
"""Loads fixtures config from pyproject.toml."""
pyproject = tmp_path / "pyproject.toml"
pyproject.write_text(
'[tool.fastapi-toolsets]\nfixtures = "app.fixtures:registry"\n'
)
monkeypatch.chdir(tmp_path)
config = load_config()
assert config.fixtures == "app.fixtures:registry"
assert config.db_context is None
def test_load_with_full_config(self, tmp_path, monkeypatch):
"""Loads full config from pyproject.toml."""
pyproject = tmp_path / "pyproject.toml"
pyproject.write_text(
"[tool.fastapi-toolsets]\n"
'fixtures = "app.fixtures:registry"\n'
'db_context = "app.db:get_session"\n'
)
monkeypatch.chdir(tmp_path)
config = load_config()
assert config.fixtures == "app.fixtures:registry"
assert config.db_context == "app.db:get_session"
def test_load_with_invalid_toml(self, tmp_path, monkeypatch):
"""Returns empty config when pyproject.toml is invalid."""
pyproject = tmp_path / "pyproject.toml"
pyproject.write_text("invalid toml {{{")
monkeypatch.chdir(tmp_path)
config = load_config()
assert config.fixtures is None
class TestCliApp:
"""Tests for CLI application."""
def test_cli_import(self):
"""CLI can be imported."""
from fastapi_toolsets.cli import cli
assert cli is not None
def test_cli_help(self, tmp_path, monkeypatch):
"""CLI shows help without fixtures."""
monkeypatch.chdir(tmp_path)
# Need to reload the module to pick up new cwd
import importlib
from fastapi_toolsets.cli import app
importlib.reload(app)
result = runner.invoke(app.cli, ["--help"])
assert result.exit_code == 0
assert "CLI utilities for FastAPI projects" in result.output
class TestFixturesCli:
"""Tests for fixtures CLI commands."""
@pytest.fixture
def cli_env(self, tmp_path, monkeypatch):
"""Set up CLI environment with fixtures config."""
# Create pyproject.toml
pyproject = tmp_path / "pyproject.toml"
pyproject.write_text(
"[tool.fastapi-toolsets]\n"
'fixtures = "fixtures:registry"\n'
'db_context = "db:get_session"\n'
)
# Create fixtures module
fixtures_file = tmp_path / "fixtures.py"
fixtures_file.write_text(
"from fastapi_toolsets.fixtures import FixtureRegistry, Context\n"
"\n"
"registry = FixtureRegistry()\n"
"\n"
"@registry.register(contexts=[Context.BASE])\n"
"def roles():\n"
' return [{"id": 1, "name": "admin"}, {"id": 2, "name": "user"}]\n'
"\n"
'@registry.register(depends_on=["roles"], contexts=[Context.TESTING])\n'
"def users():\n"
' return [{"id": 1, "name": "alice", "role_id": 1}]\n'
)
# Create db module
db_file = tmp_path / "db.py"
db_file.write_text(
"from contextlib import asynccontextmanager\n"
"\n"
"@asynccontextmanager\n"
"async def get_session():\n"
" yield None\n"
)
monkeypatch.chdir(tmp_path)
# Add tmp_path to sys.path for imports
if str(tmp_path) not in sys.path:
sys.path.insert(0, str(tmp_path))
# Reload the CLI module to pick up new config
import importlib
from fastapi_toolsets.cli import app, config
importlib.reload(config)
importlib.reload(app)
yield tmp_path, app.cli
# Cleanup
if str(tmp_path) in sys.path:
sys.path.remove(str(tmp_path))
def test_fixtures_list(self, cli_env):
"""fixtures list shows registered fixtures."""
tmp_path, cli = cli_env
result = runner.invoke(cli, ["fixtures", "list"])
assert result.exit_code == 0
assert "roles" in result.output
assert "users" in result.output
assert "Total: 2 fixture(s)" in result.output
def test_fixtures_list_with_context(self, cli_env):
"""fixtures list --context filters by context."""
tmp_path, cli = cli_env
result = runner.invoke(cli, ["fixtures", "list", "--context", "base"])
assert result.exit_code == 0
assert "roles" in result.output
assert "users" not in result.output
assert "Total: 1 fixture(s)" in result.output
def test_fixtures_load_dry_run(self, cli_env):
"""fixtures load --dry-run shows what would be loaded."""
tmp_path, cli = cli_env
result = runner.invoke(cli, ["fixtures", "load", "base", "--dry-run"])
assert result.exit_code == 0
assert "Fixtures to load" in result.output
assert "roles" in result.output
assert "[Dry run - no changes made]" in result.output
def test_fixtures_load_invalid_strategy(self, cli_env):
"""fixtures load with invalid strategy shows error."""
tmp_path, cli = cli_env
result = runner.invoke(
cli, ["fixtures", "load", "base", "--strategy", "invalid"]
)
assert result.exit_code == 1
assert "Invalid strategy" in result.output
class TestCliWithoutFixturesConfig:
"""Tests for CLI when fixtures is not configured."""
def test_no_fixtures_command(self, tmp_path, monkeypatch):
"""fixtures command is not available when not configured."""
# Create pyproject.toml without fixtures
pyproject = tmp_path / "pyproject.toml"
pyproject.write_text('[project]\nname = "test"\n')
monkeypatch.chdir(tmp_path)
# Reload the CLI module
import importlib
from fastapi_toolsets.cli import app, config
importlib.reload(config)
importlib.reload(app)
result = runner.invoke(app.cli, ["--help"])
assert result.exit_code == 0
assert "fixtures" not in result.output
class TestAsyncCommand:
"""Tests for async_command decorator."""
def test_async_command_runs_coroutine(self):
"""async_command runs async function synchronously."""
@async_command
async def async_func(value: int) -> int:
return value * 2
result = async_func(21)
assert result == 42
def test_async_command_preserves_signature(self):
"""async_command preserves function signature."""
@async_command
async def async_func(name: str, count: int = 1) -> str:
return f"{name} x {count}"
result = async_func("test", count=3)
assert result == "test x 3"
def test_async_command_preserves_docstring(self):
"""async_command preserves function docstring."""
@async_command
async def async_func() -> None:
"""This is a docstring."""
pass
assert async_func.__doc__ == """This is a docstring."""
def test_async_command_preserves_name(self):
"""async_command preserves function name."""
@async_command
async def my_async_function() -> None:
pass
assert my_async_function.__name__ == "my_async_function"

View File

@@ -1,5 +1,7 @@
"""Tests for fastapi_toolsets.crud module.""" """Tests for fastapi_toolsets.crud module."""
import uuid
import pytest import pytest
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@@ -8,6 +10,9 @@ from fastapi_toolsets.crud.factory import AsyncCrud
from fastapi_toolsets.exceptions import NotFoundError from fastapi_toolsets.exceptions import NotFoundError
from .conftest import ( from .conftest import (
Post,
PostCreate,
PostCrud,
Role, Role,
RoleCreate, RoleCreate,
RoleCrud, RoleCrud,
@@ -89,8 +94,9 @@ class TestCrudGet:
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_raises_not_found(self, db_session: AsyncSession): async def test_get_raises_not_found(self, db_session: AsyncSession):
"""Get raises NotFoundError for missing records.""" """Get raises NotFoundError for missing records."""
non_existent_id = uuid.uuid4()
with pytest.raises(NotFoundError): with pytest.raises(NotFoundError):
await RoleCrud.get(db_session, [Role.id == 99999]) await RoleCrud.get(db_session, [Role.id == non_existent_id])
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_with_multiple_filters(self, db_session: AsyncSession): async def test_get_with_multiple_filters(self, db_session: AsyncSession):
@@ -223,11 +229,12 @@ class TestCrudUpdate:
@pytest.mark.anyio @pytest.mark.anyio
async def test_update_raises_not_found(self, db_session: AsyncSession): async def test_update_raises_not_found(self, db_session: AsyncSession):
"""Update raises NotFoundError for missing records.""" """Update raises NotFoundError for missing records."""
non_existent_id = uuid.uuid4()
with pytest.raises(NotFoundError): with pytest.raises(NotFoundError):
await RoleCrud.update( await RoleCrud.update(
db_session, db_session,
RoleUpdate(name="new"), RoleUpdate(name="new"),
[Role.id == 99999], [Role.id == non_existent_id],
) )
@pytest.mark.anyio @pytest.mark.anyio
@@ -340,7 +347,8 @@ class TestCrudUpsert:
@pytest.mark.anyio @pytest.mark.anyio
async def test_upsert_insert_new_record(self, db_session: AsyncSession): async def test_upsert_insert_new_record(self, db_session: AsyncSession):
"""Upsert inserts a new record when it doesn't exist.""" """Upsert inserts a new record when it doesn't exist."""
data = RoleCreate(id=1, name="upsert_new") role_id = uuid.uuid4()
data = RoleCreate(id=role_id, name="upsert_new")
role = await RoleCrud.upsert( role = await RoleCrud.upsert(
db_session, db_session,
data, data,
@@ -353,12 +361,13 @@ class TestCrudUpsert:
@pytest.mark.anyio @pytest.mark.anyio
async def test_upsert_update_existing_record(self, db_session: AsyncSession): async def test_upsert_update_existing_record(self, db_session: AsyncSession):
"""Upsert updates an existing record.""" """Upsert updates an existing record."""
role_id = uuid.uuid4()
# First insert # First insert
data = RoleCreate(id=100, name="original_name") data = RoleCreate(id=role_id, name="original_name")
await RoleCrud.upsert(db_session, data, index_elements=["id"]) await RoleCrud.upsert(db_session, data, index_elements=["id"])
# Upsert with update # Upsert with update
updated_data = RoleCreate(id=100, name="updated_name") updated_data = RoleCreate(id=role_id, name="updated_name")
role = await RoleCrud.upsert( role = await RoleCrud.upsert(
db_session, db_session,
updated_data, updated_data,
@@ -370,22 +379,23 @@ class TestCrudUpsert:
assert role.name == "updated_name" assert role.name == "updated_name"
# Verify only one record exists # Verify only one record exists
count = await RoleCrud.count(db_session, [Role.id == 100]) count = await RoleCrud.count(db_session, [Role.id == role_id])
assert count == 1 assert count == 1
@pytest.mark.anyio @pytest.mark.anyio
async def test_upsert_do_nothing_on_conflict(self, db_session: AsyncSession): async def test_upsert_do_nothing_on_conflict(self, db_session: AsyncSession):
"""Upsert does nothing on conflict when set_ is not provided.""" """Upsert does nothing on conflict when set_ is not provided."""
role_id = uuid.uuid4()
# First insert # First insert
data = RoleCreate(id=200, name="do_nothing_original") data = RoleCreate(id=role_id, name="do_nothing_original")
await RoleCrud.upsert(db_session, data, index_elements=["id"]) await RoleCrud.upsert(db_session, data, index_elements=["id"])
# Upsert without set_ (do nothing) # Upsert without set_ (do nothing)
conflict_data = RoleCreate(id=200, name="do_nothing_conflict") conflict_data = RoleCreate(id=role_id, name="do_nothing_conflict")
await RoleCrud.upsert(db_session, conflict_data, index_elements=["id"]) await RoleCrud.upsert(db_session, conflict_data, index_elements=["id"])
# Original value should be preserved # Original value should be preserved
role = await RoleCrud.first(db_session, [Role.id == 200]) role = await RoleCrud.first(db_session, [Role.id == role_id])
assert role is not None assert role is not None
assert role.name == "do_nothing_original" assert role.name == "do_nothing_original"
@@ -474,3 +484,271 @@ class TestCrudPaginate:
names = [r.name for r in result["data"]] names = [r.name for r in result["data"]]
assert names == ["alpha", "bravo", "charlie"] assert names == ["alpha", "bravo", "charlie"]
class TestCrudJoins:
"""Tests for CRUD operations with joins."""
@pytest.mark.anyio
async def test_get_with_join(self, db_session: AsyncSession):
"""Get with inner join filters correctly."""
# Create user with posts
user = await UserCrud.create(
db_session,
UserCreate(username="author", email="author@test.com"),
)
await PostCrud.create(
db_session,
PostCreate(title="Post 1", author_id=user.id, is_published=True),
)
# Get user with join on published posts
fetched = await UserCrud.get(
db_session,
filters=[User.id == user.id, Post.is_published == True], # noqa: E712
joins=[(Post, Post.author_id == User.id)],
)
assert fetched.id == user.id
@pytest.mark.anyio
async def test_first_with_join(self, db_session: AsyncSession):
"""First with join returns matching record."""
user = await UserCrud.create(
db_session,
UserCreate(username="writer", email="writer@test.com"),
)
await PostCrud.create(
db_session,
PostCreate(title="Draft", author_id=user.id, is_published=False),
)
# Find user with unpublished posts
result = await UserCrud.first(
db_session,
filters=[Post.is_published == False], # noqa: E712
joins=[(Post, Post.author_id == User.id)],
)
assert result is not None
assert result.id == user.id
@pytest.mark.anyio
async def test_first_with_outer_join(self, db_session: AsyncSession):
"""First with outer join includes records without related data."""
# User without posts
user = await UserCrud.create(
db_session,
UserCreate(username="no_posts", email="no_posts@test.com"),
)
# With outer join, user should be found even without posts
result = await UserCrud.first(
db_session,
filters=[User.id == user.id],
joins=[(Post, Post.author_id == User.id)],
outer_join=True,
)
assert result is not None
assert result.id == user.id
@pytest.mark.anyio
async def test_get_multi_with_inner_join(self, db_session: AsyncSession):
"""Get multiple with inner join only returns matching records."""
# User with published post
user1 = await UserCrud.create(
db_session,
UserCreate(username="publisher", email="pub@test.com"),
)
await PostCrud.create(
db_session,
PostCreate(title="Published", author_id=user1.id, is_published=True),
)
# User without posts
await UserCrud.create(
db_session,
UserCreate(username="lurker", email="lurk@test.com"),
)
# Inner join should only return user with published post
users = await UserCrud.get_multi(
db_session,
joins=[(Post, Post.author_id == User.id)],
filters=[Post.is_published == True], # noqa: E712
)
assert len(users) == 1
assert users[0].username == "publisher"
@pytest.mark.anyio
async def test_get_multi_with_outer_join(self, db_session: AsyncSession):
"""Get multiple with outer join includes all records."""
# User with post
user1 = await UserCrud.create(
db_session,
UserCreate(username="has_post", email="has@test.com"),
)
await PostCrud.create(
db_session,
PostCreate(title="My Post", author_id=user1.id),
)
# User without posts
await UserCrud.create(
db_session,
UserCreate(username="no_post", email="no@test.com"),
)
# Outer join should return both users
users = await UserCrud.get_multi(
db_session,
joins=[(Post, Post.author_id == User.id)],
outer_join=True,
)
assert len(users) == 2
@pytest.mark.anyio
async def test_count_with_join(self, db_session: AsyncSession):
"""Count with join counts correctly."""
# Create users with different post statuses
user1 = await UserCrud.create(
db_session,
UserCreate(username="active_author", email="active@test.com"),
)
await PostCrud.create(
db_session,
PostCreate(title="Published 1", author_id=user1.id, is_published=True),
)
user2 = await UserCrud.create(
db_session,
UserCreate(username="draft_author", email="draft@test.com"),
)
await PostCrud.create(
db_session,
PostCreate(title="Draft 1", author_id=user2.id, is_published=False),
)
# Count users with published posts
count = await UserCrud.count(
db_session,
filters=[Post.is_published == True], # noqa: E712
joins=[(Post, Post.author_id == User.id)],
)
assert count == 1
@pytest.mark.anyio
async def test_exists_with_join(self, db_session: AsyncSession):
"""Exists with join checks correctly."""
user = await UserCrud.create(
db_session,
UserCreate(username="poster", email="poster@test.com"),
)
await PostCrud.create(
db_session,
PostCreate(title="Exists Post", author_id=user.id, is_published=True),
)
# Check if user with published post exists
exists = await UserCrud.exists(
db_session,
filters=[Post.is_published == True], # noqa: E712
joins=[(Post, Post.author_id == User.id)],
)
assert exists is True
# Check if user with specific title exists
exists = await UserCrud.exists(
db_session,
filters=[Post.title == "Nonexistent"],
joins=[(Post, Post.author_id == User.id)],
)
assert exists is False
@pytest.mark.anyio
async def test_paginate_with_join(self, db_session: AsyncSession):
"""Paginate with join works correctly."""
# Create users with posts
for i in range(5):
user = await UserCrud.create(
db_session,
UserCreate(username=f"author{i}", email=f"author{i}@test.com"),
)
await PostCrud.create(
db_session,
PostCreate(
title=f"Post {i}",
author_id=user.id,
is_published=i % 2 == 0,
),
)
# Paginate users with published posts
result = await UserCrud.paginate(
db_session,
joins=[(Post, Post.author_id == User.id)],
filters=[Post.is_published == True], # noqa: E712
page=1,
items_per_page=10,
)
assert result["pagination"]["total_count"] == 3
assert len(result["data"]) == 3
@pytest.mark.anyio
async def test_paginate_with_outer_join(self, db_session: AsyncSession):
"""Paginate with outer join includes all records."""
# User with post
user1 = await UserCrud.create(
db_session,
UserCreate(username="with_post", email="with@test.com"),
)
await PostCrud.create(
db_session,
PostCreate(title="A Post", author_id=user1.id),
)
# User without post
await UserCrud.create(
db_session,
UserCreate(username="without_post", email="without@test.com"),
)
# Paginate with outer join
result = await UserCrud.paginate(
db_session,
joins=[(Post, Post.author_id == User.id)],
outer_join=True,
page=1,
items_per_page=10,
)
assert result["pagination"]["total_count"] == 2
assert len(result["data"]) == 2
@pytest.mark.anyio
async def test_multiple_joins(self, db_session: AsyncSession):
"""Multiple joins can be applied."""
role = await RoleCrud.create(db_session, RoleCreate(name="author_role"))
user = await UserCrud.create(
db_session,
UserCreate(
username="multi_join",
email="multi@test.com",
role_id=role.id,
),
)
await PostCrud.create(
db_session,
PostCreate(title="Multi Join Post", author_id=user.id, is_published=True),
)
# Join both Role and Post
users = await UserCrud.get_multi(
db_session,
joins=[
(Role, Role.id == User.role_id),
(Post, Post.author_id == User.id),
],
filters=[Role.name == "author_role", Post.is_published == True], # noqa: E712
)
assert len(users) == 1
assert users[0].username == "multi_join"

View File

@@ -1,5 +1,7 @@
"""Tests for CRUD search functionality.""" """Tests for CRUD search functionality."""
import uuid
import pytest import pytest
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@@ -272,6 +274,27 @@ class TestPaginateSearch:
usernames = [u.username for u in result["data"]] usernames = [u.username for u in result["data"]]
assert usernames == ["alice", "bob", "charlie"] assert usernames == ["alice", "bob", "charlie"]
@pytest.mark.anyio
async def test_search_non_string_column(self, db_session: AsyncSession):
"""Search on non-string columns (e.g., UUID) works via cast."""
user_id = uuid.UUID("12345678-1234-5678-1234-567812345678")
await UserCrud.create(
db_session, UserCreate(id=user_id, username="john", email="john@test.com")
)
await UserCrud.create(
db_session, UserCreate(username="jane", email="jane@test.com")
)
# Search by UUID (partial match)
result = await UserCrud.paginate(
db_session,
search="12345678",
search_fields=[User.id, User.username],
)
assert result["pagination"]["total_count"] == 1
assert result["data"][0].id == user_id
class TestSearchConfig: class TestSearchConfig:
"""Tests for SearchConfig options.""" """Tests for SearchConfig options."""

View File

@@ -1,5 +1,7 @@
"""Tests for fastapi_toolsets.fixtures module.""" """Tests for fastapi_toolsets.fixtures module."""
import uuid
import pytest import pytest
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@@ -57,20 +59,22 @@ class TestFixtureRegistry:
def test_register_with_decorator(self): def test_register_with_decorator(self):
"""Register fixture with decorator.""" """Register fixture with decorator."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id = uuid.uuid4()
@registry.register @registry.register
def roles(): def roles():
return [Role(id=1, name="admin")] return [Role(id=role_id, name="admin")]
assert "roles" in [f.name for f in registry.get_all()] assert "roles" in [f.name for f in registry.get_all()]
def test_register_with_custom_name(self): def test_register_with_custom_name(self):
"""Register fixture with custom name.""" """Register fixture with custom name."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id = uuid.uuid4()
@registry.register(name="custom_roles") @registry.register(name="custom_roles")
def roles(): def roles():
return [Role(id=1, name="admin")] return [Role(id=role_id, name="admin")]
fixture = registry.get("custom_roles") fixture = registry.get("custom_roles")
assert fixture.name == "custom_roles" assert fixture.name == "custom_roles"
@@ -78,14 +82,23 @@ class TestFixtureRegistry:
def test_register_with_dependencies(self): def test_register_with_dependencies(self):
"""Register fixture with dependencies.""" """Register fixture with dependencies."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id = uuid.uuid4()
user_id = uuid.uuid4()
@registry.register @registry.register
def roles(): def roles():
return [Role(id=1, name="admin")] return [Role(id=role_id, name="admin")]
@registry.register(depends_on=["roles"]) @registry.register(depends_on=["roles"])
def users(): def users():
return [User(id=1, username="admin", email="admin@test.com", role_id=1)] return [
User(
id=user_id,
username="admin",
email="admin@test.com",
role_id=role_id,
)
]
fixture = registry.get("users") fixture = registry.get("users")
assert fixture.depends_on == ["roles"] assert fixture.depends_on == ["roles"]
@@ -93,10 +106,11 @@ class TestFixtureRegistry:
def test_register_with_contexts(self): def test_register_with_contexts(self):
"""Register fixture with contexts.""" """Register fixture with contexts."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id = uuid.uuid4()
@registry.register(contexts=[Context.TESTING]) @registry.register(contexts=[Context.TESTING])
def test_data(): def test_data():
return [Role(id=100, name="test")] return [Role(id=role_id, name="test")]
fixture = registry.get("test_data") fixture = registry.get("test_data")
assert Context.TESTING.value in fixture.contexts assert Context.TESTING.value in fixture.contexts
@@ -145,6 +159,178 @@ class TestFixtureRegistry:
assert names == {"test_data"} assert names == {"test_data"}
class TestIncludeRegistry:
"""Tests for FixtureRegistry.include_registry method."""
def test_include_empty_registry(self):
"""Include an empty registry does nothing."""
main_registry = FixtureRegistry()
other_registry = FixtureRegistry()
@main_registry.register
def roles():
return []
main_registry.include_registry(other_registry)
assert len(main_registry.get_all()) == 1
def test_include_registry_adds_fixtures(self):
"""Include registry adds all fixtures from the other registry."""
main_registry = FixtureRegistry()
other_registry = FixtureRegistry()
@main_registry.register
def roles():
return []
@other_registry.register
def users():
return []
@other_registry.register
def posts():
return []
main_registry.include_registry(other_registry)
names = {f.name for f in main_registry.get_all()}
assert names == {"roles", "users", "posts"}
def test_include_registry_preserves_dependencies(self):
"""Include registry preserves fixture dependencies."""
main_registry = FixtureRegistry()
other_registry = FixtureRegistry()
@main_registry.register
def roles():
return []
@other_registry.register(depends_on=["roles"])
def users():
return []
main_registry.include_registry(other_registry)
fixture = main_registry.get("users")
assert fixture.depends_on == ["roles"]
def test_include_registry_preserves_contexts(self):
"""Include registry preserves fixture contexts."""
main_registry = FixtureRegistry()
other_registry = FixtureRegistry()
@other_registry.register(contexts=[Context.TESTING, Context.DEVELOPMENT])
def test_data():
return []
main_registry.include_registry(other_registry)
fixture = main_registry.get("test_data")
assert Context.TESTING.value in fixture.contexts
assert Context.DEVELOPMENT.value in fixture.contexts
def test_include_registry_raises_on_duplicate(self):
"""Include registry raises ValueError on duplicate fixture names."""
main_registry = FixtureRegistry()
other_registry = FixtureRegistry()
@main_registry.register(name="roles")
def roles_main():
return []
@other_registry.register(name="roles")
def roles_other():
return []
with pytest.raises(ValueError, match="already exists"):
main_registry.include_registry(other_registry)
def test_include_multiple_registries(self):
"""Include multiple registries sequentially."""
main_registry = FixtureRegistry()
dev_registry = FixtureRegistry()
test_registry = FixtureRegistry()
@main_registry.register
def base():
return []
@dev_registry.register
def dev_data():
return []
@test_registry.register
def test_data():
return []
main_registry.include_registry(dev_registry)
main_registry.include_registry(test_registry)
names = {f.name for f in main_registry.get_all()}
assert names == {"base", "dev_data", "test_data"}
class TestDefaultContexts:
"""Tests for FixtureRegistry default contexts."""
def test_default_contexts_applied_to_fixtures(self):
"""Default contexts are applied when no contexts specified."""
registry = FixtureRegistry(contexts=[Context.TESTING])
@registry.register
def test_data():
return []
fixture = registry.get("test_data")
assert fixture.contexts == [Context.TESTING.value]
def test_explicit_contexts_override_default(self):
"""Explicit contexts override default contexts."""
registry = FixtureRegistry(contexts=[Context.TESTING])
@registry.register(contexts=[Context.PRODUCTION])
def prod_data():
return []
fixture = registry.get("prod_data")
assert fixture.contexts == [Context.PRODUCTION.value]
def test_no_default_contexts_uses_base(self):
"""Without default contexts, BASE is used."""
registry = FixtureRegistry()
@registry.register
def data():
return []
fixture = registry.get("data")
assert fixture.contexts == [Context.BASE.value]
def test_multiple_default_contexts(self):
"""Multiple default contexts are applied."""
registry = FixtureRegistry(contexts=[Context.DEVELOPMENT, Context.TESTING])
@registry.register
def dev_test_data():
return []
fixture = registry.get("dev_test_data")
assert Context.DEVELOPMENT.value in fixture.contexts
assert Context.TESTING.value in fixture.contexts
def test_default_contexts_with_string_values(self):
"""Default contexts work with string values."""
registry = FixtureRegistry(contexts=["custom_context"])
@registry.register
def custom_data():
return []
fixture = registry.get("custom_data")
assert fixture.contexts == ["custom_context"]
class TestDependencyResolution: class TestDependencyResolution:
"""Tests for fixture dependency resolution.""" """Tests for fixture dependency resolution."""
@@ -244,12 +430,14 @@ class TestLoadFixtures:
async def test_load_single_fixture(self, db_session: AsyncSession): async def test_load_single_fixture(self, db_session: AsyncSession):
"""Load a single fixture.""" """Load a single fixture."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id_1 = uuid.uuid4()
role_id_2 = uuid.uuid4()
@registry.register @registry.register
def roles(): def roles():
return [ return [
Role(id=1, name="admin"), Role(id=role_id_1, name="admin"),
Role(id=2, name="user"), Role(id=role_id_2, name="user"),
] ]
result = await load_fixtures(db_session, registry, "roles") result = await load_fixtures(db_session, registry, "roles")
@@ -266,14 +454,23 @@ class TestLoadFixtures:
async def test_load_with_dependencies(self, db_session: AsyncSession): async def test_load_with_dependencies(self, db_session: AsyncSession):
"""Load fixtures with dependencies.""" """Load fixtures with dependencies."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id = uuid.uuid4()
user_id = uuid.uuid4()
@registry.register @registry.register
def roles(): def roles():
return [Role(id=1, name="admin")] return [Role(id=role_id, name="admin")]
@registry.register(depends_on=["roles"]) @registry.register(depends_on=["roles"])
def users(): def users():
return [User(id=1, username="admin", email="admin@test.com", role_id=1)] return [
User(
id=user_id,
username="admin",
email="admin@test.com",
role_id=role_id,
)
]
result = await load_fixtures(db_session, registry, "users") result = await load_fixtures(db_session, registry, "users")
@@ -289,10 +486,11 @@ class TestLoadFixtures:
async def test_load_with_merge_strategy(self, db_session: AsyncSession): async def test_load_with_merge_strategy(self, db_session: AsyncSession):
"""Load fixtures with MERGE strategy updates existing.""" """Load fixtures with MERGE strategy updates existing."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id = uuid.uuid4()
@registry.register @registry.register
def roles(): def roles():
return [Role(id=1, name="admin")] return [Role(id=role_id, name="admin")]
await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE) await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE)
await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE) await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE)
@@ -306,10 +504,11 @@ class TestLoadFixtures:
async def test_load_with_skip_existing_strategy(self, db_session: AsyncSession): async def test_load_with_skip_existing_strategy(self, db_session: AsyncSession):
"""Load fixtures with SKIP_EXISTING strategy.""" """Load fixtures with SKIP_EXISTING strategy."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id = uuid.uuid4()
@registry.register @registry.register
def roles(): def roles():
return [Role(id=1, name="original")] return [Role(id=role_id, name="original")]
await load_fixtures( await load_fixtures(
db_session, registry, "roles", strategy=LoadStrategy.SKIP_EXISTING db_session, registry, "roles", strategy=LoadStrategy.SKIP_EXISTING
@@ -317,7 +516,7 @@ class TestLoadFixtures:
@registry.register(name="roles_updated") @registry.register(name="roles_updated")
def roles_v2(): def roles_v2():
return [Role(id=1, name="updated")] return [Role(id=role_id, name="updated")]
registry._fixtures["roles"] = registry._fixtures.pop("roles_updated") registry._fixtures["roles"] = registry._fixtures.pop("roles_updated")
@@ -327,7 +526,7 @@ class TestLoadFixtures:
from .conftest import RoleCrud from .conftest import RoleCrud
role = await RoleCrud.first(db_session, [Role.id == 1]) role = await RoleCrud.first(db_session, [Role.id == role_id])
assert role is not None assert role is not None
assert role.name == "original" assert role.name == "original"
@@ -335,12 +534,14 @@ class TestLoadFixtures:
async def test_load_with_insert_strategy(self, db_session: AsyncSession): async def test_load_with_insert_strategy(self, db_session: AsyncSession):
"""Load fixtures with INSERT strategy.""" """Load fixtures with INSERT strategy."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id_1 = uuid.uuid4()
role_id_2 = uuid.uuid4()
@registry.register @registry.register
def roles(): def roles():
return [ return [
Role(id=1, name="admin"), Role(id=role_id_1, name="admin"),
Role(id=2, name="user"), Role(id=role_id_2, name="user"),
] ]
result = await load_fixtures( result = await load_fixtures(
@@ -375,14 +576,16 @@ class TestLoadFixtures:
): ):
"""Load multiple independent fixtures.""" """Load multiple independent fixtures."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id_1 = uuid.uuid4()
role_id_2 = uuid.uuid4()
@registry.register @registry.register
def roles(): def roles():
return [Role(id=1, name="admin")] return [Role(id=role_id_1, name="admin")]
@registry.register @registry.register
def other_roles(): def other_roles():
return [Role(id=2, name="user")] return [Role(id=role_id_2, name="user")]
result = await load_fixtures(db_session, registry, "roles", "other_roles") result = await load_fixtures(db_session, registry, "roles", "other_roles")
@@ -402,14 +605,16 @@ class TestLoadFixturesByContext:
async def test_load_by_single_context(self, db_session: AsyncSession): async def test_load_by_single_context(self, db_session: AsyncSession):
"""Load fixtures by single context.""" """Load fixtures by single context."""
registry = FixtureRegistry() registry = FixtureRegistry()
base_role_id = uuid.uuid4()
test_role_id = uuid.uuid4()
@registry.register(contexts=[Context.BASE]) @registry.register(contexts=[Context.BASE])
def base_roles(): def base_roles():
return [Role(id=1, name="base_role")] return [Role(id=base_role_id, name="base_role")]
@registry.register(contexts=[Context.TESTING]) @registry.register(contexts=[Context.TESTING])
def test_roles(): def test_roles():
return [Role(id=100, name="test_role")] return [Role(id=test_role_id, name="test_role")]
await load_fixtures_by_context(db_session, registry, Context.BASE) await load_fixtures_by_context(db_session, registry, Context.BASE)
@@ -418,7 +623,7 @@ class TestLoadFixturesByContext:
count = await RoleCrud.count(db_session) count = await RoleCrud.count(db_session)
assert count == 1 assert count == 1
role = await RoleCrud.first(db_session, [Role.id == 1]) role = await RoleCrud.first(db_session, [Role.id == base_role_id])
assert role is not None assert role is not None
assert role.name == "base_role" assert role.name == "base_role"
@@ -426,14 +631,16 @@ class TestLoadFixturesByContext:
async def test_load_by_multiple_contexts(self, db_session: AsyncSession): async def test_load_by_multiple_contexts(self, db_session: AsyncSession):
"""Load fixtures by multiple contexts.""" """Load fixtures by multiple contexts."""
registry = FixtureRegistry() registry = FixtureRegistry()
base_role_id = uuid.uuid4()
test_role_id = uuid.uuid4()
@registry.register(contexts=[Context.BASE]) @registry.register(contexts=[Context.BASE])
def base_roles(): def base_roles():
return [Role(id=1, name="base_role")] return [Role(id=base_role_id, name="base_role")]
@registry.register(contexts=[Context.TESTING]) @registry.register(contexts=[Context.TESTING])
def test_roles(): def test_roles():
return [Role(id=100, name="test_role")] return [Role(id=test_role_id, name="test_role")]
await load_fixtures_by_context( await load_fixtures_by_context(
db_session, registry, Context.BASE, Context.TESTING db_session, registry, Context.BASE, Context.TESTING
@@ -448,14 +655,23 @@ class TestLoadFixturesByContext:
async def test_load_context_with_dependencies(self, db_session: AsyncSession): async def test_load_context_with_dependencies(self, db_session: AsyncSession):
"""Load context fixtures with cross-context dependencies.""" """Load context fixtures with cross-context dependencies."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id = uuid.uuid4()
user_id = uuid.uuid4()
@registry.register(contexts=[Context.BASE]) @registry.register(contexts=[Context.BASE])
def roles(): def roles():
return [Role(id=1, name="admin")] return [Role(id=role_id, name="admin")]
@registry.register(depends_on=["roles"], contexts=[Context.TESTING]) @registry.register(depends_on=["roles"], contexts=[Context.TESTING])
def test_users(): def test_users():
return [User(id=1, username="tester", email="test@test.com", role_id=1)] return [
User(
id=user_id,
username="tester",
email="test@test.com",
role_id=role_id,
)
]
await load_fixtures_by_context(db_session, registry, Context.TESTING) await load_fixtures_by_context(db_session, registry, Context.TESTING)
@@ -471,20 +687,41 @@ class TestGetObjByAttr:
def setup_method(self): def setup_method(self):
"""Set up test fixtures for each test.""" """Set up test fixtures for each test."""
self.registry = FixtureRegistry() self.registry = FixtureRegistry()
self.role_id_1 = uuid.uuid4()
self.role_id_2 = uuid.uuid4()
self.role_id_3 = uuid.uuid4()
self.user_id_1 = uuid.uuid4()
self.user_id_2 = uuid.uuid4()
role_id_1 = self.role_id_1
role_id_2 = self.role_id_2
role_id_3 = self.role_id_3
user_id_1 = self.user_id_1
user_id_2 = self.user_id_2
@self.registry.register @self.registry.register
def roles() -> list[Role]: def roles() -> list[Role]:
return [ return [
Role(id=1, name="admin"), Role(id=role_id_1, name="admin"),
Role(id=2, name="user"), Role(id=role_id_2, name="user"),
Role(id=3, name="moderator"), Role(id=role_id_3, name="moderator"),
] ]
@self.registry.register(depends_on=["roles"]) @self.registry.register(depends_on=["roles"])
def users() -> list[User]: def users() -> list[User]:
return [ return [
User(id=1, username="alice", email="alice@example.com", role_id=1), User(
User(id=2, username="bob", email="bob@example.com", role_id=1), id=user_id_1,
username="alice",
email="alice@example.com",
role_id=role_id_1,
),
User(
id=user_id_2,
username="bob",
email="bob@example.com",
role_id=role_id_1,
),
] ]
self.roles = roles self.roles = roles
@@ -492,18 +729,18 @@ class TestGetObjByAttr:
def test_get_by_id(self): def test_get_by_id(self):
"""Get an object by its id attribute.""" """Get an object by its id attribute."""
role = get_obj_by_attr(self.roles, "id", 1) role = get_obj_by_attr(self.roles, "id", self.role_id_1)
assert role.name == "admin" assert role.name == "admin"
def test_get_user_by_username(self): def test_get_user_by_username(self):
"""Get a user by username.""" """Get a user by username."""
user = get_obj_by_attr(self.users, "username", "bob") user = get_obj_by_attr(self.users, "username", "bob")
assert user.id == 2 assert user.id == self.user_id_2
assert user.email == "bob@example.com" assert user.email == "bob@example.com"
def test_returns_first_match(self): def test_returns_first_match(self):
"""Returns the first matching object when multiple could match.""" """Returns the first matching object when multiple could match."""
user = get_obj_by_attr(self.users, "role_id", 1) user = get_obj_by_attr(self.users, "role_id", self.role_id_1)
assert user.username == "alice" assert user.username == "alice"
def test_no_match_raises_stop_iteration(self): def test_no_match_raises_stop_iteration(self):
@@ -514,4 +751,4 @@ class TestGetObjByAttr:
def test_no_match_on_wrong_value_type(self): def test_no_match_on_wrong_value_type(self):
"""Raises StopIteration when value type doesn't match.""" """Raises StopIteration when value type doesn't match."""
with pytest.raises(StopIteration): with pytest.raises(StopIteration):
get_obj_by_attr(self.roles, "id", "1") get_obj_by_attr(self.roles, "id", "not-a-uuid")

View File

@@ -1,5 +1,7 @@
"""Tests for fastapi_toolsets.pytest module.""" """Tests for fastapi_toolsets.pytest module."""
import uuid
import pytest import pytest
from fastapi import FastAPI from fastapi import FastAPI
from httpx import AsyncClient from httpx import AsyncClient
@@ -18,27 +20,49 @@ from .conftest import DATABASE_URL, Base, Role, RoleCrud, User, UserCrud
test_registry = FixtureRegistry() test_registry = FixtureRegistry()
# Fixed UUIDs for test fixtures to allow consistent assertions
ROLE_ADMIN_ID = uuid.UUID("00000000-0000-0000-0000-000000001000")
ROLE_USER_ID = uuid.UUID("00000000-0000-0000-0000-000000001001")
USER_ADMIN_ID = uuid.UUID("00000000-0000-0000-0000-000000002000")
USER_USER_ID = uuid.UUID("00000000-0000-0000-0000-000000002001")
USER_EXTRA_ID = uuid.UUID("00000000-0000-0000-0000-000000002002")
@test_registry.register(contexts=[Context.BASE]) @test_registry.register(contexts=[Context.BASE])
def roles() -> list[Role]: def roles() -> list[Role]:
return [ return [
Role(id=1000, name="plugin_admin"), Role(id=ROLE_ADMIN_ID, name="plugin_admin"),
Role(id=1001, name="plugin_user"), Role(id=ROLE_USER_ID, name="plugin_user"),
] ]
@test_registry.register(depends_on=["roles"], contexts=[Context.BASE]) @test_registry.register(depends_on=["roles"], contexts=[Context.BASE])
def users() -> list[User]: def users() -> list[User]:
return [ return [
User(id=1000, username="plugin_admin", email="padmin@test.com", role_id=1000), User(
User(id=1001, username="plugin_user", email="puser@test.com", role_id=1001), id=USER_ADMIN_ID,
username="plugin_admin",
email="padmin@test.com",
role_id=ROLE_ADMIN_ID,
),
User(
id=USER_USER_ID,
username="plugin_user",
email="puser@test.com",
role_id=ROLE_USER_ID,
),
] ]
@test_registry.register(depends_on=["users"], contexts=[Context.TESTING]) @test_registry.register(depends_on=["users"], contexts=[Context.TESTING])
def extra_users() -> list[User]: def extra_users() -> list[User]:
return [ return [
User(id=1002, username="plugin_extra", email="pextra@test.com", role_id=1001), User(
id=USER_EXTRA_ID,
username="plugin_extra",
email="pextra@test.com",
role_id=ROLE_USER_ID,
),
] ]
@@ -73,7 +97,7 @@ class TestGeneratedFixtures:
assert fixture_roles[1].name == "plugin_user" assert fixture_roles[1].name == "plugin_user"
# Verify data is in database # Verify data is in database
count = await RoleCrud.count(db_session, [Role.id >= 1000]) count = await RoleCrud.count(db_session)
assert count == 2 assert count == 2
@pytest.mark.anyio @pytest.mark.anyio
@@ -86,11 +110,11 @@ class TestGeneratedFixtures:
assert len(fixture_users) == 2 assert len(fixture_users) == 2
# Roles should also be in database # Roles should also be in database
roles_count = await RoleCrud.count(db_session, [Role.id >= 1000]) roles_count = await RoleCrud.count(db_session)
assert roles_count == 2 assert roles_count == 2
# Users should be in database # Users should be in database
users_count = await UserCrud.count(db_session, [User.id >= 1000]) users_count = await UserCrud.count(db_session)
assert users_count == 2 assert users_count == 2
@pytest.mark.anyio @pytest.mark.anyio
@@ -100,7 +124,7 @@ class TestGeneratedFixtures:
"""Fixture returns actual model instances.""" """Fixture returns actual model instances."""
user = fixture_users[0] user = fixture_users[0]
assert isinstance(user, User) assert isinstance(user, User)
assert user.id == 1000 assert user.id == USER_ADMIN_ID
assert user.username == "plugin_admin" assert user.username == "plugin_admin"
@pytest.mark.anyio @pytest.mark.anyio
@@ -111,7 +135,7 @@ class TestGeneratedFixtures:
# Load user with role relationship # Load user with role relationship
user = await UserCrud.get( user = await UserCrud.get(
db_session, db_session,
[User.id == 1000], [User.id == USER_ADMIN_ID],
load_options=[selectinload(User.role)], load_options=[selectinload(User.role)],
) )
@@ -127,8 +151,8 @@ class TestGeneratedFixtures:
assert len(fixture_extra_users) == 1 assert len(fixture_extra_users) == 1
# All fixtures should be loaded # All fixtures should be loaded
roles_count = await RoleCrud.count(db_session, [Role.id >= 1000]) roles_count = await RoleCrud.count(db_session)
users_count = await UserCrud.count(db_session, [User.id >= 1000]) users_count = await UserCrud.count(db_session)
assert roles_count == 2 assert roles_count == 2
assert users_count == 3 # 2 from users + 1 from extra_users assert users_count == 3 # 2 from users + 1 from extra_users
@@ -141,8 +165,7 @@ class TestGeneratedFixtures:
# Get all users loaded by fixture # Get all users loaded by fixture
users = await UserCrud.get_multi( users = await UserCrud.get_multi(
db_session, db_session,
filters=[User.id >= 1000], order_by=User.username,
order_by=User.id,
) )
assert len(users) == 2 assert len(users) == 2
@@ -161,8 +184,8 @@ class TestGeneratedFixtures:
assert len(fixture_users) == 2 assert len(fixture_users) == 2
# Both should be in database # Both should be in database
roles = await RoleCrud.get_multi(db_session, filters=[Role.id >= 1000]) roles = await RoleCrud.get_multi(db_session)
users = await UserCrud.get_multi(db_session, filters=[User.id >= 1000]) users = await UserCrud.get_multi(db_session)
assert len(roles) == 2 assert len(roles) == 2
assert len(users) == 2 assert len(users) == 2
@@ -215,14 +238,15 @@ class TestCreateDbSession:
@pytest.mark.anyio @pytest.mark.anyio
async def test_creates_working_session(self): async def test_creates_working_session(self):
"""Session can perform database operations.""" """Session can perform database operations."""
role_id = uuid.uuid4()
async with create_db_session(DATABASE_URL, Base) as session: async with create_db_session(DATABASE_URL, Base) as session:
assert isinstance(session, AsyncSession) assert isinstance(session, AsyncSession)
role = Role(id=9001, name="test_helper_role") role = Role(id=role_id, name="test_helper_role")
session.add(role) session.add(role)
await session.commit() await session.commit()
result = await session.execute(select(Role).where(Role.id == 9001)) result = await session.execute(select(Role).where(Role.id == role_id))
fetched = result.scalar_one() fetched = result.scalar_one()
assert fetched.name == "test_helper_role" assert fetched.name == "test_helper_role"
@@ -237,8 +261,9 @@ class TestCreateDbSession:
@pytest.mark.anyio @pytest.mark.anyio
async def test_tables_dropped_after_session(self): async def test_tables_dropped_after_session(self):
"""Tables are dropped after session closes when drop_tables=True.""" """Tables are dropped after session closes when drop_tables=True."""
role_id = uuid.uuid4()
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session: async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
role = Role(id=9002, name="will_be_dropped") role = Role(id=role_id, name="will_be_dropped")
session.add(role) session.add(role)
await session.commit() await session.commit()
@@ -250,14 +275,15 @@ class TestCreateDbSession:
@pytest.mark.anyio @pytest.mark.anyio
async def test_tables_preserved_when_drop_disabled(self): async def test_tables_preserved_when_drop_disabled(self):
"""Tables are preserved when drop_tables=False.""" """Tables are preserved when drop_tables=False."""
role_id = uuid.uuid4()
async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session: async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session:
role = Role(id=9003, name="preserved_role") role = Role(id=role_id, name="preserved_role")
session.add(role) session.add(role)
await session.commit() await session.commit()
# Create another session without dropping # Create another session without dropping
async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session: async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session:
result = await session.execute(select(Role).where(Role.id == 9003)) result = await session.execute(select(Role).where(Role.id == role_id))
fetched = result.scalar_one_or_none() fetched = result.scalar_one_or_none()
assert fetched is not None assert fetched is not None
assert fetched.name == "preserved_role" assert fetched.name == "preserved_role"

2
uv.lock generated
View File

@@ -220,7 +220,7 @@ wheels = [
[[package]] [[package]]
name = "fastapi-toolsets" name = "fastapi-toolsets"
version = "0.4.0" version = "0.5.0"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "asyncpg" }, { name = "asyncpg" },