mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-15 22:26:25 +02:00
feat: rework async event system (#196)
* feat: rework async event system * docs: add v3 migration guide * feat: add cache * enhancements
This commit is contained in:
@@ -48,7 +48,8 @@ uv add "fastapi-toolsets[all]"
|
|||||||
- **Database**: Session management, transaction helpers, table locking, and polling-based row change detection
|
- **Database**: Session management, transaction helpers, table locking, and polling-based row change detection
|
||||||
- **Dependencies**: FastAPI dependency factories (`PathDependency`, `BodyDependency`) for automatic DB lookups from path or body parameters
|
- **Dependencies**: FastAPI dependency factories (`PathDependency`, `BodyDependency`) for automatic DB lookups from path or body parameters
|
||||||
- **Fixtures**: Fixture system with dependency management, context support, and pytest integration
|
- **Fixtures**: Fixture system with dependency management, context support, and pytest integration
|
||||||
- **Model Mixins**: SQLAlchemy mixins for common column patterns (`UUIDMixin`, `UUIDv7Mixin`, `CreatedAtMixin`, `UpdatedAtMixin`, `TimestampMixin`) and lifecycle callbacks (`WatchedFieldsMixin`, `@watch`) that fire after commit for insert, update, and delete events
|
- **Model Mixins**: SQLAlchemy mixins for common column patterns (`UUIDMixin`, `UUIDv7Mixin`, `CreatedAtMixin`, `UpdatedAtMixin`, `TimestampMixin`)
|
||||||
|
- **Lifecycle Events**: Post-commit event system (`EventSession`, `listens_for`) that dispatches async/sync callbacks for insert, update, and delete operations
|
||||||
- **Standardized API Responses**: Consistent response format with `Response`, `ErrorResponse`, `PaginatedResponse`, `CursorPaginatedResponse` and `OffsetPaginatedResponse`.
|
- **Standardized API Responses**: Consistent response format with `Response`, `ErrorResponse`, `PaginatedResponse`, `CursorPaginatedResponse` and `OffsetPaginatedResponse`.
|
||||||
- **Exception Handling**: Structured error responses with automatic OpenAPI documentation
|
- **Exception Handling**: Structured error responses with automatic OpenAPI documentation
|
||||||
- **Logging**: Logging configuration with uvicorn integration via `configure_logging` and `get_logger`
|
- **Logging**: Logging configuration with uvicorn integration via `configure_logging` and `get_logger`
|
||||||
|
|||||||
@@ -48,7 +48,8 @@ uv add "fastapi-toolsets[all]"
|
|||||||
- **Database**: Session management, transaction helpers, table locking, and polling-based row change detection
|
- **Database**: Session management, transaction helpers, table locking, and polling-based row change detection
|
||||||
- **Dependencies**: FastAPI dependency factories (`PathDependency`, `BodyDependency`) for automatic DB lookups from path or body parameters
|
- **Dependencies**: FastAPI dependency factories (`PathDependency`, `BodyDependency`) for automatic DB lookups from path or body parameters
|
||||||
- **Fixtures**: Fixture system with dependency management, context support, and pytest integration
|
- **Fixtures**: Fixture system with dependency management, context support, and pytest integration
|
||||||
- **Model Mixins**: SQLAlchemy mixins for common column patterns (`UUIDMixin`, `UUIDv7Mixin`, `CreatedAtMixin`, `UpdatedAtMixin`, `TimestampMixin`) and lifecycle callbacks (`WatchedFieldsMixin`) that fire after commit for insert, update, and delete events.
|
- **Model Mixins**: SQLAlchemy mixins for common column patterns (`UUIDMixin`, `UUIDv7Mixin`, `CreatedAtMixin`, `UpdatedAtMixin`, `TimestampMixin`).
|
||||||
|
- **Lifecycle Events**: Post-commit event system (`EventSession`, `listens_for`) that dispatches async/sync callbacks for insert, update, and delete operations.
|
||||||
- **Standardized API Responses**: Consistent response format with `Response`, `ErrorResponse`, `PaginatedResponse`, `CursorPaginatedResponse` and `OffsetPaginatedResponse`.
|
- **Standardized API Responses**: Consistent response format with `Response`, `ErrorResponse`, `PaginatedResponse`, `CursorPaginatedResponse` and `OffsetPaginatedResponse`.
|
||||||
- **Exception Handling**: Structured error responses with automatic OpenAPI documentation
|
- **Exception Handling**: Structured error responses with automatic OpenAPI documentation
|
||||||
- **Logging**: Logging configuration with uvicorn integration via `configure_logging` and `get_logger`
|
- **Logging**: Logging configuration with uvicorn integration via `configure_logging` and `get_logger`
|
||||||
|
|||||||
93
docs/migration/v3.md
Normal file
93
docs/migration/v3.md
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
# Migrating to v3.0
|
||||||
|
|
||||||
|
This page covers every breaking change introduced in **v3.0** and the steps required to update your code.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Models
|
||||||
|
|
||||||
|
The lifecycle event system has been rewritten. Callbacks are now registered with a module-level [`listens_for`](../reference/models.md#fastapi_toolsets.models.listens_for) decorator and dispatched by [`EventSession`](../reference/models.md#fastapi_toolsets.models.EventSession), replacing the mixin-based approach from `v2`.
|
||||||
|
|
||||||
|
### `WatchedFieldsMixin` and `@watch` removed
|
||||||
|
|
||||||
|
Importing `WatchedFieldsMixin` or `watch` will raise `ImportError`.
|
||||||
|
|
||||||
|
Model method callbacks (`on_create`, `on_delete`, `on_update`) and the `@watch` decorator are replaced by:
|
||||||
|
|
||||||
|
1. **`__watched_fields__`** — a plain class attribute to restrict which field changes trigger `UPDATE` events (replaces `@watch`).
|
||||||
|
2. **`@listens_for`** — a module-level decorator to register callbacks for one or more [`ModelEvent`](../reference/models.md#fastapi_toolsets.models.ModelEvent) types (replaces `on_create` / `on_delete` / `on_update` methods).
|
||||||
|
|
||||||
|
=== "Before (`v2`)"
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.models import WatchedFieldsMixin, watch
|
||||||
|
|
||||||
|
@watch("status")
|
||||||
|
class Order(Base, UUIDMixin, WatchedFieldsMixin):
|
||||||
|
__tablename__ = "orders"
|
||||||
|
|
||||||
|
status: Mapped[str]
|
||||||
|
|
||||||
|
async def on_create(self):
|
||||||
|
await notify_new_order(self.id)
|
||||||
|
|
||||||
|
async def on_update(self, changes):
|
||||||
|
if "status" in changes:
|
||||||
|
await notify_status_change(self.id, changes["status"])
|
||||||
|
|
||||||
|
async def on_delete(self):
|
||||||
|
await notify_order_cancelled(self.id)
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "Now (`v3`)"
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.models import ModelEvent, UUIDMixin, listens_for
|
||||||
|
|
||||||
|
class Order(Base, UUIDMixin):
|
||||||
|
__tablename__ = "orders"
|
||||||
|
__watched_fields__ = ("status",)
|
||||||
|
|
||||||
|
status: Mapped[str]
|
||||||
|
|
||||||
|
@listens_for(Order, [ModelEvent.CREATE])
|
||||||
|
async def on_order_created(order: Order, event_type: ModelEvent, changes: None):
|
||||||
|
await notify_new_order(order.id)
|
||||||
|
|
||||||
|
@listens_for(Order, [ModelEvent.UPDATE])
|
||||||
|
async def on_order_updated(order: Order, event_type: ModelEvent, changes: dict):
|
||||||
|
if "status" in changes:
|
||||||
|
await notify_status_change(order.id, changes["status"])
|
||||||
|
|
||||||
|
@listens_for(Order, [ModelEvent.DELETE])
|
||||||
|
async def on_order_deleted(order: Order, event_type: ModelEvent, changes: None):
|
||||||
|
await notify_order_cancelled(order.id)
|
||||||
|
```
|
||||||
|
|
||||||
|
### `EventSession` now required
|
||||||
|
|
||||||
|
Without `EventSession`, lifecycle callbacks will silently stop firing.
|
||||||
|
|
||||||
|
Callbacks are now dispatched inside `EventSession.commit()` rather than via background tasks. Pass it as the session class when creating your session factory:
|
||||||
|
|
||||||
|
=== "Before (`v2`)"
|
||||||
|
|
||||||
|
```python
|
||||||
|
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
engine = create_async_engine("postgresql+asyncpg://...")
|
||||||
|
SessionLocal = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "Now (`v3`)"
|
||||||
|
|
||||||
|
```python
|
||||||
|
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||||
|
from fastapi_toolsets.models import EventSession
|
||||||
|
|
||||||
|
engine = create_async_engine("postgresql+asyncpg://...")
|
||||||
|
SessionLocal = async_sessionmaker(engine, expire_on_commit=False, class_=EventSession)
|
||||||
|
```
|
||||||
|
|
||||||
|
!!! note
|
||||||
|
If you use `create_db_session` from `fastapi_toolsets.pytest`, the session already uses `EventSession` — no changes needed in tests.
|
||||||
@@ -117,139 +117,118 @@ class Article(Base, UUIDMixin, TimestampMixin):
|
|||||||
title: Mapped[str]
|
title: Mapped[str]
|
||||||
```
|
```
|
||||||
|
|
||||||
### [`WatchedFieldsMixin`](../reference/models.md#fastapi_toolsets.models.WatchedFieldsMixin)
|
## Lifecycle events
|
||||||
|
|
||||||
!!! info "Added in `v2.4`"
|
The event system provides lifecycle callbacks that fire **after commit**. If the transaction rolls back, no callback fires.
|
||||||
|
|
||||||
`WatchedFieldsMixin` provides lifecycle callbacks that fire **after commit** — meaning the row is durably persisted when your callback runs. If the transaction rolls back, no callback fires.
|
### Setup
|
||||||
|
|
||||||
Three callbacks are available, each corresponding to a [`ModelEvent`](../reference/models.md#fastapi_toolsets.models.ModelEvent) value:
|
Event dispatch requires [`EventSession`](../reference/models.md#fastapi_toolsets.models.EventSession). Pass it as the session class when creating your session factory:
|
||||||
|
|
||||||
| Callback | Event | Trigger |
|
|
||||||
|---|---|---|
|
|
||||||
| `on_create()` | `ModelEvent.CREATE` | After `INSERT` |
|
|
||||||
| `on_delete()` | `ModelEvent.DELETE` | After `DELETE` |
|
|
||||||
| `on_update(changes)` | `ModelEvent.UPDATE` | After `UPDATE` on a watched field |
|
|
||||||
|
|
||||||
Server-side defaults (e.g. `id`, `created_at`) are fully populated in all callbacks. All callbacks support both `async def` and plain `def`. Use `@watch` to restrict which fields trigger `on_update`:
|
|
||||||
|
|
||||||
| Decorator | `on_update` behaviour |
|
|
||||||
|---|---|
|
|
||||||
| `@watch("status", "role")` | Only fires when `status` or `role` changes |
|
|
||||||
| *(no decorator)* | Fires when **any** mapped field changes |
|
|
||||||
|
|
||||||
`@watch` is inherited through the class hierarchy. If a subclass does not declare its own `@watch`, it uses the filter from the nearest decorated parent. Applying `@watch` on the subclass overrides the parent's filter:
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@watch("status")
|
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||||
class Order(Base, UUIDMixin, WatchedFieldsMixin):
|
from fastapi_toolsets.models import EventSession
|
||||||
|
|
||||||
|
engine = create_async_engine("postgresql+asyncpg://...")
|
||||||
|
SessionLocal = async_sessionmaker(engine, expire_on_commit=False, class_=EventSession)
|
||||||
|
```
|
||||||
|
|
||||||
|
!!! info "Callbacks fire on `session.commit()` only — not on savepoints."
|
||||||
|
Savepoints created by [`get_transaction`](db.md) or `begin_nested()` do **not**
|
||||||
|
trigger callbacks. All events accumulated across flushes are dispatched once
|
||||||
|
when the outermost `commit()` is called.
|
||||||
|
|
||||||
|
### Events
|
||||||
|
|
||||||
|
Three event types are available, each corresponding to a [`ModelEvent`](../reference/models.md#fastapi_toolsets.models.ModelEvent) value:
|
||||||
|
|
||||||
|
| Event | Trigger |
|
||||||
|
|---|---|
|
||||||
|
| `ModelEvent.CREATE` | After `INSERT` commit |
|
||||||
|
| `ModelEvent.DELETE` | After `DELETE` commit |
|
||||||
|
| `ModelEvent.UPDATE` | After `UPDATE` commit on a watched field |
|
||||||
|
|
||||||
|
!!! warning "Callbacks fire only for ORM-level changes. Rows updated via raw SQL (`UPDATE ... SET ...`) are not detected."
|
||||||
|
|
||||||
|
### Watched fields
|
||||||
|
|
||||||
|
Set `__watched_fields__` on the model to restrict which field changes trigger `UPDATE` events. It must be a `tuple[str, ...]` — any other type raises `TypeError`:
|
||||||
|
|
||||||
|
| Class attribute | `UPDATE` behaviour |
|
||||||
|
|---|---|
|
||||||
|
| `__watched_fields__ = ("status", "role")` | Only fires when `status` or `role` changes |
|
||||||
|
| *(not set)* | Fires when **any** mapped field changes |
|
||||||
|
|
||||||
|
`__watched_fields__` is inherited through the class hierarchy via normal Python MRO. A subclass can override it:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class Order(Base, UUIDMixin):
|
||||||
|
__watched_fields__ = ("status",)
|
||||||
...
|
...
|
||||||
|
|
||||||
class UrgentOrder(Order):
|
class UrgentOrder(Order):
|
||||||
# inherits @watch("status") — on_update fires only for status changes
|
# inherits __watched_fields__ = ("status",)
|
||||||
...
|
...
|
||||||
|
|
||||||
@watch("priority")
|
|
||||||
class PriorityOrder(Order):
|
class PriorityOrder(Order):
|
||||||
# overrides parent — on_update fires only for priority changes
|
__watched_fields__ = ("priority",)
|
||||||
|
# overrides parent — UPDATE fires only for priority changes
|
||||||
...
|
...
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Option 1 — catch-all with `on_event`
|
### Registering handlers
|
||||||
|
|
||||||
Override `on_event` to handle all event types in one place. The specific methods delegate here by default:
|
Register handlers with the [`listens_for`](../reference/models.md#fastapi_toolsets.models.listens_for) decorator. Every callback receives three arguments: the model instance, the [`ModelEvent`](../reference/models.md#fastapi_toolsets.models.ModelEvent) that triggered it, and a `changes` dict (`None` for `CREATE` and `DELETE`):
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from fastapi_toolsets.models import ModelEvent, UUIDMixin, WatchedFieldsMixin, watch
|
from fastapi_toolsets.models import ModelEvent, UUIDMixin, listens_for
|
||||||
|
|
||||||
@watch("status")
|
class Order(Base, UUIDMixin):
|
||||||
class Order(Base, UUIDMixin, WatchedFieldsMixin):
|
|
||||||
__tablename__ = "orders"
|
__tablename__ = "orders"
|
||||||
|
__watched_fields__ = ("status",)
|
||||||
|
|
||||||
status: Mapped[str]
|
status: Mapped[str]
|
||||||
|
|
||||||
async def on_event(self, event: ModelEvent, changes: dict | None = None) -> None:
|
@listens_for(Order, [ModelEvent.CREATE])
|
||||||
if event == ModelEvent.CREATE:
|
async def on_order_created(order: Order, event_type: ModelEvent, changes: None):
|
||||||
await notify_new_order(self.id)
|
await notify_new_order(order.id)
|
||||||
elif event == ModelEvent.DELETE:
|
|
||||||
await notify_order_cancelled(self.id)
|
@listens_for(Order, [ModelEvent.DELETE])
|
||||||
elif event == ModelEvent.UPDATE:
|
async def on_order_deleted(order: Order, event_type: ModelEvent, changes: None):
|
||||||
await notify_status_change(self.id, changes["status"])
|
await notify_order_cancelled(order.id)
|
||||||
|
|
||||||
|
@listens_for(Order, [ModelEvent.UPDATE])
|
||||||
|
async def on_order_updated(order: Order, event_type: ModelEvent, changes: dict):
|
||||||
|
if "status" in changes:
|
||||||
|
await notify_status_change(order.id, changes["status"])
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Option 2 — targeted overrides
|
Multiple handlers can be registered for the same model and event. Handlers registered on a parent class also fire for subclass instances.
|
||||||
|
|
||||||
Override individual methods for more focused logic:
|
A single handler can listen for multiple events at once. When `event_types` is omitted, the handler fires for all events:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@watch("status")
|
@listens_for(Order, [ModelEvent.CREATE, ModelEvent.UPDATE])
|
||||||
class Order(Base, UUIDMixin, WatchedFieldsMixin):
|
async def on_order_changed(order: Order, event_type: ModelEvent, changes: dict | None):
|
||||||
__tablename__ = "orders"
|
await invalidate_cache(order.id)
|
||||||
|
|
||||||
status: Mapped[str]
|
@listens_for(Order) # all events
|
||||||
|
async def on_any_order_event(order: Order, event_type: ModelEvent, changes: dict | None):
|
||||||
async def on_create(self) -> None:
|
await audit_log(order.id, event_type)
|
||||||
await notify_new_order(self.id)
|
|
||||||
|
|
||||||
async def on_delete(self) -> None:
|
|
||||||
await notify_order_cancelled(self.id)
|
|
||||||
|
|
||||||
async def on_update(self, changes: dict) -> None:
|
|
||||||
if "status" in changes:
|
|
||||||
old = changes["status"]["old"]
|
|
||||||
new = changes["status"]["new"]
|
|
||||||
await notify_status_change(self.id, old, new)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Field changes format
|
### Field changes format
|
||||||
|
|
||||||
The `changes` dict maps each watched field that changed to `{"old": ..., "new": ...}`. Only fields that actually changed are included:
|
The `changes` dict maps each watched field that changed to `{"old": ..., "new": ...}`. Only fields that actually changed are included. For `CREATE` and `DELETE` events, `changes` is `None`:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
# CREATE / DELETE → changes is None
|
||||||
# status changed → {"status": {"old": "pending", "new": "shipped"}}
|
# status changed → {"status": {"old": "pending", "new": "shipped"}}
|
||||||
# two fields changed → {"status": {...}, "assigned_to": {...}}
|
# two fields changed → {"status": {...}, "assigned_to": {...}}
|
||||||
```
|
```
|
||||||
|
|
||||||
!!! info "Multiple flushes in one transaction are merged: the earliest `old` and latest `new` are preserved, and `on_update` fires only once per commit."
|
!!! info "Multiple flushes in one transaction are merged: the earliest `old` and latest `new` are preserved, and `on_update` fires only once per commit."
|
||||||
|
|
||||||
!!! warning "Callbacks fire only for ORM-level changes. Rows updated via raw SQL (`UPDATE ... SET ...`) are not detected."
|
|
||||||
|
|
||||||
!!! warning "Callbacks fire when the **outermost active context** (savepoint or transaction) commits."
|
|
||||||
If you create several related objects using `CrudFactory.create` and need
|
|
||||||
callbacks to see all of them (including associations), wrap the whole
|
|
||||||
operation in a single [`get_transaction`](db.md) or [`lock_tables`](db.md)
|
|
||||||
block. Without it, each `create` call commits its own savepoint and
|
|
||||||
`on_create` fires before the remaining objects exist.
|
|
||||||
|
|
||||||
```python
|
|
||||||
from fastapi_toolsets.db import get_transaction
|
|
||||||
|
|
||||||
async with get_transaction(session):
|
|
||||||
order = await OrderCrud.create(session, order_data)
|
|
||||||
item = await ItemCrud.create(session, item_data)
|
|
||||||
await session.refresh(order, attribute_names=["items"])
|
|
||||||
order.items.append(item)
|
|
||||||
# on_create fires here for both order and item,
|
|
||||||
# with the full association already committed.
|
|
||||||
```
|
|
||||||
|
|
||||||
## Composing mixins
|
|
||||||
|
|
||||||
All mixins can be combined in any order. The only constraint is that exactly one primary key must be defined — either via `UUIDMixin` or directly on the model.
|
|
||||||
|
|
||||||
```python
|
|
||||||
from fastapi_toolsets.models import UUIDMixin, TimestampMixin
|
|
||||||
|
|
||||||
class Event(Base, UUIDMixin, TimestampMixin):
|
|
||||||
__tablename__ = "events"
|
|
||||||
name: Mapped[str]
|
|
||||||
|
|
||||||
class Counter(Base, UpdatedAtMixin):
|
|
||||||
__tablename__ = "counters"
|
|
||||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
|
||||||
value: Mapped[int]
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
[:material-api: API Reference](../reference/models.md)
|
[:material-api: API Reference](../reference/models.md)
|
||||||
|
|||||||
@@ -6,17 +6,19 @@ You can import them directly from `fastapi_toolsets.models`:
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
from fastapi_toolsets.models import (
|
from fastapi_toolsets.models import (
|
||||||
|
EventSession,
|
||||||
ModelEvent,
|
ModelEvent,
|
||||||
UUIDMixin,
|
UUIDMixin,
|
||||||
UUIDv7Mixin,
|
UUIDv7Mixin,
|
||||||
CreatedAtMixin,
|
CreatedAtMixin,
|
||||||
UpdatedAtMixin,
|
UpdatedAtMixin,
|
||||||
TimestampMixin,
|
TimestampMixin,
|
||||||
WatchedFieldsMixin,
|
listens_for,
|
||||||
watch,
|
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## ::: fastapi_toolsets.models.EventSession
|
||||||
|
|
||||||
## ::: fastapi_toolsets.models.ModelEvent
|
## ::: fastapi_toolsets.models.ModelEvent
|
||||||
|
|
||||||
## ::: fastapi_toolsets.models.UUIDMixin
|
## ::: fastapi_toolsets.models.UUIDMixin
|
||||||
@@ -29,6 +31,4 @@ from fastapi_toolsets.models import (
|
|||||||
|
|
||||||
## ::: fastapi_toolsets.models.TimestampMixin
|
## ::: fastapi_toolsets.models.TimestampMixin
|
||||||
|
|
||||||
## ::: fastapi_toolsets.models.WatchedFieldsMixin
|
## ::: fastapi_toolsets.models.listens_for
|
||||||
|
|
||||||
## ::: fastapi_toolsets.models.watch
|
|
||||||
|
|||||||
@@ -24,9 +24,12 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
_SessionT = TypeVar("_SessionT", bound=AsyncSession)
|
||||||
|
|
||||||
|
|
||||||
def create_db_dependency(
|
def create_db_dependency(
|
||||||
session_maker: async_sessionmaker[AsyncSession],
|
session_maker: async_sessionmaker[_SessionT],
|
||||||
) -> Callable[[], AsyncGenerator[AsyncSession, None]]:
|
) -> Callable[[], AsyncGenerator[_SessionT, None]]:
|
||||||
"""Create a FastAPI dependency for database sessions.
|
"""Create a FastAPI dependency for database sessions.
|
||||||
|
|
||||||
Creates a dependency function that yields a session and auto-commits
|
Creates a dependency function that yields a session and auto-commits
|
||||||
@@ -54,7 +57,7 @@ def create_db_dependency(
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
async def get_db() -> AsyncGenerator[_SessionT, None]:
|
||||||
async with session_maker() as session:
|
async with session_maker() as session:
|
||||||
await session.connection()
|
await session.connection()
|
||||||
yield session
|
yield session
|
||||||
@@ -65,8 +68,8 @@ def create_db_dependency(
|
|||||||
|
|
||||||
|
|
||||||
def create_db_context(
|
def create_db_context(
|
||||||
session_maker: async_sessionmaker[AsyncSession],
|
session_maker: async_sessionmaker[_SessionT],
|
||||||
) -> Callable[[], AbstractAsyncContextManager[AsyncSession]]:
|
) -> Callable[[], AbstractAsyncContextManager[_SessionT]]:
|
||||||
"""Create a context manager for database sessions.
|
"""Create a context manager for database sessions.
|
||||||
|
|
||||||
Creates a context manager for use outside of FastAPI request handlers,
|
Creates a context manager for use outside of FastAPI request handlers,
|
||||||
|
|||||||
@@ -7,15 +7,15 @@ from .columns import (
|
|||||||
UUIDv7Mixin,
|
UUIDv7Mixin,
|
||||||
UpdatedAtMixin,
|
UpdatedAtMixin,
|
||||||
)
|
)
|
||||||
from .watched import ModelEvent, WatchedFieldsMixin, watch
|
from .watched import EventSession, ModelEvent, listens_for
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"EventSession",
|
||||||
"ModelEvent",
|
"ModelEvent",
|
||||||
"UUIDMixin",
|
"UUIDMixin",
|
||||||
"UUIDv7Mixin",
|
"UUIDv7Mixin",
|
||||||
"CreatedAtMixin",
|
"CreatedAtMixin",
|
||||||
"UpdatedAtMixin",
|
"UpdatedAtMixin",
|
||||||
"TimestampMixin",
|
"TimestampMixin",
|
||||||
"WatchedFieldsMixin",
|
"listens_for",
|
||||||
"watch",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -6,14 +6,6 @@ from datetime import datetime
|
|||||||
from sqlalchemy import DateTime, Uuid, text
|
from sqlalchemy import DateTime, Uuid, text
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"UUIDMixin",
|
|
||||||
"UUIDv7Mixin",
|
|
||||||
"CreatedAtMixin",
|
|
||||||
"UpdatedAtMixin",
|
|
||||||
"TimestampMixin",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class UUIDMixin:
|
class UUIDMixin:
|
||||||
"""Mixin that adds a UUID primary key auto-generated by the database."""
|
"""Mixin that adds a UUID primary key auto-generated by the database."""
|
||||||
|
|||||||
@@ -1,11 +1,9 @@
|
|||||||
"""Field-change monitoring via SQLAlchemy session events."""
|
"""Field-change monitoring via SQLAlchemy session events."""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import inspect
|
import inspect
|
||||||
import weakref
|
from collections.abc import Callable
|
||||||
from collections.abc import Awaitable
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, TypeVar
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy import event
|
from sqlalchemy import event
|
||||||
from sqlalchemy import inspect as sa_inspect
|
from sqlalchemy import inspect as sa_inspect
|
||||||
@@ -14,49 +12,81 @@ from sqlalchemy.orm.attributes import set_committed_value as _sa_set_committed_v
|
|||||||
|
|
||||||
from ..logger import get_logger
|
from ..logger import get_logger
|
||||||
|
|
||||||
__all__ = ["ModelEvent", "WatchedFieldsMixin", "watch"]
|
|
||||||
|
|
||||||
_logger = get_logger()
|
_logger = get_logger()
|
||||||
_T = TypeVar("_T")
|
|
||||||
_CALLBACK_ERROR_MSG = "WatchedFieldsMixin callback raised an unhandled exception"
|
|
||||||
_WATCHED_FIELDS: weakref.WeakKeyDictionary[type, list[str]] = (
|
|
||||||
weakref.WeakKeyDictionary()
|
|
||||||
)
|
|
||||||
_SESSION_PENDING_NEW = "_ft_pending_new"
|
|
||||||
_SESSION_CREATES = "_ft_creates"
|
|
||||||
_SESSION_DELETES = "_ft_deletes"
|
|
||||||
_SESSION_UPDATES = "_ft_updates"
|
|
||||||
_SESSION_SAVEPOINT_DEPTH = "_ft_sp_depth"
|
|
||||||
_DEFERRED_STRATEGY_KEY = (("deferred", True), ("instrument", True))
|
|
||||||
|
|
||||||
|
|
||||||
class ModelEvent(str, Enum):
|
class ModelEvent(str, Enum):
|
||||||
"""Event types emitted by :class:`WatchedFieldsMixin`."""
|
"""Event types dispatched by :class:`EventSession`."""
|
||||||
|
|
||||||
CREATE = "create"
|
CREATE = "create"
|
||||||
DELETE = "delete"
|
DELETE = "delete"
|
||||||
UPDATE = "update"
|
UPDATE = "update"
|
||||||
|
|
||||||
|
|
||||||
def watch(*fields: str) -> Any:
|
_CALLBACK_ERROR_MSG = "Event callback raised an unhandled exception"
|
||||||
"""Class decorator to filter which fields trigger ``on_update``.
|
_SESSION_CREATES = "_ft_creates"
|
||||||
|
_SESSION_DELETES = "_ft_deletes"
|
||||||
|
_SESSION_UPDATES = "_ft_updates"
|
||||||
|
_DEFERRED_STRATEGY_KEY = (("deferred", True), ("instrument", True))
|
||||||
|
_EVENT_HANDLERS: dict[tuple[type, ModelEvent], list[Callable[..., Any]]] = {}
|
||||||
|
_WATCHED_MODELS: set[type] = set()
|
||||||
|
_WATCHED_CACHE: dict[type, bool] = {}
|
||||||
|
_HANDLER_CACHE: dict[tuple[type, ModelEvent], list[Callable[..., Any]]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _invalidate_caches() -> None:
|
||||||
|
"""Clear lookup caches after handler registration."""
|
||||||
|
_WATCHED_CACHE.clear()
|
||||||
|
_HANDLER_CACHE.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def listens_for(
|
||||||
|
model_class: type,
|
||||||
|
event_types: list[ModelEvent] | None = None,
|
||||||
|
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||||
|
"""Register a callback for one or more model lifecycle events.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
*fields: One or more field names to watch. At least one name is required.
|
model_class: The SQLAlchemy model class to listen on.
|
||||||
|
event_types: List of :class:`ModelEvent` values to listen for.
|
||||||
Raises:
|
Defaults to all event types.
|
||||||
ValueError: If called with no field names.
|
|
||||||
"""
|
"""
|
||||||
if not fields:
|
evs = event_types if event_types is not None else list(ModelEvent)
|
||||||
raise ValueError("@watch requires at least one field name.")
|
|
||||||
|
|
||||||
def decorator(cls: type[_T]) -> type[_T]:
|
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
_WATCHED_FIELDS[cls] = list(fields)
|
for ev in evs:
|
||||||
return cls
|
_EVENT_HANDLERS.setdefault((model_class, ev), []).append(fn)
|
||||||
|
_WATCHED_MODELS.add(model_class)
|
||||||
|
_invalidate_caches()
|
||||||
|
return fn
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def _is_watched(obj: Any) -> bool:
|
||||||
|
"""Return True if *obj*'s type (or any ancestor) has registered handlers."""
|
||||||
|
cls = type(obj)
|
||||||
|
try:
|
||||||
|
return _WATCHED_CACHE[cls]
|
||||||
|
except KeyError:
|
||||||
|
result = any(klass in _WATCHED_MODELS for klass in cls.__mro__)
|
||||||
|
_WATCHED_CACHE[cls] = result
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _get_handlers(cls: type, ev: ModelEvent) -> list[Callable[..., Any]]:
|
||||||
|
"""Return registered handlers for *cls* and *ev*, walking the MRO."""
|
||||||
|
key = (cls, ev)
|
||||||
|
try:
|
||||||
|
return _HANDLER_CACHE[key]
|
||||||
|
except KeyError:
|
||||||
|
handlers: list[Callable[..., Any]] = []
|
||||||
|
for klass in cls.__mro__:
|
||||||
|
handlers.extend(_EVENT_HANDLERS.get((klass, ev), []))
|
||||||
|
_HANDLER_CACHE[key] = handlers
|
||||||
|
return handlers
|
||||||
|
|
||||||
|
|
||||||
def _snapshot_column_attrs(obj: Any) -> dict[str, Any]:
|
def _snapshot_column_attrs(obj: Any) -> dict[str, Any]:
|
||||||
"""Read currently-loaded column values into a plain dict."""
|
"""Read currently-loaded column values into a plain dict."""
|
||||||
state = sa_inspect(obj) # InstanceState
|
state = sa_inspect(obj) # InstanceState
|
||||||
@@ -65,7 +95,7 @@ def _snapshot_column_attrs(obj: Any) -> dict[str, Any]:
|
|||||||
for prop in state.mapper.column_attrs:
|
for prop in state.mapper.column_attrs:
|
||||||
if prop.key in state_dict:
|
if prop.key in state_dict:
|
||||||
snapshot[prop.key] = state_dict[prop.key]
|
snapshot[prop.key] = state_dict[prop.key]
|
||||||
elif (
|
elif ( # pragma: no cover
|
||||||
not state.expired
|
not state.expired
|
||||||
and prop.strategy_key != _DEFERRED_STRATEGY_KEY
|
and prop.strategy_key != _DEFERRED_STRATEGY_KEY
|
||||||
and all(
|
and all(
|
||||||
@@ -79,12 +109,17 @@ def _snapshot_column_attrs(obj: Any) -> dict[str, Any]:
|
|||||||
return snapshot
|
return snapshot
|
||||||
|
|
||||||
|
|
||||||
def _get_watched_fields(cls: type) -> list[str] | None:
|
def _get_watched_fields(cls: type) -> tuple[str, ...] | None:
|
||||||
"""Return the watched fields for *cls*, walking the MRO to inherit from parents."""
|
"""Return the watched fields for *cls*."""
|
||||||
for klass in cls.__mro__:
|
fields = getattr(cls, "__watched_fields__", None)
|
||||||
if klass in _WATCHED_FIELDS:
|
if fields is not None and (
|
||||||
return _WATCHED_FIELDS[klass]
|
not isinstance(fields, tuple) or not all(isinstance(f, str) for f in fields)
|
||||||
return None
|
):
|
||||||
|
raise TypeError(
|
||||||
|
f"{cls.__name__}.__watched_fields__ must be a tuple[str, ...], "
|
||||||
|
f"got {type(fields).__name__}"
|
||||||
|
)
|
||||||
|
return fields
|
||||||
|
|
||||||
|
|
||||||
def _upsert_changes(
|
def _upsert_changes(
|
||||||
@@ -105,50 +140,32 @@ def _upsert_changes(
|
|||||||
pending[key] = (obj, changes)
|
pending[key] = (obj, changes)
|
||||||
|
|
||||||
|
|
||||||
@event.listens_for(AsyncSession.sync_session_class, "after_transaction_create")
|
|
||||||
def _after_transaction_create(session: Any, transaction: Any) -> None:
|
|
||||||
if transaction.nested:
|
|
||||||
session.info[_SESSION_SAVEPOINT_DEPTH] = (
|
|
||||||
session.info.get(_SESSION_SAVEPOINT_DEPTH, 0) + 1
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@event.listens_for(AsyncSession.sync_session_class, "after_transaction_end")
|
|
||||||
def _after_transaction_end(session: Any, transaction: Any) -> None:
|
|
||||||
if transaction.nested:
|
|
||||||
depth = session.info.get(_SESSION_SAVEPOINT_DEPTH, 0)
|
|
||||||
if depth > 0: # pragma: no branch
|
|
||||||
session.info[_SESSION_SAVEPOINT_DEPTH] = depth - 1
|
|
||||||
|
|
||||||
|
|
||||||
@event.listens_for(AsyncSession.sync_session_class, "after_flush")
|
@event.listens_for(AsyncSession.sync_session_class, "after_flush")
|
||||||
def _after_flush(session: Any, flush_context: Any) -> None:
|
def _after_flush(session: Any, flush_context: Any) -> None:
|
||||||
# New objects: capture references while session.new is still populated.
|
# New objects: capture reference. Attributes will be refreshed after commit.
|
||||||
# Values are read in _after_flush_postexec once RETURNING has been processed.
|
|
||||||
for obj in session.new:
|
for obj in session.new:
|
||||||
if isinstance(obj, WatchedFieldsMixin):
|
if _is_watched(obj):
|
||||||
session.info.setdefault(_SESSION_PENDING_NEW, []).append(obj)
|
session.info.setdefault(_SESSION_CREATES, []).append(obj)
|
||||||
|
|
||||||
# Deleted objects: capture before they leave the identity map.
|
# Deleted objects: snapshot now while attributes are still loaded.
|
||||||
for obj in session.deleted:
|
for obj in session.deleted:
|
||||||
if isinstance(obj, WatchedFieldsMixin):
|
if _is_watched(obj):
|
||||||
session.info.setdefault(_SESSION_DELETES, []).append(obj)
|
snapshot = _snapshot_column_attrs(obj)
|
||||||
|
session.info.setdefault(_SESSION_DELETES, []).append((obj, snapshot))
|
||||||
|
|
||||||
# Dirty objects: read old/new from SQLAlchemy attribute history.
|
# Dirty objects: read old/new from SQLAlchemy attribute history.
|
||||||
for obj in session.dirty:
|
for obj in session.dirty:
|
||||||
if not isinstance(obj, WatchedFieldsMixin):
|
if not _is_watched(obj):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# None = not in dict = watch all fields; list = specific fields only
|
|
||||||
watched = _get_watched_fields(type(obj))
|
watched = _get_watched_fields(type(obj))
|
||||||
changes: dict[str, dict[str, Any]] = {}
|
changes: dict[str, dict[str, Any]] = {}
|
||||||
|
|
||||||
|
inst_attrs = sa_inspect(obj).attrs
|
||||||
attrs = (
|
attrs = (
|
||||||
# Specific fields
|
((field, inst_attrs[field]) for field in watched)
|
||||||
((field, sa_inspect(obj).attrs[field]) for field in watched)
|
|
||||||
if watched is not None
|
if watched is not None
|
||||||
# All mapped fields
|
else ((s.key, s) for s in inst_attrs)
|
||||||
else ((s.key, s) for s in sa_inspect(obj).attrs)
|
|
||||||
)
|
)
|
||||||
for field, attr_state in attrs:
|
for field, attr_state in attrs:
|
||||||
history = attr_state.history
|
history = attr_state.history
|
||||||
@@ -166,116 +183,101 @@ def _after_flush(session: Any, flush_context: Any) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@event.listens_for(AsyncSession.sync_session_class, "after_flush_postexec")
|
|
||||||
def _after_flush_postexec(session: Any, flush_context: Any) -> None:
|
|
||||||
# New objects are now persistent and RETURNING values have been applied,
|
|
||||||
# so server defaults (id, created_at, …) are available via getattr.
|
|
||||||
pending_new: list[Any] = session.info.pop(_SESSION_PENDING_NEW, [])
|
|
||||||
if not pending_new:
|
|
||||||
return
|
|
||||||
session.info.setdefault(_SESSION_CREATES, []).extend(pending_new)
|
|
||||||
|
|
||||||
|
|
||||||
@event.listens_for(AsyncSession.sync_session_class, "after_rollback")
|
@event.listens_for(AsyncSession.sync_session_class, "after_rollback")
|
||||||
def _after_rollback(session: Any) -> None:
|
def _after_rollback(session: Any) -> None:
|
||||||
session.info.pop(_SESSION_PENDING_NEW, None)
|
if session.in_transaction():
|
||||||
|
return
|
||||||
session.info.pop(_SESSION_CREATES, None)
|
session.info.pop(_SESSION_CREATES, None)
|
||||||
session.info.pop(_SESSION_DELETES, None)
|
session.info.pop(_SESSION_DELETES, None)
|
||||||
session.info.pop(_SESSION_UPDATES, None)
|
session.info.pop(_SESSION_UPDATES, None)
|
||||||
|
|
||||||
|
|
||||||
def _task_error_handler(task: asyncio.Task[Any]) -> None:
|
async def _invoke_callback(
|
||||||
if not task.cancelled() and (exc := task.exception()):
|
fn: Callable[..., Any],
|
||||||
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
|
obj: Any,
|
||||||
|
event_type: ModelEvent,
|
||||||
|
changes: dict[str, dict[str, Any]] | None,
|
||||||
def _schedule_with_snapshot(
|
|
||||||
loop: asyncio.AbstractEventLoop, obj: Any, fn: Any, *args: Any
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Snapshot *obj*'s column attrs now (before expire_on_commit wipes them),
|
"""Call *fn* and await the result if it is awaitable."""
|
||||||
then schedule a coroutine that restores the snapshot and calls *fn*.
|
result = fn(obj, event_type, changes)
|
||||||
"""
|
if inspect.isawaitable(result):
|
||||||
snapshot = _snapshot_column_attrs(obj)
|
await result
|
||||||
|
|
||||||
async def _run(
|
|
||||||
obj: Any = obj,
|
|
||||||
fn: Any = fn,
|
|
||||||
snapshot: dict[str, Any] = snapshot,
|
|
||||||
args: tuple = args,
|
|
||||||
) -> None:
|
|
||||||
for key, value in snapshot.items():
|
|
||||||
_sa_set_committed_value(obj, key, value)
|
|
||||||
try:
|
|
||||||
result = fn(*args)
|
|
||||||
if inspect.isawaitable(result):
|
|
||||||
await result
|
|
||||||
except Exception as exc:
|
|
||||||
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
|
|
||||||
|
|
||||||
task = loop.create_task(_run())
|
|
||||||
task.add_done_callback(_task_error_handler)
|
|
||||||
|
|
||||||
|
|
||||||
@event.listens_for(AsyncSession.sync_session_class, "after_commit")
|
class EventSession(AsyncSession):
|
||||||
def _after_commit(session: Any) -> None:
|
"""AsyncSession subclass that dispatches lifecycle callbacks after commit."""
|
||||||
if session.info.get(_SESSION_SAVEPOINT_DEPTH, 0) > 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
creates: list[Any] = session.info.pop(_SESSION_CREATES, [])
|
async def commit(self) -> None: # noqa: C901
|
||||||
deletes: list[Any] = session.info.pop(_SESSION_DELETES, [])
|
await super().commit()
|
||||||
field_changes: dict[int, tuple[Any, dict[str, dict[str, Any]]]] = session.info.pop(
|
|
||||||
_SESSION_UPDATES, {}
|
|
||||||
)
|
|
||||||
|
|
||||||
if creates and deletes:
|
creates: list[Any] = self.info.pop(_SESSION_CREATES, [])
|
||||||
transient_ids = {id(o) for o in creates} & {id(o) for o in deletes}
|
deletes: list[tuple[Any, dict[str, Any]]] = self.info.pop(_SESSION_DELETES, [])
|
||||||
if transient_ids:
|
field_changes: dict[int, tuple[Any, dict[str, dict[str, Any]]]] = self.info.pop(
|
||||||
creates = [o for o in creates if id(o) not in transient_ids]
|
_SESSION_UPDATES, {}
|
||||||
deletes = [o for o in deletes if id(o) not in transient_ids]
|
)
|
||||||
|
|
||||||
|
if not creates and not deletes and not field_changes:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Suppress transient objects (created + deleted in same transaction).
|
||||||
|
if creates and deletes:
|
||||||
|
created_ids = {id(o) for o in creates}
|
||||||
|
deleted_ids = {id(o) for o, _ in deletes}
|
||||||
|
transient_ids = created_ids & deleted_ids
|
||||||
|
if transient_ids:
|
||||||
|
creates = [o for o in creates if id(o) not in transient_ids]
|
||||||
|
deletes = [(o, s) for o, s in deletes if id(o) not in transient_ids]
|
||||||
|
field_changes = {
|
||||||
|
k: v for k, v in field_changes.items() if k not in transient_ids
|
||||||
|
}
|
||||||
|
|
||||||
|
# Suppress updates for newly created objects (CREATE-only semantics).
|
||||||
|
if creates and field_changes:
|
||||||
|
create_ids = {id(o) for o in creates}
|
||||||
field_changes = {
|
field_changes = {
|
||||||
k: v for k, v in field_changes.items() if k not in transient_ids
|
k: v for k, v in field_changes.items() if k not in create_ids
|
||||||
}
|
}
|
||||||
|
|
||||||
if not creates and not deletes and not field_changes:
|
# Dispatch CREATE callbacks.
|
||||||
return
|
for obj in creates:
|
||||||
|
try:
|
||||||
|
state = sa_inspect(obj, raiseerr=False)
|
||||||
|
if (
|
||||||
|
state is None or state.detached or state.transient
|
||||||
|
): # pragma: no cover
|
||||||
|
continue
|
||||||
|
await self.refresh(obj)
|
||||||
|
for handler in _get_handlers(type(obj), ModelEvent.CREATE):
|
||||||
|
await _invoke_callback(handler, obj, ModelEvent.CREATE, None)
|
||||||
|
except Exception as exc:
|
||||||
|
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
|
||||||
|
|
||||||
try:
|
# Dispatch DELETE callbacks (restore snapshot; row is gone).
|
||||||
loop = asyncio.get_running_loop()
|
for obj, snapshot in deletes:
|
||||||
except RuntimeError:
|
try:
|
||||||
return
|
for key, value in snapshot.items():
|
||||||
|
_sa_set_committed_value(obj, key, value)
|
||||||
|
for handler in _get_handlers(type(obj), ModelEvent.DELETE):
|
||||||
|
await _invoke_callback(handler, obj, ModelEvent.DELETE, None)
|
||||||
|
except Exception as exc:
|
||||||
|
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
|
||||||
|
|
||||||
for obj in creates:
|
# Dispatch UPDATE callbacks.
|
||||||
_schedule_with_snapshot(loop, obj, obj.on_create)
|
for obj, changes in field_changes.values():
|
||||||
|
try:
|
||||||
|
state = sa_inspect(obj, raiseerr=False)
|
||||||
|
if (
|
||||||
|
state is None or state.detached or state.transient
|
||||||
|
): # pragma: no cover
|
||||||
|
continue
|
||||||
|
await self.refresh(obj)
|
||||||
|
for handler in _get_handlers(type(obj), ModelEvent.UPDATE):
|
||||||
|
await _invoke_callback(handler, obj, ModelEvent.UPDATE, changes)
|
||||||
|
except Exception as exc:
|
||||||
|
_logger.error(_CALLBACK_ERROR_MSG, exc_info=exc)
|
||||||
|
|
||||||
for obj in deletes:
|
async def rollback(self) -> None:
|
||||||
_schedule_with_snapshot(loop, obj, obj.on_delete)
|
await super().rollback()
|
||||||
|
self.info.pop(_SESSION_CREATES, None)
|
||||||
for obj, changes in field_changes.values():
|
self.info.pop(_SESSION_DELETES, None)
|
||||||
_schedule_with_snapshot(loop, obj, obj.on_update, changes)
|
self.info.pop(_SESSION_UPDATES, None)
|
||||||
|
|
||||||
|
|
||||||
class WatchedFieldsMixin:
|
|
||||||
"""Mixin that enables lifecycle callbacks for SQLAlchemy models."""
|
|
||||||
|
|
||||||
def on_event(
|
|
||||||
self, event: ModelEvent, changes: dict[str, dict[str, Any]] | None = None
|
|
||||||
) -> Awaitable[None] | None:
|
|
||||||
"""Catch-all callback fired for every lifecycle event.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
event: The event type (:attr:`ModelEvent.CREATE`, :attr:`ModelEvent.DELETE`,
|
|
||||||
or :attr:`ModelEvent.UPDATE`).
|
|
||||||
changes: Field changes for :attr:`ModelEvent.UPDATE`, ``None`` otherwise.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def on_create(self) -> Awaitable[None] | None:
|
|
||||||
"""Called after INSERT commit."""
|
|
||||||
return self.on_event(ModelEvent.CREATE)
|
|
||||||
|
|
||||||
def on_delete(self) -> Awaitable[None] | None:
|
|
||||||
"""Called after DELETE commit."""
|
|
||||||
return self.on_event(ModelEvent.DELETE)
|
|
||||||
|
|
||||||
def on_update(self, changes: dict[str, dict[str, Any]]) -> Awaitable[None] | None:
|
|
||||||
"""Called after UPDATE commit when watched fields change."""
|
|
||||||
return self.on_event(ModelEvent.UPDATE, changes=changes)
|
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from sqlalchemy.orm import DeclarativeBase
|
|||||||
|
|
||||||
from ..db import cleanup_tables as _cleanup_tables
|
from ..db import cleanup_tables as _cleanup_tables
|
||||||
from ..db import create_database
|
from ..db import create_database
|
||||||
|
from ..models.watched import EventSession
|
||||||
|
|
||||||
|
|
||||||
async def cleanup_tables(
|
async def cleanup_tables(
|
||||||
@@ -265,7 +266,9 @@ async def create_db_session(
|
|||||||
async with engine.begin() as conn:
|
async with engine.begin() as conn:
|
||||||
await conn.run_sync(base.metadata.create_all)
|
await conn.run_sync(base.metadata.create_all)
|
||||||
|
|
||||||
session_maker = async_sessionmaker(engine, expire_on_commit=expire_on_commit)
|
session_maker = async_sessionmaker(
|
||||||
|
engine, expire_on_commit=expire_on_commit, class_=EventSession
|
||||||
|
)
|
||||||
async with session_maker() as session:
|
async with session_maker() as session:
|
||||||
yield session
|
yield session
|
||||||
|
|
||||||
|
|||||||
1177
tests/test_models.py
1177
tests/test_models.py
File diff suppressed because it is too large
Load Diff
@@ -140,6 +140,7 @@ Examples = [
|
|||||||
|
|
||||||
[[project.nav]]
|
[[project.nav]]
|
||||||
Migration = [
|
Migration = [
|
||||||
|
{"v3.0" = "migration/v3.md"},
|
||||||
{"v2.0" = "migration/v2.md"},
|
{"v2.0" = "migration/v2.md"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user