mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-15 22:26:25 +02:00
fix: snapshot nullable columns correctly in WatchedFieldsMixin callback (#194)
This commit is contained in:
@@ -27,6 +27,7 @@ _SESSION_CREATES = "_ft_creates"
|
|||||||
_SESSION_DELETES = "_ft_deletes"
|
_SESSION_DELETES = "_ft_deletes"
|
||||||
_SESSION_UPDATES = "_ft_updates"
|
_SESSION_UPDATES = "_ft_updates"
|
||||||
_SESSION_SAVEPOINT_DEPTH = "_ft_sp_depth"
|
_SESSION_SAVEPOINT_DEPTH = "_ft_sp_depth"
|
||||||
|
_DEFERRED_STRATEGY_KEY = (("deferred", True), ("instrument", True))
|
||||||
|
|
||||||
|
|
||||||
class ModelEvent(str, Enum):
|
class ModelEvent(str, Enum):
|
||||||
@@ -60,11 +61,22 @@ def _snapshot_column_attrs(obj: Any) -> dict[str, Any]:
|
|||||||
"""Read currently-loaded column values into a plain dict."""
|
"""Read currently-loaded column values into a plain dict."""
|
||||||
state = sa_inspect(obj) # InstanceState
|
state = sa_inspect(obj) # InstanceState
|
||||||
state_dict = state.dict
|
state_dict = state.dict
|
||||||
return {
|
snapshot: dict[str, Any] = {}
|
||||||
prop.key: state_dict[prop.key]
|
for prop in state.mapper.column_attrs:
|
||||||
for prop in state.mapper.column_attrs
|
if prop.key in state_dict:
|
||||||
if prop.key in state_dict
|
snapshot[prop.key] = state_dict[prop.key]
|
||||||
}
|
elif (
|
||||||
|
not state.expired
|
||||||
|
and prop.strategy_key != _DEFERRED_STRATEGY_KEY
|
||||||
|
and all(
|
||||||
|
col.nullable
|
||||||
|
and col.server_default is None
|
||||||
|
and col.server_onupdate is None
|
||||||
|
for col in prop.columns
|
||||||
|
)
|
||||||
|
):
|
||||||
|
snapshot[prop.key] = None
|
||||||
|
return snapshot
|
||||||
|
|
||||||
|
|
||||||
def _get_watched_fields(cls: type) -> list[str] | None:
|
def _get_watched_fields(cls: type) -> list[str] | None:
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ from fastapi_toolsets.models.watched import (
|
|||||||
_after_flush,
|
_after_flush,
|
||||||
_after_flush_postexec,
|
_after_flush_postexec,
|
||||||
_after_rollback,
|
_after_rollback,
|
||||||
|
_snapshot_column_attrs,
|
||||||
_task_error_handler,
|
_task_error_handler,
|
||||||
_upsert_changes,
|
_upsert_changes,
|
||||||
)
|
)
|
||||||
@@ -213,20 +214,36 @@ class AttrAccessModel(MixinBase, UUIDMixin, WatchedFieldsMixin):
|
|||||||
__tablename__ = "mixin_attr_access_models"
|
__tablename__ = "mixin_attr_access_models"
|
||||||
|
|
||||||
name: Mapped[str] = mapped_column(String(50))
|
name: Mapped[str] = mapped_column(String(50))
|
||||||
|
callback_url: Mapped[str | None] = mapped_column(String(200), nullable=True)
|
||||||
|
|
||||||
async def on_create(self) -> None:
|
async def on_create(self) -> None:
|
||||||
_attr_access_events.append(
|
_attr_access_events.append(
|
||||||
{"event": "create", "id": self.id, "name": self.name}
|
{
|
||||||
|
"event": "create",
|
||||||
|
"id": self.id,
|
||||||
|
"name": self.name,
|
||||||
|
"callback_url": self.callback_url,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
async def on_delete(self) -> None:
|
async def on_delete(self) -> None:
|
||||||
_attr_access_events.append(
|
_attr_access_events.append(
|
||||||
{"event": "delete", "id": self.id, "name": self.name}
|
{
|
||||||
|
"event": "delete",
|
||||||
|
"id": self.id,
|
||||||
|
"name": self.name,
|
||||||
|
"callback_url": self.callback_url,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
async def on_update(self, changes: dict) -> None:
|
async def on_update(self, changes: dict) -> None:
|
||||||
_attr_access_events.append(
|
_attr_access_events.append(
|
||||||
{"event": "update", "id": self.id, "name": self.name}
|
{
|
||||||
|
"event": "update",
|
||||||
|
"id": self.id,
|
||||||
|
"name": self.name,
|
||||||
|
"callback_url": self.callback_url,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -1279,3 +1296,67 @@ class TestAttributeAccessInCallbacks:
|
|||||||
assert len(events) == 1
|
assert len(events) == 1
|
||||||
assert isinstance(events[0]["id"], uuid.UUID)
|
assert isinstance(events[0]["id"], uuid.UUID)
|
||||||
assert events[0]["name"] == "updated"
|
assert events[0]["name"] == "updated"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_nullable_column_none_accessible_in_on_create(
|
||||||
|
self, mixin_session_expire
|
||||||
|
):
|
||||||
|
"""Nullable column left as None is accessible in on_create without greenlet error."""
|
||||||
|
obj = AttrAccessModel(name="no-url") # callback_url not set → None
|
||||||
|
mixin_session_expire.add(obj)
|
||||||
|
await mixin_session_expire.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
events = [e for e in _attr_access_events if e["event"] == "create"]
|
||||||
|
assert len(events) == 1
|
||||||
|
assert events[0]["callback_url"] is None
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_nullable_column_with_value_accessible_in_on_create(
|
||||||
|
self, mixin_session_expire
|
||||||
|
):
|
||||||
|
"""Nullable column set to a value is accessible in on_create without greenlet error."""
|
||||||
|
obj = AttrAccessModel(name="with-url", callback_url="https://example.com/hook")
|
||||||
|
mixin_session_expire.add(obj)
|
||||||
|
await mixin_session_expire.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
events = [e for e in _attr_access_events if e["event"] == "create"]
|
||||||
|
assert len(events) == 1
|
||||||
|
assert events[0]["callback_url"] == "https://example.com/hook"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_nullable_column_accessible_after_update_to_none(
|
||||||
|
self, mixin_session_expire
|
||||||
|
):
|
||||||
|
"""Nullable column updated to None is accessible in on_update without greenlet error."""
|
||||||
|
obj = AttrAccessModel(name="x", callback_url="https://example.com/hook")
|
||||||
|
mixin_session_expire.add(obj)
|
||||||
|
await mixin_session_expire.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
_attr_access_events.clear()
|
||||||
|
|
||||||
|
obj.callback_url = None
|
||||||
|
await mixin_session_expire.commit()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
events = [e for e in _attr_access_events if e["event"] == "update"]
|
||||||
|
assert len(events) == 1
|
||||||
|
assert events[0]["callback_url"] is None
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_expired_nullable_column_not_inferred_as_none(
|
||||||
|
self, mixin_session_expire
|
||||||
|
):
|
||||||
|
"""A nullable column with a real value that is expired (by a prior
|
||||||
|
expire_on_commit) must not be inferred as None in the snapshot — its
|
||||||
|
actual value is unknown without a DB refresh."""
|
||||||
|
obj = AttrAccessModel(name="original", callback_url="https://example.com/hook")
|
||||||
|
mixin_session_expire.add(obj)
|
||||||
|
await mixin_session_expire.commit()
|
||||||
|
# expire_on_commit fired → obj.state.expired=True, callback_url not in state.dict
|
||||||
|
|
||||||
|
snapshot = _snapshot_column_attrs(obj)
|
||||||
|
|
||||||
|
# callback_url has a real DB value but is expired — must not be snapshotted as None.
|
||||||
|
assert "callback_url" not in snapshot
|
||||||
|
|||||||
Reference in New Issue
Block a user