mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-03-01 17:00:48 +01:00
test: add missing tests for fixtures/utils.py
This commit is contained in:
@@ -15,6 +15,67 @@ from .registry import Context, FixtureRegistry
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
async def _load_ordered(
|
||||
session: AsyncSession,
|
||||
registry: FixtureRegistry,
|
||||
ordered_names: list[str],
|
||||
strategy: LoadStrategy,
|
||||
) -> dict[str, list[DeclarativeBase]]:
|
||||
"""Load fixtures in order."""
|
||||
results: dict[str, list[DeclarativeBase]] = {}
|
||||
|
||||
for name in ordered_names:
|
||||
fixture = registry.get(name)
|
||||
instances = list(fixture.func())
|
||||
|
||||
if not instances:
|
||||
results[name] = []
|
||||
continue
|
||||
|
||||
model_name = type(instances[0]).__name__
|
||||
loaded: list[DeclarativeBase] = []
|
||||
|
||||
async with get_transaction(session):
|
||||
for instance in instances:
|
||||
if strategy == LoadStrategy.INSERT:
|
||||
session.add(instance)
|
||||
loaded.append(instance)
|
||||
|
||||
elif strategy == LoadStrategy.MERGE:
|
||||
merged = await session.merge(instance)
|
||||
loaded.append(merged)
|
||||
|
||||
else: # LoadStrategy.SKIP_EXISTING
|
||||
pk = _get_primary_key(instance)
|
||||
if pk is not None:
|
||||
existing = await session.get(type(instance), pk)
|
||||
if existing is None:
|
||||
session.add(instance)
|
||||
loaded.append(instance)
|
||||
else:
|
||||
session.add(instance)
|
||||
loaded.append(instance)
|
||||
|
||||
results[name] = loaded
|
||||
logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
|
||||
"""Get the primary key value of a model instance."""
|
||||
mapper = instance.__class__.__mapper__
|
||||
pk_cols = mapper.primary_key
|
||||
|
||||
if len(pk_cols) == 1:
|
||||
return getattr(instance, pk_cols[0].name, None)
|
||||
|
||||
pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols)
|
||||
if all(v is not None for v in pk_values):
|
||||
return pk_values
|
||||
return None
|
||||
|
||||
|
||||
def get_obj_by_attr(
|
||||
fixtures: Callable[[], Sequence[ModelType]], attr_name: str, value: Any
|
||||
) -> ModelType:
|
||||
@@ -56,13 +117,6 @@ async def load_fixtures(
|
||||
|
||||
Returns:
|
||||
Dict mapping fixture names to loaded instances
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Loads 'roles' first (dependency), then 'users'
|
||||
result = await load_fixtures(session, fixtures, "users")
|
||||
print(result["users"]) # [User(...), ...]
|
||||
```
|
||||
"""
|
||||
ordered = registry.resolve_dependencies(*names)
|
||||
return await _load_ordered(session, registry, ordered, strategy)
|
||||
@@ -84,76 +138,6 @@ async def load_fixtures_by_context(
|
||||
|
||||
Returns:
|
||||
Dict mapping fixture names to loaded instances
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Load base + testing fixtures
|
||||
await load_fixtures_by_context(
|
||||
session, fixtures,
|
||||
Context.BASE, Context.TESTING
|
||||
)
|
||||
```
|
||||
"""
|
||||
ordered = registry.resolve_context_dependencies(*contexts)
|
||||
return await _load_ordered(session, registry, ordered, strategy)
|
||||
|
||||
|
||||
async def _load_ordered(
|
||||
session: AsyncSession,
|
||||
registry: FixtureRegistry,
|
||||
ordered_names: list[str],
|
||||
strategy: LoadStrategy,
|
||||
) -> dict[str, list[DeclarativeBase]]:
|
||||
"""Load fixtures in order."""
|
||||
results: dict[str, list[DeclarativeBase]] = {}
|
||||
|
||||
for name in ordered_names:
|
||||
fixture = registry.get(name)
|
||||
instances = list(fixture.func())
|
||||
|
||||
if not instances:
|
||||
results[name] = []
|
||||
continue
|
||||
|
||||
model_name = type(instances[0]).__name__
|
||||
loaded: list[DeclarativeBase] = []
|
||||
|
||||
async with get_transaction(session):
|
||||
for instance in instances:
|
||||
if strategy == LoadStrategy.INSERT:
|
||||
session.add(instance)
|
||||
loaded.append(instance)
|
||||
|
||||
elif strategy == LoadStrategy.MERGE:
|
||||
merged = await session.merge(instance)
|
||||
loaded.append(merged)
|
||||
|
||||
elif strategy == LoadStrategy.SKIP_EXISTING:
|
||||
pk = _get_primary_key(instance)
|
||||
if pk is not None:
|
||||
existing = await session.get(type(instance), pk)
|
||||
if existing is None:
|
||||
session.add(instance)
|
||||
loaded.append(instance)
|
||||
else:
|
||||
session.add(instance)
|
||||
loaded.append(instance)
|
||||
|
||||
results[name] = loaded
|
||||
logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
|
||||
"""Get the primary key value of a model instance."""
|
||||
mapper = instance.__class__.__mapper__
|
||||
pk_cols = mapper.primary_key
|
||||
|
||||
if len(pk_cols) == 1:
|
||||
return getattr(instance, pk_cols[0].name, None)
|
||||
|
||||
pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols)
|
||||
if all(v is not None for v in pk_values):
|
||||
return pk_values
|
||||
return None
|
||||
|
||||
@@ -92,6 +92,15 @@ class IntRole(Base):
|
||||
name: Mapped[str] = mapped_column(String(50), unique=True)
|
||||
|
||||
|
||||
class Permission(Base):
|
||||
"""Test model with composite primary key."""
|
||||
|
||||
__tablename__ = "permissions"
|
||||
|
||||
subject: Mapped[str] = mapped_column(String(50), primary_key=True)
|
||||
action: Mapped[str] = mapped_column(String(50), primary_key=True)
|
||||
|
||||
|
||||
class Event(Base):
|
||||
"""Test model with DateTime and Date cursor columns."""
|
||||
|
||||
|
||||
@@ -14,7 +14,9 @@ from fastapi_toolsets.fixtures import (
|
||||
load_fixtures_by_context,
|
||||
)
|
||||
|
||||
from .conftest import Role, User
|
||||
from fastapi_toolsets.fixtures.utils import _get_primary_key
|
||||
|
||||
from .conftest import IntRole, Permission, Role, User
|
||||
|
||||
|
||||
class TestContext:
|
||||
@@ -597,6 +599,46 @@ class TestLoadFixtures:
|
||||
count = await RoleCrud.count(db_session)
|
||||
assert count == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_skip_existing_skips_if_record_exists(self, db_session: AsyncSession):
|
||||
"""SKIP_EXISTING returns empty loaded list when the record already exists."""
|
||||
registry = FixtureRegistry()
|
||||
role_id = uuid.uuid4()
|
||||
|
||||
@registry.register
|
||||
def roles():
|
||||
return [Role(id=role_id, name="admin")]
|
||||
|
||||
# First load — inserts the record.
|
||||
result1 = await load_fixtures(
|
||||
db_session, registry, "roles", strategy=LoadStrategy.SKIP_EXISTING
|
||||
)
|
||||
assert len(result1["roles"]) == 1
|
||||
|
||||
# Remove from identity map so session.get() queries the DB in the second load.
|
||||
db_session.expunge_all()
|
||||
|
||||
# Second load — record exists in DB, nothing should be added.
|
||||
result2 = await load_fixtures(
|
||||
db_session, registry, "roles", strategy=LoadStrategy.SKIP_EXISTING
|
||||
)
|
||||
assert result2["roles"] == []
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_skip_existing_null_pk_inserts(self, db_session: AsyncSession):
|
||||
"""SKIP_EXISTING inserts when the instance has no PK set (auto-increment)."""
|
||||
registry = FixtureRegistry()
|
||||
|
||||
@registry.register
|
||||
def int_roles():
|
||||
# No id provided — PK is None before INSERT (autoincrement).
|
||||
return [IntRole(name="member")]
|
||||
|
||||
result = await load_fixtures(
|
||||
db_session, registry, "int_roles", strategy=LoadStrategy.SKIP_EXISTING
|
||||
)
|
||||
assert len(result["int_roles"]) == 1
|
||||
|
||||
|
||||
class TestLoadFixturesByContext:
|
||||
"""Tests for load_fixtures_by_context function."""
|
||||
@@ -755,3 +797,19 @@ class TestGetObjByAttr:
|
||||
"""Raises StopIteration when value type doesn't match."""
|
||||
with pytest.raises(StopIteration):
|
||||
get_obj_by_attr(self.roles, "id", "not-a-uuid")
|
||||
|
||||
|
||||
class TestGetPrimaryKey:
|
||||
"""Unit tests for the _get_primary_key helper (composite PK paths)."""
|
||||
|
||||
def test_composite_pk_all_set(self):
|
||||
"""Returns a tuple when all composite PK values are set."""
|
||||
instance = Permission(subject="post", action="read")
|
||||
pk = _get_primary_key(instance)
|
||||
assert pk == ("post", "read")
|
||||
|
||||
def test_composite_pk_partial_none(self):
|
||||
"""Returns None when any composite PK value is None."""
|
||||
instance = Permission(subject="post") # action is None
|
||||
pk = _get_primary_key(instance)
|
||||
assert pk is None
|
||||
|
||||
Reference in New Issue
Block a user