mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-03-01 17:00:48 +01:00
Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
0f50c8a0f0
|
|||
|
|
691fb78fda | ||
|
|
34ef4da317 | ||
|
|
8c287b3ce7 | ||
|
54f5479c24
|
|||
|
|
f467754df1 | ||
|
b57ce40b05
|
|||
|
5264631550
|
|||
|
a76f7c439d
|
|||
|
|
d14551781c | ||
|
|
577e087321 |
2
.github/workflows/build-release.yml
vendored
2
.github/workflows/build-release.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
|||||||
uses: astral-sh/setup-uv@v7
|
uses: astral-sh/setup-uv@v7
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
run: uv python install 3.13
|
run: uv python install 3.14
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: uv sync
|
run: uv sync
|
||||||
|
|||||||
4
.github/workflows/ci.yml
vendored
4
.github/workflows/ci.yml
vendored
@@ -56,7 +56,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ["3.11", "3.12", "3.13"]
|
python-version: ["3.11", "3.12", "3.13", "3.14"]
|
||||||
|
|
||||||
services:
|
services:
|
||||||
postgres:
|
postgres:
|
||||||
@@ -92,7 +92,7 @@ jobs:
|
|||||||
uv run pytest --cov --cov-report=xml --cov-report=term-missing
|
uv run pytest --cov --cov-report=xml --cov-report=term-missing
|
||||||
|
|
||||||
- name: Upload coverage to Codecov
|
- name: Upload coverage to Codecov
|
||||||
if: matrix.python-version == '3.13'
|
if: matrix.python-version == '3.14'
|
||||||
uses: codecov/codecov-action@v5
|
uses: codecov/codecov-action@v5
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
3.13
|
3.14
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "fastapi-toolsets"
|
name = "fastapi-toolsets"
|
||||||
version = "0.3.0"
|
version = "0.5.0"
|
||||||
description = "Reusable tools for FastAPI: async CRUD, fixtures, CLI, and standardized responses for SQLAlchemy + PostgreSQL"
|
description = "Reusable tools for FastAPI: async CRUD, fixtures, CLI, and standardized responses for SQLAlchemy + PostgreSQL"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
@@ -24,6 +24,7 @@ classifiers = [
|
|||||||
"Programming Language :: Python :: 3.11",
|
"Programming Language :: Python :: 3.11",
|
||||||
"Programming Language :: Python :: 3.12",
|
"Programming Language :: Python :: 3.12",
|
||||||
"Programming Language :: Python :: 3.13",
|
"Programming Language :: Python :: 3.13",
|
||||||
|
"Programming Language :: Python :: 3.14",
|
||||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||||
"Topic :: Software Development :: Libraries",
|
"Topic :: Software Development :: Libraries",
|
||||||
"Topic :: Software Development",
|
"Topic :: Software Development",
|
||||||
@@ -58,7 +59,7 @@ dev = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
fastapi-toolsets = "fastapi_toolsets.cli:app"
|
manager = "fastapi_toolsets.cli:cli"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["uv_build>=0.9.26,<0.10.0"]
|
requires = ["uv_build>=0.9.26,<0.10.0"]
|
||||||
|
|||||||
@@ -21,4 +21,4 @@ Example usage:
|
|||||||
return Response(data={"user": user.username}, message="Success")
|
return Response(data={"user": user.username}, message="Success")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__version__ = "0.3.0"
|
__version__ = "0.5.0"
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""CLI for FastAPI projects."""
|
"""CLI for FastAPI projects."""
|
||||||
|
|
||||||
from .app import app, register_command
|
from .app import cli
|
||||||
|
from .utils import async_command
|
||||||
|
|
||||||
__all__ = ["app", "register_command"]
|
__all__ = ["async_command", "cli"]
|
||||||
|
|||||||
@@ -1,97 +1,25 @@
|
|||||||
"""Main CLI application."""
|
"""Main CLI application."""
|
||||||
|
|
||||||
import importlib.util
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Annotated
|
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
|
|
||||||
from .commands import fixtures
|
from .config import load_config
|
||||||
|
|
||||||
app = typer.Typer(
|
cli = typer.Typer(
|
||||||
name="fastapi-utils",
|
name="manager",
|
||||||
help="CLI utilities for FastAPI projects.",
|
help="CLI utilities for FastAPI projects.",
|
||||||
no_args_is_help=True,
|
no_args_is_help=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Register built-in commands
|
_config = load_config()
|
||||||
app.add_typer(fixtures.app, name="fixtures")
|
|
||||||
|
if _config.fixtures:
|
||||||
|
from .commands.fixtures import fixture_cli
|
||||||
|
|
||||||
|
cli.add_typer(fixture_cli, name="fixtures")
|
||||||
|
|
||||||
|
|
||||||
def register_command(command: typer.Typer, name: str) -> None:
|
@cli.callback()
|
||||||
"""Register a custom command group.
|
def main(ctx: typer.Context) -> None:
|
||||||
|
|
||||||
Args:
|
|
||||||
command: Typer app for the command group
|
|
||||||
name: Name for the command group
|
|
||||||
|
|
||||||
Example:
|
|
||||||
# In your project's cli.py:
|
|
||||||
import typer
|
|
||||||
from fastapi_toolsets.cli import app, register_command
|
|
||||||
|
|
||||||
my_commands = typer.Typer()
|
|
||||||
|
|
||||||
@my_commands.command()
|
|
||||||
def seed():
|
|
||||||
'''Seed the database.'''
|
|
||||||
...
|
|
||||||
|
|
||||||
register_command(my_commands, "db")
|
|
||||||
# Now available as: fastapi-utils db seed
|
|
||||||
"""
|
|
||||||
app.add_typer(command, name=name)
|
|
||||||
|
|
||||||
|
|
||||||
@app.callback()
|
|
||||||
def main(
|
|
||||||
ctx: typer.Context,
|
|
||||||
config: Annotated[
|
|
||||||
Path | None,
|
|
||||||
typer.Option(
|
|
||||||
"--config",
|
|
||||||
"-c",
|
|
||||||
help="Path to project config file (Python module with fixtures registry).",
|
|
||||||
envvar="FASTAPI_TOOLSETS_CONFIG",
|
|
||||||
),
|
|
||||||
] = None,
|
|
||||||
) -> None:
|
|
||||||
"""FastAPI utilities CLI."""
|
"""FastAPI utilities CLI."""
|
||||||
ctx.ensure_object(dict)
|
ctx.ensure_object(dict)
|
||||||
|
ctx.obj["config"] = _config
|
||||||
if config:
|
|
||||||
ctx.obj["config_path"] = config
|
|
||||||
# Load the config module
|
|
||||||
config_module = _load_module_from_path(config)
|
|
||||||
ctx.obj["config_module"] = config_module
|
|
||||||
|
|
||||||
|
|
||||||
def _load_module_from_path(path: Path) -> object:
|
|
||||||
"""Load a Python module from a file path.
|
|
||||||
|
|
||||||
Handles both absolute and relative imports by adding the config's
|
|
||||||
parent directory to sys.path temporarily.
|
|
||||||
"""
|
|
||||||
path = path.resolve()
|
|
||||||
|
|
||||||
# Add the parent directory to sys.path to support relative imports
|
|
||||||
parent_dir = str(
|
|
||||||
path.parent.parent
|
|
||||||
) # Go up two levels (e.g., from app/cli_config.py to project root)
|
|
||||||
if parent_dir not in sys.path:
|
|
||||||
sys.path.insert(0, parent_dir)
|
|
||||||
|
|
||||||
# Also add immediate parent for direct module imports
|
|
||||||
immediate_parent = str(path.parent)
|
|
||||||
if immediate_parent not in sys.path:
|
|
||||||
sys.path.insert(0, immediate_parent)
|
|
||||||
|
|
||||||
spec = importlib.util.spec_from_file_location("config", path)
|
|
||||||
if spec is None or spec.loader is None:
|
|
||||||
raise typer.BadParameter(f"Cannot load module from {path}")
|
|
||||||
|
|
||||||
module = importlib.util.module_from_spec(spec)
|
|
||||||
sys.modules["config"] = module
|
|
||||||
spec.loader.exec_module(module)
|
|
||||||
return module
|
|
||||||
|
|||||||
@@ -1,55 +1,29 @@
|
|||||||
"""Fixture management commands."""
|
"""Fixture management commands."""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
from ...fixtures import Context, FixtureRegistry, LoadStrategy, load_fixtures_by_context
|
from ...fixtures import Context, LoadStrategy, load_fixtures_by_context
|
||||||
|
from ..config import CliConfig
|
||||||
|
from ..utils import async_command
|
||||||
|
|
||||||
app = typer.Typer(
|
fixture_cli = typer.Typer(
|
||||||
name="fixtures",
|
name="fixtures",
|
||||||
help="Manage database fixtures.",
|
help="Manage database fixtures.",
|
||||||
no_args_is_help=True,
|
no_args_is_help=True,
|
||||||
)
|
)
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
|
||||||
def _get_registry(ctx: typer.Context) -> FixtureRegistry:
|
def _get_config(ctx: typer.Context) -> CliConfig:
|
||||||
"""Get fixture registry from context."""
|
"""Get CLI config from context."""
|
||||||
config = ctx.obj.get("config_module") if ctx.obj else None
|
return ctx.obj["config"]
|
||||||
if config is None:
|
|
||||||
raise typer.BadParameter(
|
|
||||||
"No config provided. Use --config to specify a config file with a 'fixtures' registry."
|
|
||||||
)
|
|
||||||
|
|
||||||
registry = getattr(config, "fixtures", None)
|
|
||||||
if registry is None:
|
|
||||||
raise typer.BadParameter(
|
|
||||||
"Config module must have a 'fixtures' attribute (FixtureRegistry instance)."
|
|
||||||
)
|
|
||||||
|
|
||||||
if not isinstance(registry, FixtureRegistry):
|
|
||||||
raise typer.BadParameter(
|
|
||||||
f"'fixtures' must be a FixtureRegistry instance, got {type(registry).__name__}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return registry
|
|
||||||
|
|
||||||
|
|
||||||
def _get_db_context(ctx: typer.Context):
|
@fixture_cli.command("list")
|
||||||
"""Get database context manager from config."""
|
|
||||||
config = ctx.obj.get("config_module") if ctx.obj else None
|
|
||||||
if config is None:
|
|
||||||
raise typer.BadParameter("No config provided.")
|
|
||||||
|
|
||||||
get_db_context = getattr(config, "get_db_context", None)
|
|
||||||
if get_db_context is None:
|
|
||||||
raise typer.BadParameter("Config module must have a 'get_db_context' function.")
|
|
||||||
|
|
||||||
return get_db_context
|
|
||||||
|
|
||||||
|
|
||||||
@app.command("list")
|
|
||||||
def list_fixtures(
|
def list_fixtures(
|
||||||
ctx: typer.Context,
|
ctx: typer.Context,
|
||||||
context: Annotated[
|
context: Annotated[
|
||||||
@@ -62,64 +36,28 @@ def list_fixtures(
|
|||||||
] = None,
|
] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""List all registered fixtures."""
|
"""List all registered fixtures."""
|
||||||
registry = _get_registry(ctx)
|
config = _get_config(ctx)
|
||||||
|
registry = config.get_fixtures_registry()
|
||||||
if context:
|
fixtures = registry.get_by_context(context) if context else registry.get_all()
|
||||||
fixtures = registry.get_by_context(context)
|
|
||||||
else:
|
|
||||||
fixtures = registry.get_all()
|
|
||||||
|
|
||||||
if not fixtures:
|
if not fixtures:
|
||||||
typer.echo("No fixtures found.")
|
print("No fixtures found.")
|
||||||
return
|
return
|
||||||
|
|
||||||
typer.echo(f"\n{'Name':<30} {'Contexts':<30} {'Dependencies'}")
|
table = Table("Name", "Contexts", "Dependencies")
|
||||||
typer.echo("-" * 80)
|
|
||||||
|
|
||||||
for fixture in fixtures:
|
for fixture in fixtures:
|
||||||
contexts = ", ".join(fixture.contexts)
|
contexts = ", ".join(fixture.contexts)
|
||||||
deps = ", ".join(fixture.depends_on) if fixture.depends_on else "-"
|
deps = ", ".join(fixture.depends_on) if fixture.depends_on else "-"
|
||||||
typer.echo(f"{fixture.name:<30} {contexts:<30} {deps}")
|
table.add_row(fixture.name, contexts, deps)
|
||||||
|
|
||||||
typer.echo(f"\nTotal: {len(fixtures)} fixture(s)")
|
console.print(table)
|
||||||
|
print(f"\nTotal: {len(fixtures)} fixture(s)")
|
||||||
|
|
||||||
|
|
||||||
@app.command("graph")
|
@fixture_cli.command("load")
|
||||||
def show_graph(
|
@async_command
|
||||||
ctx: typer.Context,
|
async def load(
|
||||||
fixture_name: Annotated[
|
|
||||||
str | None,
|
|
||||||
typer.Argument(help="Show dependencies for a specific fixture."),
|
|
||||||
] = None,
|
|
||||||
) -> None:
|
|
||||||
"""Show fixture dependency graph."""
|
|
||||||
registry = _get_registry(ctx)
|
|
||||||
|
|
||||||
if fixture_name:
|
|
||||||
try:
|
|
||||||
order = registry.resolve_dependencies(fixture_name)
|
|
||||||
typer.echo(f"\nDependency chain for '{fixture_name}':\n")
|
|
||||||
for i, name in enumerate(order):
|
|
||||||
indent = " " * i
|
|
||||||
arrow = "└─> " if i > 0 else ""
|
|
||||||
typer.echo(f"{indent}{arrow}{name}")
|
|
||||||
except KeyError:
|
|
||||||
typer.echo(f"Fixture '{fixture_name}' not found.", err=True)
|
|
||||||
raise typer.Exit(1)
|
|
||||||
else:
|
|
||||||
# Show full graph
|
|
||||||
fixtures = registry.get_all()
|
|
||||||
|
|
||||||
typer.echo("\nFixture Dependency Graph:\n")
|
|
||||||
for fixture in fixtures:
|
|
||||||
deps = (
|
|
||||||
f" -> [{', '.join(fixture.depends_on)}]" if fixture.depends_on else ""
|
|
||||||
)
|
|
||||||
typer.echo(f" {fixture.name}{deps}")
|
|
||||||
|
|
||||||
|
|
||||||
@app.command("load")
|
|
||||||
def load(
|
|
||||||
ctx: typer.Context,
|
ctx: typer.Context,
|
||||||
contexts: Annotated[
|
contexts: Annotated[
|
||||||
list[str] | None,
|
list[str] | None,
|
||||||
@@ -141,16 +79,12 @@ def load(
|
|||||||
] = False,
|
] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Load fixtures into the database."""
|
"""Load fixtures into the database."""
|
||||||
registry = _get_registry(ctx)
|
config = _get_config(ctx)
|
||||||
get_db_context = _get_db_context(ctx)
|
registry = config.get_fixtures_registry()
|
||||||
|
get_db_context = config.get_db_context()
|
||||||
|
|
||||||
# Parse contexts
|
context_list = contexts if contexts else [Context.BASE]
|
||||||
if contexts:
|
|
||||||
context_list = contexts
|
|
||||||
else:
|
|
||||||
context_list = [Context.BASE]
|
|
||||||
|
|
||||||
# Parse strategy
|
|
||||||
try:
|
try:
|
||||||
load_strategy = LoadStrategy(strategy)
|
load_strategy = LoadStrategy(strategy)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
@@ -159,67 +93,27 @@ def load(
|
|||||||
)
|
)
|
||||||
raise typer.Exit(1)
|
raise typer.Exit(1)
|
||||||
|
|
||||||
# Resolve what will be loaded
|
|
||||||
ordered = registry.resolve_context_dependencies(*context_list)
|
ordered = registry.resolve_context_dependencies(*context_list)
|
||||||
|
|
||||||
if not ordered:
|
if not ordered:
|
||||||
typer.echo("No fixtures to load for the specified context(s).")
|
print("No fixtures to load for the specified context(s).")
|
||||||
return
|
return
|
||||||
|
|
||||||
typer.echo(f"\nFixtures to load ({load_strategy.value} strategy):")
|
print(f"\nFixtures to load ({load_strategy.value} strategy):")
|
||||||
for name in ordered:
|
for name in ordered:
|
||||||
fixture = registry.get(name)
|
fixture = registry.get(name)
|
||||||
instances = list(fixture.func())
|
instances = list(fixture.func())
|
||||||
model_name = type(instances[0]).__name__ if instances else "?"
|
model_name = type(instances[0]).__name__ if instances else "?"
|
||||||
typer.echo(f" - {name}: {len(instances)} {model_name}(s)")
|
print(f" - {name}: {len(instances)} {model_name}(s)")
|
||||||
|
|
||||||
if dry_run:
|
if dry_run:
|
||||||
typer.echo("\n[Dry run - no changes made]")
|
print("\n[Dry run - no changes made]")
|
||||||
return
|
return
|
||||||
|
|
||||||
typer.echo("\nLoading...")
|
async with get_db_context() as session:
|
||||||
|
result = await load_fixtures_by_context(
|
||||||
async def do_load():
|
session, registry, *context_list, strategy=load_strategy
|
||||||
async with get_db_context() as session:
|
)
|
||||||
result = await load_fixtures_by_context(
|
|
||||||
session, registry, *context_list, strategy=load_strategy
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
result = asyncio.run(do_load())
|
|
||||||
|
|
||||||
total = sum(len(items) for items in result.values())
|
total = sum(len(items) for items in result.values())
|
||||||
typer.echo(f"\nLoaded {total} record(s) successfully.")
|
print(f"\nLoaded {total} record(s) successfully.")
|
||||||
|
|
||||||
|
|
||||||
@app.command("show")
|
|
||||||
def show_fixture(
|
|
||||||
ctx: typer.Context,
|
|
||||||
name: Annotated[str, typer.Argument(help="Fixture name to show.")],
|
|
||||||
) -> None:
|
|
||||||
"""Show details of a specific fixture."""
|
|
||||||
registry = _get_registry(ctx)
|
|
||||||
|
|
||||||
try:
|
|
||||||
fixture = registry.get(name)
|
|
||||||
except KeyError:
|
|
||||||
typer.echo(f"Fixture '{name}' not found.", err=True)
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
typer.echo(f"\nFixture: {fixture.name}")
|
|
||||||
typer.echo(f"Contexts: {', '.join(fixture.contexts)}")
|
|
||||||
typer.echo(
|
|
||||||
f"Dependencies: {', '.join(fixture.depends_on) if fixture.depends_on else 'None'}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Show instances
|
|
||||||
instances = list(fixture.func())
|
|
||||||
if instances:
|
|
||||||
model_name = type(instances[0]).__name__
|
|
||||||
typer.echo(f"\nInstances ({len(instances)} {model_name}):")
|
|
||||||
for instance in instances[:10]: # Limit to 10
|
|
||||||
typer.echo(f" - {instance!r}")
|
|
||||||
if len(instances) > 10:
|
|
||||||
typer.echo(f" ... and {len(instances) - 10} more")
|
|
||||||
else:
|
|
||||||
typer.echo("\nNo instances (empty fixture)")
|
|
||||||
|
|||||||
92
src/fastapi_toolsets/cli/config.py
Normal file
92
src/fastapi_toolsets/cli/config.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
"""CLI configuration."""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import sys
|
||||||
|
import tomllib
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import typer
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CliConfig:
|
||||||
|
"""CLI configuration loaded from pyproject.toml."""
|
||||||
|
|
||||||
|
fixtures: str | None = None
|
||||||
|
db_context: str | None = None
|
||||||
|
|
||||||
|
def get_fixtures_registry(self):
|
||||||
|
"""Import and return the fixtures registry."""
|
||||||
|
from ..fixtures import FixtureRegistry
|
||||||
|
|
||||||
|
if not self.fixtures:
|
||||||
|
raise typer.BadParameter(
|
||||||
|
"No fixtures registry configured. "
|
||||||
|
"Add 'fixtures' to [tool.fastapi-toolsets] in pyproject.toml."
|
||||||
|
)
|
||||||
|
|
||||||
|
registry = _import_from_string(self.fixtures)
|
||||||
|
|
||||||
|
if not isinstance(registry, FixtureRegistry):
|
||||||
|
raise typer.BadParameter(
|
||||||
|
f"'fixtures' must be a FixtureRegistry instance, got {type(registry).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return registry
|
||||||
|
|
||||||
|
def get_db_context(self):
|
||||||
|
"""Import and return the db_context function."""
|
||||||
|
if not self.db_context:
|
||||||
|
raise typer.BadParameter(
|
||||||
|
"No db_context configured. "
|
||||||
|
"Add 'db_context' to [tool.fastapi-toolsets] in pyproject.toml."
|
||||||
|
)
|
||||||
|
return _import_from_string(self.db_context)
|
||||||
|
|
||||||
|
|
||||||
|
def _import_from_string(import_path: str):
|
||||||
|
"""Import an object from a string path like 'module.submodule:attribute'."""
|
||||||
|
if ":" not in import_path:
|
||||||
|
raise typer.BadParameter(
|
||||||
|
f"Invalid import path '{import_path}'. Expected format: 'module:attribute'"
|
||||||
|
)
|
||||||
|
|
||||||
|
module_path, attr_name = import_path.rsplit(":", 1)
|
||||||
|
|
||||||
|
# Add cwd to sys.path for local imports
|
||||||
|
cwd = str(Path.cwd())
|
||||||
|
if cwd not in sys.path:
|
||||||
|
sys.path.insert(0, cwd)
|
||||||
|
|
||||||
|
try:
|
||||||
|
module = importlib.import_module(module_path)
|
||||||
|
except ImportError as e:
|
||||||
|
raise typer.BadParameter(f"Cannot import module '{module_path}': {e}")
|
||||||
|
|
||||||
|
if not hasattr(module, attr_name):
|
||||||
|
raise typer.BadParameter(
|
||||||
|
f"Module '{module_path}' has no attribute '{attr_name}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
return getattr(module, attr_name)
|
||||||
|
|
||||||
|
|
||||||
|
def load_config() -> CliConfig:
|
||||||
|
"""Load CLI configuration from pyproject.toml."""
|
||||||
|
pyproject_path = Path.cwd() / "pyproject.toml"
|
||||||
|
|
||||||
|
if not pyproject_path.exists():
|
||||||
|
return CliConfig()
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(pyproject_path, "rb") as f:
|
||||||
|
data = tomllib.load(f)
|
||||||
|
|
||||||
|
tool_config = data.get("tool", {}).get("fastapi-toolsets", {})
|
||||||
|
return CliConfig(
|
||||||
|
fixtures=tool_config.get("fixtures"),
|
||||||
|
db_context=tool_config.get("db_context"),
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
return CliConfig()
|
||||||
27
src/fastapi_toolsets/cli/utils.py
Normal file
27
src/fastapi_toolsets/cli/utils.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
"""CLI utility functions."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import functools
|
||||||
|
from collections.abc import Callable, Coroutine
|
||||||
|
from typing import Any, ParamSpec, TypeVar
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def async_command(func: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, T]:
|
||||||
|
"""Decorator to run an async function as a sync CLI command.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
@fixture_cli.command("load")
|
||||||
|
@async_command
|
||||||
|
async def load(ctx: typer.Context) -> None:
|
||||||
|
async with get_db_context() as session:
|
||||||
|
await load_fixtures(session, registry)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||||
|
return asyncio.run(func(*args, **kwargs))
|
||||||
|
|
||||||
|
return wrapper
|
||||||
15
src/fastapi_toolsets/crud/__init__.py
Normal file
15
src/fastapi_toolsets/crud/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
"""Generic async CRUD operations for SQLAlchemy models."""
|
||||||
|
|
||||||
|
from ..exceptions import NoSearchableFieldsError
|
||||||
|
from .factory import CrudFactory
|
||||||
|
from .search import (
|
||||||
|
SearchConfig,
|
||||||
|
get_searchable_fields,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CrudFactory",
|
||||||
|
"get_searchable_fields",
|
||||||
|
"NoSearchableFieldsError",
|
||||||
|
"SearchConfig",
|
||||||
|
]
|
||||||
@@ -12,35 +12,22 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
from sqlalchemy.sql.roles import WhereHavingRole
|
from sqlalchemy.sql.roles import WhereHavingRole
|
||||||
|
|
||||||
from .db import get_transaction
|
from ..db import get_transaction
|
||||||
from .exceptions import NotFoundError
|
from ..exceptions import NotFoundError
|
||||||
|
from .search import SearchConfig, SearchFieldType, build_search_filters
|
||||||
__all__ = [
|
|
||||||
"AsyncCrud",
|
|
||||||
"CrudFactory",
|
|
||||||
]
|
|
||||||
|
|
||||||
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
||||||
|
JoinType = list[tuple[type[DeclarativeBase], Any]]
|
||||||
|
|
||||||
|
|
||||||
class AsyncCrud(Generic[ModelType]):
|
class AsyncCrud(Generic[ModelType]):
|
||||||
"""Generic async CRUD operations for SQLAlchemy models.
|
"""Generic async CRUD operations for SQLAlchemy models.
|
||||||
|
|
||||||
Subclass this and set the `model` class variable, or use `CrudFactory`.
|
Subclass this and set the `model` class variable, or use `CrudFactory`.
|
||||||
|
|
||||||
Example:
|
|
||||||
class UserCrud(AsyncCrud[User]):
|
|
||||||
model = User
|
|
||||||
|
|
||||||
# Or use the factory:
|
|
||||||
UserCrud = CrudFactory(User)
|
|
||||||
|
|
||||||
# Then use it:
|
|
||||||
user = await UserCrud.get(session, [User.id == 1])
|
|
||||||
users = await UserCrud.get_multi(session, limit=10)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model: ClassVar[type[DeclarativeBase]]
|
model: ClassVar[type[DeclarativeBase]]
|
||||||
|
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create(
|
async def create(
|
||||||
@@ -69,6 +56,8 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
filters: list[Any],
|
filters: list[Any],
|
||||||
*,
|
*,
|
||||||
|
joins: JoinType | None = None,
|
||||||
|
outer_join: bool = False,
|
||||||
with_for_update: bool = False,
|
with_for_update: bool = False,
|
||||||
load_options: list[Any] | None = None,
|
load_options: list[Any] | None = None,
|
||||||
) -> ModelType:
|
) -> ModelType:
|
||||||
@@ -77,6 +66,8 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
Args:
|
Args:
|
||||||
session: DB async session
|
session: DB async session
|
||||||
filters: List of SQLAlchemy filter conditions
|
filters: List of SQLAlchemy filter conditions
|
||||||
|
joins: List of (model, condition) tuples for joining related tables
|
||||||
|
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
|
||||||
with_for_update: Lock the row for update
|
with_for_update: Lock the row for update
|
||||||
load_options: SQLAlchemy loader options (e.g., selectinload)
|
load_options: SQLAlchemy loader options (e.g., selectinload)
|
||||||
|
|
||||||
@@ -87,7 +78,15 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
NotFoundError: If no record found
|
NotFoundError: If no record found
|
||||||
MultipleResultsFound: If more than one record found
|
MultipleResultsFound: If more than one record found
|
||||||
"""
|
"""
|
||||||
q = select(cls.model).where(and_(*filters))
|
q = select(cls.model)
|
||||||
|
if joins:
|
||||||
|
for model, condition in joins:
|
||||||
|
q = (
|
||||||
|
q.outerjoin(model, condition)
|
||||||
|
if outer_join
|
||||||
|
else q.join(model, condition)
|
||||||
|
)
|
||||||
|
q = q.where(and_(*filters))
|
||||||
if load_options:
|
if load_options:
|
||||||
q = q.options(*load_options)
|
q = q.options(*load_options)
|
||||||
if with_for_update:
|
if with_for_update:
|
||||||
@@ -104,6 +103,8 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
filters: list[Any] | None = None,
|
filters: list[Any] | None = None,
|
||||||
*,
|
*,
|
||||||
|
joins: JoinType | None = None,
|
||||||
|
outer_join: bool = False,
|
||||||
load_options: list[Any] | None = None,
|
load_options: list[Any] | None = None,
|
||||||
) -> ModelType | None:
|
) -> ModelType | None:
|
||||||
"""Get the first matching record, or None.
|
"""Get the first matching record, or None.
|
||||||
@@ -111,12 +112,21 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
Args:
|
Args:
|
||||||
session: DB async session
|
session: DB async session
|
||||||
filters: List of SQLAlchemy filter conditions
|
filters: List of SQLAlchemy filter conditions
|
||||||
|
joins: List of (model, condition) tuples for joining related tables
|
||||||
|
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
|
||||||
load_options: SQLAlchemy loader options
|
load_options: SQLAlchemy loader options
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Model instance or None
|
Model instance or None
|
||||||
"""
|
"""
|
||||||
q = select(cls.model)
|
q = select(cls.model)
|
||||||
|
if joins:
|
||||||
|
for model, condition in joins:
|
||||||
|
q = (
|
||||||
|
q.outerjoin(model, condition)
|
||||||
|
if outer_join
|
||||||
|
else q.join(model, condition)
|
||||||
|
)
|
||||||
if filters:
|
if filters:
|
||||||
q = q.where(and_(*filters))
|
q = q.where(and_(*filters))
|
||||||
if load_options:
|
if load_options:
|
||||||
@@ -130,6 +140,8 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
*,
|
*,
|
||||||
filters: list[Any] | None = None,
|
filters: list[Any] | None = None,
|
||||||
|
joins: JoinType | None = None,
|
||||||
|
outer_join: bool = False,
|
||||||
load_options: list[Any] | None = None,
|
load_options: list[Any] | None = None,
|
||||||
order_by: Any | None = None,
|
order_by: Any | None = None,
|
||||||
limit: int | None = None,
|
limit: int | None = None,
|
||||||
@@ -140,6 +152,8 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
Args:
|
Args:
|
||||||
session: DB async session
|
session: DB async session
|
||||||
filters: List of SQLAlchemy filter conditions
|
filters: List of SQLAlchemy filter conditions
|
||||||
|
joins: List of (model, condition) tuples for joining related tables
|
||||||
|
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
|
||||||
load_options: SQLAlchemy loader options
|
load_options: SQLAlchemy loader options
|
||||||
order_by: Column or list of columns to order by
|
order_by: Column or list of columns to order by
|
||||||
limit: Max number of rows to return
|
limit: Max number of rows to return
|
||||||
@@ -149,6 +163,13 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
List of model instances
|
List of model instances
|
||||||
"""
|
"""
|
||||||
q = select(cls.model)
|
q = select(cls.model)
|
||||||
|
if joins:
|
||||||
|
for model, condition in joins:
|
||||||
|
q = (
|
||||||
|
q.outerjoin(model, condition)
|
||||||
|
if outer_join
|
||||||
|
else q.join(model, condition)
|
||||||
|
)
|
||||||
if filters:
|
if filters:
|
||||||
q = q.where(and_(*filters))
|
q = q.where(and_(*filters))
|
||||||
if load_options:
|
if load_options:
|
||||||
@@ -268,17 +289,29 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
cls: type[Self],
|
cls: type[Self],
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
filters: list[Any] | None = None,
|
filters: list[Any] | None = None,
|
||||||
|
*,
|
||||||
|
joins: JoinType | None = None,
|
||||||
|
outer_join: bool = False,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Count records matching the filters.
|
"""Count records matching the filters.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session: DB async session
|
session: DB async session
|
||||||
filters: List of SQLAlchemy filter conditions
|
filters: List of SQLAlchemy filter conditions
|
||||||
|
joins: List of (model, condition) tuples for joining related tables
|
||||||
|
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Number of matching records
|
Number of matching records
|
||||||
"""
|
"""
|
||||||
q = select(func.count()).select_from(cls.model)
|
q = select(func.count()).select_from(cls.model)
|
||||||
|
if joins:
|
||||||
|
for model, condition in joins:
|
||||||
|
q = (
|
||||||
|
q.outerjoin(model, condition)
|
||||||
|
if outer_join
|
||||||
|
else q.join(model, condition)
|
||||||
|
)
|
||||||
if filters:
|
if filters:
|
||||||
q = q.where(and_(*filters))
|
q = q.where(and_(*filters))
|
||||||
result = await session.execute(q)
|
result = await session.execute(q)
|
||||||
@@ -289,17 +322,30 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
cls: type[Self],
|
cls: type[Self],
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
filters: list[Any],
|
filters: list[Any],
|
||||||
|
*,
|
||||||
|
joins: JoinType | None = None,
|
||||||
|
outer_join: bool = False,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Check if a record exists.
|
"""Check if a record exists.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session: DB async session
|
session: DB async session
|
||||||
filters: List of SQLAlchemy filter conditions
|
filters: List of SQLAlchemy filter conditions
|
||||||
|
joins: List of (model, condition) tuples for joining related tables
|
||||||
|
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if at least one record matches
|
True if at least one record matches
|
||||||
"""
|
"""
|
||||||
q = select(cls.model).where(and_(*filters)).exists().select()
|
q = select(cls.model)
|
||||||
|
if joins:
|
||||||
|
for model, condition in joins:
|
||||||
|
q = (
|
||||||
|
q.outerjoin(model, condition)
|
||||||
|
if outer_join
|
||||||
|
else q.join(model, condition)
|
||||||
|
)
|
||||||
|
q = q.where(and_(*filters)).exists().select()
|
||||||
result = await session.execute(q)
|
result = await session.execute(q)
|
||||||
return bool(result.scalar())
|
return bool(result.scalar())
|
||||||
|
|
||||||
@@ -309,37 +355,96 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
*,
|
*,
|
||||||
filters: list[Any] | None = None,
|
filters: list[Any] | None = None,
|
||||||
|
joins: JoinType | None = None,
|
||||||
|
outer_join: bool = False,
|
||||||
load_options: list[Any] | None = None,
|
load_options: list[Any] | None = None,
|
||||||
order_by: Any | None = None,
|
order_by: Any | None = None,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
items_per_page: int = 20,
|
items_per_page: int = 20,
|
||||||
|
search: str | SearchConfig | None = None,
|
||||||
|
search_fields: Sequence[SearchFieldType] | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Get paginated results with metadata.
|
"""Get paginated results with metadata.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session: DB async session
|
session: DB async session
|
||||||
filters: List of SQLAlchemy filter conditions
|
filters: List of SQLAlchemy filter conditions
|
||||||
|
joins: List of (model, condition) tuples for joining related tables
|
||||||
|
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
|
||||||
load_options: SQLAlchemy loader options
|
load_options: SQLAlchemy loader options
|
||||||
order_by: Column or list of columns to order by
|
order_by: Column or list of columns to order by
|
||||||
page: Page number (1-indexed)
|
page: Page number (1-indexed)
|
||||||
items_per_page: Number of items per page
|
items_per_page: Number of items per page
|
||||||
|
search: Search query string or SearchConfig object
|
||||||
|
search_fields: Fields to search in (overrides class default)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict with 'data' and 'pagination' keys
|
Dict with 'data' and 'pagination' keys
|
||||||
"""
|
"""
|
||||||
filters = filters or []
|
filters = list(filters) if filters else []
|
||||||
offset = (page - 1) * items_per_page
|
offset = (page - 1) * items_per_page
|
||||||
|
search_joins: list[Any] = []
|
||||||
|
|
||||||
items = await cls.get_multi(
|
# Build search filters
|
||||||
session,
|
if search:
|
||||||
filters=filters,
|
search_filters, search_joins = build_search_filters(
|
||||||
load_options=load_options,
|
cls.model,
|
||||||
order_by=order_by,
|
search,
|
||||||
limit=items_per_page,
|
search_fields=search_fields,
|
||||||
offset=offset,
|
default_fields=cls.searchable_fields,
|
||||||
)
|
)
|
||||||
|
filters.extend(search_filters)
|
||||||
|
|
||||||
total_count = await cls.count(session, filters=filters)
|
# Build query with joins
|
||||||
|
q = select(cls.model)
|
||||||
|
|
||||||
|
# Apply explicit joins
|
||||||
|
if joins:
|
||||||
|
for model, condition in joins:
|
||||||
|
q = (
|
||||||
|
q.outerjoin(model, condition)
|
||||||
|
if outer_join
|
||||||
|
else q.join(model, condition)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply search joins (always outer joins for search)
|
||||||
|
for join_rel in search_joins:
|
||||||
|
q = q.outerjoin(join_rel)
|
||||||
|
|
||||||
|
if filters:
|
||||||
|
q = q.where(and_(*filters))
|
||||||
|
if load_options:
|
||||||
|
q = q.options(*load_options)
|
||||||
|
if order_by is not None:
|
||||||
|
q = q.order_by(order_by)
|
||||||
|
|
||||||
|
q = q.offset(offset).limit(items_per_page)
|
||||||
|
result = await session.execute(q)
|
||||||
|
items = result.unique().scalars().all()
|
||||||
|
|
||||||
|
# Count query (with same joins and filters)
|
||||||
|
pk_col = cls.model.__mapper__.primary_key[0]
|
||||||
|
count_q = select(func.count(func.distinct(getattr(cls.model, pk_col.name))))
|
||||||
|
count_q = count_q.select_from(cls.model)
|
||||||
|
|
||||||
|
# Apply explicit joins to count query
|
||||||
|
if joins:
|
||||||
|
for model, condition in joins:
|
||||||
|
count_q = (
|
||||||
|
count_q.outerjoin(model, condition)
|
||||||
|
if outer_join
|
||||||
|
else count_q.join(model, condition)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply search joins to count query
|
||||||
|
for join_rel in search_joins:
|
||||||
|
count_q = count_q.outerjoin(join_rel)
|
||||||
|
|
||||||
|
if filters:
|
||||||
|
count_q = count_q.where(and_(*filters))
|
||||||
|
|
||||||
|
count_result = await session.execute(count_q)
|
||||||
|
total_count = count_result.scalar_one()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"data": items,
|
"data": items,
|
||||||
@@ -354,11 +459,14 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
|
|
||||||
def CrudFactory(
|
def CrudFactory(
|
||||||
model: type[ModelType],
|
model: type[ModelType],
|
||||||
|
*,
|
||||||
|
searchable_fields: Sequence[SearchFieldType] | None = None,
|
||||||
) -> type[AsyncCrud[ModelType]]:
|
) -> type[AsyncCrud[ModelType]]:
|
||||||
"""Create a CRUD class for a specific model.
|
"""Create a CRUD class for a specific model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: SQLAlchemy model class
|
model: SQLAlchemy model class
|
||||||
|
searchable_fields: Optional list of searchable fields
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
AsyncCrud subclass bound to the model
|
AsyncCrud subclass bound to the model
|
||||||
@@ -370,9 +478,39 @@ def CrudFactory(
|
|||||||
UserCrud = CrudFactory(User)
|
UserCrud = CrudFactory(User)
|
||||||
PostCrud = CrudFactory(Post)
|
PostCrud = CrudFactory(Post)
|
||||||
|
|
||||||
|
# With searchable fields:
|
||||||
|
UserCrud = CrudFactory(
|
||||||
|
User,
|
||||||
|
searchable_fields=[User.username, User.email, (User.role, Role.name)]
|
||||||
|
)
|
||||||
|
|
||||||
# Usage
|
# Usage
|
||||||
user = await UserCrud.get(session, [User.id == 1])
|
user = await UserCrud.get(session, [User.id == 1])
|
||||||
posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id])
|
posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id])
|
||||||
|
|
||||||
|
# With search
|
||||||
|
result = await UserCrud.paginate(session, search="john")
|
||||||
|
|
||||||
|
# With joins (inner join by default):
|
||||||
|
users = await UserCrud.get_multi(
|
||||||
|
session,
|
||||||
|
joins=[(Post, Post.user_id == User.id)],
|
||||||
|
filters=[Post.published == True],
|
||||||
|
)
|
||||||
|
|
||||||
|
# With outer join:
|
||||||
|
users = await UserCrud.get_multi(
|
||||||
|
session,
|
||||||
|
joins=[(Post, Post.user_id == User.id)],
|
||||||
|
outer_join=True,
|
||||||
|
)
|
||||||
"""
|
"""
|
||||||
cls = type(f"Async{model.__name__}Crud", (AsyncCrud,), {"model": model})
|
cls = type(
|
||||||
|
f"Async{model.__name__}Crud",
|
||||||
|
(AsyncCrud,),
|
||||||
|
{
|
||||||
|
"model": model,
|
||||||
|
"searchable_fields": searchable_fields,
|
||||||
|
},
|
||||||
|
)
|
||||||
return cast(type[AsyncCrud[ModelType]], cls)
|
return cast(type[AsyncCrud[ModelType]], cls)
|
||||||
146
src/fastapi_toolsets/crud/search.py
Normal file
146
src/fastapi_toolsets/crud/search.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
"""Search utilities for AsyncCrud."""
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Any, Literal
|
||||||
|
|
||||||
|
from sqlalchemy import String, or_
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||||
|
|
||||||
|
from ..exceptions import NoSearchableFieldsError
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sqlalchemy.sql.elements import ColumnElement
|
||||||
|
|
||||||
|
SearchFieldType = InstrumentedAttribute[Any] | tuple[InstrumentedAttribute[Any], ...]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SearchConfig:
|
||||||
|
"""Advanced search configuration.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
query: The search string
|
||||||
|
fields: Fields to search (columns or tuples for relationships)
|
||||||
|
case_sensitive: Case-sensitive search (default: False)
|
||||||
|
match_mode: "any" (OR) or "all" (AND) to combine fields
|
||||||
|
"""
|
||||||
|
|
||||||
|
query: str
|
||||||
|
fields: Sequence[SearchFieldType] | None = None
|
||||||
|
case_sensitive: bool = False
|
||||||
|
match_mode: Literal["any", "all"] = "any"
|
||||||
|
|
||||||
|
|
||||||
|
def get_searchable_fields(
|
||||||
|
model: type[DeclarativeBase],
|
||||||
|
*,
|
||||||
|
include_relationships: bool = True,
|
||||||
|
max_depth: int = 1,
|
||||||
|
) -> list[SearchFieldType]:
|
||||||
|
"""Auto-detect String fields on a model and its relationships.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: SQLAlchemy model class
|
||||||
|
include_relationships: Include fields from many-to-one/one-to-one relationships
|
||||||
|
max_depth: Max depth for relationship traversal (default: 1)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of columns and tuples (relationship, column)
|
||||||
|
"""
|
||||||
|
fields: list[SearchFieldType] = []
|
||||||
|
mapper = model.__mapper__
|
||||||
|
|
||||||
|
# Direct String columns
|
||||||
|
for col in mapper.columns:
|
||||||
|
if isinstance(col.type, String):
|
||||||
|
fields.append(getattr(model, col.key))
|
||||||
|
|
||||||
|
# Relationships (one-to-one, many-to-one only)
|
||||||
|
if include_relationships and max_depth > 0:
|
||||||
|
for rel_name, rel_prop in mapper.relationships.items():
|
||||||
|
if rel_prop.uselist: # Skip collections (one-to-many, many-to-many)
|
||||||
|
continue
|
||||||
|
|
||||||
|
rel_attr = getattr(model, rel_name)
|
||||||
|
related_model = rel_prop.mapper.class_
|
||||||
|
|
||||||
|
for col in related_model.__mapper__.columns:
|
||||||
|
if isinstance(col.type, String):
|
||||||
|
fields.append((rel_attr, getattr(related_model, col.key)))
|
||||||
|
|
||||||
|
return fields
|
||||||
|
|
||||||
|
|
||||||
|
def build_search_filters(
|
||||||
|
model: type[DeclarativeBase],
|
||||||
|
search: str | SearchConfig,
|
||||||
|
search_fields: Sequence[SearchFieldType] | None = None,
|
||||||
|
default_fields: Sequence[SearchFieldType] | None = None,
|
||||||
|
) -> tuple[list["ColumnElement[bool]"], list[InstrumentedAttribute[Any]]]:
|
||||||
|
"""Build SQLAlchemy filter conditions for search.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: SQLAlchemy model class
|
||||||
|
search: Search string or SearchConfig
|
||||||
|
search_fields: Fields specified per-call (takes priority)
|
||||||
|
default_fields: Default fields (from ClassVar)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (filter_conditions, joins_needed)
|
||||||
|
"""
|
||||||
|
# Normalize input
|
||||||
|
if isinstance(search, str):
|
||||||
|
config = SearchConfig(query=search, fields=search_fields)
|
||||||
|
else:
|
||||||
|
config = search
|
||||||
|
if search_fields is not None:
|
||||||
|
config = SearchConfig(
|
||||||
|
query=config.query,
|
||||||
|
fields=search_fields,
|
||||||
|
case_sensitive=config.case_sensitive,
|
||||||
|
match_mode=config.match_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not config.query or not config.query.strip():
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
# Determine which fields to search
|
||||||
|
fields = config.fields or default_fields or get_searchable_fields(model)
|
||||||
|
|
||||||
|
if not fields:
|
||||||
|
raise NoSearchableFieldsError(model)
|
||||||
|
|
||||||
|
query = config.query.strip()
|
||||||
|
filters: list[ColumnElement[bool]] = []
|
||||||
|
joins: list[InstrumentedAttribute[Any]] = []
|
||||||
|
added_joins: set[str] = set()
|
||||||
|
|
||||||
|
for field in fields:
|
||||||
|
if isinstance(field, tuple):
|
||||||
|
# Relationship: (User.role, Role.name) or deeper
|
||||||
|
for rel in field[:-1]:
|
||||||
|
rel_key = str(rel)
|
||||||
|
if rel_key not in added_joins:
|
||||||
|
joins.append(rel)
|
||||||
|
added_joins.add(rel_key)
|
||||||
|
column = field[-1]
|
||||||
|
else:
|
||||||
|
column = field
|
||||||
|
|
||||||
|
# Build the filter (cast to String for non-text columns)
|
||||||
|
column_as_string = column.cast(String)
|
||||||
|
if config.case_sensitive:
|
||||||
|
filters.append(column_as_string.like(f"%{query}%"))
|
||||||
|
else:
|
||||||
|
filters.append(column_as_string.ilike(f"%{query}%"))
|
||||||
|
|
||||||
|
if not filters:
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
# Combine based on match_mode
|
||||||
|
if config.match_mode == "any":
|
||||||
|
return [or_(*filters)], joins
|
||||||
|
else:
|
||||||
|
return filters, joins
|
||||||
@@ -1,7 +1,9 @@
|
|||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
|
ApiError,
|
||||||
ApiException,
|
ApiException,
|
||||||
ConflictError,
|
ConflictError,
|
||||||
ForbiddenError,
|
ForbiddenError,
|
||||||
|
NoSearchableFieldsError,
|
||||||
NotFoundError,
|
NotFoundError,
|
||||||
UnauthorizedError,
|
UnauthorizedError,
|
||||||
generate_error_responses,
|
generate_error_responses,
|
||||||
@@ -9,11 +11,13 @@ from .exceptions import (
|
|||||||
from .handler import init_exceptions_handlers
|
from .handler import init_exceptions_handlers
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"init_exceptions_handlers",
|
"ApiError",
|
||||||
"generate_error_responses",
|
|
||||||
"ApiException",
|
"ApiException",
|
||||||
"ConflictError",
|
"ConflictError",
|
||||||
"ForbiddenError",
|
"ForbiddenError",
|
||||||
|
"generate_error_responses",
|
||||||
|
"init_exceptions_handlers",
|
||||||
|
"NoSearchableFieldsError",
|
||||||
"NotFoundError",
|
"NotFoundError",
|
||||||
"UnauthorizedError",
|
"UnauthorizedError",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -119,6 +119,25 @@ class RoleNotFoundError(NotFoundError):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NoSearchableFieldsError(ApiException):
|
||||||
|
"""Raised when search is requested but no searchable fields are available."""
|
||||||
|
|
||||||
|
api_error = ApiError(
|
||||||
|
code=400,
|
||||||
|
msg="No Searchable Fields",
|
||||||
|
desc="No searchable fields configured for this resource.",
|
||||||
|
err_code="SEARCH-400",
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, model: type) -> None:
|
||||||
|
self.model = model
|
||||||
|
detail = (
|
||||||
|
f"No searchable fields found for model '{model.__name__}'. "
|
||||||
|
"Provide 'search_fields' parameter or set 'searchable_fields' on the CRUD class."
|
||||||
|
)
|
||||||
|
super().__init__(detail)
|
||||||
|
|
||||||
|
|
||||||
def generate_error_responses(
|
def generate_error_responses(
|
||||||
*errors: type[ApiException],
|
*errors: type[ApiException],
|
||||||
) -> dict[int | str, dict[str, Any]]:
|
) -> dict[int | str, dict[str, Any]]:
|
||||||
|
|||||||
@@ -50,8 +50,16 @@ class FixtureRegistry:
|
|||||||
]
|
]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
contexts: list[str | Context] | None = None,
|
||||||
|
) -> None:
|
||||||
self._fixtures: dict[str, Fixture] = {}
|
self._fixtures: dict[str, Fixture] = {}
|
||||||
|
self._default_contexts: list[str] | None = (
|
||||||
|
[c.value if isinstance(c, Context) else c for c in contexts]
|
||||||
|
if contexts
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
def register(
|
def register(
|
||||||
self,
|
self,
|
||||||
@@ -85,10 +93,14 @@ class FixtureRegistry:
|
|||||||
fn: Callable[[], Sequence[DeclarativeBase]],
|
fn: Callable[[], Sequence[DeclarativeBase]],
|
||||||
) -> Callable[[], Sequence[DeclarativeBase]]:
|
) -> Callable[[], Sequence[DeclarativeBase]]:
|
||||||
fixture_name = name or cast(Any, fn).__name__
|
fixture_name = name or cast(Any, fn).__name__
|
||||||
fixture_contexts = [
|
if contexts is not None:
|
||||||
c.value if isinstance(c, Context) else c
|
fixture_contexts = [
|
||||||
for c in (contexts or [Context.BASE])
|
c.value if isinstance(c, Context) else c for c in contexts
|
||||||
]
|
]
|
||||||
|
elif self._default_contexts is not None:
|
||||||
|
fixture_contexts = self._default_contexts
|
||||||
|
else:
|
||||||
|
fixture_contexts = [Context.BASE.value]
|
||||||
|
|
||||||
self._fixtures[fixture_name] = Fixture(
|
self._fixtures[fixture_name] = Fixture(
|
||||||
name=fixture_name,
|
name=fixture_name,
|
||||||
@@ -102,6 +114,32 @@ class FixtureRegistry:
|
|||||||
return decorator(func)
|
return decorator(func)
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
def include_registry(self, registry: "FixtureRegistry") -> None:
|
||||||
|
"""Include another `FixtureRegistry` in the same current `FixtureRegistry`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
registry: The `FixtureRegistry` to include
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If a fixture name already exists in the current registry
|
||||||
|
|
||||||
|
Example:
|
||||||
|
registry = FixtureRegistry()
|
||||||
|
dev_registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@dev_registry.register
|
||||||
|
def dev_data():
|
||||||
|
return [...]
|
||||||
|
|
||||||
|
registry.include_registry(registry=dev_registry)
|
||||||
|
"""
|
||||||
|
for name, fixture in registry._fixtures.items():
|
||||||
|
if name in self._fixtures:
|
||||||
|
raise ValueError(
|
||||||
|
f"Fixture '{name}' already exists in the current registry"
|
||||||
|
)
|
||||||
|
self._fixtures[name] = fixture
|
||||||
|
|
||||||
def get(self, name: str) -> Fixture:
|
def get(self, name: str) -> Fixture:
|
||||||
"""Get a fixture by name."""
|
"""Get a fixture by name."""
|
||||||
if name not in self._fixtures:
|
if name not in self._fixtures:
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
"""Shared pytest fixtures for fastapi-utils tests."""
|
"""Shared pytest fixtures for fastapi-utils tests."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import uuid
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import ForeignKey, String
|
from sqlalchemy import ForeignKey, String, Uuid
|
||||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||||
|
|
||||||
@@ -33,7 +34,7 @@ class Role(Base):
|
|||||||
|
|
||||||
__tablename__ = "roles"
|
__tablename__ = "roles"
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(primary_key=True)
|
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||||
name: Mapped[str] = mapped_column(String(50), unique=True)
|
name: Mapped[str] = mapped_column(String(50), unique=True)
|
||||||
|
|
||||||
users: Mapped[list["User"]] = relationship(back_populates="role")
|
users: Mapped[list["User"]] = relationship(back_populates="role")
|
||||||
@@ -44,11 +45,13 @@ class User(Base):
|
|||||||
|
|
||||||
__tablename__ = "users"
|
__tablename__ = "users"
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(primary_key=True)
|
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||||
username: Mapped[str] = mapped_column(String(50), unique=True)
|
username: Mapped[str] = mapped_column(String(50), unique=True)
|
||||||
email: Mapped[str] = mapped_column(String(100), unique=True)
|
email: Mapped[str] = mapped_column(String(100), unique=True)
|
||||||
is_active: Mapped[bool] = mapped_column(default=True)
|
is_active: Mapped[bool] = mapped_column(default=True)
|
||||||
role_id: Mapped[int | None] = mapped_column(ForeignKey("roles.id"), nullable=True)
|
role_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||||
|
ForeignKey("roles.id"), nullable=True
|
||||||
|
)
|
||||||
|
|
||||||
role: Mapped[Role | None] = relationship(back_populates="users")
|
role: Mapped[Role | None] = relationship(back_populates="users")
|
||||||
|
|
||||||
@@ -58,11 +61,11 @@ class Post(Base):
|
|||||||
|
|
||||||
__tablename__ = "posts"
|
__tablename__ = "posts"
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(primary_key=True)
|
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||||
title: Mapped[str] = mapped_column(String(200))
|
title: Mapped[str] = mapped_column(String(200))
|
||||||
content: Mapped[str] = mapped_column(String(1000), default="")
|
content: Mapped[str] = mapped_column(String(1000), default="")
|
||||||
is_published: Mapped[bool] = mapped_column(default=False)
|
is_published: Mapped[bool] = mapped_column(default=False)
|
||||||
author_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
|
author_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"))
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -73,7 +76,7 @@ class Post(Base):
|
|||||||
class RoleCreate(BaseModel):
|
class RoleCreate(BaseModel):
|
||||||
"""Schema for creating a role."""
|
"""Schema for creating a role."""
|
||||||
|
|
||||||
id: int | None = None
|
id: uuid.UUID | None = None
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
|
|
||||||
@@ -86,11 +89,11 @@ class RoleUpdate(BaseModel):
|
|||||||
class UserCreate(BaseModel):
|
class UserCreate(BaseModel):
|
||||||
"""Schema for creating a user."""
|
"""Schema for creating a user."""
|
||||||
|
|
||||||
id: int | None = None
|
id: uuid.UUID | None = None
|
||||||
username: str
|
username: str
|
||||||
email: str
|
email: str
|
||||||
is_active: bool = True
|
is_active: bool = True
|
||||||
role_id: int | None = None
|
role_id: uuid.UUID | None = None
|
||||||
|
|
||||||
|
|
||||||
class UserUpdate(BaseModel):
|
class UserUpdate(BaseModel):
|
||||||
@@ -99,17 +102,17 @@ class UserUpdate(BaseModel):
|
|||||||
username: str | None = None
|
username: str | None = None
|
||||||
email: str | None = None
|
email: str | None = None
|
||||||
is_active: bool | None = None
|
is_active: bool | None = None
|
||||||
role_id: int | None = None
|
role_id: uuid.UUID | None = None
|
||||||
|
|
||||||
|
|
||||||
class PostCreate(BaseModel):
|
class PostCreate(BaseModel):
|
||||||
"""Schema for creating a post."""
|
"""Schema for creating a post."""
|
||||||
|
|
||||||
id: int | None = None
|
id: uuid.UUID | None = None
|
||||||
title: str
|
title: str
|
||||||
content: str = ""
|
content: str = ""
|
||||||
is_published: bool = False
|
is_published: bool = False
|
||||||
author_id: int
|
author_id: uuid.UUID
|
||||||
|
|
||||||
|
|
||||||
class PostUpdate(BaseModel):
|
class PostUpdate(BaseModel):
|
||||||
@@ -195,5 +198,5 @@ def sample_post_data() -> PostCreate:
|
|||||||
title="Test Post",
|
title="Test Post",
|
||||||
content="Test content",
|
content="Test content",
|
||||||
is_published=True,
|
is_published=True,
|
||||||
author_id=1,
|
author_id=uuid.uuid4(),
|
||||||
)
|
)
|
||||||
|
|||||||
322
tests/test_cli.py
Normal file
322
tests/test_cli.py
Normal file
@@ -0,0 +1,322 @@
|
|||||||
|
"""Tests for fastapi_toolsets.cli module."""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from typer.testing import CliRunner
|
||||||
|
|
||||||
|
from fastapi_toolsets.cli.config import CliConfig, _import_from_string, load_config
|
||||||
|
from fastapi_toolsets.cli.utils import async_command
|
||||||
|
from fastapi_toolsets.fixtures import FixtureRegistry
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
|
||||||
|
|
||||||
|
class TestCliConfig:
|
||||||
|
"""Tests for CliConfig dataclass."""
|
||||||
|
|
||||||
|
def test_default_values(self):
|
||||||
|
"""Config has None defaults."""
|
||||||
|
config = CliConfig()
|
||||||
|
assert config.fixtures is None
|
||||||
|
assert config.db_context is None
|
||||||
|
|
||||||
|
def test_with_values(self):
|
||||||
|
"""Config stores provided values."""
|
||||||
|
config = CliConfig(
|
||||||
|
fixtures="app.fixtures:registry",
|
||||||
|
db_context="app.db:get_session",
|
||||||
|
)
|
||||||
|
assert config.fixtures == "app.fixtures:registry"
|
||||||
|
assert config.db_context == "app.db:get_session"
|
||||||
|
|
||||||
|
def test_get_fixtures_registry_without_config(self):
|
||||||
|
"""get_fixtures_registry raises error when not configured."""
|
||||||
|
config = CliConfig()
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
config.get_fixtures_registry()
|
||||||
|
assert "No fixtures registry configured" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_get_db_context_without_config(self):
|
||||||
|
"""get_db_context raises error when not configured."""
|
||||||
|
config = CliConfig()
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
config.get_db_context()
|
||||||
|
assert "No db_context configured" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
class TestImportFromString:
|
||||||
|
"""Tests for _import_from_string function."""
|
||||||
|
|
||||||
|
def test_import_valid_path(self):
|
||||||
|
"""Import valid module:attribute path."""
|
||||||
|
result = _import_from_string("fastapi_toolsets.fixtures:FixtureRegistry")
|
||||||
|
assert result is FixtureRegistry
|
||||||
|
|
||||||
|
def test_import_without_colon_raises_error(self):
|
||||||
|
"""Import path without colon raises error."""
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
_import_from_string("fastapi_toolsets.fixtures.FixtureRegistry")
|
||||||
|
assert "Expected format: 'module:attribute'" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_import_nonexistent_module_raises_error(self):
|
||||||
|
"""Import nonexistent module raises error."""
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
_import_from_string("nonexistent.module:something")
|
||||||
|
assert "Cannot import module" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_import_nonexistent_attribute_raises_error(self):
|
||||||
|
"""Import nonexistent attribute raises error."""
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
_import_from_string("fastapi_toolsets.fixtures:NonexistentClass")
|
||||||
|
assert "has no attribute" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoadConfig:
|
||||||
|
"""Tests for load_config function."""
|
||||||
|
|
||||||
|
def test_load_without_pyproject(self, tmp_path, monkeypatch):
|
||||||
|
"""Returns empty config when no pyproject.toml exists."""
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
config = load_config()
|
||||||
|
assert config.fixtures is None
|
||||||
|
assert config.db_context is None
|
||||||
|
|
||||||
|
def test_load_without_tool_section(self, tmp_path, monkeypatch):
|
||||||
|
"""Returns empty config when no [tool.fastapi-toolsets] section."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("[project]\nname = 'test'\n")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
config = load_config()
|
||||||
|
assert config.fixtures is None
|
||||||
|
assert config.db_context is None
|
||||||
|
|
||||||
|
def test_load_with_fixtures_config(self, tmp_path, monkeypatch):
|
||||||
|
"""Loads fixtures config from pyproject.toml."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text(
|
||||||
|
'[tool.fastapi-toolsets]\nfixtures = "app.fixtures:registry"\n'
|
||||||
|
)
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
config = load_config()
|
||||||
|
assert config.fixtures == "app.fixtures:registry"
|
||||||
|
assert config.db_context is None
|
||||||
|
|
||||||
|
def test_load_with_full_config(self, tmp_path, monkeypatch):
|
||||||
|
"""Loads full config from pyproject.toml."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text(
|
||||||
|
"[tool.fastapi-toolsets]\n"
|
||||||
|
'fixtures = "app.fixtures:registry"\n'
|
||||||
|
'db_context = "app.db:get_session"\n'
|
||||||
|
)
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
config = load_config()
|
||||||
|
assert config.fixtures == "app.fixtures:registry"
|
||||||
|
assert config.db_context == "app.db:get_session"
|
||||||
|
|
||||||
|
def test_load_with_invalid_toml(self, tmp_path, monkeypatch):
|
||||||
|
"""Returns empty config when pyproject.toml is invalid."""
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text("invalid toml {{{")
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
config = load_config()
|
||||||
|
assert config.fixtures is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestCliApp:
|
||||||
|
"""Tests for CLI application."""
|
||||||
|
|
||||||
|
def test_cli_import(self):
|
||||||
|
"""CLI can be imported."""
|
||||||
|
from fastapi_toolsets.cli import cli
|
||||||
|
|
||||||
|
assert cli is not None
|
||||||
|
|
||||||
|
def test_cli_help(self, tmp_path, monkeypatch):
|
||||||
|
"""CLI shows help without fixtures."""
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
# Need to reload the module to pick up new cwd
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from fastapi_toolsets.cli import app
|
||||||
|
|
||||||
|
importlib.reload(app)
|
||||||
|
|
||||||
|
result = runner.invoke(app.cli, ["--help"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "CLI utilities for FastAPI projects" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
class TestFixturesCli:
|
||||||
|
"""Tests for fixtures CLI commands."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def cli_env(self, tmp_path, monkeypatch):
|
||||||
|
"""Set up CLI environment with fixtures config."""
|
||||||
|
# Create pyproject.toml
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text(
|
||||||
|
"[tool.fastapi-toolsets]\n"
|
||||||
|
'fixtures = "fixtures:registry"\n'
|
||||||
|
'db_context = "db:get_session"\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create fixtures module
|
||||||
|
fixtures_file = tmp_path / "fixtures.py"
|
||||||
|
fixtures_file.write_text(
|
||||||
|
"from fastapi_toolsets.fixtures import FixtureRegistry, Context\n"
|
||||||
|
"\n"
|
||||||
|
"registry = FixtureRegistry()\n"
|
||||||
|
"\n"
|
||||||
|
"@registry.register(contexts=[Context.BASE])\n"
|
||||||
|
"def roles():\n"
|
||||||
|
' return [{"id": 1, "name": "admin"}, {"id": 2, "name": "user"}]\n'
|
||||||
|
"\n"
|
||||||
|
'@registry.register(depends_on=["roles"], contexts=[Context.TESTING])\n'
|
||||||
|
"def users():\n"
|
||||||
|
' return [{"id": 1, "name": "alice", "role_id": 1}]\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create db module
|
||||||
|
db_file = tmp_path / "db.py"
|
||||||
|
db_file.write_text(
|
||||||
|
"from contextlib import asynccontextmanager\n"
|
||||||
|
"\n"
|
||||||
|
"@asynccontextmanager\n"
|
||||||
|
"async def get_session():\n"
|
||||||
|
" yield None\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
# Add tmp_path to sys.path for imports
|
||||||
|
if str(tmp_path) not in sys.path:
|
||||||
|
sys.path.insert(0, str(tmp_path))
|
||||||
|
|
||||||
|
# Reload the CLI module to pick up new config
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from fastapi_toolsets.cli import app, config
|
||||||
|
|
||||||
|
importlib.reload(config)
|
||||||
|
importlib.reload(app)
|
||||||
|
|
||||||
|
yield tmp_path, app.cli
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
if str(tmp_path) in sys.path:
|
||||||
|
sys.path.remove(str(tmp_path))
|
||||||
|
|
||||||
|
def test_fixtures_list(self, cli_env):
|
||||||
|
"""fixtures list shows registered fixtures."""
|
||||||
|
tmp_path, cli = cli_env
|
||||||
|
result = runner.invoke(cli, ["fixtures", "list"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "roles" in result.output
|
||||||
|
assert "users" in result.output
|
||||||
|
assert "Total: 2 fixture(s)" in result.output
|
||||||
|
|
||||||
|
def test_fixtures_list_with_context(self, cli_env):
|
||||||
|
"""fixtures list --context filters by context."""
|
||||||
|
tmp_path, cli = cli_env
|
||||||
|
result = runner.invoke(cli, ["fixtures", "list", "--context", "base"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "roles" in result.output
|
||||||
|
assert "users" not in result.output
|
||||||
|
assert "Total: 1 fixture(s)" in result.output
|
||||||
|
|
||||||
|
def test_fixtures_load_dry_run(self, cli_env):
|
||||||
|
"""fixtures load --dry-run shows what would be loaded."""
|
||||||
|
tmp_path, cli = cli_env
|
||||||
|
result = runner.invoke(cli, ["fixtures", "load", "base", "--dry-run"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Fixtures to load" in result.output
|
||||||
|
assert "roles" in result.output
|
||||||
|
assert "[Dry run - no changes made]" in result.output
|
||||||
|
|
||||||
|
def test_fixtures_load_invalid_strategy(self, cli_env):
|
||||||
|
"""fixtures load with invalid strategy shows error."""
|
||||||
|
tmp_path, cli = cli_env
|
||||||
|
result = runner.invoke(
|
||||||
|
cli, ["fixtures", "load", "base", "--strategy", "invalid"]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code == 1
|
||||||
|
assert "Invalid strategy" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
class TestCliWithoutFixturesConfig:
|
||||||
|
"""Tests for CLI when fixtures is not configured."""
|
||||||
|
|
||||||
|
def test_no_fixtures_command(self, tmp_path, monkeypatch):
|
||||||
|
"""fixtures command is not available when not configured."""
|
||||||
|
# Create pyproject.toml without fixtures
|
||||||
|
pyproject = tmp_path / "pyproject.toml"
|
||||||
|
pyproject.write_text('[project]\nname = "test"\n')
|
||||||
|
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
# Reload the CLI module
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from fastapi_toolsets.cli import app, config
|
||||||
|
|
||||||
|
importlib.reload(config)
|
||||||
|
importlib.reload(app)
|
||||||
|
|
||||||
|
result = runner.invoke(app.cli, ["--help"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "fixtures" not in result.output
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncCommand:
|
||||||
|
"""Tests for async_command decorator."""
|
||||||
|
|
||||||
|
def test_async_command_runs_coroutine(self):
|
||||||
|
"""async_command runs async function synchronously."""
|
||||||
|
|
||||||
|
@async_command
|
||||||
|
async def async_func(value: int) -> int:
|
||||||
|
return value * 2
|
||||||
|
|
||||||
|
result = async_func(21)
|
||||||
|
assert result == 42
|
||||||
|
|
||||||
|
def test_async_command_preserves_signature(self):
|
||||||
|
"""async_command preserves function signature."""
|
||||||
|
|
||||||
|
@async_command
|
||||||
|
async def async_func(name: str, count: int = 1) -> str:
|
||||||
|
return f"{name} x {count}"
|
||||||
|
|
||||||
|
result = async_func("test", count=3)
|
||||||
|
assert result == "test x 3"
|
||||||
|
|
||||||
|
def test_async_command_preserves_docstring(self):
|
||||||
|
"""async_command preserves function docstring."""
|
||||||
|
|
||||||
|
@async_command
|
||||||
|
async def async_func() -> None:
|
||||||
|
"""This is a docstring."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert async_func.__doc__ == """This is a docstring."""
|
||||||
|
|
||||||
|
def test_async_command_preserves_name(self):
|
||||||
|
"""async_command preserves function name."""
|
||||||
|
|
||||||
|
@async_command
|
||||||
|
async def my_async_function() -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert my_async_function.__name__ == "my_async_function"
|
||||||
@@ -1,12 +1,18 @@
|
|||||||
"""Tests for fastapi_toolsets.crud module."""
|
"""Tests for fastapi_toolsets.crud module."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from fastapi_toolsets.crud import AsyncCrud, CrudFactory
|
from fastapi_toolsets.crud import CrudFactory
|
||||||
|
from fastapi_toolsets.crud.factory import AsyncCrud
|
||||||
from fastapi_toolsets.exceptions import NotFoundError
|
from fastapi_toolsets.exceptions import NotFoundError
|
||||||
|
|
||||||
from .conftest import (
|
from .conftest import (
|
||||||
|
Post,
|
||||||
|
PostCreate,
|
||||||
|
PostCrud,
|
||||||
Role,
|
Role,
|
||||||
RoleCreate,
|
RoleCreate,
|
||||||
RoleCrud,
|
RoleCrud,
|
||||||
@@ -88,8 +94,9 @@ class TestCrudGet:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_raises_not_found(self, db_session: AsyncSession):
|
async def test_get_raises_not_found(self, db_session: AsyncSession):
|
||||||
"""Get raises NotFoundError for missing records."""
|
"""Get raises NotFoundError for missing records."""
|
||||||
|
non_existent_id = uuid.uuid4()
|
||||||
with pytest.raises(NotFoundError):
|
with pytest.raises(NotFoundError):
|
||||||
await RoleCrud.get(db_session, [Role.id == 99999])
|
await RoleCrud.get(db_session, [Role.id == non_existent_id])
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_with_multiple_filters(self, db_session: AsyncSession):
|
async def test_get_with_multiple_filters(self, db_session: AsyncSession):
|
||||||
@@ -222,11 +229,12 @@ class TestCrudUpdate:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_update_raises_not_found(self, db_session: AsyncSession):
|
async def test_update_raises_not_found(self, db_session: AsyncSession):
|
||||||
"""Update raises NotFoundError for missing records."""
|
"""Update raises NotFoundError for missing records."""
|
||||||
|
non_existent_id = uuid.uuid4()
|
||||||
with pytest.raises(NotFoundError):
|
with pytest.raises(NotFoundError):
|
||||||
await RoleCrud.update(
|
await RoleCrud.update(
|
||||||
db_session,
|
db_session,
|
||||||
RoleUpdate(name="new"),
|
RoleUpdate(name="new"),
|
||||||
[Role.id == 99999],
|
[Role.id == non_existent_id],
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@@ -339,7 +347,8 @@ class TestCrudUpsert:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_upsert_insert_new_record(self, db_session: AsyncSession):
|
async def test_upsert_insert_new_record(self, db_session: AsyncSession):
|
||||||
"""Upsert inserts a new record when it doesn't exist."""
|
"""Upsert inserts a new record when it doesn't exist."""
|
||||||
data = RoleCreate(id=1, name="upsert_new")
|
role_id = uuid.uuid4()
|
||||||
|
data = RoleCreate(id=role_id, name="upsert_new")
|
||||||
role = await RoleCrud.upsert(
|
role = await RoleCrud.upsert(
|
||||||
db_session,
|
db_session,
|
||||||
data,
|
data,
|
||||||
@@ -352,12 +361,13 @@ class TestCrudUpsert:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_upsert_update_existing_record(self, db_session: AsyncSession):
|
async def test_upsert_update_existing_record(self, db_session: AsyncSession):
|
||||||
"""Upsert updates an existing record."""
|
"""Upsert updates an existing record."""
|
||||||
|
role_id = uuid.uuid4()
|
||||||
# First insert
|
# First insert
|
||||||
data = RoleCreate(id=100, name="original_name")
|
data = RoleCreate(id=role_id, name="original_name")
|
||||||
await RoleCrud.upsert(db_session, data, index_elements=["id"])
|
await RoleCrud.upsert(db_session, data, index_elements=["id"])
|
||||||
|
|
||||||
# Upsert with update
|
# Upsert with update
|
||||||
updated_data = RoleCreate(id=100, name="updated_name")
|
updated_data = RoleCreate(id=role_id, name="updated_name")
|
||||||
role = await RoleCrud.upsert(
|
role = await RoleCrud.upsert(
|
||||||
db_session,
|
db_session,
|
||||||
updated_data,
|
updated_data,
|
||||||
@@ -369,22 +379,23 @@ class TestCrudUpsert:
|
|||||||
assert role.name == "updated_name"
|
assert role.name == "updated_name"
|
||||||
|
|
||||||
# Verify only one record exists
|
# Verify only one record exists
|
||||||
count = await RoleCrud.count(db_session, [Role.id == 100])
|
count = await RoleCrud.count(db_session, [Role.id == role_id])
|
||||||
assert count == 1
|
assert count == 1
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_upsert_do_nothing_on_conflict(self, db_session: AsyncSession):
|
async def test_upsert_do_nothing_on_conflict(self, db_session: AsyncSession):
|
||||||
"""Upsert does nothing on conflict when set_ is not provided."""
|
"""Upsert does nothing on conflict when set_ is not provided."""
|
||||||
|
role_id = uuid.uuid4()
|
||||||
# First insert
|
# First insert
|
||||||
data = RoleCreate(id=200, name="do_nothing_original")
|
data = RoleCreate(id=role_id, name="do_nothing_original")
|
||||||
await RoleCrud.upsert(db_session, data, index_elements=["id"])
|
await RoleCrud.upsert(db_session, data, index_elements=["id"])
|
||||||
|
|
||||||
# Upsert without set_ (do nothing)
|
# Upsert without set_ (do nothing)
|
||||||
conflict_data = RoleCreate(id=200, name="do_nothing_conflict")
|
conflict_data = RoleCreate(id=role_id, name="do_nothing_conflict")
|
||||||
await RoleCrud.upsert(db_session, conflict_data, index_elements=["id"])
|
await RoleCrud.upsert(db_session, conflict_data, index_elements=["id"])
|
||||||
|
|
||||||
# Original value should be preserved
|
# Original value should be preserved
|
||||||
role = await RoleCrud.first(db_session, [Role.id == 200])
|
role = await RoleCrud.first(db_session, [Role.id == role_id])
|
||||||
assert role is not None
|
assert role is not None
|
||||||
assert role.name == "do_nothing_original"
|
assert role.name == "do_nothing_original"
|
||||||
|
|
||||||
@@ -473,3 +484,271 @@ class TestCrudPaginate:
|
|||||||
|
|
||||||
names = [r.name for r in result["data"]]
|
names = [r.name for r in result["data"]]
|
||||||
assert names == ["alpha", "bravo", "charlie"]
|
assert names == ["alpha", "bravo", "charlie"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestCrudJoins:
|
||||||
|
"""Tests for CRUD operations with joins."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_get_with_join(self, db_session: AsyncSession):
|
||||||
|
"""Get with inner join filters correctly."""
|
||||||
|
# Create user with posts
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="author", email="author@test.com"),
|
||||||
|
)
|
||||||
|
await PostCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostCreate(title="Post 1", author_id=user.id, is_published=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get user with join on published posts
|
||||||
|
fetched = await UserCrud.get(
|
||||||
|
db_session,
|
||||||
|
filters=[User.id == user.id, Post.is_published == True], # noqa: E712
|
||||||
|
joins=[(Post, Post.author_id == User.id)],
|
||||||
|
)
|
||||||
|
assert fetched.id == user.id
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_first_with_join(self, db_session: AsyncSession):
|
||||||
|
"""First with join returns matching record."""
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="writer", email="writer@test.com"),
|
||||||
|
)
|
||||||
|
await PostCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostCreate(title="Draft", author_id=user.id, is_published=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find user with unpublished posts
|
||||||
|
result = await UserCrud.first(
|
||||||
|
db_session,
|
||||||
|
filters=[Post.is_published == False], # noqa: E712
|
||||||
|
joins=[(Post, Post.author_id == User.id)],
|
||||||
|
)
|
||||||
|
assert result is not None
|
||||||
|
assert result.id == user.id
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_first_with_outer_join(self, db_session: AsyncSession):
|
||||||
|
"""First with outer join includes records without related data."""
|
||||||
|
# User without posts
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="no_posts", email="no_posts@test.com"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# With outer join, user should be found even without posts
|
||||||
|
result = await UserCrud.first(
|
||||||
|
db_session,
|
||||||
|
filters=[User.id == user.id],
|
||||||
|
joins=[(Post, Post.author_id == User.id)],
|
||||||
|
outer_join=True,
|
||||||
|
)
|
||||||
|
assert result is not None
|
||||||
|
assert result.id == user.id
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_get_multi_with_inner_join(self, db_session: AsyncSession):
|
||||||
|
"""Get multiple with inner join only returns matching records."""
|
||||||
|
# User with published post
|
||||||
|
user1 = await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="publisher", email="pub@test.com"),
|
||||||
|
)
|
||||||
|
await PostCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostCreate(title="Published", author_id=user1.id, is_published=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# User without posts
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="lurker", email="lurk@test.com"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Inner join should only return user with published post
|
||||||
|
users = await UserCrud.get_multi(
|
||||||
|
db_session,
|
||||||
|
joins=[(Post, Post.author_id == User.id)],
|
||||||
|
filters=[Post.is_published == True], # noqa: E712
|
||||||
|
)
|
||||||
|
assert len(users) == 1
|
||||||
|
assert users[0].username == "publisher"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_get_multi_with_outer_join(self, db_session: AsyncSession):
|
||||||
|
"""Get multiple with outer join includes all records."""
|
||||||
|
# User with post
|
||||||
|
user1 = await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="has_post", email="has@test.com"),
|
||||||
|
)
|
||||||
|
await PostCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostCreate(title="My Post", author_id=user1.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
# User without posts
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="no_post", email="no@test.com"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Outer join should return both users
|
||||||
|
users = await UserCrud.get_multi(
|
||||||
|
db_session,
|
||||||
|
joins=[(Post, Post.author_id == User.id)],
|
||||||
|
outer_join=True,
|
||||||
|
)
|
||||||
|
assert len(users) == 2
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_count_with_join(self, db_session: AsyncSession):
|
||||||
|
"""Count with join counts correctly."""
|
||||||
|
# Create users with different post statuses
|
||||||
|
user1 = await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="active_author", email="active@test.com"),
|
||||||
|
)
|
||||||
|
await PostCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostCreate(title="Published 1", author_id=user1.id, is_published=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
user2 = await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="draft_author", email="draft@test.com"),
|
||||||
|
)
|
||||||
|
await PostCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostCreate(title="Draft 1", author_id=user2.id, is_published=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Count users with published posts
|
||||||
|
count = await UserCrud.count(
|
||||||
|
db_session,
|
||||||
|
filters=[Post.is_published == True], # noqa: E712
|
||||||
|
joins=[(Post, Post.author_id == User.id)],
|
||||||
|
)
|
||||||
|
assert count == 1
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_exists_with_join(self, db_session: AsyncSession):
|
||||||
|
"""Exists with join checks correctly."""
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="poster", email="poster@test.com"),
|
||||||
|
)
|
||||||
|
await PostCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostCreate(title="Exists Post", author_id=user.id, is_published=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if user with published post exists
|
||||||
|
exists = await UserCrud.exists(
|
||||||
|
db_session,
|
||||||
|
filters=[Post.is_published == True], # noqa: E712
|
||||||
|
joins=[(Post, Post.author_id == User.id)],
|
||||||
|
)
|
||||||
|
assert exists is True
|
||||||
|
|
||||||
|
# Check if user with specific title exists
|
||||||
|
exists = await UserCrud.exists(
|
||||||
|
db_session,
|
||||||
|
filters=[Post.title == "Nonexistent"],
|
||||||
|
joins=[(Post, Post.author_id == User.id)],
|
||||||
|
)
|
||||||
|
assert exists is False
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_paginate_with_join(self, db_session: AsyncSession):
|
||||||
|
"""Paginate with join works correctly."""
|
||||||
|
# Create users with posts
|
||||||
|
for i in range(5):
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username=f"author{i}", email=f"author{i}@test.com"),
|
||||||
|
)
|
||||||
|
await PostCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostCreate(
|
||||||
|
title=f"Post {i}",
|
||||||
|
author_id=user.id,
|
||||||
|
is_published=i % 2 == 0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Paginate users with published posts
|
||||||
|
result = await UserCrud.paginate(
|
||||||
|
db_session,
|
||||||
|
joins=[(Post, Post.author_id == User.id)],
|
||||||
|
filters=[Post.is_published == True], # noqa: E712
|
||||||
|
page=1,
|
||||||
|
items_per_page=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["pagination"]["total_count"] == 3
|
||||||
|
assert len(result["data"]) == 3
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_paginate_with_outer_join(self, db_session: AsyncSession):
|
||||||
|
"""Paginate with outer join includes all records."""
|
||||||
|
# User with post
|
||||||
|
user1 = await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="with_post", email="with@test.com"),
|
||||||
|
)
|
||||||
|
await PostCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostCreate(title="A Post", author_id=user1.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
# User without post
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="without_post", email="without@test.com"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Paginate with outer join
|
||||||
|
result = await UserCrud.paginate(
|
||||||
|
db_session,
|
||||||
|
joins=[(Post, Post.author_id == User.id)],
|
||||||
|
outer_join=True,
|
||||||
|
page=1,
|
||||||
|
items_per_page=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["pagination"]["total_count"] == 2
|
||||||
|
assert len(result["data"]) == 2
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_multiple_joins(self, db_session: AsyncSession):
|
||||||
|
"""Multiple joins can be applied."""
|
||||||
|
role = await RoleCrud.create(db_session, RoleCreate(name="author_role"))
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(
|
||||||
|
username="multi_join",
|
||||||
|
email="multi@test.com",
|
||||||
|
role_id=role.id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await PostCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostCreate(title="Multi Join Post", author_id=user.id, is_published=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Join both Role and Post
|
||||||
|
users = await UserCrud.get_multi(
|
||||||
|
db_session,
|
||||||
|
joins=[
|
||||||
|
(Role, Role.id == User.role_id),
|
||||||
|
(Post, Post.author_id == User.id),
|
||||||
|
],
|
||||||
|
filters=[Role.name == "author_role", Post.is_published == True], # noqa: E712
|
||||||
|
)
|
||||||
|
assert len(users) == 1
|
||||||
|
assert users[0].username == "multi_join"
|
||||||
|
|||||||
415
tests/test_crud_search.py
Normal file
415
tests/test_crud_search.py
Normal file
@@ -0,0 +1,415 @@
|
|||||||
|
"""Tests for CRUD search functionality."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from fastapi_toolsets.crud import SearchConfig, get_searchable_fields
|
||||||
|
|
||||||
|
from .conftest import (
|
||||||
|
Role,
|
||||||
|
RoleCreate,
|
||||||
|
RoleCrud,
|
||||||
|
User,
|
||||||
|
UserCreate,
|
||||||
|
UserCrud,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPaginateSearch:
|
||||||
|
"""Tests for paginate() with search."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_single_column(self, db_session: AsyncSession):
|
||||||
|
"""Search on a single direct column."""
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="john_doe", email="john@test.com")
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="jane_doe", email="jane@test.com")
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="bob_smith", email="bob@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserCrud.paginate(
|
||||||
|
db_session,
|
||||||
|
search="doe",
|
||||||
|
search_fields=[User.username],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["pagination"]["total_count"] == 2
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_multiple_columns(self, db_session: AsyncSession):
|
||||||
|
"""Search across multiple columns (OR logic)."""
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="alice", email="alice@company.com")
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="company_bob", email="bob@other.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserCrud.paginate(
|
||||||
|
db_session,
|
||||||
|
search="company",
|
||||||
|
search_fields=[User.username, User.email],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["pagination"]["total_count"] == 2
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_relationship_depth1(self, db_session: AsyncSession):
|
||||||
|
"""Search through a relationship (depth 1)."""
|
||||||
|
admin_role = await RoleCrud.create(db_session, RoleCreate(name="administrator"))
|
||||||
|
user_role = await RoleCrud.create(db_session, RoleCreate(name="basic_user"))
|
||||||
|
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="admin1", email="a1@test.com", role_id=admin_role.id),
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="admin2", email="a2@test.com", role_id=admin_role.id),
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="user1", email="u1@test.com", role_id=user_role.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserCrud.paginate(
|
||||||
|
db_session,
|
||||||
|
search="admin",
|
||||||
|
search_fields=[(User.role, Role.name)],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["pagination"]["total_count"] == 2
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_mixed_direct_and_relation(self, db_session: AsyncSession):
|
||||||
|
"""Search combining direct columns and relationships."""
|
||||||
|
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="john", email="john@test.com", role_id=role.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search "admin" in username OR role.name
|
||||||
|
result = await UserCrud.paginate(
|
||||||
|
db_session,
|
||||||
|
search="admin",
|
||||||
|
search_fields=[User.username, (User.role, Role.name)],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["pagination"]["total_count"] == 1
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_case_insensitive(self, db_session: AsyncSession):
|
||||||
|
"""Search is case-insensitive by default."""
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="JohnDoe", email="j@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserCrud.paginate(
|
||||||
|
db_session,
|
||||||
|
search="johndoe",
|
||||||
|
search_fields=[User.username],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["pagination"]["total_count"] == 1
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_case_sensitive(self, db_session: AsyncSession):
|
||||||
|
"""Case-sensitive search with SearchConfig."""
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="JohnDoe", email="j@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not find (case mismatch)
|
||||||
|
result = await UserCrud.paginate(
|
||||||
|
db_session,
|
||||||
|
search=SearchConfig(query="johndoe", case_sensitive=True),
|
||||||
|
search_fields=[User.username],
|
||||||
|
)
|
||||||
|
assert result["pagination"]["total_count"] == 0
|
||||||
|
|
||||||
|
# Should find (case match)
|
||||||
|
result = await UserCrud.paginate(
|
||||||
|
db_session,
|
||||||
|
search=SearchConfig(query="JohnDoe", case_sensitive=True),
|
||||||
|
search_fields=[User.username],
|
||||||
|
)
|
||||||
|
assert result["pagination"]["total_count"] == 1
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_empty_query(self, db_session: AsyncSession):
|
||||||
|
"""Empty search returns all results."""
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="user1", email="u1@test.com")
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="user2", email="u2@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserCrud.paginate(db_session, search="")
|
||||||
|
assert result["pagination"]["total_count"] == 2
|
||||||
|
|
||||||
|
result = await UserCrud.paginate(db_session, search=None)
|
||||||
|
assert result["pagination"]["total_count"] == 2
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_with_existing_filters(self, db_session: AsyncSession):
|
||||||
|
"""Search combines with existing filters (AND)."""
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="active_john", email="aj@test.com", is_active=True),
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="inactive_john", email="ij@test.com", is_active=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserCrud.paginate(
|
||||||
|
db_session,
|
||||||
|
filters=[User.is_active == True], # noqa: E712
|
||||||
|
search="john",
|
||||||
|
search_fields=[User.username],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["pagination"]["total_count"] == 1
|
||||||
|
assert result["data"][0].username == "active_john"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_auto_detect_fields(self, db_session: AsyncSession):
|
||||||
|
"""Auto-detect searchable fields when not specified."""
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="findme", email="other@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserCrud.paginate(db_session, search="findme")
|
||||||
|
|
||||||
|
assert result["pagination"]["total_count"] == 1
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_no_results(self, db_session: AsyncSession):
|
||||||
|
"""Search with no matching results."""
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="john", email="j@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserCrud.paginate(
|
||||||
|
db_session,
|
||||||
|
search="nonexistent",
|
||||||
|
search_fields=[User.username],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["pagination"]["total_count"] == 0
|
||||||
|
assert result["data"] == []
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_with_pagination(self, db_session: AsyncSession):
|
||||||
|
"""Search respects pagination parameters."""
|
||||||
|
for i in range(15):
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username=f"user_{i}", email=f"user{i}@test.com"),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserCrud.paginate(
|
||||||
|
db_session,
|
||||||
|
search="user_",
|
||||||
|
search_fields=[User.username],
|
||||||
|
page=1,
|
||||||
|
items_per_page=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["pagination"]["total_count"] == 15
|
||||||
|
assert len(result["data"]) == 5
|
||||||
|
assert result["pagination"]["has_more"] is True
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_null_relationship(self, db_session: AsyncSession):
|
||||||
|
"""Users without relationship are included (outerjoin)."""
|
||||||
|
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="with_role", email="wr@test.com", role_id=role.id),
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="no_role", email="nr@test.com", role_id=None),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search in username, not in role
|
||||||
|
result = await UserCrud.paginate(
|
||||||
|
db_session,
|
||||||
|
search="role",
|
||||||
|
search_fields=[User.username],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["pagination"]["total_count"] == 2
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_with_order_by(self, db_session: AsyncSession):
|
||||||
|
"""Search works with order_by parameter."""
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="charlie", email="c@test.com")
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="alice", email="a@test.com")
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="bob", email="b@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserCrud.paginate(
|
||||||
|
db_session,
|
||||||
|
search="@test.com",
|
||||||
|
search_fields=[User.email],
|
||||||
|
order_by=User.username,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["pagination"]["total_count"] == 3
|
||||||
|
usernames = [u.username for u in result["data"]]
|
||||||
|
assert usernames == ["alice", "bob", "charlie"]
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_non_string_column(self, db_session: AsyncSession):
|
||||||
|
"""Search on non-string columns (e.g., UUID) works via cast."""
|
||||||
|
user_id = uuid.UUID("12345678-1234-5678-1234-567812345678")
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(id=user_id, username="john", email="john@test.com")
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="jane", email="jane@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search by UUID (partial match)
|
||||||
|
result = await UserCrud.paginate(
|
||||||
|
db_session,
|
||||||
|
search="12345678",
|
||||||
|
search_fields=[User.id, User.username],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["pagination"]["total_count"] == 1
|
||||||
|
assert result["data"][0].id == user_id
|
||||||
|
|
||||||
|
|
||||||
|
class TestSearchConfig:
|
||||||
|
"""Tests for SearchConfig options."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_match_mode_all(self, db_session: AsyncSession):
|
||||||
|
"""match_mode='all' requires all fields to match (AND)."""
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="john_test", email="john_test@company.com"),
|
||||||
|
)
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="john_other", email="other@example.com"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 'john' must be in username AND email
|
||||||
|
result = await UserCrud.paginate(
|
||||||
|
db_session,
|
||||||
|
search=SearchConfig(query="john", match_mode="all"),
|
||||||
|
search_fields=[User.username, User.email],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["pagination"]["total_count"] == 1
|
||||||
|
assert result["data"][0].username == "john_test"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_search_config_with_fields(self, db_session: AsyncSession):
|
||||||
|
"""SearchConfig can specify fields directly."""
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session, UserCreate(username="test", email="findme@test.com")
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await UserCrud.paginate(
|
||||||
|
db_session,
|
||||||
|
search=SearchConfig(query="findme", fields=[User.email]),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["pagination"]["total_count"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestNoSearchableFieldsError:
|
||||||
|
"""Tests for NoSearchableFieldsError exception."""
|
||||||
|
|
||||||
|
def test_error_is_api_exception(self):
|
||||||
|
"""NoSearchableFieldsError inherits from ApiException."""
|
||||||
|
from fastapi_toolsets.exceptions import ApiException, NoSearchableFieldsError
|
||||||
|
|
||||||
|
assert issubclass(NoSearchableFieldsError, ApiException)
|
||||||
|
|
||||||
|
def test_error_has_api_error_fields(self):
|
||||||
|
"""NoSearchableFieldsError has proper api_error configuration."""
|
||||||
|
from fastapi_toolsets.exceptions import NoSearchableFieldsError
|
||||||
|
|
||||||
|
assert NoSearchableFieldsError.api_error.code == 400
|
||||||
|
assert NoSearchableFieldsError.api_error.err_code == "SEARCH-400"
|
||||||
|
|
||||||
|
def test_error_message_contains_model_name(self):
|
||||||
|
"""Error message includes the model name."""
|
||||||
|
from fastapi_toolsets.exceptions import NoSearchableFieldsError
|
||||||
|
|
||||||
|
error = NoSearchableFieldsError(User)
|
||||||
|
assert "User" in str(error)
|
||||||
|
assert error.model is User
|
||||||
|
|
||||||
|
def test_error_raised_when_no_fields(self):
|
||||||
|
"""Error is raised when search has no searchable fields."""
|
||||||
|
from sqlalchemy import Integer
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||||
|
|
||||||
|
from fastapi_toolsets.crud.search import build_search_filters
|
||||||
|
from fastapi_toolsets.exceptions import NoSearchableFieldsError
|
||||||
|
|
||||||
|
# Model with no String columns
|
||||||
|
class NoStringBase(DeclarativeBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class NoStringModel(NoStringBase):
|
||||||
|
__tablename__ = "no_strings"
|
||||||
|
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||||
|
count: Mapped[int] = mapped_column(Integer, default=0)
|
||||||
|
|
||||||
|
with pytest.raises(NoSearchableFieldsError) as exc_info:
|
||||||
|
build_search_filters(NoStringModel, "test")
|
||||||
|
|
||||||
|
assert exc_info.value.model is NoStringModel
|
||||||
|
assert "NoStringModel" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetSearchableFields:
|
||||||
|
"""Tests for auto-detection of searchable fields."""
|
||||||
|
|
||||||
|
def test_detects_string_columns(self):
|
||||||
|
"""Detects String columns on the model."""
|
||||||
|
fields = get_searchable_fields(User, include_relationships=False)
|
||||||
|
|
||||||
|
# Should include username and email (String), not id or is_active
|
||||||
|
field_names = [str(f) for f in fields]
|
||||||
|
assert any("username" in f for f in field_names)
|
||||||
|
assert any("email" in f for f in field_names)
|
||||||
|
assert not any("id" in f and "role_id" not in f for f in field_names)
|
||||||
|
assert not any("is_active" in f for f in field_names)
|
||||||
|
|
||||||
|
def test_detects_relationship_fields(self):
|
||||||
|
"""Detects String fields on related models."""
|
||||||
|
fields = get_searchable_fields(User, include_relationships=True)
|
||||||
|
|
||||||
|
# Should include (User.role, Role.name)
|
||||||
|
has_role_name = any(isinstance(f, tuple) and len(f) == 2 for f in fields)
|
||||||
|
assert has_role_name
|
||||||
|
|
||||||
|
def test_skips_collection_relationships(self):
|
||||||
|
"""Skips one-to-many relationships."""
|
||||||
|
fields = get_searchable_fields(Role, include_relationships=True)
|
||||||
|
|
||||||
|
# Role.users is a collection, should not be included
|
||||||
|
field_strs = [str(f) for f in fields]
|
||||||
|
assert not any("users" in f for f in field_strs)
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
"""Tests for fastapi_toolsets.fixtures module."""
|
"""Tests for fastapi_toolsets.fixtures module."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
@@ -57,20 +59,22 @@ class TestFixtureRegistry:
|
|||||||
def test_register_with_decorator(self):
|
def test_register_with_decorator(self):
|
||||||
"""Register fixture with decorator."""
|
"""Register fixture with decorator."""
|
||||||
registry = FixtureRegistry()
|
registry = FixtureRegistry()
|
||||||
|
role_id = uuid.uuid4()
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
def roles():
|
def roles():
|
||||||
return [Role(id=1, name="admin")]
|
return [Role(id=role_id, name="admin")]
|
||||||
|
|
||||||
assert "roles" in [f.name for f in registry.get_all()]
|
assert "roles" in [f.name for f in registry.get_all()]
|
||||||
|
|
||||||
def test_register_with_custom_name(self):
|
def test_register_with_custom_name(self):
|
||||||
"""Register fixture with custom name."""
|
"""Register fixture with custom name."""
|
||||||
registry = FixtureRegistry()
|
registry = FixtureRegistry()
|
||||||
|
role_id = uuid.uuid4()
|
||||||
|
|
||||||
@registry.register(name="custom_roles")
|
@registry.register(name="custom_roles")
|
||||||
def roles():
|
def roles():
|
||||||
return [Role(id=1, name="admin")]
|
return [Role(id=role_id, name="admin")]
|
||||||
|
|
||||||
fixture = registry.get("custom_roles")
|
fixture = registry.get("custom_roles")
|
||||||
assert fixture.name == "custom_roles"
|
assert fixture.name == "custom_roles"
|
||||||
@@ -78,14 +82,23 @@ class TestFixtureRegistry:
|
|||||||
def test_register_with_dependencies(self):
|
def test_register_with_dependencies(self):
|
||||||
"""Register fixture with dependencies."""
|
"""Register fixture with dependencies."""
|
||||||
registry = FixtureRegistry()
|
registry = FixtureRegistry()
|
||||||
|
role_id = uuid.uuid4()
|
||||||
|
user_id = uuid.uuid4()
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
def roles():
|
def roles():
|
||||||
return [Role(id=1, name="admin")]
|
return [Role(id=role_id, name="admin")]
|
||||||
|
|
||||||
@registry.register(depends_on=["roles"])
|
@registry.register(depends_on=["roles"])
|
||||||
def users():
|
def users():
|
||||||
return [User(id=1, username="admin", email="admin@test.com", role_id=1)]
|
return [
|
||||||
|
User(
|
||||||
|
id=user_id,
|
||||||
|
username="admin",
|
||||||
|
email="admin@test.com",
|
||||||
|
role_id=role_id,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
fixture = registry.get("users")
|
fixture = registry.get("users")
|
||||||
assert fixture.depends_on == ["roles"]
|
assert fixture.depends_on == ["roles"]
|
||||||
@@ -93,10 +106,11 @@ class TestFixtureRegistry:
|
|||||||
def test_register_with_contexts(self):
|
def test_register_with_contexts(self):
|
||||||
"""Register fixture with contexts."""
|
"""Register fixture with contexts."""
|
||||||
registry = FixtureRegistry()
|
registry = FixtureRegistry()
|
||||||
|
role_id = uuid.uuid4()
|
||||||
|
|
||||||
@registry.register(contexts=[Context.TESTING])
|
@registry.register(contexts=[Context.TESTING])
|
||||||
def test_data():
|
def test_data():
|
||||||
return [Role(id=100, name="test")]
|
return [Role(id=role_id, name="test")]
|
||||||
|
|
||||||
fixture = registry.get("test_data")
|
fixture = registry.get("test_data")
|
||||||
assert Context.TESTING.value in fixture.contexts
|
assert Context.TESTING.value in fixture.contexts
|
||||||
@@ -145,6 +159,178 @@ class TestFixtureRegistry:
|
|||||||
assert names == {"test_data"}
|
assert names == {"test_data"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestIncludeRegistry:
|
||||||
|
"""Tests for FixtureRegistry.include_registry method."""
|
||||||
|
|
||||||
|
def test_include_empty_registry(self):
|
||||||
|
"""Include an empty registry does nothing."""
|
||||||
|
main_registry = FixtureRegistry()
|
||||||
|
other_registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@main_registry.register
|
||||||
|
def roles():
|
||||||
|
return []
|
||||||
|
|
||||||
|
main_registry.include_registry(other_registry)
|
||||||
|
|
||||||
|
assert len(main_registry.get_all()) == 1
|
||||||
|
|
||||||
|
def test_include_registry_adds_fixtures(self):
|
||||||
|
"""Include registry adds all fixtures from the other registry."""
|
||||||
|
main_registry = FixtureRegistry()
|
||||||
|
other_registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@main_registry.register
|
||||||
|
def roles():
|
||||||
|
return []
|
||||||
|
|
||||||
|
@other_registry.register
|
||||||
|
def users():
|
||||||
|
return []
|
||||||
|
|
||||||
|
@other_registry.register
|
||||||
|
def posts():
|
||||||
|
return []
|
||||||
|
|
||||||
|
main_registry.include_registry(other_registry)
|
||||||
|
|
||||||
|
names = {f.name for f in main_registry.get_all()}
|
||||||
|
assert names == {"roles", "users", "posts"}
|
||||||
|
|
||||||
|
def test_include_registry_preserves_dependencies(self):
|
||||||
|
"""Include registry preserves fixture dependencies."""
|
||||||
|
main_registry = FixtureRegistry()
|
||||||
|
other_registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@main_registry.register
|
||||||
|
def roles():
|
||||||
|
return []
|
||||||
|
|
||||||
|
@other_registry.register(depends_on=["roles"])
|
||||||
|
def users():
|
||||||
|
return []
|
||||||
|
|
||||||
|
main_registry.include_registry(other_registry)
|
||||||
|
|
||||||
|
fixture = main_registry.get("users")
|
||||||
|
assert fixture.depends_on == ["roles"]
|
||||||
|
|
||||||
|
def test_include_registry_preserves_contexts(self):
|
||||||
|
"""Include registry preserves fixture contexts."""
|
||||||
|
main_registry = FixtureRegistry()
|
||||||
|
other_registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@other_registry.register(contexts=[Context.TESTING, Context.DEVELOPMENT])
|
||||||
|
def test_data():
|
||||||
|
return []
|
||||||
|
|
||||||
|
main_registry.include_registry(other_registry)
|
||||||
|
|
||||||
|
fixture = main_registry.get("test_data")
|
||||||
|
assert Context.TESTING.value in fixture.contexts
|
||||||
|
assert Context.DEVELOPMENT.value in fixture.contexts
|
||||||
|
|
||||||
|
def test_include_registry_raises_on_duplicate(self):
|
||||||
|
"""Include registry raises ValueError on duplicate fixture names."""
|
||||||
|
main_registry = FixtureRegistry()
|
||||||
|
other_registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@main_registry.register(name="roles")
|
||||||
|
def roles_main():
|
||||||
|
return []
|
||||||
|
|
||||||
|
@other_registry.register(name="roles")
|
||||||
|
def roles_other():
|
||||||
|
return []
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="already exists"):
|
||||||
|
main_registry.include_registry(other_registry)
|
||||||
|
|
||||||
|
def test_include_multiple_registries(self):
|
||||||
|
"""Include multiple registries sequentially."""
|
||||||
|
main_registry = FixtureRegistry()
|
||||||
|
dev_registry = FixtureRegistry()
|
||||||
|
test_registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@main_registry.register
|
||||||
|
def base():
|
||||||
|
return []
|
||||||
|
|
||||||
|
@dev_registry.register
|
||||||
|
def dev_data():
|
||||||
|
return []
|
||||||
|
|
||||||
|
@test_registry.register
|
||||||
|
def test_data():
|
||||||
|
return []
|
||||||
|
|
||||||
|
main_registry.include_registry(dev_registry)
|
||||||
|
main_registry.include_registry(test_registry)
|
||||||
|
|
||||||
|
names = {f.name for f in main_registry.get_all()}
|
||||||
|
assert names == {"base", "dev_data", "test_data"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestDefaultContexts:
|
||||||
|
"""Tests for FixtureRegistry default contexts."""
|
||||||
|
|
||||||
|
def test_default_contexts_applied_to_fixtures(self):
|
||||||
|
"""Default contexts are applied when no contexts specified."""
|
||||||
|
registry = FixtureRegistry(contexts=[Context.TESTING])
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def test_data():
|
||||||
|
return []
|
||||||
|
|
||||||
|
fixture = registry.get("test_data")
|
||||||
|
assert fixture.contexts == [Context.TESTING.value]
|
||||||
|
|
||||||
|
def test_explicit_contexts_override_default(self):
|
||||||
|
"""Explicit contexts override default contexts."""
|
||||||
|
registry = FixtureRegistry(contexts=[Context.TESTING])
|
||||||
|
|
||||||
|
@registry.register(contexts=[Context.PRODUCTION])
|
||||||
|
def prod_data():
|
||||||
|
return []
|
||||||
|
|
||||||
|
fixture = registry.get("prod_data")
|
||||||
|
assert fixture.contexts == [Context.PRODUCTION.value]
|
||||||
|
|
||||||
|
def test_no_default_contexts_uses_base(self):
|
||||||
|
"""Without default contexts, BASE is used."""
|
||||||
|
registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def data():
|
||||||
|
return []
|
||||||
|
|
||||||
|
fixture = registry.get("data")
|
||||||
|
assert fixture.contexts == [Context.BASE.value]
|
||||||
|
|
||||||
|
def test_multiple_default_contexts(self):
|
||||||
|
"""Multiple default contexts are applied."""
|
||||||
|
registry = FixtureRegistry(contexts=[Context.DEVELOPMENT, Context.TESTING])
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def dev_test_data():
|
||||||
|
return []
|
||||||
|
|
||||||
|
fixture = registry.get("dev_test_data")
|
||||||
|
assert Context.DEVELOPMENT.value in fixture.contexts
|
||||||
|
assert Context.TESTING.value in fixture.contexts
|
||||||
|
|
||||||
|
def test_default_contexts_with_string_values(self):
|
||||||
|
"""Default contexts work with string values."""
|
||||||
|
registry = FixtureRegistry(contexts=["custom_context"])
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def custom_data():
|
||||||
|
return []
|
||||||
|
|
||||||
|
fixture = registry.get("custom_data")
|
||||||
|
assert fixture.contexts == ["custom_context"]
|
||||||
|
|
||||||
|
|
||||||
class TestDependencyResolution:
|
class TestDependencyResolution:
|
||||||
"""Tests for fixture dependency resolution."""
|
"""Tests for fixture dependency resolution."""
|
||||||
|
|
||||||
@@ -244,12 +430,14 @@ class TestLoadFixtures:
|
|||||||
async def test_load_single_fixture(self, db_session: AsyncSession):
|
async def test_load_single_fixture(self, db_session: AsyncSession):
|
||||||
"""Load a single fixture."""
|
"""Load a single fixture."""
|
||||||
registry = FixtureRegistry()
|
registry = FixtureRegistry()
|
||||||
|
role_id_1 = uuid.uuid4()
|
||||||
|
role_id_2 = uuid.uuid4()
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
def roles():
|
def roles():
|
||||||
return [
|
return [
|
||||||
Role(id=1, name="admin"),
|
Role(id=role_id_1, name="admin"),
|
||||||
Role(id=2, name="user"),
|
Role(id=role_id_2, name="user"),
|
||||||
]
|
]
|
||||||
|
|
||||||
result = await load_fixtures(db_session, registry, "roles")
|
result = await load_fixtures(db_session, registry, "roles")
|
||||||
@@ -266,14 +454,23 @@ class TestLoadFixtures:
|
|||||||
async def test_load_with_dependencies(self, db_session: AsyncSession):
|
async def test_load_with_dependencies(self, db_session: AsyncSession):
|
||||||
"""Load fixtures with dependencies."""
|
"""Load fixtures with dependencies."""
|
||||||
registry = FixtureRegistry()
|
registry = FixtureRegistry()
|
||||||
|
role_id = uuid.uuid4()
|
||||||
|
user_id = uuid.uuid4()
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
def roles():
|
def roles():
|
||||||
return [Role(id=1, name="admin")]
|
return [Role(id=role_id, name="admin")]
|
||||||
|
|
||||||
@registry.register(depends_on=["roles"])
|
@registry.register(depends_on=["roles"])
|
||||||
def users():
|
def users():
|
||||||
return [User(id=1, username="admin", email="admin@test.com", role_id=1)]
|
return [
|
||||||
|
User(
|
||||||
|
id=user_id,
|
||||||
|
username="admin",
|
||||||
|
email="admin@test.com",
|
||||||
|
role_id=role_id,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
result = await load_fixtures(db_session, registry, "users")
|
result = await load_fixtures(db_session, registry, "users")
|
||||||
|
|
||||||
@@ -289,10 +486,11 @@ class TestLoadFixtures:
|
|||||||
async def test_load_with_merge_strategy(self, db_session: AsyncSession):
|
async def test_load_with_merge_strategy(self, db_session: AsyncSession):
|
||||||
"""Load fixtures with MERGE strategy updates existing."""
|
"""Load fixtures with MERGE strategy updates existing."""
|
||||||
registry = FixtureRegistry()
|
registry = FixtureRegistry()
|
||||||
|
role_id = uuid.uuid4()
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
def roles():
|
def roles():
|
||||||
return [Role(id=1, name="admin")]
|
return [Role(id=role_id, name="admin")]
|
||||||
|
|
||||||
await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE)
|
await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE)
|
||||||
await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE)
|
await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE)
|
||||||
@@ -306,10 +504,11 @@ class TestLoadFixtures:
|
|||||||
async def test_load_with_skip_existing_strategy(self, db_session: AsyncSession):
|
async def test_load_with_skip_existing_strategy(self, db_session: AsyncSession):
|
||||||
"""Load fixtures with SKIP_EXISTING strategy."""
|
"""Load fixtures with SKIP_EXISTING strategy."""
|
||||||
registry = FixtureRegistry()
|
registry = FixtureRegistry()
|
||||||
|
role_id = uuid.uuid4()
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
def roles():
|
def roles():
|
||||||
return [Role(id=1, name="original")]
|
return [Role(id=role_id, name="original")]
|
||||||
|
|
||||||
await load_fixtures(
|
await load_fixtures(
|
||||||
db_session, registry, "roles", strategy=LoadStrategy.SKIP_EXISTING
|
db_session, registry, "roles", strategy=LoadStrategy.SKIP_EXISTING
|
||||||
@@ -317,7 +516,7 @@ class TestLoadFixtures:
|
|||||||
|
|
||||||
@registry.register(name="roles_updated")
|
@registry.register(name="roles_updated")
|
||||||
def roles_v2():
|
def roles_v2():
|
||||||
return [Role(id=1, name="updated")]
|
return [Role(id=role_id, name="updated")]
|
||||||
|
|
||||||
registry._fixtures["roles"] = registry._fixtures.pop("roles_updated")
|
registry._fixtures["roles"] = registry._fixtures.pop("roles_updated")
|
||||||
|
|
||||||
@@ -327,7 +526,7 @@ class TestLoadFixtures:
|
|||||||
|
|
||||||
from .conftest import RoleCrud
|
from .conftest import RoleCrud
|
||||||
|
|
||||||
role = await RoleCrud.first(db_session, [Role.id == 1])
|
role = await RoleCrud.first(db_session, [Role.id == role_id])
|
||||||
assert role is not None
|
assert role is not None
|
||||||
assert role.name == "original"
|
assert role.name == "original"
|
||||||
|
|
||||||
@@ -335,12 +534,14 @@ class TestLoadFixtures:
|
|||||||
async def test_load_with_insert_strategy(self, db_session: AsyncSession):
|
async def test_load_with_insert_strategy(self, db_session: AsyncSession):
|
||||||
"""Load fixtures with INSERT strategy."""
|
"""Load fixtures with INSERT strategy."""
|
||||||
registry = FixtureRegistry()
|
registry = FixtureRegistry()
|
||||||
|
role_id_1 = uuid.uuid4()
|
||||||
|
role_id_2 = uuid.uuid4()
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
def roles():
|
def roles():
|
||||||
return [
|
return [
|
||||||
Role(id=1, name="admin"),
|
Role(id=role_id_1, name="admin"),
|
||||||
Role(id=2, name="user"),
|
Role(id=role_id_2, name="user"),
|
||||||
]
|
]
|
||||||
|
|
||||||
result = await load_fixtures(
|
result = await load_fixtures(
|
||||||
@@ -375,14 +576,16 @@ class TestLoadFixtures:
|
|||||||
):
|
):
|
||||||
"""Load multiple independent fixtures."""
|
"""Load multiple independent fixtures."""
|
||||||
registry = FixtureRegistry()
|
registry = FixtureRegistry()
|
||||||
|
role_id_1 = uuid.uuid4()
|
||||||
|
role_id_2 = uuid.uuid4()
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
def roles():
|
def roles():
|
||||||
return [Role(id=1, name="admin")]
|
return [Role(id=role_id_1, name="admin")]
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
def other_roles():
|
def other_roles():
|
||||||
return [Role(id=2, name="user")]
|
return [Role(id=role_id_2, name="user")]
|
||||||
|
|
||||||
result = await load_fixtures(db_session, registry, "roles", "other_roles")
|
result = await load_fixtures(db_session, registry, "roles", "other_roles")
|
||||||
|
|
||||||
@@ -402,14 +605,16 @@ class TestLoadFixturesByContext:
|
|||||||
async def test_load_by_single_context(self, db_session: AsyncSession):
|
async def test_load_by_single_context(self, db_session: AsyncSession):
|
||||||
"""Load fixtures by single context."""
|
"""Load fixtures by single context."""
|
||||||
registry = FixtureRegistry()
|
registry = FixtureRegistry()
|
||||||
|
base_role_id = uuid.uuid4()
|
||||||
|
test_role_id = uuid.uuid4()
|
||||||
|
|
||||||
@registry.register(contexts=[Context.BASE])
|
@registry.register(contexts=[Context.BASE])
|
||||||
def base_roles():
|
def base_roles():
|
||||||
return [Role(id=1, name="base_role")]
|
return [Role(id=base_role_id, name="base_role")]
|
||||||
|
|
||||||
@registry.register(contexts=[Context.TESTING])
|
@registry.register(contexts=[Context.TESTING])
|
||||||
def test_roles():
|
def test_roles():
|
||||||
return [Role(id=100, name="test_role")]
|
return [Role(id=test_role_id, name="test_role")]
|
||||||
|
|
||||||
await load_fixtures_by_context(db_session, registry, Context.BASE)
|
await load_fixtures_by_context(db_session, registry, Context.BASE)
|
||||||
|
|
||||||
@@ -418,7 +623,7 @@ class TestLoadFixturesByContext:
|
|||||||
count = await RoleCrud.count(db_session)
|
count = await RoleCrud.count(db_session)
|
||||||
assert count == 1
|
assert count == 1
|
||||||
|
|
||||||
role = await RoleCrud.first(db_session, [Role.id == 1])
|
role = await RoleCrud.first(db_session, [Role.id == base_role_id])
|
||||||
assert role is not None
|
assert role is not None
|
||||||
assert role.name == "base_role"
|
assert role.name == "base_role"
|
||||||
|
|
||||||
@@ -426,14 +631,16 @@ class TestLoadFixturesByContext:
|
|||||||
async def test_load_by_multiple_contexts(self, db_session: AsyncSession):
|
async def test_load_by_multiple_contexts(self, db_session: AsyncSession):
|
||||||
"""Load fixtures by multiple contexts."""
|
"""Load fixtures by multiple contexts."""
|
||||||
registry = FixtureRegistry()
|
registry = FixtureRegistry()
|
||||||
|
base_role_id = uuid.uuid4()
|
||||||
|
test_role_id = uuid.uuid4()
|
||||||
|
|
||||||
@registry.register(contexts=[Context.BASE])
|
@registry.register(contexts=[Context.BASE])
|
||||||
def base_roles():
|
def base_roles():
|
||||||
return [Role(id=1, name="base_role")]
|
return [Role(id=base_role_id, name="base_role")]
|
||||||
|
|
||||||
@registry.register(contexts=[Context.TESTING])
|
@registry.register(contexts=[Context.TESTING])
|
||||||
def test_roles():
|
def test_roles():
|
||||||
return [Role(id=100, name="test_role")]
|
return [Role(id=test_role_id, name="test_role")]
|
||||||
|
|
||||||
await load_fixtures_by_context(
|
await load_fixtures_by_context(
|
||||||
db_session, registry, Context.BASE, Context.TESTING
|
db_session, registry, Context.BASE, Context.TESTING
|
||||||
@@ -448,14 +655,23 @@ class TestLoadFixturesByContext:
|
|||||||
async def test_load_context_with_dependencies(self, db_session: AsyncSession):
|
async def test_load_context_with_dependencies(self, db_session: AsyncSession):
|
||||||
"""Load context fixtures with cross-context dependencies."""
|
"""Load context fixtures with cross-context dependencies."""
|
||||||
registry = FixtureRegistry()
|
registry = FixtureRegistry()
|
||||||
|
role_id = uuid.uuid4()
|
||||||
|
user_id = uuid.uuid4()
|
||||||
|
|
||||||
@registry.register(contexts=[Context.BASE])
|
@registry.register(contexts=[Context.BASE])
|
||||||
def roles():
|
def roles():
|
||||||
return [Role(id=1, name="admin")]
|
return [Role(id=role_id, name="admin")]
|
||||||
|
|
||||||
@registry.register(depends_on=["roles"], contexts=[Context.TESTING])
|
@registry.register(depends_on=["roles"], contexts=[Context.TESTING])
|
||||||
def test_users():
|
def test_users():
|
||||||
return [User(id=1, username="tester", email="test@test.com", role_id=1)]
|
return [
|
||||||
|
User(
|
||||||
|
id=user_id,
|
||||||
|
username="tester",
|
||||||
|
email="test@test.com",
|
||||||
|
role_id=role_id,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
await load_fixtures_by_context(db_session, registry, Context.TESTING)
|
await load_fixtures_by_context(db_session, registry, Context.TESTING)
|
||||||
|
|
||||||
@@ -471,20 +687,41 @@ class TestGetObjByAttr:
|
|||||||
def setup_method(self):
|
def setup_method(self):
|
||||||
"""Set up test fixtures for each test."""
|
"""Set up test fixtures for each test."""
|
||||||
self.registry = FixtureRegistry()
|
self.registry = FixtureRegistry()
|
||||||
|
self.role_id_1 = uuid.uuid4()
|
||||||
|
self.role_id_2 = uuid.uuid4()
|
||||||
|
self.role_id_3 = uuid.uuid4()
|
||||||
|
self.user_id_1 = uuid.uuid4()
|
||||||
|
self.user_id_2 = uuid.uuid4()
|
||||||
|
|
||||||
|
role_id_1 = self.role_id_1
|
||||||
|
role_id_2 = self.role_id_2
|
||||||
|
role_id_3 = self.role_id_3
|
||||||
|
user_id_1 = self.user_id_1
|
||||||
|
user_id_2 = self.user_id_2
|
||||||
|
|
||||||
@self.registry.register
|
@self.registry.register
|
||||||
def roles() -> list[Role]:
|
def roles() -> list[Role]:
|
||||||
return [
|
return [
|
||||||
Role(id=1, name="admin"),
|
Role(id=role_id_1, name="admin"),
|
||||||
Role(id=2, name="user"),
|
Role(id=role_id_2, name="user"),
|
||||||
Role(id=3, name="moderator"),
|
Role(id=role_id_3, name="moderator"),
|
||||||
]
|
]
|
||||||
|
|
||||||
@self.registry.register(depends_on=["roles"])
|
@self.registry.register(depends_on=["roles"])
|
||||||
def users() -> list[User]:
|
def users() -> list[User]:
|
||||||
return [
|
return [
|
||||||
User(id=1, username="alice", email="alice@example.com", role_id=1),
|
User(
|
||||||
User(id=2, username="bob", email="bob@example.com", role_id=1),
|
id=user_id_1,
|
||||||
|
username="alice",
|
||||||
|
email="alice@example.com",
|
||||||
|
role_id=role_id_1,
|
||||||
|
),
|
||||||
|
User(
|
||||||
|
id=user_id_2,
|
||||||
|
username="bob",
|
||||||
|
email="bob@example.com",
|
||||||
|
role_id=role_id_1,
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
self.roles = roles
|
self.roles = roles
|
||||||
@@ -492,18 +729,18 @@ class TestGetObjByAttr:
|
|||||||
|
|
||||||
def test_get_by_id(self):
|
def test_get_by_id(self):
|
||||||
"""Get an object by its id attribute."""
|
"""Get an object by its id attribute."""
|
||||||
role = get_obj_by_attr(self.roles, "id", 1)
|
role = get_obj_by_attr(self.roles, "id", self.role_id_1)
|
||||||
assert role.name == "admin"
|
assert role.name == "admin"
|
||||||
|
|
||||||
def test_get_user_by_username(self):
|
def test_get_user_by_username(self):
|
||||||
"""Get a user by username."""
|
"""Get a user by username."""
|
||||||
user = get_obj_by_attr(self.users, "username", "bob")
|
user = get_obj_by_attr(self.users, "username", "bob")
|
||||||
assert user.id == 2
|
assert user.id == self.user_id_2
|
||||||
assert user.email == "bob@example.com"
|
assert user.email == "bob@example.com"
|
||||||
|
|
||||||
def test_returns_first_match(self):
|
def test_returns_first_match(self):
|
||||||
"""Returns the first matching object when multiple could match."""
|
"""Returns the first matching object when multiple could match."""
|
||||||
user = get_obj_by_attr(self.users, "role_id", 1)
|
user = get_obj_by_attr(self.users, "role_id", self.role_id_1)
|
||||||
assert user.username == "alice"
|
assert user.username == "alice"
|
||||||
|
|
||||||
def test_no_match_raises_stop_iteration(self):
|
def test_no_match_raises_stop_iteration(self):
|
||||||
@@ -514,4 +751,4 @@ class TestGetObjByAttr:
|
|||||||
def test_no_match_on_wrong_value_type(self):
|
def test_no_match_on_wrong_value_type(self):
|
||||||
"""Raises StopIteration when value type doesn't match."""
|
"""Raises StopIteration when value type doesn't match."""
|
||||||
with pytest.raises(StopIteration):
|
with pytest.raises(StopIteration):
|
||||||
get_obj_by_attr(self.roles, "id", "1")
|
get_obj_by_attr(self.roles, "id", "not-a-uuid")
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
"""Tests for fastapi_toolsets.pytest module."""
|
"""Tests for fastapi_toolsets.pytest module."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
@@ -18,27 +20,49 @@ from .conftest import DATABASE_URL, Base, Role, RoleCrud, User, UserCrud
|
|||||||
|
|
||||||
test_registry = FixtureRegistry()
|
test_registry = FixtureRegistry()
|
||||||
|
|
||||||
|
# Fixed UUIDs for test fixtures to allow consistent assertions
|
||||||
|
ROLE_ADMIN_ID = uuid.UUID("00000000-0000-0000-0000-000000001000")
|
||||||
|
ROLE_USER_ID = uuid.UUID("00000000-0000-0000-0000-000000001001")
|
||||||
|
USER_ADMIN_ID = uuid.UUID("00000000-0000-0000-0000-000000002000")
|
||||||
|
USER_USER_ID = uuid.UUID("00000000-0000-0000-0000-000000002001")
|
||||||
|
USER_EXTRA_ID = uuid.UUID("00000000-0000-0000-0000-000000002002")
|
||||||
|
|
||||||
|
|
||||||
@test_registry.register(contexts=[Context.BASE])
|
@test_registry.register(contexts=[Context.BASE])
|
||||||
def roles() -> list[Role]:
|
def roles() -> list[Role]:
|
||||||
return [
|
return [
|
||||||
Role(id=1000, name="plugin_admin"),
|
Role(id=ROLE_ADMIN_ID, name="plugin_admin"),
|
||||||
Role(id=1001, name="plugin_user"),
|
Role(id=ROLE_USER_ID, name="plugin_user"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@test_registry.register(depends_on=["roles"], contexts=[Context.BASE])
|
@test_registry.register(depends_on=["roles"], contexts=[Context.BASE])
|
||||||
def users() -> list[User]:
|
def users() -> list[User]:
|
||||||
return [
|
return [
|
||||||
User(id=1000, username="plugin_admin", email="padmin@test.com", role_id=1000),
|
User(
|
||||||
User(id=1001, username="plugin_user", email="puser@test.com", role_id=1001),
|
id=USER_ADMIN_ID,
|
||||||
|
username="plugin_admin",
|
||||||
|
email="padmin@test.com",
|
||||||
|
role_id=ROLE_ADMIN_ID,
|
||||||
|
),
|
||||||
|
User(
|
||||||
|
id=USER_USER_ID,
|
||||||
|
username="plugin_user",
|
||||||
|
email="puser@test.com",
|
||||||
|
role_id=ROLE_USER_ID,
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@test_registry.register(depends_on=["users"], contexts=[Context.TESTING])
|
@test_registry.register(depends_on=["users"], contexts=[Context.TESTING])
|
||||||
def extra_users() -> list[User]:
|
def extra_users() -> list[User]:
|
||||||
return [
|
return [
|
||||||
User(id=1002, username="plugin_extra", email="pextra@test.com", role_id=1001),
|
User(
|
||||||
|
id=USER_EXTRA_ID,
|
||||||
|
username="plugin_extra",
|
||||||
|
email="pextra@test.com",
|
||||||
|
role_id=ROLE_USER_ID,
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -73,7 +97,7 @@ class TestGeneratedFixtures:
|
|||||||
assert fixture_roles[1].name == "plugin_user"
|
assert fixture_roles[1].name == "plugin_user"
|
||||||
|
|
||||||
# Verify data is in database
|
# Verify data is in database
|
||||||
count = await RoleCrud.count(db_session, [Role.id >= 1000])
|
count = await RoleCrud.count(db_session)
|
||||||
assert count == 2
|
assert count == 2
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@@ -86,11 +110,11 @@ class TestGeneratedFixtures:
|
|||||||
assert len(fixture_users) == 2
|
assert len(fixture_users) == 2
|
||||||
|
|
||||||
# Roles should also be in database
|
# Roles should also be in database
|
||||||
roles_count = await RoleCrud.count(db_session, [Role.id >= 1000])
|
roles_count = await RoleCrud.count(db_session)
|
||||||
assert roles_count == 2
|
assert roles_count == 2
|
||||||
|
|
||||||
# Users should be in database
|
# Users should be in database
|
||||||
users_count = await UserCrud.count(db_session, [User.id >= 1000])
|
users_count = await UserCrud.count(db_session)
|
||||||
assert users_count == 2
|
assert users_count == 2
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@@ -100,7 +124,7 @@ class TestGeneratedFixtures:
|
|||||||
"""Fixture returns actual model instances."""
|
"""Fixture returns actual model instances."""
|
||||||
user = fixture_users[0]
|
user = fixture_users[0]
|
||||||
assert isinstance(user, User)
|
assert isinstance(user, User)
|
||||||
assert user.id == 1000
|
assert user.id == USER_ADMIN_ID
|
||||||
assert user.username == "plugin_admin"
|
assert user.username == "plugin_admin"
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@@ -111,7 +135,7 @@ class TestGeneratedFixtures:
|
|||||||
# Load user with role relationship
|
# Load user with role relationship
|
||||||
user = await UserCrud.get(
|
user = await UserCrud.get(
|
||||||
db_session,
|
db_session,
|
||||||
[User.id == 1000],
|
[User.id == USER_ADMIN_ID],
|
||||||
load_options=[selectinload(User.role)],
|
load_options=[selectinload(User.role)],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -127,8 +151,8 @@ class TestGeneratedFixtures:
|
|||||||
assert len(fixture_extra_users) == 1
|
assert len(fixture_extra_users) == 1
|
||||||
|
|
||||||
# All fixtures should be loaded
|
# All fixtures should be loaded
|
||||||
roles_count = await RoleCrud.count(db_session, [Role.id >= 1000])
|
roles_count = await RoleCrud.count(db_session)
|
||||||
users_count = await UserCrud.count(db_session, [User.id >= 1000])
|
users_count = await UserCrud.count(db_session)
|
||||||
|
|
||||||
assert roles_count == 2
|
assert roles_count == 2
|
||||||
assert users_count == 3 # 2 from users + 1 from extra_users
|
assert users_count == 3 # 2 from users + 1 from extra_users
|
||||||
@@ -141,8 +165,7 @@ class TestGeneratedFixtures:
|
|||||||
# Get all users loaded by fixture
|
# Get all users loaded by fixture
|
||||||
users = await UserCrud.get_multi(
|
users = await UserCrud.get_multi(
|
||||||
db_session,
|
db_session,
|
||||||
filters=[User.id >= 1000],
|
order_by=User.username,
|
||||||
order_by=User.id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(users) == 2
|
assert len(users) == 2
|
||||||
@@ -161,8 +184,8 @@ class TestGeneratedFixtures:
|
|||||||
assert len(fixture_users) == 2
|
assert len(fixture_users) == 2
|
||||||
|
|
||||||
# Both should be in database
|
# Both should be in database
|
||||||
roles = await RoleCrud.get_multi(db_session, filters=[Role.id >= 1000])
|
roles = await RoleCrud.get_multi(db_session)
|
||||||
users = await UserCrud.get_multi(db_session, filters=[User.id >= 1000])
|
users = await UserCrud.get_multi(db_session)
|
||||||
|
|
||||||
assert len(roles) == 2
|
assert len(roles) == 2
|
||||||
assert len(users) == 2
|
assert len(users) == 2
|
||||||
@@ -215,14 +238,15 @@ class TestCreateDbSession:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_creates_working_session(self):
|
async def test_creates_working_session(self):
|
||||||
"""Session can perform database operations."""
|
"""Session can perform database operations."""
|
||||||
|
role_id = uuid.uuid4()
|
||||||
async with create_db_session(DATABASE_URL, Base) as session:
|
async with create_db_session(DATABASE_URL, Base) as session:
|
||||||
assert isinstance(session, AsyncSession)
|
assert isinstance(session, AsyncSession)
|
||||||
|
|
||||||
role = Role(id=9001, name="test_helper_role")
|
role = Role(id=role_id, name="test_helper_role")
|
||||||
session.add(role)
|
session.add(role)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
result = await session.execute(select(Role).where(Role.id == 9001))
|
result = await session.execute(select(Role).where(Role.id == role_id))
|
||||||
fetched = result.scalar_one()
|
fetched = result.scalar_one()
|
||||||
assert fetched.name == "test_helper_role"
|
assert fetched.name == "test_helper_role"
|
||||||
|
|
||||||
@@ -237,8 +261,9 @@ class TestCreateDbSession:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_tables_dropped_after_session(self):
|
async def test_tables_dropped_after_session(self):
|
||||||
"""Tables are dropped after session closes when drop_tables=True."""
|
"""Tables are dropped after session closes when drop_tables=True."""
|
||||||
|
role_id = uuid.uuid4()
|
||||||
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
|
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
|
||||||
role = Role(id=9002, name="will_be_dropped")
|
role = Role(id=role_id, name="will_be_dropped")
|
||||||
session.add(role)
|
session.add(role)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
@@ -250,14 +275,15 @@ class TestCreateDbSession:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_tables_preserved_when_drop_disabled(self):
|
async def test_tables_preserved_when_drop_disabled(self):
|
||||||
"""Tables are preserved when drop_tables=False."""
|
"""Tables are preserved when drop_tables=False."""
|
||||||
|
role_id = uuid.uuid4()
|
||||||
async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session:
|
async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session:
|
||||||
role = Role(id=9003, name="preserved_role")
|
role = Role(id=role_id, name="preserved_role")
|
||||||
session.add(role)
|
session.add(role)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
# Create another session without dropping
|
# Create another session without dropping
|
||||||
async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session:
|
async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session:
|
||||||
result = await session.execute(select(Role).where(Role.id == 9003))
|
result = await session.execute(select(Role).where(Role.id == role_id))
|
||||||
fetched = result.scalar_one_or_none()
|
fetched = result.scalar_one_or_none()
|
||||||
assert fetched is not None
|
assert fetched is not None
|
||||||
assert fetched.name == "preserved_role"
|
assert fetched.name == "preserved_role"
|
||||||
|
|||||||
2
uv.lock
generated
2
uv.lock
generated
@@ -220,7 +220,7 @@ wheels = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fastapi-toolsets"
|
name = "fastapi-toolsets"
|
||||||
version = "0.3.0"
|
version = "0.5.0"
|
||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "asyncpg" },
|
{ name = "asyncpg" },
|
||||||
|
|||||||
Reference in New Issue
Block a user