mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-03-01 17:00:48 +01:00
517 lines
16 KiB
Python
517 lines
16 KiB
Python
"""Generic async CRUD operations for SQLAlchemy models."""
|
|
|
|
from collections.abc import Sequence
|
|
from typing import Any, ClassVar, Generic, Self, TypeVar, cast
|
|
|
|
from pydantic import BaseModel
|
|
from sqlalchemy import and_, func, select
|
|
from sqlalchemy import delete as sql_delete
|
|
from sqlalchemy.dialects.postgresql import insert
|
|
from sqlalchemy.exc import NoResultFound
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.orm import DeclarativeBase
|
|
from sqlalchemy.sql.roles import WhereHavingRole
|
|
|
|
from ..db import get_transaction
|
|
from ..exceptions import NotFoundError
|
|
from .search import SearchConfig, SearchFieldType, build_search_filters
|
|
|
|
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
|
JoinType = list[tuple[type[DeclarativeBase], Any]]
|
|
|
|
|
|
class AsyncCrud(Generic[ModelType]):
|
|
"""Generic async CRUD operations for SQLAlchemy models.
|
|
|
|
Subclass this and set the `model` class variable, or use `CrudFactory`.
|
|
"""
|
|
|
|
model: ClassVar[type[DeclarativeBase]]
|
|
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
|
|
|
|
@classmethod
|
|
async def create(
|
|
cls: type[Self],
|
|
session: AsyncSession,
|
|
obj: BaseModel,
|
|
) -> ModelType:
|
|
"""Create a new record in the database.
|
|
|
|
Args:
|
|
session: DB async session
|
|
obj: Pydantic model with data to create
|
|
|
|
Returns:
|
|
Created model instance
|
|
"""
|
|
async with get_transaction(session):
|
|
db_model = cls.model(**obj.model_dump())
|
|
session.add(db_model)
|
|
await session.refresh(db_model)
|
|
return cast(ModelType, db_model)
|
|
|
|
@classmethod
|
|
async def get(
|
|
cls: type[Self],
|
|
session: AsyncSession,
|
|
filters: list[Any],
|
|
*,
|
|
joins: JoinType | None = None,
|
|
outer_join: bool = False,
|
|
with_for_update: bool = False,
|
|
load_options: list[Any] | None = None,
|
|
) -> ModelType:
|
|
"""Get exactly one record. Raises NotFoundError if not found.
|
|
|
|
Args:
|
|
session: DB async session
|
|
filters: List of SQLAlchemy filter conditions
|
|
joins: List of (model, condition) tuples for joining related tables
|
|
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
|
|
with_for_update: Lock the row for update
|
|
load_options: SQLAlchemy loader options (e.g., selectinload)
|
|
|
|
Returns:
|
|
Model instance
|
|
|
|
Raises:
|
|
NotFoundError: If no record found
|
|
MultipleResultsFound: If more than one record found
|
|
"""
|
|
q = select(cls.model)
|
|
if joins:
|
|
for model, condition in joins:
|
|
q = (
|
|
q.outerjoin(model, condition)
|
|
if outer_join
|
|
else q.join(model, condition)
|
|
)
|
|
q = q.where(and_(*filters))
|
|
if load_options:
|
|
q = q.options(*load_options)
|
|
if with_for_update:
|
|
q = q.with_for_update()
|
|
result = await session.execute(q)
|
|
item = result.unique().scalar_one_or_none()
|
|
if not item:
|
|
raise NotFoundError()
|
|
return cast(ModelType, item)
|
|
|
|
@classmethod
|
|
async def first(
|
|
cls: type[Self],
|
|
session: AsyncSession,
|
|
filters: list[Any] | None = None,
|
|
*,
|
|
joins: JoinType | None = None,
|
|
outer_join: bool = False,
|
|
load_options: list[Any] | None = None,
|
|
) -> ModelType | None:
|
|
"""Get the first matching record, or None.
|
|
|
|
Args:
|
|
session: DB async session
|
|
filters: List of SQLAlchemy filter conditions
|
|
joins: List of (model, condition) tuples for joining related tables
|
|
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
|
|
load_options: SQLAlchemy loader options
|
|
|
|
Returns:
|
|
Model instance or None
|
|
"""
|
|
q = select(cls.model)
|
|
if joins:
|
|
for model, condition in joins:
|
|
q = (
|
|
q.outerjoin(model, condition)
|
|
if outer_join
|
|
else q.join(model, condition)
|
|
)
|
|
if filters:
|
|
q = q.where(and_(*filters))
|
|
if load_options:
|
|
q = q.options(*load_options)
|
|
result = await session.execute(q)
|
|
return cast(ModelType | None, result.unique().scalars().first())
|
|
|
|
@classmethod
|
|
async def get_multi(
|
|
cls: type[Self],
|
|
session: AsyncSession,
|
|
*,
|
|
filters: list[Any] | None = None,
|
|
joins: JoinType | None = None,
|
|
outer_join: bool = False,
|
|
load_options: list[Any] | None = None,
|
|
order_by: Any | None = None,
|
|
limit: int | None = None,
|
|
offset: int | None = None,
|
|
) -> Sequence[ModelType]:
|
|
"""Get multiple records from the database.
|
|
|
|
Args:
|
|
session: DB async session
|
|
filters: List of SQLAlchemy filter conditions
|
|
joins: List of (model, condition) tuples for joining related tables
|
|
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
|
|
load_options: SQLAlchemy loader options
|
|
order_by: Column or list of columns to order by
|
|
limit: Max number of rows to return
|
|
offset: Rows to skip
|
|
|
|
Returns:
|
|
List of model instances
|
|
"""
|
|
q = select(cls.model)
|
|
if joins:
|
|
for model, condition in joins:
|
|
q = (
|
|
q.outerjoin(model, condition)
|
|
if outer_join
|
|
else q.join(model, condition)
|
|
)
|
|
if filters:
|
|
q = q.where(and_(*filters))
|
|
if load_options:
|
|
q = q.options(*load_options)
|
|
if order_by is not None:
|
|
q = q.order_by(order_by)
|
|
if offset is not None:
|
|
q = q.offset(offset)
|
|
if limit is not None:
|
|
q = q.limit(limit)
|
|
result = await session.execute(q)
|
|
return cast(Sequence[ModelType], result.unique().scalars().all())
|
|
|
|
@classmethod
|
|
async def update(
|
|
cls: type[Self],
|
|
session: AsyncSession,
|
|
obj: BaseModel,
|
|
filters: list[Any],
|
|
*,
|
|
exclude_unset: bool = True,
|
|
exclude_none: bool = False,
|
|
) -> ModelType:
|
|
"""Update a record in the database.
|
|
|
|
Args:
|
|
session: DB async session
|
|
obj: Pydantic model with update data
|
|
filters: List of SQLAlchemy filter conditions
|
|
exclude_unset: Exclude fields not explicitly set in the schema
|
|
exclude_none: Exclude fields with None value
|
|
|
|
Returns:
|
|
Updated model instance
|
|
|
|
Raises:
|
|
NotFoundError: If no record found
|
|
"""
|
|
async with get_transaction(session):
|
|
db_model = await cls.get(session=session, filters=filters)
|
|
values = obj.model_dump(
|
|
exclude_unset=exclude_unset, exclude_none=exclude_none
|
|
)
|
|
for key, value in values.items():
|
|
setattr(db_model, key, value)
|
|
await session.refresh(db_model)
|
|
return db_model
|
|
|
|
@classmethod
|
|
async def upsert(
|
|
cls: type[Self],
|
|
session: AsyncSession,
|
|
obj: BaseModel,
|
|
index_elements: list[str],
|
|
*,
|
|
set_: BaseModel | None = None,
|
|
where: WhereHavingRole | None = None,
|
|
) -> ModelType | None:
|
|
"""Create or update a record (PostgreSQL only).
|
|
|
|
Uses INSERT ... ON CONFLICT for atomic upsert.
|
|
|
|
Args:
|
|
session: DB async session
|
|
obj: Pydantic model with data
|
|
index_elements: Columns for ON CONFLICT (unique constraint)
|
|
set_: Pydantic model for ON CONFLICT DO UPDATE SET
|
|
where: WHERE clause for ON CONFLICT DO UPDATE
|
|
|
|
Returns:
|
|
Model instance
|
|
"""
|
|
async with get_transaction(session):
|
|
values = obj.model_dump(exclude_unset=True)
|
|
q = insert(cls.model).values(**values)
|
|
if set_:
|
|
q = q.on_conflict_do_update(
|
|
index_elements=index_elements,
|
|
set_=set_.model_dump(exclude_unset=True),
|
|
where=where,
|
|
)
|
|
else:
|
|
q = q.on_conflict_do_nothing(index_elements=index_elements)
|
|
q = q.returning(cls.model)
|
|
result = await session.execute(q)
|
|
try:
|
|
db_model = result.unique().scalar_one()
|
|
except NoResultFound:
|
|
db_model = await cls.first(
|
|
session=session,
|
|
filters=[getattr(cls.model, k) == v for k, v in values.items()],
|
|
)
|
|
return cast(ModelType | None, db_model)
|
|
|
|
@classmethod
|
|
async def delete(
|
|
cls: type[Self],
|
|
session: AsyncSession,
|
|
filters: list[Any],
|
|
) -> bool:
|
|
"""Delete records from the database.
|
|
|
|
Args:
|
|
session: DB async session
|
|
filters: List of SQLAlchemy filter conditions
|
|
|
|
Returns:
|
|
True if deletion was executed
|
|
"""
|
|
async with get_transaction(session):
|
|
q = sql_delete(cls.model).where(and_(*filters))
|
|
await session.execute(q)
|
|
return True
|
|
|
|
@classmethod
|
|
async def count(
|
|
cls: type[Self],
|
|
session: AsyncSession,
|
|
filters: list[Any] | None = None,
|
|
*,
|
|
joins: JoinType | None = None,
|
|
outer_join: bool = False,
|
|
) -> int:
|
|
"""Count records matching the filters.
|
|
|
|
Args:
|
|
session: DB async session
|
|
filters: List of SQLAlchemy filter conditions
|
|
joins: List of (model, condition) tuples for joining related tables
|
|
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
|
|
|
|
Returns:
|
|
Number of matching records
|
|
"""
|
|
q = select(func.count()).select_from(cls.model)
|
|
if joins:
|
|
for model, condition in joins:
|
|
q = (
|
|
q.outerjoin(model, condition)
|
|
if outer_join
|
|
else q.join(model, condition)
|
|
)
|
|
if filters:
|
|
q = q.where(and_(*filters))
|
|
result = await session.execute(q)
|
|
return result.scalar_one()
|
|
|
|
@classmethod
|
|
async def exists(
|
|
cls: type[Self],
|
|
session: AsyncSession,
|
|
filters: list[Any],
|
|
*,
|
|
joins: JoinType | None = None,
|
|
outer_join: bool = False,
|
|
) -> bool:
|
|
"""Check if a record exists.
|
|
|
|
Args:
|
|
session: DB async session
|
|
filters: List of SQLAlchemy filter conditions
|
|
joins: List of (model, condition) tuples for joining related tables
|
|
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
|
|
|
|
Returns:
|
|
True if at least one record matches
|
|
"""
|
|
q = select(cls.model)
|
|
if joins:
|
|
for model, condition in joins:
|
|
q = (
|
|
q.outerjoin(model, condition)
|
|
if outer_join
|
|
else q.join(model, condition)
|
|
)
|
|
q = q.where(and_(*filters)).exists().select()
|
|
result = await session.execute(q)
|
|
return bool(result.scalar())
|
|
|
|
@classmethod
|
|
async def paginate(
|
|
cls: type[Self],
|
|
session: AsyncSession,
|
|
*,
|
|
filters: list[Any] | None = None,
|
|
joins: JoinType | None = None,
|
|
outer_join: bool = False,
|
|
load_options: list[Any] | None = None,
|
|
order_by: Any | None = None,
|
|
page: int = 1,
|
|
items_per_page: int = 20,
|
|
search: str | SearchConfig | None = None,
|
|
search_fields: Sequence[SearchFieldType] | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Get paginated results with metadata.
|
|
|
|
Args:
|
|
session: DB async session
|
|
filters: List of SQLAlchemy filter conditions
|
|
joins: List of (model, condition) tuples for joining related tables
|
|
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
|
|
load_options: SQLAlchemy loader options
|
|
order_by: Column or list of columns to order by
|
|
page: Page number (1-indexed)
|
|
items_per_page: Number of items per page
|
|
search: Search query string or SearchConfig object
|
|
search_fields: Fields to search in (overrides class default)
|
|
|
|
Returns:
|
|
Dict with 'data' and 'pagination' keys
|
|
"""
|
|
filters = list(filters) if filters else []
|
|
offset = (page - 1) * items_per_page
|
|
search_joins: list[Any] = []
|
|
|
|
# Build search filters
|
|
if search:
|
|
search_filters, search_joins = build_search_filters(
|
|
cls.model,
|
|
search,
|
|
search_fields=search_fields,
|
|
default_fields=cls.searchable_fields,
|
|
)
|
|
filters.extend(search_filters)
|
|
|
|
# Build query with joins
|
|
q = select(cls.model)
|
|
|
|
# Apply explicit joins
|
|
if joins:
|
|
for model, condition in joins:
|
|
q = (
|
|
q.outerjoin(model, condition)
|
|
if outer_join
|
|
else q.join(model, condition)
|
|
)
|
|
|
|
# Apply search joins (always outer joins for search)
|
|
for join_rel in search_joins:
|
|
q = q.outerjoin(join_rel)
|
|
|
|
if filters:
|
|
q = q.where(and_(*filters))
|
|
if load_options:
|
|
q = q.options(*load_options)
|
|
if order_by is not None:
|
|
q = q.order_by(order_by)
|
|
|
|
q = q.offset(offset).limit(items_per_page)
|
|
result = await session.execute(q)
|
|
items = result.unique().scalars().all()
|
|
|
|
# Count query (with same joins and filters)
|
|
pk_col = cls.model.__mapper__.primary_key[0]
|
|
count_q = select(func.count(func.distinct(getattr(cls.model, pk_col.name))))
|
|
count_q = count_q.select_from(cls.model)
|
|
|
|
# Apply explicit joins to count query
|
|
if joins:
|
|
for model, condition in joins:
|
|
count_q = (
|
|
count_q.outerjoin(model, condition)
|
|
if outer_join
|
|
else count_q.join(model, condition)
|
|
)
|
|
|
|
# Apply search joins to count query
|
|
for join_rel in search_joins:
|
|
count_q = count_q.outerjoin(join_rel)
|
|
|
|
if filters:
|
|
count_q = count_q.where(and_(*filters))
|
|
|
|
count_result = await session.execute(count_q)
|
|
total_count = count_result.scalar_one()
|
|
|
|
return {
|
|
"data": items,
|
|
"pagination": {
|
|
"total_count": total_count,
|
|
"items_per_page": items_per_page,
|
|
"page": page,
|
|
"has_more": page * items_per_page < total_count,
|
|
},
|
|
}
|
|
|
|
|
|
def CrudFactory(
|
|
model: type[ModelType],
|
|
*,
|
|
searchable_fields: Sequence[SearchFieldType] | None = None,
|
|
) -> type[AsyncCrud[ModelType]]:
|
|
"""Create a CRUD class for a specific model.
|
|
|
|
Args:
|
|
model: SQLAlchemy model class
|
|
searchable_fields: Optional list of searchable fields
|
|
|
|
Returns:
|
|
AsyncCrud subclass bound to the model
|
|
|
|
Example:
|
|
from fastapi_toolsets.crud import CrudFactory
|
|
from myapp.models import User, Post
|
|
|
|
UserCrud = CrudFactory(User)
|
|
PostCrud = CrudFactory(Post)
|
|
|
|
# With searchable fields:
|
|
UserCrud = CrudFactory(
|
|
User,
|
|
searchable_fields=[User.username, User.email, (User.role, Role.name)]
|
|
)
|
|
|
|
# Usage
|
|
user = await UserCrud.get(session, [User.id == 1])
|
|
posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id])
|
|
|
|
# With search
|
|
result = await UserCrud.paginate(session, search="john")
|
|
|
|
# With joins (inner join by default):
|
|
users = await UserCrud.get_multi(
|
|
session,
|
|
joins=[(Post, Post.user_id == User.id)],
|
|
filters=[Post.published == True],
|
|
)
|
|
|
|
# With outer join:
|
|
users = await UserCrud.get_multi(
|
|
session,
|
|
joins=[(Post, Post.user_id == User.id)],
|
|
outer_join=True,
|
|
)
|
|
"""
|
|
cls = type(
|
|
f"Async{model.__name__}Crud",
|
|
(AsyncCrud,),
|
|
{
|
|
"model": model,
|
|
"searchable_fields": searchable_fields,
|
|
},
|
|
)
|
|
return cast(type[AsyncCrud[ModelType]], cls)
|