mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-15 22:26:25 +02:00
Compare commits
2 Commits
8a16f2808e
...
v2.4.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
6d6fae5538
|
|||
|
|
fc9cd1f034 |
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "fastapi-toolsets"
|
name = "fastapi-toolsets"
|
||||||
version = "2.3.0"
|
version = "2.4.0"
|
||||||
description = "Production-ready utilities for FastAPI applications"
|
description = "Production-ready utilities for FastAPI applications"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
|
|||||||
@@ -21,4 +21,4 @@ Example usage:
|
|||||||
return Response(data={"user": user.username}, message="Success")
|
return Response(data={"user": user.username}, message="Success")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__version__ = "2.3.0"
|
__version__ = "2.4.0"
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from typing import Any, TypeVar
|
|||||||
from sqlalchemy import event
|
from sqlalchemy import event
|
||||||
from sqlalchemy import inspect as sa_inspect
|
from sqlalchemy import inspect as sa_inspect
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm.attributes import set_committed_value as _sa_set_committed_value
|
||||||
|
|
||||||
from ..logger import get_logger
|
from ..logger import get_logger
|
||||||
|
|
||||||
@@ -53,6 +54,17 @@ def watch(*fields: str) -> Any:
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def _snapshot_column_attrs(obj: Any) -> dict[str, Any]:
|
||||||
|
"""Read currently-loaded column values into a plain dict."""
|
||||||
|
state = sa_inspect(obj) # InstanceState
|
||||||
|
state_dict = state.dict
|
||||||
|
return {
|
||||||
|
prop.key: state_dict[prop.key]
|
||||||
|
for prop in state.mapper.column_attrs
|
||||||
|
if prop.key in state_dict
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _upsert_changes(
|
def _upsert_changes(
|
||||||
pending: dict[int, tuple[Any, dict[str, dict[str, Any]]]],
|
pending: dict[int, tuple[Any, dict[str, dict[str, Any]]]],
|
||||||
obj: Any,
|
obj: Any,
|
||||||
@@ -139,16 +151,31 @@ def _task_error_handler(task: asyncio.Task[Any]) -> None:
|
|||||||
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
|
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
|
||||||
|
|
||||||
|
|
||||||
def _call_callback(loop: asyncio.AbstractEventLoop, fn: Any, *args: Any) -> None:
|
def _schedule_with_snapshot(
|
||||||
"""Dispatch *fn* with *args*, handling both sync and async callables."""
|
loop: asyncio.AbstractEventLoop, obj: Any, fn: Any, *args: Any
|
||||||
try:
|
) -> None:
|
||||||
result = fn(*args)
|
"""Snapshot *obj*'s column attrs now (before expire_on_commit wipes them),
|
||||||
except Exception as exc:
|
then schedule a coroutine that restores the snapshot and calls *fn*.
|
||||||
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
|
"""
|
||||||
return
|
snapshot = _snapshot_column_attrs(obj)
|
||||||
if asyncio.iscoroutine(result):
|
|
||||||
task = loop.create_task(result)
|
async def _run(
|
||||||
task.add_done_callback(_task_error_handler)
|
obj: Any = obj,
|
||||||
|
fn: Any = fn,
|
||||||
|
snapshot: dict[str, Any] = snapshot,
|
||||||
|
args: tuple = args,
|
||||||
|
) -> None:
|
||||||
|
for key, value in snapshot.items():
|
||||||
|
_sa_set_committed_value(obj, key, value)
|
||||||
|
try:
|
||||||
|
result = fn(*args)
|
||||||
|
if asyncio.iscoroutine(result):
|
||||||
|
await result
|
||||||
|
except Exception as exc:
|
||||||
|
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
|
||||||
|
|
||||||
|
task = loop.create_task(_run())
|
||||||
|
task.add_done_callback(_task_error_handler)
|
||||||
|
|
||||||
|
|
||||||
@event.listens_for(AsyncSession.sync_session_class, "after_commit")
|
@event.listens_for(AsyncSession.sync_session_class, "after_commit")
|
||||||
@@ -168,13 +195,13 @@ def _after_commit(session: Any) -> None:
|
|||||||
return
|
return
|
||||||
|
|
||||||
for obj in creates:
|
for obj in creates:
|
||||||
_call_callback(loop, obj.on_create)
|
_schedule_with_snapshot(loop, obj, obj.on_create)
|
||||||
|
|
||||||
for obj in deletes:
|
for obj in deletes:
|
||||||
_call_callback(loop, obj.on_delete)
|
_schedule_with_snapshot(loop, obj, obj.on_delete)
|
||||||
|
|
||||||
for obj, changes in field_changes.values():
|
for obj, changes in field_changes.values():
|
||||||
_call_callback(loop, obj.on_update, changes)
|
_schedule_with_snapshot(loop, obj, obj.on_update, changes)
|
||||||
|
|
||||||
|
|
||||||
class WatchedFieldsMixin:
|
class WatchedFieldsMixin:
|
||||||
|
|||||||
@@ -31,7 +31,6 @@ from fastapi_toolsets.models.watched import (
|
|||||||
_after_flush,
|
_after_flush,
|
||||||
_after_flush_postexec,
|
_after_flush_postexec,
|
||||||
_after_rollback,
|
_after_rollback,
|
||||||
_call_callback,
|
|
||||||
_task_error_handler,
|
_task_error_handler,
|
||||||
_upsert_changes,
|
_upsert_changes,
|
||||||
)
|
)
|
||||||
@@ -128,6 +127,17 @@ class WatchAllModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
|
|||||||
_test_events.append({"event": "update", "obj_id": self.id, "changes": changes})
|
_test_events.append({"event": "update", "obj_id": self.id, "changes": changes})
|
||||||
|
|
||||||
|
|
||||||
|
class FailingCallbackModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
|
||||||
|
"""Model whose on_create always raises to test exception logging."""
|
||||||
|
|
||||||
|
__tablename__ = "mixin_failing_callback_models"
|
||||||
|
|
||||||
|
name: Mapped[str] = mapped_column(String(50))
|
||||||
|
|
||||||
|
async def on_create(self) -> None:
|
||||||
|
raise RuntimeError("callback intentionally failed")
|
||||||
|
|
||||||
|
|
||||||
class NonWatchedModel(MixinBase):
|
class NonWatchedModel(MixinBase):
|
||||||
__tablename__ = "mixin_non_watched_models"
|
__tablename__ = "mixin_non_watched_models"
|
||||||
|
|
||||||
@@ -135,6 +145,32 @@ class NonWatchedModel(MixinBase):
|
|||||||
value: Mapped[str] = mapped_column(String(50))
|
value: Mapped[str] = mapped_column(String(50))
|
||||||
|
|
||||||
|
|
||||||
|
_attr_access_events: list[dict] = []
|
||||||
|
|
||||||
|
|
||||||
|
class AttrAccessModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
|
||||||
|
"""Model used to verify that self attributes are accessible in every callback."""
|
||||||
|
|
||||||
|
__tablename__ = "mixin_attr_access_models"
|
||||||
|
|
||||||
|
name: Mapped[str] = mapped_column(String(50))
|
||||||
|
|
||||||
|
async def on_create(self) -> None:
|
||||||
|
_attr_access_events.append(
|
||||||
|
{"event": "create", "id": self.id, "name": self.name}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_delete(self) -> None:
|
||||||
|
_attr_access_events.append(
|
||||||
|
{"event": "delete", "id": self.id, "name": self.name}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_update(self, changes: dict) -> None:
|
||||||
|
_attr_access_events.append(
|
||||||
|
{"event": "update", "id": self.id, "name": self.name}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_sync_events: list[dict] = []
|
_sync_events: list[dict] = []
|
||||||
|
|
||||||
|
|
||||||
@@ -174,6 +210,25 @@ async def mixin_session():
|
|||||||
await engine.dispose()
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
async def mixin_session_expire():
|
||||||
|
"""Session with expire_on_commit=True (the default) to exercise attribute access after commit."""
|
||||||
|
engine = create_async_engine(DATABASE_URL, echo=False)
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(MixinBase.metadata.create_all)
|
||||||
|
|
||||||
|
session_factory = async_sessionmaker(engine, expire_on_commit=True)
|
||||||
|
session = session_factory()
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(MixinBase.metadata.drop_all)
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
class TestUUIDMixin:
|
class TestUUIDMixin:
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_uuid_generated_by_db(self, mixin_session):
|
async def test_uuid_generated_by_db(self, mixin_session):
|
||||||
@@ -742,6 +797,16 @@ class TestWatchedFieldsMixin:
|
|||||||
|
|
||||||
assert _test_events == []
|
assert _test_events == []
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_callback_exception_is_logged(self, mixin_session):
|
||||||
|
"""Exceptions raised inside on_create are logged, not propagated."""
|
||||||
|
obj = FailingCallbackModel(name="boom")
|
||||||
|
mixin_session.add(obj)
|
||||||
|
with patch.object(_watched_module._logger, "error") as mock_error:
|
||||||
|
await mixin_session.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
mock_error.assert_called_once()
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_non_watched_model_no_callback(self, mixin_session):
|
async def test_non_watched_model_no_callback(self, mixin_session):
|
||||||
"""Dirty objects whose type is not a WatchedFieldsMixin are skipped."""
|
"""Dirty objects whose type is not a WatchedFieldsMixin are skipped."""
|
||||||
@@ -903,65 +968,66 @@ class TestSyncCallbacks:
|
|||||||
assert updates[0]["changes"]["status"] == {"old": "initial", "new": "updated"}
|
assert updates[0]["changes"]["status"] == {"old": "initial", "new": "updated"}
|
||||||
|
|
||||||
|
|
||||||
class TestCallCallback:
|
class TestAttributeAccessInCallbacks:
|
||||||
|
"""Verify that self attributes are accessible inside every callback type.
|
||||||
|
|
||||||
|
Uses expire_on_commit=True (the SQLAlchemy default) so the tests would fail
|
||||||
|
without the snapshot-restore logic in _schedule_with_snapshot.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def clear_events(self):
|
||||||
|
_attr_access_events.clear()
|
||||||
|
yield
|
||||||
|
_attr_access_events.clear()
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_async_callback_scheduled_as_task(self):
|
async def test_on_create_pk_and_field_accessible(self, mixin_session_expire):
|
||||||
"""_call_callback schedules async functions as tasks."""
|
"""id (server default) and regular fields are readable inside on_create."""
|
||||||
called = []
|
obj = AttrAccessModel(name="hello")
|
||||||
|
mixin_session_expire.add(obj)
|
||||||
async def async_fn() -> None:
|
await mixin_session_expire.commit()
|
||||||
called.append("async")
|
|
||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
_call_callback(loop, async_fn)
|
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
assert called == ["async"]
|
|
||||||
|
events = [e for e in _attr_access_events if e["event"] == "create"]
|
||||||
|
assert len(events) == 1
|
||||||
|
assert isinstance(events[0]["id"], uuid.UUID)
|
||||||
|
assert events[0]["name"] == "hello"
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_sync_callback_called_directly(self):
|
async def test_on_delete_pk_and_field_accessible(self, mixin_session_expire):
|
||||||
"""_call_callback invokes sync functions immediately."""
|
"""id and regular fields are readable inside on_delete."""
|
||||||
called = []
|
obj = AttrAccessModel(name="to-delete")
|
||||||
|
mixin_session_expire.add(obj)
|
||||||
def sync_fn() -> None:
|
await mixin_session_expire.commit()
|
||||||
called.append("sync")
|
|
||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
_call_callback(loop, sync_fn)
|
|
||||||
assert called == ["sync"]
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_sync_callback_exception_logged(self):
|
|
||||||
"""_call_callback logs exceptions from sync callbacks."""
|
|
||||||
|
|
||||||
def failing_fn() -> None:
|
|
||||||
raise RuntimeError("sync error")
|
|
||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
with patch.object(_watched_module._logger, "error") as mock_error:
|
|
||||||
_call_callback(loop, failing_fn)
|
|
||||||
mock_error.assert_called_once()
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_async_callback_with_args(self):
|
|
||||||
"""_call_callback passes arguments to async callbacks."""
|
|
||||||
received = []
|
|
||||||
|
|
||||||
async def async_fn(changes: dict) -> None:
|
|
||||||
received.append(changes)
|
|
||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
_call_callback(loop, async_fn, {"status": {"old": "a", "new": "b"}})
|
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
assert received == [{"status": {"old": "a", "new": "b"}}]
|
_attr_access_events.clear()
|
||||||
|
|
||||||
|
await mixin_session_expire.delete(obj)
|
||||||
|
await mixin_session_expire.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
events = [e for e in _attr_access_events if e["event"] == "delete"]
|
||||||
|
assert len(events) == 1
|
||||||
|
assert isinstance(events[0]["id"], uuid.UUID)
|
||||||
|
assert events[0]["name"] == "to-delete"
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_sync_callback_with_args(self):
|
async def test_on_update_pk_and_updated_field_accessible(
|
||||||
"""_call_callback passes arguments to sync callbacks."""
|
self, mixin_session_expire
|
||||||
received = []
|
):
|
||||||
|
"""id and the new field value are readable inside on_update."""
|
||||||
|
obj = AttrAccessModel(name="original")
|
||||||
|
mixin_session_expire.add(obj)
|
||||||
|
await mixin_session_expire.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
_attr_access_events.clear()
|
||||||
|
|
||||||
def sync_fn(changes: dict) -> None:
|
obj.name = "updated"
|
||||||
received.append(changes)
|
await mixin_session_expire.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
events = [e for e in _attr_access_events if e["event"] == "update"]
|
||||||
_call_callback(loop, sync_fn, {"x": 1})
|
assert len(events) == 1
|
||||||
assert received == [{"x": 1}]
|
assert isinstance(events[0]["id"], uuid.UUID)
|
||||||
|
assert events[0]["name"] == "updated"
|
||||||
|
|||||||
2
uv.lock
generated
2
uv.lock
generated
@@ -251,7 +251,7 @@ wheels = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fastapi-toolsets"
|
name = "fastapi-toolsets"
|
||||||
version = "2.3.0"
|
version = "2.4.0"
|
||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "asyncpg" },
|
{ name = "asyncpg" },
|
||||||
|
|||||||
Reference in New Issue
Block a user