5 Commits

Author SHA1 Message Date
d3vyce
8c287b3ce7 feat: add join to crud functions (#21) 2026-02-01 15:01:10 +01:00
54f5479c24 Version 0.4.1 2026-01-29 14:15:55 -05:00
d3vyce
f467754df1 fix: cast to String non-text columns for crud search (#18)
fix: cast to String non-text columns for crud search
2026-01-29 19:44:48 +01:00
b57ce40b05 tests: change models to use UUID as primary key 2026-01-29 13:43:03 -05:00
5264631550 fix: cast to String non-text columns for crud search 2026-01-29 13:35:20 -05:00
12 changed files with 588 additions and 92 deletions

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "fastapi-toolsets" name = "fastapi-toolsets"
version = "0.4.0" version = "0.4.1"
description = "Reusable tools for FastAPI: async CRUD, fixtures, CLI, and standardized responses for SQLAlchemy + PostgreSQL" description = "Reusable tools for FastAPI: async CRUD, fixtures, CLI, and standardized responses for SQLAlchemy + PostgreSQL"
readme = "README.md" readme = "README.md"
license = "MIT" license = "MIT"

View File

@@ -21,4 +21,4 @@ Example usage:
return Response(data={"user": user.username}, message="Success") return Response(data={"user": user.username}, message="Success")
""" """
__version__ = "0.4.0" __version__ = "0.4.1"

View File

@@ -4,7 +4,6 @@ from ..exceptions import NoSearchableFieldsError
from .factory import CrudFactory from .factory import CrudFactory
from .search import ( from .search import (
SearchConfig, SearchConfig,
SearchFieldType,
get_searchable_fields, get_searchable_fields,
) )
@@ -13,5 +12,4 @@ __all__ = [
"get_searchable_fields", "get_searchable_fields",
"NoSearchableFieldsError", "NoSearchableFieldsError",
"SearchConfig", "SearchConfig",
"SearchFieldType",
] ]

View File

@@ -17,6 +17,7 @@ from ..exceptions import NotFoundError
from .search import SearchConfig, SearchFieldType, build_search_filters from .search import SearchConfig, SearchFieldType, build_search_filters
ModelType = TypeVar("ModelType", bound=DeclarativeBase) ModelType = TypeVar("ModelType", bound=DeclarativeBase)
JoinType = list[tuple[type[DeclarativeBase], Any]]
class AsyncCrud(Generic[ModelType]): class AsyncCrud(Generic[ModelType]):
@@ -55,6 +56,8 @@ class AsyncCrud(Generic[ModelType]):
session: AsyncSession, session: AsyncSession,
filters: list[Any], filters: list[Any],
*, *,
joins: JoinType | None = None,
outer_join: bool = False,
with_for_update: bool = False, with_for_update: bool = False,
load_options: list[Any] | None = None, load_options: list[Any] | None = None,
) -> ModelType: ) -> ModelType:
@@ -63,6 +66,8 @@ class AsyncCrud(Generic[ModelType]):
Args: Args:
session: DB async session session: DB async session
filters: List of SQLAlchemy filter conditions filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
with_for_update: Lock the row for update with_for_update: Lock the row for update
load_options: SQLAlchemy loader options (e.g., selectinload) load_options: SQLAlchemy loader options (e.g., selectinload)
@@ -73,7 +78,15 @@ class AsyncCrud(Generic[ModelType]):
NotFoundError: If no record found NotFoundError: If no record found
MultipleResultsFound: If more than one record found MultipleResultsFound: If more than one record found
""" """
q = select(cls.model).where(and_(*filters)) q = select(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
q = q.where(and_(*filters))
if load_options: if load_options:
q = q.options(*load_options) q = q.options(*load_options)
if with_for_update: if with_for_update:
@@ -90,6 +103,8 @@ class AsyncCrud(Generic[ModelType]):
session: AsyncSession, session: AsyncSession,
filters: list[Any] | None = None, filters: list[Any] | None = None,
*, *,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: list[Any] | None = None, load_options: list[Any] | None = None,
) -> ModelType | None: ) -> ModelType | None:
"""Get the first matching record, or None. """Get the first matching record, or None.
@@ -97,12 +112,21 @@ class AsyncCrud(Generic[ModelType]):
Args: Args:
session: DB async session session: DB async session
filters: List of SQLAlchemy filter conditions filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
load_options: SQLAlchemy loader options load_options: SQLAlchemy loader options
Returns: Returns:
Model instance or None Model instance or None
""" """
q = select(cls.model) q = select(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
if filters: if filters:
q = q.where(and_(*filters)) q = q.where(and_(*filters))
if load_options: if load_options:
@@ -116,6 +140,8 @@ class AsyncCrud(Generic[ModelType]):
session: AsyncSession, session: AsyncSession,
*, *,
filters: list[Any] | None = None, filters: list[Any] | None = None,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: list[Any] | None = None, load_options: list[Any] | None = None,
order_by: Any | None = None, order_by: Any | None = None,
limit: int | None = None, limit: int | None = None,
@@ -126,6 +152,8 @@ class AsyncCrud(Generic[ModelType]):
Args: Args:
session: DB async session session: DB async session
filters: List of SQLAlchemy filter conditions filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
load_options: SQLAlchemy loader options load_options: SQLAlchemy loader options
order_by: Column or list of columns to order by order_by: Column or list of columns to order by
limit: Max number of rows to return limit: Max number of rows to return
@@ -135,6 +163,13 @@ class AsyncCrud(Generic[ModelType]):
List of model instances List of model instances
""" """
q = select(cls.model) q = select(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
if filters: if filters:
q = q.where(and_(*filters)) q = q.where(and_(*filters))
if load_options: if load_options:
@@ -254,17 +289,29 @@ class AsyncCrud(Generic[ModelType]):
cls: type[Self], cls: type[Self],
session: AsyncSession, session: AsyncSession,
filters: list[Any] | None = None, filters: list[Any] | None = None,
*,
joins: JoinType | None = None,
outer_join: bool = False,
) -> int: ) -> int:
"""Count records matching the filters. """Count records matching the filters.
Args: Args:
session: DB async session session: DB async session
filters: List of SQLAlchemy filter conditions filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
Returns: Returns:
Number of matching records Number of matching records
""" """
q = select(func.count()).select_from(cls.model) q = select(func.count()).select_from(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
if filters: if filters:
q = q.where(and_(*filters)) q = q.where(and_(*filters))
result = await session.execute(q) result = await session.execute(q)
@@ -275,17 +322,30 @@ class AsyncCrud(Generic[ModelType]):
cls: type[Self], cls: type[Self],
session: AsyncSession, session: AsyncSession,
filters: list[Any], filters: list[Any],
*,
joins: JoinType | None = None,
outer_join: bool = False,
) -> bool: ) -> bool:
"""Check if a record exists. """Check if a record exists.
Args: Args:
session: DB async session session: DB async session
filters: List of SQLAlchemy filter conditions filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
Returns: Returns:
True if at least one record matches True if at least one record matches
""" """
q = select(cls.model).where(and_(*filters)).exists().select() q = select(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
q = q.where(and_(*filters)).exists().select()
result = await session.execute(q) result = await session.execute(q)
return bool(result.scalar()) return bool(result.scalar())
@@ -295,6 +355,8 @@ class AsyncCrud(Generic[ModelType]):
session: AsyncSession, session: AsyncSession,
*, *,
filters: list[Any] | None = None, filters: list[Any] | None = None,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: list[Any] | None = None, load_options: list[Any] | None = None,
order_by: Any | None = None, order_by: Any | None = None,
page: int = 1, page: int = 1,
@@ -307,6 +369,8 @@ class AsyncCrud(Generic[ModelType]):
Args: Args:
session: DB async session session: DB async session
filters: List of SQLAlchemy filter conditions filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
load_options: SQLAlchemy loader options load_options: SQLAlchemy loader options
order_by: Column or list of columns to order by order_by: Column or list of columns to order by
page: Page number (1-indexed) page: Page number (1-indexed)
@@ -319,7 +383,7 @@ class AsyncCrud(Generic[ModelType]):
""" """
filters = list(filters) if filters else [] filters = list(filters) if filters else []
offset = (page - 1) * items_per_page offset = (page - 1) * items_per_page
joins: list[Any] = [] search_joins: list[Any] = []
# Build search filters # Build search filters
if search: if search:
@@ -330,11 +394,21 @@ class AsyncCrud(Generic[ModelType]):
default_fields=cls.searchable_fields, default_fields=cls.searchable_fields,
) )
filters.extend(search_filters) filters.extend(search_filters)
joins.extend(search_joins)
# Build query with joins # Build query with joins
q = select(cls.model) q = select(cls.model)
for join_rel in joins:
# Apply explicit joins
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
# Apply search joins (always outer joins for search)
for join_rel in search_joins:
q = q.outerjoin(join_rel) q = q.outerjoin(join_rel)
if filters: if filters:
@@ -352,8 +426,20 @@ class AsyncCrud(Generic[ModelType]):
pk_col = cls.model.__mapper__.primary_key[0] pk_col = cls.model.__mapper__.primary_key[0]
count_q = select(func.count(func.distinct(getattr(cls.model, pk_col.name)))) count_q = select(func.count(func.distinct(getattr(cls.model, pk_col.name))))
count_q = count_q.select_from(cls.model) count_q = count_q.select_from(cls.model)
for join_rel in joins:
# Apply explicit joins to count query
if joins:
for model, condition in joins:
count_q = (
count_q.outerjoin(model, condition)
if outer_join
else count_q.join(model, condition)
)
# Apply search joins to count query
for join_rel in search_joins:
count_q = count_q.outerjoin(join_rel) count_q = count_q.outerjoin(join_rel)
if filters: if filters:
count_q = count_q.where(and_(*filters)) count_q = count_q.where(and_(*filters))
@@ -404,6 +490,20 @@ def CrudFactory(
# With search # With search
result = await UserCrud.paginate(session, search="john") result = await UserCrud.paginate(session, search="john")
# With joins (inner join by default):
users = await UserCrud.get_multi(
session,
joins=[(Post, Post.user_id == User.id)],
filters=[Post.published == True],
)
# With outer join:
users = await UserCrud.get_multi(
session,
joins=[(Post, Post.user_id == User.id)],
outer_join=True,
)
""" """
cls = type( cls = type(
f"Async{model.__name__}Crud", f"Async{model.__name__}Crud",

View File

@@ -129,11 +129,12 @@ def build_search_filters(
else: else:
column = field column = field
# Build the filter # Build the filter (cast to String for non-text columns)
column_as_string = column.cast(String)
if config.case_sensitive: if config.case_sensitive:
filters.append(column.like(f"%{query}%")) filters.append(column_as_string.like(f"%{query}%"))
else: else:
filters.append(column.ilike(f"%{query}%")) filters.append(column_as_string.ilike(f"%{query}%"))
if not filters: if not filters:
return [], [] return [], []

View File

@@ -1,4 +1,5 @@
from .exceptions import ( from .exceptions import (
ApiError,
ApiException, ApiException,
ConflictError, ConflictError,
ForbiddenError, ForbiddenError,
@@ -10,11 +11,12 @@ from .exceptions import (
from .handler import init_exceptions_handlers from .handler import init_exceptions_handlers
__all__ = [ __all__ = [
"init_exceptions_handlers", "ApiError",
"generate_error_responses",
"ApiException", "ApiException",
"ConflictError", "ConflictError",
"ForbiddenError", "ForbiddenError",
"generate_error_responses",
"init_exceptions_handlers",
"NoSearchableFieldsError", "NoSearchableFieldsError",
"NotFoundError", "NotFoundError",
"UnauthorizedError", "UnauthorizedError",

View File

@@ -1,10 +1,11 @@
"""Shared pytest fixtures for fastapi-utils tests.""" """Shared pytest fixtures for fastapi-utils tests."""
import os import os
import uuid
import pytest import pytest
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import ForeignKey, String from sqlalchemy import ForeignKey, String, Uuid
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
@@ -33,7 +34,7 @@ class Role(Base):
__tablename__ = "roles" __tablename__ = "roles"
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
name: Mapped[str] = mapped_column(String(50), unique=True) name: Mapped[str] = mapped_column(String(50), unique=True)
users: Mapped[list["User"]] = relationship(back_populates="role") users: Mapped[list["User"]] = relationship(back_populates="role")
@@ -44,11 +45,13 @@ class User(Base):
__tablename__ = "users" __tablename__ = "users"
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
username: Mapped[str] = mapped_column(String(50), unique=True) username: Mapped[str] = mapped_column(String(50), unique=True)
email: Mapped[str] = mapped_column(String(100), unique=True) email: Mapped[str] = mapped_column(String(100), unique=True)
is_active: Mapped[bool] = mapped_column(default=True) is_active: Mapped[bool] = mapped_column(default=True)
role_id: Mapped[int | None] = mapped_column(ForeignKey("roles.id"), nullable=True) role_id: Mapped[uuid.UUID | None] = mapped_column(
ForeignKey("roles.id"), nullable=True
)
role: Mapped[Role | None] = relationship(back_populates="users") role: Mapped[Role | None] = relationship(back_populates="users")
@@ -58,11 +61,11 @@ class Post(Base):
__tablename__ = "posts" __tablename__ = "posts"
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
title: Mapped[str] = mapped_column(String(200)) title: Mapped[str] = mapped_column(String(200))
content: Mapped[str] = mapped_column(String(1000), default="") content: Mapped[str] = mapped_column(String(1000), default="")
is_published: Mapped[bool] = mapped_column(default=False) is_published: Mapped[bool] = mapped_column(default=False)
author_id: Mapped[int] = mapped_column(ForeignKey("users.id")) author_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"))
# ============================================================================= # =============================================================================
@@ -73,7 +76,7 @@ class Post(Base):
class RoleCreate(BaseModel): class RoleCreate(BaseModel):
"""Schema for creating a role.""" """Schema for creating a role."""
id: int | None = None id: uuid.UUID | None = None
name: str name: str
@@ -86,11 +89,11 @@ class RoleUpdate(BaseModel):
class UserCreate(BaseModel): class UserCreate(BaseModel):
"""Schema for creating a user.""" """Schema for creating a user."""
id: int | None = None id: uuid.UUID | None = None
username: str username: str
email: str email: str
is_active: bool = True is_active: bool = True
role_id: int | None = None role_id: uuid.UUID | None = None
class UserUpdate(BaseModel): class UserUpdate(BaseModel):
@@ -99,17 +102,17 @@ class UserUpdate(BaseModel):
username: str | None = None username: str | None = None
email: str | None = None email: str | None = None
is_active: bool | None = None is_active: bool | None = None
role_id: int | None = None role_id: uuid.UUID | None = None
class PostCreate(BaseModel): class PostCreate(BaseModel):
"""Schema for creating a post.""" """Schema for creating a post."""
id: int | None = None id: uuid.UUID | None = None
title: str title: str
content: str = "" content: str = ""
is_published: bool = False is_published: bool = False
author_id: int author_id: uuid.UUID
class PostUpdate(BaseModel): class PostUpdate(BaseModel):
@@ -195,5 +198,5 @@ def sample_post_data() -> PostCreate:
title="Test Post", title="Test Post",
content="Test content", content="Test content",
is_published=True, is_published=True,
author_id=1, author_id=uuid.uuid4(),
) )

View File

@@ -1,5 +1,7 @@
"""Tests for fastapi_toolsets.crud module.""" """Tests for fastapi_toolsets.crud module."""
import uuid
import pytest import pytest
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@@ -8,6 +10,9 @@ from fastapi_toolsets.crud.factory import AsyncCrud
from fastapi_toolsets.exceptions import NotFoundError from fastapi_toolsets.exceptions import NotFoundError
from .conftest import ( from .conftest import (
Post,
PostCreate,
PostCrud,
Role, Role,
RoleCreate, RoleCreate,
RoleCrud, RoleCrud,
@@ -89,8 +94,9 @@ class TestCrudGet:
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_raises_not_found(self, db_session: AsyncSession): async def test_get_raises_not_found(self, db_session: AsyncSession):
"""Get raises NotFoundError for missing records.""" """Get raises NotFoundError for missing records."""
non_existent_id = uuid.uuid4()
with pytest.raises(NotFoundError): with pytest.raises(NotFoundError):
await RoleCrud.get(db_session, [Role.id == 99999]) await RoleCrud.get(db_session, [Role.id == non_existent_id])
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_with_multiple_filters(self, db_session: AsyncSession): async def test_get_with_multiple_filters(self, db_session: AsyncSession):
@@ -223,11 +229,12 @@ class TestCrudUpdate:
@pytest.mark.anyio @pytest.mark.anyio
async def test_update_raises_not_found(self, db_session: AsyncSession): async def test_update_raises_not_found(self, db_session: AsyncSession):
"""Update raises NotFoundError for missing records.""" """Update raises NotFoundError for missing records."""
non_existent_id = uuid.uuid4()
with pytest.raises(NotFoundError): with pytest.raises(NotFoundError):
await RoleCrud.update( await RoleCrud.update(
db_session, db_session,
RoleUpdate(name="new"), RoleUpdate(name="new"),
[Role.id == 99999], [Role.id == non_existent_id],
) )
@pytest.mark.anyio @pytest.mark.anyio
@@ -340,7 +347,8 @@ class TestCrudUpsert:
@pytest.mark.anyio @pytest.mark.anyio
async def test_upsert_insert_new_record(self, db_session: AsyncSession): async def test_upsert_insert_new_record(self, db_session: AsyncSession):
"""Upsert inserts a new record when it doesn't exist.""" """Upsert inserts a new record when it doesn't exist."""
data = RoleCreate(id=1, name="upsert_new") role_id = uuid.uuid4()
data = RoleCreate(id=role_id, name="upsert_new")
role = await RoleCrud.upsert( role = await RoleCrud.upsert(
db_session, db_session,
data, data,
@@ -353,12 +361,13 @@ class TestCrudUpsert:
@pytest.mark.anyio @pytest.mark.anyio
async def test_upsert_update_existing_record(self, db_session: AsyncSession): async def test_upsert_update_existing_record(self, db_session: AsyncSession):
"""Upsert updates an existing record.""" """Upsert updates an existing record."""
role_id = uuid.uuid4()
# First insert # First insert
data = RoleCreate(id=100, name="original_name") data = RoleCreate(id=role_id, name="original_name")
await RoleCrud.upsert(db_session, data, index_elements=["id"]) await RoleCrud.upsert(db_session, data, index_elements=["id"])
# Upsert with update # Upsert with update
updated_data = RoleCreate(id=100, name="updated_name") updated_data = RoleCreate(id=role_id, name="updated_name")
role = await RoleCrud.upsert( role = await RoleCrud.upsert(
db_session, db_session,
updated_data, updated_data,
@@ -370,22 +379,23 @@ class TestCrudUpsert:
assert role.name == "updated_name" assert role.name == "updated_name"
# Verify only one record exists # Verify only one record exists
count = await RoleCrud.count(db_session, [Role.id == 100]) count = await RoleCrud.count(db_session, [Role.id == role_id])
assert count == 1 assert count == 1
@pytest.mark.anyio @pytest.mark.anyio
async def test_upsert_do_nothing_on_conflict(self, db_session: AsyncSession): async def test_upsert_do_nothing_on_conflict(self, db_session: AsyncSession):
"""Upsert does nothing on conflict when set_ is not provided.""" """Upsert does nothing on conflict when set_ is not provided."""
role_id = uuid.uuid4()
# First insert # First insert
data = RoleCreate(id=200, name="do_nothing_original") data = RoleCreate(id=role_id, name="do_nothing_original")
await RoleCrud.upsert(db_session, data, index_elements=["id"]) await RoleCrud.upsert(db_session, data, index_elements=["id"])
# Upsert without set_ (do nothing) # Upsert without set_ (do nothing)
conflict_data = RoleCreate(id=200, name="do_nothing_conflict") conflict_data = RoleCreate(id=role_id, name="do_nothing_conflict")
await RoleCrud.upsert(db_session, conflict_data, index_elements=["id"]) await RoleCrud.upsert(db_session, conflict_data, index_elements=["id"])
# Original value should be preserved # Original value should be preserved
role = await RoleCrud.first(db_session, [Role.id == 200]) role = await RoleCrud.first(db_session, [Role.id == role_id])
assert role is not None assert role is not None
assert role.name == "do_nothing_original" assert role.name == "do_nothing_original"
@@ -474,3 +484,271 @@ class TestCrudPaginate:
names = [r.name for r in result["data"]] names = [r.name for r in result["data"]]
assert names == ["alpha", "bravo", "charlie"] assert names == ["alpha", "bravo", "charlie"]
class TestCrudJoins:
"""Tests for CRUD operations with joins."""
@pytest.mark.anyio
async def test_get_with_join(self, db_session: AsyncSession):
"""Get with inner join filters correctly."""
# Create user with posts
user = await UserCrud.create(
db_session,
UserCreate(username="author", email="author@test.com"),
)
await PostCrud.create(
db_session,
PostCreate(title="Post 1", author_id=user.id, is_published=True),
)
# Get user with join on published posts
fetched = await UserCrud.get(
db_session,
filters=[User.id == user.id, Post.is_published == True], # noqa: E712
joins=[(Post, Post.author_id == User.id)],
)
assert fetched.id == user.id
@pytest.mark.anyio
async def test_first_with_join(self, db_session: AsyncSession):
"""First with join returns matching record."""
user = await UserCrud.create(
db_session,
UserCreate(username="writer", email="writer@test.com"),
)
await PostCrud.create(
db_session,
PostCreate(title="Draft", author_id=user.id, is_published=False),
)
# Find user with unpublished posts
result = await UserCrud.first(
db_session,
filters=[Post.is_published == False], # noqa: E712
joins=[(Post, Post.author_id == User.id)],
)
assert result is not None
assert result.id == user.id
@pytest.mark.anyio
async def test_first_with_outer_join(self, db_session: AsyncSession):
"""First with outer join includes records without related data."""
# User without posts
user = await UserCrud.create(
db_session,
UserCreate(username="no_posts", email="no_posts@test.com"),
)
# With outer join, user should be found even without posts
result = await UserCrud.first(
db_session,
filters=[User.id == user.id],
joins=[(Post, Post.author_id == User.id)],
outer_join=True,
)
assert result is not None
assert result.id == user.id
@pytest.mark.anyio
async def test_get_multi_with_inner_join(self, db_session: AsyncSession):
"""Get multiple with inner join only returns matching records."""
# User with published post
user1 = await UserCrud.create(
db_session,
UserCreate(username="publisher", email="pub@test.com"),
)
await PostCrud.create(
db_session,
PostCreate(title="Published", author_id=user1.id, is_published=True),
)
# User without posts
await UserCrud.create(
db_session,
UserCreate(username="lurker", email="lurk@test.com"),
)
# Inner join should only return user with published post
users = await UserCrud.get_multi(
db_session,
joins=[(Post, Post.author_id == User.id)],
filters=[Post.is_published == True], # noqa: E712
)
assert len(users) == 1
assert users[0].username == "publisher"
@pytest.mark.anyio
async def test_get_multi_with_outer_join(self, db_session: AsyncSession):
"""Get multiple with outer join includes all records."""
# User with post
user1 = await UserCrud.create(
db_session,
UserCreate(username="has_post", email="has@test.com"),
)
await PostCrud.create(
db_session,
PostCreate(title="My Post", author_id=user1.id),
)
# User without posts
await UserCrud.create(
db_session,
UserCreate(username="no_post", email="no@test.com"),
)
# Outer join should return both users
users = await UserCrud.get_multi(
db_session,
joins=[(Post, Post.author_id == User.id)],
outer_join=True,
)
assert len(users) == 2
@pytest.mark.anyio
async def test_count_with_join(self, db_session: AsyncSession):
"""Count with join counts correctly."""
# Create users with different post statuses
user1 = await UserCrud.create(
db_session,
UserCreate(username="active_author", email="active@test.com"),
)
await PostCrud.create(
db_session,
PostCreate(title="Published 1", author_id=user1.id, is_published=True),
)
user2 = await UserCrud.create(
db_session,
UserCreate(username="draft_author", email="draft@test.com"),
)
await PostCrud.create(
db_session,
PostCreate(title="Draft 1", author_id=user2.id, is_published=False),
)
# Count users with published posts
count = await UserCrud.count(
db_session,
filters=[Post.is_published == True], # noqa: E712
joins=[(Post, Post.author_id == User.id)],
)
assert count == 1
@pytest.mark.anyio
async def test_exists_with_join(self, db_session: AsyncSession):
"""Exists with join checks correctly."""
user = await UserCrud.create(
db_session,
UserCreate(username="poster", email="poster@test.com"),
)
await PostCrud.create(
db_session,
PostCreate(title="Exists Post", author_id=user.id, is_published=True),
)
# Check if user with published post exists
exists = await UserCrud.exists(
db_session,
filters=[Post.is_published == True], # noqa: E712
joins=[(Post, Post.author_id == User.id)],
)
assert exists is True
# Check if user with specific title exists
exists = await UserCrud.exists(
db_session,
filters=[Post.title == "Nonexistent"],
joins=[(Post, Post.author_id == User.id)],
)
assert exists is False
@pytest.mark.anyio
async def test_paginate_with_join(self, db_session: AsyncSession):
"""Paginate with join works correctly."""
# Create users with posts
for i in range(5):
user = await UserCrud.create(
db_session,
UserCreate(username=f"author{i}", email=f"author{i}@test.com"),
)
await PostCrud.create(
db_session,
PostCreate(
title=f"Post {i}",
author_id=user.id,
is_published=i % 2 == 0,
),
)
# Paginate users with published posts
result = await UserCrud.paginate(
db_session,
joins=[(Post, Post.author_id == User.id)],
filters=[Post.is_published == True], # noqa: E712
page=1,
items_per_page=10,
)
assert result["pagination"]["total_count"] == 3
assert len(result["data"]) == 3
@pytest.mark.anyio
async def test_paginate_with_outer_join(self, db_session: AsyncSession):
"""Paginate with outer join includes all records."""
# User with post
user1 = await UserCrud.create(
db_session,
UserCreate(username="with_post", email="with@test.com"),
)
await PostCrud.create(
db_session,
PostCreate(title="A Post", author_id=user1.id),
)
# User without post
await UserCrud.create(
db_session,
UserCreate(username="without_post", email="without@test.com"),
)
# Paginate with outer join
result = await UserCrud.paginate(
db_session,
joins=[(Post, Post.author_id == User.id)],
outer_join=True,
page=1,
items_per_page=10,
)
assert result["pagination"]["total_count"] == 2
assert len(result["data"]) == 2
@pytest.mark.anyio
async def test_multiple_joins(self, db_session: AsyncSession):
"""Multiple joins can be applied."""
role = await RoleCrud.create(db_session, RoleCreate(name="author_role"))
user = await UserCrud.create(
db_session,
UserCreate(
username="multi_join",
email="multi@test.com",
role_id=role.id,
),
)
await PostCrud.create(
db_session,
PostCreate(title="Multi Join Post", author_id=user.id, is_published=True),
)
# Join both Role and Post
users = await UserCrud.get_multi(
db_session,
joins=[
(Role, Role.id == User.role_id),
(Post, Post.author_id == User.id),
],
filters=[Role.name == "author_role", Post.is_published == True], # noqa: E712
)
assert len(users) == 1
assert users[0].username == "multi_join"

View File

@@ -1,5 +1,7 @@
"""Tests for CRUD search functionality.""" """Tests for CRUD search functionality."""
import uuid
import pytest import pytest
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@@ -272,6 +274,27 @@ class TestPaginateSearch:
usernames = [u.username for u in result["data"]] usernames = [u.username for u in result["data"]]
assert usernames == ["alice", "bob", "charlie"] assert usernames == ["alice", "bob", "charlie"]
@pytest.mark.anyio
async def test_search_non_string_column(self, db_session: AsyncSession):
"""Search on non-string columns (e.g., UUID) works via cast."""
user_id = uuid.UUID("12345678-1234-5678-1234-567812345678")
await UserCrud.create(
db_session, UserCreate(id=user_id, username="john", email="john@test.com")
)
await UserCrud.create(
db_session, UserCreate(username="jane", email="jane@test.com")
)
# Search by UUID (partial match)
result = await UserCrud.paginate(
db_session,
search="12345678",
search_fields=[User.id, User.username],
)
assert result["pagination"]["total_count"] == 1
assert result["data"][0].id == user_id
class TestSearchConfig: class TestSearchConfig:
"""Tests for SearchConfig options.""" """Tests for SearchConfig options."""

View File

@@ -1,5 +1,7 @@
"""Tests for fastapi_toolsets.fixtures module.""" """Tests for fastapi_toolsets.fixtures module."""
import uuid
import pytest import pytest
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@@ -57,20 +59,22 @@ class TestFixtureRegistry:
def test_register_with_decorator(self): def test_register_with_decorator(self):
"""Register fixture with decorator.""" """Register fixture with decorator."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id = uuid.uuid4()
@registry.register @registry.register
def roles(): def roles():
return [Role(id=1, name="admin")] return [Role(id=role_id, name="admin")]
assert "roles" in [f.name for f in registry.get_all()] assert "roles" in [f.name for f in registry.get_all()]
def test_register_with_custom_name(self): def test_register_with_custom_name(self):
"""Register fixture with custom name.""" """Register fixture with custom name."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id = uuid.uuid4()
@registry.register(name="custom_roles") @registry.register(name="custom_roles")
def roles(): def roles():
return [Role(id=1, name="admin")] return [Role(id=role_id, name="admin")]
fixture = registry.get("custom_roles") fixture = registry.get("custom_roles")
assert fixture.name == "custom_roles" assert fixture.name == "custom_roles"
@@ -78,14 +82,23 @@ class TestFixtureRegistry:
def test_register_with_dependencies(self): def test_register_with_dependencies(self):
"""Register fixture with dependencies.""" """Register fixture with dependencies."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id = uuid.uuid4()
user_id = uuid.uuid4()
@registry.register @registry.register
def roles(): def roles():
return [Role(id=1, name="admin")] return [Role(id=role_id, name="admin")]
@registry.register(depends_on=["roles"]) @registry.register(depends_on=["roles"])
def users(): def users():
return [User(id=1, username="admin", email="admin@test.com", role_id=1)] return [
User(
id=user_id,
username="admin",
email="admin@test.com",
role_id=role_id,
)
]
fixture = registry.get("users") fixture = registry.get("users")
assert fixture.depends_on == ["roles"] assert fixture.depends_on == ["roles"]
@@ -93,10 +106,11 @@ class TestFixtureRegistry:
def test_register_with_contexts(self): def test_register_with_contexts(self):
"""Register fixture with contexts.""" """Register fixture with contexts."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id = uuid.uuid4()
@registry.register(contexts=[Context.TESTING]) @registry.register(contexts=[Context.TESTING])
def test_data(): def test_data():
return [Role(id=100, name="test")] return [Role(id=role_id, name="test")]
fixture = registry.get("test_data") fixture = registry.get("test_data")
assert Context.TESTING.value in fixture.contexts assert Context.TESTING.value in fixture.contexts
@@ -244,12 +258,14 @@ class TestLoadFixtures:
async def test_load_single_fixture(self, db_session: AsyncSession): async def test_load_single_fixture(self, db_session: AsyncSession):
"""Load a single fixture.""" """Load a single fixture."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id_1 = uuid.uuid4()
role_id_2 = uuid.uuid4()
@registry.register @registry.register
def roles(): def roles():
return [ return [
Role(id=1, name="admin"), Role(id=role_id_1, name="admin"),
Role(id=2, name="user"), Role(id=role_id_2, name="user"),
] ]
result = await load_fixtures(db_session, registry, "roles") result = await load_fixtures(db_session, registry, "roles")
@@ -266,14 +282,23 @@ class TestLoadFixtures:
async def test_load_with_dependencies(self, db_session: AsyncSession): async def test_load_with_dependencies(self, db_session: AsyncSession):
"""Load fixtures with dependencies.""" """Load fixtures with dependencies."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id = uuid.uuid4()
user_id = uuid.uuid4()
@registry.register @registry.register
def roles(): def roles():
return [Role(id=1, name="admin")] return [Role(id=role_id, name="admin")]
@registry.register(depends_on=["roles"]) @registry.register(depends_on=["roles"])
def users(): def users():
return [User(id=1, username="admin", email="admin@test.com", role_id=1)] return [
User(
id=user_id,
username="admin",
email="admin@test.com",
role_id=role_id,
)
]
result = await load_fixtures(db_session, registry, "users") result = await load_fixtures(db_session, registry, "users")
@@ -289,10 +314,11 @@ class TestLoadFixtures:
async def test_load_with_merge_strategy(self, db_session: AsyncSession): async def test_load_with_merge_strategy(self, db_session: AsyncSession):
"""Load fixtures with MERGE strategy updates existing.""" """Load fixtures with MERGE strategy updates existing."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id = uuid.uuid4()
@registry.register @registry.register
def roles(): def roles():
return [Role(id=1, name="admin")] return [Role(id=role_id, name="admin")]
await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE) await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE)
await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE) await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE)
@@ -306,10 +332,11 @@ class TestLoadFixtures:
async def test_load_with_skip_existing_strategy(self, db_session: AsyncSession): async def test_load_with_skip_existing_strategy(self, db_session: AsyncSession):
"""Load fixtures with SKIP_EXISTING strategy.""" """Load fixtures with SKIP_EXISTING strategy."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id = uuid.uuid4()
@registry.register @registry.register
def roles(): def roles():
return [Role(id=1, name="original")] return [Role(id=role_id, name="original")]
await load_fixtures( await load_fixtures(
db_session, registry, "roles", strategy=LoadStrategy.SKIP_EXISTING db_session, registry, "roles", strategy=LoadStrategy.SKIP_EXISTING
@@ -317,7 +344,7 @@ class TestLoadFixtures:
@registry.register(name="roles_updated") @registry.register(name="roles_updated")
def roles_v2(): def roles_v2():
return [Role(id=1, name="updated")] return [Role(id=role_id, name="updated")]
registry._fixtures["roles"] = registry._fixtures.pop("roles_updated") registry._fixtures["roles"] = registry._fixtures.pop("roles_updated")
@@ -327,7 +354,7 @@ class TestLoadFixtures:
from .conftest import RoleCrud from .conftest import RoleCrud
role = await RoleCrud.first(db_session, [Role.id == 1]) role = await RoleCrud.first(db_session, [Role.id == role_id])
assert role is not None assert role is not None
assert role.name == "original" assert role.name == "original"
@@ -335,12 +362,14 @@ class TestLoadFixtures:
async def test_load_with_insert_strategy(self, db_session: AsyncSession): async def test_load_with_insert_strategy(self, db_session: AsyncSession):
"""Load fixtures with INSERT strategy.""" """Load fixtures with INSERT strategy."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id_1 = uuid.uuid4()
role_id_2 = uuid.uuid4()
@registry.register @registry.register
def roles(): def roles():
return [ return [
Role(id=1, name="admin"), Role(id=role_id_1, name="admin"),
Role(id=2, name="user"), Role(id=role_id_2, name="user"),
] ]
result = await load_fixtures( result = await load_fixtures(
@@ -375,14 +404,16 @@ class TestLoadFixtures:
): ):
"""Load multiple independent fixtures.""" """Load multiple independent fixtures."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id_1 = uuid.uuid4()
role_id_2 = uuid.uuid4()
@registry.register @registry.register
def roles(): def roles():
return [Role(id=1, name="admin")] return [Role(id=role_id_1, name="admin")]
@registry.register @registry.register
def other_roles(): def other_roles():
return [Role(id=2, name="user")] return [Role(id=role_id_2, name="user")]
result = await load_fixtures(db_session, registry, "roles", "other_roles") result = await load_fixtures(db_session, registry, "roles", "other_roles")
@@ -402,14 +433,16 @@ class TestLoadFixturesByContext:
async def test_load_by_single_context(self, db_session: AsyncSession): async def test_load_by_single_context(self, db_session: AsyncSession):
"""Load fixtures by single context.""" """Load fixtures by single context."""
registry = FixtureRegistry() registry = FixtureRegistry()
base_role_id = uuid.uuid4()
test_role_id = uuid.uuid4()
@registry.register(contexts=[Context.BASE]) @registry.register(contexts=[Context.BASE])
def base_roles(): def base_roles():
return [Role(id=1, name="base_role")] return [Role(id=base_role_id, name="base_role")]
@registry.register(contexts=[Context.TESTING]) @registry.register(contexts=[Context.TESTING])
def test_roles(): def test_roles():
return [Role(id=100, name="test_role")] return [Role(id=test_role_id, name="test_role")]
await load_fixtures_by_context(db_session, registry, Context.BASE) await load_fixtures_by_context(db_session, registry, Context.BASE)
@@ -418,7 +451,7 @@ class TestLoadFixturesByContext:
count = await RoleCrud.count(db_session) count = await RoleCrud.count(db_session)
assert count == 1 assert count == 1
role = await RoleCrud.first(db_session, [Role.id == 1]) role = await RoleCrud.first(db_session, [Role.id == base_role_id])
assert role is not None assert role is not None
assert role.name == "base_role" assert role.name == "base_role"
@@ -426,14 +459,16 @@ class TestLoadFixturesByContext:
async def test_load_by_multiple_contexts(self, db_session: AsyncSession): async def test_load_by_multiple_contexts(self, db_session: AsyncSession):
"""Load fixtures by multiple contexts.""" """Load fixtures by multiple contexts."""
registry = FixtureRegistry() registry = FixtureRegistry()
base_role_id = uuid.uuid4()
test_role_id = uuid.uuid4()
@registry.register(contexts=[Context.BASE]) @registry.register(contexts=[Context.BASE])
def base_roles(): def base_roles():
return [Role(id=1, name="base_role")] return [Role(id=base_role_id, name="base_role")]
@registry.register(contexts=[Context.TESTING]) @registry.register(contexts=[Context.TESTING])
def test_roles(): def test_roles():
return [Role(id=100, name="test_role")] return [Role(id=test_role_id, name="test_role")]
await load_fixtures_by_context( await load_fixtures_by_context(
db_session, registry, Context.BASE, Context.TESTING db_session, registry, Context.BASE, Context.TESTING
@@ -448,14 +483,23 @@ class TestLoadFixturesByContext:
async def test_load_context_with_dependencies(self, db_session: AsyncSession): async def test_load_context_with_dependencies(self, db_session: AsyncSession):
"""Load context fixtures with cross-context dependencies.""" """Load context fixtures with cross-context dependencies."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id = uuid.uuid4()
user_id = uuid.uuid4()
@registry.register(contexts=[Context.BASE]) @registry.register(contexts=[Context.BASE])
def roles(): def roles():
return [Role(id=1, name="admin")] return [Role(id=role_id, name="admin")]
@registry.register(depends_on=["roles"], contexts=[Context.TESTING]) @registry.register(depends_on=["roles"], contexts=[Context.TESTING])
def test_users(): def test_users():
return [User(id=1, username="tester", email="test@test.com", role_id=1)] return [
User(
id=user_id,
username="tester",
email="test@test.com",
role_id=role_id,
)
]
await load_fixtures_by_context(db_session, registry, Context.TESTING) await load_fixtures_by_context(db_session, registry, Context.TESTING)
@@ -471,20 +515,41 @@ class TestGetObjByAttr:
def setup_method(self): def setup_method(self):
"""Set up test fixtures for each test.""" """Set up test fixtures for each test."""
self.registry = FixtureRegistry() self.registry = FixtureRegistry()
self.role_id_1 = uuid.uuid4()
self.role_id_2 = uuid.uuid4()
self.role_id_3 = uuid.uuid4()
self.user_id_1 = uuid.uuid4()
self.user_id_2 = uuid.uuid4()
role_id_1 = self.role_id_1
role_id_2 = self.role_id_2
role_id_3 = self.role_id_3
user_id_1 = self.user_id_1
user_id_2 = self.user_id_2
@self.registry.register @self.registry.register
def roles() -> list[Role]: def roles() -> list[Role]:
return [ return [
Role(id=1, name="admin"), Role(id=role_id_1, name="admin"),
Role(id=2, name="user"), Role(id=role_id_2, name="user"),
Role(id=3, name="moderator"), Role(id=role_id_3, name="moderator"),
] ]
@self.registry.register(depends_on=["roles"]) @self.registry.register(depends_on=["roles"])
def users() -> list[User]: def users() -> list[User]:
return [ return [
User(id=1, username="alice", email="alice@example.com", role_id=1), User(
User(id=2, username="bob", email="bob@example.com", role_id=1), id=user_id_1,
username="alice",
email="alice@example.com",
role_id=role_id_1,
),
User(
id=user_id_2,
username="bob",
email="bob@example.com",
role_id=role_id_1,
),
] ]
self.roles = roles self.roles = roles
@@ -492,18 +557,18 @@ class TestGetObjByAttr:
def test_get_by_id(self): def test_get_by_id(self):
"""Get an object by its id attribute.""" """Get an object by its id attribute."""
role = get_obj_by_attr(self.roles, "id", 1) role = get_obj_by_attr(self.roles, "id", self.role_id_1)
assert role.name == "admin" assert role.name == "admin"
def test_get_user_by_username(self): def test_get_user_by_username(self):
"""Get a user by username.""" """Get a user by username."""
user = get_obj_by_attr(self.users, "username", "bob") user = get_obj_by_attr(self.users, "username", "bob")
assert user.id == 2 assert user.id == self.user_id_2
assert user.email == "bob@example.com" assert user.email == "bob@example.com"
def test_returns_first_match(self): def test_returns_first_match(self):
"""Returns the first matching object when multiple could match.""" """Returns the first matching object when multiple could match."""
user = get_obj_by_attr(self.users, "role_id", 1) user = get_obj_by_attr(self.users, "role_id", self.role_id_1)
assert user.username == "alice" assert user.username == "alice"
def test_no_match_raises_stop_iteration(self): def test_no_match_raises_stop_iteration(self):
@@ -514,4 +579,4 @@ class TestGetObjByAttr:
def test_no_match_on_wrong_value_type(self): def test_no_match_on_wrong_value_type(self):
"""Raises StopIteration when value type doesn't match.""" """Raises StopIteration when value type doesn't match."""
with pytest.raises(StopIteration): with pytest.raises(StopIteration):
get_obj_by_attr(self.roles, "id", "1") get_obj_by_attr(self.roles, "id", "not-a-uuid")

View File

@@ -1,5 +1,7 @@
"""Tests for fastapi_toolsets.pytest module.""" """Tests for fastapi_toolsets.pytest module."""
import uuid
import pytest import pytest
from fastapi import FastAPI from fastapi import FastAPI
from httpx import AsyncClient from httpx import AsyncClient
@@ -18,27 +20,49 @@ from .conftest import DATABASE_URL, Base, Role, RoleCrud, User, UserCrud
test_registry = FixtureRegistry() test_registry = FixtureRegistry()
# Fixed UUIDs for test fixtures to allow consistent assertions
ROLE_ADMIN_ID = uuid.UUID("00000000-0000-0000-0000-000000001000")
ROLE_USER_ID = uuid.UUID("00000000-0000-0000-0000-000000001001")
USER_ADMIN_ID = uuid.UUID("00000000-0000-0000-0000-000000002000")
USER_USER_ID = uuid.UUID("00000000-0000-0000-0000-000000002001")
USER_EXTRA_ID = uuid.UUID("00000000-0000-0000-0000-000000002002")
@test_registry.register(contexts=[Context.BASE]) @test_registry.register(contexts=[Context.BASE])
def roles() -> list[Role]: def roles() -> list[Role]:
return [ return [
Role(id=1000, name="plugin_admin"), Role(id=ROLE_ADMIN_ID, name="plugin_admin"),
Role(id=1001, name="plugin_user"), Role(id=ROLE_USER_ID, name="plugin_user"),
] ]
@test_registry.register(depends_on=["roles"], contexts=[Context.BASE]) @test_registry.register(depends_on=["roles"], contexts=[Context.BASE])
def users() -> list[User]: def users() -> list[User]:
return [ return [
User(id=1000, username="plugin_admin", email="padmin@test.com", role_id=1000), User(
User(id=1001, username="plugin_user", email="puser@test.com", role_id=1001), id=USER_ADMIN_ID,
username="plugin_admin",
email="padmin@test.com",
role_id=ROLE_ADMIN_ID,
),
User(
id=USER_USER_ID,
username="plugin_user",
email="puser@test.com",
role_id=ROLE_USER_ID,
),
] ]
@test_registry.register(depends_on=["users"], contexts=[Context.TESTING]) @test_registry.register(depends_on=["users"], contexts=[Context.TESTING])
def extra_users() -> list[User]: def extra_users() -> list[User]:
return [ return [
User(id=1002, username="plugin_extra", email="pextra@test.com", role_id=1001), User(
id=USER_EXTRA_ID,
username="plugin_extra",
email="pextra@test.com",
role_id=ROLE_USER_ID,
),
] ]
@@ -73,7 +97,7 @@ class TestGeneratedFixtures:
assert fixture_roles[1].name == "plugin_user" assert fixture_roles[1].name == "plugin_user"
# Verify data is in database # Verify data is in database
count = await RoleCrud.count(db_session, [Role.id >= 1000]) count = await RoleCrud.count(db_session)
assert count == 2 assert count == 2
@pytest.mark.anyio @pytest.mark.anyio
@@ -86,11 +110,11 @@ class TestGeneratedFixtures:
assert len(fixture_users) == 2 assert len(fixture_users) == 2
# Roles should also be in database # Roles should also be in database
roles_count = await RoleCrud.count(db_session, [Role.id >= 1000]) roles_count = await RoleCrud.count(db_session)
assert roles_count == 2 assert roles_count == 2
# Users should be in database # Users should be in database
users_count = await UserCrud.count(db_session, [User.id >= 1000]) users_count = await UserCrud.count(db_session)
assert users_count == 2 assert users_count == 2
@pytest.mark.anyio @pytest.mark.anyio
@@ -100,7 +124,7 @@ class TestGeneratedFixtures:
"""Fixture returns actual model instances.""" """Fixture returns actual model instances."""
user = fixture_users[0] user = fixture_users[0]
assert isinstance(user, User) assert isinstance(user, User)
assert user.id == 1000 assert user.id == USER_ADMIN_ID
assert user.username == "plugin_admin" assert user.username == "plugin_admin"
@pytest.mark.anyio @pytest.mark.anyio
@@ -111,7 +135,7 @@ class TestGeneratedFixtures:
# Load user with role relationship # Load user with role relationship
user = await UserCrud.get( user = await UserCrud.get(
db_session, db_session,
[User.id == 1000], [User.id == USER_ADMIN_ID],
load_options=[selectinload(User.role)], load_options=[selectinload(User.role)],
) )
@@ -127,8 +151,8 @@ class TestGeneratedFixtures:
assert len(fixture_extra_users) == 1 assert len(fixture_extra_users) == 1
# All fixtures should be loaded # All fixtures should be loaded
roles_count = await RoleCrud.count(db_session, [Role.id >= 1000]) roles_count = await RoleCrud.count(db_session)
users_count = await UserCrud.count(db_session, [User.id >= 1000]) users_count = await UserCrud.count(db_session)
assert roles_count == 2 assert roles_count == 2
assert users_count == 3 # 2 from users + 1 from extra_users assert users_count == 3 # 2 from users + 1 from extra_users
@@ -141,8 +165,7 @@ class TestGeneratedFixtures:
# Get all users loaded by fixture # Get all users loaded by fixture
users = await UserCrud.get_multi( users = await UserCrud.get_multi(
db_session, db_session,
filters=[User.id >= 1000], order_by=User.username,
order_by=User.id,
) )
assert len(users) == 2 assert len(users) == 2
@@ -161,8 +184,8 @@ class TestGeneratedFixtures:
assert len(fixture_users) == 2 assert len(fixture_users) == 2
# Both should be in database # Both should be in database
roles = await RoleCrud.get_multi(db_session, filters=[Role.id >= 1000]) roles = await RoleCrud.get_multi(db_session)
users = await UserCrud.get_multi(db_session, filters=[User.id >= 1000]) users = await UserCrud.get_multi(db_session)
assert len(roles) == 2 assert len(roles) == 2
assert len(users) == 2 assert len(users) == 2
@@ -215,14 +238,15 @@ class TestCreateDbSession:
@pytest.mark.anyio @pytest.mark.anyio
async def test_creates_working_session(self): async def test_creates_working_session(self):
"""Session can perform database operations.""" """Session can perform database operations."""
role_id = uuid.uuid4()
async with create_db_session(DATABASE_URL, Base) as session: async with create_db_session(DATABASE_URL, Base) as session:
assert isinstance(session, AsyncSession) assert isinstance(session, AsyncSession)
role = Role(id=9001, name="test_helper_role") role = Role(id=role_id, name="test_helper_role")
session.add(role) session.add(role)
await session.commit() await session.commit()
result = await session.execute(select(Role).where(Role.id == 9001)) result = await session.execute(select(Role).where(Role.id == role_id))
fetched = result.scalar_one() fetched = result.scalar_one()
assert fetched.name == "test_helper_role" assert fetched.name == "test_helper_role"
@@ -237,8 +261,9 @@ class TestCreateDbSession:
@pytest.mark.anyio @pytest.mark.anyio
async def test_tables_dropped_after_session(self): async def test_tables_dropped_after_session(self):
"""Tables are dropped after session closes when drop_tables=True.""" """Tables are dropped after session closes when drop_tables=True."""
role_id = uuid.uuid4()
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session: async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
role = Role(id=9002, name="will_be_dropped") role = Role(id=role_id, name="will_be_dropped")
session.add(role) session.add(role)
await session.commit() await session.commit()
@@ -250,14 +275,15 @@ class TestCreateDbSession:
@pytest.mark.anyio @pytest.mark.anyio
async def test_tables_preserved_when_drop_disabled(self): async def test_tables_preserved_when_drop_disabled(self):
"""Tables are preserved when drop_tables=False.""" """Tables are preserved when drop_tables=False."""
role_id = uuid.uuid4()
async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session: async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session:
role = Role(id=9003, name="preserved_role") role = Role(id=role_id, name="preserved_role")
session.add(role) session.add(role)
await session.commit() await session.commit()
# Create another session without dropping # Create another session without dropping
async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session: async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session:
result = await session.execute(select(Role).where(Role.id == 9003)) result = await session.execute(select(Role).where(Role.id == role_id))
fetched = result.scalar_one_or_none() fetched = result.scalar_one_or_none()
assert fetched is not None assert fetched is not None
assert fetched.name == "preserved_role" assert fetched.name == "preserved_role"

2
uv.lock generated
View File

@@ -220,7 +220,7 @@ wheels = [
[[package]] [[package]]
name = "fastapi-toolsets" name = "fastapi-toolsets"
version = "0.4.0" version = "0.4.1"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "asyncpg" }, { name = "asyncpg" },