Files
fastapi-toolsets/src/fastapi_toolsets/crud/factory.py
2026-02-01 15:01:10 +01:00

517 lines
16 KiB
Python

"""Generic async CRUD operations for SQLAlchemy models."""
from collections.abc import Sequence
from typing import Any, ClassVar, Generic, Self, TypeVar, cast
from pydantic import BaseModel
from sqlalchemy import and_, func, select
from sqlalchemy import delete as sql_delete
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.exc import NoResultFound
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.sql.roles import WhereHavingRole
from ..db import get_transaction
from ..exceptions import NotFoundError
from .search import SearchConfig, SearchFieldType, build_search_filters
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
JoinType = list[tuple[type[DeclarativeBase], Any]]
class AsyncCrud(Generic[ModelType]):
"""Generic async CRUD operations for SQLAlchemy models.
Subclass this and set the `model` class variable, or use `CrudFactory`.
"""
model: ClassVar[type[DeclarativeBase]]
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
@classmethod
async def create(
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
) -> ModelType:
"""Create a new record in the database.
Args:
session: DB async session
obj: Pydantic model with data to create
Returns:
Created model instance
"""
async with get_transaction(session):
db_model = cls.model(**obj.model_dump())
session.add(db_model)
await session.refresh(db_model)
return cast(ModelType, db_model)
@classmethod
async def get(
cls: type[Self],
session: AsyncSession,
filters: list[Any],
*,
joins: JoinType | None = None,
outer_join: bool = False,
with_for_update: bool = False,
load_options: list[Any] | None = None,
) -> ModelType:
"""Get exactly one record. Raises NotFoundError if not found.
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
with_for_update: Lock the row for update
load_options: SQLAlchemy loader options (e.g., selectinload)
Returns:
Model instance
Raises:
NotFoundError: If no record found
MultipleResultsFound: If more than one record found
"""
q = select(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
q = q.where(and_(*filters))
if load_options:
q = q.options(*load_options)
if with_for_update:
q = q.with_for_update()
result = await session.execute(q)
item = result.unique().scalar_one_or_none()
if not item:
raise NotFoundError()
return cast(ModelType, item)
@classmethod
async def first(
cls: type[Self],
session: AsyncSession,
filters: list[Any] | None = None,
*,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: list[Any] | None = None,
) -> ModelType | None:
"""Get the first matching record, or None.
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
load_options: SQLAlchemy loader options
Returns:
Model instance or None
"""
q = select(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
if filters:
q = q.where(and_(*filters))
if load_options:
q = q.options(*load_options)
result = await session.execute(q)
return cast(ModelType | None, result.unique().scalars().first())
@classmethod
async def get_multi(
cls: type[Self],
session: AsyncSession,
*,
filters: list[Any] | None = None,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: list[Any] | None = None,
order_by: Any | None = None,
limit: int | None = None,
offset: int | None = None,
) -> Sequence[ModelType]:
"""Get multiple records from the database.
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
load_options: SQLAlchemy loader options
order_by: Column or list of columns to order by
limit: Max number of rows to return
offset: Rows to skip
Returns:
List of model instances
"""
q = select(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
if filters:
q = q.where(and_(*filters))
if load_options:
q = q.options(*load_options)
if order_by is not None:
q = q.order_by(order_by)
if offset is not None:
q = q.offset(offset)
if limit is not None:
q = q.limit(limit)
result = await session.execute(q)
return cast(Sequence[ModelType], result.unique().scalars().all())
@classmethod
async def update(
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
filters: list[Any],
*,
exclude_unset: bool = True,
exclude_none: bool = False,
) -> ModelType:
"""Update a record in the database.
Args:
session: DB async session
obj: Pydantic model with update data
filters: List of SQLAlchemy filter conditions
exclude_unset: Exclude fields not explicitly set in the schema
exclude_none: Exclude fields with None value
Returns:
Updated model instance
Raises:
NotFoundError: If no record found
"""
async with get_transaction(session):
db_model = await cls.get(session=session, filters=filters)
values = obj.model_dump(
exclude_unset=exclude_unset, exclude_none=exclude_none
)
for key, value in values.items():
setattr(db_model, key, value)
await session.refresh(db_model)
return db_model
@classmethod
async def upsert(
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
index_elements: list[str],
*,
set_: BaseModel | None = None,
where: WhereHavingRole | None = None,
) -> ModelType | None:
"""Create or update a record (PostgreSQL only).
Uses INSERT ... ON CONFLICT for atomic upsert.
Args:
session: DB async session
obj: Pydantic model with data
index_elements: Columns for ON CONFLICT (unique constraint)
set_: Pydantic model for ON CONFLICT DO UPDATE SET
where: WHERE clause for ON CONFLICT DO UPDATE
Returns:
Model instance
"""
async with get_transaction(session):
values = obj.model_dump(exclude_unset=True)
q = insert(cls.model).values(**values)
if set_:
q = q.on_conflict_do_update(
index_elements=index_elements,
set_=set_.model_dump(exclude_unset=True),
where=where,
)
else:
q = q.on_conflict_do_nothing(index_elements=index_elements)
q = q.returning(cls.model)
result = await session.execute(q)
try:
db_model = result.unique().scalar_one()
except NoResultFound:
db_model = await cls.first(
session=session,
filters=[getattr(cls.model, k) == v for k, v in values.items()],
)
return cast(ModelType | None, db_model)
@classmethod
async def delete(
cls: type[Self],
session: AsyncSession,
filters: list[Any],
) -> bool:
"""Delete records from the database.
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
Returns:
True if deletion was executed
"""
async with get_transaction(session):
q = sql_delete(cls.model).where(and_(*filters))
await session.execute(q)
return True
@classmethod
async def count(
cls: type[Self],
session: AsyncSession,
filters: list[Any] | None = None,
*,
joins: JoinType | None = None,
outer_join: bool = False,
) -> int:
"""Count records matching the filters.
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
Returns:
Number of matching records
"""
q = select(func.count()).select_from(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
if filters:
q = q.where(and_(*filters))
result = await session.execute(q)
return result.scalar_one()
@classmethod
async def exists(
cls: type[Self],
session: AsyncSession,
filters: list[Any],
*,
joins: JoinType | None = None,
outer_join: bool = False,
) -> bool:
"""Check if a record exists.
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
Returns:
True if at least one record matches
"""
q = select(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
q = q.where(and_(*filters)).exists().select()
result = await session.execute(q)
return bool(result.scalar())
@classmethod
async def paginate(
cls: type[Self],
session: AsyncSession,
*,
filters: list[Any] | None = None,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: list[Any] | None = None,
order_by: Any | None = None,
page: int = 1,
items_per_page: int = 20,
search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None,
) -> dict[str, Any]:
"""Get paginated results with metadata.
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
load_options: SQLAlchemy loader options
order_by: Column or list of columns to order by
page: Page number (1-indexed)
items_per_page: Number of items per page
search: Search query string or SearchConfig object
search_fields: Fields to search in (overrides class default)
Returns:
Dict with 'data' and 'pagination' keys
"""
filters = list(filters) if filters else []
offset = (page - 1) * items_per_page
search_joins: list[Any] = []
# Build search filters
if search:
search_filters, search_joins = build_search_filters(
cls.model,
search,
search_fields=search_fields,
default_fields=cls.searchable_fields,
)
filters.extend(search_filters)
# Build query with joins
q = select(cls.model)
# Apply explicit joins
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
# Apply search joins (always outer joins for search)
for join_rel in search_joins:
q = q.outerjoin(join_rel)
if filters:
q = q.where(and_(*filters))
if load_options:
q = q.options(*load_options)
if order_by is not None:
q = q.order_by(order_by)
q = q.offset(offset).limit(items_per_page)
result = await session.execute(q)
items = result.unique().scalars().all()
# Count query (with same joins and filters)
pk_col = cls.model.__mapper__.primary_key[0]
count_q = select(func.count(func.distinct(getattr(cls.model, pk_col.name))))
count_q = count_q.select_from(cls.model)
# Apply explicit joins to count query
if joins:
for model, condition in joins:
count_q = (
count_q.outerjoin(model, condition)
if outer_join
else count_q.join(model, condition)
)
# Apply search joins to count query
for join_rel in search_joins:
count_q = count_q.outerjoin(join_rel)
if filters:
count_q = count_q.where(and_(*filters))
count_result = await session.execute(count_q)
total_count = count_result.scalar_one()
return {
"data": items,
"pagination": {
"total_count": total_count,
"items_per_page": items_per_page,
"page": page,
"has_more": page * items_per_page < total_count,
},
}
def CrudFactory(
model: type[ModelType],
*,
searchable_fields: Sequence[SearchFieldType] | None = None,
) -> type[AsyncCrud[ModelType]]:
"""Create a CRUD class for a specific model.
Args:
model: SQLAlchemy model class
searchable_fields: Optional list of searchable fields
Returns:
AsyncCrud subclass bound to the model
Example:
from fastapi_toolsets.crud import CrudFactory
from myapp.models import User, Post
UserCrud = CrudFactory(User)
PostCrud = CrudFactory(Post)
# With searchable fields:
UserCrud = CrudFactory(
User,
searchable_fields=[User.username, User.email, (User.role, Role.name)]
)
# Usage
user = await UserCrud.get(session, [User.id == 1])
posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id])
# With search
result = await UserCrud.paginate(session, search="john")
# With joins (inner join by default):
users = await UserCrud.get_multi(
session,
joins=[(Post, Post.user_id == User.id)],
filters=[Post.published == True],
)
# With outer join:
users = await UserCrud.get_multi(
session,
joins=[(Post, Post.user_id == User.id)],
outer_join=True,
)
"""
cls = type(
f"Async{model.__name__}Crud",
(AsyncCrud,),
{
"model": model,
"searchable_fields": searchable_fields,
},
)
return cast(type[AsyncCrud[ModelType]], cls)