Compare commits

...

8 Commits

14 changed files with 889 additions and 123 deletions

1
docs/CNAME Normal file
View File

@@ -0,0 +1 @@
fastapi-toolsets.d3vyce.fr

View File

@@ -474,7 +474,7 @@ The distinct values are returned in the `filter_attributes` field of [`Paginated
"filter_attributes": { "filter_attributes": {
"status": ["active", "inactive"], "status": ["active", "inactive"],
"country": ["DE", "FR", "US"], "country": ["DE", "FR", "US"],
"name": ["admin", "editor", "viewer"] "role__name": ["admin", "editor", "viewer"]
} }
} }
``` ```
@@ -482,7 +482,7 @@ The distinct values are returned in the `filter_attributes` field of [`Paginated
Use `filter_by` to pass the client's chosen filter values directly — no need to build SQLAlchemy conditions by hand. Any unknown key raises [`InvalidFacetFilterError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.InvalidFacetFilterError). Use `filter_by` to pass the client's chosen filter values directly — no need to build SQLAlchemy conditions by hand. Any unknown key raises [`InvalidFacetFilterError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.InvalidFacetFilterError).
!!! info "The keys in `filter_by` are the same keys the client received in `filter_attributes`." !!! info "The keys in `filter_by` are the same keys the client received in `filter_attributes`."
Keys are normally the terminal `column.key` (e.g. `"name"` for `Role.name`). When two facet fields share the same column key (e.g. `(Build.project, Project.name)` and `(Build.os, Os.name)`), the relationship name is prepended automatically: `"project__name"` and `"os__name"`. Keys use `__` as a separator for the full relationship chain. A direct column `User.status` produces `"status"`. A relationship tuple `(User.role, Role.name)` produces `"role__name"`. A deeper chain `(User.role, Role.permission, Permission.name)` produces `"role__permission__name"`.
`filter_by` and `filters` can be combined — both are applied with AND logic. `filter_by` and `filters` can be combined — both are applied with AND logic.
@@ -515,9 +515,9 @@ async def list_users(
Both single-value and multi-value query parameters work: Both single-value and multi-value query parameters work:
``` ```
GET /users?status=active → filter_by={"status": ["active"]} GET /users?status=active → filter_by={"status": ["active"]}
GET /users?status=active&country=FR → filter_by={"status": ["active"], "country": ["FR"]} GET /users?status=active&country=FR → filter_by={"status": ["active"], "country": ["FR"]}
GET /users?role=admin&role=editor → filter_by={"role": ["admin", "editor"]} (IN clause) GET /users?role__name=admin&role__name=editor → filter_by={"role__name": ["admin", "editor"]} (IN clause)
``` ```
## Sorting ## Sorting

View File

@@ -38,18 +38,20 @@ By context with [`load_fixtures_by_context`](../reference/fixtures.md#fastapi_to
from fastapi_toolsets.fixtures import load_fixtures_by_context from fastapi_toolsets.fixtures import load_fixtures_by_context
async with db_context() as session: async with db_context() as session:
await load_fixtures_by_context(session=session, registry=fixtures, context=Context.TESTING) await load_fixtures_by_context(session, fixtures, Context.TESTING)
``` ```
Directly with [`load_fixtures`](../reference/fixtures.md#fastapi_toolsets.fixtures.utils.load_fixtures): Directly by name with [`load_fixtures`](../reference/fixtures.md#fastapi_toolsets.fixtures.utils.load_fixtures):
```python ```python
from fastapi_toolsets.fixtures import load_fixtures from fastapi_toolsets.fixtures import load_fixtures
async with db_context() as session: async with db_context() as session:
await load_fixtures(session=session, registry=fixtures) await load_fixtures(session, fixtures, "roles", "test_users")
``` ```
Both functions return a `dict[str, list[...]]` mapping each fixture name to the list of loaded instances.
## Contexts ## Contexts
[`Context`](../reference/fixtures.md#fastapi_toolsets.fixtures.enum.Context) is an enum with predefined values: [`Context`](../reference/fixtures.md#fastapi_toolsets.fixtures.enum.Context) is an enum with predefined values:
@@ -58,10 +60,60 @@ async with db_context() as session:
|---------|-------------| |---------|-------------|
| `Context.BASE` | Core data required in all environments | | `Context.BASE` | Core data required in all environments |
| `Context.TESTING` | Data only loaded during tests | | `Context.TESTING` | Data only loaded during tests |
| `Context.DEVELOPMENT` | Data only loaded in development |
| `Context.PRODUCTION` | Data only loaded in production | | `Context.PRODUCTION` | Data only loaded in production |
A fixture with no `contexts` defined takes `Context.BASE` by default. A fixture with no `contexts` defined takes `Context.BASE` by default.
### Custom contexts
Plain strings and any `Enum` subclass are accepted wherever a `Context` enum is expected.
```python
from enum import Enum
class AppContext(str, Enum):
STAGING = "staging"
DEMO = "demo"
@fixtures.register(contexts=[AppContext.STAGING])
def staging_data():
return [Config(key="feature_x", enabled=True)]
await load_fixtures_by_context(session, fixtures, AppContext.STAGING)
```
### Default context for a registry
Pass `contexts` to `FixtureRegistry` to set a default for all fixtures registered in it:
```python
testing_registry = FixtureRegistry(contexts=[Context.TESTING])
@testing_registry.register # implicitly contexts=[Context.TESTING]
def test_orders():
return [Order(id=1, total=99)]
```
### Same fixture name, multiple context variants
The same fixture name may be registered under different (non-overlapping) context sets. When multiple contexts are loaded together, all matching variants are merged:
```python
@fixtures.register(contexts=[Context.BASE])
def users():
return [User(id=1, username="admin")]
@fixtures.register(contexts=[Context.TESTING])
def users():
return [User(id=2, username="tester")]
# loads both admin and tester
await load_fixtures_by_context(session, fixtures, Context.BASE, Context.TESTING)
```
Registering two variants with overlapping context sets raises `ValueError`.
## Load strategies ## Load strategies
[`LoadStrategy`](../reference/fixtures.md#fastapi_toolsets.fixtures.enum.LoadStrategy) controls how the fixture loader handles rows that already exist: [`LoadStrategy`](../reference/fixtures.md#fastapi_toolsets.fixtures.enum.LoadStrategy) controls how the fixture loader handles rows that already exist:
@@ -69,20 +121,44 @@ A fixture with no `contexts` defined takes `Context.BASE` by default.
| Strategy | Description | | Strategy | Description |
|----------|-------------| |----------|-------------|
| `LoadStrategy.INSERT` | Insert only, fail on duplicates | | `LoadStrategy.INSERT` | Insert only, fail on duplicates |
| `LoadStrategy.UPSERT` | Insert or update on conflict | | `LoadStrategy.MERGE` | Insert or update on conflict (default) |
| `LoadStrategy.SKIP` | Skip rows that already exist | | `LoadStrategy.SKIP_EXISTING` | Skip rows that already exist |
```python
await load_fixtures_by_context(
session, fixtures, Context.BASE, strategy=LoadStrategy.SKIP_EXISTING
)
```
## Merging registries ## Merging registries
Split fixtures definitions across modules and merge them: Split fixture definitions across modules and merge them:
```python ```python
from myapp.fixtures.dev import dev_fixtures from myapp.fixtures.dev import dev_fixtures
from myapp.fixtures.prod import prod_fixtures from myapp.fixtures.prod import prod_fixtures
fixtures = fixturesRegistry() fixtures = FixtureRegistry()
fixtures.include_registry(registry=dev_fixtures) fixtures.include_registry(registry=dev_fixtures)
fixtures.include_registry(registry=prod_fixtures) fixtures.include_registry(registry=prod_fixtures)
```
Fixtures with the same name are allowed as long as their context sets do not overlap. Conflicting contexts raise `ValueError`.
## Looking up fixture instances
[`get_obj_by_attr`](../reference/fixtures.md#fastapi_toolsets.fixtures.utils.get_obj_by_attr) retrieves a specific instance from a fixture function by attribute value — useful when building cross-fixture `depends_on` relationships:
```python
from fastapi_toolsets.fixtures import get_obj_by_attr
@fixtures.register(depends_on=["roles"])
def users():
admin_role = get_obj_by_attr(roles, "name", "admin")
return [User(id=1, username="alice", role_id=admin_role.id)]
```
Raises `StopIteration` if no matching instance is found.
## Pytest integration ## Pytest integration
@@ -111,7 +187,6 @@ async def test_user_can_login(fixture_users: list[User], fixture_roles: list[Rol
... ...
``` ```
The load order is resolved automatically from the `depends_on` declarations in your registry. Each generated fixture receives `db_session` as a dependency and returns the list of loaded model instances. The load order is resolved automatically from the `depends_on` declarations in your registry. Each generated fixture receives `db_session` as a dependency and returns the list of loaded model instances.
## CLI integration ## CLI integration

View File

@@ -2,7 +2,6 @@
import asyncio import asyncio
import functools import functools
from collections import Counter
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass, replace from dataclasses import dataclass, replace
from typing import TYPE_CHECKING, Any, Literal from typing import TYPE_CHECKING, Any, Literal
@@ -151,7 +150,7 @@ def build_search_filters(
def facet_keys(facet_fields: Sequence[FacetFieldType]) -> list[str]: def facet_keys(facet_fields: Sequence[FacetFieldType]) -> list[str]:
"""Return a key for each facet field, disambiguating duplicate column keys. """Return a key for each facet field.
Args: Args:
facet_fields: Sequence of facet fields — either direct columns or facet_fields: Sequence of facet fields — either direct columns or
@@ -160,22 +159,12 @@ def facet_keys(facet_fields: Sequence[FacetFieldType]) -> list[str]:
Returns: Returns:
A list of string keys, one per facet field, in the same order. A list of string keys, one per facet field, in the same order.
""" """
raw: list[tuple[str, str | None]] = [] keys: list[str] = []
for field in facet_fields: for field in facet_fields:
if isinstance(field, tuple): if isinstance(field, tuple):
rel = field[-2] keys.append("__".join(el.key for el in field))
column = field[-1]
raw.append((column.key, rel.key))
else: else:
raw.append((field.key, None)) keys.append(field.key)
counts = Counter(col_key for col_key, _ in raw)
keys: list[str] = []
for col_key, rel_key in raw:
if counts[col_key] > 1 and rel_key is not None:
keys.append(f"{rel_key}__{col_key}")
else:
keys.append(col_key)
return keys return keys

View File

@@ -2,6 +2,7 @@
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum
from typing import Any, cast from typing import Any, cast
from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import DeclarativeBase
@@ -12,6 +13,13 @@ from .enum import Context
logger = get_logger() logger = get_logger()
def _normalize_contexts(
contexts: list[str | Enum] | tuple[str | Enum, ...],
) -> list[str]:
"""Convert a sequence of any Enum subclass and/or plain strings to a list of strings."""
return [c.value if isinstance(c, Enum) else c for c in contexts]
@dataclass @dataclass
class Fixture: class Fixture:
"""A fixture definition with metadata.""" """A fixture definition with metadata."""
@@ -50,26 +58,51 @@ class FixtureRegistry:
Post(id=1, title="Test", user_id=1), Post(id=1, title="Test", user_id=1),
] ]
``` ```
Fixtures with the same name may be registered for **different** contexts.
When multiple contexts are loaded together, their instances are merged:
```python
@fixtures.register(contexts=[Context.BASE])
def users():
return [User(id=1, username="admin")]
@fixtures.register(contexts=[Context.TESTING])
def users():
return [User(id=2, username="tester")]
# load_fixtures_by_context(..., Context.BASE, Context.TESTING)
# → loads both User(admin) and User(tester) under the "users" name
```
""" """
def __init__( def __init__(
self, self,
contexts: list[str | Context] | None = None, contexts: list[str | Enum] | None = None,
) -> None: ) -> None:
self._fixtures: dict[str, Fixture] = {} self._fixtures: dict[str, list[Fixture]] = {}
self._default_contexts: list[str] | None = ( self._default_contexts: list[str] | None = (
[c.value if isinstance(c, Context) else c for c in contexts] _normalize_contexts(contexts) if contexts else None
if contexts
else None
) )
def _validate_no_context_overlap(self, name: str, new_contexts: list[str]) -> None:
"""Raise ``ValueError`` if any existing variant for *name* overlaps."""
existing_variants = self._fixtures.get(name, [])
new_set = set(new_contexts)
for variant in existing_variants:
if set(variant.contexts) & new_set:
raise ValueError(
f"Fixture '{name}' already exists in the current registry "
f"with overlapping contexts. Use distinct context sets for "
f"each variant of the same fixture name."
)
def register( def register(
self, self,
func: Callable[[], Sequence[DeclarativeBase]] | None = None, func: Callable[[], Sequence[DeclarativeBase]] | None = None,
*, *,
name: str | None = None, name: str | None = None,
depends_on: list[str] | None = None, depends_on: list[str] | None = None,
contexts: list[str | Context] | None = None, contexts: list[str | Enum] | None = None,
) -> Callable[..., Any]: ) -> Callable[..., Any]:
"""Register a fixture function. """Register a fixture function.
@@ -79,7 +112,8 @@ class FixtureRegistry:
func: Fixture function returning list of model instances func: Fixture function returning list of model instances
name: Fixture name (defaults to function name) name: Fixture name (defaults to function name)
depends_on: List of fixture names this depends on depends_on: List of fixture names this depends on
contexts: List of contexts this fixture belongs to contexts: List of contexts this fixture belongs to. Both
:class:`Context` enum values and plain strings are accepted.
Example: Example:
```python ```python
@@ -90,7 +124,6 @@ class FixtureRegistry:
@fixtures.register(depends_on=["roles"], contexts=[Context.TESTING]) @fixtures.register(depends_on=["roles"], contexts=[Context.TESTING])
def test_users(): def test_users():
return [User(id=1, username="test", role_id=1)] return [User(id=1, username="test", role_id=1)]
```
""" """
def decorator( def decorator(
@@ -98,19 +131,20 @@ class FixtureRegistry:
) -> Callable[[], Sequence[DeclarativeBase]]: ) -> Callable[[], Sequence[DeclarativeBase]]:
fixture_name = name or cast(Any, fn).__name__ fixture_name = name or cast(Any, fn).__name__
if contexts is not None: if contexts is not None:
fixture_contexts = [ fixture_contexts = _normalize_contexts(contexts)
c.value if isinstance(c, Context) else c for c in contexts
]
elif self._default_contexts is not None: elif self._default_contexts is not None:
fixture_contexts = self._default_contexts fixture_contexts = self._default_contexts
else: else:
fixture_contexts = [Context.BASE.value] fixture_contexts = [Context.BASE.value]
self._fixtures[fixture_name] = Fixture( self._validate_no_context_overlap(fixture_name, fixture_contexts)
name=fixture_name, self._fixtures.setdefault(fixture_name, []).append(
func=fn, Fixture(
depends_on=depends_on or [], name=fixture_name,
contexts=fixture_contexts, func=fn,
depends_on=depends_on or [],
contexts=fixture_contexts,
)
) )
return fn return fn
@@ -121,11 +155,14 @@ class FixtureRegistry:
def include_registry(self, registry: "FixtureRegistry") -> None: def include_registry(self, registry: "FixtureRegistry") -> None:
"""Include another `FixtureRegistry` in the same current `FixtureRegistry`. """Include another `FixtureRegistry` in the same current `FixtureRegistry`.
Fixtures with the same name are allowed as long as their context sets
do not overlap. Conflicting contexts raise :class:`ValueError`.
Args: Args:
registry: The `FixtureRegistry` to include registry: The `FixtureRegistry` to include
Raises: Raises:
ValueError: If a fixture name already exists in the current registry ValueError: If a fixture name already exists with overlapping contexts
Example: Example:
```python ```python
@@ -139,31 +176,73 @@ class FixtureRegistry:
registry.include_registry(registry=dev_registry) registry.include_registry(registry=dev_registry)
``` ```
""" """
for name, fixture in registry._fixtures.items(): for name, variants in registry._fixtures.items():
if name in self._fixtures: for fixture in variants:
raise ValueError( self._validate_no_context_overlap(name, fixture.contexts)
f"Fixture '{name}' already exists in the current registry" self._fixtures.setdefault(name, []).append(fixture)
)
self._fixtures[name] = fixture
def get(self, name: str) -> Fixture: def get(self, name: str) -> Fixture:
"""Get a fixture by name.""" """Get a fixture by name.
Raises:
KeyError: If no fixture with *name* is registered.
ValueError: If the fixture has multiple context variants — use
:meth:`get_variants` in that case.
"""
if name not in self._fixtures: if name not in self._fixtures:
raise KeyError(f"Fixture '{name}' not found") raise KeyError(f"Fixture '{name}' not found")
return self._fixtures[name] variants = self._fixtures[name]
if len(variants) > 1:
raise ValueError(
f"Fixture '{name}' has {len(variants)} context variants. "
f"Use get_variants('{name}') to retrieve them."
)
return variants[0]
def get_variants(self, name: str, *contexts: str | Enum) -> list[Fixture]:
"""Return all registered variants for *name*, optionally filtered by context.
Args:
name: Fixture name.
*contexts: If given, only return variants whose context set
intersects with these values. Both :class:`Context` enum
values and plain strings are accepted.
Returns:
List of matching :class:`Fixture` objects (may be empty when a
context filter is applied and nothing matches).
Raises:
KeyError: If no fixture with *name* is registered.
"""
if name not in self._fixtures:
raise KeyError(f"Fixture '{name}' not found")
variants = self._fixtures[name]
if not contexts:
return list(variants)
context_values = set(_normalize_contexts(contexts))
return [v for v in variants if set(v.contexts) & context_values]
def get_all(self) -> list[Fixture]: def get_all(self) -> list[Fixture]:
"""Get all registered fixtures.""" """Get all registered fixtures (all variants of all names)."""
return list(self._fixtures.values()) return [f for variants in self._fixtures.values() for f in variants]
def get_by_context(self, *contexts: str | Context) -> list[Fixture]: def get_by_context(self, *contexts: str | Enum) -> list[Fixture]:
"""Get fixtures for specific contexts.""" """Get fixtures for specific contexts."""
context_values = {c.value if isinstance(c, Context) else c for c in contexts} context_values = set(_normalize_contexts(contexts))
return [f for f in self._fixtures.values() if set(f.contexts) & context_values] return [
f
for variants in self._fixtures.values()
for f in variants
if set(f.contexts) & context_values
]
def resolve_dependencies(self, *names: str) -> list[str]: def resolve_dependencies(self, *names: str) -> list[str]:
"""Resolve fixture dependencies in topological order. """Resolve fixture dependencies in topological order.
When a fixture name has multiple context variants, the union of all
variants' ``depends_on`` lists is used.
Args: Args:
*names: Fixture names to resolve *names: Fixture names to resolve
@@ -185,9 +264,20 @@ class FixtureRegistry:
raise ValueError(f"Circular dependency detected: {name}") raise ValueError(f"Circular dependency detected: {name}")
visiting.add(name) visiting.add(name)
fixture = self.get(name) variants = self._fixtures.get(name)
if variants is None:
raise KeyError(f"Fixture '{name}' not found")
for dep in fixture.depends_on: # Union of depends_on across all variants, preserving first-seen order.
seen_deps: set[str] = set()
all_deps: list[str] = []
for variant in variants:
for dep in variant.depends_on:
if dep not in seen_deps:
all_deps.append(dep)
seen_deps.add(dep)
for dep in all_deps:
visit(dep) visit(dep)
visiting.remove(name) visiting.remove(name)
@@ -199,7 +289,7 @@ class FixtureRegistry:
return resolved return resolved
def resolve_context_dependencies(self, *contexts: str | Context) -> list[str]: def resolve_context_dependencies(self, *contexts: str | Enum) -> list[str]:
"""Resolve all fixtures for contexts with dependencies. """Resolve all fixtures for contexts with dependencies.
Args: Args:
@@ -209,7 +299,9 @@ class FixtureRegistry:
List of fixture names in load order List of fixture names in load order
""" """
context_fixtures = self.get_by_context(*contexts) context_fixtures = self.get_by_context(*contexts)
names = [f.name for f in context_fixtures] # Deduplicate names while preserving first-seen order (a name can
# appear multiple times if it has variants in different contexts).
names = list(dict.fromkeys(f.name for f in context_fixtures))
all_deps: set[str] = set() all_deps: set[str] = set()
for name in names: for name in names:

View File

@@ -1,8 +1,11 @@
"""Fixture loading utilities for database seeding.""" """Fixture loading utilities for database seeding."""
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from enum import Enum
from typing import Any from typing import Any
from sqlalchemy import inspect as sa_inspect
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import DeclarativeBase
@@ -10,23 +13,163 @@ from ..db import get_transaction
from ..logger import get_logger from ..logger import get_logger
from ..types import ModelType from ..types import ModelType
from .enum import LoadStrategy from .enum import LoadStrategy
from .registry import Context, FixtureRegistry from .registry import FixtureRegistry, _normalize_contexts
logger = get_logger() logger = get_logger()
def _instance_to_dict(instance: DeclarativeBase) -> dict[str, Any]:
"""Extract column values from a model instance, skipping unset server-default columns."""
state = sa_inspect(instance)
state_dict = state.dict
result: dict[str, Any] = {}
for prop in state.mapper.column_attrs:
if prop.key not in state_dict:
continue
val = state_dict[prop.key]
if val is None:
col = prop.columns[0]
if col.server_default is not None or (
col.default is not None and col.default.is_callable
):
continue
result[prop.key] = val
return result
def _normalize_rows(dicts: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Ensure all row dicts share the same key set."""
all_keys: set[str] = set().union(*dicts)
return [{k: d.get(k) for k in all_keys} for d in dicts]
def _group_by_type(
instances: list[DeclarativeBase],
) -> list[tuple[type[DeclarativeBase], list[DeclarativeBase]]]:
"""Group instances by their concrete model class, preserving insertion order."""
groups: dict[type[DeclarativeBase], list[DeclarativeBase]] = {}
for instance in instances:
groups.setdefault(type(instance), []).append(instance)
return list(groups.items())
async def _batch_insert(
session: AsyncSession,
model_cls: type[DeclarativeBase],
instances: list[DeclarativeBase],
) -> None:
"""INSERT all instances — raises on conflict (no duplicate handling)."""
dicts = _normalize_rows([_instance_to_dict(i) for i in instances])
await session.execute(pg_insert(model_cls).values(dicts))
async def _batch_merge(
session: AsyncSession,
model_cls: type[DeclarativeBase],
instances: list[DeclarativeBase],
) -> None:
"""UPSERT: insert new rows, update existing ones with the provided values."""
mapper = model_cls.__mapper__
pk_names = [col.name for col in mapper.primary_key]
pk_names_set = set(pk_names)
non_pk_cols = [
prop.key
for prop in mapper.column_attrs
if not any(col.name in pk_names_set for col in prop.columns)
]
dicts = _normalize_rows([_instance_to_dict(i) for i in instances])
stmt = pg_insert(model_cls).values(dicts)
inserted_keys = set(dicts[0]) if dicts else set()
update_cols = [col for col in non_pk_cols if col in inserted_keys]
if update_cols:
stmt = stmt.on_conflict_do_update(
index_elements=pk_names,
set_={col: stmt.excluded[col] for col in update_cols},
)
else:
stmt = stmt.on_conflict_do_nothing(index_elements=pk_names)
await session.execute(stmt)
async def _batch_skip_existing(
session: AsyncSession,
model_cls: type[DeclarativeBase],
instances: list[DeclarativeBase],
) -> list[DeclarativeBase]:
"""INSERT only rows that do not already exist; return the inserted ones."""
mapper = model_cls.__mapper__
pk_names = [col.name for col in mapper.primary_key]
no_pk: list[DeclarativeBase] = []
with_pk_pairs: list[tuple[DeclarativeBase, Any]] = []
for inst in instances:
pk = _get_primary_key(inst)
if pk is None:
no_pk.append(inst)
else:
with_pk_pairs.append((inst, pk))
loaded: list[DeclarativeBase] = list(no_pk)
if no_pk:
await session.execute(
pg_insert(model_cls).values(
_normalize_rows([_instance_to_dict(i) for i in no_pk])
)
)
if with_pk_pairs:
with_pk = [i for i, _ in with_pk_pairs]
stmt = (
pg_insert(model_cls)
.values(_normalize_rows([_instance_to_dict(i) for i in with_pk]))
.on_conflict_do_nothing(index_elements=pk_names)
)
result = await session.execute(stmt.returning(*mapper.primary_key))
inserted_pks = {row[0] if len(pk_names) == 1 else tuple(row) for row in result}
loaded.extend(inst for inst, pk in with_pk_pairs if pk in inserted_pks)
return loaded
async def _load_ordered( async def _load_ordered(
session: AsyncSession, session: AsyncSession,
registry: FixtureRegistry, registry: FixtureRegistry,
ordered_names: list[str], ordered_names: list[str],
strategy: LoadStrategy, strategy: LoadStrategy,
contexts: tuple[str, ...] | None = None,
) -> dict[str, list[DeclarativeBase]]: ) -> dict[str, list[DeclarativeBase]]:
"""Load fixtures in order.""" """Load fixtures in order using batch Core INSERT statements.
When *contexts* is provided only variants whose context set intersects with
*contexts* are called for each name; their instances are concatenated.
When *contexts* is ``None`` all variants of each name are loaded.
"""
results: dict[str, list[DeclarativeBase]] = {} results: dict[str, list[DeclarativeBase]] = {}
for name in ordered_names: for name in ordered_names:
fixture = registry.get(name) variants = (
instances = list(fixture.func()) registry.get_variants(name, *contexts)
if contexts is not None
else registry.get_variants(name)
)
# Cross-context dependency fallback: if we're loading by context but
# no variant matches (e.g. a "base"-only fixture required by a
# "testing" fixture), load all available variants so the dependency
# is satisfied.
if contexts is not None and not variants:
variants = registry.get_variants(name)
if not variants:
results[name] = []
continue
instances = [inst for v in variants for inst in v.func()]
if not instances: if not instances:
results[name] = [] results[name] = []
@@ -36,25 +179,17 @@ async def _load_ordered(
loaded: list[DeclarativeBase] = [] loaded: list[DeclarativeBase] = []
async with get_transaction(session): async with get_transaction(session):
for instance in instances: for model_cls, group in _group_by_type(instances):
if strategy == LoadStrategy.INSERT: match strategy:
session.add(instance) case LoadStrategy.INSERT:
loaded.append(instance) await _batch_insert(session, model_cls, group)
loaded.extend(group)
elif strategy == LoadStrategy.MERGE: case LoadStrategy.MERGE:
merged = await session.merge(instance) await _batch_merge(session, model_cls, group)
loaded.append(merged) loaded.extend(group)
case LoadStrategy.SKIP_EXISTING:
else: # LoadStrategy.SKIP_EXISTING inserted = await _batch_skip_existing(session, model_cls, group)
pk = _get_primary_key(instance) loaded.extend(inserted)
if pk is not None:
existing = await session.get(type(instance), pk)
if existing is None:
session.add(instance)
loaded.append(instance)
else:
session.add(instance)
loaded.append(instance)
results[name] = loaded results[name] = loaded
logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)") logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")
@@ -109,6 +244,8 @@ async def load_fixtures(
) -> dict[str, list[DeclarativeBase]]: ) -> dict[str, list[DeclarativeBase]]:
"""Load specific fixtures by name with dependencies. """Load specific fixtures by name with dependencies.
All context variants of each requested fixture are loaded and merged.
Args: Args:
session: Database session session: Database session
registry: Fixture registry registry: Fixture registry
@@ -125,19 +262,27 @@ async def load_fixtures(
async def load_fixtures_by_context( async def load_fixtures_by_context(
session: AsyncSession, session: AsyncSession,
registry: FixtureRegistry, registry: FixtureRegistry,
*contexts: str | Context, *contexts: str | Enum,
strategy: LoadStrategy = LoadStrategy.MERGE, strategy: LoadStrategy = LoadStrategy.MERGE,
) -> dict[str, list[DeclarativeBase]]: ) -> dict[str, list[DeclarativeBase]]:
"""Load all fixtures for specific contexts. """Load all fixtures for specific contexts.
For each fixture name, only the variants whose context set intersects with
*contexts* are loaded. When a name has variants in multiple of the
requested contexts, their instances are merged before being inserted.
Args: Args:
session: Database session session: Database session
registry: Fixture registry registry: Fixture registry
*contexts: Contexts to load (e.g., Context.BASE, Context.TESTING) *contexts: Contexts to load (e.g., ``Context.BASE``, ``Context.TESTING``,
or plain strings for custom contexts)
strategy: How to handle existing records strategy: How to handle existing records
Returns: Returns:
Dict mapping fixture names to loaded instances Dict mapping fixture names to loaded instances
""" """
context_strings = tuple(_normalize_contexts(contexts))
ordered = registry.resolve_context_dependencies(*contexts) ordered = registry.resolve_context_dependencies(*contexts)
return await _load_ordered(session, registry, ordered, strategy) return await _load_ordered(
session, registry, ordered, strategy, contexts=context_strings
)

View File

@@ -27,6 +27,7 @@ _SESSION_CREATES = "_ft_creates"
_SESSION_DELETES = "_ft_deletes" _SESSION_DELETES = "_ft_deletes"
_SESSION_UPDATES = "_ft_updates" _SESSION_UPDATES = "_ft_updates"
_SESSION_SAVEPOINT_DEPTH = "_ft_sp_depth" _SESSION_SAVEPOINT_DEPTH = "_ft_sp_depth"
_DEFERRED_STRATEGY_KEY = (("deferred", True), ("instrument", True))
class ModelEvent(str, Enum): class ModelEvent(str, Enum):
@@ -60,11 +61,22 @@ def _snapshot_column_attrs(obj: Any) -> dict[str, Any]:
"""Read currently-loaded column values into a plain dict.""" """Read currently-loaded column values into a plain dict."""
state = sa_inspect(obj) # InstanceState state = sa_inspect(obj) # InstanceState
state_dict = state.dict state_dict = state.dict
return { snapshot: dict[str, Any] = {}
prop.key: state_dict[prop.key] for prop in state.mapper.column_attrs:
for prop in state.mapper.column_attrs if prop.key in state_dict:
if prop.key in state_dict snapshot[prop.key] = state_dict[prop.key]
} elif (
not state.expired
and prop.strategy_key != _DEFERRED_STRATEGY_KEY
and all(
col.nullable
and col.server_default is None
and col.server_onupdate is None
for col in prop.columns
)
):
snapshot[prop.key] = None
return snapshot
def _get_watched_fields(cls: type) -> list[str] | None: def _get_watched_fields(cls: type) -> list[str] | None:

View File

@@ -7,6 +7,7 @@ from contextlib import asynccontextmanager
from typing import Any from typing import Any
from httpx import ASGITransport, AsyncClient from httpx import ASGITransport, AsyncClient
from sqlalchemy import text
from sqlalchemy.engine import make_url from sqlalchemy.engine import make_url
from sqlalchemy.ext.asyncio import ( from sqlalchemy.ext.asyncio import (
AsyncSession, AsyncSession,
@@ -15,13 +16,8 @@ from sqlalchemy.ext.asyncio import (
) )
from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import DeclarativeBase
from sqlalchemy import text from ..db import cleanup_tables as _cleanup_tables
from ..db import create_database
from ..db import (
cleanup_tables as _cleanup_tables,
create_database,
create_db_context,
)
async def cleanup_tables( async def cleanup_tables(
@@ -269,15 +265,12 @@ async def create_db_session(
async with engine.begin() as conn: async with engine.begin() as conn:
await conn.run_sync(base.metadata.create_all) await conn.run_sync(base.metadata.create_all)
# Create session using existing db context utility
session_maker = async_sessionmaker(engine, expire_on_commit=expire_on_commit) session_maker = async_sessionmaker(engine, expire_on_commit=expire_on_commit)
get_session = create_db_context(session_maker) async with session_maker() as session:
async with get_session() as session:
yield session yield session
if cleanup: if cleanup:
await cleanup_tables(session, base) await _cleanup_tables(session=session, base=base)
if drop_tables: if drop_tables:
async with engine.begin() as conn: async with engine.begin() as conn:

View File

@@ -57,6 +57,7 @@ class User(Base):
username: Mapped[str] = mapped_column(String(50), unique=True) username: Mapped[str] = mapped_column(String(50), unique=True)
email: Mapped[str] = mapped_column(String(100), unique=True) email: Mapped[str] = mapped_column(String(100), unique=True)
is_active: Mapped[bool] = mapped_column(default=True) is_active: Mapped[bool] = mapped_column(default=True)
notes: Mapped[str | None]
role_id: Mapped[uuid.UUID | None] = mapped_column( role_id: Mapped[uuid.UUID | None] = mapped_column(
ForeignKey("roles.id"), nullable=True ForeignKey("roles.id"), nullable=True
) )

View File

@@ -646,7 +646,7 @@ class TestFacetsRelationship:
result = await UserRelFacetCrud.offset_paginate(db_session, schema=UserRead) result = await UserRelFacetCrud.offset_paginate(db_session, schema=UserRead)
assert result.filter_attributes is not None assert result.filter_attributes is not None
assert set(result.filter_attributes["name"]) == {"admin", "editor"} assert set(result.filter_attributes["role__name"]) == {"admin", "editor"}
@pytest.mark.anyio @pytest.mark.anyio
async def test_relationship_facet_none_excluded(self, db_session: AsyncSession): async def test_relationship_facet_none_excluded(self, db_session: AsyncSession):
@@ -661,7 +661,7 @@ class TestFacetsRelationship:
result = await UserRelFacetCrud.offset_paginate(db_session, schema=UserRead) result = await UserRelFacetCrud.offset_paginate(db_session, schema=UserRead)
assert result.filter_attributes is not None assert result.filter_attributes is not None
assert result.filter_attributes["name"] == [] assert result.filter_attributes["role__name"] == []
@pytest.mark.anyio @pytest.mark.anyio
async def test_relationship_facet_deduplicates_join_with_search( async def test_relationship_facet_deduplicates_join_with_search(
@@ -689,7 +689,7 @@ class TestFacetsRelationship:
) )
assert result.filter_attributes is not None assert result.filter_attributes is not None
assert result.filter_attributes["name"] == ["admin"] assert result.filter_attributes["role__name"] == ["admin"]
class TestFilterBy: class TestFilterBy:
@@ -755,7 +755,7 @@ class TestFilterBy:
) )
result = await UserRelFacetCrud.offset_paginate( result = await UserRelFacetCrud.offset_paginate(
db_session, filter_by={"name": "admin"}, schema=UserRead db_session, filter_by={"role__name": "admin"}, schema=UserRead
) )
assert isinstance(result.pagination, OffsetPagination) assert isinstance(result.pagination, OffsetPagination)
@@ -824,7 +824,7 @@ class TestFilterBy:
result = await UserRoleFacetCrud.offset_paginate( result = await UserRoleFacetCrud.offset_paginate(
db_session, db_session,
filter_by={"name": "admin", "id": str(admin.id)}, filter_by={"role__name": "admin", "role__id": str(admin.id)},
schema=UserRead, schema=UserRead,
) )
@@ -916,15 +916,15 @@ class TestFilterParamsSchema:
param_names = set(inspect.signature(dep).parameters) param_names = set(inspect.signature(dep).parameters)
assert param_names == {"username", "email"} assert param_names == {"username", "email"}
def test_relationship_facet_uses_column_key(self): def test_relationship_facet_uses_full_chain_key(self):
"""Relationship tuple uses the terminal column's key.""" """Relationship tuple uses the full chain joined by __ as the key."""
import inspect import inspect
UserRoleCrud = CrudFactory(User, facet_fields=[(User.role, Role.name)]) UserRoleCrud = CrudFactory(User, facet_fields=[(User.role, Role.name)])
dep = UserRoleCrud.filter_params() dep = UserRoleCrud.filter_params()
param_names = set(inspect.signature(dep).parameters) param_names = set(inspect.signature(dep).parameters)
assert param_names == {"name"} assert param_names == {"role__name"}
def test_raises_when_no_facet_fields(self): def test_raises_when_no_facet_fields(self):
"""ValueError raised when no facet_fields are configured or provided.""" """ValueError raised when no facet_fields are configured or provided."""
@@ -978,6 +978,22 @@ class TestFilterParamsSchema:
keys = facet_keys([(rel_a, col_a), (rel_b, col_b)]) keys = facet_keys([(rel_a, col_a), (rel_b, col_b)])
assert keys == ["project__name", "os__name"] assert keys == ["project__name", "os__name"]
def test_deep_chain_joins_all_segments(self):
"""Three-element tuple produces all relation segments joined by __."""
from unittest.mock import MagicMock
from fastapi_toolsets.crud.search import facet_keys
rel_a = MagicMock()
rel_a.key = "role"
rel_b = MagicMock()
rel_b.key = "permission"
col = MagicMock()
col.key = "name"
keys = facet_keys([(rel_a, rel_b, col)])
assert keys == ["role__permission__name"]
def test_unique_column_keys_kept_plain(self): def test_unique_column_keys_kept_plain(self):
"""Fields with unique column keys are not prefixed.""" """Fields with unique column keys are not prefixed."""
from fastapi_toolsets.crud.search import facet_keys from fastapi_toolsets.crud.search import facet_keys

View File

@@ -97,7 +97,7 @@ class TestAppSessionDep:
gen = get_db() gen = get_db()
session = await gen.__anext__() session = await gen.__anext__()
assert isinstance(session, AsyncSession) assert isinstance(session, AsyncSession)
await session.close() await gen.aclose()
class TestOffsetPagination: class TestOffsetPagination:
@@ -182,8 +182,7 @@ class TestOffsetPagination:
body = resp.json() body = resp.json()
fa = body["filter_attributes"] fa = body["filter_attributes"]
assert set(fa["status"]) == {"draft", "published"} assert set(fa["status"]) == {"draft", "published"}
# "name" is unique across all facet fields — no prefix needed assert set(fa["category__name"]) == {"backend", "python"}
assert set(fa["name"]) == {"backend", "python"}
@pytest.mark.anyio @pytest.mark.anyio
async def test_filter_attributes_scoped_to_filter( async def test_filter_attributes_scoped_to_filter(

View File

@@ -1,6 +1,7 @@
"""Tests for fastapi_toolsets.fixtures module.""" """Tests for fastapi_toolsets.fixtures module."""
import uuid import uuid
from enum import Enum
import pytest import pytest
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@@ -13,10 +14,22 @@ from fastapi_toolsets.fixtures import (
load_fixtures, load_fixtures,
load_fixtures_by_context, load_fixtures_by_context,
) )
from fastapi_toolsets.fixtures.utils import _get_primary_key, _instance_to_dict
from fastapi_toolsets.fixtures.utils import _get_primary_key from .conftest import IntRole, Permission, Role, RoleCreate, RoleCrud, User, UserCrud
from .conftest import IntRole, Permission, Role, RoleCrud, User, UserCrud
class AppContext(str, Enum):
"""Example user-defined str+Enum context."""
STAGING = "staging"
DEMO = "demo"
class PlainEnumContext(Enum):
"""Example user-defined plain Enum context (no str mixin)."""
STAGING = "staging"
class TestContext: class TestContext:
@@ -39,6 +52,86 @@ class TestContext:
assert Context.TESTING.value == "testing" assert Context.TESTING.value == "testing"
class TestCustomEnumContext:
"""Custom Enum types are accepted wherever Context/str are expected."""
def test_cannot_subclass_context_with_members(self):
"""Python prohibits extending an Enum that already has members."""
with pytest.raises(TypeError):
class MyContext(Context): # noqa: F841 # ty: ignore[subclass-of-final-class]
STAGING = "staging"
def test_custom_enum_values_interchangeable_with_context(self):
"""A custom enum with the same .value as a built-in Context member is
treated as the same context — fixtures registered under one are found
by the other."""
class AppContextFull(str, Enum):
BASE = "base"
STAGING = "staging"
registry = FixtureRegistry()
@registry.register(contexts=[Context.BASE])
def roles():
return []
# AppContextFull.BASE has value "base" — same as Context.BASE
fixtures = registry.get_by_context(AppContextFull.BASE)
assert len(fixtures) == 1
def test_custom_enum_registry_default_contexts(self):
"""FixtureRegistry(contexts=[...]) accepts a custom Enum."""
registry = FixtureRegistry(contexts=[AppContext.STAGING])
@registry.register
def data():
return []
fixture = registry.get("data")
assert fixture.contexts == ["staging"]
def test_custom_enum_resolve_context_dependencies(self):
"""resolve_context_dependencies accepts a custom Enum context."""
registry = FixtureRegistry()
@registry.register(contexts=[AppContext.STAGING])
def staging_roles():
return []
order = registry.resolve_context_dependencies(AppContext.STAGING)
assert "staging_roles" in order
@pytest.mark.anyio
async def test_custom_enum_e2e(self, db_session: AsyncSession):
"""End-to-end: register with custom Enum, load with the same Enum."""
registry = FixtureRegistry()
@registry.register(contexts=[AppContext.STAGING])
def staging_roles():
return [Role(id=uuid.uuid4(), name="staging-admin")]
result = await load_fixtures_by_context(
db_session, registry, AppContext.STAGING
)
assert len(result["staging_roles"]) == 1
@pytest.mark.anyio
async def test_plain_enum_e2e(self, db_session: AsyncSession):
"""End-to-end: register with plain Enum, load with the same Enum."""
registry = FixtureRegistry()
@registry.register(contexts=[PlainEnumContext.STAGING])
def staging_roles():
return [Role(id=uuid.uuid4(), name="plain-staging-admin")]
result = await load_fixtures_by_context(
db_session, registry, PlainEnumContext.STAGING
)
assert len(result["staging_roles"]) == 1
class TestLoadStrategy: class TestLoadStrategy:
"""Tests for LoadStrategy enum.""" """Tests for LoadStrategy enum."""
@@ -407,6 +500,37 @@ class TestDependencyResolution:
with pytest.raises(ValueError, match="Circular dependency"): with pytest.raises(ValueError, match="Circular dependency"):
registry.resolve_dependencies("a") registry.resolve_dependencies("a")
def test_resolve_raises_for_unknown_dependency(self):
"""KeyError when depends_on references an unregistered fixture."""
registry = FixtureRegistry()
@registry.register(depends_on=["ghost"])
def users():
return []
with pytest.raises(KeyError, match="ghost"):
registry.resolve_dependencies("users")
def test_resolve_deduplicates_shared_depends_on_across_variants(self):
"""A dep shared by two same-name variants appears only once in the order."""
registry = FixtureRegistry()
@registry.register(contexts=[Context.BASE])
def roles():
return []
@registry.register(depends_on=["roles"], contexts=[Context.BASE])
def items():
return []
@registry.register(depends_on=["roles"], contexts=[Context.TESTING])
def items(): # noqa: F811
return []
order = registry.resolve_dependencies("items")
assert order.count("roles") == 1
assert order.index("roles") < order.index("items")
def test_resolve_context_dependencies(self): def test_resolve_context_dependencies(self):
"""Resolve all fixtures for a context with dependencies.""" """Resolve all fixtures for a context with dependencies."""
registry = FixtureRegistry() registry = FixtureRegistry()
@@ -496,6 +620,52 @@ class TestLoadFixtures:
count = await RoleCrud.count(db_session) count = await RoleCrud.count(db_session)
assert count == 1 assert count == 1
@pytest.mark.anyio
async def test_merge_does_not_overwrite_omitted_nullable_columns(
self, db_session: AsyncSession
):
"""MERGE must not clear nullable columns that the fixture didn't set.
When a fixture omits a nullable column (e.g. role_id or notes), a re-merge
must leave the existing DB value untouched — not overwrite it with NULL.
"""
registry = FixtureRegistry()
admin = await RoleCrud.create(db_session, RoleCreate(name="admin"))
uid = uuid.uuid4()
# First load: user has role_id and notes set
@registry.register
def users():
return [
User(
id=uid,
username="alice",
email="a@test.com",
role_id=admin.id,
notes="original",
)
]
await load_fixtures(db_session, registry, "users", strategy=LoadStrategy.MERGE)
# Second load: fixture omits role_id and notes
registry2 = FixtureRegistry()
@registry2.register
def users(): # noqa: F811
return [User(id=uid, username="alice-updated", email="a@test.com")]
await load_fixtures(db_session, registry2, "users", strategy=LoadStrategy.MERGE)
from sqlalchemy import select
row = (
await db_session.execute(select(User).where(User.id == uid))
).scalar_one()
assert row.username == "alice-updated" # updated column changed
assert row.role_id == admin.id # omitted → preserved
assert row.notes == "original" # omitted → preserved
@pytest.mark.anyio @pytest.mark.anyio
async def test_load_with_skip_existing_strategy(self, db_session: AsyncSession): async def test_load_with_skip_existing_strategy(self, db_session: AsyncSession):
"""Load fixtures with SKIP_EXISTING strategy.""" """Load fixtures with SKIP_EXISTING strategy."""
@@ -795,3 +965,145 @@ class TestGetPrimaryKey:
instance = Permission(subject="post") # action is None instance = Permission(subject="post") # action is None
pk = _get_primary_key(instance) pk = _get_primary_key(instance)
assert pk is None assert pk is None
class TestRegistryGetVariants:
"""Tests for FixtureRegistry.get and get_variants edge cases."""
def test_get_raises_value_error_for_multi_variant(self):
"""get() raises ValueError when the fixture has multiple context variants."""
registry = FixtureRegistry()
@registry.register(contexts=[Context.BASE])
def items():
return []
@registry.register(contexts=[Context.TESTING])
def items(): # noqa: F811
return []
with pytest.raises(ValueError, match="get_variants"):
registry.get("items")
def test_get_variants_raises_key_error_for_unknown(self):
"""get_variants() raises KeyError for an unregistered name."""
registry = FixtureRegistry()
with pytest.raises(KeyError, match="not found"):
registry.get_variants("no_such_fixture")
class TestInstanceToDict:
"""Unit tests for the _instance_to_dict helper."""
def test_explicit_values_included(self):
"""All explicitly set column values appear in the result."""
role_id = uuid.uuid4()
instance = Role(id=role_id, name="admin")
d = _instance_to_dict(instance)
assert d["id"] == role_id
assert d["name"] == "admin"
def test_callable_default_none_excluded(self):
"""A column whose value is None but has a callable Python-side default
(e.g. ``default=uuid.uuid4``) is excluded so the DB generates it."""
instance = Role(id=None, name="admin")
d = _instance_to_dict(instance)
assert "id" not in d
assert d["name"] == "admin"
def test_nullable_none_included(self):
"""None on a nullable column with no default is kept (explicit NULL)."""
instance = User(id=uuid.uuid4(), username="u", email="e@e.com", role_id=None)
d = _instance_to_dict(instance)
assert "role_id" in d
assert d["role_id"] is None
def test_nullable_str_no_default_omitted_not_in_dict(self):
"""Mapped[str | None] with no default, not provided in constructor, is absent from dict."""
instance = User(id=uuid.uuid4(), username="u", email="e@e.com")
d = _instance_to_dict(instance)
assert "notes" not in d
def test_nullable_str_no_default_explicit_none_included(self):
"""Mapped[str | None] with no default, explicitly set to None, is included as NULL."""
instance = User(id=uuid.uuid4(), username="u", email="e@e.com", notes=None)
d = _instance_to_dict(instance)
assert "notes" in d
assert d["notes"] is None
def test_nullable_str_no_default_with_value_included(self):
"""Mapped[str | None] with no default and a value set is included normally."""
instance = User(id=uuid.uuid4(), username="u", email="e@e.com", notes="hello")
d = _instance_to_dict(instance)
assert d["notes"] == "hello"
@pytest.mark.anyio
async def test_nullable_str_no_default_insert_roundtrip(
self, db_session: AsyncSession
):
"""Fixture loading works for models with Mapped[str | None] (no default).
Both the omitted-value (→ NULL) and explicit-None paths must insert without error.
"""
registry = FixtureRegistry()
uid_a = uuid.uuid4()
uid_b = uuid.uuid4()
uid_c = uuid.uuid4()
@registry.register
def users():
return [
User(
id=uid_a, username="no_notes", email="a@test.com"
), # notes omitted
User(
id=uid_b, username="null_notes", email="b@test.com", notes=None
), # explicit None
User(
id=uid_c, username="has_notes", email="c@test.com", notes="hi"
), # value set
]
result = await load_fixtures(db_session, registry, "users")
from sqlalchemy import select
rows = (
(await db_session.execute(select(User).order_by(User.username)))
.scalars()
.all()
)
by_username = {r.username: r for r in rows}
assert by_username["no_notes"].notes is None
assert by_username["null_notes"].notes is None
assert by_username["has_notes"].notes == "hi"
assert len(result["users"]) == 3
class TestBatchMergeNonPkColumns:
"""Batch MERGE on a model with no non-PK columns (PK-only table)."""
@pytest.mark.anyio
async def test_merge_pk_only_model(self, db_session: AsyncSession):
"""MERGE strategy on a PK-only model uses on_conflict_do_nothing."""
registry = FixtureRegistry()
@registry.register
def permissions():
return [
Permission(subject="post", action="read"),
Permission(subject="post", action="write"),
]
result = await load_fixtures(
db_session, registry, "permissions", strategy=LoadStrategy.MERGE
)
assert len(result["permissions"]) == 2
# Run again — conflicts are silently ignored.
result2 = await load_fixtures(
db_session, registry, "permissions", strategy=LoadStrategy.MERGE
)
assert len(result2["permissions"]) == 2

View File

@@ -32,6 +32,7 @@ from fastapi_toolsets.models.watched import (
_after_flush, _after_flush,
_after_flush_postexec, _after_flush_postexec,
_after_rollback, _after_rollback,
_snapshot_column_attrs,
_task_error_handler, _task_error_handler,
_upsert_changes, _upsert_changes,
) )
@@ -213,20 +214,36 @@ class AttrAccessModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
__tablename__ = "mixin_attr_access_models" __tablename__ = "mixin_attr_access_models"
name: Mapped[str] = mapped_column(String(50)) name: Mapped[str] = mapped_column(String(50))
callback_url: Mapped[str | None] = mapped_column(String(200), nullable=True)
async def on_create(self) -> None: async def on_create(self) -> None:
_attr_access_events.append( _attr_access_events.append(
{"event": "create", "id": self.id, "name": self.name} {
"event": "create",
"id": self.id,
"name": self.name,
"callback_url": self.callback_url,
}
) )
async def on_delete(self) -> None: async def on_delete(self) -> None:
_attr_access_events.append( _attr_access_events.append(
{"event": "delete", "id": self.id, "name": self.name} {
"event": "delete",
"id": self.id,
"name": self.name,
"callback_url": self.callback_url,
}
) )
async def on_update(self, changes: dict) -> None: async def on_update(self, changes: dict) -> None:
_attr_access_events.append( _attr_access_events.append(
{"event": "update", "id": self.id, "name": self.name} {
"event": "update",
"id": self.id,
"name": self.name,
"callback_url": self.callback_url,
}
) )
@@ -1279,3 +1296,67 @@ class TestAttributeAccessInCallbacks:
assert len(events) == 1 assert len(events) == 1
assert isinstance(events[0]["id"], uuid.UUID) assert isinstance(events[0]["id"], uuid.UUID)
assert events[0]["name"] == "updated" assert events[0]["name"] == "updated"
@pytest.mark.anyio
async def test_nullable_column_none_accessible_in_on_create(
self, mixin_session_expire
):
"""Nullable column left as None is accessible in on_create without greenlet error."""
obj = AttrAccessModel(name="no-url") # callback_url not set → None
mixin_session_expire.add(obj)
await mixin_session_expire.commit()
await asyncio.sleep(0)
events = [e for e in _attr_access_events if e["event"] == "create"]
assert len(events) == 1
assert events[0]["callback_url"] is None
@pytest.mark.anyio
async def test_nullable_column_with_value_accessible_in_on_create(
self, mixin_session_expire
):
"""Nullable column set to a value is accessible in on_create without greenlet error."""
obj = AttrAccessModel(name="with-url", callback_url="https://example.com/hook")
mixin_session_expire.add(obj)
await mixin_session_expire.commit()
await asyncio.sleep(0)
events = [e for e in _attr_access_events if e["event"] == "create"]
assert len(events) == 1
assert events[0]["callback_url"] == "https://example.com/hook"
@pytest.mark.anyio
async def test_nullable_column_accessible_after_update_to_none(
self, mixin_session_expire
):
"""Nullable column updated to None is accessible in on_update without greenlet error."""
obj = AttrAccessModel(name="x", callback_url="https://example.com/hook")
mixin_session_expire.add(obj)
await mixin_session_expire.commit()
await asyncio.sleep(0)
_attr_access_events.clear()
obj.callback_url = None
await mixin_session_expire.commit()
await asyncio.sleep(0)
events = [e for e in _attr_access_events if e["event"] == "update"]
assert len(events) == 1
assert events[0]["callback_url"] is None
@pytest.mark.anyio
async def test_expired_nullable_column_not_inferred_as_none(
self, mixin_session_expire
):
"""A nullable column with a real value that is expired (by a prior
expire_on_commit) must not be inferred as None in the snapshot — its
actual value is unknown without a DB refresh."""
obj = AttrAccessModel(name="original", callback_url="https://example.com/hook")
mixin_session_expire.add(obj)
await mixin_session_expire.commit()
# expire_on_commit fired → obj.state.expired=True, callback_url not in state.dict
snapshot = _snapshot_column_attrs(obj)
# callback_url has a real DB value but is expired — must not be snapshotted as None.
assert "callback_url" not in snapshot

View File

@@ -7,9 +7,10 @@ from fastapi import Depends, FastAPI
from httpx import AsyncClient from httpx import AsyncClient
from sqlalchemy import select, text from sqlalchemy import select, text
from sqlalchemy.engine import make_url from sqlalchemy.engine import make_url
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
from fastapi_toolsets.db import get_transaction
from fastapi_toolsets.fixtures import Context, FixtureRegistry from fastapi_toolsets.fixtures import Context, FixtureRegistry
from fastapi_toolsets.pytest import ( from fastapi_toolsets.pytest import (
create_async_client, create_async_client,
@@ -336,6 +337,55 @@ class TestCreateDbSession:
result = await session.execute(select(Role)) result = await session.execute(select(Role))
assert result.all() == [] assert result.all() == []
@pytest.mark.anyio
async def test_get_transaction_commits_visible_to_separate_session(self):
"""Data written via get_transaction() is committed and visible to other sessions."""
role_id = uuid.uuid4()
async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session:
# Simulate what _create_fixture_function does: insert via get_transaction
# with no explicit commit afterward.
async with get_transaction(session):
role = Role(id=role_id, name="visible_to_other_session")
session.add(role)
# The data must have been committed (begin/commit, not a savepoint),
# so a separate engine/session can read it.
other_engine = create_async_engine(DATABASE_URL, echo=False)
try:
other_session_maker = async_sessionmaker(
other_engine, expire_on_commit=False
)
async with other_session_maker() as other:
result = await other.execute(select(Role).where(Role.id == role_id))
fetched = result.scalar_one_or_none()
assert fetched is not None, (
"Fixture data inserted via get_transaction() must be committed "
"and visible to a separate session. If create_db_session uses "
"create_db_context, auto-begin forces get_transaction() into "
"savepoints instead of real commits."
)
assert fetched.name == "visible_to_other_session"
finally:
await other_engine.dispose()
# Cleanup
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as _:
pass
class TestDeprecatedCleanupTables:
"""Tests for the deprecated cleanup_tables re-export in fastapi_toolsets.pytest."""
@pytest.mark.anyio
async def test_emits_deprecation_warning(self):
"""cleanup_tables imported from fastapi_toolsets.pytest emits DeprecationWarning."""
from fastapi_toolsets.pytest.utils import cleanup_tables
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
with pytest.warns(DeprecationWarning, match="fastapi_toolsets.db"):
await cleanup_tables(session, Base)
class TestGetXdistWorker: class TestGetXdistWorker:
"""Tests for _get_xdist_worker helper.""" """Tests for _get_xdist_worker helper."""