mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-03-01 17:00:48 +01:00
feat: add pytest helpers (#8)
This commit is contained in:
@@ -1,11 +1,6 @@
|
|||||||
from .fixtures import (
|
from .enum import LoadStrategy
|
||||||
Context,
|
from .registry import Context, FixtureRegistry
|
||||||
FixtureRegistry,
|
from .utils import get_obj_by_attr, load_fixtures, load_fixtures_by_context
|
||||||
LoadStrategy,
|
|
||||||
load_fixtures,
|
|
||||||
load_fixtures_by_context,
|
|
||||||
)
|
|
||||||
from .utils import get_obj_by_attr
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Context",
|
"Context",
|
||||||
@@ -16,12 +11,3 @@ __all__ = [
|
|||||||
"load_fixtures_by_context",
|
"load_fixtures_by_context",
|
||||||
"register_fixtures",
|
"register_fixtures",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# We lazy-load register_fixtures to avoid needing pytest when using fixtures CLI
|
|
||||||
def __getattr__(name: str):
|
|
||||||
if name == "register_fixtures":
|
|
||||||
from .pytest_plugin import register_fixtures
|
|
||||||
|
|
||||||
return register_fixtures
|
|
||||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
||||||
|
|||||||
30
src/fastapi_toolsets/fixtures/enum.py
Normal file
30
src/fastapi_toolsets/fixtures/enum.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class LoadStrategy(str, Enum):
|
||||||
|
"""Strategy for loading fixtures into the database."""
|
||||||
|
|
||||||
|
INSERT = "insert"
|
||||||
|
"""Insert new records. Fails if record already exists."""
|
||||||
|
|
||||||
|
MERGE = "merge"
|
||||||
|
"""Insert or update based on primary key (SQLAlchemy merge)."""
|
||||||
|
|
||||||
|
SKIP_EXISTING = "skip_existing"
|
||||||
|
"""Insert only if record doesn't exist (based on primary key)."""
|
||||||
|
|
||||||
|
|
||||||
|
class Context(str, Enum):
|
||||||
|
"""Predefined fixture contexts."""
|
||||||
|
|
||||||
|
BASE = "base"
|
||||||
|
"""Base fixtures loaded in all environments."""
|
||||||
|
|
||||||
|
PRODUCTION = "production"
|
||||||
|
"""Production-only fixtures."""
|
||||||
|
|
||||||
|
DEVELOPMENT = "development"
|
||||||
|
"""Development fixtures."""
|
||||||
|
|
||||||
|
TESTING = "testing"
|
||||||
|
"""Test fixtures."""
|
||||||
@@ -3,46 +3,15 @@
|
|||||||
import logging
|
import logging
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
from ..db import get_transaction
|
from .enum import Context
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LoadStrategy(str, Enum):
|
|
||||||
"""Strategy for loading fixtures into the database."""
|
|
||||||
|
|
||||||
INSERT = "insert"
|
|
||||||
"""Insert new records. Fails if record already exists."""
|
|
||||||
|
|
||||||
MERGE = "merge"
|
|
||||||
"""Insert or update based on primary key (SQLAlchemy merge)."""
|
|
||||||
|
|
||||||
SKIP_EXISTING = "skip_existing"
|
|
||||||
"""Insert only if record doesn't exist (based on primary key)."""
|
|
||||||
|
|
||||||
|
|
||||||
class Context(str, Enum):
|
|
||||||
"""Predefined fixture contexts."""
|
|
||||||
|
|
||||||
BASE = "base"
|
|
||||||
"""Base fixtures loaded in all environments."""
|
|
||||||
|
|
||||||
PRODUCTION = "production"
|
|
||||||
"""Production-only fixtures."""
|
|
||||||
|
|
||||||
DEVELOPMENT = "development"
|
|
||||||
"""Development fixtures."""
|
|
||||||
|
|
||||||
TESTING = "testing"
|
|
||||||
"""Test fixtures."""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Fixture:
|
class Fixture:
|
||||||
"""A fixture definition with metadata."""
|
"""A fixture definition with metadata."""
|
||||||
@@ -204,118 +173,3 @@ class FixtureRegistry:
|
|||||||
all_deps.update(deps)
|
all_deps.update(deps)
|
||||||
|
|
||||||
return self.resolve_dependencies(*all_deps)
|
return self.resolve_dependencies(*all_deps)
|
||||||
|
|
||||||
|
|
||||||
async def load_fixtures(
|
|
||||||
session: AsyncSession,
|
|
||||||
registry: FixtureRegistry,
|
|
||||||
*names: str,
|
|
||||||
strategy: LoadStrategy = LoadStrategy.MERGE,
|
|
||||||
) -> dict[str, list[DeclarativeBase]]:
|
|
||||||
"""Load specific fixtures by name with dependencies.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: Database session
|
|
||||||
registry: Fixture registry
|
|
||||||
*names: Fixture names to load (dependencies auto-resolved)
|
|
||||||
strategy: How to handle existing records
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict mapping fixture names to loaded instances
|
|
||||||
|
|
||||||
Example:
|
|
||||||
# Loads 'roles' first (dependency), then 'users'
|
|
||||||
result = await load_fixtures(session, fixtures, "users")
|
|
||||||
print(result["users"]) # [User(...), ...]
|
|
||||||
"""
|
|
||||||
ordered = registry.resolve_dependencies(*names)
|
|
||||||
return await _load_ordered(session, registry, ordered, strategy)
|
|
||||||
|
|
||||||
|
|
||||||
async def load_fixtures_by_context(
|
|
||||||
session: AsyncSession,
|
|
||||||
registry: FixtureRegistry,
|
|
||||||
*contexts: str | Context,
|
|
||||||
strategy: LoadStrategy = LoadStrategy.MERGE,
|
|
||||||
) -> dict[str, list[DeclarativeBase]]:
|
|
||||||
"""Load all fixtures for specific contexts.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: Database session
|
|
||||||
registry: Fixture registry
|
|
||||||
*contexts: Contexts to load (e.g., Context.BASE, Context.TESTING)
|
|
||||||
strategy: How to handle existing records
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict mapping fixture names to loaded instances
|
|
||||||
|
|
||||||
Example:
|
|
||||||
# Load base + testing fixtures
|
|
||||||
await load_fixtures_by_context(
|
|
||||||
session, fixtures,
|
|
||||||
Context.BASE, Context.TESTING
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
ordered = registry.resolve_context_dependencies(*contexts)
|
|
||||||
return await _load_ordered(session, registry, ordered, strategy)
|
|
||||||
|
|
||||||
|
|
||||||
async def _load_ordered(
|
|
||||||
session: AsyncSession,
|
|
||||||
registry: FixtureRegistry,
|
|
||||||
ordered_names: list[str],
|
|
||||||
strategy: LoadStrategy,
|
|
||||||
) -> dict[str, list[DeclarativeBase]]:
|
|
||||||
"""Load fixtures in order."""
|
|
||||||
results: dict[str, list[DeclarativeBase]] = {}
|
|
||||||
|
|
||||||
for name in ordered_names:
|
|
||||||
fixture = registry.get(name)
|
|
||||||
instances = list(fixture.func())
|
|
||||||
|
|
||||||
if not instances:
|
|
||||||
results[name] = []
|
|
||||||
continue
|
|
||||||
|
|
||||||
model_name = type(instances[0]).__name__
|
|
||||||
loaded: list[DeclarativeBase] = []
|
|
||||||
|
|
||||||
async with get_transaction(session):
|
|
||||||
for instance in instances:
|
|
||||||
if strategy == LoadStrategy.INSERT:
|
|
||||||
session.add(instance)
|
|
||||||
loaded.append(instance)
|
|
||||||
|
|
||||||
elif strategy == LoadStrategy.MERGE:
|
|
||||||
merged = await session.merge(instance)
|
|
||||||
loaded.append(merged)
|
|
||||||
|
|
||||||
elif strategy == LoadStrategy.SKIP_EXISTING:
|
|
||||||
pk = _get_primary_key(instance)
|
|
||||||
if pk is not None:
|
|
||||||
existing = await session.get(type(instance), pk)
|
|
||||||
if existing is None:
|
|
||||||
session.add(instance)
|
|
||||||
loaded.append(instance)
|
|
||||||
else:
|
|
||||||
session.add(instance)
|
|
||||||
loaded.append(instance)
|
|
||||||
|
|
||||||
results[name] = loaded
|
|
||||||
logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
|
|
||||||
"""Get the primary key value of a model instance."""
|
|
||||||
mapper = instance.__class__.__mapper__
|
|
||||||
pk_cols = mapper.primary_key
|
|
||||||
|
|
||||||
if len(pk_cols) == 1:
|
|
||||||
return getattr(instance, pk_cols[0].name, None)
|
|
||||||
|
|
||||||
pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols)
|
|
||||||
if all(v is not None for v in pk_values):
|
|
||||||
return pk_values
|
|
||||||
return None
|
|
||||||
@@ -1,8 +1,16 @@
|
|||||||
|
import logging
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from typing import Any, TypeVar
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
from ..db import get_transaction
|
||||||
|
from .enum import LoadStrategy
|
||||||
|
from .registry import Context, FixtureRegistry
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
T = TypeVar("T", bound=DeclarativeBase)
|
T = TypeVar("T", bound=DeclarativeBase)
|
||||||
|
|
||||||
|
|
||||||
@@ -24,3 +32,118 @@ def get_obj_by_attr(
|
|||||||
StopIteration: If no matching object is found.
|
StopIteration: If no matching object is found.
|
||||||
"""
|
"""
|
||||||
return next(obj for obj in fixtures() if getattr(obj, attr_name) == value)
|
return next(obj for obj in fixtures() if getattr(obj, attr_name) == value)
|
||||||
|
|
||||||
|
|
||||||
|
async def load_fixtures(
|
||||||
|
session: AsyncSession,
|
||||||
|
registry: FixtureRegistry,
|
||||||
|
*names: str,
|
||||||
|
strategy: LoadStrategy = LoadStrategy.MERGE,
|
||||||
|
) -> dict[str, list[DeclarativeBase]]:
|
||||||
|
"""Load specific fixtures by name with dependencies.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: Database session
|
||||||
|
registry: Fixture registry
|
||||||
|
*names: Fixture names to load (dependencies auto-resolved)
|
||||||
|
strategy: How to handle existing records
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping fixture names to loaded instances
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# Loads 'roles' first (dependency), then 'users'
|
||||||
|
result = await load_fixtures(session, fixtures, "users")
|
||||||
|
print(result["users"]) # [User(...), ...]
|
||||||
|
"""
|
||||||
|
ordered = registry.resolve_dependencies(*names)
|
||||||
|
return await _load_ordered(session, registry, ordered, strategy)
|
||||||
|
|
||||||
|
|
||||||
|
async def load_fixtures_by_context(
|
||||||
|
session: AsyncSession,
|
||||||
|
registry: FixtureRegistry,
|
||||||
|
*contexts: str | Context,
|
||||||
|
strategy: LoadStrategy = LoadStrategy.MERGE,
|
||||||
|
) -> dict[str, list[DeclarativeBase]]:
|
||||||
|
"""Load all fixtures for specific contexts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: Database session
|
||||||
|
registry: Fixture registry
|
||||||
|
*contexts: Contexts to load (e.g., Context.BASE, Context.TESTING)
|
||||||
|
strategy: How to handle existing records
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping fixture names to loaded instances
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# Load base + testing fixtures
|
||||||
|
await load_fixtures_by_context(
|
||||||
|
session, fixtures,
|
||||||
|
Context.BASE, Context.TESTING
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
ordered = registry.resolve_context_dependencies(*contexts)
|
||||||
|
return await _load_ordered(session, registry, ordered, strategy)
|
||||||
|
|
||||||
|
|
||||||
|
async def _load_ordered(
|
||||||
|
session: AsyncSession,
|
||||||
|
registry: FixtureRegistry,
|
||||||
|
ordered_names: list[str],
|
||||||
|
strategy: LoadStrategy,
|
||||||
|
) -> dict[str, list[DeclarativeBase]]:
|
||||||
|
"""Load fixtures in order."""
|
||||||
|
results: dict[str, list[DeclarativeBase]] = {}
|
||||||
|
|
||||||
|
for name in ordered_names:
|
||||||
|
fixture = registry.get(name)
|
||||||
|
instances = list(fixture.func())
|
||||||
|
|
||||||
|
if not instances:
|
||||||
|
results[name] = []
|
||||||
|
continue
|
||||||
|
|
||||||
|
model_name = type(instances[0]).__name__
|
||||||
|
loaded: list[DeclarativeBase] = []
|
||||||
|
|
||||||
|
async with get_transaction(session):
|
||||||
|
for instance in instances:
|
||||||
|
if strategy == LoadStrategy.INSERT:
|
||||||
|
session.add(instance)
|
||||||
|
loaded.append(instance)
|
||||||
|
|
||||||
|
elif strategy == LoadStrategy.MERGE:
|
||||||
|
merged = await session.merge(instance)
|
||||||
|
loaded.append(merged)
|
||||||
|
|
||||||
|
elif strategy == LoadStrategy.SKIP_EXISTING:
|
||||||
|
pk = _get_primary_key(instance)
|
||||||
|
if pk is not None:
|
||||||
|
existing = await session.get(type(instance), pk)
|
||||||
|
if existing is None:
|
||||||
|
session.add(instance)
|
||||||
|
loaded.append(instance)
|
||||||
|
else:
|
||||||
|
session.add(instance)
|
||||||
|
loaded.append(instance)
|
||||||
|
|
||||||
|
results[name] = loaded
|
||||||
|
logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
|
||||||
|
"""Get the primary key value of a model instance."""
|
||||||
|
mapper = instance.__class__.__mapper__
|
||||||
|
pk_cols = mapper.primary_key
|
||||||
|
|
||||||
|
if len(pk_cols) == 1:
|
||||||
|
return getattr(instance, pk_cols[0].name, None)
|
||||||
|
|
||||||
|
pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols)
|
||||||
|
if all(v is not None for v in pk_values):
|
||||||
|
return pk_values
|
||||||
|
return None
|
||||||
|
|||||||
8
src/fastapi_toolsets/pytest/__init__.py
Normal file
8
src/fastapi_toolsets/pytest/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
from .plugin import register_fixtures
|
||||||
|
from .utils import create_async_client, create_db_session
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"create_async_client",
|
||||||
|
"create_db_session",
|
||||||
|
"register_fixtures",
|
||||||
|
]
|
||||||
@@ -59,7 +59,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
from ..db import get_transaction
|
from ..db import get_transaction
|
||||||
from .fixtures import FixtureRegistry, LoadStrategy
|
from ..fixtures import FixtureRegistry, LoadStrategy
|
||||||
|
|
||||||
|
|
||||||
def register_fixtures(
|
def register_fixtures(
|
||||||
110
src/fastapi_toolsets/pytest/utils.py
Normal file
110
src/fastapi_toolsets/pytest/utils.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
"""Pytest helper utilities for FastAPI testing."""
|
||||||
|
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
from ..db import create_db_context
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def create_async_client(
|
||||||
|
app: Any,
|
||||||
|
base_url: str = "http://test",
|
||||||
|
) -> AsyncGenerator[AsyncClient, None]:
|
||||||
|
"""Create an async httpx client for testing FastAPI applications.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: FastAPI application instance.
|
||||||
|
base_url: Base URL for requests. Defaults to "http://test".
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
An AsyncClient configured for the app.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi_toolsets.pytest import create_async_client
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def client():
|
||||||
|
async with create_async_client(app) as c:
|
||||||
|
yield c
|
||||||
|
|
||||||
|
async def test_endpoint(client: AsyncClient):
|
||||||
|
response = await client.get("/health")
|
||||||
|
assert response.status_code == 200
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url=base_url) as client:
|
||||||
|
yield client
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def create_db_session(
|
||||||
|
database_url: str,
|
||||||
|
base: type[DeclarativeBase],
|
||||||
|
*,
|
||||||
|
echo: bool = False,
|
||||||
|
expire_on_commit: bool = False,
|
||||||
|
drop_tables: bool = True,
|
||||||
|
) -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
"""Create a database session for testing.
|
||||||
|
|
||||||
|
Creates tables before yielding the session and optionally drops them after.
|
||||||
|
Each call creates a fresh engine and session for test isolation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
database_url: Database connection URL (e.g., "postgresql+asyncpg://...").
|
||||||
|
base: SQLAlchemy DeclarativeBase class containing model metadata.
|
||||||
|
echo: Enable SQLAlchemy query logging. Defaults to False.
|
||||||
|
expire_on_commit: Expire objects after commit. Defaults to False.
|
||||||
|
drop_tables: Drop tables after test. Defaults to True.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
An AsyncSession ready for database operations.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.pytest import create_db_session
|
||||||
|
from app.models import Base
|
||||||
|
|
||||||
|
DATABASE_URL = "postgresql+asyncpg://user:pass@localhost/test_db"
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_session():
|
||||||
|
async with create_db_session(DATABASE_URL, Base) as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
async def test_create_user(db_session: AsyncSession):
|
||||||
|
user = User(name="test")
|
||||||
|
db_session.add(user)
|
||||||
|
await db_session.commit()
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
engine = create_async_engine(database_url, echo=echo)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create tables
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(base.metadata.create_all)
|
||||||
|
|
||||||
|
# Create session using existing db context utility
|
||||||
|
session_maker = async_sessionmaker(engine, expire_on_commit=expire_on_commit)
|
||||||
|
get_session = create_db_context(session_maker)
|
||||||
|
|
||||||
|
async with get_session() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
if drop_tables:
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(base.metadata.drop_all)
|
||||||
|
finally:
|
||||||
|
await engine.dispose()
|
||||||
@@ -7,6 +7,7 @@ from fastapi_toolsets.fixtures import (
|
|||||||
Context,
|
Context,
|
||||||
FixtureRegistry,
|
FixtureRegistry,
|
||||||
LoadStrategy,
|
LoadStrategy,
|
||||||
|
get_obj_by_attr,
|
||||||
load_fixtures,
|
load_fixtures,
|
||||||
load_fixtures_by_context,
|
load_fixtures_by_context,
|
||||||
)
|
)
|
||||||
@@ -330,6 +331,69 @@ class TestLoadFixtures:
|
|||||||
assert role is not None
|
assert role is not None
|
||||||
assert role.name == "original"
|
assert role.name == "original"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_load_with_insert_strategy(self, db_session: AsyncSession):
|
||||||
|
"""Load fixtures with INSERT strategy."""
|
||||||
|
registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def roles():
|
||||||
|
return [
|
||||||
|
Role(id=1, name="admin"),
|
||||||
|
Role(id=2, name="user"),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await load_fixtures(
|
||||||
|
db_session, registry, "roles", strategy=LoadStrategy.INSERT
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "roles" in result
|
||||||
|
assert len(result["roles"]) == 2
|
||||||
|
|
||||||
|
from .conftest import RoleCrud
|
||||||
|
|
||||||
|
count = await RoleCrud.count(db_session)
|
||||||
|
assert count == 2
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_load_empty_fixture(self, db_session: AsyncSession):
|
||||||
|
"""Load a fixture that returns an empty list."""
|
||||||
|
registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def empty_roles():
|
||||||
|
return []
|
||||||
|
|
||||||
|
result = await load_fixtures(db_session, registry, "empty_roles")
|
||||||
|
|
||||||
|
assert "empty_roles" in result
|
||||||
|
assert result["empty_roles"] == []
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_load_multiple_fixtures_without_dependencies(
|
||||||
|
self, db_session: AsyncSession
|
||||||
|
):
|
||||||
|
"""Load multiple independent fixtures."""
|
||||||
|
registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def roles():
|
||||||
|
return [Role(id=1, name="admin")]
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def other_roles():
|
||||||
|
return [Role(id=2, name="user")]
|
||||||
|
|
||||||
|
result = await load_fixtures(db_session, registry, "roles", "other_roles")
|
||||||
|
|
||||||
|
assert "roles" in result
|
||||||
|
assert "other_roles" in result
|
||||||
|
|
||||||
|
from .conftest import RoleCrud
|
||||||
|
|
||||||
|
count = await RoleCrud.count(db_session)
|
||||||
|
assert count == 2
|
||||||
|
|
||||||
|
|
||||||
class TestLoadFixturesByContext:
|
class TestLoadFixturesByContext:
|
||||||
"""Tests for load_fixtures_by_context function."""
|
"""Tests for load_fixtures_by_context function."""
|
||||||
@@ -399,3 +463,55 @@ class TestLoadFixturesByContext:
|
|||||||
|
|
||||||
assert await RoleCrud.count(db_session) == 1
|
assert await RoleCrud.count(db_session) == 1
|
||||||
assert await UserCrud.count(db_session) == 1
|
assert await UserCrud.count(db_session) == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetObjByAttr:
|
||||||
|
"""Tests for get_obj_by_attr helper function."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Set up test fixtures for each test."""
|
||||||
|
self.registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@self.registry.register
|
||||||
|
def roles() -> list[Role]:
|
||||||
|
return [
|
||||||
|
Role(id=1, name="admin"),
|
||||||
|
Role(id=2, name="user"),
|
||||||
|
Role(id=3, name="moderator"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@self.registry.register(depends_on=["roles"])
|
||||||
|
def users() -> list[User]:
|
||||||
|
return [
|
||||||
|
User(id=1, username="alice", email="alice@example.com", role_id=1),
|
||||||
|
User(id=2, username="bob", email="bob@example.com", role_id=1),
|
||||||
|
]
|
||||||
|
|
||||||
|
self.roles = roles
|
||||||
|
self.users = users
|
||||||
|
|
||||||
|
def test_get_by_id(self):
|
||||||
|
"""Get an object by its id attribute."""
|
||||||
|
role = get_obj_by_attr(self.roles, "id", 1)
|
||||||
|
assert role.name == "admin"
|
||||||
|
|
||||||
|
def test_get_user_by_username(self):
|
||||||
|
"""Get a user by username."""
|
||||||
|
user = get_obj_by_attr(self.users, "username", "bob")
|
||||||
|
assert user.id == 2
|
||||||
|
assert user.email == "bob@example.com"
|
||||||
|
|
||||||
|
def test_returns_first_match(self):
|
||||||
|
"""Returns the first matching object when multiple could match."""
|
||||||
|
user = get_obj_by_attr(self.users, "role_id", 1)
|
||||||
|
assert user.username == "alice"
|
||||||
|
|
||||||
|
def test_no_match_raises_stop_iteration(self):
|
||||||
|
"""Raises StopIteration when no object matches."""
|
||||||
|
with pytest.raises(StopIteration):
|
||||||
|
get_obj_by_attr(self.roles, "name", "nonexistent")
|
||||||
|
|
||||||
|
def test_no_match_on_wrong_value_type(self):
|
||||||
|
"""Raises StopIteration when value type doesn't match."""
|
||||||
|
with pytest.raises(StopIteration):
|
||||||
|
get_obj_by_attr(self.roles, "id", "1")
|
||||||
|
|||||||
@@ -1,57 +0,0 @@
|
|||||||
"""Tests for fastapi_toolsets.fixtures.utils."""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from fastapi_toolsets.fixtures import FixtureRegistry
|
|
||||||
from fastapi_toolsets.fixtures.utils import get_obj_by_attr
|
|
||||||
|
|
||||||
from .conftest import Role, User
|
|
||||||
|
|
||||||
registry = FixtureRegistry()
|
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
|
||||||
def roles() -> list[Role]:
|
|
||||||
return [
|
|
||||||
Role(id=1, name="admin"),
|
|
||||||
Role(id=2, name="user"),
|
|
||||||
Role(id=3, name="moderator"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@registry.register(depends_on=["roles"])
|
|
||||||
def users() -> list[User]:
|
|
||||||
return [
|
|
||||||
User(id=1, username="alice", email="alice@example.com", role_id=1),
|
|
||||||
User(id=2, username="bob", email="bob@example.com", role_id=1),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class TestGetObjByAttr:
|
|
||||||
"""Tests for get_obj_by_attr."""
|
|
||||||
|
|
||||||
def test_get_by_id(self):
|
|
||||||
"""Get an object by its id attribute."""
|
|
||||||
role = get_obj_by_attr(roles, "id", 1)
|
|
||||||
assert role.name == "admin"
|
|
||||||
|
|
||||||
def test_get_user_by_username(self):
|
|
||||||
"""Get a user by username."""
|
|
||||||
user = get_obj_by_attr(users, "username", "bob")
|
|
||||||
assert user.id == 2
|
|
||||||
assert user.email == "bob@example.com"
|
|
||||||
|
|
||||||
def test_returns_first_match(self):
|
|
||||||
"""Returns the first matching object when multiple could match."""
|
|
||||||
user = get_obj_by_attr(users, "role_id", 1)
|
|
||||||
assert user.username == "alice"
|
|
||||||
|
|
||||||
def test_no_match_raises_stop_iteration(self):
|
|
||||||
"""Raises StopIteration when no object matches."""
|
|
||||||
with pytest.raises(StopIteration):
|
|
||||||
get_obj_by_attr(roles, "name", "nonexistent")
|
|
||||||
|
|
||||||
def test_no_match_on_wrong_value_type(self):
|
|
||||||
"""Raises StopIteration when value type doesn't match."""
|
|
||||||
with pytest.raises(StopIteration):
|
|
||||||
get_obj_by_attr(roles, "id", "1")
|
|
||||||
@@ -1,12 +1,20 @@
|
|||||||
"""Tests for fastapi_toolsets.pytest_plugin module."""
|
"""Tests for fastapi_toolsets.pytest module."""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from httpx import AsyncClient
|
||||||
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
from fastapi_toolsets.fixtures import Context, FixtureRegistry, register_fixtures
|
from fastapi_toolsets.fixtures import Context, FixtureRegistry
|
||||||
|
from fastapi_toolsets.pytest import (
|
||||||
|
create_async_client,
|
||||||
|
create_db_session,
|
||||||
|
register_fixtures,
|
||||||
|
)
|
||||||
|
|
||||||
from .conftest import Role, RoleCrud, User, UserCrud
|
from .conftest import DATABASE_URL, Base, Role, RoleCrud, User, UserCrud
|
||||||
|
|
||||||
test_registry = FixtureRegistry()
|
test_registry = FixtureRegistry()
|
||||||
|
|
||||||
@@ -158,3 +166,102 @@ class TestGeneratedFixtures:
|
|||||||
|
|
||||||
assert len(roles) == 2
|
assert len(roles) == 2
|
||||||
assert len(users) == 2
|
assert len(users) == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateAsyncClient:
|
||||||
|
"""Tests for create_async_client helper."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_creates_working_client(self):
|
||||||
|
"""Client can make requests to the app."""
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health():
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
async with create_async_client(app) as client:
|
||||||
|
assert isinstance(client, AsyncClient)
|
||||||
|
response = await client.get("/health")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"status": "ok"}
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_custom_base_url(self):
|
||||||
|
"""Client uses custom base URL."""
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.get("/test")
|
||||||
|
async def test_endpoint():
|
||||||
|
return {"url": "test"}
|
||||||
|
|
||||||
|
async with create_async_client(app, base_url="http://custom") as client:
|
||||||
|
assert str(client.base_url) == "http://custom"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_client_closes_properly(self):
|
||||||
|
"""Client is properly closed after context exit."""
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
async with create_async_client(app) as client:
|
||||||
|
client_ref = client
|
||||||
|
|
||||||
|
assert client_ref.is_closed
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateDbSession:
|
||||||
|
"""Tests for create_db_session helper."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_creates_working_session(self):
|
||||||
|
"""Session can perform database operations."""
|
||||||
|
async with create_db_session(DATABASE_URL, Base) as session:
|
||||||
|
assert isinstance(session, AsyncSession)
|
||||||
|
|
||||||
|
role = Role(id=9001, name="test_helper_role")
|
||||||
|
session.add(role)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
result = await session.execute(select(Role).where(Role.id == 9001))
|
||||||
|
fetched = result.scalar_one()
|
||||||
|
assert fetched.name == "test_helper_role"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_tables_created_before_session(self):
|
||||||
|
"""Tables exist when session is yielded."""
|
||||||
|
async with create_db_session(DATABASE_URL, Base) as session:
|
||||||
|
# Should not raise - tables exist
|
||||||
|
result = await session.execute(select(Role))
|
||||||
|
assert result.all() == []
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_tables_dropped_after_session(self):
|
||||||
|
"""Tables are dropped after session closes when drop_tables=True."""
|
||||||
|
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
|
||||||
|
role = Role(id=9002, name="will_be_dropped")
|
||||||
|
session.add(role)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# Verify tables were dropped by creating new session
|
||||||
|
async with create_db_session(DATABASE_URL, Base) as session:
|
||||||
|
result = await session.execute(select(Role))
|
||||||
|
assert result.all() == []
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_tables_preserved_when_drop_disabled(self):
|
||||||
|
"""Tables are preserved when drop_tables=False."""
|
||||||
|
async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session:
|
||||||
|
role = Role(id=9003, name="preserved_role")
|
||||||
|
session.add(role)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# Create another session without dropping
|
||||||
|
async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session:
|
||||||
|
result = await session.execute(select(Role).where(Role.id == 9003))
|
||||||
|
fetched = result.scalar_one_or_none()
|
||||||
|
assert fetched is not None
|
||||||
|
assert fetched.name == "preserved_role"
|
||||||
|
|
||||||
|
# Cleanup: drop tables manually
|
||||||
|
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as _:
|
||||||
|
pass
|
||||||
Reference in New Issue
Block a user