Files
fastapi-toolsets/src/fastapi_toolsets/crud/factory.py

1109 lines
40 KiB
Python

"""Generic async CRUD operations for SQLAlchemy models."""
from __future__ import annotations
import base64
import inspect
import json
import uuid as uuid_module
from collections.abc import Awaitable, Callable, Sequence
from datetime import date, datetime
from decimal import Decimal
from typing import Any, ClassVar, Generic, Literal, Self, cast, overload
from fastapi import Query
from pydantic import BaseModel
from sqlalchemy import Date, DateTime, Float, Integer, Numeric, Uuid, and_, func, select
from sqlalchemy import delete as sql_delete
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.exc import NoResultFound
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase, QueryableAttribute, selectinload
from sqlalchemy.sql.base import ExecutableOption
from sqlalchemy.sql.roles import WhereHavingRole
from ..db import get_transaction
from ..exceptions import InvalidOrderFieldError, NotFoundError
from ..schemas import CursorPagination, OffsetPagination, PaginatedResponse, Response
from ..types import (
FacetFieldType,
JoinType,
M2MFieldType,
ModelType,
OrderByClause,
SchemaType,
SearchFieldType,
)
from .search import (
SearchConfig,
build_facets,
build_filter_by,
build_search_filters,
facet_keys,
)
def _encode_cursor(value: Any) -> str:
"""Encode cursor column value as an base64 string."""
return base64.b64encode(json.dumps(str(value)).encode()).decode()
def _decode_cursor(cursor: str) -> str:
"""Decode cursor base64 string."""
return json.loads(base64.b64decode(cursor.encode()).decode())
def _apply_joins(q: Any, joins: JoinType | None, outer_join: bool) -> Any:
"""Apply a list of (model, condition) joins to a SQLAlchemy select query."""
if not joins:
return q
for model, condition in joins:
q = q.outerjoin(model, condition) if outer_join else q.join(model, condition)
return q
def _apply_search_joins(q: Any, search_joins: list[Any]) -> Any:
"""Apply relationship-based outer joins (from search/filter_by) to a query."""
for join_rel in search_joins:
q = q.outerjoin(join_rel)
return q
class AsyncCrud(Generic[ModelType]):
"""Generic async CRUD operations for SQLAlchemy models.
Subclass this and set the `model` class variable, or use `CrudFactory`.
"""
model: ClassVar[type[DeclarativeBase]]
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
facet_fields: ClassVar[Sequence[FacetFieldType] | None] = None
order_fields: ClassVar[Sequence[QueryableAttribute[Any]] | None] = None
m2m_fields: ClassVar[M2MFieldType | None] = None
default_load_options: ClassVar[list[ExecutableOption] | None] = None
cursor_column: ClassVar[Any | None] = None
@classmethod
def _resolve_load_options(
cls, load_options: list[ExecutableOption] | None
) -> list[ExecutableOption] | None:
"""Return load_options if provided, else fall back to default_load_options."""
if load_options is not None:
return load_options
return cls.default_load_options
@classmethod
async def _resolve_m2m(
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
*,
only_set: bool = False,
) -> dict[str, list[Any]]:
"""Resolve M2M fields from a Pydantic schema into related model instances.
Args:
session: DB async session
obj: Pydantic model containing M2M ID fields
only_set: If True, only process fields explicitly set on the schema
Returns:
Dict mapping relationship attr names to lists of related instances
"""
result: dict[str, list[Any]] = {}
if not cls.m2m_fields:
return result
for schema_field, rel in cls.m2m_fields.items():
rel_attr = rel.property.key
related_model = rel.property.mapper.class_
if only_set and schema_field not in obj.model_fields_set:
continue
ids = getattr(obj, schema_field, None)
if ids is not None:
related = (
(
await session.execute(
select(related_model).where(related_model.id.in_(ids))
)
)
.scalars()
.all()
)
if len(related) != len(ids):
found_ids = {r.id for r in related}
missing = set(ids) - found_ids
raise NotFoundError(
f"Related {related_model.__name__} not found for IDs: {missing}"
)
result[rel_attr] = list(related)
else:
result[rel_attr] = []
return result
@classmethod
def _m2m_schema_fields(cls: type[Self]) -> set[str]:
"""Return the set of schema field names that are M2M fields."""
if not cls.m2m_fields:
return set()
return set(cls.m2m_fields.keys())
@classmethod
def _resolve_facet_fields(
cls: type[Self],
facet_fields: Sequence[FacetFieldType] | None,
) -> Sequence[FacetFieldType] | None:
"""Return facet_fields if given, otherwise fall back to the class-level default."""
return facet_fields if facet_fields is not None else cls.facet_fields
@classmethod
def _prepare_filter_by(
cls: type[Self],
filter_by: dict[str, Any] | BaseModel | None,
facet_fields: Sequence[FacetFieldType] | None,
) -> tuple[list[Any], list[Any]]:
"""Normalize filter_by and return (filters, joins) to apply to the query."""
if isinstance(filter_by, BaseModel):
filter_by = filter_by.model_dump(exclude_none=True)
if not filter_by:
return [], []
resolved = cls._resolve_facet_fields(facet_fields)
return build_filter_by(filter_by, resolved or [])
@classmethod
async def _build_filter_attributes(
cls: type[Self],
session: AsyncSession,
facet_fields: Sequence[FacetFieldType] | None,
filters: list[Any],
search_joins: list[Any],
) -> dict[str, list[Any]] | None:
"""Build facet filter_attributes, or return None if no facet fields configured."""
resolved = cls._resolve_facet_fields(facet_fields)
if not resolved:
return None
return await build_facets(
session,
cls.model,
resolved,
base_filters=filters,
base_joins=search_joins,
)
@classmethod
def filter_params(
cls: type[Self],
*,
facet_fields: Sequence[FacetFieldType] | None = None,
) -> Callable[..., Awaitable[dict[str, list[str]]]]:
"""Return a FastAPI dependency that collects facet filter values from query parameters.
Args:
facet_fields: Override the facet fields for this dependency. Falls back to the
class-level ``facet_fields`` if not provided.
Returns:
An async dependency function named ``{Model}FilterParams`` that resolves to a
``dict[str, list[str]]`` containing only the keys that were supplied in the
request (absent/``None`` parameters are excluded).
Raises:
ValueError: If no facet fields are configured on this CRUD class and none are
provided via ``facet_fields``.
"""
fields = cls._resolve_facet_fields(facet_fields)
if not fields:
raise ValueError(
f"{cls.__name__} has no facet_fields configured. "
"Pass facet_fields= or set them on CrudFactory."
)
keys = facet_keys(fields)
async def dependency(**kwargs: Any) -> dict[str, list[str]]:
return {k: v for k, v in kwargs.items() if v is not None}
dependency.__name__ = f"{cls.model.__name__}FilterParams"
dependency.__signature__ = inspect.Signature( # type: ignore[attr-defined]
parameters=[
inspect.Parameter(
k,
inspect.Parameter.KEYWORD_ONLY,
annotation=list[str] | None,
default=Query(default=None),
)
for k in keys
]
)
return dependency
@classmethod
def order_params(
cls: type[Self],
*,
order_fields: Sequence[QueryableAttribute[Any]] | None = None,
default_field: QueryableAttribute[Any] | None = None,
default_order: Literal["asc", "desc"] = "asc",
) -> Callable[..., Awaitable[OrderByClause | None]]:
"""Return a FastAPI dependency that resolves order query params into an order_by clause.
Args:
order_fields: Override the allowed order fields. Falls back to the class-level
``order_fields`` if not provided.
default_field: Field to order by when ``order_by`` query param is absent.
If ``None`` and no ``order_by`` is provided, no ordering is applied.
default_order: Default order direction when ``order`` is absent
(``"asc"`` or ``"desc"``).
Returns:
An async dependency function named ``{Model}OrderParams`` that resolves to an
``OrderByClause`` (or ``None``). Pass it to ``Depends()`` in your route.
Raises:
ValueError: If no order fields are configured on this CRUD class and none are
provided via ``order_fields``.
InvalidOrderFieldError: When the request provides an unknown ``order_by`` value.
"""
fields = order_fields if order_fields is not None else cls.order_fields
if not fields:
raise ValueError(
f"{cls.__name__} has no order_fields configured. "
"Pass order_fields= or set them on CrudFactory."
)
field_map: dict[str, QueryableAttribute[Any]] = {f.key: f for f in fields}
valid_keys = sorted(field_map.keys())
async def dependency(
order_by: str | None = Query(
None, description=f"Field to order by. Valid values: {valid_keys}"
),
order: Literal["asc", "desc"] = Query(
default_order, description="Sort direction"
),
) -> OrderByClause | None:
if order_by is None:
if default_field is None:
return None
field = default_field
elif order_by not in field_map:
raise InvalidOrderFieldError(order_by, valid_keys)
else:
field = field_map[order_by]
return field.asc() if order == "asc" else field.desc()
dependency.__name__ = f"{cls.model.__name__}OrderParams"
return dependency
@overload
@classmethod
async def create( # pragma: no cover
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
*,
schema: type[SchemaType],
) -> Response[SchemaType]: ...
@overload
@classmethod
async def create( # pragma: no cover
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
*,
schema: None = ...,
) -> ModelType: ...
@classmethod
async def create(
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
*,
schema: type[BaseModel] | None = None,
) -> ModelType | Response[Any]:
"""Create a new record in the database.
Args:
session: DB async session
obj: Pydantic model with data to create
schema: Pydantic schema to serialize the result into. When provided,
the result is automatically wrapped in a ``Response[schema]``.
Returns:
Created model instance, or ``Response[schema]`` when ``schema`` is given.
"""
async with get_transaction(session):
m2m_exclude = cls._m2m_schema_fields()
data = (
obj.model_dump(exclude=m2m_exclude) if m2m_exclude else obj.model_dump()
)
db_model = cls.model(**data)
if m2m_exclude:
m2m_resolved = await cls._resolve_m2m(session, obj)
for rel_attr, related_instances in m2m_resolved.items():
setattr(db_model, rel_attr, related_instances)
session.add(db_model)
await session.refresh(db_model)
result = cast(ModelType, db_model)
if schema:
return Response(data=schema.model_validate(result))
return result
@overload
@classmethod
async def get( # pragma: no cover
cls: type[Self],
session: AsyncSession,
filters: list[Any],
*,
joins: JoinType | None = None,
outer_join: bool = False,
with_for_update: bool = False,
load_options: list[ExecutableOption] | None = None,
schema: type[SchemaType],
) -> Response[SchemaType]: ...
@overload
@classmethod
async def get( # pragma: no cover
cls: type[Self],
session: AsyncSession,
filters: list[Any],
*,
joins: JoinType | None = None,
outer_join: bool = False,
with_for_update: bool = False,
load_options: list[ExecutableOption] | None = None,
schema: None = ...,
) -> ModelType: ...
@classmethod
async def get(
cls: type[Self],
session: AsyncSession,
filters: list[Any],
*,
joins: JoinType | None = None,
outer_join: bool = False,
with_for_update: bool = False,
load_options: list[ExecutableOption] | None = None,
schema: type[BaseModel] | None = None,
) -> ModelType | Response[Any]:
"""Get exactly one record. Raises NotFoundError if not found.
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
with_for_update: Lock the row for update
load_options: SQLAlchemy loader options (e.g., selectinload)
schema: Pydantic schema to serialize the result into. When provided,
the result is automatically wrapped in a ``Response[schema]``.
Returns:
Model instance, or ``Response[schema]`` when ``schema`` is given.
Raises:
NotFoundError: If no record found
MultipleResultsFound: If more than one record found
"""
q = select(cls.model)
q = _apply_joins(q, joins, outer_join)
q = q.where(and_(*filters))
if resolved := cls._resolve_load_options(load_options):
q = q.options(*resolved)
if with_for_update:
q = q.with_for_update()
result = await session.execute(q)
item = result.unique().scalar_one_or_none()
if not item:
raise NotFoundError()
result = cast(ModelType, item)
if schema:
return Response(data=schema.model_validate(result))
return result
@classmethod
async def first(
cls: type[Self],
session: AsyncSession,
filters: list[Any] | None = None,
*,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: list[ExecutableOption] | None = None,
) -> ModelType | None:
"""Get the first matching record, or None.
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
load_options: SQLAlchemy loader options
Returns:
Model instance or None
"""
q = select(cls.model)
q = _apply_joins(q, joins, outer_join)
if filters:
q = q.where(and_(*filters))
if resolved := cls._resolve_load_options(load_options):
q = q.options(*resolved)
result = await session.execute(q)
return cast(ModelType | None, result.unique().scalars().first())
@classmethod
async def get_multi(
cls: type[Self],
session: AsyncSession,
*,
filters: list[Any] | None = None,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: list[ExecutableOption] | None = None,
order_by: OrderByClause | None = None,
limit: int | None = None,
offset: int | None = None,
) -> Sequence[ModelType]:
"""Get multiple records from the database.
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
load_options: SQLAlchemy loader options
order_by: Column or list of columns to order by
limit: Max number of rows to return
offset: Rows to skip
Returns:
List of model instances
"""
q = select(cls.model)
q = _apply_joins(q, joins, outer_join)
if filters:
q = q.where(and_(*filters))
if resolved := cls._resolve_load_options(load_options):
q = q.options(*resolved)
if order_by is not None:
q = q.order_by(order_by)
if offset is not None:
q = q.offset(offset)
if limit is not None:
q = q.limit(limit)
result = await session.execute(q)
return cast(Sequence[ModelType], result.unique().scalars().all())
@overload
@classmethod
async def update( # pragma: no cover
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
filters: list[Any],
*,
exclude_unset: bool = True,
exclude_none: bool = False,
schema: type[SchemaType],
) -> Response[SchemaType]: ...
@overload
@classmethod
async def update( # pragma: no cover
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
filters: list[Any],
*,
exclude_unset: bool = True,
exclude_none: bool = False,
schema: None = ...,
) -> ModelType: ...
@classmethod
async def update(
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
filters: list[Any],
*,
exclude_unset: bool = True,
exclude_none: bool = False,
schema: type[BaseModel] | None = None,
) -> ModelType | Response[Any]:
"""Update a record in the database.
Args:
session: DB async session
obj: Pydantic model with update data
filters: List of SQLAlchemy filter conditions
exclude_unset: Exclude fields not explicitly set in the schema
exclude_none: Exclude fields with None value
schema: Pydantic schema to serialize the result into. When provided,
the result is automatically wrapped in a ``Response[schema]``.
Returns:
Updated model instance, or ``Response[schema]`` when ``schema`` is given.
Raises:
NotFoundError: If no record found
"""
async with get_transaction(session):
m2m_exclude = cls._m2m_schema_fields()
# Eagerly load M2M relationships that will be updated so that
# setattr does not trigger a lazy load (which fails in async).
m2m_load_options: list[ExecutableOption] = []
if m2m_exclude and cls.m2m_fields:
for schema_field, rel in cls.m2m_fields.items():
if schema_field in obj.model_fields_set:
m2m_load_options.append(selectinload(rel))
db_model = await cls.get(
session=session,
filters=filters,
load_options=m2m_load_options or None,
)
values = obj.model_dump(
exclude_unset=exclude_unset,
exclude_none=exclude_none,
exclude=m2m_exclude,
)
for key, value in values.items():
setattr(db_model, key, value)
if m2m_exclude:
m2m_resolved = await cls._resolve_m2m(session, obj, only_set=True)
for rel_attr, related_instances in m2m_resolved.items():
setattr(db_model, rel_attr, related_instances)
await session.refresh(db_model)
if schema:
return Response(data=schema.model_validate(db_model))
return db_model
@classmethod
async def upsert(
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
index_elements: list[str],
*,
set_: BaseModel | None = None,
where: WhereHavingRole | None = None,
) -> ModelType | None:
"""Create or update a record (PostgreSQL only).
Uses INSERT ... ON CONFLICT for atomic upsert.
Args:
session: DB async session
obj: Pydantic model with data
index_elements: Columns for ON CONFLICT (unique constraint)
set_: Pydantic model for ON CONFLICT DO UPDATE SET
where: WHERE clause for ON CONFLICT DO UPDATE
Returns:
Model instance
"""
async with get_transaction(session):
values = obj.model_dump(exclude_unset=True)
q = insert(cls.model).values(**values)
if set_:
q = q.on_conflict_do_update(
index_elements=index_elements,
set_=set_.model_dump(exclude_unset=True),
where=where,
)
else:
q = q.on_conflict_do_nothing(index_elements=index_elements)
q = q.returning(cls.model)
result = await session.execute(q)
try:
db_model = result.unique().scalar_one()
except NoResultFound:
db_model = await cls.first(
session=session,
filters=[getattr(cls.model, k) == v for k, v in values.items()],
)
return cast(ModelType | None, db_model)
@overload
@classmethod
async def delete( # pragma: no cover
cls: type[Self],
session: AsyncSession,
filters: list[Any],
*,
return_response: Literal[True],
) -> Response[None]: ...
@overload
@classmethod
async def delete( # pragma: no cover
cls: type[Self],
session: AsyncSession,
filters: list[Any],
*,
return_response: Literal[False] = ...,
) -> None: ...
@classmethod
async def delete(
cls: type[Self],
session: AsyncSession,
filters: list[Any],
*,
return_response: bool = False,
) -> None | Response[None]:
"""Delete records from the database.
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
return_response: When ``True``, returns ``Response[None]`` instead
of ``None``. Useful for API endpoints that expect a consistent
response envelope.
Returns:
``None``, or ``Response[None]`` when ``return_response=True``.
"""
async with get_transaction(session):
q = sql_delete(cls.model).where(and_(*filters))
await session.execute(q)
if return_response:
return Response(data=None)
return None
@classmethod
async def count(
cls: type[Self],
session: AsyncSession,
filters: list[Any] | None = None,
*,
joins: JoinType | None = None,
outer_join: bool = False,
) -> int:
"""Count records matching the filters.
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
Returns:
Number of matching records
"""
q = select(func.count()).select_from(cls.model)
q = _apply_joins(q, joins, outer_join)
if filters:
q = q.where(and_(*filters))
result = await session.execute(q)
return result.scalar_one()
@classmethod
async def exists(
cls: type[Self],
session: AsyncSession,
filters: list[Any],
*,
joins: JoinType | None = None,
outer_join: bool = False,
) -> bool:
"""Check if a record exists.
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
Returns:
True if at least one record matches
"""
q = select(cls.model)
q = _apply_joins(q, joins, outer_join)
q = q.where(and_(*filters)).exists().select()
result = await session.execute(q)
return bool(result.scalar())
@classmethod
async def offset_paginate(
cls: type[Self],
session: AsyncSession,
*,
filters: list[Any] | None = None,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: list[ExecutableOption] | None = None,
order_by: OrderByClause | None = None,
page: int = 1,
items_per_page: int = 20,
search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None,
facet_fields: Sequence[FacetFieldType] | None = None,
filter_by: dict[str, Any] | BaseModel | None = None,
schema: type[BaseModel],
) -> PaginatedResponse[Any]:
"""Get paginated results using offset-based pagination.
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
load_options: SQLAlchemy loader options
order_by: Column or list of columns to order by
page: Page number (1-indexed)
items_per_page: Number of items per page
search: Search query string or SearchConfig object
search_fields: Fields to search in (overrides class default)
facet_fields: Columns to compute distinct values for (overrides class default)
filter_by: Dict of {column_key: value} to filter by declared facet fields.
Keys must match the column.key of a facet field. Scalar → equality,
list → IN clause. Raises InvalidFacetFilterError for unknown keys.
schema: Pydantic schema to serialize each item into.
Returns:
PaginatedResponse with OffsetPagination metadata
"""
filters = list(filters) if filters else []
offset = (page - 1) * items_per_page
fb_filters, search_joins = cls._prepare_filter_by(filter_by, facet_fields)
filters.extend(fb_filters)
# Build search filters
if search:
search_filters, new_search_joins = build_search_filters(
cls.model,
search,
search_fields=search_fields,
default_fields=cls.searchable_fields,
)
filters.extend(search_filters)
search_joins.extend(new_search_joins)
# Build query with joins
q = select(cls.model)
# Apply explicit joins
q = _apply_joins(q, joins, outer_join)
# Apply search joins (always outer joins for search)
q = _apply_search_joins(q, search_joins)
if filters:
q = q.where(and_(*filters))
if resolved := cls._resolve_load_options(load_options):
q = q.options(*resolved)
if order_by is not None:
q = q.order_by(order_by)
q = q.offset(offset).limit(items_per_page)
result = await session.execute(q)
raw_items = cast(list[ModelType], result.unique().scalars().all())
items: list[Any] = [schema.model_validate(item) for item in raw_items]
# Count query (with same joins and filters)
pk_col = cls.model.__mapper__.primary_key[0]
count_q = select(func.count(func.distinct(getattr(cls.model, pk_col.name))))
count_q = count_q.select_from(cls.model)
# Apply explicit joins to count query
count_q = _apply_joins(count_q, joins, outer_join)
# Apply search joins to count query
count_q = _apply_search_joins(count_q, search_joins)
if filters:
count_q = count_q.where(and_(*filters))
count_result = await session.execute(count_q)
total_count = count_result.scalar_one()
filter_attributes = await cls._build_filter_attributes(
session, facet_fields, filters, search_joins
)
return PaginatedResponse(
data=items,
pagination=OffsetPagination(
total_count=total_count,
items_per_page=items_per_page,
page=page,
has_more=page * items_per_page < total_count,
),
filter_attributes=filter_attributes,
)
@classmethod
async def cursor_paginate(
cls: type[Self],
session: AsyncSession,
*,
cursor: str | None = None,
filters: list[Any] | None = None,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: list[ExecutableOption] | None = None,
order_by: OrderByClause | None = None,
items_per_page: int = 20,
search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None,
facet_fields: Sequence[FacetFieldType] | None = None,
filter_by: dict[str, Any] | BaseModel | None = None,
schema: type[BaseModel],
) -> PaginatedResponse[Any]:
"""Get paginated results using cursor-based pagination.
Args:
session: DB async session.
cursor: Cursor string from a previous ``CursorPagination``.
Omit (or pass ``None``) to start from the beginning.
filters: List of SQLAlchemy filter conditions.
joins: List of ``(model, condition)`` tuples for joining related
tables.
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN.
load_options: SQLAlchemy loader options. Falls back to
``default_load_options`` when not provided.
order_by: Additional ordering applied after the cursor column.
items_per_page: Number of items per page (default 20).
search: Search query string or SearchConfig object.
search_fields: Fields to search in (overrides class default).
facet_fields: Columns to compute distinct values for (overrides class default).
filter_by: Dict of {column_key: value} to filter by declared facet fields.
Keys must match the column.key of a facet field. Scalar → equality,
list → IN clause. Raises InvalidFacetFilterError for unknown keys.
schema: Optional Pydantic schema to serialize each item into.
Returns:
PaginatedResponse with CursorPagination metadata
"""
filters = list(filters) if filters else []
fb_filters, search_joins = cls._prepare_filter_by(filter_by, facet_fields)
filters.extend(fb_filters)
if cls.cursor_column is None:
raise ValueError(
f"{cls.__name__}.cursor_column is not set. "
"Pass cursor_column=<column> to CrudFactory() to use cursor_paginate."
)
cursor_column: Any = cls.cursor_column
cursor_col_name: str = cursor_column.key
if cursor is not None:
raw_val = _decode_cursor(cursor)
col_type = cursor_column.property.columns[0].type
if isinstance(col_type, Integer):
cursor_val: Any = int(raw_val)
elif isinstance(col_type, Uuid):
cursor_val = uuid_module.UUID(raw_val)
elif isinstance(col_type, DateTime):
cursor_val = datetime.fromisoformat(raw_val)
elif isinstance(col_type, Date):
cursor_val = date.fromisoformat(raw_val)
elif isinstance(col_type, (Float, Numeric)):
cursor_val = Decimal(raw_val)
else:
raise ValueError(
f"Unsupported cursor column type: {type(col_type).__name__!r}. "
"Supported types: Integer, BigInteger, SmallInteger, Uuid, "
"DateTime, Date, Float, Numeric."
)
filters.append(cursor_column > cursor_val)
# Build search filters
if search:
search_filters, new_search_joins = build_search_filters(
cls.model,
search,
search_fields=search_fields,
default_fields=cls.searchable_fields,
)
filters.extend(search_filters)
search_joins.extend(new_search_joins)
# Build query
q = select(cls.model)
# Apply explicit joins
q = _apply_joins(q, joins, outer_join)
# Apply search joins (always outer joins)
q = _apply_search_joins(q, search_joins)
if filters:
q = q.where(and_(*filters))
if resolved := cls._resolve_load_options(load_options):
q = q.options(*resolved)
# Cursor column is always the primary sort
q = q.order_by(cursor_column)
if order_by is not None:
q = q.order_by(order_by)
# Fetch one extra to detect whether a next page exists
q = q.limit(items_per_page + 1)
result = await session.execute(q)
raw_items = cast(list[ModelType], result.unique().scalars().all())
has_more = len(raw_items) > items_per_page
items_page = raw_items[:items_per_page]
# next_cursor points past the last item on this page
next_cursor: str | None = None
if has_more and items_page:
next_cursor = _encode_cursor(getattr(items_page[-1], cursor_col_name))
# prev_cursor points to the first item on this page or None when on the first page
prev_cursor: str | None = None
if cursor is not None and items_page:
prev_cursor = _encode_cursor(getattr(items_page[0], cursor_col_name))
items: list[Any] = [schema.model_validate(item) for item in items_page]
filter_attributes = await cls._build_filter_attributes(
session, facet_fields, filters, search_joins
)
return PaginatedResponse(
data=items,
pagination=CursorPagination(
next_cursor=next_cursor,
prev_cursor=prev_cursor,
items_per_page=items_per_page,
has_more=has_more,
),
filter_attributes=filter_attributes,
)
def CrudFactory(
model: type[ModelType],
*,
searchable_fields: Sequence[SearchFieldType] | None = None,
facet_fields: Sequence[FacetFieldType] | None = None,
order_fields: Sequence[QueryableAttribute[Any]] | None = None,
m2m_fields: M2MFieldType | None = None,
default_load_options: list[ExecutableOption] | None = None,
cursor_column: Any | None = None,
) -> type[AsyncCrud[ModelType]]:
"""Create a CRUD class for a specific model.
Args:
model: SQLAlchemy model class
searchable_fields: Optional list of searchable fields
facet_fields: Optional list of columns to compute distinct values for in paginated
responses. Supports direct columns (``User.status``) and relationship tuples
(``(User.role, Role.name)``). Can be overridden per call.
order_fields: Optional list of model attributes that callers are allowed to order by
via ``order_params()``. Can be overridden per call.
m2m_fields: Optional mapping for many-to-many relationships.
Maps schema field names (containing lists of IDs) to
SQLAlchemy relationship attributes.
default_load_options: Default SQLAlchemy loader options applied to all read
queries when no explicit ``load_options`` are passed. Use this
instead of ``lazy="selectin"`` on the model so that loading
strategy is explicit and per-CRUD. Overridden entirely (not
merged) when ``load_options`` is provided at call-site.
cursor_column: Required to call ``cursor_paginate``.
Must be monotonically ordered (e.g. integer PK, UUID v7, timestamp).
See the cursor pagination docs for supported column types.
Returns:
AsyncCrud subclass bound to the model
Example:
```python
from fastapi_toolsets.crud import CrudFactory
from myapp.models import User, Post
UserCrud = CrudFactory(User)
PostCrud = CrudFactory(Post)
# With searchable fields:
UserCrud = CrudFactory(
User,
searchable_fields=[User.username, User.email, (User.role, Role.name)]
)
# With many-to-many fields:
# Schema has `tag_ids: list[UUID]`, model has `tags` relationship to Tag
PostCrud = CrudFactory(
Post,
m2m_fields={"tag_ids": Post.tags},
)
# With facet fields for filter dropdowns / faceted search:
UserCrud = CrudFactory(
User,
facet_fields=[User.status, User.country, (User.role, Role.name)],
)
# With a fixed cursor column for cursor_paginate:
PostCrud = CrudFactory(
Post,
cursor_column=Post.created_at,
)
# With default load strategy (replaces lazy="selectin" on the model):
ArticleCrud = CrudFactory(
Article,
default_load_options=[selectinload(Article.category), selectinload(Article.tags)],
)
# Override default_load_options for a specific call:
article = await ArticleCrud.get(
session,
[Article.id == 1],
load_options=[selectinload(Article.category)], # tags won't load
)
# Usage
user = await UserCrud.get(session, [User.id == 1])
posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id])
# Create with M2M - tag_ids are automatically resolved
post = await PostCrud.create(session, PostCreate(title="Hello", tag_ids=[id1, id2]))
# With search
result = await UserCrud.offset_paginate(session, search="john")
# With joins (inner join by default):
users = await UserCrud.get_multi(
session,
joins=[(Post, Post.user_id == User.id)],
filters=[Post.published == True],
)
# With outer join:
users = await UserCrud.get_multi(
session,
joins=[(Post, Post.user_id == User.id)],
outer_join=True,
)
```
"""
cls = type(
f"Async{model.__name__}Crud",
(AsyncCrud,),
{
"model": model,
"searchable_fields": searchable_fields,
"facet_fields": facet_fields,
"order_fields": order_fields,
"m2m_fields": m2m_fields,
"default_load_options": default_load_options,
"cursor_column": cursor_column,
},
)
return cast(type[AsyncCrud[ModelType]], cls)