feat: add join to crud functions (#21)

This commit is contained in:
d3vyce
2026-02-01 15:01:10 +01:00
committed by GitHub
parent 54f5479c24
commit 8c287b3ce7
3 changed files with 377 additions and 8 deletions

View File

@@ -4,7 +4,6 @@ from ..exceptions import NoSearchableFieldsError
from .factory import CrudFactory from .factory import CrudFactory
from .search import ( from .search import (
SearchConfig, SearchConfig,
SearchFieldType,
get_searchable_fields, get_searchable_fields,
) )
@@ -13,5 +12,4 @@ __all__ = [
"get_searchable_fields", "get_searchable_fields",
"NoSearchableFieldsError", "NoSearchableFieldsError",
"SearchConfig", "SearchConfig",
"SearchFieldType",
] ]

View File

@@ -17,6 +17,7 @@ from ..exceptions import NotFoundError
from .search import SearchConfig, SearchFieldType, build_search_filters from .search import SearchConfig, SearchFieldType, build_search_filters
ModelType = TypeVar("ModelType", bound=DeclarativeBase) ModelType = TypeVar("ModelType", bound=DeclarativeBase)
JoinType = list[tuple[type[DeclarativeBase], Any]]
class AsyncCrud(Generic[ModelType]): class AsyncCrud(Generic[ModelType]):
@@ -55,6 +56,8 @@ class AsyncCrud(Generic[ModelType]):
session: AsyncSession, session: AsyncSession,
filters: list[Any], filters: list[Any],
*, *,
joins: JoinType | None = None,
outer_join: bool = False,
with_for_update: bool = False, with_for_update: bool = False,
load_options: list[Any] | None = None, load_options: list[Any] | None = None,
) -> ModelType: ) -> ModelType:
@@ -63,6 +66,8 @@ class AsyncCrud(Generic[ModelType]):
Args: Args:
session: DB async session session: DB async session
filters: List of SQLAlchemy filter conditions filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
with_for_update: Lock the row for update with_for_update: Lock the row for update
load_options: SQLAlchemy loader options (e.g., selectinload) load_options: SQLAlchemy loader options (e.g., selectinload)
@@ -73,7 +78,15 @@ class AsyncCrud(Generic[ModelType]):
NotFoundError: If no record found NotFoundError: If no record found
MultipleResultsFound: If more than one record found MultipleResultsFound: If more than one record found
""" """
q = select(cls.model).where(and_(*filters)) q = select(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
q = q.where(and_(*filters))
if load_options: if load_options:
q = q.options(*load_options) q = q.options(*load_options)
if with_for_update: if with_for_update:
@@ -90,6 +103,8 @@ class AsyncCrud(Generic[ModelType]):
session: AsyncSession, session: AsyncSession,
filters: list[Any] | None = None, filters: list[Any] | None = None,
*, *,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: list[Any] | None = None, load_options: list[Any] | None = None,
) -> ModelType | None: ) -> ModelType | None:
"""Get the first matching record, or None. """Get the first matching record, or None.
@@ -97,12 +112,21 @@ class AsyncCrud(Generic[ModelType]):
Args: Args:
session: DB async session session: DB async session
filters: List of SQLAlchemy filter conditions filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
load_options: SQLAlchemy loader options load_options: SQLAlchemy loader options
Returns: Returns:
Model instance or None Model instance or None
""" """
q = select(cls.model) q = select(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
if filters: if filters:
q = q.where(and_(*filters)) q = q.where(and_(*filters))
if load_options: if load_options:
@@ -116,6 +140,8 @@ class AsyncCrud(Generic[ModelType]):
session: AsyncSession, session: AsyncSession,
*, *,
filters: list[Any] | None = None, filters: list[Any] | None = None,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: list[Any] | None = None, load_options: list[Any] | None = None,
order_by: Any | None = None, order_by: Any | None = None,
limit: int | None = None, limit: int | None = None,
@@ -126,6 +152,8 @@ class AsyncCrud(Generic[ModelType]):
Args: Args:
session: DB async session session: DB async session
filters: List of SQLAlchemy filter conditions filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
load_options: SQLAlchemy loader options load_options: SQLAlchemy loader options
order_by: Column or list of columns to order by order_by: Column or list of columns to order by
limit: Max number of rows to return limit: Max number of rows to return
@@ -135,6 +163,13 @@ class AsyncCrud(Generic[ModelType]):
List of model instances List of model instances
""" """
q = select(cls.model) q = select(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
if filters: if filters:
q = q.where(and_(*filters)) q = q.where(and_(*filters))
if load_options: if load_options:
@@ -254,17 +289,29 @@ class AsyncCrud(Generic[ModelType]):
cls: type[Self], cls: type[Self],
session: AsyncSession, session: AsyncSession,
filters: list[Any] | None = None, filters: list[Any] | None = None,
*,
joins: JoinType | None = None,
outer_join: bool = False,
) -> int: ) -> int:
"""Count records matching the filters. """Count records matching the filters.
Args: Args:
session: DB async session session: DB async session
filters: List of SQLAlchemy filter conditions filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
Returns: Returns:
Number of matching records Number of matching records
""" """
q = select(func.count()).select_from(cls.model) q = select(func.count()).select_from(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
if filters: if filters:
q = q.where(and_(*filters)) q = q.where(and_(*filters))
result = await session.execute(q) result = await session.execute(q)
@@ -275,17 +322,30 @@ class AsyncCrud(Generic[ModelType]):
cls: type[Self], cls: type[Self],
session: AsyncSession, session: AsyncSession,
filters: list[Any], filters: list[Any],
*,
joins: JoinType | None = None,
outer_join: bool = False,
) -> bool: ) -> bool:
"""Check if a record exists. """Check if a record exists.
Args: Args:
session: DB async session session: DB async session
filters: List of SQLAlchemy filter conditions filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
Returns: Returns:
True if at least one record matches True if at least one record matches
""" """
q = select(cls.model).where(and_(*filters)).exists().select() q = select(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
q = q.where(and_(*filters)).exists().select()
result = await session.execute(q) result = await session.execute(q)
return bool(result.scalar()) return bool(result.scalar())
@@ -295,6 +355,8 @@ class AsyncCrud(Generic[ModelType]):
session: AsyncSession, session: AsyncSession,
*, *,
filters: list[Any] | None = None, filters: list[Any] | None = None,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: list[Any] | None = None, load_options: list[Any] | None = None,
order_by: Any | None = None, order_by: Any | None = None,
page: int = 1, page: int = 1,
@@ -307,6 +369,8 @@ class AsyncCrud(Generic[ModelType]):
Args: Args:
session: DB async session session: DB async session
filters: List of SQLAlchemy filter conditions filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
load_options: SQLAlchemy loader options load_options: SQLAlchemy loader options
order_by: Column or list of columns to order by order_by: Column or list of columns to order by
page: Page number (1-indexed) page: Page number (1-indexed)
@@ -319,7 +383,7 @@ class AsyncCrud(Generic[ModelType]):
""" """
filters = list(filters) if filters else [] filters = list(filters) if filters else []
offset = (page - 1) * items_per_page offset = (page - 1) * items_per_page
joins: list[Any] = [] search_joins: list[Any] = []
# Build search filters # Build search filters
if search: if search:
@@ -330,11 +394,21 @@ class AsyncCrud(Generic[ModelType]):
default_fields=cls.searchable_fields, default_fields=cls.searchable_fields,
) )
filters.extend(search_filters) filters.extend(search_filters)
joins.extend(search_joins)
# Build query with joins # Build query with joins
q = select(cls.model) q = select(cls.model)
for join_rel in joins:
# Apply explicit joins
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
# Apply search joins (always outer joins for search)
for join_rel in search_joins:
q = q.outerjoin(join_rel) q = q.outerjoin(join_rel)
if filters: if filters:
@@ -352,8 +426,20 @@ class AsyncCrud(Generic[ModelType]):
pk_col = cls.model.__mapper__.primary_key[0] pk_col = cls.model.__mapper__.primary_key[0]
count_q = select(func.count(func.distinct(getattr(cls.model, pk_col.name)))) count_q = select(func.count(func.distinct(getattr(cls.model, pk_col.name))))
count_q = count_q.select_from(cls.model) count_q = count_q.select_from(cls.model)
for join_rel in joins:
# Apply explicit joins to count query
if joins:
for model, condition in joins:
count_q = (
count_q.outerjoin(model, condition)
if outer_join
else count_q.join(model, condition)
)
# Apply search joins to count query
for join_rel in search_joins:
count_q = count_q.outerjoin(join_rel) count_q = count_q.outerjoin(join_rel)
if filters: if filters:
count_q = count_q.where(and_(*filters)) count_q = count_q.where(and_(*filters))
@@ -404,6 +490,20 @@ def CrudFactory(
# With search # With search
result = await UserCrud.paginate(session, search="john") result = await UserCrud.paginate(session, search="john")
# With joins (inner join by default):
users = await UserCrud.get_multi(
session,
joins=[(Post, Post.user_id == User.id)],
filters=[Post.published == True],
)
# With outer join:
users = await UserCrud.get_multi(
session,
joins=[(Post, Post.user_id == User.id)],
outer_join=True,
)
""" """
cls = type( cls = type(
f"Async{model.__name__}Crud", f"Async{model.__name__}Crud",

View File

@@ -10,6 +10,9 @@ from fastapi_toolsets.crud.factory import AsyncCrud
from fastapi_toolsets.exceptions import NotFoundError from fastapi_toolsets.exceptions import NotFoundError
from .conftest import ( from .conftest import (
Post,
PostCreate,
PostCrud,
Role, Role,
RoleCreate, RoleCreate,
RoleCrud, RoleCrud,
@@ -481,3 +484,271 @@ class TestCrudPaginate:
names = [r.name for r in result["data"]] names = [r.name for r in result["data"]]
assert names == ["alpha", "bravo", "charlie"] assert names == ["alpha", "bravo", "charlie"]
class TestCrudJoins:
"""Tests for CRUD operations with joins."""
@pytest.mark.anyio
async def test_get_with_join(self, db_session: AsyncSession):
"""Get with inner join filters correctly."""
# Create user with posts
user = await UserCrud.create(
db_session,
UserCreate(username="author", email="author@test.com"),
)
await PostCrud.create(
db_session,
PostCreate(title="Post 1", author_id=user.id, is_published=True),
)
# Get user with join on published posts
fetched = await UserCrud.get(
db_session,
filters=[User.id == user.id, Post.is_published == True], # noqa: E712
joins=[(Post, Post.author_id == User.id)],
)
assert fetched.id == user.id
@pytest.mark.anyio
async def test_first_with_join(self, db_session: AsyncSession):
"""First with join returns matching record."""
user = await UserCrud.create(
db_session,
UserCreate(username="writer", email="writer@test.com"),
)
await PostCrud.create(
db_session,
PostCreate(title="Draft", author_id=user.id, is_published=False),
)
# Find user with unpublished posts
result = await UserCrud.first(
db_session,
filters=[Post.is_published == False], # noqa: E712
joins=[(Post, Post.author_id == User.id)],
)
assert result is not None
assert result.id == user.id
@pytest.mark.anyio
async def test_first_with_outer_join(self, db_session: AsyncSession):
"""First with outer join includes records without related data."""
# User without posts
user = await UserCrud.create(
db_session,
UserCreate(username="no_posts", email="no_posts@test.com"),
)
# With outer join, user should be found even without posts
result = await UserCrud.first(
db_session,
filters=[User.id == user.id],
joins=[(Post, Post.author_id == User.id)],
outer_join=True,
)
assert result is not None
assert result.id == user.id
@pytest.mark.anyio
async def test_get_multi_with_inner_join(self, db_session: AsyncSession):
"""Get multiple with inner join only returns matching records."""
# User with published post
user1 = await UserCrud.create(
db_session,
UserCreate(username="publisher", email="pub@test.com"),
)
await PostCrud.create(
db_session,
PostCreate(title="Published", author_id=user1.id, is_published=True),
)
# User without posts
await UserCrud.create(
db_session,
UserCreate(username="lurker", email="lurk@test.com"),
)
# Inner join should only return user with published post
users = await UserCrud.get_multi(
db_session,
joins=[(Post, Post.author_id == User.id)],
filters=[Post.is_published == True], # noqa: E712
)
assert len(users) == 1
assert users[0].username == "publisher"
@pytest.mark.anyio
async def test_get_multi_with_outer_join(self, db_session: AsyncSession):
"""Get multiple with outer join includes all records."""
# User with post
user1 = await UserCrud.create(
db_session,
UserCreate(username="has_post", email="has@test.com"),
)
await PostCrud.create(
db_session,
PostCreate(title="My Post", author_id=user1.id),
)
# User without posts
await UserCrud.create(
db_session,
UserCreate(username="no_post", email="no@test.com"),
)
# Outer join should return both users
users = await UserCrud.get_multi(
db_session,
joins=[(Post, Post.author_id == User.id)],
outer_join=True,
)
assert len(users) == 2
@pytest.mark.anyio
async def test_count_with_join(self, db_session: AsyncSession):
"""Count with join counts correctly."""
# Create users with different post statuses
user1 = await UserCrud.create(
db_session,
UserCreate(username="active_author", email="active@test.com"),
)
await PostCrud.create(
db_session,
PostCreate(title="Published 1", author_id=user1.id, is_published=True),
)
user2 = await UserCrud.create(
db_session,
UserCreate(username="draft_author", email="draft@test.com"),
)
await PostCrud.create(
db_session,
PostCreate(title="Draft 1", author_id=user2.id, is_published=False),
)
# Count users with published posts
count = await UserCrud.count(
db_session,
filters=[Post.is_published == True], # noqa: E712
joins=[(Post, Post.author_id == User.id)],
)
assert count == 1
@pytest.mark.anyio
async def test_exists_with_join(self, db_session: AsyncSession):
"""Exists with join checks correctly."""
user = await UserCrud.create(
db_session,
UserCreate(username="poster", email="poster@test.com"),
)
await PostCrud.create(
db_session,
PostCreate(title="Exists Post", author_id=user.id, is_published=True),
)
# Check if user with published post exists
exists = await UserCrud.exists(
db_session,
filters=[Post.is_published == True], # noqa: E712
joins=[(Post, Post.author_id == User.id)],
)
assert exists is True
# Check if user with specific title exists
exists = await UserCrud.exists(
db_session,
filters=[Post.title == "Nonexistent"],
joins=[(Post, Post.author_id == User.id)],
)
assert exists is False
@pytest.mark.anyio
async def test_paginate_with_join(self, db_session: AsyncSession):
"""Paginate with join works correctly."""
# Create users with posts
for i in range(5):
user = await UserCrud.create(
db_session,
UserCreate(username=f"author{i}", email=f"author{i}@test.com"),
)
await PostCrud.create(
db_session,
PostCreate(
title=f"Post {i}",
author_id=user.id,
is_published=i % 2 == 0,
),
)
# Paginate users with published posts
result = await UserCrud.paginate(
db_session,
joins=[(Post, Post.author_id == User.id)],
filters=[Post.is_published == True], # noqa: E712
page=1,
items_per_page=10,
)
assert result["pagination"]["total_count"] == 3
assert len(result["data"]) == 3
@pytest.mark.anyio
async def test_paginate_with_outer_join(self, db_session: AsyncSession):
"""Paginate with outer join includes all records."""
# User with post
user1 = await UserCrud.create(
db_session,
UserCreate(username="with_post", email="with@test.com"),
)
await PostCrud.create(
db_session,
PostCreate(title="A Post", author_id=user1.id),
)
# User without post
await UserCrud.create(
db_session,
UserCreate(username="without_post", email="without@test.com"),
)
# Paginate with outer join
result = await UserCrud.paginate(
db_session,
joins=[(Post, Post.author_id == User.id)],
outer_join=True,
page=1,
items_per_page=10,
)
assert result["pagination"]["total_count"] == 2
assert len(result["data"]) == 2
@pytest.mark.anyio
async def test_multiple_joins(self, db_session: AsyncSession):
"""Multiple joins can be applied."""
role = await RoleCrud.create(db_session, RoleCreate(name="author_role"))
user = await UserCrud.create(
db_session,
UserCreate(
username="multi_join",
email="multi@test.com",
role_id=role.id,
),
)
await PostCrud.create(
db_session,
PostCreate(title="Multi Join Post", author_id=user.id, is_published=True),
)
# Join both Role and Post
users = await UserCrud.get_multi(
db_session,
joins=[
(Role, Role.id == User.role_id),
(Post, Post.author_id == User.id),
],
filters=[Role.name == "author_role", Post.is_published == True], # noqa: E712
)
assert len(users) == 1
assert users[0].username == "multi_join"