mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-16 14:46:24 +02:00
Compare commits
2 Commits
f82225f995
...
v2.4.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
6d6fae5538
|
|||
|
|
fc9cd1f034 |
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "fastapi-toolsets"
|
||||
version = "2.3.0"
|
||||
version = "2.4.0"
|
||||
description = "Production-ready utilities for FastAPI applications"
|
||||
readme = "README.md"
|
||||
license = "MIT"
|
||||
|
||||
@@ -21,4 +21,4 @@ Example usage:
|
||||
return Response(data={"user": user.username}, message="Success")
|
||||
"""
|
||||
|
||||
__version__ = "2.3.0"
|
||||
__version__ = "2.4.0"
|
||||
|
||||
@@ -9,6 +9,7 @@ from typing import Any, TypeVar
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy import inspect as sa_inspect
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import set_committed_value as _sa_set_committed_value
|
||||
|
||||
from ..logger import get_logger
|
||||
|
||||
@@ -53,6 +54,17 @@ def watch(*fields: str) -> Any:
|
||||
return decorator
|
||||
|
||||
|
||||
def _snapshot_column_attrs(obj: Any) -> dict[str, Any]:
|
||||
"""Read currently-loaded column values into a plain dict."""
|
||||
state = sa_inspect(obj) # InstanceState
|
||||
state_dict = state.dict
|
||||
return {
|
||||
prop.key: state_dict[prop.key]
|
||||
for prop in state.mapper.column_attrs
|
||||
if prop.key in state_dict
|
||||
}
|
||||
|
||||
|
||||
def _upsert_changes(
|
||||
pending: dict[int, tuple[Any, dict[str, dict[str, Any]]]],
|
||||
obj: Any,
|
||||
@@ -139,15 +151,30 @@ def _task_error_handler(task: asyncio.Task[Any]) -> None:
|
||||
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
|
||||
|
||||
|
||||
def _call_callback(loop: asyncio.AbstractEventLoop, fn: Any, *args: Any) -> None:
|
||||
"""Dispatch *fn* with *args*, handling both sync and async callables."""
|
||||
def _schedule_with_snapshot(
|
||||
loop: asyncio.AbstractEventLoop, obj: Any, fn: Any, *args: Any
|
||||
) -> None:
|
||||
"""Snapshot *obj*'s column attrs now (before expire_on_commit wipes them),
|
||||
then schedule a coroutine that restores the snapshot and calls *fn*.
|
||||
"""
|
||||
snapshot = _snapshot_column_attrs(obj)
|
||||
|
||||
async def _run(
|
||||
obj: Any = obj,
|
||||
fn: Any = fn,
|
||||
snapshot: dict[str, Any] = snapshot,
|
||||
args: tuple = args,
|
||||
) -> None:
|
||||
for key, value in snapshot.items():
|
||||
_sa_set_committed_value(obj, key, value)
|
||||
try:
|
||||
result = fn(*args)
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
except Exception as exc:
|
||||
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
|
||||
return
|
||||
if asyncio.iscoroutine(result):
|
||||
task = loop.create_task(result)
|
||||
|
||||
task = loop.create_task(_run())
|
||||
task.add_done_callback(_task_error_handler)
|
||||
|
||||
|
||||
@@ -168,13 +195,13 @@ def _after_commit(session: Any) -> None:
|
||||
return
|
||||
|
||||
for obj in creates:
|
||||
_call_callback(loop, obj.on_create)
|
||||
_schedule_with_snapshot(loop, obj, obj.on_create)
|
||||
|
||||
for obj in deletes:
|
||||
_call_callback(loop, obj.on_delete)
|
||||
_schedule_with_snapshot(loop, obj, obj.on_delete)
|
||||
|
||||
for obj, changes in field_changes.values():
|
||||
_call_callback(loop, obj.on_update, changes)
|
||||
_schedule_with_snapshot(loop, obj, obj.on_update, changes)
|
||||
|
||||
|
||||
class WatchedFieldsMixin:
|
||||
|
||||
@@ -31,7 +31,6 @@ from fastapi_toolsets.models.watched import (
|
||||
_after_flush,
|
||||
_after_flush_postexec,
|
||||
_after_rollback,
|
||||
_call_callback,
|
||||
_task_error_handler,
|
||||
_upsert_changes,
|
||||
)
|
||||
@@ -128,6 +127,17 @@ class WatchAllModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
|
||||
_test_events.append({"event": "update", "obj_id": self.id, "changes": changes})
|
||||
|
||||
|
||||
class FailingCallbackModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
|
||||
"""Model whose on_create always raises to test exception logging."""
|
||||
|
||||
__tablename__ = "mixin_failing_callback_models"
|
||||
|
||||
name: Mapped[str] = mapped_column(String(50))
|
||||
|
||||
async def on_create(self) -> None:
|
||||
raise RuntimeError("callback intentionally failed")
|
||||
|
||||
|
||||
class NonWatchedModel(MixinBase):
|
||||
__tablename__ = "mixin_non_watched_models"
|
||||
|
||||
@@ -135,6 +145,32 @@ class NonWatchedModel(MixinBase):
|
||||
value: Mapped[str] = mapped_column(String(50))
|
||||
|
||||
|
||||
_attr_access_events: list[dict] = []
|
||||
|
||||
|
||||
class AttrAccessModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
|
||||
"""Model used to verify that self attributes are accessible in every callback."""
|
||||
|
||||
__tablename__ = "mixin_attr_access_models"
|
||||
|
||||
name: Mapped[str] = mapped_column(String(50))
|
||||
|
||||
async def on_create(self) -> None:
|
||||
_attr_access_events.append(
|
||||
{"event": "create", "id": self.id, "name": self.name}
|
||||
)
|
||||
|
||||
async def on_delete(self) -> None:
|
||||
_attr_access_events.append(
|
||||
{"event": "delete", "id": self.id, "name": self.name}
|
||||
)
|
||||
|
||||
async def on_update(self, changes: dict) -> None:
|
||||
_attr_access_events.append(
|
||||
{"event": "update", "id": self.id, "name": self.name}
|
||||
)
|
||||
|
||||
|
||||
_sync_events: list[dict] = []
|
||||
|
||||
|
||||
@@ -174,6 +210,25 @@ async def mixin_session():
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def mixin_session_expire():
|
||||
"""Session with expire_on_commit=True (the default) to exercise attribute access after commit."""
|
||||
engine = create_async_engine(DATABASE_URL, echo=False)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(MixinBase.metadata.create_all)
|
||||
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=True)
|
||||
session = session_factory()
|
||||
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(MixinBase.metadata.drop_all)
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
class TestUUIDMixin:
|
||||
@pytest.mark.anyio
|
||||
async def test_uuid_generated_by_db(self, mixin_session):
|
||||
@@ -742,6 +797,16 @@ class TestWatchedFieldsMixin:
|
||||
|
||||
assert _test_events == []
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_callback_exception_is_logged(self, mixin_session):
|
||||
"""Exceptions raised inside on_create are logged, not propagated."""
|
||||
obj = FailingCallbackModel(name="boom")
|
||||
mixin_session.add(obj)
|
||||
with patch.object(_watched_module._logger, "error") as mock_error:
|
||||
await mixin_session.commit()
|
||||
await asyncio.sleep(0)
|
||||
mock_error.assert_called_once()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_non_watched_model_no_callback(self, mixin_session):
|
||||
"""Dirty objects whose type is not a WatchedFieldsMixin are skipped."""
|
||||
@@ -903,65 +968,66 @@ class TestSyncCallbacks:
|
||||
assert updates[0]["changes"]["status"] == {"old": "initial", "new": "updated"}
|
||||
|
||||
|
||||
class TestCallCallback:
|
||||
class TestAttributeAccessInCallbacks:
|
||||
"""Verify that self attributes are accessible inside every callback type.
|
||||
|
||||
Uses expire_on_commit=True (the SQLAlchemy default) so the tests would fail
|
||||
without the snapshot-restore logic in _schedule_with_snapshot.
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_events(self):
|
||||
_attr_access_events.clear()
|
||||
yield
|
||||
_attr_access_events.clear()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_async_callback_scheduled_as_task(self):
|
||||
"""_call_callback schedules async functions as tasks."""
|
||||
called = []
|
||||
|
||||
async def async_fn() -> None:
|
||||
called.append("async")
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
_call_callback(loop, async_fn)
|
||||
async def test_on_create_pk_and_field_accessible(self, mixin_session_expire):
|
||||
"""id (server default) and regular fields are readable inside on_create."""
|
||||
obj = AttrAccessModel(name="hello")
|
||||
mixin_session_expire.add(obj)
|
||||
await mixin_session_expire.commit()
|
||||
await asyncio.sleep(0)
|
||||
assert called == ["async"]
|
||||
|
||||
events = [e for e in _attr_access_events if e["event"] == "create"]
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0]["id"], uuid.UUID)
|
||||
assert events[0]["name"] == "hello"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_sync_callback_called_directly(self):
|
||||
"""_call_callback invokes sync functions immediately."""
|
||||
called = []
|
||||
|
||||
def sync_fn() -> None:
|
||||
called.append("sync")
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
_call_callback(loop, sync_fn)
|
||||
assert called == ["sync"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_sync_callback_exception_logged(self):
|
||||
"""_call_callback logs exceptions from sync callbacks."""
|
||||
|
||||
def failing_fn() -> None:
|
||||
raise RuntimeError("sync error")
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
with patch.object(_watched_module._logger, "error") as mock_error:
|
||||
_call_callback(loop, failing_fn)
|
||||
mock_error.assert_called_once()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_async_callback_with_args(self):
|
||||
"""_call_callback passes arguments to async callbacks."""
|
||||
received = []
|
||||
|
||||
async def async_fn(changes: dict) -> None:
|
||||
received.append(changes)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
_call_callback(loop, async_fn, {"status": {"old": "a", "new": "b"}})
|
||||
async def test_on_delete_pk_and_field_accessible(self, mixin_session_expire):
|
||||
"""id and regular fields are readable inside on_delete."""
|
||||
obj = AttrAccessModel(name="to-delete")
|
||||
mixin_session_expire.add(obj)
|
||||
await mixin_session_expire.commit()
|
||||
await asyncio.sleep(0)
|
||||
assert received == [{"status": {"old": "a", "new": "b"}}]
|
||||
_attr_access_events.clear()
|
||||
|
||||
await mixin_session_expire.delete(obj)
|
||||
await mixin_session_expire.commit()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
events = [e for e in _attr_access_events if e["event"] == "delete"]
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0]["id"], uuid.UUID)
|
||||
assert events[0]["name"] == "to-delete"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_sync_callback_with_args(self):
|
||||
"""_call_callback passes arguments to sync callbacks."""
|
||||
received = []
|
||||
async def test_on_update_pk_and_updated_field_accessible(
|
||||
self, mixin_session_expire
|
||||
):
|
||||
"""id and the new field value are readable inside on_update."""
|
||||
obj = AttrAccessModel(name="original")
|
||||
mixin_session_expire.add(obj)
|
||||
await mixin_session_expire.commit()
|
||||
await asyncio.sleep(0)
|
||||
_attr_access_events.clear()
|
||||
|
||||
def sync_fn(changes: dict) -> None:
|
||||
received.append(changes)
|
||||
obj.name = "updated"
|
||||
await mixin_session_expire.commit()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
_call_callback(loop, sync_fn, {"x": 1})
|
||||
assert received == [{"x": 1}]
|
||||
events = [e for e in _attr_access_events if e["event"] == "update"]
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0]["id"], uuid.UUID)
|
||||
assert events[0]["name"] == "updated"
|
||||
|
||||
Reference in New Issue
Block a user