mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-03-01 17:00:48 +01:00
feat: add join to crud functions (#21)
This commit is contained in:
@@ -4,7 +4,6 @@ from ..exceptions import NoSearchableFieldsError
|
|||||||
from .factory import CrudFactory
|
from .factory import CrudFactory
|
||||||
from .search import (
|
from .search import (
|
||||||
SearchConfig,
|
SearchConfig,
|
||||||
SearchFieldType,
|
|
||||||
get_searchable_fields,
|
get_searchable_fields,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -13,5 +12,4 @@ __all__ = [
|
|||||||
"get_searchable_fields",
|
"get_searchable_fields",
|
||||||
"NoSearchableFieldsError",
|
"NoSearchableFieldsError",
|
||||||
"SearchConfig",
|
"SearchConfig",
|
||||||
"SearchFieldType",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from ..exceptions import NotFoundError
|
|||||||
from .search import SearchConfig, SearchFieldType, build_search_filters
|
from .search import SearchConfig, SearchFieldType, build_search_filters
|
||||||
|
|
||||||
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
||||||
|
JoinType = list[tuple[type[DeclarativeBase], Any]]
|
||||||
|
|
||||||
|
|
||||||
class AsyncCrud(Generic[ModelType]):
|
class AsyncCrud(Generic[ModelType]):
|
||||||
@@ -55,6 +56,8 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
filters: list[Any],
|
filters: list[Any],
|
||||||
*,
|
*,
|
||||||
|
joins: JoinType | None = None,
|
||||||
|
outer_join: bool = False,
|
||||||
with_for_update: bool = False,
|
with_for_update: bool = False,
|
||||||
load_options: list[Any] | None = None,
|
load_options: list[Any] | None = None,
|
||||||
) -> ModelType:
|
) -> ModelType:
|
||||||
@@ -63,6 +66,8 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
Args:
|
Args:
|
||||||
session: DB async session
|
session: DB async session
|
||||||
filters: List of SQLAlchemy filter conditions
|
filters: List of SQLAlchemy filter conditions
|
||||||
|
joins: List of (model, condition) tuples for joining related tables
|
||||||
|
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
|
||||||
with_for_update: Lock the row for update
|
with_for_update: Lock the row for update
|
||||||
load_options: SQLAlchemy loader options (e.g., selectinload)
|
load_options: SQLAlchemy loader options (e.g., selectinload)
|
||||||
|
|
||||||
@@ -73,7 +78,15 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
NotFoundError: If no record found
|
NotFoundError: If no record found
|
||||||
MultipleResultsFound: If more than one record found
|
MultipleResultsFound: If more than one record found
|
||||||
"""
|
"""
|
||||||
q = select(cls.model).where(and_(*filters))
|
q = select(cls.model)
|
||||||
|
if joins:
|
||||||
|
for model, condition in joins:
|
||||||
|
q = (
|
||||||
|
q.outerjoin(model, condition)
|
||||||
|
if outer_join
|
||||||
|
else q.join(model, condition)
|
||||||
|
)
|
||||||
|
q = q.where(and_(*filters))
|
||||||
if load_options:
|
if load_options:
|
||||||
q = q.options(*load_options)
|
q = q.options(*load_options)
|
||||||
if with_for_update:
|
if with_for_update:
|
||||||
@@ -90,6 +103,8 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
filters: list[Any] | None = None,
|
filters: list[Any] | None = None,
|
||||||
*,
|
*,
|
||||||
|
joins: JoinType | None = None,
|
||||||
|
outer_join: bool = False,
|
||||||
load_options: list[Any] | None = None,
|
load_options: list[Any] | None = None,
|
||||||
) -> ModelType | None:
|
) -> ModelType | None:
|
||||||
"""Get the first matching record, or None.
|
"""Get the first matching record, or None.
|
||||||
@@ -97,12 +112,21 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
Args:
|
Args:
|
||||||
session: DB async session
|
session: DB async session
|
||||||
filters: List of SQLAlchemy filter conditions
|
filters: List of SQLAlchemy filter conditions
|
||||||
|
joins: List of (model, condition) tuples for joining related tables
|
||||||
|
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
|
||||||
load_options: SQLAlchemy loader options
|
load_options: SQLAlchemy loader options
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Model instance or None
|
Model instance or None
|
||||||
"""
|
"""
|
||||||
q = select(cls.model)
|
q = select(cls.model)
|
||||||
|
if joins:
|
||||||
|
for model, condition in joins:
|
||||||
|
q = (
|
||||||
|
q.outerjoin(model, condition)
|
||||||
|
if outer_join
|
||||||
|
else q.join(model, condition)
|
||||||
|
)
|
||||||
if filters:
|
if filters:
|
||||||
q = q.where(and_(*filters))
|
q = q.where(and_(*filters))
|
||||||
if load_options:
|
if load_options:
|
||||||
@@ -116,6 +140,8 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
*,
|
*,
|
||||||
filters: list[Any] | None = None,
|
filters: list[Any] | None = None,
|
||||||
|
joins: JoinType | None = None,
|
||||||
|
outer_join: bool = False,
|
||||||
load_options: list[Any] | None = None,
|
load_options: list[Any] | None = None,
|
||||||
order_by: Any | None = None,
|
order_by: Any | None = None,
|
||||||
limit: int | None = None,
|
limit: int | None = None,
|
||||||
@@ -126,6 +152,8 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
Args:
|
Args:
|
||||||
session: DB async session
|
session: DB async session
|
||||||
filters: List of SQLAlchemy filter conditions
|
filters: List of SQLAlchemy filter conditions
|
||||||
|
joins: List of (model, condition) tuples for joining related tables
|
||||||
|
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
|
||||||
load_options: SQLAlchemy loader options
|
load_options: SQLAlchemy loader options
|
||||||
order_by: Column or list of columns to order by
|
order_by: Column or list of columns to order by
|
||||||
limit: Max number of rows to return
|
limit: Max number of rows to return
|
||||||
@@ -135,6 +163,13 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
List of model instances
|
List of model instances
|
||||||
"""
|
"""
|
||||||
q = select(cls.model)
|
q = select(cls.model)
|
||||||
|
if joins:
|
||||||
|
for model, condition in joins:
|
||||||
|
q = (
|
||||||
|
q.outerjoin(model, condition)
|
||||||
|
if outer_join
|
||||||
|
else q.join(model, condition)
|
||||||
|
)
|
||||||
if filters:
|
if filters:
|
||||||
q = q.where(and_(*filters))
|
q = q.where(and_(*filters))
|
||||||
if load_options:
|
if load_options:
|
||||||
@@ -254,17 +289,29 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
cls: type[Self],
|
cls: type[Self],
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
filters: list[Any] | None = None,
|
filters: list[Any] | None = None,
|
||||||
|
*,
|
||||||
|
joins: JoinType | None = None,
|
||||||
|
outer_join: bool = False,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Count records matching the filters.
|
"""Count records matching the filters.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session: DB async session
|
session: DB async session
|
||||||
filters: List of SQLAlchemy filter conditions
|
filters: List of SQLAlchemy filter conditions
|
||||||
|
joins: List of (model, condition) tuples for joining related tables
|
||||||
|
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Number of matching records
|
Number of matching records
|
||||||
"""
|
"""
|
||||||
q = select(func.count()).select_from(cls.model)
|
q = select(func.count()).select_from(cls.model)
|
||||||
|
if joins:
|
||||||
|
for model, condition in joins:
|
||||||
|
q = (
|
||||||
|
q.outerjoin(model, condition)
|
||||||
|
if outer_join
|
||||||
|
else q.join(model, condition)
|
||||||
|
)
|
||||||
if filters:
|
if filters:
|
||||||
q = q.where(and_(*filters))
|
q = q.where(and_(*filters))
|
||||||
result = await session.execute(q)
|
result = await session.execute(q)
|
||||||
@@ -275,17 +322,30 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
cls: type[Self],
|
cls: type[Self],
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
filters: list[Any],
|
filters: list[Any],
|
||||||
|
*,
|
||||||
|
joins: JoinType | None = None,
|
||||||
|
outer_join: bool = False,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Check if a record exists.
|
"""Check if a record exists.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session: DB async session
|
session: DB async session
|
||||||
filters: List of SQLAlchemy filter conditions
|
filters: List of SQLAlchemy filter conditions
|
||||||
|
joins: List of (model, condition) tuples for joining related tables
|
||||||
|
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if at least one record matches
|
True if at least one record matches
|
||||||
"""
|
"""
|
||||||
q = select(cls.model).where(and_(*filters)).exists().select()
|
q = select(cls.model)
|
||||||
|
if joins:
|
||||||
|
for model, condition in joins:
|
||||||
|
q = (
|
||||||
|
q.outerjoin(model, condition)
|
||||||
|
if outer_join
|
||||||
|
else q.join(model, condition)
|
||||||
|
)
|
||||||
|
q = q.where(and_(*filters)).exists().select()
|
||||||
result = await session.execute(q)
|
result = await session.execute(q)
|
||||||
return bool(result.scalar())
|
return bool(result.scalar())
|
||||||
|
|
||||||
@@ -295,6 +355,8 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
*,
|
*,
|
||||||
filters: list[Any] | None = None,
|
filters: list[Any] | None = None,
|
||||||
|
joins: JoinType | None = None,
|
||||||
|
outer_join: bool = False,
|
||||||
load_options: list[Any] | None = None,
|
load_options: list[Any] | None = None,
|
||||||
order_by: Any | None = None,
|
order_by: Any | None = None,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
@@ -307,6 +369,8 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
Args:
|
Args:
|
||||||
session: DB async session
|
session: DB async session
|
||||||
filters: List of SQLAlchemy filter conditions
|
filters: List of SQLAlchemy filter conditions
|
||||||
|
joins: List of (model, condition) tuples for joining related tables
|
||||||
|
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
|
||||||
load_options: SQLAlchemy loader options
|
load_options: SQLAlchemy loader options
|
||||||
order_by: Column or list of columns to order by
|
order_by: Column or list of columns to order by
|
||||||
page: Page number (1-indexed)
|
page: Page number (1-indexed)
|
||||||
@@ -319,7 +383,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
"""
|
"""
|
||||||
filters = list(filters) if filters else []
|
filters = list(filters) if filters else []
|
||||||
offset = (page - 1) * items_per_page
|
offset = (page - 1) * items_per_page
|
||||||
joins: list[Any] = []
|
search_joins: list[Any] = []
|
||||||
|
|
||||||
# Build search filters
|
# Build search filters
|
||||||
if search:
|
if search:
|
||||||
@@ -330,11 +394,21 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
default_fields=cls.searchable_fields,
|
default_fields=cls.searchable_fields,
|
||||||
)
|
)
|
||||||
filters.extend(search_filters)
|
filters.extend(search_filters)
|
||||||
joins.extend(search_joins)
|
|
||||||
|
|
||||||
# Build query with joins
|
# Build query with joins
|
||||||
q = select(cls.model)
|
q = select(cls.model)
|
||||||
for join_rel in joins:
|
|
||||||
|
# Apply explicit joins
|
||||||
|
if joins:
|
||||||
|
for model, condition in joins:
|
||||||
|
q = (
|
||||||
|
q.outerjoin(model, condition)
|
||||||
|
if outer_join
|
||||||
|
else q.join(model, condition)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply search joins (always outer joins for search)
|
||||||
|
for join_rel in search_joins:
|
||||||
q = q.outerjoin(join_rel)
|
q = q.outerjoin(join_rel)
|
||||||
|
|
||||||
if filters:
|
if filters:
|
||||||
@@ -352,8 +426,20 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
pk_col = cls.model.__mapper__.primary_key[0]
|
pk_col = cls.model.__mapper__.primary_key[0]
|
||||||
count_q = select(func.count(func.distinct(getattr(cls.model, pk_col.name))))
|
count_q = select(func.count(func.distinct(getattr(cls.model, pk_col.name))))
|
||||||
count_q = count_q.select_from(cls.model)
|
count_q = count_q.select_from(cls.model)
|
||||||
for join_rel in joins:
|
|
||||||
|
# Apply explicit joins to count query
|
||||||
|
if joins:
|
||||||
|
for model, condition in joins:
|
||||||
|
count_q = (
|
||||||
|
count_q.outerjoin(model, condition)
|
||||||
|
if outer_join
|
||||||
|
else count_q.join(model, condition)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply search joins to count query
|
||||||
|
for join_rel in search_joins:
|
||||||
count_q = count_q.outerjoin(join_rel)
|
count_q = count_q.outerjoin(join_rel)
|
||||||
|
|
||||||
if filters:
|
if filters:
|
||||||
count_q = count_q.where(and_(*filters))
|
count_q = count_q.where(and_(*filters))
|
||||||
|
|
||||||
@@ -404,6 +490,20 @@ def CrudFactory(
|
|||||||
|
|
||||||
# With search
|
# With search
|
||||||
result = await UserCrud.paginate(session, search="john")
|
result = await UserCrud.paginate(session, search="john")
|
||||||
|
|
||||||
|
# With joins (inner join by default):
|
||||||
|
users = await UserCrud.get_multi(
|
||||||
|
session,
|
||||||
|
joins=[(Post, Post.user_id == User.id)],
|
||||||
|
filters=[Post.published == True],
|
||||||
|
)
|
||||||
|
|
||||||
|
# With outer join:
|
||||||
|
users = await UserCrud.get_multi(
|
||||||
|
session,
|
||||||
|
joins=[(Post, Post.user_id == User.id)],
|
||||||
|
outer_join=True,
|
||||||
|
)
|
||||||
"""
|
"""
|
||||||
cls = type(
|
cls = type(
|
||||||
f"Async{model.__name__}Crud",
|
f"Async{model.__name__}Crud",
|
||||||
|
|||||||
@@ -10,6 +10,9 @@ from fastapi_toolsets.crud.factory import AsyncCrud
|
|||||||
from fastapi_toolsets.exceptions import NotFoundError
|
from fastapi_toolsets.exceptions import NotFoundError
|
||||||
|
|
||||||
from .conftest import (
|
from .conftest import (
|
||||||
|
Post,
|
||||||
|
PostCreate,
|
||||||
|
PostCrud,
|
||||||
Role,
|
Role,
|
||||||
RoleCreate,
|
RoleCreate,
|
||||||
RoleCrud,
|
RoleCrud,
|
||||||
@@ -481,3 +484,271 @@ class TestCrudPaginate:
|
|||||||
|
|
||||||
names = [r.name for r in result["data"]]
|
names = [r.name for r in result["data"]]
|
||||||
assert names == ["alpha", "bravo", "charlie"]
|
assert names == ["alpha", "bravo", "charlie"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestCrudJoins:
|
||||||
|
"""Tests for CRUD operations with joins."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_get_with_join(self, db_session: AsyncSession):
|
||||||
|
"""Get with inner join filters correctly."""
|
||||||
|
# Create user with posts
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="author", email="author@test.com"),
|
||||||
|
)
|
||||||
|
await PostCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostCreate(title="Post 1", author_id=user.id, is_published=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get user with join on published posts
|
||||||
|
fetched = await UserCrud.get(
|
||||||
|
db_session,
|
||||||
|
filters=[User.id == user.id, Post.is_published == True], # noqa: E712
|
||||||
|
joins=[(Post, Post.author_id == User.id)],
|
||||||
|
)
|
||||||
|
assert fetched.id == user.id
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_first_with_join(self, db_session: AsyncSession):
|
||||||
|
"""First with join returns matching record."""
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="writer", email="writer@test.com"),
|
||||||
|
)
|
||||||
|
await PostCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostCreate(title="Draft", author_id=user.id, is_published=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find user with unpublished posts
|
||||||
|
result = await UserCrud.first(
|
||||||
|
db_session,
|
||||||
|
filters=[Post.is_published == False], # noqa: E712
|
||||||
|
joins=[(Post, Post.author_id == User.id)],
|
||||||
|
)
|
||||||
|
assert result is not None
|
||||||
|
assert result.id == user.id
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_first_with_outer_join(self, db_session: AsyncSession):
|
||||||
|
"""First with outer join includes records without related data."""
|
||||||
|
# User without posts
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="no_posts", email="no_posts@test.com"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# With outer join, user should be found even without posts
|
||||||
|
result = await UserCrud.first(
|
||||||
|
db_session,
|
||||||
|
filters=[User.id == user.id],
|
||||||
|
joins=[(Post, Post.author_id == User.id)],
|
||||||
|
outer_join=True,
|
||||||
|
)
|
||||||
|
assert result is not None
|
||||||
|
assert result.id == user.id
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_get_multi_with_inner_join(self, db_session: AsyncSession):
|
||||||
|
"""Get multiple with inner join only returns matching records."""
|
||||||
|
# User with published post
|
||||||
|
user1 = await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="publisher", email="pub@test.com"),
|
||||||
|
)
|
||||||
|
await PostCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostCreate(title="Published", author_id=user1.id, is_published=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# User without posts
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="lurker", email="lurk@test.com"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Inner join should only return user with published post
|
||||||
|
users = await UserCrud.get_multi(
|
||||||
|
db_session,
|
||||||
|
joins=[(Post, Post.author_id == User.id)],
|
||||||
|
filters=[Post.is_published == True], # noqa: E712
|
||||||
|
)
|
||||||
|
assert len(users) == 1
|
||||||
|
assert users[0].username == "publisher"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_get_multi_with_outer_join(self, db_session: AsyncSession):
|
||||||
|
"""Get multiple with outer join includes all records."""
|
||||||
|
# User with post
|
||||||
|
user1 = await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="has_post", email="has@test.com"),
|
||||||
|
)
|
||||||
|
await PostCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostCreate(title="My Post", author_id=user1.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
# User without posts
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="no_post", email="no@test.com"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Outer join should return both users
|
||||||
|
users = await UserCrud.get_multi(
|
||||||
|
db_session,
|
||||||
|
joins=[(Post, Post.author_id == User.id)],
|
||||||
|
outer_join=True,
|
||||||
|
)
|
||||||
|
assert len(users) == 2
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_count_with_join(self, db_session: AsyncSession):
|
||||||
|
"""Count with join counts correctly."""
|
||||||
|
# Create users with different post statuses
|
||||||
|
user1 = await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="active_author", email="active@test.com"),
|
||||||
|
)
|
||||||
|
await PostCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostCreate(title="Published 1", author_id=user1.id, is_published=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
user2 = await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="draft_author", email="draft@test.com"),
|
||||||
|
)
|
||||||
|
await PostCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostCreate(title="Draft 1", author_id=user2.id, is_published=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Count users with published posts
|
||||||
|
count = await UserCrud.count(
|
||||||
|
db_session,
|
||||||
|
filters=[Post.is_published == True], # noqa: E712
|
||||||
|
joins=[(Post, Post.author_id == User.id)],
|
||||||
|
)
|
||||||
|
assert count == 1
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_exists_with_join(self, db_session: AsyncSession):
|
||||||
|
"""Exists with join checks correctly."""
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="poster", email="poster@test.com"),
|
||||||
|
)
|
||||||
|
await PostCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostCreate(title="Exists Post", author_id=user.id, is_published=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if user with published post exists
|
||||||
|
exists = await UserCrud.exists(
|
||||||
|
db_session,
|
||||||
|
filters=[Post.is_published == True], # noqa: E712
|
||||||
|
joins=[(Post, Post.author_id == User.id)],
|
||||||
|
)
|
||||||
|
assert exists is True
|
||||||
|
|
||||||
|
# Check if user with specific title exists
|
||||||
|
exists = await UserCrud.exists(
|
||||||
|
db_session,
|
||||||
|
filters=[Post.title == "Nonexistent"],
|
||||||
|
joins=[(Post, Post.author_id == User.id)],
|
||||||
|
)
|
||||||
|
assert exists is False
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_paginate_with_join(self, db_session: AsyncSession):
|
||||||
|
"""Paginate with join works correctly."""
|
||||||
|
# Create users with posts
|
||||||
|
for i in range(5):
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username=f"author{i}", email=f"author{i}@test.com"),
|
||||||
|
)
|
||||||
|
await PostCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostCreate(
|
||||||
|
title=f"Post {i}",
|
||||||
|
author_id=user.id,
|
||||||
|
is_published=i % 2 == 0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Paginate users with published posts
|
||||||
|
result = await UserCrud.paginate(
|
||||||
|
db_session,
|
||||||
|
joins=[(Post, Post.author_id == User.id)],
|
||||||
|
filters=[Post.is_published == True], # noqa: E712
|
||||||
|
page=1,
|
||||||
|
items_per_page=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["pagination"]["total_count"] == 3
|
||||||
|
assert len(result["data"]) == 3
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_paginate_with_outer_join(self, db_session: AsyncSession):
|
||||||
|
"""Paginate with outer join includes all records."""
|
||||||
|
# User with post
|
||||||
|
user1 = await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="with_post", email="with@test.com"),
|
||||||
|
)
|
||||||
|
await PostCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostCreate(title="A Post", author_id=user1.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
# User without post
|
||||||
|
await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(username="without_post", email="without@test.com"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Paginate with outer join
|
||||||
|
result = await UserCrud.paginate(
|
||||||
|
db_session,
|
||||||
|
joins=[(Post, Post.author_id == User.id)],
|
||||||
|
outer_join=True,
|
||||||
|
page=1,
|
||||||
|
items_per_page=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["pagination"]["total_count"] == 2
|
||||||
|
assert len(result["data"]) == 2
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_multiple_joins(self, db_session: AsyncSession):
|
||||||
|
"""Multiple joins can be applied."""
|
||||||
|
role = await RoleCrud.create(db_session, RoleCreate(name="author_role"))
|
||||||
|
user = await UserCrud.create(
|
||||||
|
db_session,
|
||||||
|
UserCreate(
|
||||||
|
username="multi_join",
|
||||||
|
email="multi@test.com",
|
||||||
|
role_id=role.id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await PostCrud.create(
|
||||||
|
db_session,
|
||||||
|
PostCreate(title="Multi Join Post", author_id=user.id, is_published=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Join both Role and Post
|
||||||
|
users = await UserCrud.get_multi(
|
||||||
|
db_session,
|
||||||
|
joins=[
|
||||||
|
(Role, Role.id == User.role_id),
|
||||||
|
(Post, Post.author_id == User.id),
|
||||||
|
],
|
||||||
|
filters=[Role.name == "author_role", Post.is_published == True], # noqa: E712
|
||||||
|
)
|
||||||
|
assert len(users) == 1
|
||||||
|
assert users[0].username == "multi_join"
|
||||||
|
|||||||
Reference in New Issue
Block a user