mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-16 06:36:26 +02:00
feat: add WatchedFieldsMixin (#148)
* feat/add WatchedFieldsMixin and watch_fields decorator for field-change monitoring * docs: add WatchedFieldsMixin * feat: add on_event, on_create and on_delete * docs: update README
This commit is contained in:
21
src/fastapi_toolsets/models/__init__.py
Normal file
21
src/fastapi_toolsets/models/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""SQLAlchemy model mixins for common column patterns."""
|
||||
|
||||
from .columns import (
|
||||
CreatedAtMixin,
|
||||
TimestampMixin,
|
||||
UUIDMixin,
|
||||
UUIDv7Mixin,
|
||||
UpdatedAtMixin,
|
||||
)
|
||||
from .watched import ModelEvent, WatchedFieldsMixin, watch
|
||||
|
||||
__all__ = [
|
||||
"ModelEvent",
|
||||
"UUIDMixin",
|
||||
"UUIDv7Mixin",
|
||||
"CreatedAtMixin",
|
||||
"UpdatedAtMixin",
|
||||
"TimestampMixin",
|
||||
"WatchedFieldsMixin",
|
||||
"watch",
|
||||
]
|
||||
58
src/fastapi_toolsets/models/columns.py
Normal file
58
src/fastapi_toolsets/models/columns.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""SQLAlchemy column mixins for common column patterns."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, Uuid, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
__all__ = [
|
||||
"UUIDMixin",
|
||||
"UUIDv7Mixin",
|
||||
"CreatedAtMixin",
|
||||
"UpdatedAtMixin",
|
||||
"TimestampMixin",
|
||||
]
|
||||
|
||||
|
||||
class UUIDMixin:
|
||||
"""Mixin that adds a UUID primary key auto-generated by the database."""
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid,
|
||||
primary_key=True,
|
||||
server_default=text("gen_random_uuid()"),
|
||||
)
|
||||
|
||||
|
||||
class UUIDv7Mixin:
|
||||
"""Mixin that adds a UUIDv7 primary key auto-generated by the database."""
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid,
|
||||
primary_key=True,
|
||||
server_default=text("uuidv7()"),
|
||||
)
|
||||
|
||||
|
||||
class CreatedAtMixin:
|
||||
"""Mixin that adds a ``created_at`` timestamp column."""
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=text("clock_timestamp()"),
|
||||
)
|
||||
|
||||
|
||||
class UpdatedAtMixin:
|
||||
"""Mixin that adds an ``updated_at`` timestamp column."""
|
||||
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=text("clock_timestamp()"),
|
||||
onupdate=text("clock_timestamp()"),
|
||||
)
|
||||
|
||||
|
||||
class TimestampMixin(CreatedAtMixin, UpdatedAtMixin):
|
||||
"""Mixin that combines ``created_at`` and ``updated_at`` timestamp columns."""
|
||||
204
src/fastapi_toolsets/models/watched.py
Normal file
204
src/fastapi_toolsets/models/watched.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""Field-change monitoring via SQLAlchemy session events."""
|
||||
|
||||
import asyncio
|
||||
import weakref
|
||||
from collections.abc import Awaitable
|
||||
from enum import Enum
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy import inspect as sa_inspect
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..logger import get_logger
|
||||
|
||||
__all__ = ["ModelEvent", "WatchedFieldsMixin", "watch"]
|
||||
|
||||
_logger = get_logger()
|
||||
_T = TypeVar("_T")
|
||||
_CALLBACK_ERROR_MSG = "WatchedFieldsMixin callback raised an unhandled exception"
|
||||
_WATCHED_FIELDS: weakref.WeakKeyDictionary[type, list[str]] = (
|
||||
weakref.WeakKeyDictionary()
|
||||
)
|
||||
_SESSION_PENDING_NEW = "_ft_pending_new"
|
||||
_SESSION_CREATES = "_ft_creates"
|
||||
_SESSION_DELETES = "_ft_deletes"
|
||||
_SESSION_UPDATES = "_ft_updates"
|
||||
|
||||
|
||||
class ModelEvent(str, Enum):
|
||||
"""Event types emitted by :class:`WatchedFieldsMixin`."""
|
||||
|
||||
CREATE = "create"
|
||||
DELETE = "delete"
|
||||
UPDATE = "update"
|
||||
|
||||
|
||||
def watch(*fields: str) -> Any:
|
||||
"""Class decorator to filter which fields trigger ``on_update``.
|
||||
|
||||
Args:
|
||||
*fields: One or more field names to watch. At least one name is required.
|
||||
|
||||
Raises:
|
||||
ValueError: If called with no field names.
|
||||
"""
|
||||
if not fields:
|
||||
raise ValueError("@watch requires at least one field name.")
|
||||
|
||||
def decorator(cls: type[_T]) -> type[_T]:
|
||||
_WATCHED_FIELDS[cls] = list(fields)
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _upsert_changes(
|
||||
pending: dict[int, tuple[Any, dict[str, dict[str, Any]]]],
|
||||
obj: Any,
|
||||
changes: dict[str, dict[str, Any]],
|
||||
) -> None:
|
||||
"""Insert or merge *changes* into *pending* for *obj*."""
|
||||
key = id(obj)
|
||||
if key in pending:
|
||||
existing = pending[key][1]
|
||||
for field, change in changes.items():
|
||||
if field in existing:
|
||||
existing[field]["new"] = change["new"]
|
||||
else:
|
||||
existing[field] = change
|
||||
else:
|
||||
pending[key] = (obj, changes)
|
||||
|
||||
|
||||
@event.listens_for(AsyncSession.sync_session_class, "after_flush")
|
||||
def _after_flush(session: Any, flush_context: Any) -> None:
|
||||
# New objects: capture references while session.new is still populated.
|
||||
# Values are read in _after_flush_postexec once RETURNING has been processed.
|
||||
for obj in session.new:
|
||||
if isinstance(obj, WatchedFieldsMixin):
|
||||
session.info.setdefault(_SESSION_PENDING_NEW, []).append(obj)
|
||||
|
||||
# Deleted objects: capture before they leave the identity map.
|
||||
for obj in session.deleted:
|
||||
if isinstance(obj, WatchedFieldsMixin):
|
||||
session.info.setdefault(_SESSION_DELETES, []).append(obj)
|
||||
|
||||
# Dirty objects: read old/new from SQLAlchemy attribute history.
|
||||
for obj in session.dirty:
|
||||
if not isinstance(obj, WatchedFieldsMixin):
|
||||
continue
|
||||
|
||||
# None = not in dict = watch all fields; list = specific fields only
|
||||
watched = _WATCHED_FIELDS.get(type(obj))
|
||||
changes: dict[str, dict[str, Any]] = {}
|
||||
|
||||
attrs = (
|
||||
# Specific fields
|
||||
((field, sa_inspect(obj).attrs[field]) for field in watched)
|
||||
if watched is not None
|
||||
# All mapped fields
|
||||
else ((s.key, s) for s in sa_inspect(obj).attrs)
|
||||
)
|
||||
for field, attr_state in attrs:
|
||||
history = attr_state.history
|
||||
if history.has_changes() and history.deleted:
|
||||
changes[field] = {
|
||||
"old": history.deleted[0],
|
||||
"new": history.added[0] if history.added else None,
|
||||
}
|
||||
|
||||
if changes:
|
||||
_upsert_changes(
|
||||
session.info.setdefault(_SESSION_UPDATES, {}),
|
||||
obj,
|
||||
changes,
|
||||
)
|
||||
|
||||
|
||||
@event.listens_for(AsyncSession.sync_session_class, "after_flush_postexec")
|
||||
def _after_flush_postexec(session: Any, flush_context: Any) -> None:
|
||||
# New objects are now persistent and RETURNING values have been applied,
|
||||
# so server defaults (id, created_at, …) are available via getattr.
|
||||
pending_new: list[Any] = session.info.pop(_SESSION_PENDING_NEW, [])
|
||||
if not pending_new:
|
||||
return
|
||||
session.info.setdefault(_SESSION_CREATES, []).extend(pending_new)
|
||||
|
||||
|
||||
@event.listens_for(AsyncSession.sync_session_class, "after_rollback")
|
||||
def _after_rollback(session: Any) -> None:
|
||||
session.info.pop(_SESSION_PENDING_NEW, None)
|
||||
session.info.pop(_SESSION_CREATES, None)
|
||||
session.info.pop(_SESSION_DELETES, None)
|
||||
session.info.pop(_SESSION_UPDATES, None)
|
||||
|
||||
|
||||
def _task_error_handler(task: asyncio.Task[Any]) -> None:
|
||||
if not task.cancelled() and (exc := task.exception()):
|
||||
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
|
||||
|
||||
|
||||
def _call_callback(loop: asyncio.AbstractEventLoop, fn: Any, *args: Any) -> None:
|
||||
"""Dispatch *fn* with *args*, handling both sync and async callables."""
|
||||
try:
|
||||
result = fn(*args)
|
||||
except Exception as exc:
|
||||
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
|
||||
return
|
||||
if asyncio.iscoroutine(result):
|
||||
task = loop.create_task(result)
|
||||
task.add_done_callback(_task_error_handler)
|
||||
|
||||
|
||||
@event.listens_for(AsyncSession.sync_session_class, "after_commit")
|
||||
def _after_commit(session: Any) -> None:
|
||||
creates: list[Any] = session.info.pop(_SESSION_CREATES, [])
|
||||
deletes: list[Any] = session.info.pop(_SESSION_DELETES, [])
|
||||
field_changes: dict[int, tuple[Any, dict[str, dict[str, Any]]]] = session.info.pop(
|
||||
_SESSION_UPDATES, {}
|
||||
)
|
||||
|
||||
if not creates and not deletes and not field_changes:
|
||||
return
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return
|
||||
|
||||
for obj in creates:
|
||||
_call_callback(loop, obj.on_create)
|
||||
|
||||
for obj in deletes:
|
||||
_call_callback(loop, obj.on_delete)
|
||||
|
||||
for obj, changes in field_changes.values():
|
||||
_call_callback(loop, obj.on_update, changes)
|
||||
|
||||
|
||||
class WatchedFieldsMixin:
|
||||
"""Mixin that enables lifecycle callbacks for SQLAlchemy models."""
|
||||
|
||||
def on_event(
|
||||
self, event: ModelEvent, changes: dict[str, dict[str, Any]] | None = None
|
||||
) -> Awaitable[None] | None:
|
||||
"""Catch-all callback fired for every lifecycle event.
|
||||
|
||||
Args:
|
||||
event: The event type (:attr:`ModelEvent.CREATE`, :attr:`ModelEvent.DELETE`,
|
||||
or :attr:`ModelEvent.UPDATE`).
|
||||
changes: Field changes for :attr:`ModelEvent.UPDATE`, ``None`` otherwise.
|
||||
"""
|
||||
|
||||
def on_create(self) -> Awaitable[None] | None:
|
||||
"""Called after INSERT commit."""
|
||||
return self.on_event(ModelEvent.CREATE)
|
||||
|
||||
def on_delete(self) -> Awaitable[None] | None:
|
||||
"""Called after DELETE commit."""
|
||||
return self.on_event(ModelEvent.DELETE)
|
||||
|
||||
def on_update(self, changes: dict[str, dict[str, Any]]) -> Awaitable[None] | None:
|
||||
"""Called after UPDATE commit when watched fields change."""
|
||||
return self.on_event(ModelEvent.UPDATE, changes=changes)
|
||||
Reference in New Issue
Block a user