Files
fastapi-toolsets/src/fastapi_toolsets/crud/factory.py

1768 lines
65 KiB
Python

"""Generic async CRUD operations for SQLAlchemy models."""
from __future__ import annotations
import base64
import inspect
import json
import uuid as uuid_module
from collections.abc import Awaitable, Callable, Sequence
from datetime import date, datetime
from decimal import Decimal
from enum import Enum
from typing import Any, ClassVar, Generic, Literal, Self, cast, overload
from fastapi import Query
from pydantic import BaseModel
from sqlalchemy import Date, DateTime, Float, Integer, Numeric, Uuid, and_, func, select
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.exc import NoResultFound
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase, QueryableAttribute, selectinload
from sqlalchemy.sql.base import ExecutableOption
from sqlalchemy.sql.roles import WhereHavingRole
from ..db import get_transaction
from ..exceptions import InvalidOrderFieldError, NotFoundError
from ..schemas import (
CursorPaginatedResponse,
CursorPagination,
OffsetPaginatedResponse,
OffsetPagination,
PaginationType,
Response,
)
from ..types import (
FacetFieldType,
JoinType,
M2MFieldType,
ModelType,
OrderByClause,
SchemaType,
SearchFieldType,
)
from .search import (
SearchConfig,
build_facets,
build_filter_by,
build_search_filters,
facet_keys,
search_field_keys,
)
class _CursorDirection(str, Enum):
NEXT = "next"
PREV = "prev"
def _encode_cursor(
value: Any, *, direction: _CursorDirection = _CursorDirection.NEXT
) -> str:
"""Encode a cursor column value and navigation direction as a URL-safe base64 string."""
return (
base64.urlsafe_b64encode(
json.dumps({"val": str(value), "dir": direction}).encode()
)
.decode()
.rstrip("=")
)
def _decode_cursor(cursor: str) -> tuple[str, _CursorDirection]:
"""Decode a URL-safe base64 cursor string into ``(raw_value, direction)``."""
padded = cursor + "=" * (-len(cursor) % 4)
payload = json.loads(base64.urlsafe_b64decode(padded).decode())
return payload["val"], _CursorDirection(payload["dir"])
def _page_size_query(default: int, max_size: int) -> int:
"""Return a FastAPI ``Query`` for the ``items_per_page`` parameter."""
return Query(
default,
ge=1,
le=max_size,
description=f"Number of items per page (max {max_size})",
)
def _parse_cursor_value(raw_val: str, col_type: Any) -> Any:
"""Parse a raw cursor string value back into the appropriate Python type."""
if isinstance(col_type, Integer):
return int(raw_val)
if isinstance(col_type, Uuid):
return uuid_module.UUID(raw_val)
if isinstance(col_type, DateTime):
return datetime.fromisoformat(raw_val)
if isinstance(col_type, Date):
return date.fromisoformat(raw_val)
if isinstance(col_type, (Float, Numeric)):
return Decimal(raw_val)
raise ValueError(
f"Unsupported cursor column type: {type(col_type).__name__!r}. "
"Supported types: Integer, BigInteger, SmallInteger, Uuid, "
"DateTime, Date, Float, Numeric."
)
def _apply_joins(q: Any, joins: JoinType | None, outer_join: bool) -> Any:
"""Apply a list of (model, condition) joins to a SQLAlchemy select query."""
if not joins:
return q
for model, condition in joins:
q = q.outerjoin(model, condition) if outer_join else q.join(model, condition)
return q
def _apply_search_joins(q: Any, search_joins: list[Any]) -> Any:
"""Apply relationship-based outer joins (from search/filter_by) to a query."""
for join_rel in search_joins:
q = q.outerjoin(join_rel)
return q
class AsyncCrud(Generic[ModelType]):
"""Generic async CRUD operations for SQLAlchemy models.
Subclass this and set the `model` class variable, or use `CrudFactory`.
"""
model: ClassVar[type[DeclarativeBase]]
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
facet_fields: ClassVar[Sequence[FacetFieldType] | None] = None
order_fields: ClassVar[Sequence[QueryableAttribute[Any]] | None] = None
m2m_fields: ClassVar[M2MFieldType | None] = None
default_load_options: ClassVar[Sequence[ExecutableOption] | None] = None
cursor_column: ClassVar[Any | None] = None
@classmethod
def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
if "model" not in cls.__dict__:
return
model: type[DeclarativeBase] = cls.__dict__["model"]
pk_key = model.__mapper__.primary_key[0].key
assert pk_key is not None
pk_col = getattr(model, pk_key)
raw_fields: Sequence[SearchFieldType] | None = cls.__dict__.get(
"searchable_fields", None
)
if raw_fields is None:
cls.searchable_fields = [pk_col]
else:
if not any(
not isinstance(f, tuple) and f.key == pk_key for f in raw_fields
):
cls.searchable_fields = [pk_col, *raw_fields]
@classmethod
def _resolve_load_options(
cls, load_options: Sequence[ExecutableOption] | None
) -> Sequence[ExecutableOption] | None:
"""Return load_options if provided, else fall back to default_load_options."""
if load_options is not None:
return load_options
return cls.default_load_options
@classmethod
async def _resolve_m2m(
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
*,
only_set: bool = False,
) -> dict[str, list[Any]]:
"""Resolve M2M fields from a Pydantic schema into related model instances.
Args:
session: DB async session
obj: Pydantic model containing M2M ID fields
only_set: If True, only process fields explicitly set on the schema
Returns:
Dict mapping relationship attr names to lists of related instances
"""
result: dict[str, list[Any]] = {}
if not cls.m2m_fields:
return result
for schema_field, rel in cls.m2m_fields.items():
rel_attr = rel.property.key
related_model = rel.property.mapper.class_
if only_set and schema_field not in obj.model_fields_set:
continue
ids = getattr(obj, schema_field, None)
if ids is not None:
related = (
(
await session.execute(
select(related_model).where(related_model.id.in_(ids))
)
)
.scalars()
.all()
)
if len(related) != len(ids):
found_ids = {r.id for r in related}
missing = set(ids) - found_ids
raise NotFoundError(
f"Related {related_model.__name__} not found for IDs: {missing}"
)
result[rel_attr] = list(related)
else:
result[rel_attr] = []
return result
@classmethod
def _m2m_schema_fields(cls: type[Self]) -> set[str]:
"""Return the set of schema field names that are M2M fields."""
if not cls.m2m_fields:
return set()
return set(cls.m2m_fields.keys())
@classmethod
def _resolve_facet_fields(
cls: type[Self],
facet_fields: Sequence[FacetFieldType] | None,
) -> Sequence[FacetFieldType] | None:
"""Return facet_fields if given, otherwise fall back to the class-level default."""
return facet_fields if facet_fields is not None else cls.facet_fields
@classmethod
def _prepare_filter_by(
cls: type[Self],
filter_by: dict[str, Any] | BaseModel | None,
facet_fields: Sequence[FacetFieldType] | None,
) -> tuple[list[Any], list[Any]]:
"""Normalize filter_by and return (filters, joins) to apply to the query."""
if isinstance(filter_by, BaseModel):
filter_by = filter_by.model_dump(exclude_none=True)
if not filter_by:
return [], []
resolved = cls._resolve_facet_fields(facet_fields)
return build_filter_by(filter_by, resolved or [])
@classmethod
async def _build_filter_attributes(
cls: type[Self],
session: AsyncSession,
facet_fields: Sequence[FacetFieldType] | None,
filters: list[Any],
search_joins: list[Any],
) -> dict[str, list[Any]] | None:
"""Build facet filter_attributes, or return None if no facet fields configured."""
resolved = cls._resolve_facet_fields(facet_fields)
if not resolved:
return None
return await build_facets(
session,
cls.model,
resolved,
base_filters=filters,
base_joins=search_joins,
)
@classmethod
def _resolve_search_columns(
cls: type[Self],
search_fields: Sequence[SearchFieldType] | None,
) -> list[str] | None:
"""Return search column keys, or None if no searchable fields configured."""
fields = search_fields if search_fields is not None else cls.searchable_fields
if not fields:
return None
return search_field_keys(fields)
@classmethod
def _build_paginate_params(
cls: type[Self],
*,
pagination_params: list[inspect.Parameter],
pagination_fixed: dict[str, Any],
dep_name: str,
search: bool,
filter: bool,
order: bool,
search_fields: Sequence[SearchFieldType] | None,
facet_fields: Sequence[FacetFieldType] | None,
order_fields: Sequence[QueryableAttribute[Any]] | None,
default_order_field: QueryableAttribute[Any] | None,
default_order: Literal["asc", "desc"],
) -> Callable[..., Awaitable[dict[str, Any]]]:
"""Build a consolidated FastAPI dependency that merges pagination, search, filter, and order params."""
all_params: list[inspect.Parameter] = list(pagination_params)
pagination_param_names = tuple(p.name for p in pagination_params)
reserved_names: set[str] = set(pagination_param_names)
search_keys: list[str] | None = None
if search:
search_keys = cls._resolve_search_columns(search_fields)
if search_keys:
all_params.extend(
[
inspect.Parameter(
"search",
inspect.Parameter.KEYWORD_ONLY,
annotation=str | None,
default=Query(
default=None, description="Search query string"
),
),
inspect.Parameter(
"search_column",
inspect.Parameter.KEYWORD_ONLY,
annotation=str | None,
default=Query(
default=None,
description="Restrict search to a single column",
enum=search_keys,
),
),
]
)
reserved_names.update({"search", "search_column"})
filter_keys: list[str] | None = None
if filter:
resolved_facets = cls._resolve_facet_fields(facet_fields)
if resolved_facets:
filter_keys = facet_keys(resolved_facets)
for k in filter_keys:
if k in reserved_names:
raise ValueError(
f"Facet field key {k!r} conflicts with a reserved "
f"parameter name. Reserved names: {sorted(reserved_names)}"
)
all_params.extend(
inspect.Parameter(
k,
inspect.Parameter.KEYWORD_ONLY,
annotation=list[str] | None,
default=Query(default=None),
)
for k in filter_keys
)
reserved_names.update(filter_keys)
order_field_map: dict[str, QueryableAttribute[Any]] | None = None
order_valid_keys: list[str] | None = None
if order:
resolved_order = (
order_fields if order_fields is not None else cls.order_fields
)
if resolved_order:
order_field_map = {f.key: f for f in resolved_order}
order_valid_keys = sorted(order_field_map.keys())
all_params.extend(
[
inspect.Parameter(
"order_by",
inspect.Parameter.KEYWORD_ONLY,
annotation=str | None,
default=Query(
None,
description=f"Field to order by. Valid values: {order_valid_keys}",
enum=order_valid_keys,
),
),
inspect.Parameter(
"order",
inspect.Parameter.KEYWORD_ONLY,
annotation=Literal["asc", "desc"],
default=Query(default_order, description="Sort direction"),
),
]
)
async def dependency(**kwargs: Any) -> dict[str, Any]:
result: dict[str, Any] = dict(pagination_fixed)
for name in pagination_param_names:
result[name] = kwargs[name]
if search_keys is not None:
search_val = kwargs.get("search")
if search_val is not None:
result["search"] = search_val
search_col_val = kwargs.get("search_column")
if search_col_val is not None:
result["search_column"] = search_col_val
if filter_keys is not None:
filter_by = {
k: kwargs[k] for k in filter_keys if kwargs.get(k) is not None
}
result["filter_by"] = filter_by or None
if order_field_map is not None:
order_by_val = kwargs.get("order_by")
order_dir = kwargs.get("order", default_order)
if order_by_val is None:
field = default_order_field
elif order_by_val not in order_field_map:
raise InvalidOrderFieldError(order_by_val, order_valid_keys or [])
else:
field = order_field_map[order_by_val]
if field is not None:
result["order_by"] = (
field.asc() if order_dir == "asc" else field.desc()
)
else:
result["order_by"] = None
return result
dependency.__name__ = dep_name
dependency.__signature__ = inspect.Signature( # type: ignore[attr-defined] # ty:ignore[unresolved-attribute]
parameters=all_params,
)
return dependency
@classmethod
def offset_paginate_params(
cls: type[Self],
*,
default_page_size: int = 20,
max_page_size: int = 100,
include_total: bool = True,
search: bool = True,
filter: bool = True,
order: bool = True,
search_fields: Sequence[SearchFieldType] | None = None,
facet_fields: Sequence[FacetFieldType] | None = None,
order_fields: Sequence[QueryableAttribute[Any]] | None = None,
default_order_field: QueryableAttribute[Any] | None = None,
default_order: Literal["asc", "desc"] = "asc",
) -> Callable[..., Awaitable[dict[str, Any]]]:
"""Return a FastAPI dependency that collects all params for :meth:`offset_paginate`.
Args:
default_page_size: Default ``items_per_page`` value.
max_page_size: Maximum ``items_per_page`` value.
include_total: Whether to include total count (not a query param).
search: Enable search query parameters.
filter: Enable facet filter query parameters.
order: Enable order query parameters.
search_fields: Override searchable fields.
facet_fields: Override facet fields.
order_fields: Override order fields.
default_order_field: Default field to order by when ``order_by`` is absent.
default_order: Default sort direction.
Returns:
An async dependency that resolves to a dict ready to be unpacked
into :meth:`offset_paginate`.
"""
pagination_params = [
inspect.Parameter(
"page",
inspect.Parameter.KEYWORD_ONLY,
annotation=int,
default=Query(1, ge=1, description="Page number (1-indexed)"),
),
inspect.Parameter(
"items_per_page",
inspect.Parameter.KEYWORD_ONLY,
annotation=int,
default=_page_size_query(default_page_size, max_page_size),
),
]
return cls._build_paginate_params(
pagination_params=pagination_params,
pagination_fixed={"include_total": include_total},
dep_name=f"{cls.model.__name__}OffsetPaginateParams",
search=search,
filter=filter,
order=order,
search_fields=search_fields,
facet_fields=facet_fields,
order_fields=order_fields,
default_order_field=default_order_field,
default_order=default_order,
)
@classmethod
def cursor_paginate_params(
cls: type[Self],
*,
default_page_size: int = 20,
max_page_size: int = 100,
search: bool = True,
filter: bool = True,
order: bool = True,
search_fields: Sequence[SearchFieldType] | None = None,
facet_fields: Sequence[FacetFieldType] | None = None,
order_fields: Sequence[QueryableAttribute[Any]] | None = None,
default_order_field: QueryableAttribute[Any] | None = None,
default_order: Literal["asc", "desc"] = "asc",
) -> Callable[..., Awaitable[dict[str, Any]]]:
"""Return a FastAPI dependency that collects all params for :meth:`cursor_paginate`.
Args:
default_page_size: Default ``items_per_page`` value.
max_page_size: Maximum ``items_per_page`` value.
search: Enable search query parameters.
filter: Enable facet filter query parameters.
order: Enable order query parameters.
search_fields: Override searchable fields.
facet_fields: Override facet fields.
order_fields: Override order fields.
default_order_field: Default field to order by when ``order_by`` is absent.
default_order: Default sort direction.
Returns:
An async dependency that resolves to a dict ready to be unpacked
into :meth:`cursor_paginate`.
"""
pagination_params = [
inspect.Parameter(
"cursor",
inspect.Parameter.KEYWORD_ONLY,
annotation=str | None,
default=Query(
None, description="Cursor token from a previous response"
),
),
inspect.Parameter(
"items_per_page",
inspect.Parameter.KEYWORD_ONLY,
annotation=int,
default=_page_size_query(default_page_size, max_page_size),
),
]
return cls._build_paginate_params(
pagination_params=pagination_params,
pagination_fixed={},
dep_name=f"{cls.model.__name__}CursorPaginateParams",
search=search,
filter=filter,
order=order,
search_fields=search_fields,
facet_fields=facet_fields,
order_fields=order_fields,
default_order_field=default_order_field,
default_order=default_order,
)
@classmethod
def paginate_params(
cls: type[Self],
*,
default_page_size: int = 20,
max_page_size: int = 100,
default_pagination_type: PaginationType = PaginationType.OFFSET,
include_total: bool = True,
search: bool = True,
filter: bool = True,
order: bool = True,
search_fields: Sequence[SearchFieldType] | None = None,
facet_fields: Sequence[FacetFieldType] | None = None,
order_fields: Sequence[QueryableAttribute[Any]] | None = None,
default_order_field: QueryableAttribute[Any] | None = None,
default_order: Literal["asc", "desc"] = "asc",
) -> Callable[..., Awaitable[dict[str, Any]]]:
"""Return a FastAPI dependency that collects all params for :meth:`paginate`.
Args:
default_page_size: Default ``items_per_page`` value.
max_page_size: Maximum ``items_per_page`` value.
default_pagination_type: Default pagination strategy.
include_total: Whether to include total count (not a query param).
search: Enable search query parameters.
filter: Enable facet filter query parameters.
order: Enable order query parameters.
search_fields: Override searchable fields.
facet_fields: Override facet fields.
order_fields: Override order fields.
default_order_field: Default field to order by when ``order_by`` is absent.
default_order: Default sort direction.
Returns:
An async dependency that resolves to a dict ready to be unpacked
into :meth:`paginate`.
"""
pagination_params = [
inspect.Parameter(
"pagination_type",
inspect.Parameter.KEYWORD_ONLY,
annotation=PaginationType,
default=Query(
default_pagination_type, description="Pagination strategy"
),
),
inspect.Parameter(
"page",
inspect.Parameter.KEYWORD_ONLY,
annotation=int,
default=Query(
1, ge=1, description="Page number (1-indexed, offset only)"
),
),
inspect.Parameter(
"cursor",
inspect.Parameter.KEYWORD_ONLY,
annotation=str | None,
default=Query(
None,
description="Cursor token from a previous response (cursor only)",
),
),
inspect.Parameter(
"items_per_page",
inspect.Parameter.KEYWORD_ONLY,
annotation=int,
default=_page_size_query(default_page_size, max_page_size),
),
]
return cls._build_paginate_params(
pagination_params=pagination_params,
pagination_fixed={"include_total": include_total},
dep_name=f"{cls.model.__name__}PaginateParams",
search=search,
filter=filter,
order=order,
search_fields=search_fields,
facet_fields=facet_fields,
order_fields=order_fields,
default_order_field=default_order_field,
default_order=default_order,
)
@overload
@classmethod
async def create( # pragma: no cover
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
*,
schema: type[SchemaType],
) -> Response[SchemaType]: ...
@overload
@classmethod
async def create( # pragma: no cover
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
*,
schema: None = ...,
) -> ModelType: ...
@classmethod
async def create(
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
*,
schema: type[BaseModel] | None = None,
) -> ModelType | Response[Any]:
"""Create a new record in the database.
Args:
session: DB async session
obj: Pydantic model with data to create
schema: Pydantic schema to serialize the result into. When provided,
the result is automatically wrapped in a ``Response[schema]``.
Returns:
Created model instance, or ``Response[schema]`` when ``schema`` is given.
"""
async with get_transaction(session):
m2m_exclude = cls._m2m_schema_fields()
data = (
obj.model_dump(exclude=m2m_exclude) if m2m_exclude else obj.model_dump()
)
db_model = cls.model(**data)
if m2m_exclude:
m2m_resolved = await cls._resolve_m2m(session, obj)
for rel_attr, related_instances in m2m_resolved.items():
setattr(db_model, rel_attr, related_instances)
session.add(db_model)
await session.refresh(db_model)
result = cast(ModelType, db_model)
if schema:
return Response(data=schema.model_validate(result))
return result
@overload
@classmethod
async def get( # pragma: no cover
cls: type[Self],
session: AsyncSession,
filters: list[Any],
*,
joins: JoinType | None = None,
outer_join: bool = False,
with_for_update: bool = False,
load_options: Sequence[ExecutableOption] | None = None,
schema: type[SchemaType],
) -> Response[SchemaType]: ...
@overload
@classmethod
async def get( # pragma: no cover
cls: type[Self],
session: AsyncSession,
filters: list[Any],
*,
joins: JoinType | None = None,
outer_join: bool = False,
with_for_update: bool = False,
load_options: Sequence[ExecutableOption] | None = None,
schema: None = ...,
) -> ModelType: ...
@classmethod
async def get(
cls: type[Self],
session: AsyncSession,
filters: list[Any],
*,
joins: JoinType | None = None,
outer_join: bool = False,
with_for_update: bool = False,
load_options: Sequence[ExecutableOption] | None = None,
schema: type[BaseModel] | None = None,
) -> ModelType | Response[Any]:
"""Get exactly one record. Raises NotFoundError if not found.
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
with_for_update: Lock the row for update
load_options: SQLAlchemy loader options (e.g., selectinload)
schema: Pydantic schema to serialize the result into. When provided,
the result is automatically wrapped in a ``Response[schema]``.
Returns:
Model instance, or ``Response[schema]`` when ``schema`` is given.
Raises:
NotFoundError: If no record found
MultipleResultsFound: If more than one record found
"""
result = await cls.get_or_none(
session,
filters,
joins=joins,
outer_join=outer_join,
with_for_update=with_for_update,
load_options=load_options,
schema=schema,
)
if result is None:
raise NotFoundError()
return result
@overload
@classmethod
async def get_or_none( # pragma: no cover
cls: type[Self],
session: AsyncSession,
filters: list[Any],
*,
joins: JoinType | None = None,
outer_join: bool = False,
with_for_update: bool = False,
load_options: Sequence[ExecutableOption] | None = None,
schema: type[SchemaType],
) -> Response[SchemaType] | None: ...
@overload
@classmethod
async def get_or_none( # pragma: no cover
cls: type[Self],
session: AsyncSession,
filters: list[Any],
*,
joins: JoinType | None = None,
outer_join: bool = False,
with_for_update: bool = False,
load_options: Sequence[ExecutableOption] | None = None,
schema: None = ...,
) -> ModelType | None: ...
@classmethod
async def get_or_none(
cls: type[Self],
session: AsyncSession,
filters: list[Any],
*,
joins: JoinType | None = None,
outer_join: bool = False,
with_for_update: bool = False,
load_options: Sequence[ExecutableOption] | None = None,
schema: type[BaseModel] | None = None,
) -> ModelType | Response[Any] | None:
"""Get exactly one record, or ``None`` if not found.
Like :meth:`get` but returns ``None`` instead of raising
:class:`~fastapi_toolsets.exceptions.NotFoundError` when no record
matches the filters.
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
with_for_update: Lock the row for update
load_options: SQLAlchemy loader options (e.g., selectinload)
schema: Pydantic schema to serialize the result into. When provided,
the result is automatically wrapped in a ``Response[schema]``.
Returns:
Model instance, ``Response[schema]`` when ``schema`` is given,
or ``None`` when no record matches.
Raises:
MultipleResultsFound: If more than one record found
"""
q = select(cls.model)
q = _apply_joins(q, joins, outer_join)
q = q.where(and_(*filters))
if resolved := cls._resolve_load_options(load_options):
q = q.options(*resolved)
if with_for_update:
q = q.with_for_update()
result = await session.execute(q)
item = result.unique().scalar_one_or_none()
if item is None:
return None
db_model = cast(ModelType, item)
if schema:
return Response(data=schema.model_validate(db_model))
return db_model
@overload
@classmethod
async def first( # pragma: no cover
cls: type[Self],
session: AsyncSession,
filters: list[Any] | None = None,
*,
joins: JoinType | None = None,
outer_join: bool = False,
with_for_update: bool = False,
load_options: Sequence[ExecutableOption] | None = None,
schema: type[SchemaType],
) -> Response[SchemaType] | None: ...
@overload
@classmethod
async def first( # pragma: no cover
cls: type[Self],
session: AsyncSession,
filters: list[Any] | None = None,
*,
joins: JoinType | None = None,
outer_join: bool = False,
with_for_update: bool = False,
load_options: Sequence[ExecutableOption] | None = None,
schema: None = ...,
) -> ModelType | None: ...
@classmethod
async def first(
cls: type[Self],
session: AsyncSession,
filters: list[Any] | None = None,
*,
joins: JoinType | None = None,
outer_join: bool = False,
with_for_update: bool = False,
load_options: Sequence[ExecutableOption] | None = None,
schema: type[BaseModel] | None = None,
) -> ModelType | Response[Any] | None:
"""Get the first matching record, or None.
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
with_for_update: Lock the row for update
load_options: SQLAlchemy loader options (e.g., selectinload)
schema: Pydantic schema to serialize the result into. When provided,
the result is automatically wrapped in a ``Response[schema]``.
Returns:
Model instance, ``Response[schema]`` when ``schema`` is given,
or ``None`` when no record matches.
"""
q = select(cls.model)
q = _apply_joins(q, joins, outer_join)
if filters:
q = q.where(and_(*filters))
if resolved := cls._resolve_load_options(load_options):
q = q.options(*resolved)
if with_for_update:
q = q.with_for_update()
result = await session.execute(q)
item = result.unique().scalars().first()
if item is None:
return None
db_model = cast(ModelType, item)
if schema:
return Response(data=schema.model_validate(db_model))
return db_model
@classmethod
async def get_multi(
cls: type[Self],
session: AsyncSession,
*,
filters: list[Any] | None = None,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: Sequence[ExecutableOption] | None = None,
order_by: OrderByClause | None = None,
limit: int | None = None,
offset: int | None = None,
) -> Sequence[ModelType]:
"""Get multiple records from the database.
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
load_options: SQLAlchemy loader options
order_by: Column or list of columns to order by
limit: Max number of rows to return
offset: Rows to skip
Returns:
List of model instances
"""
q = select(cls.model)
q = _apply_joins(q, joins, outer_join)
if filters:
q = q.where(and_(*filters))
if resolved := cls._resolve_load_options(load_options):
q = q.options(*resolved)
if order_by is not None:
q = q.order_by(order_by)
if offset is not None:
q = q.offset(offset)
if limit is not None:
q = q.limit(limit)
result = await session.execute(q)
return cast(Sequence[ModelType], result.unique().scalars().all())
@overload
@classmethod
async def update( # pragma: no cover
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
filters: list[Any],
*,
exclude_unset: bool = True,
exclude_none: bool = False,
schema: type[SchemaType],
) -> Response[SchemaType]: ...
@overload
@classmethod
async def update( # pragma: no cover
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
filters: list[Any],
*,
exclude_unset: bool = True,
exclude_none: bool = False,
schema: None = ...,
) -> ModelType: ...
@classmethod
async def update(
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
filters: list[Any],
*,
exclude_unset: bool = True,
exclude_none: bool = False,
schema: type[BaseModel] | None = None,
) -> ModelType | Response[Any]:
"""Update a record in the database.
Args:
session: DB async session
obj: Pydantic model with update data
filters: List of SQLAlchemy filter conditions
exclude_unset: Exclude fields not explicitly set in the schema
exclude_none: Exclude fields with None value
schema: Pydantic schema to serialize the result into. When provided,
the result is automatically wrapped in a ``Response[schema]``.
Returns:
Updated model instance, or ``Response[schema]`` when ``schema`` is given.
Raises:
NotFoundError: If no record found
"""
async with get_transaction(session):
m2m_exclude = cls._m2m_schema_fields()
# Eagerly load M2M relationships that will be updated so that
# setattr does not trigger a lazy load (which fails in async).
m2m_load_options: list[ExecutableOption] = []
if m2m_exclude and cls.m2m_fields:
for schema_field, rel in cls.m2m_fields.items():
if schema_field in obj.model_fields_set:
m2m_load_options.append(selectinload(rel))
db_model = await cls.get(
session=session,
filters=filters,
load_options=m2m_load_options or None,
)
values = obj.model_dump(
exclude_unset=exclude_unset,
exclude_none=exclude_none,
exclude=m2m_exclude,
)
for key, value in values.items():
setattr(db_model, key, value)
if m2m_exclude:
m2m_resolved = await cls._resolve_m2m(session, obj, only_set=True)
for rel_attr, related_instances in m2m_resolved.items():
setattr(db_model, rel_attr, related_instances)
await session.refresh(db_model)
if schema:
return Response(data=schema.model_validate(db_model))
return db_model
@classmethod
async def upsert(
cls: type[Self],
session: AsyncSession,
obj: BaseModel,
index_elements: list[str],
*,
set_: BaseModel | None = None,
where: WhereHavingRole | None = None,
) -> ModelType | None:
"""Create or update a record (PostgreSQL only).
Uses INSERT ... ON CONFLICT for atomic upsert.
Args:
session: DB async session
obj: Pydantic model with data
index_elements: Columns for ON CONFLICT (unique constraint)
set_: Pydantic model for ON CONFLICT DO UPDATE SET
where: WHERE clause for ON CONFLICT DO UPDATE
Returns:
Model instance
"""
async with get_transaction(session):
values = obj.model_dump(exclude_unset=True)
q = insert(cls.model).values(**values)
if set_:
q = q.on_conflict_do_update(
index_elements=index_elements,
set_=set_.model_dump(exclude_unset=True),
where=where,
)
else:
q = q.on_conflict_do_nothing(index_elements=index_elements)
q = q.returning(cls.model)
result = await session.execute(q)
try:
db_model = result.unique().scalar_one()
except NoResultFound:
db_model = await cls.first(
session=session,
filters=[getattr(cls.model, k) == v for k, v in values.items()],
)
return cast(ModelType | None, db_model)
@overload
@classmethod
async def delete( # pragma: no cover
cls: type[Self],
session: AsyncSession,
filters: list[Any],
*,
return_response: Literal[True],
) -> Response[None]: ...
@overload
@classmethod
async def delete( # pragma: no cover
cls: type[Self],
session: AsyncSession,
filters: list[Any],
*,
return_response: Literal[False] = ...,
) -> None: ...
@classmethod
async def delete(
cls: type[Self],
session: AsyncSession,
filters: list[Any],
*,
return_response: bool = False,
) -> None | Response[None]:
"""Delete records from the database.
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
return_response: When ``True``, returns ``Response[None]`` instead
of ``None``. Useful for API endpoints that expect a consistent
response envelope.
Returns:
``None``, or ``Response[None]`` when ``return_response=True``.
"""
async with get_transaction(session):
result = await session.execute(select(cls.model).where(and_(*filters)))
objects = result.scalars().all()
for obj in objects:
await session.delete(obj)
if return_response:
return Response(data=None)
return None
@classmethod
async def count(
cls: type[Self],
session: AsyncSession,
filters: list[Any] | None = None,
*,
joins: JoinType | None = None,
outer_join: bool = False,
) -> int:
"""Count records matching the filters.
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
Returns:
Number of matching records
"""
q = select(func.count()).select_from(cls.model)
q = _apply_joins(q, joins, outer_join)
if filters:
q = q.where(and_(*filters))
result = await session.execute(q)
return result.scalar_one()
@classmethod
async def exists(
cls: type[Self],
session: AsyncSession,
filters: list[Any],
*,
joins: JoinType | None = None,
outer_join: bool = False,
) -> bool:
"""Check if a record exists.
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
Returns:
True if at least one record matches
"""
q = select(cls.model)
q = _apply_joins(q, joins, outer_join)
q = q.where(and_(*filters)).exists().select()
result = await session.execute(q)
return bool(result.scalar())
@classmethod
async def offset_paginate(
cls: type[Self],
session: AsyncSession,
*,
filters: list[Any] | None = None,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: Sequence[ExecutableOption] | None = None,
order_by: OrderByClause | None = None,
page: int = 1,
items_per_page: int = 20,
include_total: bool = True,
search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None,
search_column: str | None = None,
facet_fields: Sequence[FacetFieldType] | None = None,
filter_by: dict[str, Any] | BaseModel | None = None,
schema: type[BaseModel],
) -> OffsetPaginatedResponse[Any]:
"""Get paginated results using offset-based pagination.
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
load_options: SQLAlchemy loader options
order_by: Column or list of columns to order by
page: Page number (1-indexed)
items_per_page: Number of items per page
include_total: When ``False``, skip the ``COUNT`` query;
``pagination.total_count`` will be ``None``.
search: Search query string or SearchConfig object
search_fields: Fields to search in (overrides class default)
search_column: Restrict search to a single column key.
facet_fields: Columns to compute distinct values for (overrides class default)
filter_by: Dict of {column_key: value} to filter by declared facet fields.
Keys must match the column.key of a facet field. Scalar → equality,
list → IN clause. Raises InvalidFacetFilterError for unknown keys.
schema: Pydantic schema to serialize each item into.
Returns:
PaginatedResponse with OffsetPagination metadata
"""
filters = list(filters) if filters else []
offset = (page - 1) * items_per_page
fb_filters, search_joins = cls._prepare_filter_by(filter_by, facet_fields)
filters.extend(fb_filters)
# Build search filters
if search:
search_filters, new_search_joins = build_search_filters(
cls.model,
search,
search_fields=search_fields,
default_fields=cls.searchable_fields,
search_column=search_column,
)
filters.extend(search_filters)
search_joins.extend(new_search_joins)
# Build query with joins
q = select(cls.model)
# Apply explicit joins
q = _apply_joins(q, joins, outer_join)
# Apply search joins (always outer joins for search)
q = _apply_search_joins(q, search_joins)
if filters:
q = q.where(and_(*filters))
if resolved := cls._resolve_load_options(load_options):
q = q.options(*resolved)
if order_by is not None:
q = q.order_by(order_by)
if include_total:
q = q.offset(offset).limit(items_per_page)
result = await session.execute(q)
raw_items = cast(list[ModelType], result.unique().scalars().all())
# Count query (with same joins and filters)
pk_col = cls.model.__mapper__.primary_key[0]
count_q = select(func.count(func.distinct(getattr(cls.model, pk_col.name))))
count_q = count_q.select_from(cls.model)
# Apply explicit joins to count query
count_q = _apply_joins(count_q, joins, outer_join)
# Apply search joins to count query
count_q = _apply_search_joins(count_q, search_joins)
if filters:
count_q = count_q.where(and_(*filters))
count_result = await session.execute(count_q)
total_count: int | None = count_result.scalar_one()
has_more = page * items_per_page < total_count
else:
# Fetch one extra row to detect if a next page exists without COUNT
q = q.offset(offset).limit(items_per_page + 1)
result = await session.execute(q)
raw_items = cast(list[ModelType], result.unique().scalars().all())
has_more = len(raw_items) > items_per_page
raw_items = raw_items[:items_per_page]
total_count = None
items: list[Any] = [schema.model_validate(item) for item in raw_items]
filter_attributes = await cls._build_filter_attributes(
session, facet_fields, filters, search_joins
)
search_columns = cls._resolve_search_columns(search_fields)
return OffsetPaginatedResponse(
data=items,
pagination=OffsetPagination(
total_count=total_count,
items_per_page=items_per_page,
page=page,
has_more=has_more,
),
filter_attributes=filter_attributes,
search_columns=search_columns,
)
@classmethod
async def cursor_paginate(
cls: type[Self],
session: AsyncSession,
*,
cursor: str | None = None,
filters: list[Any] | None = None,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: Sequence[ExecutableOption] | None = None,
order_by: OrderByClause | None = None,
items_per_page: int = 20,
search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None,
search_column: str | None = None,
facet_fields: Sequence[FacetFieldType] | None = None,
filter_by: dict[str, Any] | BaseModel | None = None,
schema: type[BaseModel],
) -> CursorPaginatedResponse[Any]:
"""Get paginated results using cursor-based pagination.
Args:
session: DB async session.
cursor: Cursor string from a previous ``CursorPagination``.
Omit (or pass ``None``) to start from the beginning.
filters: List of SQLAlchemy filter conditions.
joins: List of ``(model, condition)`` tuples for joining related
tables.
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN.
load_options: SQLAlchemy loader options. Falls back to
``default_load_options`` when not provided.
order_by: Additional ordering applied after the cursor column.
items_per_page: Number of items per page (default 20).
search: Search query string or SearchConfig object.
search_fields: Fields to search in (overrides class default).
search_column: Restrict search to a single column key.
facet_fields: Columns to compute distinct values for (overrides class default).
filter_by: Dict of {column_key: value} to filter by declared facet fields.
Keys must match the column.key of a facet field. Scalar → equality,
list → IN clause. Raises InvalidFacetFilterError for unknown keys.
schema: Optional Pydantic schema to serialize each item into.
Returns:
PaginatedResponse with CursorPagination metadata
"""
filters = list(filters) if filters else []
fb_filters, search_joins = cls._prepare_filter_by(filter_by, facet_fields)
filters.extend(fb_filters)
if cls.cursor_column is None:
raise ValueError(
f"{cls.__name__}.cursor_column is not set. "
"Pass cursor_column=<column> to CrudFactory() to use cursor_paginate."
)
cursor_column: Any = cls.cursor_column
cursor_col_name: str = cursor_column.key
direction = _CursorDirection.NEXT
if cursor is not None:
raw_val, direction = _decode_cursor(cursor)
col_type = cursor_column.property.columns[0].type
cursor_val: Any = _parse_cursor_value(raw_val, col_type)
if direction is _CursorDirection.PREV:
filters.append(cursor_column < cursor_val)
else:
filters.append(cursor_column > cursor_val)
# Build search filters
if search:
search_filters, new_search_joins = build_search_filters(
cls.model,
search,
search_fields=search_fields,
default_fields=cls.searchable_fields,
search_column=search_column,
)
filters.extend(search_filters)
search_joins.extend(new_search_joins)
# Build query
q = select(cls.model)
# Apply explicit joins
q = _apply_joins(q, joins, outer_join)
# Apply search joins (always outer joins)
q = _apply_search_joins(q, search_joins)
if filters:
q = q.where(and_(*filters))
if resolved := cls._resolve_load_options(load_options):
q = q.options(*resolved)
# Cursor column is always the primary sort; reverse direction for prev traversal
if direction is _CursorDirection.PREV:
q = q.order_by(cursor_column.desc())
else:
q = q.order_by(cursor_column)
if order_by is not None:
q = q.order_by(order_by)
# Fetch one extra to detect whether another page exists in this direction
q = q.limit(items_per_page + 1)
result = await session.execute(q)
raw_items = cast(list[ModelType], result.unique().scalars().all())
has_more = len(raw_items) > items_per_page
items_page = raw_items[:items_per_page]
# Restore ascending order when traversing backward
if direction is _CursorDirection.PREV:
items_page = list(reversed(items_page))
# next_cursor: points past the last item in ascending order
next_cursor: str | None = None
if direction is _CursorDirection.NEXT:
if has_more and items_page:
next_cursor = _encode_cursor(
getattr(items_page[-1], cursor_col_name),
direction=_CursorDirection.NEXT,
)
else:
# Going backward: always provide a next_cursor to allow returning forward
if items_page:
next_cursor = _encode_cursor(
getattr(items_page[-1], cursor_col_name),
direction=_CursorDirection.NEXT,
)
# prev_cursor: points before the first item in ascending order
prev_cursor: str | None = None
if direction is _CursorDirection.NEXT and cursor is not None and items_page:
prev_cursor = _encode_cursor(
getattr(items_page[0], cursor_col_name), direction=_CursorDirection.PREV
)
elif direction is _CursorDirection.PREV and has_more and items_page:
prev_cursor = _encode_cursor(
getattr(items_page[0], cursor_col_name), direction=_CursorDirection.PREV
)
items: list[Any] = [schema.model_validate(item) for item in items_page]
filter_attributes = await cls._build_filter_attributes(
session, facet_fields, filters, search_joins
)
search_columns = cls._resolve_search_columns(search_fields)
return CursorPaginatedResponse(
data=items,
pagination=CursorPagination(
next_cursor=next_cursor,
prev_cursor=prev_cursor,
items_per_page=items_per_page,
has_more=has_more,
),
filter_attributes=filter_attributes,
search_columns=search_columns,
)
@overload
@classmethod
async def paginate( # pragma: no cover
cls: type[Self],
session: AsyncSession,
*,
pagination_type: Literal[PaginationType.OFFSET],
filters: list[Any] | None = ...,
joins: JoinType | None = ...,
outer_join: bool = ...,
load_options: Sequence[ExecutableOption] | None = ...,
order_by: OrderByClause | None = ...,
page: int = ...,
cursor: str | None = ...,
items_per_page: int = ...,
include_total: bool = ...,
search: str | SearchConfig | None = ...,
search_fields: Sequence[SearchFieldType] | None = ...,
search_column: str | None = ...,
facet_fields: Sequence[FacetFieldType] | None = ...,
filter_by: dict[str, Any] | BaseModel | None = ...,
schema: type[BaseModel],
) -> OffsetPaginatedResponse[Any]: ...
@overload
@classmethod
async def paginate( # pragma: no cover
cls: type[Self],
session: AsyncSession,
*,
pagination_type: Literal[PaginationType.CURSOR],
filters: list[Any] | None = ...,
joins: JoinType | None = ...,
outer_join: bool = ...,
load_options: Sequence[ExecutableOption] | None = ...,
order_by: OrderByClause | None = ...,
page: int = ...,
cursor: str | None = ...,
items_per_page: int = ...,
include_total: bool = ...,
search: str | SearchConfig | None = ...,
search_fields: Sequence[SearchFieldType] | None = ...,
search_column: str | None = ...,
facet_fields: Sequence[FacetFieldType] | None = ...,
filter_by: dict[str, Any] | BaseModel | None = ...,
schema: type[BaseModel],
) -> CursorPaginatedResponse[Any]: ...
@classmethod
async def paginate(
cls: type[Self],
session: AsyncSession,
*,
pagination_type: PaginationType = PaginationType.OFFSET,
filters: list[Any] | None = None,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: Sequence[ExecutableOption] | None = None,
order_by: OrderByClause | None = None,
page: int = 1,
cursor: str | None = None,
items_per_page: int = 20,
include_total: bool = True,
search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None,
search_column: str | None = None,
facet_fields: Sequence[FacetFieldType] | None = None,
filter_by: dict[str, Any] | BaseModel | None = None,
schema: type[BaseModel],
) -> OffsetPaginatedResponse[Any] | CursorPaginatedResponse[Any]:
"""Get paginated results using either offset or cursor pagination.
Args:
session: DB async session.
pagination_type: Pagination strategy. Defaults to
``PaginationType.OFFSET``.
filters: List of SQLAlchemy filter conditions.
joins: List of ``(model, condition)`` tuples for joining related
tables.
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN.
load_options: SQLAlchemy loader options. Falls back to
``default_load_options`` when not provided.
order_by: Column or expression to order results by.
page: Page number (1-indexed). Only used when
``pagination_type`` is ``OFFSET``.
cursor: Cursor token from a previous
:class:`.CursorPaginatedResponse`. Only used when
``pagination_type`` is ``CURSOR``.
items_per_page: Number of items per page (default 20).
include_total: When ``False``, skip the ``COUNT`` query;
only applies when ``pagination_type`` is ``OFFSET``.
search: Search query string or :class:`.SearchConfig` object.
search_fields: Fields to search in (overrides class default).
search_column: Restrict search to a single column key.
facet_fields: Columns to compute distinct values for (overrides
class default).
filter_by: Dict of ``{column_key: value}`` to filter by declared
facet fields. Keys must match the ``column.key`` of a facet
field. Scalar → equality, list → IN clause. Raises
:exc:`.InvalidFacetFilterError` for unknown keys.
schema: Pydantic schema to serialize each item into.
Returns:
:class:`.OffsetPaginatedResponse` when ``pagination_type`` is
``OFFSET``, :class:`.CursorPaginatedResponse` when it is
``CURSOR``.
"""
if items_per_page < 1:
raise ValueError(f"items_per_page must be >= 1, got {items_per_page}")
match pagination_type:
case PaginationType.CURSOR:
return await cls.cursor_paginate(
session,
cursor=cursor,
filters=filters,
joins=joins,
outer_join=outer_join,
load_options=load_options,
order_by=order_by,
items_per_page=items_per_page,
search=search,
search_fields=search_fields,
search_column=search_column,
facet_fields=facet_fields,
filter_by=filter_by,
schema=schema,
)
case PaginationType.OFFSET:
if page < 1:
raise ValueError(f"page must be >= 1, got {page}")
return await cls.offset_paginate(
session,
filters=filters,
joins=joins,
outer_join=outer_join,
load_options=load_options,
order_by=order_by,
page=page,
items_per_page=items_per_page,
include_total=include_total,
search=search,
search_fields=search_fields,
search_column=search_column,
facet_fields=facet_fields,
filter_by=filter_by,
schema=schema,
)
case _:
raise ValueError(f"Unknown pagination_type: {pagination_type!r}")
def CrudFactory(
model: type[ModelType],
*,
base_class: type[AsyncCrud[Any]] = AsyncCrud,
searchable_fields: Sequence[SearchFieldType] | None = None,
facet_fields: Sequence[FacetFieldType] | None = None,
order_fields: Sequence[QueryableAttribute[Any]] | None = None,
m2m_fields: M2MFieldType | None = None,
default_load_options: Sequence[ExecutableOption] | None = None,
cursor_column: Any | None = None,
) -> type[AsyncCrud[ModelType]]:
"""Create a CRUD class for a specific model.
Args:
model: SQLAlchemy model class
base_class: Optional base class to inherit from instead of ``AsyncCrud``.
Use this to share custom methods across multiple CRUD classes while
still using the factory shorthand.
searchable_fields: Optional list of searchable fields
facet_fields: Optional list of columns to compute distinct values for in paginated
responses. Supports direct columns (``User.status``) and relationship tuples
(``(User.role, Role.name)``). Can be overridden per call.
order_fields: Optional list of model attributes that callers are allowed to order by
via ``offset_paginate_params()``. Can be overridden per call.
m2m_fields: Optional mapping for many-to-many relationships.
Maps schema field names (containing lists of IDs) to
SQLAlchemy relationship attributes.
default_load_options: Default SQLAlchemy loader options applied to all read
queries when no explicit ``load_options`` are passed. Use this
instead of ``lazy="selectin"`` on the model so that loading
strategy is explicit and per-CRUD. Overridden entirely (not
merged) when ``load_options`` is provided at call-site.
cursor_column: Required to call ``cursor_paginate``.
Must be monotonically ordered (e.g. integer PK, UUID v7, timestamp).
See the cursor pagination docs for supported column types.
Returns:
AsyncCrud subclass bound to the model
Example:
```python
from fastapi_toolsets.crud import CrudFactory
from myapp.models import User, Post
UserCrud = CrudFactory(User)
PostCrud = CrudFactory(Post)
# With searchable fields:
UserCrud = CrudFactory(
User,
searchable_fields=[User.username, User.email, (User.role, Role.name)]
)
# With many-to-many fields:
# Schema has `tag_ids: list[UUID]`, model has `tags` relationship to Tag
PostCrud = CrudFactory(
Post,
m2m_fields={"tag_ids": Post.tags},
)
# With facet fields for filter dropdowns / faceted search:
UserCrud = CrudFactory(
User,
facet_fields=[User.status, User.country, (User.role, Role.name)],
)
# With a fixed cursor column for cursor_paginate:
PostCrud = CrudFactory(
Post,
cursor_column=Post.created_at,
)
# With default load strategy (replaces lazy="selectin" on the model):
ArticleCrud = CrudFactory(
Article,
default_load_options=[selectinload(Article.category), selectinload(Article.tags)],
)
# Override default_load_options for a specific call:
article = await ArticleCrud.get(
session,
[Article.id == 1],
load_options=[selectinload(Article.category)], # tags won't load
)
# Usage
user = await UserCrud.get(session, [User.id == 1])
posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id])
# Create with M2M - tag_ids are automatically resolved
post = await PostCrud.create(session, PostCreate(title="Hello", tag_ids=[id1, id2]))
# With search
result = await UserCrud.offset_paginate(session, search="john")
# With joins (inner join by default):
users = await UserCrud.get_multi(
session,
joins=[(Post, Post.user_id == User.id)],
filters=[Post.published == True],
)
# With outer join:
users = await UserCrud.get_multi(
session,
joins=[(Post, Post.user_id == User.id)],
outer_join=True,
)
# With a shared custom base class:
from typing import Generic, TypeVar
from sqlalchemy.orm import DeclarativeBase
T = TypeVar("T", bound=DeclarativeBase)
class AuditedCrud(AsyncCrud[T], Generic[T]):
@classmethod
async def get_active(cls, session):
return await cls.get_multi(session, filters=[cls.model.is_active == True])
UserCrud = CrudFactory(User, base_class=AuditedCrud)
```
"""
cls = type(
f"Async{model.__name__}Crud",
(base_class,),
{
"model": model,
"searchable_fields": searchable_fields,
"facet_fields": facet_fields,
"order_fields": order_fields,
"m2m_fields": m2m_fields,
"default_load_options": default_load_options,
"cursor_column": cursor_column,
},
)
return cast(type[AsyncCrud[ModelType]], cls)