mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-03-01 17:00:48 +01:00
feat: add join to crud functions (#21)
This commit is contained in:
@@ -4,7 +4,6 @@ from ..exceptions import NoSearchableFieldsError
|
||||
from .factory import CrudFactory
|
||||
from .search import (
|
||||
SearchConfig,
|
||||
SearchFieldType,
|
||||
get_searchable_fields,
|
||||
)
|
||||
|
||||
@@ -13,5 +12,4 @@ __all__ = [
|
||||
"get_searchable_fields",
|
||||
"NoSearchableFieldsError",
|
||||
"SearchConfig",
|
||||
"SearchFieldType",
|
||||
]
|
||||
|
||||
@@ -17,6 +17,7 @@ from ..exceptions import NotFoundError
|
||||
from .search import SearchConfig, SearchFieldType, build_search_filters
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
||||
JoinType = list[tuple[type[DeclarativeBase], Any]]
|
||||
|
||||
|
||||
class AsyncCrud(Generic[ModelType]):
|
||||
@@ -55,6 +56,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
session: AsyncSession,
|
||||
filters: list[Any],
|
||||
*,
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
with_for_update: bool = False,
|
||||
load_options: list[Any] | None = None,
|
||||
) -> ModelType:
|
||||
@@ -63,6 +66,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
Args:
|
||||
session: DB async session
|
||||
filters: List of SQLAlchemy filter conditions
|
||||
joins: List of (model, condition) tuples for joining related tables
|
||||
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
|
||||
with_for_update: Lock the row for update
|
||||
load_options: SQLAlchemy loader options (e.g., selectinload)
|
||||
|
||||
@@ -73,7 +78,15 @@ class AsyncCrud(Generic[ModelType]):
|
||||
NotFoundError: If no record found
|
||||
MultipleResultsFound: If more than one record found
|
||||
"""
|
||||
q = select(cls.model).where(and_(*filters))
|
||||
q = select(cls.model)
|
||||
if joins:
|
||||
for model, condition in joins:
|
||||
q = (
|
||||
q.outerjoin(model, condition)
|
||||
if outer_join
|
||||
else q.join(model, condition)
|
||||
)
|
||||
q = q.where(and_(*filters))
|
||||
if load_options:
|
||||
q = q.options(*load_options)
|
||||
if with_for_update:
|
||||
@@ -90,6 +103,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
session: AsyncSession,
|
||||
filters: list[Any] | None = None,
|
||||
*,
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
load_options: list[Any] | None = None,
|
||||
) -> ModelType | None:
|
||||
"""Get the first matching record, or None.
|
||||
@@ -97,12 +112,21 @@ class AsyncCrud(Generic[ModelType]):
|
||||
Args:
|
||||
session: DB async session
|
||||
filters: List of SQLAlchemy filter conditions
|
||||
joins: List of (model, condition) tuples for joining related tables
|
||||
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
|
||||
load_options: SQLAlchemy loader options
|
||||
|
||||
Returns:
|
||||
Model instance or None
|
||||
"""
|
||||
q = select(cls.model)
|
||||
if joins:
|
||||
for model, condition in joins:
|
||||
q = (
|
||||
q.outerjoin(model, condition)
|
||||
if outer_join
|
||||
else q.join(model, condition)
|
||||
)
|
||||
if filters:
|
||||
q = q.where(and_(*filters))
|
||||
if load_options:
|
||||
@@ -116,6 +140,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
session: AsyncSession,
|
||||
*,
|
||||
filters: list[Any] | None = None,
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
load_options: list[Any] | None = None,
|
||||
order_by: Any | None = None,
|
||||
limit: int | None = None,
|
||||
@@ -126,6 +152,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
Args:
|
||||
session: DB async session
|
||||
filters: List of SQLAlchemy filter conditions
|
||||
joins: List of (model, condition) tuples for joining related tables
|
||||
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
|
||||
load_options: SQLAlchemy loader options
|
||||
order_by: Column or list of columns to order by
|
||||
limit: Max number of rows to return
|
||||
@@ -135,6 +163,13 @@ class AsyncCrud(Generic[ModelType]):
|
||||
List of model instances
|
||||
"""
|
||||
q = select(cls.model)
|
||||
if joins:
|
||||
for model, condition in joins:
|
||||
q = (
|
||||
q.outerjoin(model, condition)
|
||||
if outer_join
|
||||
else q.join(model, condition)
|
||||
)
|
||||
if filters:
|
||||
q = q.where(and_(*filters))
|
||||
if load_options:
|
||||
@@ -254,17 +289,29 @@ class AsyncCrud(Generic[ModelType]):
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
filters: list[Any] | None = None,
|
||||
*,
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
) -> int:
|
||||
"""Count records matching the filters.
|
||||
|
||||
Args:
|
||||
session: DB async session
|
||||
filters: List of SQLAlchemy filter conditions
|
||||
joins: List of (model, condition) tuples for joining related tables
|
||||
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
|
||||
|
||||
Returns:
|
||||
Number of matching records
|
||||
"""
|
||||
q = select(func.count()).select_from(cls.model)
|
||||
if joins:
|
||||
for model, condition in joins:
|
||||
q = (
|
||||
q.outerjoin(model, condition)
|
||||
if outer_join
|
||||
else q.join(model, condition)
|
||||
)
|
||||
if filters:
|
||||
q = q.where(and_(*filters))
|
||||
result = await session.execute(q)
|
||||
@@ -275,17 +322,30 @@ class AsyncCrud(Generic[ModelType]):
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
filters: list[Any],
|
||||
*,
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
) -> bool:
|
||||
"""Check if a record exists.
|
||||
|
||||
Args:
|
||||
session: DB async session
|
||||
filters: List of SQLAlchemy filter conditions
|
||||
joins: List of (model, condition) tuples for joining related tables
|
||||
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
|
||||
|
||||
Returns:
|
||||
True if at least one record matches
|
||||
"""
|
||||
q = select(cls.model).where(and_(*filters)).exists().select()
|
||||
q = select(cls.model)
|
||||
if joins:
|
||||
for model, condition in joins:
|
||||
q = (
|
||||
q.outerjoin(model, condition)
|
||||
if outer_join
|
||||
else q.join(model, condition)
|
||||
)
|
||||
q = q.where(and_(*filters)).exists().select()
|
||||
result = await session.execute(q)
|
||||
return bool(result.scalar())
|
||||
|
||||
@@ -295,6 +355,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
session: AsyncSession,
|
||||
*,
|
||||
filters: list[Any] | None = None,
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
load_options: list[Any] | None = None,
|
||||
order_by: Any | None = None,
|
||||
page: int = 1,
|
||||
@@ -307,6 +369,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
Args:
|
||||
session: DB async session
|
||||
filters: List of SQLAlchemy filter conditions
|
||||
joins: List of (model, condition) tuples for joining related tables
|
||||
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
|
||||
load_options: SQLAlchemy loader options
|
||||
order_by: Column or list of columns to order by
|
||||
page: Page number (1-indexed)
|
||||
@@ -319,7 +383,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
"""
|
||||
filters = list(filters) if filters else []
|
||||
offset = (page - 1) * items_per_page
|
||||
joins: list[Any] = []
|
||||
search_joins: list[Any] = []
|
||||
|
||||
# Build search filters
|
||||
if search:
|
||||
@@ -330,11 +394,21 @@ class AsyncCrud(Generic[ModelType]):
|
||||
default_fields=cls.searchable_fields,
|
||||
)
|
||||
filters.extend(search_filters)
|
||||
joins.extend(search_joins)
|
||||
|
||||
# Build query with joins
|
||||
q = select(cls.model)
|
||||
for join_rel in joins:
|
||||
|
||||
# Apply explicit joins
|
||||
if joins:
|
||||
for model, condition in joins:
|
||||
q = (
|
||||
q.outerjoin(model, condition)
|
||||
if outer_join
|
||||
else q.join(model, condition)
|
||||
)
|
||||
|
||||
# Apply search joins (always outer joins for search)
|
||||
for join_rel in search_joins:
|
||||
q = q.outerjoin(join_rel)
|
||||
|
||||
if filters:
|
||||
@@ -352,8 +426,20 @@ class AsyncCrud(Generic[ModelType]):
|
||||
pk_col = cls.model.__mapper__.primary_key[0]
|
||||
count_q = select(func.count(func.distinct(getattr(cls.model, pk_col.name))))
|
||||
count_q = count_q.select_from(cls.model)
|
||||
for join_rel in joins:
|
||||
|
||||
# Apply explicit joins to count query
|
||||
if joins:
|
||||
for model, condition in joins:
|
||||
count_q = (
|
||||
count_q.outerjoin(model, condition)
|
||||
if outer_join
|
||||
else count_q.join(model, condition)
|
||||
)
|
||||
|
||||
# Apply search joins to count query
|
||||
for join_rel in search_joins:
|
||||
count_q = count_q.outerjoin(join_rel)
|
||||
|
||||
if filters:
|
||||
count_q = count_q.where(and_(*filters))
|
||||
|
||||
@@ -404,6 +490,20 @@ def CrudFactory(
|
||||
|
||||
# With search
|
||||
result = await UserCrud.paginate(session, search="john")
|
||||
|
||||
# With joins (inner join by default):
|
||||
users = await UserCrud.get_multi(
|
||||
session,
|
||||
joins=[(Post, Post.user_id == User.id)],
|
||||
filters=[Post.published == True],
|
||||
)
|
||||
|
||||
# With outer join:
|
||||
users = await UserCrud.get_multi(
|
||||
session,
|
||||
joins=[(Post, Post.user_id == User.id)],
|
||||
outer_join=True,
|
||||
)
|
||||
"""
|
||||
cls = type(
|
||||
f"Async{model.__name__}Crud",
|
||||
|
||||
Reference in New Issue
Block a user