feat: add AsyncCrud subclass style and base_class param to CrudFactory (#132)

This commit is contained in:
d3vyce
2026-03-12 22:46:51 +01:00
committed by GitHub
parent 1eafcb3873
commit 19232d3436
3 changed files with 196 additions and 18 deletions

View File

@@ -82,6 +82,26 @@ class AsyncCrud(Generic[ModelType]):
default_load_options: ClassVar[Sequence[ExecutableOption] | None] = None
cursor_column: ClassVar[Any | None] = None
@classmethod
def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
if "model" not in cls.__dict__:
return
model: type[DeclarativeBase] = cls.__dict__["model"]
pk_key = model.__mapper__.primary_key[0].key
assert pk_key is not None
pk_col = getattr(model, pk_key)
raw_fields: Sequence[SearchFieldType] | None = cls.__dict__.get(
"searchable_fields", None
)
if raw_fields is None:
cls.searchable_fields = [pk_col]
else:
existing_keys = {f.key for f in raw_fields if not isinstance(f, tuple)}
if pk_key not in existing_keys:
cls.searchable_fields = [pk_col, *raw_fields]
@classmethod
def _resolve_load_options(
cls, load_options: Sequence[ExecutableOption] | None
@@ -1108,6 +1128,7 @@ class AsyncCrud(Generic[ModelType]):
def CrudFactory(
model: type[ModelType],
*,
base_class: type[AsyncCrud[Any]] = AsyncCrud,
searchable_fields: Sequence[SearchFieldType] | None = None,
facet_fields: Sequence[FacetFieldType] | None = None,
order_fields: Sequence[QueryableAttribute[Any]] | None = None,
@@ -1119,6 +1140,9 @@ def CrudFactory(
Args:
model: SQLAlchemy model class
base_class: Optional base class to inherit from instead of ``AsyncCrud``.
Use this to share custom methods across multiple CRUD classes while
still using the factory shorthand.
searchable_fields: Optional list of searchable fields
facet_fields: Optional list of columns to compute distinct values for in paginated
responses. Supports direct columns (``User.status``) and relationship tuples
@@ -1209,28 +1233,27 @@ def CrudFactory(
joins=[(Post, Post.user_id == User.id)],
outer_join=True,
)
# With a shared custom base class:
from typing import Generic, TypeVar
from sqlalchemy.orm import DeclarativeBase
T = TypeVar("T", bound=DeclarativeBase)
class AuditedCrud(AsyncCrud[T], Generic[T]):
@classmethod
async def get_active(cls, session):
return await cls.get_multi(session, filters=[cls.model.is_active == True])
UserCrud = CrudFactory(User, base_class=AuditedCrud)
```
"""
pk_key = model.__mapper__.primary_key[0].key
assert pk_key is not None
pk_col = getattr(model, pk_key)
if searchable_fields is None:
effective_searchable = [pk_col]
else:
existing_keys = {f.key for f in searchable_fields if not isinstance(f, tuple)}
effective_searchable = (
[pk_col, *searchable_fields]
if pk_key not in existing_keys
else list(searchable_fields)
)
cls = type(
f"Async{model.__name__}Crud",
(AsyncCrud,),
(base_class,),
{
"model": model,
"searchable_fields": effective_searchable,
"searchable_fields": searchable_fields,
"facet_fields": facet_fields,
"order_fields": order_fields,
"m2m_fields": m2m_fields,