Compare commits

4 Commits

Author SHA1 Message Date
94e7d79d06 Version 3.1.0 2026-04-12 12:48:32 -04:00
d3vyce
9268b576b4 feat: add M2M helpers (#247) 2026-04-12 18:46:57 +02:00
d3vyce
863e6ce6e9 feat: add get_field_by_attr fixtures helper (#245) 2026-04-12 17:12:08 +02:00
d3vyce
c7397faea4 feat: auto eager-load relationships in register_fixtures (#243) 2026-04-12 17:04:44 +02:00
12 changed files with 1004 additions and 25 deletions

View File

@@ -118,6 +118,57 @@ async def clean(db_session):
await cleanup_tables(session=db_session, base=Base)
```
## Many-to-Many helpers
SQLAlchemy's ORM collection API triggers lazy-loads when you append to a relationship inside a savepoint (e.g. inside `lock_tables` or a nested `get_transaction`). The three `m2m_*` helpers bypass the ORM collection entirely and issue direct SQL against the association table.
### `m2m_add` — insert associations
[`m2m_add`](../reference/db.md#fastapi_toolsets.db.m2m_add) inserts one or more rows into a secondary table without touching the ORM collection:
```python
from fastapi_toolsets.db import lock_tables, m2m_add
async with lock_tables(session, [Tag]):
tag = await TagCrud.create(session, TagCreate(name="python"))
await m2m_add(session, post, Post.tags, tag)
```
Pass `ignore_conflicts=True` to silently skip associations that already exist:
```python
await m2m_add(session, post, Post.tags, tag, ignore_conflicts=True)
```
### `m2m_remove` — delete associations
[`m2m_remove`](../reference/db.md#fastapi_toolsets.db.m2m_remove) deletes specific association rows. Removing a non-existent association is a no-op:
```python
from fastapi_toolsets.db import get_transaction, m2m_remove
async with get_transaction(session):
await m2m_remove(session, post, Post.tags, tag1, tag2)
```
### `m2m_set` — replace the full set
[`m2m_set`](../reference/db.md#fastapi_toolsets.db.m2m_set) atomically replaces all associations: it deletes every existing row for the owner instance then inserts the new set. Passing no related instances clears the association entirely:
```python
from fastapi_toolsets.db import get_transaction, m2m_set
# Replace all tags
async with get_transaction(session):
await m2m_set(session, post, Post.tags, tag_a, tag_b)
# Clear all tags
async with get_transaction(session):
await m2m_set(session, post, Post.tags)
```
All three helpers raise `TypeError` if the relationship attribute is not a Many-to-Many (i.e. has no secondary table).
---
[:material-api: API Reference](../reference/db.md)

View File

@@ -13,6 +13,9 @@ from fastapi_toolsets.db import (
create_db_context,
get_transaction,
lock_tables,
m2m_add,
m2m_remove,
m2m_set,
wait_for_row_change,
)
```
@@ -32,3 +35,9 @@ from fastapi_toolsets.db import (
## ::: fastapi_toolsets.db.create_database
## ::: fastapi_toolsets.db.cleanup_tables
## ::: fastapi_toolsets.db.m2m_add
## ::: fastapi_toolsets.db.m2m_remove
## ::: fastapi_toolsets.db.m2m_set

View File

@@ -1,6 +1,6 @@
[project]
name = "fastapi-toolsets"
version = "3.0.3"
version = "3.1.0"
description = "Production-ready utilities for FastAPI applications"
readme = "README.md"
license = "MIT"

View File

@@ -21,4 +21,4 @@ Example usage:
return Response(data={"user": user.username}, message="Success")
"""
__version__ = "3.0.3"
__version__ = "3.1.0"

View File

@@ -4,11 +4,13 @@ import asyncio
from collections.abc import AsyncGenerator, Callable
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from enum import Enum
from typing import Any, TypeVar
from typing import Any, TypeVar, cast
from sqlalchemy import text
from sqlalchemy import Table, delete, text, tuple_
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import DeclarativeBase, QueryableAttribute
from sqlalchemy.orm.relationships import RelationshipProperty
from .exceptions import NotFoundError
@@ -20,6 +22,9 @@ __all__ = [
"create_db_dependency",
"get_transaction",
"lock_tables",
"m2m_add",
"m2m_remove",
"m2m_set",
"wait_for_row_change",
]
@@ -339,3 +344,140 @@ async def wait_for_row_change(
current = {col: getattr(instance, col) for col in watch_cols}
if current != initial:
return instance
def _m2m_prop(rel_attr: QueryableAttribute) -> RelationshipProperty: # type: ignore[type-arg]
"""Return the validated M2M RelationshipProperty for *rel_attr*.
Raises TypeError if *rel_attr* is not a Many-to-Many relationship.
"""
prop = rel_attr.property
if not isinstance(prop, RelationshipProperty) or prop.secondary is None:
raise TypeError(
f"m2m helpers require a Many-to-Many relationship attribute, "
f"got {rel_attr!r}. Use a relationship with a secondary table."
)
return prop
async def m2m_add(
session: AsyncSession,
instance: DeclarativeBase,
rel_attr: QueryableAttribute,
*related: DeclarativeBase,
ignore_conflicts: bool = False,
) -> None:
"""Insert rows into a Many-to-Many association table without loading the ORM collection.
Args:
session: DB async session.
instance: The "owner" side model instance (e.g. the ``A`` in ``A.b_list``).
rel_attr: The M2M relationship attribute on the model class (e.g. ``A.b_list``).
*related: One or more related instances to associate with ``instance``.
ignore_conflicts: When ``True``, silently skip rows that already exist
in the association table (``ON CONFLICT DO NOTHING``).
Raises:
TypeError: If ``rel_attr`` is not a Many-to-Many relationship.
"""
prop = _m2m_prop(rel_attr)
if not related:
return
secondary = cast(Table, prop.secondary)
assert secondary is not None # guaranteed by _m2m_prop
sync_pairs = prop.secondary_synchronize_pairs
assert sync_pairs is not None # set whenever secondary is set
# synchronize_pairs: [(parent_col, assoc_col), ...]
# secondary_synchronize_pairs: [(related_col, assoc_col), ...]
rows: list[dict[str, Any]] = []
for rel_instance in related:
row: dict[str, Any] = {}
for parent_col, assoc_col in prop.synchronize_pairs:
row[assoc_col.name] = getattr(instance, cast(str, parent_col.key))
for related_col, assoc_col in sync_pairs:
row[assoc_col.name] = getattr(rel_instance, cast(str, related_col.key))
rows.append(row)
stmt = pg_insert(secondary).values(rows)
if ignore_conflicts:
stmt = stmt.on_conflict_do_nothing()
await session.execute(stmt)
async def m2m_remove(
session: AsyncSession,
instance: DeclarativeBase,
rel_attr: QueryableAttribute,
*related: DeclarativeBase,
) -> None:
"""Remove rows from a Many-to-Many association table without loading the ORM collection.
Args:
session: DB async session.
instance: The "owner" side model instance (e.g. the ``A`` in ``A.b_list``).
rel_attr: The M2M relationship attribute on the model class (e.g. ``A.b_list``).
*related: One or more related instances to disassociate from ``instance``.
Raises:
TypeError: If ``rel_attr`` is not a Many-to-Many relationship.
"""
prop = _m2m_prop(rel_attr)
if not related:
return
secondary = cast(Table, prop.secondary)
assert secondary is not None # guaranteed by _m2m_prop
related_pairs = prop.secondary_synchronize_pairs
assert related_pairs is not None # set whenever secondary is set
parent_where = [
assoc_col == getattr(instance, cast(str, parent_col.key))
for parent_col, assoc_col in prop.synchronize_pairs
]
if len(related_pairs) == 1:
related_col, assoc_col = related_pairs[0]
related_values = [getattr(r, cast(str, related_col.key)) for r in related]
related_where = assoc_col.in_(related_values)
else:
assoc_cols = [ac for _, ac in related_pairs]
rel_cols = [rc for rc, _ in related_pairs]
related_values_t = [
tuple(getattr(r, cast(str, rc.key)) for rc in rel_cols) for r in related
]
related_where = tuple_(*assoc_cols).in_(related_values_t)
await session.execute(delete(secondary).where(*parent_where, related_where))
async def m2m_set(
session: AsyncSession,
instance: DeclarativeBase,
rel_attr: QueryableAttribute,
*related: DeclarativeBase,
) -> None:
"""Replace the entire Many-to-Many association set atomically.
Args:
session: DB async session.
instance: The "owner" side model instance (e.g. the ``A`` in ``A.b_list``).
rel_attr: The M2M relationship attribute on the model class (e.g. ``A.b_list``).
*related: The new complete set of related instances.
Raises:
TypeError: If ``rel_attr`` is not a Many-to-Many relationship.
"""
prop = _m2m_prop(rel_attr)
secondary = cast(Table, prop.secondary)
assert secondary is not None # guaranteed by _m2m_prop
parent_where = [
assoc_col == getattr(instance, cast(str, parent_col.key))
for parent_col, assoc_col in prop.synchronize_pairs
]
await session.execute(delete(secondary).where(*parent_where))
if related:
await m2m_add(session, instance, rel_attr, *related)

View File

@@ -2,12 +2,18 @@
from .enum import LoadStrategy
from .registry import Context, FixtureRegistry
from .utils import get_obj_by_attr, load_fixtures, load_fixtures_by_context
from .utils import (
get_field_by_attr,
get_obj_by_attr,
load_fixtures,
load_fixtures_by_context,
)
__all__ = [
"Context",
"FixtureRegistry",
"LoadStrategy",
"get_field_by_attr",
"get_obj_by_attr",
"load_fixtures",
"load_fixtures_by_context",

View File

@@ -250,6 +250,31 @@ def get_obj_by_attr(
) from None
def get_field_by_attr(
fixtures: Callable[[], Sequence[ModelType]],
attr_name: str,
value: Any,
*,
field: str = "id",
) -> Any:
"""Get a single field value from a fixture object matched by an attribute.
Args:
fixtures: A fixture function registered via ``@registry.register``
that returns a sequence of SQLAlchemy model instances.
attr_name: Name of the attribute to match against.
value: Value to match.
field: Attribute name to return from the matched object (default: ``"id"``).
Returns:
The value of ``field`` on the first matching model instance.
Raises:
StopIteration: If no matching object is found in the fixture group.
"""
return getattr(get_obj_by_attr(fixtures, attr_name, value), field)
async def load_fixtures(
session: AsyncSession,
registry: FixtureRegistry,

View File

@@ -1,11 +1,13 @@
"""Pytest plugin for using FixtureRegistry fixtures in tests."""
from collections.abc import Callable, Sequence
from typing import Any
from typing import Any, cast
import pytest
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import DeclarativeBase, selectinload
from sqlalchemy.orm.interfaces import ExecutableOption, ORMOption
from ..db import get_transaction
from ..fixtures import FixtureRegistry, LoadStrategy
@@ -112,7 +114,7 @@ def _create_fixture_function(
elif strategy == LoadStrategy.MERGE:
merged = await session.merge(instance)
loaded.append(merged)
elif strategy == LoadStrategy.SKIP_EXISTING:
elif strategy == LoadStrategy.SKIP_EXISTING: # pragma: no branch
pk = _get_primary_key(instance)
if pk is not None:
existing = await session.get(type(instance), pk)
@@ -125,6 +127,11 @@ def _create_fixture_function(
session.add(instance)
loaded.append(instance)
if loaded: # pragma: no branch
load_options = _relationship_load_options(type(loaded[0]))
if load_options:
return await _reload_with_relationships(session, loaded, load_options)
return loaded
# Update function signature to include dependencies
@@ -141,6 +148,54 @@ def _create_fixture_function(
return created_func
def _relationship_load_options(model: type[DeclarativeBase]) -> list[ExecutableOption]:
"""Build selectinload options for all direct relationships on a model."""
return [
selectinload(getattr(model, rel.key)) for rel in model.__mapper__.relationships
]
async def _reload_with_relationships(
session: AsyncSession,
instances: list[DeclarativeBase],
load_options: list[ExecutableOption],
) -> list[DeclarativeBase]:
"""Reload instances in a single bulk query with relationship eager-loading.
Uses one SELECT … WHERE pk IN (…) so selectinload can batch all relationship
queries — 1 + N_relationships round-trips regardless of how many instances
there are, instead of one session.get() per instance.
Preserves the original insertion order.
"""
model = type(instances[0])
mapper = model.__mapper__
pk_cols = mapper.primary_key
if len(pk_cols) == 1:
pk_attr = getattr(model, pk_cols[0].key)
pks = [getattr(inst, pk_cols[0].key) for inst in instances]
result = await session.execute(
select(model).where(pk_attr.in_(pks)).options(*load_options)
)
by_pk = {getattr(row, pk_cols[0].key): row for row in result.unique().scalars()}
return [by_pk[pk] for pk in pks]
# Composite PK: fall back to per-instance reload
reloaded: list[DeclarativeBase] = []
for instance in instances:
pk = _get_primary_key(instance)
refreshed = await session.get(
model,
pk,
options=cast(list[ORMOption], load_options),
populate_existing=True,
)
if refreshed is not None: # pragma: no branch
reloaded.append(refreshed)
return reloaded
def _get_primary_key(instance: DeclarativeBase) -> Any | None:
"""Get the primary key value of a model instance."""
mapper = instance.__class__.__mapper__

View File

@@ -4,10 +4,26 @@ import asyncio
import uuid
import pytest
from sqlalchemy import text
from sqlalchemy import (
Column,
ForeignKey,
ForeignKeyConstraint,
String,
Table,
Uuid,
select,
text,
)
from sqlalchemy.engine import make_url
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import (
DeclarativeBase,
Mapped,
mapped_column,
relationship,
selectinload,
)
from fastapi_toolsets.db import (
LockMode,
@@ -17,12 +33,15 @@ from fastapi_toolsets.db import (
create_db_dependency,
get_transaction,
lock_tables,
m2m_add,
m2m_remove,
m2m_set,
wait_for_row_change,
)
from fastapi_toolsets.exceptions import NotFoundError
from fastapi_toolsets.pytest import create_db_session
from .conftest import DATABASE_URL, Base, Role, RoleCrud, User, UserCrud
from .conftest import DATABASE_URL, Base, Post, Role, RoleCrud, Tag, User, UserCrud
class TestCreateDbDependency:
@@ -81,6 +100,21 @@ class TestCreateDbDependency:
await engine.dispose()
@pytest.mark.anyio
async def test_no_commit_when_not_in_transaction(self):
"""Dependency skips commit if the session is no longer in a transaction on exit."""
engine = create_async_engine(DATABASE_URL, echo=False)
session_factory = async_sessionmaker(engine, expire_on_commit=False)
get_db = create_db_dependency(session_factory)
async for session in get_db():
# Manually commit — session exits the transaction
await session.commit()
assert not session.in_transaction()
# The dependency's post-yield path must not call commit again (no error)
await engine.dispose()
@pytest.mark.anyio
async def test_update_after_lock_tables_is_persisted(self):
"""Changes made after lock_tables exits (before endpoint returns) are committed.
@@ -480,3 +514,417 @@ class TestCleanupTables:
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
# Should not raise
await cleanup_tables(session, EmptyBase)
class TestM2MAdd:
"""Tests for m2m_add helper."""
@pytest.mark.anyio
async def test_adds_single_related(self, db_session: AsyncSession):
"""Associates one related instance via the secondary table."""
user = User(username="m2m_author", email="m2m@test.com")
db_session.add(user)
await db_session.flush()
post = Post(title="Post A", author_id=user.id)
tag = Tag(name="python")
db_session.add_all([post, tag])
await db_session.flush()
async with get_transaction(db_session):
await m2m_add(db_session, post, Post.tags, tag)
result = await db_session.execute(
select(Post).where(Post.id == post.id).options(selectinload(Post.tags))
)
loaded = result.scalar_one()
assert len(loaded.tags) == 1
assert loaded.tags[0].id == tag.id
@pytest.mark.anyio
async def test_adds_multiple_related(self, db_session: AsyncSession):
"""Associates multiple related instances in a single call."""
user = User(username="m2m_author2", email="m2m2@test.com")
db_session.add(user)
await db_session.flush()
post = Post(title="Post B", author_id=user.id)
tag1 = Tag(name="web")
tag2 = Tag(name="api")
tag3 = Tag(name="async")
db_session.add_all([post, tag1, tag2, tag3])
await db_session.flush()
async with get_transaction(db_session):
await m2m_add(db_session, post, Post.tags, tag1, tag2, tag3)
result = await db_session.execute(
select(Post).where(Post.id == post.id).options(selectinload(Post.tags))
)
loaded = result.scalar_one()
assert {t.id for t in loaded.tags} == {tag1.id, tag2.id, tag3.id}
@pytest.mark.anyio
async def test_noop_for_empty_related(self, db_session: AsyncSession):
"""Calling with no related instances is a no-op."""
user = User(username="m2m_author3", email="m2m3@test.com")
db_session.add(user)
await db_session.flush()
post = Post(title="Post C", author_id=user.id)
db_session.add(post)
await db_session.flush()
async with get_transaction(db_session):
await m2m_add(db_session, post, Post.tags) # no related instances
result = await db_session.execute(
select(Post).where(Post.id == post.id).options(selectinload(Post.tags))
)
loaded = result.scalar_one()
assert loaded.tags == []
@pytest.mark.anyio
async def test_ignore_conflicts_true(self, db_session: AsyncSession):
"""Duplicate inserts are silently skipped when ignore_conflicts=True."""
user = User(username="m2m_author4", email="m2m4@test.com")
db_session.add(user)
await db_session.flush()
post = Post(title="Post D", author_id=user.id)
tag = Tag(name="duplicate_tag")
db_session.add_all([post, tag])
await db_session.flush()
async with get_transaction(db_session):
await m2m_add(db_session, post, Post.tags, tag)
# Second call with ignore_conflicts=True must not raise
async with get_transaction(db_session):
await m2m_add(db_session, post, Post.tags, tag, ignore_conflicts=True)
result = await db_session.execute(
select(Post).where(Post.id == post.id).options(selectinload(Post.tags))
)
loaded = result.scalar_one()
assert len(loaded.tags) == 1
@pytest.mark.anyio
async def test_ignore_conflicts_false_raises(self, db_session: AsyncSession):
"""Duplicate inserts raise IntegrityError when ignore_conflicts=False (default)."""
user = User(username="m2m_author5", email="m2m5@test.com")
db_session.add(user)
await db_session.flush()
post = Post(title="Post E", author_id=user.id)
tag = Tag(name="conflict_tag")
db_session.add_all([post, tag])
await db_session.flush()
async with get_transaction(db_session):
await m2m_add(db_session, post, Post.tags, tag)
with pytest.raises(IntegrityError):
async with get_transaction(db_session):
await m2m_add(db_session, post, Post.tags, tag)
@pytest.mark.anyio
async def test_non_m2m_raises_type_error(self, db_session: AsyncSession):
"""Passing a non-M2M relationship attribute raises TypeError."""
user = User(username="m2m_author6", email="m2m6@test.com")
db_session.add(user)
await db_session.flush()
role = Role(name="type_err_role")
db_session.add(role)
await db_session.flush()
with pytest.raises(TypeError, match="Many-to-Many"):
await m2m_add(db_session, user, User.role, role)
@pytest.mark.anyio
async def test_works_inside_lock_tables(self, db_session: AsyncSession):
"""m2m_add works correctly inside a lock_tables nested transaction."""
user = User(username="m2m_lock_author", email="m2m_lock@test.com")
db_session.add(user)
await db_session.flush()
async with lock_tables(db_session, [Tag]):
tag = Tag(name="locked_tag")
db_session.add(tag)
await db_session.flush()
post = Post(title="Post Lock", author_id=user.id)
db_session.add(post)
await db_session.flush()
await m2m_add(db_session, post, Post.tags, tag)
result = await db_session.execute(
select(Post).where(Post.id == post.id).options(selectinload(Post.tags))
)
loaded = result.scalar_one()
assert len(loaded.tags) == 1
assert loaded.tags[0].name == "locked_tag"
class _LocalBase(DeclarativeBase):
pass
_comp_assoc = Table(
"_comp_assoc",
_LocalBase.metadata,
Column("owner_id", Uuid, ForeignKey("_comp_owners.id"), primary_key=True),
Column("item_group", String(50), primary_key=True),
Column("item_code", String(50), primary_key=True),
ForeignKeyConstraint(
["item_group", "item_code"],
["_comp_items.group_id", "_comp_items.item_code"],
),
)
class _CompOwner(_LocalBase):
__tablename__ = "_comp_owners"
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
items: Mapped[list["_CompItem"]] = relationship(secondary=_comp_assoc)
class _CompItem(_LocalBase):
__tablename__ = "_comp_items"
group_id: Mapped[str] = mapped_column(String(50), primary_key=True)
item_code: Mapped[str] = mapped_column(String(50), primary_key=True)
class TestM2MRemove:
"""Tests for m2m_remove helper."""
async def _setup(
self, session: AsyncSession, username: str, email: str, *tag_names: str
):
"""Create a user, post, and tags; associate all tags with the post."""
user = User(username=username, email=email)
session.add(user)
await session.flush()
post = Post(title=f"Post {username}", author_id=user.id)
tags = [Tag(name=n) for n in tag_names]
session.add(post)
session.add_all(tags)
await session.flush()
async with get_transaction(session):
await m2m_add(session, post, Post.tags, *tags)
return post, tags
async def _load_tags(self, session: AsyncSession, post: Post) -> list[Tag]:
result = await session.execute(
select(Post).where(Post.id == post.id).options(selectinload(Post.tags))
)
return result.scalar_one().tags
@pytest.mark.anyio
async def test_removes_single(self, db_session: AsyncSession):
"""Removes one association, leaving others intact."""
post, (tag1, tag2) = await self._setup(
db_session, "rm_author1", "rm1@test.com", "tag_rm_a", "tag_rm_b"
)
async with get_transaction(db_session):
await m2m_remove(db_session, post, Post.tags, tag1)
remaining = await self._load_tags(db_session, post)
assert len(remaining) == 1
assert remaining[0].id == tag2.id
@pytest.mark.anyio
async def test_removes_multiple(self, db_session: AsyncSession):
"""Removes multiple associations in one call."""
post, (tag1, tag2, tag3) = await self._setup(
db_session, "rm_author2", "rm2@test.com", "tag_rm_c", "tag_rm_d", "tag_rm_e"
)
async with get_transaction(db_session):
await m2m_remove(db_session, post, Post.tags, tag1, tag3)
remaining = await self._load_tags(db_session, post)
assert len(remaining) == 1
assert remaining[0].id == tag2.id
@pytest.mark.anyio
async def test_noop_for_empty_related(self, db_session: AsyncSession):
"""Calling with no related instances is a no-op."""
post, (tag,) = await self._setup(
db_session, "rm_author3", "rm3@test.com", "tag_rm_f"
)
async with get_transaction(db_session):
await m2m_remove(db_session, post, Post.tags)
remaining = await self._load_tags(db_session, post)
assert len(remaining) == 1
@pytest.mark.anyio
async def test_idempotent_for_missing_association(self, db_session: AsyncSession):
"""Removing a non-existent association does not raise."""
post, (tag1,) = await self._setup(
db_session, "rm_author4", "rm4@test.com", "tag_rm_g"
)
tag2 = Tag(name="tag_rm_h")
db_session.add(tag2)
await db_session.flush()
# tag2 was never associated — should not raise
async with get_transaction(db_session):
await m2m_remove(db_session, post, Post.tags, tag2)
remaining = await self._load_tags(db_session, post)
assert len(remaining) == 1
@pytest.mark.anyio
async def test_non_m2m_raises_type_error(self, db_session: AsyncSession):
"""Passing a non-M2M relationship attribute raises TypeError."""
user = User(username="rm_author5", email="rm5@test.com")
db_session.add(user)
await db_session.flush()
role = Role(name="rm_type_err_role")
db_session.add(role)
await db_session.flush()
with pytest.raises(TypeError, match="Many-to-Many"):
await m2m_remove(db_session, user, User.role, role)
@pytest.mark.anyio
async def test_removes_composite_pk_related(self):
"""Composite-PK branch: DELETE uses tuple IN when related side has multi-col PK."""
engine = create_async_engine(DATABASE_URL, echo=False)
async with engine.begin() as conn:
await conn.run_sync(_LocalBase.metadata.create_all)
session_factory = async_sessionmaker(engine, expire_on_commit=False)
try:
async with session_factory() as session:
owner = _CompOwner()
item1 = _CompItem(group_id="g1", item_code="c1")
item2 = _CompItem(group_id="g1", item_code="c2")
session.add_all([owner, item1, item2])
await session.flush()
async with get_transaction(session):
await m2m_add(session, owner, _CompOwner.items, item1, item2)
async with get_transaction(session):
await m2m_remove(session, owner, _CompOwner.items, item1)
await session.commit()
async with session_factory() as verify:
from sqlalchemy import select
from sqlalchemy.orm import selectinload
result = await verify.execute(
select(_CompOwner)
.where(_CompOwner.id == owner.id)
.options(selectinload(_CompOwner.items))
)
loaded = result.scalar_one()
assert len(loaded.items) == 1
assert (loaded.items[0].group_id, loaded.items[0].item_code) == (
"g1",
"c2",
)
finally:
async with engine.begin() as conn:
await conn.run_sync(_LocalBase.metadata.drop_all)
await engine.dispose()
class TestM2MSet:
"""Tests for m2m_set helper."""
async def _load_tags(self, session: AsyncSession, post: Post) -> list[Tag]:
result = await session.execute(
select(Post).where(Post.id == post.id).options(selectinload(Post.tags))
)
return result.scalar_one().tags
@pytest.mark.anyio
async def test_replaces_existing_set(self, db_session: AsyncSession):
"""Replaces the full association set atomically."""
user = User(username="set_author1", email="set1@test.com")
db_session.add(user)
await db_session.flush()
post = Post(title="Post Set A", author_id=user.id)
tag1 = Tag(name="tag_set_a")
tag2 = Tag(name="tag_set_b")
tag3 = Tag(name="tag_set_c")
db_session.add_all([post, tag1, tag2, tag3])
await db_session.flush()
async with get_transaction(db_session):
await m2m_add(db_session, post, Post.tags, tag1, tag2)
async with get_transaction(db_session):
await m2m_set(db_session, post, Post.tags, tag3)
remaining = await self._load_tags(db_session, post)
assert len(remaining) == 1
assert remaining[0].id == tag3.id
@pytest.mark.anyio
async def test_clears_all_when_no_related(self, db_session: AsyncSession):
"""Passing no related instances clears all associations."""
user = User(username="set_author2", email="set2@test.com")
db_session.add(user)
await db_session.flush()
post = Post(title="Post Set B", author_id=user.id)
tag = Tag(name="tag_set_d")
db_session.add_all([post, tag])
await db_session.flush()
async with get_transaction(db_session):
await m2m_add(db_session, post, Post.tags, tag)
async with get_transaction(db_session):
await m2m_set(db_session, post, Post.tags)
remaining = await self._load_tags(db_session, post)
assert remaining == []
@pytest.mark.anyio
async def test_set_on_empty_then_populate(self, db_session: AsyncSession):
"""m2m_set works on a post with no existing associations."""
user = User(username="set_author3", email="set3@test.com")
db_session.add(user)
await db_session.flush()
post = Post(title="Post Set C", author_id=user.id)
tag1 = Tag(name="tag_set_e")
tag2 = Tag(name="tag_set_f")
db_session.add_all([post, tag1, tag2])
await db_session.flush()
async with get_transaction(db_session):
await m2m_set(db_session, post, Post.tags, tag1, tag2)
remaining = await self._load_tags(db_session, post)
assert {t.id for t in remaining} == {tag1.id, tag2.id}
@pytest.mark.anyio
async def test_non_m2m_raises_type_error(self, db_session: AsyncSession):
"""Passing a non-M2M relationship attribute raises TypeError."""
user = User(username="set_author4", email="set4@test.com")
db_session.add(user)
await db_session.flush()
role = Role(name="set_type_err_role")
db_session.add(role)
await db_session.flush()
with pytest.raises(TypeError, match="Many-to-Many"):
await m2m_set(db_session, user, User.role, role)

View File

@@ -10,6 +10,7 @@ from fastapi_toolsets.fixtures import (
Context,
FixtureRegistry,
LoadStrategy,
get_field_by_attr,
get_obj_by_attr,
load_fixtures,
load_fixtures_by_context,
@@ -951,6 +952,41 @@ class TestGetObjByAttr:
get_obj_by_attr(self.roles, "id", "not-a-uuid")
class TestGetFieldByAttr:
"""Tests for get_field_by_attr helper function."""
def setup_method(self):
self.registry = FixtureRegistry()
self.role_id_1 = uuid.uuid4()
self.role_id_2 = uuid.uuid4()
role_id_1 = self.role_id_1
role_id_2 = self.role_id_2
@self.registry.register
def roles() -> list[Role]:
return [
Role(id=role_id_1, name="admin"),
Role(id=role_id_2, name="user"),
]
self.roles = roles
def test_returns_id_by_default(self):
"""Returns the id field when no field is specified."""
result = get_field_by_attr(self.roles, "name", "admin")
assert result == self.role_id_1
def test_returns_specified_field(self):
"""Returns the requested field instead of id."""
result = get_field_by_attr(self.roles, "id", self.role_id_2, field="name")
assert result == "user"
def test_no_match_raises_stop_iteration(self):
"""Propagates StopIteration from get_obj_by_attr when no match found."""
with pytest.raises(StopIteration, match="No object with name=missing"):
get_field_by_attr(self.roles, "name", "missing")
class TestGetPrimaryKey:
"""Unit tests for the _get_primary_key helper (composite PK paths)."""

View File

@@ -1,17 +1,18 @@
"""Tests for fastapi_toolsets.pytest module."""
import uuid
from typing import cast
import pytest
from fastapi import Depends, FastAPI
from httpx import AsyncClient
from sqlalchemy import select, text
from sqlalchemy import ForeignKey, String, select, text
from sqlalchemy.engine import make_url
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from fastapi_toolsets.db import get_transaction
from fastapi_toolsets.fixtures import Context, FixtureRegistry
from fastapi_toolsets.fixtures import Context, FixtureRegistry, LoadStrategy
from fastapi_toolsets.pytest import (
create_async_client,
create_db_session,
@@ -19,9 +20,23 @@ from fastapi_toolsets.pytest import (
register_fixtures,
worker_database_url,
)
from fastapi_toolsets.pytest.plugin import (
_get_primary_key,
_relationship_load_options,
_reload_with_relationships,
)
from fastapi_toolsets.pytest.utils import _get_xdist_worker
from .conftest import DATABASE_URL, Base, Role, RoleCrud, User, UserCrud
from .conftest import (
DATABASE_URL,
Base,
IntRole,
Permission,
Role,
RoleCrud,
User,
UserCrud,
)
test_registry = FixtureRegistry()
@@ -136,14 +151,8 @@ class TestGeneratedFixtures:
async def test_fixture_relationships_work(
self, db_session: AsyncSession, fixture_users: list[User]
):
"""Loaded fixtures have working relationships."""
# Load user with role relationship
user = await UserCrud.get(
db_session,
[User.id == USER_ADMIN_ID],
load_options=[selectinload(User.role)],
)
"""Loaded fixtures have working relationships directly accessible."""
user = next(u for u in fixture_users if u.id == USER_ADMIN_ID)
assert user.role is not None
assert user.role.name == "plugin_admin"
@@ -177,6 +186,15 @@ class TestGeneratedFixtures:
assert users[0].username == "plugin_admin"
assert users[1].username == "plugin_user"
@pytest.mark.anyio
async def test_fixture_auto_loads_relationships(
self, db_session: AsyncSession, fixture_users: list[User]
):
"""Fixtures automatically eager-load all direct relationships."""
user = next(u for u in fixture_users if u.username == "plugin_admin")
assert user.role is not None
assert user.role.name == "plugin_admin"
@pytest.mark.anyio
async def test_multiple_fixtures_in_same_test(
self,
@@ -516,3 +534,192 @@ class TestCreateWorkerDatabase:
)
assert result.scalar() is None
await engine.dispose()
class _LocalBase(DeclarativeBase):
pass
class _Group(_LocalBase):
__tablename__ = "_test_groups"
id: Mapped[uuid.UUID] = mapped_column(primary_key=True, default=uuid.uuid4)
name: Mapped[str] = mapped_column(String(50))
class _CompositeItem(_LocalBase):
"""Model with composite PK and a relationship — exercises the fallback path."""
__tablename__ = "_test_composite_items"
group_id: Mapped[uuid.UUID] = mapped_column(
ForeignKey("_test_groups.id"), primary_key=True
)
item_code: Mapped[str] = mapped_column(String(50), primary_key=True)
group: Mapped["_Group"] = relationship()
class TestGetPrimaryKey:
"""Unit tests for _get_primary_key — no DB needed."""
def test_single_pk_returns_value(self):
rid = uuid.UUID("00000000-0000-0000-0000-000000000001")
role = Role(id=rid, name="x")
assert _get_primary_key(role) == rid
def test_composite_pk_all_set_returns_tuple(self):
perm = Permission(subject="posts", action="read")
assert _get_primary_key(perm) == ("posts", "read")
def test_composite_pk_partial_none_returns_none(self):
perm = Permission(subject=None, action="read")
assert _get_primary_key(perm) is None
def test_composite_pk_all_none_returns_none(self):
perm = Permission(subject=None, action=None)
assert _get_primary_key(perm) is None
class TestRelationshipLoadOptions:
"""Unit tests for _relationship_load_options — no DB needed."""
def test_empty_for_model_with_no_relationships(self):
assert _relationship_load_options(IntRole) == []
def test_returns_options_for_model_with_relationships(self):
opts = _relationship_load_options(User)
assert len(opts) >= 1
class TestFixtureStrategies:
"""Integration tests covering INSERT, SKIP_EXISTING, empty fixture, no-rels model."""
@pytest.mark.anyio
async def test_empty_fixture_returns_empty_list(self, db_session: AsyncSession):
"""Fixture function returning [] produces an empty list."""
registry = FixtureRegistry()
@registry.register()
def empty() -> list[Role]:
return []
local_ns: dict = {}
register_fixtures(registry, local_ns, session_fixture="db_session")
inner = local_ns["fixture_empty"].__wrapped__ # type: ignore[attr-defined]
result = await inner(db_session=db_session)
assert result == []
@pytest.mark.anyio
async def test_insert_strategy_no_relationships(self, db_session: AsyncSession):
"""INSERT strategy adds instances; model with no rels skips reload (line 135)."""
registry = FixtureRegistry()
@registry.register()
def int_roles() -> list[IntRole]:
return [IntRole(name="insert_role")]
local_ns: dict = {}
register_fixtures(
registry,
local_ns,
session_fixture="db_session",
strategy=LoadStrategy.INSERT,
)
inner = local_ns["fixture_int_roles"].__wrapped__ # type: ignore[attr-defined]
result = await inner(db_session=db_session)
assert len(result) == 1
assert result[0].name == "insert_role"
@pytest.mark.anyio
async def test_skip_existing_inserts_new_record(self, db_session: AsyncSession):
"""SKIP_EXISTING inserts when the record does not yet exist."""
registry = FixtureRegistry()
role_id = uuid.uuid4()
@registry.register()
def new_roles() -> list[Role]:
return [Role(id=role_id, name="skip_new")]
local_ns: dict = {}
register_fixtures(
registry,
local_ns,
session_fixture="db_session",
strategy=LoadStrategy.SKIP_EXISTING,
)
inner = local_ns["fixture_new_roles"].__wrapped__ # type: ignore[attr-defined]
result = await inner(db_session=db_session)
assert len(result) == 1
assert result[0].id == role_id
@pytest.mark.anyio
async def test_skip_existing_returns_existing_record(
self, db_session: AsyncSession
):
"""SKIP_EXISTING returns the existing DB record when PK already present."""
role_id = uuid.uuid4()
existing = Role(id=role_id, name="already_there")
db_session.add(existing)
await db_session.flush()
registry = FixtureRegistry()
@registry.register()
def dup_roles() -> list[Role]:
return [Role(id=role_id, name="should_not_overwrite")]
local_ns: dict = {}
register_fixtures(
registry,
local_ns,
session_fixture="db_session",
strategy=LoadStrategy.SKIP_EXISTING,
)
inner = local_ns["fixture_dup_roles"].__wrapped__ # type: ignore[attr-defined]
result = await inner(db_session=db_session)
assert len(result) == 1
assert result[0].name == "already_there"
@pytest.mark.anyio
async def test_skip_existing_null_pk_inserts(self, db_session: AsyncSession):
"""SKIP_EXISTING with null PK (auto-increment) falls through to session.add()."""
registry = FixtureRegistry()
@registry.register()
def auto_roles() -> list[IntRole]:
return [IntRole(name="auto_int")]
local_ns: dict = {}
register_fixtures(
registry,
local_ns,
session_fixture="db_session",
strategy=LoadStrategy.SKIP_EXISTING,
)
inner = local_ns["fixture_auto_roles"].__wrapped__ # type: ignore[attr-defined]
result = await inner(db_session=db_session)
assert len(result) == 1
assert result[0].name == "auto_int"
class TestReloadWithRelationshipsCompositePK:
"""Integration test for _reload_with_relationships composite-PK fallback."""
@pytest.mark.anyio
async def test_composite_pk_fallback_loads_relationships(self):
"""Models with composite PKs are reloaded per-instance via session.get()."""
async with create_db_session(DATABASE_URL, _LocalBase) as session:
group = _Group(id=uuid.uuid4(), name="g1")
session.add(group)
await session.flush()
item = _CompositeItem(group_id=group.id, item_code="A")
session.add(item)
await session.flush()
load_opts = _relationship_load_options(_CompositeItem)
assert load_opts # _CompositeItem has 'group' relationship
reloaded = await _reload_with_relationships(session, [item], load_opts)
assert len(reloaded) == 1
reloaded_item = cast(_CompositeItem, reloaded[0])
assert reloaded_item.group is not None
assert reloaded_item.group.name == "g1"

2
uv.lock generated
View File

@@ -251,7 +251,7 @@ wheels = [
[[package]]
name = "fastapi-toolsets"
version = "3.0.3"
version = "3.1.0"
source = { editable = "." }
dependencies = [
{ name = "asyncpg" },