From 19232d3436dcea647fb6d087d9b1a75ada42d0cc Mon Sep 17 00:00:00 2001 From: d3vyce <44915747+d3vyce@users.noreply.github.com> Date: Thu, 12 Mar 2026 22:46:51 +0100 Subject: [PATCH] feat: add AsyncCrud subclass style and base_class param to CrudFactory (#132) --- docs/module/crud.md | 64 ++++++++++++++++++- src/fastapi_toolsets/crud/factory.py | 55 +++++++++++----- tests/test_crud.py | 95 ++++++++++++++++++++++++++++ 3 files changed, 196 insertions(+), 18 deletions(-) diff --git a/docs/module/crud.md b/docs/module/crud.md index 9dc457c..ea6744a 100644 --- a/docs/module/crud.md +++ b/docs/module/crud.md @@ -7,10 +7,12 @@ Generic async CRUD operations for SQLAlchemy models with search, pagination, and ## Overview -The `crud` module provides [`AsyncCrud`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud), an abstract base class with a full suite of async database operations, and [`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory), a convenience function to instantiate it for a given model. +The `crud` module provides [`AsyncCrud`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud), a base class with a full suite of async database operations, and [`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory), a convenience function to instantiate it for a given model. ## Creating a CRUD class +### Factory style + ```python from fastapi_toolsets.crud import CrudFactory from myapp.models import User @@ -18,7 +20,65 @@ from myapp.models import User UserCrud = CrudFactory(model=User) ``` -[`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory) dynamically creates a class named `AsyncUserCrud` with `User` as its model. +[`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory) dynamically creates a class named `AsyncUserCrud` with `User` as its model. This is the most concise option for straightforward CRUD with no custom logic. + +### Subclass style + +!!! info "Added in `v2.3.0`" + +```python +from fastapi_toolsets.crud.factory import AsyncCrud +from myapp.models import User + +class UserCrud(AsyncCrud[User]): + model = User + searchable_fields = [User.username, User.email] + default_load_options = [selectinload(User.role)] +``` + +Subclassing [`AsyncCrud`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud) directly is the preferred style when you need to add custom methods or when the configuration is complex enough to benefit from a named class body. + +### Adding custom methods + +```python +class UserCrud(AsyncCrud[User]): + model = User + + @classmethod + async def get_active(cls, session: AsyncSession) -> list[User]: + return await cls.get_multi(session, filters=[User.is_active == True]) +``` + +### Sharing a custom base across multiple models + +Define a generic base class with the shared methods, then subclass it for each model: + +```python +from typing import Generic, TypeVar +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import DeclarativeBase +from fastapi_toolsets.crud.factory import AsyncCrud + +T = TypeVar("T", bound=DeclarativeBase) + +class AuditedCrud(AsyncCrud[T], Generic[T]): + """Base CRUD with custom function""" + + @classmethod + async def get_active(cls, session: AsyncSession): + return await cls.get_multi(session, filters=[cls.model.is_active == True]) + + +class UserCrud(AuditedCrud[User]): + model = User + searchable_fields = [User.username, User.email] +``` + +You can also use the factory shorthand with the same base by passing `base_class`: + +```python +UserCrud = CrudFactory(User, base_class=AuditedCrud) +``` ## Basic operations diff --git a/src/fastapi_toolsets/crud/factory.py b/src/fastapi_toolsets/crud/factory.py index 341d354..ce754ff 100644 --- a/src/fastapi_toolsets/crud/factory.py +++ b/src/fastapi_toolsets/crud/factory.py @@ -82,6 +82,26 @@ class AsyncCrud(Generic[ModelType]): default_load_options: ClassVar[Sequence[ExecutableOption] | None] = None cursor_column: ClassVar[Any | None] = None + @classmethod + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + if "model" not in cls.__dict__: + return + model: type[DeclarativeBase] = cls.__dict__["model"] + pk_key = model.__mapper__.primary_key[0].key + assert pk_key is not None + pk_col = getattr(model, pk_key) + + raw_fields: Sequence[SearchFieldType] | None = cls.__dict__.get( + "searchable_fields", None + ) + if raw_fields is None: + cls.searchable_fields = [pk_col] + else: + existing_keys = {f.key for f in raw_fields if not isinstance(f, tuple)} + if pk_key not in existing_keys: + cls.searchable_fields = [pk_col, *raw_fields] + @classmethod def _resolve_load_options( cls, load_options: Sequence[ExecutableOption] | None @@ -1108,6 +1128,7 @@ class AsyncCrud(Generic[ModelType]): def CrudFactory( model: type[ModelType], *, + base_class: type[AsyncCrud[Any]] = AsyncCrud, searchable_fields: Sequence[SearchFieldType] | None = None, facet_fields: Sequence[FacetFieldType] | None = None, order_fields: Sequence[QueryableAttribute[Any]] | None = None, @@ -1119,6 +1140,9 @@ def CrudFactory( Args: model: SQLAlchemy model class + base_class: Optional base class to inherit from instead of ``AsyncCrud``. + Use this to share custom methods across multiple CRUD classes while + still using the factory shorthand. searchable_fields: Optional list of searchable fields facet_fields: Optional list of columns to compute distinct values for in paginated responses. Supports direct columns (``User.status``) and relationship tuples @@ -1209,28 +1233,27 @@ def CrudFactory( joins=[(Post, Post.user_id == User.id)], outer_join=True, ) + + # With a shared custom base class: + from typing import Generic, TypeVar + from sqlalchemy.orm import DeclarativeBase + + T = TypeVar("T", bound=DeclarativeBase) + + class AuditedCrud(AsyncCrud[T], Generic[T]): + @classmethod + async def get_active(cls, session): + return await cls.get_multi(session, filters=[cls.model.is_active == True]) + + UserCrud = CrudFactory(User, base_class=AuditedCrud) ``` """ - pk_key = model.__mapper__.primary_key[0].key - assert pk_key is not None - pk_col = getattr(model, pk_key) - - if searchable_fields is None: - effective_searchable = [pk_col] - else: - existing_keys = {f.key for f in searchable_fields if not isinstance(f, tuple)} - effective_searchable = ( - [pk_col, *searchable_fields] - if pk_key not in existing_keys - else list(searchable_fields) - ) - cls = type( f"Async{model.__name__}Crud", - (AsyncCrud,), + (base_class,), { "model": model, - "searchable_fields": effective_searchable, + "searchable_fields": searchable_fields, "facet_fields": facet_fields, "order_fields": order_fields, "m2m_fields": m2m_fields, diff --git a/tests/test_crud.py b/tests/test_crud.py index f63f21a..58e9a0f 100644 --- a/tests/test_crud.py +++ b/tests/test_crud.py @@ -86,6 +86,101 @@ class TestCrudFactory: assert crud_with.default_load_options == options assert crud_without.default_load_options is None + def test_base_class_custom_methods_inherited(self): + """CrudFactory with base_class inherits custom methods from that base.""" + from typing import Generic, TypeVar + + from sqlalchemy.orm import DeclarativeBase + + T = TypeVar("T", bound=DeclarativeBase) + + class CustomBase(AsyncCrud[T], Generic[T]): + @classmethod + def custom_method(cls) -> str: + return f"custom:{cls.model.__name__}" + + UserCrudCustom = CrudFactory(User, base_class=CustomBase) + PostCrudCustom = CrudFactory(Post, base_class=CustomBase) + + assert issubclass(UserCrudCustom, CustomBase) + assert issubclass(PostCrudCustom, CustomBase) + assert UserCrudCustom.custom_method() == "custom:User" + assert PostCrudCustom.custom_method() == "custom:Post" + + def test_base_class_pk_injected(self): + """PK is still injected when using a custom base_class.""" + from typing import Generic, TypeVar + + from sqlalchemy.orm import DeclarativeBase + + T = TypeVar("T", bound=DeclarativeBase) + + class CustomBase(AsyncCrud[T], Generic[T]): + pass + + crud = CrudFactory(User, base_class=CustomBase) + assert crud.searchable_fields is not None + assert User.id in crud.searchable_fields + + +class TestAsyncCrudSubclass: + """Tests for direct AsyncCrud subclassing (alternative to CrudFactory).""" + + def test_subclass_with_model_only(self): + """Subclassing with just model auto-injects PK into searchable_fields.""" + + class UserCrudDirect(AsyncCrud[User]): + model = User + + assert UserCrudDirect.searchable_fields == [User.id] + + def test_subclass_with_explicit_fields_prepends_pk(self): + """Subclassing with searchable_fields prepends PK automatically.""" + + class UserCrudDirect(AsyncCrud[User]): + model = User + searchable_fields = [User.username] + + assert UserCrudDirect.searchable_fields == [User.id, User.username] + + def test_subclass_with_pk_already_in_fields(self): + """PK is not duplicated when already in searchable_fields.""" + + class UserCrudDirect(AsyncCrud[User]): + model = User + searchable_fields = [User.id, User.username] + + assert UserCrudDirect.searchable_fields == [User.id, User.username] + + def test_subclass_has_default_class_vars(self): + """Other ClassVars are None by default on a direct subclass.""" + + class UserCrudDirect(AsyncCrud[User]): + model = User + + assert UserCrudDirect.facet_fields is None + assert UserCrudDirect.default_load_options is None + assert UserCrudDirect.cursor_column is None + + def test_subclass_with_load_options(self): + """Direct subclass can declare default_load_options.""" + opts = [selectinload(User.role)] + + class UserCrudDirect(AsyncCrud[User]): + model = User + default_load_options = opts + + assert UserCrudDirect.default_load_options is opts + + def test_abstract_base_without_model_not_processed(self): + """Intermediate abstract class without model is not processed.""" + + class AbstractCrud(AsyncCrud[User]): + pass + + # Should not raise, and searchable_fields inherits base default (None) + assert AbstractCrud.searchable_fields is None + class TestResolveLoadOptions: """Tests for _resolve_load_options logic."""