Compare commits

..

2 Commits

7 changed files with 163 additions and 39 deletions

View File

@@ -474,7 +474,7 @@ The distinct values are returned in the `filter_attributes` field of [`Paginated
"filter_attributes": { "filter_attributes": {
"status": ["active", "inactive"], "status": ["active", "inactive"],
"country": ["DE", "FR", "US"], "country": ["DE", "FR", "US"],
"name": ["admin", "editor", "viewer"] "role__name": ["admin", "editor", "viewer"]
} }
} }
``` ```
@@ -482,7 +482,7 @@ The distinct values are returned in the `filter_attributes` field of [`Paginated
Use `filter_by` to pass the client's chosen filter values directly — no need to build SQLAlchemy conditions by hand. Any unknown key raises [`InvalidFacetFilterError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.InvalidFacetFilterError). Use `filter_by` to pass the client's chosen filter values directly — no need to build SQLAlchemy conditions by hand. Any unknown key raises [`InvalidFacetFilterError`](../reference/exceptions.md#fastapi_toolsets.exceptions.exceptions.InvalidFacetFilterError).
!!! info "The keys in `filter_by` are the same keys the client received in `filter_attributes`." !!! info "The keys in `filter_by` are the same keys the client received in `filter_attributes`."
Keys are normally the terminal `column.key` (e.g. `"name"` for `Role.name`). When two facet fields share the same column key (e.g. `(Build.project, Project.name)` and `(Build.os, Os.name)`), the relationship name is prepended automatically: `"project__name"` and `"os__name"`. Keys use `__` as a separator for the full relationship chain. A direct column `User.status` produces `"status"`. A relationship tuple `(User.role, Role.name)` produces `"role__name"`. A deeper chain `(User.role, Role.permission, Permission.name)` produces `"role__permission__name"`.
`filter_by` and `filters` can be combined — both are applied with AND logic. `filter_by` and `filters` can be combined — both are applied with AND logic.
@@ -515,9 +515,9 @@ async def list_users(
Both single-value and multi-value query parameters work: Both single-value and multi-value query parameters work:
``` ```
GET /users?status=active → filter_by={"status": ["active"]} GET /users?status=active → filter_by={"status": ["active"]}
GET /users?status=active&country=FR → filter_by={"status": ["active"], "country": ["FR"]} GET /users?status=active&country=FR → filter_by={"status": ["active"], "country": ["FR"]}
GET /users?role=admin&role=editor → filter_by={"role": ["admin", "editor"]} (IN clause) GET /users?role__name=admin&role__name=editor → filter_by={"role__name": ["admin", "editor"]} (IN clause)
``` ```
## Sorting ## Sorting

View File

@@ -2,7 +2,6 @@
import asyncio import asyncio
import functools import functools
from collections import Counter
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass, replace from dataclasses import dataclass, replace
from typing import TYPE_CHECKING, Any, Literal from typing import TYPE_CHECKING, Any, Literal
@@ -151,7 +150,7 @@ def build_search_filters(
def facet_keys(facet_fields: Sequence[FacetFieldType]) -> list[str]: def facet_keys(facet_fields: Sequence[FacetFieldType]) -> list[str]:
"""Return a key for each facet field, disambiguating duplicate column keys. """Return a key for each facet field.
Args: Args:
facet_fields: Sequence of facet fields — either direct columns or facet_fields: Sequence of facet fields — either direct columns or
@@ -160,22 +159,12 @@ def facet_keys(facet_fields: Sequence[FacetFieldType]) -> list[str]:
Returns: Returns:
A list of string keys, one per facet field, in the same order. A list of string keys, one per facet field, in the same order.
""" """
raw: list[tuple[str, str | None]] = [] keys: list[str] = []
for field in facet_fields: for field in facet_fields:
if isinstance(field, tuple): if isinstance(field, tuple):
rel = field[-2] keys.append("__".join(el.key for el in field))
column = field[-1]
raw.append((column.key, rel.key))
else: else:
raw.append((field.key, None)) keys.append(field.key)
counts = Counter(col_key for col_key, _ in raw)
keys: list[str] = []
for col_key, rel_key in raw:
if counts[col_key] > 1 and rel_key is not None:
keys.append(f"{rel_key}__{col_key}")
else:
keys.append(col_key)
return keys return keys

View File

@@ -38,6 +38,12 @@ def _instance_to_dict(instance: DeclarativeBase) -> dict[str, Any]:
return result return result
def _normalize_rows(dicts: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Ensure all row dicts share the same key set."""
all_keys: set[str] = set().union(*dicts)
return [{k: d.get(k) for k in all_keys} for d in dicts]
def _group_by_type( def _group_by_type(
instances: list[DeclarativeBase], instances: list[DeclarativeBase],
) -> list[tuple[type[DeclarativeBase], list[DeclarativeBase]]]: ) -> list[tuple[type[DeclarativeBase], list[DeclarativeBase]]]:
@@ -54,7 +60,7 @@ async def _batch_insert(
instances: list[DeclarativeBase], instances: list[DeclarativeBase],
) -> None: ) -> None:
"""INSERT all instances — raises on conflict (no duplicate handling).""" """INSERT all instances — raises on conflict (no duplicate handling)."""
dicts = [_instance_to_dict(i) for i in instances] dicts = _normalize_rows([_instance_to_dict(i) for i in instances])
await session.execute(pg_insert(model_cls).values(dicts)) await session.execute(pg_insert(model_cls).values(dicts))
@@ -73,13 +79,16 @@ async def _batch_merge(
if not any(col.name in pk_names_set for col in prop.columns) if not any(col.name in pk_names_set for col in prop.columns)
] ]
dicts = [_instance_to_dict(i) for i in instances] dicts = _normalize_rows([_instance_to_dict(i) for i in instances])
stmt = pg_insert(model_cls).values(dicts) stmt = pg_insert(model_cls).values(dicts)
if non_pk_cols: inserted_keys = set(dicts[0]) if dicts else set()
update_cols = [col for col in non_pk_cols if col in inserted_keys]
if update_cols:
stmt = stmt.on_conflict_do_update( stmt = stmt.on_conflict_do_update(
index_elements=pk_names, index_elements=pk_names,
set_={col: stmt.excluded[col] for col in non_pk_cols}, set_={col: stmt.excluded[col] for col in update_cols},
) )
else: else:
stmt = stmt.on_conflict_do_nothing(index_elements=pk_names) stmt = stmt.on_conflict_do_nothing(index_elements=pk_names)
@@ -108,14 +117,16 @@ async def _batch_skip_existing(
loaded: list[DeclarativeBase] = list(no_pk) loaded: list[DeclarativeBase] = list(no_pk)
if no_pk: if no_pk:
await session.execute( await session.execute(
pg_insert(model_cls).values([_instance_to_dict(i) for i in no_pk]) pg_insert(model_cls).values(
_normalize_rows([_instance_to_dict(i) for i in no_pk])
)
) )
if with_pk_pairs: if with_pk_pairs:
with_pk = [i for i, _ in with_pk_pairs] with_pk = [i for i, _ in with_pk_pairs]
stmt = ( stmt = (
pg_insert(model_cls) pg_insert(model_cls)
.values([_instance_to_dict(i) for i in with_pk]) .values(_normalize_rows([_instance_to_dict(i) for i in with_pk]))
.on_conflict_do_nothing(index_elements=pk_names) .on_conflict_do_nothing(index_elements=pk_names)
) )
result = await session.execute(stmt.returning(*mapper.primary_key)) result = await session.execute(stmt.returning(*mapper.primary_key))

View File

@@ -57,6 +57,7 @@ class User(Base):
username: Mapped[str] = mapped_column(String(50), unique=True) username: Mapped[str] = mapped_column(String(50), unique=True)
email: Mapped[str] = mapped_column(String(100), unique=True) email: Mapped[str] = mapped_column(String(100), unique=True)
is_active: Mapped[bool] = mapped_column(default=True) is_active: Mapped[bool] = mapped_column(default=True)
notes: Mapped[str | None]
role_id: Mapped[uuid.UUID | None] = mapped_column( role_id: Mapped[uuid.UUID | None] = mapped_column(
ForeignKey("roles.id"), nullable=True ForeignKey("roles.id"), nullable=True
) )

View File

@@ -646,7 +646,7 @@ class TestFacetsRelationship:
result = await UserRelFacetCrud.offset_paginate(db_session, schema=UserRead) result = await UserRelFacetCrud.offset_paginate(db_session, schema=UserRead)
assert result.filter_attributes is not None assert result.filter_attributes is not None
assert set(result.filter_attributes["name"]) == {"admin", "editor"} assert set(result.filter_attributes["role__name"]) == {"admin", "editor"}
@pytest.mark.anyio @pytest.mark.anyio
async def test_relationship_facet_none_excluded(self, db_session: AsyncSession): async def test_relationship_facet_none_excluded(self, db_session: AsyncSession):
@@ -661,7 +661,7 @@ class TestFacetsRelationship:
result = await UserRelFacetCrud.offset_paginate(db_session, schema=UserRead) result = await UserRelFacetCrud.offset_paginate(db_session, schema=UserRead)
assert result.filter_attributes is not None assert result.filter_attributes is not None
assert result.filter_attributes["name"] == [] assert result.filter_attributes["role__name"] == []
@pytest.mark.anyio @pytest.mark.anyio
async def test_relationship_facet_deduplicates_join_with_search( async def test_relationship_facet_deduplicates_join_with_search(
@@ -689,7 +689,7 @@ class TestFacetsRelationship:
) )
assert result.filter_attributes is not None assert result.filter_attributes is not None
assert result.filter_attributes["name"] == ["admin"] assert result.filter_attributes["role__name"] == ["admin"]
class TestFilterBy: class TestFilterBy:
@@ -755,7 +755,7 @@ class TestFilterBy:
) )
result = await UserRelFacetCrud.offset_paginate( result = await UserRelFacetCrud.offset_paginate(
db_session, filter_by={"name": "admin"}, schema=UserRead db_session, filter_by={"role__name": "admin"}, schema=UserRead
) )
assert isinstance(result.pagination, OffsetPagination) assert isinstance(result.pagination, OffsetPagination)
@@ -824,7 +824,7 @@ class TestFilterBy:
result = await UserRoleFacetCrud.offset_paginate( result = await UserRoleFacetCrud.offset_paginate(
db_session, db_session,
filter_by={"name": "admin", "id": str(admin.id)}, filter_by={"role__name": "admin", "role__id": str(admin.id)},
schema=UserRead, schema=UserRead,
) )
@@ -916,15 +916,15 @@ class TestFilterParamsSchema:
param_names = set(inspect.signature(dep).parameters) param_names = set(inspect.signature(dep).parameters)
assert param_names == {"username", "email"} assert param_names == {"username", "email"}
def test_relationship_facet_uses_column_key(self): def test_relationship_facet_uses_full_chain_key(self):
"""Relationship tuple uses the terminal column's key.""" """Relationship tuple uses the full chain joined by __ as the key."""
import inspect import inspect
UserRoleCrud = CrudFactory(User, facet_fields=[(User.role, Role.name)]) UserRoleCrud = CrudFactory(User, facet_fields=[(User.role, Role.name)])
dep = UserRoleCrud.filter_params() dep = UserRoleCrud.filter_params()
param_names = set(inspect.signature(dep).parameters) param_names = set(inspect.signature(dep).parameters)
assert param_names == {"name"} assert param_names == {"role__name"}
def test_raises_when_no_facet_fields(self): def test_raises_when_no_facet_fields(self):
"""ValueError raised when no facet_fields are configured or provided.""" """ValueError raised when no facet_fields are configured or provided."""
@@ -978,6 +978,22 @@ class TestFilterParamsSchema:
keys = facet_keys([(rel_a, col_a), (rel_b, col_b)]) keys = facet_keys([(rel_a, col_a), (rel_b, col_b)])
assert keys == ["project__name", "os__name"] assert keys == ["project__name", "os__name"]
def test_deep_chain_joins_all_segments(self):
"""Three-element tuple produces all relation segments joined by __."""
from unittest.mock import MagicMock
from fastapi_toolsets.crud.search import facet_keys
rel_a = MagicMock()
rel_a.key = "role"
rel_b = MagicMock()
rel_b.key = "permission"
col = MagicMock()
col.key = "name"
keys = facet_keys([(rel_a, rel_b, col)])
assert keys == ["role__permission__name"]
def test_unique_column_keys_kept_plain(self): def test_unique_column_keys_kept_plain(self):
"""Fields with unique column keys are not prefixed.""" """Fields with unique column keys are not prefixed."""
from fastapi_toolsets.crud.search import facet_keys from fastapi_toolsets.crud.search import facet_keys

View File

@@ -182,8 +182,7 @@ class TestOffsetPagination:
body = resp.json() body = resp.json()
fa = body["filter_attributes"] fa = body["filter_attributes"]
assert set(fa["status"]) == {"draft", "published"} assert set(fa["status"]) == {"draft", "published"}
# "name" is unique across all facet fields — no prefix needed assert set(fa["category__name"]) == {"backend", "python"}
assert set(fa["name"]) == {"backend", "python"}
@pytest.mark.anyio @pytest.mark.anyio
async def test_filter_attributes_scoped_to_filter( async def test_filter_attributes_scoped_to_filter(

View File

@@ -14,10 +14,9 @@ from fastapi_toolsets.fixtures import (
load_fixtures, load_fixtures,
load_fixtures_by_context, load_fixtures_by_context,
) )
from fastapi_toolsets.fixtures.utils import _get_primary_key, _instance_to_dict from fastapi_toolsets.fixtures.utils import _get_primary_key, _instance_to_dict
from .conftest import IntRole, Permission, Role, RoleCrud, User, UserCrud from .conftest import IntRole, Permission, Role, RoleCreate, RoleCrud, User, UserCrud
class AppContext(str, Enum): class AppContext(str, Enum):
@@ -621,6 +620,52 @@ class TestLoadFixtures:
count = await RoleCrud.count(db_session) count = await RoleCrud.count(db_session)
assert count == 1 assert count == 1
@pytest.mark.anyio
async def test_merge_does_not_overwrite_omitted_nullable_columns(
self, db_session: AsyncSession
):
"""MERGE must not clear nullable columns that the fixture didn't set.
When a fixture omits a nullable column (e.g. role_id or notes), a re-merge
must leave the existing DB value untouched — not overwrite it with NULL.
"""
registry = FixtureRegistry()
admin = await RoleCrud.create(db_session, RoleCreate(name="admin"))
uid = uuid.uuid4()
# First load: user has role_id and notes set
@registry.register
def users():
return [
User(
id=uid,
username="alice",
email="a@test.com",
role_id=admin.id,
notes="original",
)
]
await load_fixtures(db_session, registry, "users", strategy=LoadStrategy.MERGE)
# Second load: fixture omits role_id and notes
registry2 = FixtureRegistry()
@registry2.register
def users(): # noqa: F811
return [User(id=uid, username="alice-updated", email="a@test.com")]
await load_fixtures(db_session, registry2, "users", strategy=LoadStrategy.MERGE)
from sqlalchemy import select
row = (
await db_session.execute(select(User).where(User.id == uid))
).scalar_one()
assert row.username == "alice-updated" # updated column changed
assert row.role_id == admin.id # omitted → preserved
assert row.notes == "original" # omitted → preserved
@pytest.mark.anyio @pytest.mark.anyio
async def test_load_with_skip_existing_strategy(self, db_session: AsyncSession): async def test_load_with_skip_existing_strategy(self, db_session: AsyncSession):
"""Load fixtures with SKIP_EXISTING strategy.""" """Load fixtures with SKIP_EXISTING strategy."""
@@ -973,6 +1018,69 @@ class TestInstanceToDict:
assert "role_id" in d assert "role_id" in d
assert d["role_id"] is None assert d["role_id"] is None
def test_nullable_str_no_default_omitted_not_in_dict(self):
"""Mapped[str | None] with no default, not provided in constructor, is absent from dict."""
instance = User(id=uuid.uuid4(), username="u", email="e@e.com")
d = _instance_to_dict(instance)
assert "notes" not in d
def test_nullable_str_no_default_explicit_none_included(self):
"""Mapped[str | None] with no default, explicitly set to None, is included as NULL."""
instance = User(id=uuid.uuid4(), username="u", email="e@e.com", notes=None)
d = _instance_to_dict(instance)
assert "notes" in d
assert d["notes"] is None
def test_nullable_str_no_default_with_value_included(self):
"""Mapped[str | None] with no default and a value set is included normally."""
instance = User(id=uuid.uuid4(), username="u", email="e@e.com", notes="hello")
d = _instance_to_dict(instance)
assert d["notes"] == "hello"
@pytest.mark.anyio
async def test_nullable_str_no_default_insert_roundtrip(
self, db_session: AsyncSession
):
"""Fixture loading works for models with Mapped[str | None] (no default).
Both the omitted-value (→ NULL) and explicit-None paths must insert without error.
"""
registry = FixtureRegistry()
uid_a = uuid.uuid4()
uid_b = uuid.uuid4()
uid_c = uuid.uuid4()
@registry.register
def users():
return [
User(
id=uid_a, username="no_notes", email="a@test.com"
), # notes omitted
User(
id=uid_b, username="null_notes", email="b@test.com", notes=None
), # explicit None
User(
id=uid_c, username="has_notes", email="c@test.com", notes="hi"
), # value set
]
result = await load_fixtures(db_session, registry, "users")
from sqlalchemy import select
rows = (
(await db_session.execute(select(User).order_by(User.username)))
.scalars()
.all()
)
by_username = {r.username: r for r in rows}
assert by_username["no_notes"].notes is None
assert by_username["null_notes"].notes is None
assert by_username["has_notes"].notes == "hi"
assert len(result["users"]) == 3
class TestBatchMergeNonPkColumns: class TestBatchMergeNonPkColumns:
"""Batch MERGE on a model with no non-PK columns (PK-only table).""" """Batch MERGE on a model with no non-PK columns (PK-only table)."""