mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-15 22:26:25 +02:00
feat: add AsyncCrud subclass style and base_class param to CrudFactory (#132)
This commit is contained in:
@@ -82,6 +82,26 @@ class AsyncCrud(Generic[ModelType]):
|
||||
default_load_options: ClassVar[Sequence[ExecutableOption] | None] = None
|
||||
cursor_column: ClassVar[Any | None] = None
|
||||
|
||||
@classmethod
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
if "model" not in cls.__dict__:
|
||||
return
|
||||
model: type[DeclarativeBase] = cls.__dict__["model"]
|
||||
pk_key = model.__mapper__.primary_key[0].key
|
||||
assert pk_key is not None
|
||||
pk_col = getattr(model, pk_key)
|
||||
|
||||
raw_fields: Sequence[SearchFieldType] | None = cls.__dict__.get(
|
||||
"searchable_fields", None
|
||||
)
|
||||
if raw_fields is None:
|
||||
cls.searchable_fields = [pk_col]
|
||||
else:
|
||||
existing_keys = {f.key for f in raw_fields if not isinstance(f, tuple)}
|
||||
if pk_key not in existing_keys:
|
||||
cls.searchable_fields = [pk_col, *raw_fields]
|
||||
|
||||
@classmethod
|
||||
def _resolve_load_options(
|
||||
cls, load_options: Sequence[ExecutableOption] | None
|
||||
@@ -1108,6 +1128,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
def CrudFactory(
|
||||
model: type[ModelType],
|
||||
*,
|
||||
base_class: type[AsyncCrud[Any]] = AsyncCrud,
|
||||
searchable_fields: Sequence[SearchFieldType] | None = None,
|
||||
facet_fields: Sequence[FacetFieldType] | None = None,
|
||||
order_fields: Sequence[QueryableAttribute[Any]] | None = None,
|
||||
@@ -1119,6 +1140,9 @@ def CrudFactory(
|
||||
|
||||
Args:
|
||||
model: SQLAlchemy model class
|
||||
base_class: Optional base class to inherit from instead of ``AsyncCrud``.
|
||||
Use this to share custom methods across multiple CRUD classes while
|
||||
still using the factory shorthand.
|
||||
searchable_fields: Optional list of searchable fields
|
||||
facet_fields: Optional list of columns to compute distinct values for in paginated
|
||||
responses. Supports direct columns (``User.status``) and relationship tuples
|
||||
@@ -1209,28 +1233,27 @@ def CrudFactory(
|
||||
joins=[(Post, Post.user_id == User.id)],
|
||||
outer_join=True,
|
||||
)
|
||||
|
||||
# With a shared custom base class:
|
||||
from typing import Generic, TypeVar
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
T = TypeVar("T", bound=DeclarativeBase)
|
||||
|
||||
class AuditedCrud(AsyncCrud[T], Generic[T]):
|
||||
@classmethod
|
||||
async def get_active(cls, session):
|
||||
return await cls.get_multi(session, filters=[cls.model.is_active == True])
|
||||
|
||||
UserCrud = CrudFactory(User, base_class=AuditedCrud)
|
||||
```
|
||||
"""
|
||||
pk_key = model.__mapper__.primary_key[0].key
|
||||
assert pk_key is not None
|
||||
pk_col = getattr(model, pk_key)
|
||||
|
||||
if searchable_fields is None:
|
||||
effective_searchable = [pk_col]
|
||||
else:
|
||||
existing_keys = {f.key for f in searchable_fields if not isinstance(f, tuple)}
|
||||
effective_searchable = (
|
||||
[pk_col, *searchable_fields]
|
||||
if pk_key not in existing_keys
|
||||
else list(searchable_fields)
|
||||
)
|
||||
|
||||
cls = type(
|
||||
f"Async{model.__name__}Crud",
|
||||
(AsyncCrud,),
|
||||
(base_class,),
|
||||
{
|
||||
"model": model,
|
||||
"searchable_fields": effective_searchable,
|
||||
"searchable_fields": searchable_fields,
|
||||
"facet_fields": facet_fields,
|
||||
"order_fields": order_fields,
|
||||
"m2m_fields": m2m_fields,
|
||||
|
||||
Reference in New Issue
Block a user