mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-15 22:26:25 +02:00
fix: resolve MissingGreenlet error when accessing self attributes in WatchedFieldsMixin callbacks (#154)
This commit is contained in:
@@ -31,7 +31,6 @@ from fastapi_toolsets.models.watched import (
|
||||
_after_flush,
|
||||
_after_flush_postexec,
|
||||
_after_rollback,
|
||||
_call_callback,
|
||||
_task_error_handler,
|
||||
_upsert_changes,
|
||||
)
|
||||
@@ -128,6 +127,17 @@ class WatchAllModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
|
||||
_test_events.append({"event": "update", "obj_id": self.id, "changes": changes})
|
||||
|
||||
|
||||
class FailingCallbackModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
|
||||
"""Model whose on_create always raises to test exception logging."""
|
||||
|
||||
__tablename__ = "mixin_failing_callback_models"
|
||||
|
||||
name: Mapped[str] = mapped_column(String(50))
|
||||
|
||||
async def on_create(self) -> None:
|
||||
raise RuntimeError("callback intentionally failed")
|
||||
|
||||
|
||||
class NonWatchedModel(MixinBase):
|
||||
__tablename__ = "mixin_non_watched_models"
|
||||
|
||||
@@ -135,6 +145,32 @@ class NonWatchedModel(MixinBase):
|
||||
value: Mapped[str] = mapped_column(String(50))
|
||||
|
||||
|
||||
_attr_access_events: list[dict] = []
|
||||
|
||||
|
||||
class AttrAccessModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
|
||||
"""Model used to verify that self attributes are accessible in every callback."""
|
||||
|
||||
__tablename__ = "mixin_attr_access_models"
|
||||
|
||||
name: Mapped[str] = mapped_column(String(50))
|
||||
|
||||
async def on_create(self) -> None:
|
||||
_attr_access_events.append(
|
||||
{"event": "create", "id": self.id, "name": self.name}
|
||||
)
|
||||
|
||||
async def on_delete(self) -> None:
|
||||
_attr_access_events.append(
|
||||
{"event": "delete", "id": self.id, "name": self.name}
|
||||
)
|
||||
|
||||
async def on_update(self, changes: dict) -> None:
|
||||
_attr_access_events.append(
|
||||
{"event": "update", "id": self.id, "name": self.name}
|
||||
)
|
||||
|
||||
|
||||
_sync_events: list[dict] = []
|
||||
|
||||
|
||||
@@ -174,6 +210,25 @@ async def mixin_session():
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def mixin_session_expire():
|
||||
"""Session with expire_on_commit=True (the default) to exercise attribute access after commit."""
|
||||
engine = create_async_engine(DATABASE_URL, echo=False)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(MixinBase.metadata.create_all)
|
||||
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=True)
|
||||
session = session_factory()
|
||||
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(MixinBase.metadata.drop_all)
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
class TestUUIDMixin:
|
||||
@pytest.mark.anyio
|
||||
async def test_uuid_generated_by_db(self, mixin_session):
|
||||
@@ -742,6 +797,16 @@ class TestWatchedFieldsMixin:
|
||||
|
||||
assert _test_events == []
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_callback_exception_is_logged(self, mixin_session):
|
||||
"""Exceptions raised inside on_create are logged, not propagated."""
|
||||
obj = FailingCallbackModel(name="boom")
|
||||
mixin_session.add(obj)
|
||||
with patch.object(_watched_module._logger, "error") as mock_error:
|
||||
await mixin_session.commit()
|
||||
await asyncio.sleep(0)
|
||||
mock_error.assert_called_once()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_non_watched_model_no_callback(self, mixin_session):
|
||||
"""Dirty objects whose type is not a WatchedFieldsMixin are skipped."""
|
||||
@@ -903,65 +968,66 @@ class TestSyncCallbacks:
|
||||
assert updates[0]["changes"]["status"] == {"old": "initial", "new": "updated"}
|
||||
|
||||
|
||||
class TestCallCallback:
|
||||
class TestAttributeAccessInCallbacks:
|
||||
"""Verify that self attributes are accessible inside every callback type.
|
||||
|
||||
Uses expire_on_commit=True (the SQLAlchemy default) so the tests would fail
|
||||
without the snapshot-restore logic in _schedule_with_snapshot.
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_events(self):
|
||||
_attr_access_events.clear()
|
||||
yield
|
||||
_attr_access_events.clear()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_async_callback_scheduled_as_task(self):
|
||||
"""_call_callback schedules async functions as tasks."""
|
||||
called = []
|
||||
|
||||
async def async_fn() -> None:
|
||||
called.append("async")
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
_call_callback(loop, async_fn)
|
||||
async def test_on_create_pk_and_field_accessible(self, mixin_session_expire):
|
||||
"""id (server default) and regular fields are readable inside on_create."""
|
||||
obj = AttrAccessModel(name="hello")
|
||||
mixin_session_expire.add(obj)
|
||||
await mixin_session_expire.commit()
|
||||
await asyncio.sleep(0)
|
||||
assert called == ["async"]
|
||||
|
||||
events = [e for e in _attr_access_events if e["event"] == "create"]
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0]["id"], uuid.UUID)
|
||||
assert events[0]["name"] == "hello"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_sync_callback_called_directly(self):
|
||||
"""_call_callback invokes sync functions immediately."""
|
||||
called = []
|
||||
|
||||
def sync_fn() -> None:
|
||||
called.append("sync")
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
_call_callback(loop, sync_fn)
|
||||
assert called == ["sync"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_sync_callback_exception_logged(self):
|
||||
"""_call_callback logs exceptions from sync callbacks."""
|
||||
|
||||
def failing_fn() -> None:
|
||||
raise RuntimeError("sync error")
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
with patch.object(_watched_module._logger, "error") as mock_error:
|
||||
_call_callback(loop, failing_fn)
|
||||
mock_error.assert_called_once()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_async_callback_with_args(self):
|
||||
"""_call_callback passes arguments to async callbacks."""
|
||||
received = []
|
||||
|
||||
async def async_fn(changes: dict) -> None:
|
||||
received.append(changes)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
_call_callback(loop, async_fn, {"status": {"old": "a", "new": "b"}})
|
||||
async def test_on_delete_pk_and_field_accessible(self, mixin_session_expire):
|
||||
"""id and regular fields are readable inside on_delete."""
|
||||
obj = AttrAccessModel(name="to-delete")
|
||||
mixin_session_expire.add(obj)
|
||||
await mixin_session_expire.commit()
|
||||
await asyncio.sleep(0)
|
||||
assert received == [{"status": {"old": "a", "new": "b"}}]
|
||||
_attr_access_events.clear()
|
||||
|
||||
await mixin_session_expire.delete(obj)
|
||||
await mixin_session_expire.commit()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
events = [e for e in _attr_access_events if e["event"] == "delete"]
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0]["id"], uuid.UUID)
|
||||
assert events[0]["name"] == "to-delete"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_sync_callback_with_args(self):
|
||||
"""_call_callback passes arguments to sync callbacks."""
|
||||
received = []
|
||||
async def test_on_update_pk_and_updated_field_accessible(
|
||||
self, mixin_session_expire
|
||||
):
|
||||
"""id and the new field value are readable inside on_update."""
|
||||
obj = AttrAccessModel(name="original")
|
||||
mixin_session_expire.add(obj)
|
||||
await mixin_session_expire.commit()
|
||||
await asyncio.sleep(0)
|
||||
_attr_access_events.clear()
|
||||
|
||||
def sync_fn(changes: dict) -> None:
|
||||
received.append(changes)
|
||||
obj.name = "updated"
|
||||
await mixin_session_expire.commit()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
_call_callback(loop, sync_fn, {"x": 1})
|
||||
assert received == [{"x": 1}]
|
||||
events = [e for e in _attr_access_events if e["event"] == "update"]
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0]["id"], uuid.UUID)
|
||||
assert events[0]["name"] == "updated"
|
||||
|
||||
Reference in New Issue
Block a user