mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-16 14:46:24 +02:00
feat: add AsyncCrud subclass style and base_class param to CrudFactory (#132)
This commit is contained in:
@@ -7,10 +7,12 @@ Generic async CRUD operations for SQLAlchemy models with search, pagination, and
|
|||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
The `crud` module provides [`AsyncCrud`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud), an abstract base class with a full suite of async database operations, and [`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory), a convenience function to instantiate it for a given model.
|
The `crud` module provides [`AsyncCrud`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud), a base class with a full suite of async database operations, and [`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory), a convenience function to instantiate it for a given model.
|
||||||
|
|
||||||
## Creating a CRUD class
|
## Creating a CRUD class
|
||||||
|
|
||||||
|
### Factory style
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from fastapi_toolsets.crud import CrudFactory
|
from fastapi_toolsets.crud import CrudFactory
|
||||||
from myapp.models import User
|
from myapp.models import User
|
||||||
@@ -18,7 +20,65 @@ from myapp.models import User
|
|||||||
UserCrud = CrudFactory(model=User)
|
UserCrud = CrudFactory(model=User)
|
||||||
```
|
```
|
||||||
|
|
||||||
[`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory) dynamically creates a class named `AsyncUserCrud` with `User` as its model.
|
[`CrudFactory`](../reference/crud.md#fastapi_toolsets.crud.factory.CrudFactory) dynamically creates a class named `AsyncUserCrud` with `User` as its model. This is the most concise option for straightforward CRUD with no custom logic.
|
||||||
|
|
||||||
|
### Subclass style
|
||||||
|
|
||||||
|
!!! info "Added in `v2.3.0`"
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi_toolsets.crud.factory import AsyncCrud
|
||||||
|
from myapp.models import User
|
||||||
|
|
||||||
|
class UserCrud(AsyncCrud[User]):
|
||||||
|
model = User
|
||||||
|
searchable_fields = [User.username, User.email]
|
||||||
|
default_load_options = [selectinload(User.role)]
|
||||||
|
```
|
||||||
|
|
||||||
|
Subclassing [`AsyncCrud`](../reference/crud.md#fastapi_toolsets.crud.factory.AsyncCrud) directly is the preferred style when you need to add custom methods or when the configuration is complex enough to benefit from a named class body.
|
||||||
|
|
||||||
|
### Adding custom methods
|
||||||
|
|
||||||
|
```python
|
||||||
|
class UserCrud(AsyncCrud[User]):
|
||||||
|
model = User
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_active(cls, session: AsyncSession) -> list[User]:
|
||||||
|
return await cls.get_multi(session, filters=[User.is_active == True])
|
||||||
|
```
|
||||||
|
|
||||||
|
### Sharing a custom base across multiple models
|
||||||
|
|
||||||
|
Define a generic base class with the shared methods, then subclass it for each model:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from typing import Generic, TypeVar
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
from fastapi_toolsets.crud.factory import AsyncCrud
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=DeclarativeBase)
|
||||||
|
|
||||||
|
class AuditedCrud(AsyncCrud[T], Generic[T]):
|
||||||
|
"""Base CRUD with custom function"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_active(cls, session: AsyncSession):
|
||||||
|
return await cls.get_multi(session, filters=[cls.model.is_active == True])
|
||||||
|
|
||||||
|
|
||||||
|
class UserCrud(AuditedCrud[User]):
|
||||||
|
model = User
|
||||||
|
searchable_fields = [User.username, User.email]
|
||||||
|
```
|
||||||
|
|
||||||
|
You can also use the factory shorthand with the same base by passing `base_class`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
UserCrud = CrudFactory(User, base_class=AuditedCrud)
|
||||||
|
```
|
||||||
|
|
||||||
## Basic operations
|
## Basic operations
|
||||||
|
|
||||||
|
|||||||
@@ -82,6 +82,26 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
default_load_options: ClassVar[Sequence[ExecutableOption] | None] = None
|
default_load_options: ClassVar[Sequence[ExecutableOption] | None] = None
|
||||||
cursor_column: ClassVar[Any | None] = None
|
cursor_column: ClassVar[Any | None] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||||
|
super().__init_subclass__(**kwargs)
|
||||||
|
if "model" not in cls.__dict__:
|
||||||
|
return
|
||||||
|
model: type[DeclarativeBase] = cls.__dict__["model"]
|
||||||
|
pk_key = model.__mapper__.primary_key[0].key
|
||||||
|
assert pk_key is not None
|
||||||
|
pk_col = getattr(model, pk_key)
|
||||||
|
|
||||||
|
raw_fields: Sequence[SearchFieldType] | None = cls.__dict__.get(
|
||||||
|
"searchable_fields", None
|
||||||
|
)
|
||||||
|
if raw_fields is None:
|
||||||
|
cls.searchable_fields = [pk_col]
|
||||||
|
else:
|
||||||
|
existing_keys = {f.key for f in raw_fields if not isinstance(f, tuple)}
|
||||||
|
if pk_key not in existing_keys:
|
||||||
|
cls.searchable_fields = [pk_col, *raw_fields]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _resolve_load_options(
|
def _resolve_load_options(
|
||||||
cls, load_options: Sequence[ExecutableOption] | None
|
cls, load_options: Sequence[ExecutableOption] | None
|
||||||
@@ -1108,6 +1128,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
def CrudFactory(
|
def CrudFactory(
|
||||||
model: type[ModelType],
|
model: type[ModelType],
|
||||||
*,
|
*,
|
||||||
|
base_class: type[AsyncCrud[Any]] = AsyncCrud,
|
||||||
searchable_fields: Sequence[SearchFieldType] | None = None,
|
searchable_fields: Sequence[SearchFieldType] | None = None,
|
||||||
facet_fields: Sequence[FacetFieldType] | None = None,
|
facet_fields: Sequence[FacetFieldType] | None = None,
|
||||||
order_fields: Sequence[QueryableAttribute[Any]] | None = None,
|
order_fields: Sequence[QueryableAttribute[Any]] | None = None,
|
||||||
@@ -1119,6 +1140,9 @@ def CrudFactory(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: SQLAlchemy model class
|
model: SQLAlchemy model class
|
||||||
|
base_class: Optional base class to inherit from instead of ``AsyncCrud``.
|
||||||
|
Use this to share custom methods across multiple CRUD classes while
|
||||||
|
still using the factory shorthand.
|
||||||
searchable_fields: Optional list of searchable fields
|
searchable_fields: Optional list of searchable fields
|
||||||
facet_fields: Optional list of columns to compute distinct values for in paginated
|
facet_fields: Optional list of columns to compute distinct values for in paginated
|
||||||
responses. Supports direct columns (``User.status``) and relationship tuples
|
responses. Supports direct columns (``User.status``) and relationship tuples
|
||||||
@@ -1209,28 +1233,27 @@ def CrudFactory(
|
|||||||
joins=[(Post, Post.user_id == User.id)],
|
joins=[(Post, Post.user_id == User.id)],
|
||||||
outer_join=True,
|
outer_join=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# With a shared custom base class:
|
||||||
|
from typing import Generic, TypeVar
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=DeclarativeBase)
|
||||||
|
|
||||||
|
class AuditedCrud(AsyncCrud[T], Generic[T]):
|
||||||
|
@classmethod
|
||||||
|
async def get_active(cls, session):
|
||||||
|
return await cls.get_multi(session, filters=[cls.model.is_active == True])
|
||||||
|
|
||||||
|
UserCrud = CrudFactory(User, base_class=AuditedCrud)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
pk_key = model.__mapper__.primary_key[0].key
|
|
||||||
assert pk_key is not None
|
|
||||||
pk_col = getattr(model, pk_key)
|
|
||||||
|
|
||||||
if searchable_fields is None:
|
|
||||||
effective_searchable = [pk_col]
|
|
||||||
else:
|
|
||||||
existing_keys = {f.key for f in searchable_fields if not isinstance(f, tuple)}
|
|
||||||
effective_searchable = (
|
|
||||||
[pk_col, *searchable_fields]
|
|
||||||
if pk_key not in existing_keys
|
|
||||||
else list(searchable_fields)
|
|
||||||
)
|
|
||||||
|
|
||||||
cls = type(
|
cls = type(
|
||||||
f"Async{model.__name__}Crud",
|
f"Async{model.__name__}Crud",
|
||||||
(AsyncCrud,),
|
(base_class,),
|
||||||
{
|
{
|
||||||
"model": model,
|
"model": model,
|
||||||
"searchable_fields": effective_searchable,
|
"searchable_fields": searchable_fields,
|
||||||
"facet_fields": facet_fields,
|
"facet_fields": facet_fields,
|
||||||
"order_fields": order_fields,
|
"order_fields": order_fields,
|
||||||
"m2m_fields": m2m_fields,
|
"m2m_fields": m2m_fields,
|
||||||
|
|||||||
@@ -86,6 +86,101 @@ class TestCrudFactory:
|
|||||||
assert crud_with.default_load_options == options
|
assert crud_with.default_load_options == options
|
||||||
assert crud_without.default_load_options is None
|
assert crud_without.default_load_options is None
|
||||||
|
|
||||||
|
def test_base_class_custom_methods_inherited(self):
|
||||||
|
"""CrudFactory with base_class inherits custom methods from that base."""
|
||||||
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=DeclarativeBase)
|
||||||
|
|
||||||
|
class CustomBase(AsyncCrud[T], Generic[T]):
|
||||||
|
@classmethod
|
||||||
|
def custom_method(cls) -> str:
|
||||||
|
return f"custom:{cls.model.__name__}"
|
||||||
|
|
||||||
|
UserCrudCustom = CrudFactory(User, base_class=CustomBase)
|
||||||
|
PostCrudCustom = CrudFactory(Post, base_class=CustomBase)
|
||||||
|
|
||||||
|
assert issubclass(UserCrudCustom, CustomBase)
|
||||||
|
assert issubclass(PostCrudCustom, CustomBase)
|
||||||
|
assert UserCrudCustom.custom_method() == "custom:User"
|
||||||
|
assert PostCrudCustom.custom_method() == "custom:Post"
|
||||||
|
|
||||||
|
def test_base_class_pk_injected(self):
|
||||||
|
"""PK is still injected when using a custom base_class."""
|
||||||
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=DeclarativeBase)
|
||||||
|
|
||||||
|
class CustomBase(AsyncCrud[T], Generic[T]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
crud = CrudFactory(User, base_class=CustomBase)
|
||||||
|
assert crud.searchable_fields is not None
|
||||||
|
assert User.id in crud.searchable_fields
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncCrudSubclass:
|
||||||
|
"""Tests for direct AsyncCrud subclassing (alternative to CrudFactory)."""
|
||||||
|
|
||||||
|
def test_subclass_with_model_only(self):
|
||||||
|
"""Subclassing with just model auto-injects PK into searchable_fields."""
|
||||||
|
|
||||||
|
class UserCrudDirect(AsyncCrud[User]):
|
||||||
|
model = User
|
||||||
|
|
||||||
|
assert UserCrudDirect.searchable_fields == [User.id]
|
||||||
|
|
||||||
|
def test_subclass_with_explicit_fields_prepends_pk(self):
|
||||||
|
"""Subclassing with searchable_fields prepends PK automatically."""
|
||||||
|
|
||||||
|
class UserCrudDirect(AsyncCrud[User]):
|
||||||
|
model = User
|
||||||
|
searchable_fields = [User.username]
|
||||||
|
|
||||||
|
assert UserCrudDirect.searchable_fields == [User.id, User.username]
|
||||||
|
|
||||||
|
def test_subclass_with_pk_already_in_fields(self):
|
||||||
|
"""PK is not duplicated when already in searchable_fields."""
|
||||||
|
|
||||||
|
class UserCrudDirect(AsyncCrud[User]):
|
||||||
|
model = User
|
||||||
|
searchable_fields = [User.id, User.username]
|
||||||
|
|
||||||
|
assert UserCrudDirect.searchable_fields == [User.id, User.username]
|
||||||
|
|
||||||
|
def test_subclass_has_default_class_vars(self):
|
||||||
|
"""Other ClassVars are None by default on a direct subclass."""
|
||||||
|
|
||||||
|
class UserCrudDirect(AsyncCrud[User]):
|
||||||
|
model = User
|
||||||
|
|
||||||
|
assert UserCrudDirect.facet_fields is None
|
||||||
|
assert UserCrudDirect.default_load_options is None
|
||||||
|
assert UserCrudDirect.cursor_column is None
|
||||||
|
|
||||||
|
def test_subclass_with_load_options(self):
|
||||||
|
"""Direct subclass can declare default_load_options."""
|
||||||
|
opts = [selectinload(User.role)]
|
||||||
|
|
||||||
|
class UserCrudDirect(AsyncCrud[User]):
|
||||||
|
model = User
|
||||||
|
default_load_options = opts
|
||||||
|
|
||||||
|
assert UserCrudDirect.default_load_options is opts
|
||||||
|
|
||||||
|
def test_abstract_base_without_model_not_processed(self):
|
||||||
|
"""Intermediate abstract class without model is not processed."""
|
||||||
|
|
||||||
|
class AbstractCrud(AsyncCrud[User]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Should not raise, and searchable_fields inherits base default (None)
|
||||||
|
assert AbstractCrud.searchable_fields is None
|
||||||
|
|
||||||
|
|
||||||
class TestResolveLoadOptions:
|
class TestResolveLoadOptions:
|
||||||
"""Tests for _resolve_load_options logic."""
|
"""Tests for _resolve_load_options logic."""
|
||||||
|
|||||||
Reference in New Issue
Block a user