Files
fastapi-toolsets/src/fastapi_toolsets/crud.py
2026-01-25 16:11:44 +01:00

379 lines
11 KiB
Python

"""Generic async CRUD operations for SQLAlchemy models."""
from collections.abc import Sequence
from typing import Any, ClassVar, Generic, Self, TypeVar, cast
from pydantic import BaseModel
from sqlalchemy import and_, func, select
from sqlalchemy import delete as sql_delete
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.exc import NoResultFound
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.sql.roles import WhereHavingRole
from .db import get_transaction
from .exceptions import NotFoundError
__all__ = [
"AsyncCrud",
"CrudFactory",
]
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
class AsyncCrud(Generic[ModelType]):
"""Generic async CRUD operations for SQLAlchemy models.
Subclass this and set the `model` class variable, or use `CrudFactory`.
Example:
class UserCrud(AsyncCrud[User]):
model = User
# Or use the factory:
UserCrud = CrudFactory(User)
# Then use it:
user = await UserCrud.get(session, [User.id == 1])
users = await UserCrud.get_multi(session, limit=10)
"""
model: ClassVar[type[DeclarativeBase]]
@classmethod
async def create(
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
) -> ModelType:
"""Create a new record in the database.
Args:
session: DB async session
obj: Pydantic model with data to create
Returns:
Created model instance
"""
async with get_transaction(session):
db_model = cls.model(**obj.model_dump())
session.add(db_model)
await session.refresh(db_model)
return cast(ModelType, db_model)
@classmethod
async def get(
cls: type[Self],
session: AsyncSession,
filters: list[Any],
*,
with_for_update: bool = False,
load_options: list[Any] | None = None,
) -> ModelType:
"""Get exactly one record. Raises NotFoundError if not found.
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
with_for_update: Lock the row for update
load_options: SQLAlchemy loader options (e.g., selectinload)
Returns:
Model instance
Raises:
NotFoundError: If no record found
MultipleResultsFound: If more than one record found
"""
q = select(cls.model).where(and_(*filters))
if load_options:
q = q.options(*load_options)
if with_for_update:
q = q.with_for_update()
result = await session.execute(q)
item = result.unique().scalar_one_or_none()
if not item:
raise NotFoundError()
return cast(ModelType, item)
@classmethod
async def first(
cls: type[Self],
session: AsyncSession,
filters: list[Any] | None = None,
*,
load_options: list[Any] | None = None,
) -> ModelType | None:
"""Get the first matching record, or None.
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
load_options: SQLAlchemy loader options
Returns:
Model instance or None
"""
q = select(cls.model)
if filters:
q = q.where(and_(*filters))
if load_options:
q = q.options(*load_options)
result = await session.execute(q)
return cast(ModelType | None, result.unique().scalars().first())
@classmethod
async def get_multi(
cls: type[Self],
session: AsyncSession,
*,
filters: list[Any] | None = None,
load_options: list[Any] | None = None,
order_by: Any | None = None,
limit: int | None = None,
offset: int | None = None,
) -> Sequence[ModelType]:
"""Get multiple records from the database.
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
load_options: SQLAlchemy loader options
order_by: Column or list of columns to order by
limit: Max number of rows to return
offset: Rows to skip
Returns:
List of model instances
"""
q = select(cls.model)
if filters:
q = q.where(and_(*filters))
if load_options:
q = q.options(*load_options)
if order_by is not None:
q = q.order_by(order_by)
if offset is not None:
q = q.offset(offset)
if limit is not None:
q = q.limit(limit)
result = await session.execute(q)
return cast(Sequence[ModelType], result.unique().scalars().all())
@classmethod
async def update(
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
filters: list[Any],
*,
exclude_unset: bool = True,
exclude_none: bool = False,
) -> ModelType:
"""Update a record in the database.
Args:
session: DB async session
obj: Pydantic model with update data
filters: List of SQLAlchemy filter conditions
exclude_unset: Exclude fields not explicitly set in the schema
exclude_none: Exclude fields with None value
Returns:
Updated model instance
Raises:
NotFoundError: If no record found
"""
async with get_transaction(session):
db_model = await cls.get(session=session, filters=filters)
values = obj.model_dump(
exclude_unset=exclude_unset, exclude_none=exclude_none
)
for key, value in values.items():
setattr(db_model, key, value)
await session.refresh(db_model)
return db_model
@classmethod
async def upsert(
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
index_elements: list[str],
*,
set_: BaseModel | None = None,
where: WhereHavingRole | None = None,
) -> ModelType | None:
"""Create or update a record (PostgreSQL only).
Uses INSERT ... ON CONFLICT for atomic upsert.
Args:
session: DB async session
obj: Pydantic model with data
index_elements: Columns for ON CONFLICT (unique constraint)
set_: Pydantic model for ON CONFLICT DO UPDATE SET
where: WHERE clause for ON CONFLICT DO UPDATE
Returns:
Model instance
"""
async with get_transaction(session):
values = obj.model_dump(exclude_unset=True)
q = insert(cls.model).values(**values)
if set_:
q = q.on_conflict_do_update(
index_elements=index_elements,
set_=set_.model_dump(exclude_unset=True),
where=where,
)
else:
q = q.on_conflict_do_nothing(index_elements=index_elements)
q = q.returning(cls.model)
result = await session.execute(q)
try:
db_model = result.unique().scalar_one()
except NoResultFound:
db_model = await cls.first(
session=session,
filters=[getattr(cls.model, k) == v for k, v in values.items()],
)
return cast(ModelType | None, db_model)
@classmethod
async def delete(
cls: type[Self],
session: AsyncSession,
filters: list[Any],
) -> bool:
"""Delete records from the database.
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
Returns:
True if deletion was executed
"""
async with get_transaction(session):
q = sql_delete(cls.model).where(and_(*filters))
await session.execute(q)
return True
@classmethod
async def count(
cls: type[Self],
session: AsyncSession,
filters: list[Any] | None = None,
) -> int:
"""Count records matching the filters.
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
Returns:
Number of matching records
"""
q = select(func.count()).select_from(cls.model)
if filters:
q = q.where(and_(*filters))
result = await session.execute(q)
return result.scalar_one()
@classmethod
async def exists(
cls: type[Self],
session: AsyncSession,
filters: list[Any],
) -> bool:
"""Check if a record exists.
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
Returns:
True if at least one record matches
"""
q = select(cls.model).where(and_(*filters)).exists().select()
result = await session.execute(q)
return bool(result.scalar())
@classmethod
async def paginate(
cls: type[Self],
session: AsyncSession,
*,
filters: list[Any] | None = None,
load_options: list[Any] | None = None,
order_by: Any | None = None,
page: int = 1,
items_per_page: int = 20,
) -> dict[str, Any]:
"""Get paginated results with metadata.
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
load_options: SQLAlchemy loader options
order_by: Column or list of columns to order by
page: Page number (1-indexed)
items_per_page: Number of items per page
Returns:
Dict with 'data' and 'pagination' keys
"""
filters = filters or []
offset = (page - 1) * items_per_page
items = await cls.get_multi(
session,
filters=filters,
load_options=load_options,
order_by=order_by,
limit=items_per_page,
offset=offset,
)
total_count = await cls.count(session, filters=filters)
return {
"data": items,
"pagination": {
"total_count": total_count,
"items_per_page": items_per_page,
"page": page,
"has_more": page * items_per_page < total_count,
},
}
def CrudFactory(
model: type[ModelType],
) -> type[AsyncCrud[ModelType]]:
"""Create a CRUD class for a specific model.
Args:
model: SQLAlchemy model class
Returns:
AsyncCrud subclass bound to the model
Example:
from fastapi_toolsets.crud import CrudFactory
from myapp.models import User, Post
UserCrud = CrudFactory(User)
PostCrud = CrudFactory(Post)
# Usage
user = await UserCrud.get(session, [User.id == 1])
posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id])
"""
cls = type(f"Async{model.__name__}Crud", (AsyncCrud,), {"model": model})
return cast(type[AsyncCrud[ModelType]], cls)