mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-16 06:36:26 +02:00
Compare commits
8 Commits
v2.4.3
...
104285c6e5
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
104285c6e5 | ||
|
|
f5afbbe37f | ||
|
|
f4698bea8a | ||
|
|
5215b921ae | ||
|
9dad59e25d
|
|||
|
|
29326ab532 | ||
|
|
04afef7e33 | ||
|
|
666c621fda |
1
docs/CNAME
Normal file
1
docs/CNAME
Normal file
@@ -0,0 +1 @@
|
|||||||
|
fastapi-toolsets.d3vyce.fr
|
||||||
@@ -36,7 +36,7 @@ class UserCrud(AsyncCrud[User]):
|
|||||||
default_load_options = [selectinload(User.role)]
|
default_load_options = [selectinload(User.role)]
|
||||||
```
|
```
|
||||||
|
|
||||||
Subclassing [`AsyncCrud`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud) directly is the preferred style when you need to add custom methods or when the configuration is complex enough to benefit from a named class body.
|
Subclassing [`AsyncCrud`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud) directly is the preferred style when you need to add custom methods or when the configuration is complex enough to benefit from a named class body.
|
||||||
|
|
||||||
### Adding custom methods
|
### Adding custom methods
|
||||||
|
|
||||||
@@ -474,7 +474,7 @@ The distinct values are returned in the `filter_attributes` field of [`Paginated
|
|||||||
"filter_attributes": {
|
"filter_attributes": {
|
||||||
"status": ["active", "inactive"],
|
"status": ["active", "inactive"],
|
||||||
"country": ["DE", "FR", "US"],
|
"country": ["DE", "FR", "US"],
|
||||||
"name": ["admin", "editor", "viewer"]
|
"role__name": ["admin", "editor", "viewer"]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
@@ -482,7 +482,7 @@ The distinct values are returned in the `filter_attributes` field of [`Paginated
|
|||||||
Use `filter_by` to pass the client's chosen filter values directly — no need to build SQLAlchemy conditions by hand. Any unknown key raises [`InvalidFacetFilterError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.InvalidFacetFilterError).
|
Use `filter_by` to pass the client's chosen filter values directly — no need to build SQLAlchemy conditions by hand. Any unknown key raises [`InvalidFacetFilterError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.InvalidFacetFilterError).
|
||||||
|
|
||||||
!!! info "The keys in `filter_by` are the same keys the client received in `filter_attributes`."
|
!!! info "The keys in `filter_by` are the same keys the client received in `filter_attributes`."
|
||||||
Keys are normally the terminal `column.key` (e.g. `"name"` for `Role.name`). When two facet fields share the same column key (e.g. `(Build.project, Project.name)` and `(Build.os, Os.name)`), the relationship name is prepended automatically: `"project__name"` and `"os__name"`.
|
Keys use `__` as a separator for the full relationship chain. A direct column `User.status` produces `"status"`. A relationship tuple `(User.role, Role.name)` produces `"role__name"`. A deeper chain `(User.role, Role.permission, Permission.name)` produces `"role__permission__name"`.
|
||||||
|
|
||||||
`filter_by` and `filters` can be combined — both are applied with AND logic.
|
`filter_by` and `filters` can be combined — both are applied with AND logic.
|
||||||
|
|
||||||
@@ -515,9 +515,9 @@ async def list_users(
|
|||||||
Both single-value and multi-value query parameters work:
|
Both single-value and multi-value query parameters work:
|
||||||
|
|
||||||
```
|
```
|
||||||
GET /users?status=active → filter_by={"status": ["active"]}
|
GET /users?status=active → filter_by={"status": ["active"]}
|
||||||
GET /users?status=active&country=FR → filter_by={"status": ["active"], "country": ["FR"]}
|
GET /users?status=active&country=FR → filter_by={"status": ["active"], "country": ["FR"]}
|
||||||
GET /users?role=admin&role=editor → filter_by={"role": ["admin", "editor"]} (IN clause)
|
GET /users?role__name=admin&role__name=editor → filter_by={"role__name": ["admin", "editor"]} (IN clause)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Sorting
|
## Sorting
|
||||||
|
|||||||
@@ -38,18 +38,20 @@ By context with [`load_fixtures_by_context`](../reference/fixtures.md#fastapi_to
|
|||||||
from fastapi_toolsets.fixtures import load_fixtures_by_context
|
from fastapi_toolsets.fixtures import load_fixtures_by_context
|
||||||
|
|
||||||
async with db_context() as session:
|
async with db_context() as session:
|
||||||
await load_fixtures_by_context(session=session, registry=fixtures, context=Context.TESTING)
|
await load_fixtures_by_context(session, fixtures, Context.TESTING)
|
||||||
```
|
```
|
||||||
|
|
||||||
Directly with [`load_fixtures`](../reference/fixtures.md#fastapi_toolsets.fixtures.utils.load_fixtures):
|
Directly by name with [`load_fixtures`](../reference/fixtures.md#fastapi_toolsets.fixtures.utils.load_fixtures):
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from fastapi_toolsets.fixtures import load_fixtures
|
from fastapi_toolsets.fixtures import load_fixtures
|
||||||
|
|
||||||
async with db_context() as session:
|
async with db_context() as session:
|
||||||
await load_fixtures(session=session, registry=fixtures)
|
await load_fixtures(session, fixtures, "roles", "test_users")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Both functions return a `dict[str, list[...]]` mapping each fixture name to the list of loaded instances.
|
||||||
|
|
||||||
## Contexts
|
## Contexts
|
||||||
|
|
||||||
[`Context`](../reference/fixtures.md#fastapi_toolsets.fixtures.enum.Context) is an enum with predefined values:
|
[`Context`](../reference/fixtures.md#fastapi_toolsets.fixtures.enum.Context) is an enum with predefined values:
|
||||||
@@ -58,10 +60,60 @@ async with db_context() as session:
|
|||||||
|---------|-------------|
|
|---------|-------------|
|
||||||
| `Context.BASE` | Core data required in all environments |
|
| `Context.BASE` | Core data required in all environments |
|
||||||
| `Context.TESTING` | Data only loaded during tests |
|
| `Context.TESTING` | Data only loaded during tests |
|
||||||
|
| `Context.DEVELOPMENT` | Data only loaded in development |
|
||||||
| `Context.PRODUCTION` | Data only loaded in production |
|
| `Context.PRODUCTION` | Data only loaded in production |
|
||||||
|
|
||||||
A fixture with no `contexts` defined takes `Context.BASE` by default.
|
A fixture with no `contexts` defined takes `Context.BASE` by default.
|
||||||
|
|
||||||
|
### Custom contexts
|
||||||
|
|
||||||
|
Plain strings and any `Enum` subclass are accepted wherever a `Context` enum is expected.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
class AppContext(str, Enum):
|
||||||
|
STAGING = "staging"
|
||||||
|
DEMO = "demo"
|
||||||
|
|
||||||
|
@fixtures.register(contexts=[AppContext.STAGING])
|
||||||
|
def staging_data():
|
||||||
|
return [Config(key="feature_x", enabled=True)]
|
||||||
|
|
||||||
|
await load_fixtures_by_context(session, fixtures, AppContext.STAGING)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Default context for a registry
|
||||||
|
|
||||||
|
Pass `contexts` to `FixtureRegistry` to set a default for all fixtures registered in it:
|
||||||
|
|
||||||
|
```python
|
||||||
|
testing_registry = FixtureRegistry(contexts=[Context.TESTING])
|
||||||
|
|
||||||
|
@testing_registry.register # implicitly contexts=[Context.TESTING]
|
||||||
|
def test_orders():
|
||||||
|
return [Order(id=1, total=99)]
|
||||||
|
```
|
||||||
|
|
||||||
|
### Same fixture name, multiple context variants
|
||||||
|
|
||||||
|
The same fixture name may be registered under different (non-overlapping) context sets. When multiple contexts are loaded together, all matching variants are merged:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@fixtures.register(contexts=[Context.BASE])
|
||||||
|
def users():
|
||||||
|
return [User(id=1, username="admin")]
|
||||||
|
|
||||||
|
@fixtures.register(contexts=[Context.TESTING])
|
||||||
|
def users():
|
||||||
|
return [User(id=2, username="tester")]
|
||||||
|
|
||||||
|
# loads both admin and tester
|
||||||
|
await load_fixtures_by_context(session, fixtures, Context.BASE, Context.TESTING)
|
||||||
|
```
|
||||||
|
|
||||||
|
Registering two variants with overlapping context sets raises `ValueError`.
|
||||||
|
|
||||||
## Load strategies
|
## Load strategies
|
||||||
|
|
||||||
[`LoadStrategy`](../reference/fixtures.md#fastapi_toolsets.fixtures.enum.LoadStrategy) controls how the fixture loader handles rows that already exist:
|
[`LoadStrategy`](../reference/fixtures.md#fastapi_toolsets.fixtures.enum.LoadStrategy) controls how the fixture loader handles rows that already exist:
|
||||||
@@ -69,20 +121,44 @@ A fixture with no `contexts` defined takes `Context.BASE` by default.
|
|||||||
| Strategy | Description |
|
| Strategy | Description |
|
||||||
|----------|-------------|
|
|----------|-------------|
|
||||||
| `LoadStrategy.INSERT` | Insert only, fail on duplicates |
|
| `LoadStrategy.INSERT` | Insert only, fail on duplicates |
|
||||||
| `LoadStrategy.UPSERT` | Insert or update on conflict |
|
| `LoadStrategy.MERGE` | Insert or update on conflict (default) |
|
||||||
| `LoadStrategy.SKIP` | Skip rows that already exist |
|
| `LoadStrategy.SKIP_EXISTING` | Skip rows that already exist |
|
||||||
|
|
||||||
|
```python
|
||||||
|
await load_fixtures_by_context(
|
||||||
|
session, fixtures, Context.BASE, strategy=LoadStrategy.SKIP_EXISTING
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
## Merging registries
|
## Merging registries
|
||||||
|
|
||||||
Split fixtures definitions across modules and merge them:
|
Split fixture definitions across modules and merge them:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from myapp.fixtures.dev import dev_fixtures
|
from myapp.fixtures.dev import dev_fixtures
|
||||||
from myapp.fixtures.prod import prod_fixtures
|
from myapp.fixtures.prod import prod_fixtures
|
||||||
|
|
||||||
fixtures = fixturesRegistry()
|
fixtures = FixtureRegistry()
|
||||||
fixtures.include_registry(registry=dev_fixtures)
|
fixtures.include_registry(registry=dev_fixtures)
|
||||||
fixtures.include_registry(registry=prod_fixtures)
|
fixtures.include_registry(registry=prod_fixtures)
|
||||||
|
```
|
||||||
|
|
||||||
|
Fixtures with the same name are allowed as long as their context sets do not overlap. Conflicting contexts raise `ValueError`.
|
||||||
|
|
||||||
|
## Looking up fixture instances
|
||||||
|
|
||||||
|
[`get_obj_by_attr`](../reference/fixtures.md#fastapi_toolsets.fixtures.utils.get_obj_by_attr) retrieves a specific instance from a fixture function by attribute value — useful when building cross-fixture `depends_on` relationships:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.fixtures import get_obj_by_attr
|
||||||
|
|
||||||
|
@fixtures.register(depends_on=["roles"])
|
||||||
|
def users():
|
||||||
|
admin_role = get_obj_by_attr(roles, "name", "admin")
|
||||||
|
return [User(id=1, username="alice", role_id=admin_role.id)]
|
||||||
|
```
|
||||||
|
|
||||||
|
Raises `StopIteration` if no matching instance is found.
|
||||||
|
|
||||||
## Pytest integration
|
## Pytest integration
|
||||||
|
|
||||||
@@ -111,7 +187,6 @@ async def test_user_can_login(fixture_users: list[User], fixture_roles: list[Rol
|
|||||||
...
|
...
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
The load order is resolved automatically from the `depends_on` declarations in your registry. Each generated fixture receives `db_session` as a dependency and returns the list of loaded model instances.
|
The load order is resolved automatically from the `depends_on` declarations in your registry. Each generated fixture receives `db_session` as a dependency and returns the list of loaded model instances.
|
||||||
|
|
||||||
## CLI integration
|
## CLI integration
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import functools
|
import functools
|
||||||
from collections import Counter
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass, replace
|
from dataclasses import dataclass, replace
|
||||||
from typing import TYPE_CHECKING, Any, Literal
|
from typing import TYPE_CHECKING, Any, Literal
|
||||||
@@ -151,7 +150,7 @@ def build_search_filters(
|
|||||||
|
|
||||||
|
|
||||||
def facet_keys(facet_fields: Sequence[FacetFieldType]) -> list[str]:
|
def facet_keys(facet_fields: Sequence[FacetFieldType]) -> list[str]:
|
||||||
"""Return a key for each facet field, disambiguating duplicate column keys.
|
"""Return a key for each facet field.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
facet_fields: Sequence of facet fields — either direct columns or
|
facet_fields: Sequence of facet fields — either direct columns or
|
||||||
@@ -160,22 +159,12 @@ def facet_keys(facet_fields: Sequence[FacetFieldType]) -> list[str]:
|
|||||||
Returns:
|
Returns:
|
||||||
A list of string keys, one per facet field, in the same order.
|
A list of string keys, one per facet field, in the same order.
|
||||||
"""
|
"""
|
||||||
raw: list[tuple[str, str | None]] = []
|
keys: list[str] = []
|
||||||
for field in facet_fields:
|
for field in facet_fields:
|
||||||
if isinstance(field, tuple):
|
if isinstance(field, tuple):
|
||||||
rel = field[-2]
|
keys.append("__".join(el.key for el in field))
|
||||||
column = field[-1]
|
|
||||||
raw.append((column.key, rel.key))
|
|
||||||
else:
|
else:
|
||||||
raw.append((field.key, None))
|
keys.append(field.key)
|
||||||
|
|
||||||
counts = Counter(col_key for col_key, _ in raw)
|
|
||||||
keys: list[str] = []
|
|
||||||
for col_key, rel_key in raw:
|
|
||||||
if counts[col_key] > 1 and rel_key is not None:
|
|
||||||
keys.append(f"{rel_key}__{col_key}")
|
|
||||||
else:
|
|
||||||
keys.append(col_key)
|
|
||||||
return keys
|
return keys
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
@@ -12,6 +13,13 @@ from .enum import Context
|
|||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_contexts(
|
||||||
|
contexts: list[str | Enum] | tuple[str | Enum, ...],
|
||||||
|
) -> list[str]:
|
||||||
|
"""Convert a sequence of any Enum subclass and/or plain strings to a list of strings."""
|
||||||
|
return [c.value if isinstance(c, Enum) else c for c in contexts]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Fixture:
|
class Fixture:
|
||||||
"""A fixture definition with metadata."""
|
"""A fixture definition with metadata."""
|
||||||
@@ -50,26 +58,51 @@ class FixtureRegistry:
|
|||||||
Post(id=1, title="Test", user_id=1),
|
Post(id=1, title="Test", user_id=1),
|
||||||
]
|
]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Fixtures with the same name may be registered for **different** contexts.
|
||||||
|
When multiple contexts are loaded together, their instances are merged:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@fixtures.register(contexts=[Context.BASE])
|
||||||
|
def users():
|
||||||
|
return [User(id=1, username="admin")]
|
||||||
|
|
||||||
|
@fixtures.register(contexts=[Context.TESTING])
|
||||||
|
def users():
|
||||||
|
return [User(id=2, username="tester")]
|
||||||
|
# load_fixtures_by_context(..., Context.BASE, Context.TESTING)
|
||||||
|
# → loads both User(admin) and User(tester) under the "users" name
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
contexts: list[str | Context] | None = None,
|
contexts: list[str | Enum] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._fixtures: dict[str, Fixture] = {}
|
self._fixtures: dict[str, list[Fixture]] = {}
|
||||||
self._default_contexts: list[str] | None = (
|
self._default_contexts: list[str] | None = (
|
||||||
[c.value if isinstance(c, Context) else c for c in contexts]
|
_normalize_contexts(contexts) if contexts else None
|
||||||
if contexts
|
|
||||||
else None
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _validate_no_context_overlap(self, name: str, new_contexts: list[str]) -> None:
|
||||||
|
"""Raise ``ValueError`` if any existing variant for *name* overlaps."""
|
||||||
|
existing_variants = self._fixtures.get(name, [])
|
||||||
|
new_set = set(new_contexts)
|
||||||
|
for variant in existing_variants:
|
||||||
|
if set(variant.contexts) & new_set:
|
||||||
|
raise ValueError(
|
||||||
|
f"Fixture '{name}' already exists in the current registry "
|
||||||
|
f"with overlapping contexts. Use distinct context sets for "
|
||||||
|
f"each variant of the same fixture name."
|
||||||
|
)
|
||||||
|
|
||||||
def register(
|
def register(
|
||||||
self,
|
self,
|
||||||
func: Callable[[], Sequence[DeclarativeBase]] | None = None,
|
func: Callable[[], Sequence[DeclarativeBase]] | None = None,
|
||||||
*,
|
*,
|
||||||
name: str | None = None,
|
name: str | None = None,
|
||||||
depends_on: list[str] | None = None,
|
depends_on: list[str] | None = None,
|
||||||
contexts: list[str | Context] | None = None,
|
contexts: list[str | Enum] | None = None,
|
||||||
) -> Callable[..., Any]:
|
) -> Callable[..., Any]:
|
||||||
"""Register a fixture function.
|
"""Register a fixture function.
|
||||||
|
|
||||||
@@ -79,7 +112,8 @@ class FixtureRegistry:
|
|||||||
func: Fixture function returning list of model instances
|
func: Fixture function returning list of model instances
|
||||||
name: Fixture name (defaults to function name)
|
name: Fixture name (defaults to function name)
|
||||||
depends_on: List of fixture names this depends on
|
depends_on: List of fixture names this depends on
|
||||||
contexts: List of contexts this fixture belongs to
|
contexts: List of contexts this fixture belongs to. Both
|
||||||
|
:class:`Context` enum values and plain strings are accepted.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
```python
|
```python
|
||||||
@@ -90,7 +124,6 @@ class FixtureRegistry:
|
|||||||
@fixtures.register(depends_on=["roles"], contexts=[Context.TESTING])
|
@fixtures.register(depends_on=["roles"], contexts=[Context.TESTING])
|
||||||
def test_users():
|
def test_users():
|
||||||
return [User(id=1, username="test", role_id=1)]
|
return [User(id=1, username="test", role_id=1)]
|
||||||
```
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(
|
def decorator(
|
||||||
@@ -98,19 +131,20 @@ class FixtureRegistry:
|
|||||||
) -> Callable[[], Sequence[DeclarativeBase]]:
|
) -> Callable[[], Sequence[DeclarativeBase]]:
|
||||||
fixture_name = name or cast(Any, fn).__name__
|
fixture_name = name or cast(Any, fn).__name__
|
||||||
if contexts is not None:
|
if contexts is not None:
|
||||||
fixture_contexts = [
|
fixture_contexts = _normalize_contexts(contexts)
|
||||||
c.value if isinstance(c, Context) else c for c in contexts
|
|
||||||
]
|
|
||||||
elif self._default_contexts is not None:
|
elif self._default_contexts is not None:
|
||||||
fixture_contexts = self._default_contexts
|
fixture_contexts = self._default_contexts
|
||||||
else:
|
else:
|
||||||
fixture_contexts = [Context.BASE.value]
|
fixture_contexts = [Context.BASE.value]
|
||||||
|
|
||||||
self._fixtures[fixture_name] = Fixture(
|
self._validate_no_context_overlap(fixture_name, fixture_contexts)
|
||||||
name=fixture_name,
|
self._fixtures.setdefault(fixture_name, []).append(
|
||||||
func=fn,
|
Fixture(
|
||||||
depends_on=depends_on or [],
|
name=fixture_name,
|
||||||
contexts=fixture_contexts,
|
func=fn,
|
||||||
|
depends_on=depends_on or [],
|
||||||
|
contexts=fixture_contexts,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return fn
|
return fn
|
||||||
|
|
||||||
@@ -121,11 +155,14 @@ class FixtureRegistry:
|
|||||||
def include_registry(self, registry: "FixtureRegistry") -> None:
|
def include_registry(self, registry: "FixtureRegistry") -> None:
|
||||||
"""Include another `FixtureRegistry` in the same current `FixtureRegistry`.
|
"""Include another `FixtureRegistry` in the same current `FixtureRegistry`.
|
||||||
|
|
||||||
|
Fixtures with the same name are allowed as long as their context sets
|
||||||
|
do not overlap. Conflicting contexts raise :class:`ValueError`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
registry: The `FixtureRegistry` to include
|
registry: The `FixtureRegistry` to include
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If a fixture name already exists in the current registry
|
ValueError: If a fixture name already exists with overlapping contexts
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
```python
|
```python
|
||||||
@@ -139,31 +176,73 @@ class FixtureRegistry:
|
|||||||
registry.include_registry(registry=dev_registry)
|
registry.include_registry(registry=dev_registry)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
for name, fixture in registry._fixtures.items():
|
for name, variants in registry._fixtures.items():
|
||||||
if name in self._fixtures:
|
for fixture in variants:
|
||||||
raise ValueError(
|
self._validate_no_context_overlap(name, fixture.contexts)
|
||||||
f"Fixture '{name}' already exists in the current registry"
|
self._fixtures.setdefault(name, []).append(fixture)
|
||||||
)
|
|
||||||
self._fixtures[name] = fixture
|
|
||||||
|
|
||||||
def get(self, name: str) -> Fixture:
|
def get(self, name: str) -> Fixture:
|
||||||
"""Get a fixture by name."""
|
"""Get a fixture by name.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If no fixture with *name* is registered.
|
||||||
|
ValueError: If the fixture has multiple context variants — use
|
||||||
|
:meth:`get_variants` in that case.
|
||||||
|
"""
|
||||||
if name not in self._fixtures:
|
if name not in self._fixtures:
|
||||||
raise KeyError(f"Fixture '{name}' not found")
|
raise KeyError(f"Fixture '{name}' not found")
|
||||||
return self._fixtures[name]
|
variants = self._fixtures[name]
|
||||||
|
if len(variants) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Fixture '{name}' has {len(variants)} context variants. "
|
||||||
|
f"Use get_variants('{name}') to retrieve them."
|
||||||
|
)
|
||||||
|
return variants[0]
|
||||||
|
|
||||||
|
def get_variants(self, name: str, *contexts: str | Enum) -> list[Fixture]:
|
||||||
|
"""Return all registered variants for *name*, optionally filtered by context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Fixture name.
|
||||||
|
*contexts: If given, only return variants whose context set
|
||||||
|
intersects with these values. Both :class:`Context` enum
|
||||||
|
values and plain strings are accepted.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching :class:`Fixture` objects (may be empty when a
|
||||||
|
context filter is applied and nothing matches).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If no fixture with *name* is registered.
|
||||||
|
"""
|
||||||
|
if name not in self._fixtures:
|
||||||
|
raise KeyError(f"Fixture '{name}' not found")
|
||||||
|
variants = self._fixtures[name]
|
||||||
|
if not contexts:
|
||||||
|
return list(variants)
|
||||||
|
context_values = set(_normalize_contexts(contexts))
|
||||||
|
return [v for v in variants if set(v.contexts) & context_values]
|
||||||
|
|
||||||
def get_all(self) -> list[Fixture]:
|
def get_all(self) -> list[Fixture]:
|
||||||
"""Get all registered fixtures."""
|
"""Get all registered fixtures (all variants of all names)."""
|
||||||
return list(self._fixtures.values())
|
return [f for variants in self._fixtures.values() for f in variants]
|
||||||
|
|
||||||
def get_by_context(self, *contexts: str | Context) -> list[Fixture]:
|
def get_by_context(self, *contexts: str | Enum) -> list[Fixture]:
|
||||||
"""Get fixtures for specific contexts."""
|
"""Get fixtures for specific contexts."""
|
||||||
context_values = {c.value if isinstance(c, Context) else c for c in contexts}
|
context_values = set(_normalize_contexts(contexts))
|
||||||
return [f for f in self._fixtures.values() if set(f.contexts) & context_values]
|
return [
|
||||||
|
f
|
||||||
|
for variants in self._fixtures.values()
|
||||||
|
for f in variants
|
||||||
|
if set(f.contexts) & context_values
|
||||||
|
]
|
||||||
|
|
||||||
def resolve_dependencies(self, *names: str) -> list[str]:
|
def resolve_dependencies(self, *names: str) -> list[str]:
|
||||||
"""Resolve fixture dependencies in topological order.
|
"""Resolve fixture dependencies in topological order.
|
||||||
|
|
||||||
|
When a fixture name has multiple context variants, the union of all
|
||||||
|
variants' ``depends_on`` lists is used.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
*names: Fixture names to resolve
|
*names: Fixture names to resolve
|
||||||
|
|
||||||
@@ -185,9 +264,20 @@ class FixtureRegistry:
|
|||||||
raise ValueError(f"Circular dependency detected: {name}")
|
raise ValueError(f"Circular dependency detected: {name}")
|
||||||
|
|
||||||
visiting.add(name)
|
visiting.add(name)
|
||||||
fixture = self.get(name)
|
variants = self._fixtures.get(name)
|
||||||
|
if variants is None:
|
||||||
|
raise KeyError(f"Fixture '{name}' not found")
|
||||||
|
|
||||||
for dep in fixture.depends_on:
|
# Union of depends_on across all variants, preserving first-seen order.
|
||||||
|
seen_deps: set[str] = set()
|
||||||
|
all_deps: list[str] = []
|
||||||
|
for variant in variants:
|
||||||
|
for dep in variant.depends_on:
|
||||||
|
if dep not in seen_deps:
|
||||||
|
all_deps.append(dep)
|
||||||
|
seen_deps.add(dep)
|
||||||
|
|
||||||
|
for dep in all_deps:
|
||||||
visit(dep)
|
visit(dep)
|
||||||
|
|
||||||
visiting.remove(name)
|
visiting.remove(name)
|
||||||
@@ -199,7 +289,7 @@ class FixtureRegistry:
|
|||||||
|
|
||||||
return resolved
|
return resolved
|
||||||
|
|
||||||
def resolve_context_dependencies(self, *contexts: str | Context) -> list[str]:
|
def resolve_context_dependencies(self, *contexts: str | Enum) -> list[str]:
|
||||||
"""Resolve all fixtures for contexts with dependencies.
|
"""Resolve all fixtures for contexts with dependencies.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -209,7 +299,9 @@ class FixtureRegistry:
|
|||||||
List of fixture names in load order
|
List of fixture names in load order
|
||||||
"""
|
"""
|
||||||
context_fixtures = self.get_by_context(*contexts)
|
context_fixtures = self.get_by_context(*contexts)
|
||||||
names = [f.name for f in context_fixtures]
|
# Deduplicate names while preserving first-seen order (a name can
|
||||||
|
# appear multiple times if it has variants in different contexts).
|
||||||
|
names = list(dict.fromkeys(f.name for f in context_fixtures))
|
||||||
|
|
||||||
all_deps: set[str] = set()
|
all_deps: set[str] = set()
|
||||||
for name in names:
|
for name in names:
|
||||||
|
|||||||
@@ -1,8 +1,11 @@
|
|||||||
"""Fixture loading utilities for database seeding."""
|
"""Fixture loading utilities for database seeding."""
|
||||||
|
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import inspect as sa_inspect
|
||||||
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
@@ -10,23 +13,163 @@ from ..db import get_transaction
|
|||||||
from ..logger import get_logger
|
from ..logger import get_logger
|
||||||
from ..types import ModelType
|
from ..types import ModelType
|
||||||
from .enum import LoadStrategy
|
from .enum import LoadStrategy
|
||||||
from .registry import Context, FixtureRegistry
|
from .registry import FixtureRegistry, _normalize_contexts
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def _instance_to_dict(instance: DeclarativeBase) -> dict[str, Any]:
|
||||||
|
"""Extract column values from a model instance, skipping unset server-default columns."""
|
||||||
|
state = sa_inspect(instance)
|
||||||
|
state_dict = state.dict
|
||||||
|
result: dict[str, Any] = {}
|
||||||
|
for prop in state.mapper.column_attrs:
|
||||||
|
if prop.key not in state_dict:
|
||||||
|
continue
|
||||||
|
val = state_dict[prop.key]
|
||||||
|
if val is None:
|
||||||
|
col = prop.columns[0]
|
||||||
|
|
||||||
|
if col.server_default is not None or (
|
||||||
|
col.default is not None and col.default.is_callable
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
result[prop.key] = val
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_rows(dicts: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
"""Ensure all row dicts share the same key set."""
|
||||||
|
all_keys: set[str] = set().union(*dicts)
|
||||||
|
return [{k: d.get(k) for k in all_keys} for d in dicts]
|
||||||
|
|
||||||
|
|
||||||
|
def _group_by_type(
|
||||||
|
instances: list[DeclarativeBase],
|
||||||
|
) -> list[tuple[type[DeclarativeBase], list[DeclarativeBase]]]:
|
||||||
|
"""Group instances by their concrete model class, preserving insertion order."""
|
||||||
|
groups: dict[type[DeclarativeBase], list[DeclarativeBase]] = {}
|
||||||
|
for instance in instances:
|
||||||
|
groups.setdefault(type(instance), []).append(instance)
|
||||||
|
return list(groups.items())
|
||||||
|
|
||||||
|
|
||||||
|
async def _batch_insert(
|
||||||
|
session: AsyncSession,
|
||||||
|
model_cls: type[DeclarativeBase],
|
||||||
|
instances: list[DeclarativeBase],
|
||||||
|
) -> None:
|
||||||
|
"""INSERT all instances — raises on conflict (no duplicate handling)."""
|
||||||
|
dicts = _normalize_rows([_instance_to_dict(i) for i in instances])
|
||||||
|
await session.execute(pg_insert(model_cls).values(dicts))
|
||||||
|
|
||||||
|
|
||||||
|
async def _batch_merge(
|
||||||
|
session: AsyncSession,
|
||||||
|
model_cls: type[DeclarativeBase],
|
||||||
|
instances: list[DeclarativeBase],
|
||||||
|
) -> None:
|
||||||
|
"""UPSERT: insert new rows, update existing ones with the provided values."""
|
||||||
|
mapper = model_cls.__mapper__
|
||||||
|
pk_names = [col.name for col in mapper.primary_key]
|
||||||
|
pk_names_set = set(pk_names)
|
||||||
|
non_pk_cols = [
|
||||||
|
prop.key
|
||||||
|
for prop in mapper.column_attrs
|
||||||
|
if not any(col.name in pk_names_set for col in prop.columns)
|
||||||
|
]
|
||||||
|
|
||||||
|
dicts = _normalize_rows([_instance_to_dict(i) for i in instances])
|
||||||
|
stmt = pg_insert(model_cls).values(dicts)
|
||||||
|
|
||||||
|
inserted_keys = set(dicts[0]) if dicts else set()
|
||||||
|
update_cols = [col for col in non_pk_cols if col in inserted_keys]
|
||||||
|
|
||||||
|
if update_cols:
|
||||||
|
stmt = stmt.on_conflict_do_update(
|
||||||
|
index_elements=pk_names,
|
||||||
|
set_={col: stmt.excluded[col] for col in update_cols},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
stmt = stmt.on_conflict_do_nothing(index_elements=pk_names)
|
||||||
|
|
||||||
|
await session.execute(stmt)
|
||||||
|
|
||||||
|
|
||||||
|
async def _batch_skip_existing(
|
||||||
|
session: AsyncSession,
|
||||||
|
model_cls: type[DeclarativeBase],
|
||||||
|
instances: list[DeclarativeBase],
|
||||||
|
) -> list[DeclarativeBase]:
|
||||||
|
"""INSERT only rows that do not already exist; return the inserted ones."""
|
||||||
|
mapper = model_cls.__mapper__
|
||||||
|
pk_names = [col.name for col in mapper.primary_key]
|
||||||
|
|
||||||
|
no_pk: list[DeclarativeBase] = []
|
||||||
|
with_pk_pairs: list[tuple[DeclarativeBase, Any]] = []
|
||||||
|
for inst in instances:
|
||||||
|
pk = _get_primary_key(inst)
|
||||||
|
if pk is None:
|
||||||
|
no_pk.append(inst)
|
||||||
|
else:
|
||||||
|
with_pk_pairs.append((inst, pk))
|
||||||
|
|
||||||
|
loaded: list[DeclarativeBase] = list(no_pk)
|
||||||
|
if no_pk:
|
||||||
|
await session.execute(
|
||||||
|
pg_insert(model_cls).values(
|
||||||
|
_normalize_rows([_instance_to_dict(i) for i in no_pk])
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if with_pk_pairs:
|
||||||
|
with_pk = [i for i, _ in with_pk_pairs]
|
||||||
|
stmt = (
|
||||||
|
pg_insert(model_cls)
|
||||||
|
.values(_normalize_rows([_instance_to_dict(i) for i in with_pk]))
|
||||||
|
.on_conflict_do_nothing(index_elements=pk_names)
|
||||||
|
)
|
||||||
|
result = await session.execute(stmt.returning(*mapper.primary_key))
|
||||||
|
inserted_pks = {row[0] if len(pk_names) == 1 else tuple(row) for row in result}
|
||||||
|
loaded.extend(inst for inst, pk in with_pk_pairs if pk in inserted_pks)
|
||||||
|
|
||||||
|
return loaded
|
||||||
|
|
||||||
|
|
||||||
async def _load_ordered(
|
async def _load_ordered(
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
registry: FixtureRegistry,
|
registry: FixtureRegistry,
|
||||||
ordered_names: list[str],
|
ordered_names: list[str],
|
||||||
strategy: LoadStrategy,
|
strategy: LoadStrategy,
|
||||||
|
contexts: tuple[str, ...] | None = None,
|
||||||
) -> dict[str, list[DeclarativeBase]]:
|
) -> dict[str, list[DeclarativeBase]]:
|
||||||
"""Load fixtures in order."""
|
"""Load fixtures in order using batch Core INSERT statements.
|
||||||
|
|
||||||
|
When *contexts* is provided only variants whose context set intersects with
|
||||||
|
*contexts* are called for each name; their instances are concatenated.
|
||||||
|
When *contexts* is ``None`` all variants of each name are loaded.
|
||||||
|
"""
|
||||||
results: dict[str, list[DeclarativeBase]] = {}
|
results: dict[str, list[DeclarativeBase]] = {}
|
||||||
|
|
||||||
for name in ordered_names:
|
for name in ordered_names:
|
||||||
fixture = registry.get(name)
|
variants = (
|
||||||
instances = list(fixture.func())
|
registry.get_variants(name, *contexts)
|
||||||
|
if contexts is not None
|
||||||
|
else registry.get_variants(name)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cross-context dependency fallback: if we're loading by context but
|
||||||
|
# no variant matches (e.g. a "base"-only fixture required by a
|
||||||
|
# "testing" fixture), load all available variants so the dependency
|
||||||
|
# is satisfied.
|
||||||
|
if contexts is not None and not variants:
|
||||||
|
variants = registry.get_variants(name)
|
||||||
|
|
||||||
|
if not variants:
|
||||||
|
results[name] = []
|
||||||
|
continue
|
||||||
|
|
||||||
|
instances = [inst for v in variants for inst in v.func()]
|
||||||
|
|
||||||
if not instances:
|
if not instances:
|
||||||
results[name] = []
|
results[name] = []
|
||||||
@@ -36,25 +179,17 @@ async def _load_ordered(
|
|||||||
loaded: list[DeclarativeBase] = []
|
loaded: list[DeclarativeBase] = []
|
||||||
|
|
||||||
async with get_transaction(session):
|
async with get_transaction(session):
|
||||||
for instance in instances:
|
for model_cls, group in _group_by_type(instances):
|
||||||
if strategy == LoadStrategy.INSERT:
|
match strategy:
|
||||||
session.add(instance)
|
case LoadStrategy.INSERT:
|
||||||
loaded.append(instance)
|
await _batch_insert(session, model_cls, group)
|
||||||
|
loaded.extend(group)
|
||||||
elif strategy == LoadStrategy.MERGE:
|
case LoadStrategy.MERGE:
|
||||||
merged = await session.merge(instance)
|
await _batch_merge(session, model_cls, group)
|
||||||
loaded.append(merged)
|
loaded.extend(group)
|
||||||
|
case LoadStrategy.SKIP_EXISTING:
|
||||||
else: # LoadStrategy.SKIP_EXISTING
|
inserted = await _batch_skip_existing(session, model_cls, group)
|
||||||
pk = _get_primary_key(instance)
|
loaded.extend(inserted)
|
||||||
if pk is not None:
|
|
||||||
existing = await session.get(type(instance), pk)
|
|
||||||
if existing is None:
|
|
||||||
session.add(instance)
|
|
||||||
loaded.append(instance)
|
|
||||||
else:
|
|
||||||
session.add(instance)
|
|
||||||
loaded.append(instance)
|
|
||||||
|
|
||||||
results[name] = loaded
|
results[name] = loaded
|
||||||
logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")
|
logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")
|
||||||
@@ -109,6 +244,8 @@ async def load_fixtures(
|
|||||||
) -> dict[str, list[DeclarativeBase]]:
|
) -> dict[str, list[DeclarativeBase]]:
|
||||||
"""Load specific fixtures by name with dependencies.
|
"""Load specific fixtures by name with dependencies.
|
||||||
|
|
||||||
|
All context variants of each requested fixture are loaded and merged.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session: Database session
|
session: Database session
|
||||||
registry: Fixture registry
|
registry: Fixture registry
|
||||||
@@ -125,19 +262,27 @@ async def load_fixtures(
|
|||||||
async def load_fixtures_by_context(
|
async def load_fixtures_by_context(
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
registry: FixtureRegistry,
|
registry: FixtureRegistry,
|
||||||
*contexts: str | Context,
|
*contexts: str | Enum,
|
||||||
strategy: LoadStrategy = LoadStrategy.MERGE,
|
strategy: LoadStrategy = LoadStrategy.MERGE,
|
||||||
) -> dict[str, list[DeclarativeBase]]:
|
) -> dict[str, list[DeclarativeBase]]:
|
||||||
"""Load all fixtures for specific contexts.
|
"""Load all fixtures for specific contexts.
|
||||||
|
|
||||||
|
For each fixture name, only the variants whose context set intersects with
|
||||||
|
*contexts* are loaded. When a name has variants in multiple of the
|
||||||
|
requested contexts, their instances are merged before being inserted.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session: Database session
|
session: Database session
|
||||||
registry: Fixture registry
|
registry: Fixture registry
|
||||||
*contexts: Contexts to load (e.g., Context.BASE, Context.TESTING)
|
*contexts: Contexts to load (e.g., ``Context.BASE``, ``Context.TESTING``,
|
||||||
|
or plain strings for custom contexts)
|
||||||
strategy: How to handle existing records
|
strategy: How to handle existing records
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict mapping fixture names to loaded instances
|
Dict mapping fixture names to loaded instances
|
||||||
"""
|
"""
|
||||||
|
context_strings = tuple(_normalize_contexts(contexts))
|
||||||
ordered = registry.resolve_context_dependencies(*contexts)
|
ordered = registry.resolve_context_dependencies(*contexts)
|
||||||
return await _load_ordered(session, registry, ordered, strategy)
|
return await _load_ordered(
|
||||||
|
session, registry, ordered, strategy, contexts=context_strings
|
||||||
|
)
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ _SESSION_CREATES = "_ft_creates"
|
|||||||
_SESSION_DELETES = "_ft_deletes"
|
_SESSION_DELETES = "_ft_deletes"
|
||||||
_SESSION_UPDATES = "_ft_updates"
|
_SESSION_UPDATES = "_ft_updates"
|
||||||
_SESSION_SAVEPOINT_DEPTH = "_ft_sp_depth"
|
_SESSION_SAVEPOINT_DEPTH = "_ft_sp_depth"
|
||||||
|
_DEFERRED_STRATEGY_KEY = (("deferred", True), ("instrument", True))
|
||||||
|
|
||||||
|
|
||||||
class ModelEvent(str, Enum):
|
class ModelEvent(str, Enum):
|
||||||
@@ -60,11 +61,22 @@ def _snapshot_column_attrs(obj: Any) -> dict[str, Any]:
|
|||||||
"""Read currently-loaded column values into a plain dict."""
|
"""Read currently-loaded column values into a plain dict."""
|
||||||
state = sa_inspect(obj) # InstanceState
|
state = sa_inspect(obj) # InstanceState
|
||||||
state_dict = state.dict
|
state_dict = state.dict
|
||||||
return {
|
snapshot: dict[str, Any] = {}
|
||||||
prop.key: state_dict[prop.key]
|
for prop in state.mapper.column_attrs:
|
||||||
for prop in state.mapper.column_attrs
|
if prop.key in state_dict:
|
||||||
if prop.key in state_dict
|
snapshot[prop.key] = state_dict[prop.key]
|
||||||
}
|
elif (
|
||||||
|
not state.expired
|
||||||
|
and prop.strategy_key != _DEFERRED_STRATEGY_KEY
|
||||||
|
and all(
|
||||||
|
col.nullable
|
||||||
|
and col.server_default is None
|
||||||
|
and col.server_onupdate is None
|
||||||
|
for col in prop.columns
|
||||||
|
)
|
||||||
|
):
|
||||||
|
snapshot[prop.key] = None
|
||||||
|
return snapshot
|
||||||
|
|
||||||
|
|
||||||
def _get_watched_fields(cls: type) -> list[str] | None:
|
def _get_watched_fields(cls: type) -> list[str] | None:
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from contextlib import asynccontextmanager
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from httpx import ASGITransport, AsyncClient
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
from sqlalchemy import text
|
||||||
from sqlalchemy.engine import make_url
|
from sqlalchemy.engine import make_url
|
||||||
from sqlalchemy.ext.asyncio import (
|
from sqlalchemy.ext.asyncio import (
|
||||||
AsyncSession,
|
AsyncSession,
|
||||||
@@ -15,13 +16,8 @@ from sqlalchemy.ext.asyncio import (
|
|||||||
)
|
)
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
from sqlalchemy import text
|
from ..db import cleanup_tables as _cleanup_tables
|
||||||
|
from ..db import create_database
|
||||||
from ..db import (
|
|
||||||
cleanup_tables as _cleanup_tables,
|
|
||||||
create_database,
|
|
||||||
create_db_context,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def cleanup_tables(
|
async def cleanup_tables(
|
||||||
@@ -269,15 +265,12 @@ async def create_db_session(
|
|||||||
async with engine.begin() as conn:
|
async with engine.begin() as conn:
|
||||||
await conn.run_sync(base.metadata.create_all)
|
await conn.run_sync(base.metadata.create_all)
|
||||||
|
|
||||||
# Create session using existing db context utility
|
|
||||||
session_maker = async_sessionmaker(engine, expire_on_commit=expire_on_commit)
|
session_maker = async_sessionmaker(engine, expire_on_commit=expire_on_commit)
|
||||||
get_session = create_db_context(session_maker)
|
async with session_maker() as session:
|
||||||
|
|
||||||
async with get_session() as session:
|
|
||||||
yield session
|
yield session
|
||||||
|
|
||||||
if cleanup:
|
if cleanup:
|
||||||
await cleanup_tables(session, base)
|
await _cleanup_tables(session=session, base=base)
|
||||||
|
|
||||||
if drop_tables:
|
if drop_tables:
|
||||||
async with engine.begin() as conn:
|
async with engine.begin() as conn:
|
||||||
|
|||||||
@@ -57,6 +57,7 @@ class User(Base):
|
|||||||
username: Mapped[str] = mapped_column(String(50), unique=True)
|
username: Mapped[str] = mapped_column(String(50), unique=True)
|
||||||
email: Mapped[str] = mapped_column(String(100), unique=True)
|
email: Mapped[str] = mapped_column(String(100), unique=True)
|
||||||
is_active: Mapped[bool] = mapped_column(default=True)
|
is_active: Mapped[bool] = mapped_column(default=True)
|
||||||
|
notes: Mapped[str | None]
|
||||||
role_id: Mapped[uuid.UUID | None] = mapped_column(
|
role_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||||
ForeignKey("roles.id"), nullable=True
|
ForeignKey("roles.id"), nullable=True
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -646,7 +646,7 @@ class TestFacetsRelationship:
|
|||||||
result = await UserRelFacetCrud.offset_paginate(db_session, schema=UserRead)
|
result = await UserRelFacetCrud.offset_paginate(db_session, schema=UserRead)
|
||||||
|
|
||||||
assert result.filter_attributes is not None
|
assert result.filter_attributes is not None
|
||||||
assert set(result.filter_attributes["name"]) == {"admin", "editor"}
|
assert set(result.filter_attributes["role__name"]) == {"admin", "editor"}
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_relationship_facet_none_excluded(self, db_session: AsyncSession):
|
async def test_relationship_facet_none_excluded(self, db_session: AsyncSession):
|
||||||
@@ -661,7 +661,7 @@ class TestFacetsRelationship:
|
|||||||
result = await UserRelFacetCrud.offset_paginate(db_session, schema=UserRead)
|
result = await UserRelFacetCrud.offset_paginate(db_session, schema=UserRead)
|
||||||
|
|
||||||
assert result.filter_attributes is not None
|
assert result.filter_attributes is not None
|
||||||
assert result.filter_attributes["name"] == []
|
assert result.filter_attributes["role__name"] == []
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_relationship_facet_deduplicates_join_with_search(
|
async def test_relationship_facet_deduplicates_join_with_search(
|
||||||
@@ -689,7 +689,7 @@ class TestFacetsRelationship:
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert result.filter_attributes is not None
|
assert result.filter_attributes is not None
|
||||||
assert result.filter_attributes["name"] == ["admin"]
|
assert result.filter_attributes["role__name"] == ["admin"]
|
||||||
|
|
||||||
|
|
||||||
class TestFilterBy:
|
class TestFilterBy:
|
||||||
@@ -755,7 +755,7 @@ class TestFilterBy:
|
|||||||
)
|
)
|
||||||
|
|
||||||
result = await UserRelFacetCrud.offset_paginate(
|
result = await UserRelFacetCrud.offset_paginate(
|
||||||
db_session, filter_by={"name": "admin"}, schema=UserRead
|
db_session, filter_by={"role__name": "admin"}, schema=UserRead
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(result.pagination, OffsetPagination)
|
assert isinstance(result.pagination, OffsetPagination)
|
||||||
@@ -824,7 +824,7 @@ class TestFilterBy:
|
|||||||
|
|
||||||
result = await UserRoleFacetCrud.offset_paginate(
|
result = await UserRoleFacetCrud.offset_paginate(
|
||||||
db_session,
|
db_session,
|
||||||
filter_by={"name": "admin", "id": str(admin.id)},
|
filter_by={"role__name": "admin", "role__id": str(admin.id)},
|
||||||
schema=UserRead,
|
schema=UserRead,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -916,15 +916,15 @@ class TestFilterParamsSchema:
|
|||||||
param_names = set(inspect.signature(dep).parameters)
|
param_names = set(inspect.signature(dep).parameters)
|
||||||
assert param_names == {"username", "email"}
|
assert param_names == {"username", "email"}
|
||||||
|
|
||||||
def test_relationship_facet_uses_column_key(self):
|
def test_relationship_facet_uses_full_chain_key(self):
|
||||||
"""Relationship tuple uses the terminal column's key."""
|
"""Relationship tuple uses the full chain joined by __ as the key."""
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
UserRoleCrud = CrudFactory(User, facet_fields=[(User.role, Role.name)])
|
UserRoleCrud = CrudFactory(User, facet_fields=[(User.role, Role.name)])
|
||||||
dep = UserRoleCrud.filter_params()
|
dep = UserRoleCrud.filter_params()
|
||||||
|
|
||||||
param_names = set(inspect.signature(dep).parameters)
|
param_names = set(inspect.signature(dep).parameters)
|
||||||
assert param_names == {"name"}
|
assert param_names == {"role__name"}
|
||||||
|
|
||||||
def test_raises_when_no_facet_fields(self):
|
def test_raises_when_no_facet_fields(self):
|
||||||
"""ValueError raised when no facet_fields are configured or provided."""
|
"""ValueError raised when no facet_fields are configured or provided."""
|
||||||
@@ -978,6 +978,22 @@ class TestFilterParamsSchema:
|
|||||||
keys = facet_keys([(rel_a, col_a), (rel_b, col_b)])
|
keys = facet_keys([(rel_a, col_a), (rel_b, col_b)])
|
||||||
assert keys == ["project__name", "os__name"]
|
assert keys == ["project__name", "os__name"]
|
||||||
|
|
||||||
|
def test_deep_chain_joins_all_segments(self):
|
||||||
|
"""Three-element tuple produces all relation segments joined by __."""
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from fastapi_toolsets.crud.search import facet_keys
|
||||||
|
|
||||||
|
rel_a = MagicMock()
|
||||||
|
rel_a.key = "role"
|
||||||
|
rel_b = MagicMock()
|
||||||
|
rel_b.key = "permission"
|
||||||
|
col = MagicMock()
|
||||||
|
col.key = "name"
|
||||||
|
|
||||||
|
keys = facet_keys([(rel_a, rel_b, col)])
|
||||||
|
assert keys == ["role__permission__name"]
|
||||||
|
|
||||||
def test_unique_column_keys_kept_plain(self):
|
def test_unique_column_keys_kept_plain(self):
|
||||||
"""Fields with unique column keys are not prefixed."""
|
"""Fields with unique column keys are not prefixed."""
|
||||||
from fastapi_toolsets.crud.search import facet_keys
|
from fastapi_toolsets.crud.search import facet_keys
|
||||||
|
|||||||
@@ -97,7 +97,7 @@ class TestAppSessionDep:
|
|||||||
gen = get_db()
|
gen = get_db()
|
||||||
session = await gen.__anext__()
|
session = await gen.__anext__()
|
||||||
assert isinstance(session, AsyncSession)
|
assert isinstance(session, AsyncSession)
|
||||||
await session.close()
|
await gen.aclose()
|
||||||
|
|
||||||
|
|
||||||
class TestOffsetPagination:
|
class TestOffsetPagination:
|
||||||
@@ -182,8 +182,7 @@ class TestOffsetPagination:
|
|||||||
body = resp.json()
|
body = resp.json()
|
||||||
fa = body["filter_attributes"]
|
fa = body["filter_attributes"]
|
||||||
assert set(fa["status"]) == {"draft", "published"}
|
assert set(fa["status"]) == {"draft", "published"}
|
||||||
# "name" is unique across all facet fields — no prefix needed
|
assert set(fa["category__name"]) == {"backend", "python"}
|
||||||
assert set(fa["name"]) == {"backend", "python"}
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_filter_attributes_scoped_to_filter(
|
async def test_filter_attributes_scoped_to_filter(
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Tests for fastapi_toolsets.fixtures module."""
|
"""Tests for fastapi_toolsets.fixtures module."""
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
@@ -13,10 +14,22 @@ from fastapi_toolsets.fixtures import (
|
|||||||
load_fixtures,
|
load_fixtures,
|
||||||
load_fixtures_by_context,
|
load_fixtures_by_context,
|
||||||
)
|
)
|
||||||
|
from fastapi_toolsets.fixtures.utils import _get_primary_key, _instance_to_dict
|
||||||
|
|
||||||
from fastapi_toolsets.fixtures.utils import _get_primary_key
|
from .conftest import IntRole, Permission, Role, RoleCreate, RoleCrud, User, UserCrud
|
||||||
|
|
||||||
from .conftest import IntRole, Permission, Role, RoleCrud, User, UserCrud
|
|
||||||
|
class AppContext(str, Enum):
|
||||||
|
"""Example user-defined str+Enum context."""
|
||||||
|
|
||||||
|
STAGING = "staging"
|
||||||
|
DEMO = "demo"
|
||||||
|
|
||||||
|
|
||||||
|
class PlainEnumContext(Enum):
|
||||||
|
"""Example user-defined plain Enum context (no str mixin)."""
|
||||||
|
|
||||||
|
STAGING = "staging"
|
||||||
|
|
||||||
|
|
||||||
class TestContext:
|
class TestContext:
|
||||||
@@ -39,6 +52,86 @@ class TestContext:
|
|||||||
assert Context.TESTING.value == "testing"
|
assert Context.TESTING.value == "testing"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCustomEnumContext:
|
||||||
|
"""Custom Enum types are accepted wherever Context/str are expected."""
|
||||||
|
|
||||||
|
def test_cannot_subclass_context_with_members(self):
|
||||||
|
"""Python prohibits extending an Enum that already has members."""
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
|
||||||
|
class MyContext(Context): # noqa: F841 # ty: ignore[subclass-of-final-class]
|
||||||
|
STAGING = "staging"
|
||||||
|
|
||||||
|
def test_custom_enum_values_interchangeable_with_context(self):
|
||||||
|
"""A custom enum with the same .value as a built-in Context member is
|
||||||
|
treated as the same context — fixtures registered under one are found
|
||||||
|
by the other."""
|
||||||
|
|
||||||
|
class AppContextFull(str, Enum):
|
||||||
|
BASE = "base"
|
||||||
|
STAGING = "staging"
|
||||||
|
|
||||||
|
registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@registry.register(contexts=[Context.BASE])
|
||||||
|
def roles():
|
||||||
|
return []
|
||||||
|
|
||||||
|
# AppContextFull.BASE has value "base" — same as Context.BASE
|
||||||
|
fixtures = registry.get_by_context(AppContextFull.BASE)
|
||||||
|
assert len(fixtures) == 1
|
||||||
|
|
||||||
|
def test_custom_enum_registry_default_contexts(self):
|
||||||
|
"""FixtureRegistry(contexts=[...]) accepts a custom Enum."""
|
||||||
|
registry = FixtureRegistry(contexts=[AppContext.STAGING])
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def data():
|
||||||
|
return []
|
||||||
|
|
||||||
|
fixture = registry.get("data")
|
||||||
|
assert fixture.contexts == ["staging"]
|
||||||
|
|
||||||
|
def test_custom_enum_resolve_context_dependencies(self):
|
||||||
|
"""resolve_context_dependencies accepts a custom Enum context."""
|
||||||
|
registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@registry.register(contexts=[AppContext.STAGING])
|
||||||
|
def staging_roles():
|
||||||
|
return []
|
||||||
|
|
||||||
|
order = registry.resolve_context_dependencies(AppContext.STAGING)
|
||||||
|
assert "staging_roles" in order
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_custom_enum_e2e(self, db_session: AsyncSession):
|
||||||
|
"""End-to-end: register with custom Enum, load with the same Enum."""
|
||||||
|
registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@registry.register(contexts=[AppContext.STAGING])
|
||||||
|
def staging_roles():
|
||||||
|
return [Role(id=uuid.uuid4(), name="staging-admin")]
|
||||||
|
|
||||||
|
result = await load_fixtures_by_context(
|
||||||
|
db_session, registry, AppContext.STAGING
|
||||||
|
)
|
||||||
|
assert len(result["staging_roles"]) == 1
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_plain_enum_e2e(self, db_session: AsyncSession):
|
||||||
|
"""End-to-end: register with plain Enum, load with the same Enum."""
|
||||||
|
registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@registry.register(contexts=[PlainEnumContext.STAGING])
|
||||||
|
def staging_roles():
|
||||||
|
return [Role(id=uuid.uuid4(), name="plain-staging-admin")]
|
||||||
|
|
||||||
|
result = await load_fixtures_by_context(
|
||||||
|
db_session, registry, PlainEnumContext.STAGING
|
||||||
|
)
|
||||||
|
assert len(result["staging_roles"]) == 1
|
||||||
|
|
||||||
|
|
||||||
class TestLoadStrategy:
|
class TestLoadStrategy:
|
||||||
"""Tests for LoadStrategy enum."""
|
"""Tests for LoadStrategy enum."""
|
||||||
|
|
||||||
@@ -407,6 +500,37 @@ class TestDependencyResolution:
|
|||||||
with pytest.raises(ValueError, match="Circular dependency"):
|
with pytest.raises(ValueError, match="Circular dependency"):
|
||||||
registry.resolve_dependencies("a")
|
registry.resolve_dependencies("a")
|
||||||
|
|
||||||
|
def test_resolve_raises_for_unknown_dependency(self):
|
||||||
|
"""KeyError when depends_on references an unregistered fixture."""
|
||||||
|
registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@registry.register(depends_on=["ghost"])
|
||||||
|
def users():
|
||||||
|
return []
|
||||||
|
|
||||||
|
with pytest.raises(KeyError, match="ghost"):
|
||||||
|
registry.resolve_dependencies("users")
|
||||||
|
|
||||||
|
def test_resolve_deduplicates_shared_depends_on_across_variants(self):
|
||||||
|
"""A dep shared by two same-name variants appears only once in the order."""
|
||||||
|
registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@registry.register(contexts=[Context.BASE])
|
||||||
|
def roles():
|
||||||
|
return []
|
||||||
|
|
||||||
|
@registry.register(depends_on=["roles"], contexts=[Context.BASE])
|
||||||
|
def items():
|
||||||
|
return []
|
||||||
|
|
||||||
|
@registry.register(depends_on=["roles"], contexts=[Context.TESTING])
|
||||||
|
def items(): # noqa: F811
|
||||||
|
return []
|
||||||
|
|
||||||
|
order = registry.resolve_dependencies("items")
|
||||||
|
assert order.count("roles") == 1
|
||||||
|
assert order.index("roles") < order.index("items")
|
||||||
|
|
||||||
def test_resolve_context_dependencies(self):
|
def test_resolve_context_dependencies(self):
|
||||||
"""Resolve all fixtures for a context with dependencies."""
|
"""Resolve all fixtures for a context with dependencies."""
|
||||||
registry = FixtureRegistry()
|
registry = FixtureRegistry()
|
||||||
@@ -496,6 +620,52 @@ class TestLoadFixtures:
|
|||||||
count = await RoleCrud.count(db_session)
|
count = await RoleCrud.count(db_session)
|
||||||
assert count == 1
|
assert count == 1
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_merge_does_not_overwrite_omitted_nullable_columns(
|
||||||
|
self, db_session: AsyncSession
|
||||||
|
):
|
||||||
|
"""MERGE must not clear nullable columns that the fixture didn't set.
|
||||||
|
|
||||||
|
When a fixture omits a nullable column (e.g. role_id or notes), a re-merge
|
||||||
|
must leave the existing DB value untouched — not overwrite it with NULL.
|
||||||
|
"""
|
||||||
|
registry = FixtureRegistry()
|
||||||
|
admin = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||||
|
uid = uuid.uuid4()
|
||||||
|
|
||||||
|
# First load: user has role_id and notes set
|
||||||
|
@registry.register
|
||||||
|
def users():
|
||||||
|
return [
|
||||||
|
User(
|
||||||
|
id=uid,
|
||||||
|
username="alice",
|
||||||
|
email="a@test.com",
|
||||||
|
role_id=admin.id,
|
||||||
|
notes="original",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
await load_fixtures(db_session, registry, "users", strategy=LoadStrategy.MERGE)
|
||||||
|
|
||||||
|
# Second load: fixture omits role_id and notes
|
||||||
|
registry2 = FixtureRegistry()
|
||||||
|
|
||||||
|
@registry2.register
|
||||||
|
def users(): # noqa: F811
|
||||||
|
return [User(id=uid, username="alice-updated", email="a@test.com")]
|
||||||
|
|
||||||
|
await load_fixtures(db_session, registry2, "users", strategy=LoadStrategy.MERGE)
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
row = (
|
||||||
|
await db_session.execute(select(User).where(User.id == uid))
|
||||||
|
).scalar_one()
|
||||||
|
assert row.username == "alice-updated" # updated column changed
|
||||||
|
assert row.role_id == admin.id # omitted → preserved
|
||||||
|
assert row.notes == "original" # omitted → preserved
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_load_with_skip_existing_strategy(self, db_session: AsyncSession):
|
async def test_load_with_skip_existing_strategy(self, db_session: AsyncSession):
|
||||||
"""Load fixtures with SKIP_EXISTING strategy."""
|
"""Load fixtures with SKIP_EXISTING strategy."""
|
||||||
@@ -795,3 +965,145 @@ class TestGetPrimaryKey:
|
|||||||
instance = Permission(subject="post") # action is None
|
instance = Permission(subject="post") # action is None
|
||||||
pk = _get_primary_key(instance)
|
pk = _get_primary_key(instance)
|
||||||
assert pk is None
|
assert pk is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestRegistryGetVariants:
|
||||||
|
"""Tests for FixtureRegistry.get and get_variants edge cases."""
|
||||||
|
|
||||||
|
def test_get_raises_value_error_for_multi_variant(self):
|
||||||
|
"""get() raises ValueError when the fixture has multiple context variants."""
|
||||||
|
registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@registry.register(contexts=[Context.BASE])
|
||||||
|
def items():
|
||||||
|
return []
|
||||||
|
|
||||||
|
@registry.register(contexts=[Context.TESTING])
|
||||||
|
def items(): # noqa: F811
|
||||||
|
return []
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="get_variants"):
|
||||||
|
registry.get("items")
|
||||||
|
|
||||||
|
def test_get_variants_raises_key_error_for_unknown(self):
|
||||||
|
"""get_variants() raises KeyError for an unregistered name."""
|
||||||
|
registry = FixtureRegistry()
|
||||||
|
with pytest.raises(KeyError, match="not found"):
|
||||||
|
registry.get_variants("no_such_fixture")
|
||||||
|
|
||||||
|
|
||||||
|
class TestInstanceToDict:
|
||||||
|
"""Unit tests for the _instance_to_dict helper."""
|
||||||
|
|
||||||
|
def test_explicit_values_included(self):
|
||||||
|
"""All explicitly set column values appear in the result."""
|
||||||
|
role_id = uuid.uuid4()
|
||||||
|
instance = Role(id=role_id, name="admin")
|
||||||
|
d = _instance_to_dict(instance)
|
||||||
|
assert d["id"] == role_id
|
||||||
|
assert d["name"] == "admin"
|
||||||
|
|
||||||
|
def test_callable_default_none_excluded(self):
|
||||||
|
"""A column whose value is None but has a callable Python-side default
|
||||||
|
(e.g. ``default=uuid.uuid4``) is excluded so the DB generates it."""
|
||||||
|
instance = Role(id=None, name="admin")
|
||||||
|
d = _instance_to_dict(instance)
|
||||||
|
assert "id" not in d
|
||||||
|
assert d["name"] == "admin"
|
||||||
|
|
||||||
|
def test_nullable_none_included(self):
|
||||||
|
"""None on a nullable column with no default is kept (explicit NULL)."""
|
||||||
|
instance = User(id=uuid.uuid4(), username="u", email="e@e.com", role_id=None)
|
||||||
|
d = _instance_to_dict(instance)
|
||||||
|
assert "role_id" in d
|
||||||
|
assert d["role_id"] is None
|
||||||
|
|
||||||
|
def test_nullable_str_no_default_omitted_not_in_dict(self):
|
||||||
|
"""Mapped[str | None] with no default, not provided in constructor, is absent from dict."""
|
||||||
|
instance = User(id=uuid.uuid4(), username="u", email="e@e.com")
|
||||||
|
d = _instance_to_dict(instance)
|
||||||
|
assert "notes" not in d
|
||||||
|
|
||||||
|
def test_nullable_str_no_default_explicit_none_included(self):
|
||||||
|
"""Mapped[str | None] with no default, explicitly set to None, is included as NULL."""
|
||||||
|
instance = User(id=uuid.uuid4(), username="u", email="e@e.com", notes=None)
|
||||||
|
d = _instance_to_dict(instance)
|
||||||
|
assert "notes" in d
|
||||||
|
assert d["notes"] is None
|
||||||
|
|
||||||
|
def test_nullable_str_no_default_with_value_included(self):
|
||||||
|
"""Mapped[str | None] with no default and a value set is included normally."""
|
||||||
|
instance = User(id=uuid.uuid4(), username="u", email="e@e.com", notes="hello")
|
||||||
|
d = _instance_to_dict(instance)
|
||||||
|
assert d["notes"] == "hello"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_nullable_str_no_default_insert_roundtrip(
|
||||||
|
self, db_session: AsyncSession
|
||||||
|
):
|
||||||
|
"""Fixture loading works for models with Mapped[str | None] (no default).
|
||||||
|
|
||||||
|
Both the omitted-value (→ NULL) and explicit-None paths must insert without error.
|
||||||
|
"""
|
||||||
|
registry = FixtureRegistry()
|
||||||
|
|
||||||
|
uid_a = uuid.uuid4()
|
||||||
|
uid_b = uuid.uuid4()
|
||||||
|
uid_c = uuid.uuid4()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def users():
|
||||||
|
return [
|
||||||
|
User(
|
||||||
|
id=uid_a, username="no_notes", email="a@test.com"
|
||||||
|
), # notes omitted
|
||||||
|
User(
|
||||||
|
id=uid_b, username="null_notes", email="b@test.com", notes=None
|
||||||
|
), # explicit None
|
||||||
|
User(
|
||||||
|
id=uid_c, username="has_notes", email="c@test.com", notes="hi"
|
||||||
|
), # value set
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await load_fixtures(db_session, registry, "users")
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
rows = (
|
||||||
|
(await db_session.execute(select(User).order_by(User.username)))
|
||||||
|
.scalars()
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
by_username = {r.username: r for r in rows}
|
||||||
|
|
||||||
|
assert by_username["no_notes"].notes is None
|
||||||
|
assert by_username["null_notes"].notes is None
|
||||||
|
assert by_username["has_notes"].notes == "hi"
|
||||||
|
assert len(result["users"]) == 3
|
||||||
|
|
||||||
|
|
||||||
|
class TestBatchMergeNonPkColumns:
|
||||||
|
"""Batch MERGE on a model with no non-PK columns (PK-only table)."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_merge_pk_only_model(self, db_session: AsyncSession):
|
||||||
|
"""MERGE strategy on a PK-only model uses on_conflict_do_nothing."""
|
||||||
|
registry = FixtureRegistry()
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
def permissions():
|
||||||
|
return [
|
||||||
|
Permission(subject="post", action="read"),
|
||||||
|
Permission(subject="post", action="write"),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await load_fixtures(
|
||||||
|
db_session, registry, "permissions", strategy=LoadStrategy.MERGE
|
||||||
|
)
|
||||||
|
assert len(result["permissions"]) == 2
|
||||||
|
|
||||||
|
# Run again — conflicts are silently ignored.
|
||||||
|
result2 = await load_fixtures(
|
||||||
|
db_session, registry, "permissions", strategy=LoadStrategy.MERGE
|
||||||
|
)
|
||||||
|
assert len(result2["permissions"]) == 2
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ from fastapi_toolsets.models.watched import (
|
|||||||
_after_flush,
|
_after_flush,
|
||||||
_after_flush_postexec,
|
_after_flush_postexec,
|
||||||
_after_rollback,
|
_after_rollback,
|
||||||
|
_snapshot_column_attrs,
|
||||||
_task_error_handler,
|
_task_error_handler,
|
||||||
_upsert_changes,
|
_upsert_changes,
|
||||||
)
|
)
|
||||||
@@ -213,20 +214,36 @@ class AttrAccessModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
|
|||||||
__tablename__ = "mixin_attr_access_models"
|
__tablename__ = "mixin_attr_access_models"
|
||||||
|
|
||||||
name: Mapped[str] = mapped_column(String(50))
|
name: Mapped[str] = mapped_column(String(50))
|
||||||
|
callback_url: Mapped[str | None] = mapped_column(String(200), nullable=True)
|
||||||
|
|
||||||
async def on_create(self) -> None:
|
async def on_create(self) -> None:
|
||||||
_attr_access_events.append(
|
_attr_access_events.append(
|
||||||
{"event": "create", "id": self.id, "name": self.name}
|
{
|
||||||
|
"event": "create",
|
||||||
|
"id": self.id,
|
||||||
|
"name": self.name,
|
||||||
|
"callback_url": self.callback_url,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
async def on_delete(self) -> None:
|
async def on_delete(self) -> None:
|
||||||
_attr_access_events.append(
|
_attr_access_events.append(
|
||||||
{"event": "delete", "id": self.id, "name": self.name}
|
{
|
||||||
|
"event": "delete",
|
||||||
|
"id": self.id,
|
||||||
|
"name": self.name,
|
||||||
|
"callback_url": self.callback_url,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
async def on_update(self, changes: dict) -> None:
|
async def on_update(self, changes: dict) -> None:
|
||||||
_attr_access_events.append(
|
_attr_access_events.append(
|
||||||
{"event": "update", "id": self.id, "name": self.name}
|
{
|
||||||
|
"event": "update",
|
||||||
|
"id": self.id,
|
||||||
|
"name": self.name,
|
||||||
|
"callback_url": self.callback_url,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -1279,3 +1296,67 @@ class TestAttributeAccessInCallbacks:
|
|||||||
assert len(events) == 1
|
assert len(events) == 1
|
||||||
assert isinstance(events[0]["id"], uuid.UUID)
|
assert isinstance(events[0]["id"], uuid.UUID)
|
||||||
assert events[0]["name"] == "updated"
|
assert events[0]["name"] == "updated"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_nullable_column_none_accessible_in_on_create(
|
||||||
|
self, mixin_session_expire
|
||||||
|
):
|
||||||
|
"""Nullable column left as None is accessible in on_create without greenlet error."""
|
||||||
|
obj = AttrAccessModel(name="no-url") # callback_url not set → None
|
||||||
|
mixin_session_expire.add(obj)
|
||||||
|
await mixin_session_expire.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
events = [e for e in _attr_access_events if e["event"] == "create"]
|
||||||
|
assert len(events) == 1
|
||||||
|
assert events[0]["callback_url"] is None
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_nullable_column_with_value_accessible_in_on_create(
|
||||||
|
self, mixin_session_expire
|
||||||
|
):
|
||||||
|
"""Nullable column set to a value is accessible in on_create without greenlet error."""
|
||||||
|
obj = AttrAccessModel(name="with-url", callback_url="https://example.com/hook")
|
||||||
|
mixin_session_expire.add(obj)
|
||||||
|
await mixin_session_expire.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
events = [e for e in _attr_access_events if e["event"] == "create"]
|
||||||
|
assert len(events) == 1
|
||||||
|
assert events[0]["callback_url"] == "https://example.com/hook"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_nullable_column_accessible_after_update_to_none(
|
||||||
|
self, mixin_session_expire
|
||||||
|
):
|
||||||
|
"""Nullable column updated to None is accessible in on_update without greenlet error."""
|
||||||
|
obj = AttrAccessModel(name="x", callback_url="https://example.com/hook")
|
||||||
|
mixin_session_expire.add(obj)
|
||||||
|
await mixin_session_expire.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
_attr_access_events.clear()
|
||||||
|
|
||||||
|
obj.callback_url = None
|
||||||
|
await mixin_session_expire.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
events = [e for e in _attr_access_events if e["event"] == "update"]
|
||||||
|
assert len(events) == 1
|
||||||
|
assert events[0]["callback_url"] is None
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_expired_nullable_column_not_inferred_as_none(
|
||||||
|
self, mixin_session_expire
|
||||||
|
):
|
||||||
|
"""A nullable column with a real value that is expired (by a prior
|
||||||
|
expire_on_commit) must not be inferred as None in the snapshot — its
|
||||||
|
actual value is unknown without a DB refresh."""
|
||||||
|
obj = AttrAccessModel(name="original", callback_url="https://example.com/hook")
|
||||||
|
mixin_session_expire.add(obj)
|
||||||
|
await mixin_session_expire.commit()
|
||||||
|
# expire_on_commit fired → obj.state.expired=True, callback_url not in state.dict
|
||||||
|
|
||||||
|
snapshot = _snapshot_column_attrs(obj)
|
||||||
|
|
||||||
|
# callback_url has a real DB value but is expired — must not be snapshotted as None.
|
||||||
|
assert "callback_url" not in snapshot
|
||||||
|
|||||||
@@ -7,9 +7,10 @@ from fastapi import Depends, FastAPI
|
|||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
from sqlalchemy import select, text
|
from sqlalchemy import select, text
|
||||||
from sqlalchemy.engine import make_url
|
from sqlalchemy.engine import make_url
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
|
from fastapi_toolsets.db import get_transaction
|
||||||
from fastapi_toolsets.fixtures import Context, FixtureRegistry
|
from fastapi_toolsets.fixtures import Context, FixtureRegistry
|
||||||
from fastapi_toolsets.pytest import (
|
from fastapi_toolsets.pytest import (
|
||||||
create_async_client,
|
create_async_client,
|
||||||
@@ -336,6 +337,55 @@ class TestCreateDbSession:
|
|||||||
result = await session.execute(select(Role))
|
result = await session.execute(select(Role))
|
||||||
assert result.all() == []
|
assert result.all() == []
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_get_transaction_commits_visible_to_separate_session(self):
|
||||||
|
"""Data written via get_transaction() is committed and visible to other sessions."""
|
||||||
|
role_id = uuid.uuid4()
|
||||||
|
|
||||||
|
async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session:
|
||||||
|
# Simulate what _create_fixture_function does: insert via get_transaction
|
||||||
|
# with no explicit commit afterward.
|
||||||
|
async with get_transaction(session):
|
||||||
|
role = Role(id=role_id, name="visible_to_other_session")
|
||||||
|
session.add(role)
|
||||||
|
|
||||||
|
# The data must have been committed (begin/commit, not a savepoint),
|
||||||
|
# so a separate engine/session can read it.
|
||||||
|
other_engine = create_async_engine(DATABASE_URL, echo=False)
|
||||||
|
try:
|
||||||
|
other_session_maker = async_sessionmaker(
|
||||||
|
other_engine, expire_on_commit=False
|
||||||
|
)
|
||||||
|
async with other_session_maker() as other:
|
||||||
|
result = await other.execute(select(Role).where(Role.id == role_id))
|
||||||
|
fetched = result.scalar_one_or_none()
|
||||||
|
assert fetched is not None, (
|
||||||
|
"Fixture data inserted via get_transaction() must be committed "
|
||||||
|
"and visible to a separate session. If create_db_session uses "
|
||||||
|
"create_db_context, auto-begin forces get_transaction() into "
|
||||||
|
"savepoints instead of real commits."
|
||||||
|
)
|
||||||
|
assert fetched.name == "visible_to_other_session"
|
||||||
|
finally:
|
||||||
|
await other_engine.dispose()
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as _:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeprecatedCleanupTables:
|
||||||
|
"""Tests for the deprecated cleanup_tables re-export in fastapi_toolsets.pytest."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_emits_deprecation_warning(self):
|
||||||
|
"""cleanup_tables imported from fastapi_toolsets.pytest emits DeprecationWarning."""
|
||||||
|
from fastapi_toolsets.pytest.utils import cleanup_tables
|
||||||
|
|
||||||
|
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
|
||||||
|
with pytest.warns(DeprecationWarning, match="fastapi_toolsets.db"):
|
||||||
|
await cleanup_tables(session, Base)
|
||||||
|
|
||||||
|
|
||||||
class TestGetXdistWorker:
|
class TestGetXdistWorker:
|
||||||
"""Tests for _get_xdist_worker helper."""
|
"""Tests for _get_xdist_worker helper."""
|
||||||
|
|||||||
Reference in New Issue
Block a user