7 Commits

Author SHA1 Message Date
54f5479c24 Version 0.4.1 2026-01-29 14:15:55 -05:00
d3vyce
f467754df1 fix: cast to String non-text columns for crud search (#18)
fix: cast to String non-text columns for crud search
2026-01-29 19:44:48 +01:00
b57ce40b05 tests: change models to use UUID as primary key 2026-01-29 13:43:03 -05:00
5264631550 fix: cast to String non-text columns for crud search 2026-01-29 13:35:20 -05:00
a76f7c439d Version 0.4.0 2026-01-29 09:15:33 -05:00
d3vyce
d14551781c feat: add search to crud paginate function (#17)
* feat: add search to crud paginate function

* fixes: comments + tests import
2026-01-29 00:08:02 +01:00
d3vyce
577e087321 feat: add support for python 3.14 (#15) 2026-01-28 21:01:15 +01:00
16 changed files with 857 additions and 115 deletions

View File

@@ -17,7 +17,7 @@ jobs:
uses: astral-sh/setup-uv@v7 uses: astral-sh/setup-uv@v7
- name: Set up Python - name: Set up Python
run: uv python install 3.13 run: uv python install 3.14
- name: Install dependencies - name: Install dependencies
run: uv sync run: uv sync

View File

@@ -56,7 +56,7 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
python-version: ["3.11", "3.12", "3.13"] python-version: ["3.11", "3.12", "3.13", "3.14"]
services: services:
postgres: postgres:
@@ -92,7 +92,7 @@ jobs:
uv run pytest --cov --cov-report=xml --cov-report=term-missing uv run pytest --cov --cov-report=xml --cov-report=term-missing
- name: Upload coverage to Codecov - name: Upload coverage to Codecov
if: matrix.python-version == '3.13' if: matrix.python-version == '3.14'
uses: codecov/codecov-action@v5 uses: codecov/codecov-action@v5
with: with:
token: ${{ secrets.CODECOV_TOKEN }} token: ${{ secrets.CODECOV_TOKEN }}

View File

@@ -1 +1 @@
3.13 3.14

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "fastapi-toolsets" name = "fastapi-toolsets"
version = "0.3.0" version = "0.4.1"
description = "Reusable tools for FastAPI: async CRUD, fixtures, CLI, and standardized responses for SQLAlchemy + PostgreSQL" description = "Reusable tools for FastAPI: async CRUD, fixtures, CLI, and standardized responses for SQLAlchemy + PostgreSQL"
readme = "README.md" readme = "README.md"
license = "MIT" license = "MIT"
@@ -24,6 +24,7 @@ classifiers = [
"Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13", "Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
"Topic :: Software Development :: Libraries :: Python Modules", "Topic :: Software Development :: Libraries :: Python Modules",
"Topic :: Software Development :: Libraries", "Topic :: Software Development :: Libraries",
"Topic :: Software Development", "Topic :: Software Development",

View File

@@ -21,4 +21,4 @@ Example usage:
return Response(data={"user": user.username}, message="Success") return Response(data={"user": user.username}, message="Success")
""" """
__version__ = "0.3.0" __version__ = "0.4.1"

View File

@@ -0,0 +1,17 @@
"""Generic async CRUD operations for SQLAlchemy models."""
from ..exceptions import NoSearchableFieldsError
from .factory import CrudFactory
from .search import (
SearchConfig,
SearchFieldType,
get_searchable_fields,
)
__all__ = [
"CrudFactory",
"get_searchable_fields",
"NoSearchableFieldsError",
"SearchConfig",
"SearchFieldType",
]

View File

@@ -12,13 +12,9 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.sql.roles import WhereHavingRole from sqlalchemy.sql.roles import WhereHavingRole
from .db import get_transaction from ..db import get_transaction
from .exceptions import NotFoundError from ..exceptions import NotFoundError
from .search import SearchConfig, SearchFieldType, build_search_filters
__all__ = [
"AsyncCrud",
"CrudFactory",
]
ModelType = TypeVar("ModelType", bound=DeclarativeBase) ModelType = TypeVar("ModelType", bound=DeclarativeBase)
@@ -27,20 +23,10 @@ class AsyncCrud(Generic[ModelType]):
"""Generic async CRUD operations for SQLAlchemy models. """Generic async CRUD operations for SQLAlchemy models.
Subclass this and set the `model` class variable, or use `CrudFactory`. Subclass this and set the `model` class variable, or use `CrudFactory`.
Example:
class UserCrud(AsyncCrud[User]):
model = User
# Or use the factory:
UserCrud = CrudFactory(User)
# Then use it:
user = await UserCrud.get(session, [User.id == 1])
users = await UserCrud.get_multi(session, limit=10)
""" """
model: ClassVar[type[DeclarativeBase]] model: ClassVar[type[DeclarativeBase]]
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
@classmethod @classmethod
async def create( async def create(
@@ -313,6 +299,8 @@ class AsyncCrud(Generic[ModelType]):
order_by: Any | None = None, order_by: Any | None = None,
page: int = 1, page: int = 1,
items_per_page: int = 20, items_per_page: int = 20,
search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Get paginated results with metadata. """Get paginated results with metadata.
@@ -323,23 +311,54 @@ class AsyncCrud(Generic[ModelType]):
order_by: Column or list of columns to order by order_by: Column or list of columns to order by
page: Page number (1-indexed) page: Page number (1-indexed)
items_per_page: Number of items per page items_per_page: Number of items per page
search: Search query string or SearchConfig object
search_fields: Fields to search in (overrides class default)
Returns: Returns:
Dict with 'data' and 'pagination' keys Dict with 'data' and 'pagination' keys
""" """
filters = filters or [] filters = list(filters) if filters else []
offset = (page - 1) * items_per_page offset = (page - 1) * items_per_page
joins: list[Any] = []
items = await cls.get_multi( # Build search filters
session, if search:
filters=filters, search_filters, search_joins = build_search_filters(
load_options=load_options, cls.model,
order_by=order_by, search,
limit=items_per_page, search_fields=search_fields,
offset=offset, default_fields=cls.searchable_fields,
) )
filters.extend(search_filters)
joins.extend(search_joins)
total_count = await cls.count(session, filters=filters) # Build query with joins
q = select(cls.model)
for join_rel in joins:
q = q.outerjoin(join_rel)
if filters:
q = q.where(and_(*filters))
if load_options:
q = q.options(*load_options)
if order_by is not None:
q = q.order_by(order_by)
q = q.offset(offset).limit(items_per_page)
result = await session.execute(q)
items = result.unique().scalars().all()
# Count query (with same joins and filters)
pk_col = cls.model.__mapper__.primary_key[0]
count_q = select(func.count(func.distinct(getattr(cls.model, pk_col.name))))
count_q = count_q.select_from(cls.model)
for join_rel in joins:
count_q = count_q.outerjoin(join_rel)
if filters:
count_q = count_q.where(and_(*filters))
count_result = await session.execute(count_q)
total_count = count_result.scalar_one()
return { return {
"data": items, "data": items,
@@ -354,11 +373,14 @@ class AsyncCrud(Generic[ModelType]):
def CrudFactory( def CrudFactory(
model: type[ModelType], model: type[ModelType],
*,
searchable_fields: Sequence[SearchFieldType] | None = None,
) -> type[AsyncCrud[ModelType]]: ) -> type[AsyncCrud[ModelType]]:
"""Create a CRUD class for a specific model. """Create a CRUD class for a specific model.
Args: Args:
model: SQLAlchemy model class model: SQLAlchemy model class
searchable_fields: Optional list of searchable fields
Returns: Returns:
AsyncCrud subclass bound to the model AsyncCrud subclass bound to the model
@@ -370,9 +392,25 @@ def CrudFactory(
UserCrud = CrudFactory(User) UserCrud = CrudFactory(User)
PostCrud = CrudFactory(Post) PostCrud = CrudFactory(Post)
# With searchable fields:
UserCrud = CrudFactory(
User,
searchable_fields=[User.username, User.email, (User.role, Role.name)]
)
# Usage # Usage
user = await UserCrud.get(session, [User.id == 1]) user = await UserCrud.get(session, [User.id == 1])
posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id]) posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id])
# With search
result = await UserCrud.paginate(session, search="john")
""" """
cls = type(f"Async{model.__name__}Crud", (AsyncCrud,), {"model": model}) cls = type(
f"Async{model.__name__}Crud",
(AsyncCrud,),
{
"model": model,
"searchable_fields": searchable_fields,
},
)
return cast(type[AsyncCrud[ModelType]], cls) return cast(type[AsyncCrud[ModelType]], cls)

View File

@@ -0,0 +1,146 @@
"""Search utilities for AsyncCrud."""
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal
from sqlalchemy import String, or_
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm.attributes import InstrumentedAttribute
from ..exceptions import NoSearchableFieldsError
if TYPE_CHECKING:
from sqlalchemy.sql.elements import ColumnElement
SearchFieldType = InstrumentedAttribute[Any] | tuple[InstrumentedAttribute[Any], ...]
@dataclass
class SearchConfig:
"""Advanced search configuration.
Attributes:
query: The search string
fields: Fields to search (columns or tuples for relationships)
case_sensitive: Case-sensitive search (default: False)
match_mode: "any" (OR) or "all" (AND) to combine fields
"""
query: str
fields: Sequence[SearchFieldType] | None = None
case_sensitive: bool = False
match_mode: Literal["any", "all"] = "any"
def get_searchable_fields(
model: type[DeclarativeBase],
*,
include_relationships: bool = True,
max_depth: int = 1,
) -> list[SearchFieldType]:
"""Auto-detect String fields on a model and its relationships.
Args:
model: SQLAlchemy model class
include_relationships: Include fields from many-to-one/one-to-one relationships
max_depth: Max depth for relationship traversal (default: 1)
Returns:
List of columns and tuples (relationship, column)
"""
fields: list[SearchFieldType] = []
mapper = model.__mapper__
# Direct String columns
for col in mapper.columns:
if isinstance(col.type, String):
fields.append(getattr(model, col.key))
# Relationships (one-to-one, many-to-one only)
if include_relationships and max_depth > 0:
for rel_name, rel_prop in mapper.relationships.items():
if rel_prop.uselist: # Skip collections (one-to-many, many-to-many)
continue
rel_attr = getattr(model, rel_name)
related_model = rel_prop.mapper.class_
for col in related_model.__mapper__.columns:
if isinstance(col.type, String):
fields.append((rel_attr, getattr(related_model, col.key)))
return fields
def build_search_filters(
model: type[DeclarativeBase],
search: str | SearchConfig,
search_fields: Sequence[SearchFieldType] | None = None,
default_fields: Sequence[SearchFieldType] | None = None,
) -> tuple[list["ColumnElement[bool]"], list[InstrumentedAttribute[Any]]]:
"""Build SQLAlchemy filter conditions for search.
Args:
model: SQLAlchemy model class
search: Search string or SearchConfig
search_fields: Fields specified per-call (takes priority)
default_fields: Default fields (from ClassVar)
Returns:
Tuple of (filter_conditions, joins_needed)
"""
# Normalize input
if isinstance(search, str):
config = SearchConfig(query=search, fields=search_fields)
else:
config = search
if search_fields is not None:
config = SearchConfig(
query=config.query,
fields=search_fields,
case_sensitive=config.case_sensitive,
match_mode=config.match_mode,
)
if not config.query or not config.query.strip():
return [], []
# Determine which fields to search
fields = config.fields or default_fields or get_searchable_fields(model)
if not fields:
raise NoSearchableFieldsError(model)
query = config.query.strip()
filters: list[ColumnElement[bool]] = []
joins: list[InstrumentedAttribute[Any]] = []
added_joins: set[str] = set()
for field in fields:
if isinstance(field, tuple):
# Relationship: (User.role, Role.name) or deeper
for rel in field[:-1]:
rel_key = str(rel)
if rel_key not in added_joins:
joins.append(rel)
added_joins.add(rel_key)
column = field[-1]
else:
column = field
# Build the filter (cast to String for non-text columns)
column_as_string = column.cast(String)
if config.case_sensitive:
filters.append(column_as_string.like(f"%{query}%"))
else:
filters.append(column_as_string.ilike(f"%{query}%"))
if not filters:
return [], []
# Combine based on match_mode
if config.match_mode == "any":
return [or_(*filters)], joins
else:
return filters, joins

View File

@@ -1,7 +1,9 @@
from .exceptions import ( from .exceptions import (
ApiError,
ApiException, ApiException,
ConflictError, ConflictError,
ForbiddenError, ForbiddenError,
NoSearchableFieldsError,
NotFoundError, NotFoundError,
UnauthorizedError, UnauthorizedError,
generate_error_responses, generate_error_responses,
@@ -9,11 +11,13 @@ from .exceptions import (
from .handler import init_exceptions_handlers from .handler import init_exceptions_handlers
__all__ = [ __all__ = [
"init_exceptions_handlers", "ApiError",
"generate_error_responses",
"ApiException", "ApiException",
"ConflictError", "ConflictError",
"ForbiddenError", "ForbiddenError",
"generate_error_responses",
"init_exceptions_handlers",
"NoSearchableFieldsError",
"NotFoundError", "NotFoundError",
"UnauthorizedError", "UnauthorizedError",
] ]

View File

@@ -119,6 +119,25 @@ class RoleNotFoundError(NotFoundError):
) )
class NoSearchableFieldsError(ApiException):
"""Raised when search is requested but no searchable fields are available."""
api_error = ApiError(
code=400,
msg="No Searchable Fields",
desc="No searchable fields configured for this resource.",
err_code="SEARCH-400",
)
def __init__(self, model: type) -> None:
self.model = model
detail = (
f"No searchable fields found for model '{model.__name__}'. "
"Provide 'search_fields' parameter or set 'searchable_fields' on the CRUD class."
)
super().__init__(detail)
def generate_error_responses( def generate_error_responses(
*errors: type[ApiException], *errors: type[ApiException],
) -> dict[int | str, dict[str, Any]]: ) -> dict[int | str, dict[str, Any]]:

View File

@@ -1,10 +1,11 @@
"""Shared pytest fixtures for fastapi-utils tests.""" """Shared pytest fixtures for fastapi-utils tests."""
import os import os
import uuid
import pytest import pytest
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import ForeignKey, String from sqlalchemy import ForeignKey, String, Uuid
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
@@ -33,7 +34,7 @@ class Role(Base):
__tablename__ = "roles" __tablename__ = "roles"
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
name: Mapped[str] = mapped_column(String(50), unique=True) name: Mapped[str] = mapped_column(String(50), unique=True)
users: Mapped[list["User"]] = relationship(back_populates="role") users: Mapped[list["User"]] = relationship(back_populates="role")
@@ -44,11 +45,13 @@ class User(Base):
__tablename__ = "users" __tablename__ = "users"
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
username: Mapped[str] = mapped_column(String(50), unique=True) username: Mapped[str] = mapped_column(String(50), unique=True)
email: Mapped[str] = mapped_column(String(100), unique=True) email: Mapped[str] = mapped_column(String(100), unique=True)
is_active: Mapped[bool] = mapped_column(default=True) is_active: Mapped[bool] = mapped_column(default=True)
role_id: Mapped[int | None] = mapped_column(ForeignKey("roles.id"), nullable=True) role_id: Mapped[uuid.UUID | None] = mapped_column(
ForeignKey("roles.id"), nullable=True
)
role: Mapped[Role | None] = relationship(back_populates="users") role: Mapped[Role | None] = relationship(back_populates="users")
@@ -58,11 +61,11 @@ class Post(Base):
__tablename__ = "posts" __tablename__ = "posts"
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
title: Mapped[str] = mapped_column(String(200)) title: Mapped[str] = mapped_column(String(200))
content: Mapped[str] = mapped_column(String(1000), default="") content: Mapped[str] = mapped_column(String(1000), default="")
is_published: Mapped[bool] = mapped_column(default=False) is_published: Mapped[bool] = mapped_column(default=False)
author_id: Mapped[int] = mapped_column(ForeignKey("users.id")) author_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"))
# ============================================================================= # =============================================================================
@@ -73,7 +76,7 @@ class Post(Base):
class RoleCreate(BaseModel): class RoleCreate(BaseModel):
"""Schema for creating a role.""" """Schema for creating a role."""
id: int | None = None id: uuid.UUID | None = None
name: str name: str
@@ -86,11 +89,11 @@ class RoleUpdate(BaseModel):
class UserCreate(BaseModel): class UserCreate(BaseModel):
"""Schema for creating a user.""" """Schema for creating a user."""
id: int | None = None id: uuid.UUID | None = None
username: str username: str
email: str email: str
is_active: bool = True is_active: bool = True
role_id: int | None = None role_id: uuid.UUID | None = None
class UserUpdate(BaseModel): class UserUpdate(BaseModel):
@@ -99,17 +102,17 @@ class UserUpdate(BaseModel):
username: str | None = None username: str | None = None
email: str | None = None email: str | None = None
is_active: bool | None = None is_active: bool | None = None
role_id: int | None = None role_id: uuid.UUID | None = None
class PostCreate(BaseModel): class PostCreate(BaseModel):
"""Schema for creating a post.""" """Schema for creating a post."""
id: int | None = None id: uuid.UUID | None = None
title: str title: str
content: str = "" content: str = ""
is_published: bool = False is_published: bool = False
author_id: int author_id: uuid.UUID
class PostUpdate(BaseModel): class PostUpdate(BaseModel):
@@ -195,5 +198,5 @@ def sample_post_data() -> PostCreate:
title="Test Post", title="Test Post",
content="Test content", content="Test content",
is_published=True, is_published=True,
author_id=1, author_id=uuid.uuid4(),
) )

View File

@@ -1,9 +1,12 @@
"""Tests for fastapi_toolsets.crud module.""" """Tests for fastapi_toolsets.crud module."""
import uuid
import pytest import pytest
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from fastapi_toolsets.crud import AsyncCrud, CrudFactory from fastapi_toolsets.crud import CrudFactory
from fastapi_toolsets.crud.factory import AsyncCrud
from fastapi_toolsets.exceptions import NotFoundError from fastapi_toolsets.exceptions import NotFoundError
from .conftest import ( from .conftest import (
@@ -88,8 +91,9 @@ class TestCrudGet:
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_raises_not_found(self, db_session: AsyncSession): async def test_get_raises_not_found(self, db_session: AsyncSession):
"""Get raises NotFoundError for missing records.""" """Get raises NotFoundError for missing records."""
non_existent_id = uuid.uuid4()
with pytest.raises(NotFoundError): with pytest.raises(NotFoundError):
await RoleCrud.get(db_session, [Role.id == 99999]) await RoleCrud.get(db_session, [Role.id == non_existent_id])
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_with_multiple_filters(self, db_session: AsyncSession): async def test_get_with_multiple_filters(self, db_session: AsyncSession):
@@ -222,11 +226,12 @@ class TestCrudUpdate:
@pytest.mark.anyio @pytest.mark.anyio
async def test_update_raises_not_found(self, db_session: AsyncSession): async def test_update_raises_not_found(self, db_session: AsyncSession):
"""Update raises NotFoundError for missing records.""" """Update raises NotFoundError for missing records."""
non_existent_id = uuid.uuid4()
with pytest.raises(NotFoundError): with pytest.raises(NotFoundError):
await RoleCrud.update( await RoleCrud.update(
db_session, db_session,
RoleUpdate(name="new"), RoleUpdate(name="new"),
[Role.id == 99999], [Role.id == non_existent_id],
) )
@pytest.mark.anyio @pytest.mark.anyio
@@ -339,7 +344,8 @@ class TestCrudUpsert:
@pytest.mark.anyio @pytest.mark.anyio
async def test_upsert_insert_new_record(self, db_session: AsyncSession): async def test_upsert_insert_new_record(self, db_session: AsyncSession):
"""Upsert inserts a new record when it doesn't exist.""" """Upsert inserts a new record when it doesn't exist."""
data = RoleCreate(id=1, name="upsert_new") role_id = uuid.uuid4()
data = RoleCreate(id=role_id, name="upsert_new")
role = await RoleCrud.upsert( role = await RoleCrud.upsert(
db_session, db_session,
data, data,
@@ -352,12 +358,13 @@ class TestCrudUpsert:
@pytest.mark.anyio @pytest.mark.anyio
async def test_upsert_update_existing_record(self, db_session: AsyncSession): async def test_upsert_update_existing_record(self, db_session: AsyncSession):
"""Upsert updates an existing record.""" """Upsert updates an existing record."""
role_id = uuid.uuid4()
# First insert # First insert
data = RoleCreate(id=100, name="original_name") data = RoleCreate(id=role_id, name="original_name")
await RoleCrud.upsert(db_session, data, index_elements=["id"]) await RoleCrud.upsert(db_session, data, index_elements=["id"])
# Upsert with update # Upsert with update
updated_data = RoleCreate(id=100, name="updated_name") updated_data = RoleCreate(id=role_id, name="updated_name")
role = await RoleCrud.upsert( role = await RoleCrud.upsert(
db_session, db_session,
updated_data, updated_data,
@@ -369,22 +376,23 @@ class TestCrudUpsert:
assert role.name == "updated_name" assert role.name == "updated_name"
# Verify only one record exists # Verify only one record exists
count = await RoleCrud.count(db_session, [Role.id == 100]) count = await RoleCrud.count(db_session, [Role.id == role_id])
assert count == 1 assert count == 1
@pytest.mark.anyio @pytest.mark.anyio
async def test_upsert_do_nothing_on_conflict(self, db_session: AsyncSession): async def test_upsert_do_nothing_on_conflict(self, db_session: AsyncSession):
"""Upsert does nothing on conflict when set_ is not provided.""" """Upsert does nothing on conflict when set_ is not provided."""
role_id = uuid.uuid4()
# First insert # First insert
data = RoleCreate(id=200, name="do_nothing_original") data = RoleCreate(id=role_id, name="do_nothing_original")
await RoleCrud.upsert(db_session, data, index_elements=["id"]) await RoleCrud.upsert(db_session, data, index_elements=["id"])
# Upsert without set_ (do nothing) # Upsert without set_ (do nothing)
conflict_data = RoleCreate(id=200, name="do_nothing_conflict") conflict_data = RoleCreate(id=role_id, name="do_nothing_conflict")
await RoleCrud.upsert(db_session, conflict_data, index_elements=["id"]) await RoleCrud.upsert(db_session, conflict_data, index_elements=["id"])
# Original value should be preserved # Original value should be preserved
role = await RoleCrud.first(db_session, [Role.id == 200]) role = await RoleCrud.first(db_session, [Role.id == role_id])
assert role is not None assert role is not None
assert role.name == "do_nothing_original" assert role.name == "do_nothing_original"

415
tests/test_crud_search.py Normal file
View File

@@ -0,0 +1,415 @@
"""Tests for CRUD search functionality."""
import uuid
import pytest
from sqlalchemy.ext.asyncio import AsyncSession
from fastapi_toolsets.crud import SearchConfig, get_searchable_fields
from .conftest import (
Role,
RoleCreate,
RoleCrud,
User,
UserCreate,
UserCrud,
)
class TestPaginateSearch:
"""Tests for paginate() with search."""
@pytest.mark.anyio
async def test_search_single_column(self, db_session: AsyncSession):
"""Search on a single direct column."""
await UserCrud.create(
db_session, UserCreate(username="john_doe", email="john@test.com")
)
await UserCrud.create(
db_session, UserCreate(username="jane_doe", email="jane@test.com")
)
await UserCrud.create(
db_session, UserCreate(username="bob_smith", email="bob@test.com")
)
result = await UserCrud.paginate(
db_session,
search="doe",
search_fields=[User.username],
)
assert result["pagination"]["total_count"] == 2
@pytest.mark.anyio
async def test_search_multiple_columns(self, db_session: AsyncSession):
"""Search across multiple columns (OR logic)."""
await UserCrud.create(
db_session, UserCreate(username="alice", email="alice@company.com")
)
await UserCrud.create(
db_session, UserCreate(username="company_bob", email="bob@other.com")
)
result = await UserCrud.paginate(
db_session,
search="company",
search_fields=[User.username, User.email],
)
assert result["pagination"]["total_count"] == 2
@pytest.mark.anyio
async def test_search_relationship_depth1(self, db_session: AsyncSession):
"""Search through a relationship (depth 1)."""
admin_role = await RoleCrud.create(db_session, RoleCreate(name="administrator"))
user_role = await RoleCrud.create(db_session, RoleCreate(name="basic_user"))
await UserCrud.create(
db_session,
UserCreate(username="admin1", email="a1@test.com", role_id=admin_role.id),
)
await UserCrud.create(
db_session,
UserCreate(username="admin2", email="a2@test.com", role_id=admin_role.id),
)
await UserCrud.create(
db_session,
UserCreate(username="user1", email="u1@test.com", role_id=user_role.id),
)
result = await UserCrud.paginate(
db_session,
search="admin",
search_fields=[(User.role, Role.name)],
)
assert result["pagination"]["total_count"] == 2
@pytest.mark.anyio
async def test_search_mixed_direct_and_relation(self, db_session: AsyncSession):
"""Search combining direct columns and relationships."""
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
await UserCrud.create(
db_session,
UserCreate(username="john", email="john@test.com", role_id=role.id),
)
# Search "admin" in username OR role.name
result = await UserCrud.paginate(
db_session,
search="admin",
search_fields=[User.username, (User.role, Role.name)],
)
assert result["pagination"]["total_count"] == 1
@pytest.mark.anyio
async def test_search_case_insensitive(self, db_session: AsyncSession):
"""Search is case-insensitive by default."""
await UserCrud.create(
db_session, UserCreate(username="JohnDoe", email="j@test.com")
)
result = await UserCrud.paginate(
db_session,
search="johndoe",
search_fields=[User.username],
)
assert result["pagination"]["total_count"] == 1
@pytest.mark.anyio
async def test_search_case_sensitive(self, db_session: AsyncSession):
"""Case-sensitive search with SearchConfig."""
await UserCrud.create(
db_session, UserCreate(username="JohnDoe", email="j@test.com")
)
# Should not find (case mismatch)
result = await UserCrud.paginate(
db_session,
search=SearchConfig(query="johndoe", case_sensitive=True),
search_fields=[User.username],
)
assert result["pagination"]["total_count"] == 0
# Should find (case match)
result = await UserCrud.paginate(
db_session,
search=SearchConfig(query="JohnDoe", case_sensitive=True),
search_fields=[User.username],
)
assert result["pagination"]["total_count"] == 1
@pytest.mark.anyio
async def test_search_empty_query(self, db_session: AsyncSession):
"""Empty search returns all results."""
await UserCrud.create(
db_session, UserCreate(username="user1", email="u1@test.com")
)
await UserCrud.create(
db_session, UserCreate(username="user2", email="u2@test.com")
)
result = await UserCrud.paginate(db_session, search="")
assert result["pagination"]["total_count"] == 2
result = await UserCrud.paginate(db_session, search=None)
assert result["pagination"]["total_count"] == 2
@pytest.mark.anyio
async def test_search_with_existing_filters(self, db_session: AsyncSession):
"""Search combines with existing filters (AND)."""
await UserCrud.create(
db_session,
UserCreate(username="active_john", email="aj@test.com", is_active=True),
)
await UserCrud.create(
db_session,
UserCreate(username="inactive_john", email="ij@test.com", is_active=False),
)
result = await UserCrud.paginate(
db_session,
filters=[User.is_active == True], # noqa: E712
search="john",
search_fields=[User.username],
)
assert result["pagination"]["total_count"] == 1
assert result["data"][0].username == "active_john"
@pytest.mark.anyio
async def test_search_auto_detect_fields(self, db_session: AsyncSession):
"""Auto-detect searchable fields when not specified."""
await UserCrud.create(
db_session, UserCreate(username="findme", email="other@test.com")
)
result = await UserCrud.paginate(db_session, search="findme")
assert result["pagination"]["total_count"] == 1
@pytest.mark.anyio
async def test_search_no_results(self, db_session: AsyncSession):
"""Search with no matching results."""
await UserCrud.create(
db_session, UserCreate(username="john", email="j@test.com")
)
result = await UserCrud.paginate(
db_session,
search="nonexistent",
search_fields=[User.username],
)
assert result["pagination"]["total_count"] == 0
assert result["data"] == []
@pytest.mark.anyio
async def test_search_with_pagination(self, db_session: AsyncSession):
"""Search respects pagination parameters."""
for i in range(15):
await UserCrud.create(
db_session,
UserCreate(username=f"user_{i}", email=f"user{i}@test.com"),
)
result = await UserCrud.paginate(
db_session,
search="user_",
search_fields=[User.username],
page=1,
items_per_page=5,
)
assert result["pagination"]["total_count"] == 15
assert len(result["data"]) == 5
assert result["pagination"]["has_more"] is True
@pytest.mark.anyio
async def test_search_null_relationship(self, db_session: AsyncSession):
"""Users without relationship are included (outerjoin)."""
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
await UserCrud.create(
db_session,
UserCreate(username="with_role", email="wr@test.com", role_id=role.id),
)
await UserCrud.create(
db_session,
UserCreate(username="no_role", email="nr@test.com", role_id=None),
)
# Search in username, not in role
result = await UserCrud.paginate(
db_session,
search="role",
search_fields=[User.username],
)
assert result["pagination"]["total_count"] == 2
@pytest.mark.anyio
async def test_search_with_order_by(self, db_session: AsyncSession):
"""Search works with order_by parameter."""
await UserCrud.create(
db_session, UserCreate(username="charlie", email="c@test.com")
)
await UserCrud.create(
db_session, UserCreate(username="alice", email="a@test.com")
)
await UserCrud.create(
db_session, UserCreate(username="bob", email="b@test.com")
)
result = await UserCrud.paginate(
db_session,
search="@test.com",
search_fields=[User.email],
order_by=User.username,
)
assert result["pagination"]["total_count"] == 3
usernames = [u.username for u in result["data"]]
assert usernames == ["alice", "bob", "charlie"]
@pytest.mark.anyio
async def test_search_non_string_column(self, db_session: AsyncSession):
"""Search on non-string columns (e.g., UUID) works via cast."""
user_id = uuid.UUID("12345678-1234-5678-1234-567812345678")
await UserCrud.create(
db_session, UserCreate(id=user_id, username="john", email="john@test.com")
)
await UserCrud.create(
db_session, UserCreate(username="jane", email="jane@test.com")
)
# Search by UUID (partial match)
result = await UserCrud.paginate(
db_session,
search="12345678",
search_fields=[User.id, User.username],
)
assert result["pagination"]["total_count"] == 1
assert result["data"][0].id == user_id
class TestSearchConfig:
"""Tests for SearchConfig options."""
@pytest.mark.anyio
async def test_match_mode_all(self, db_session: AsyncSession):
"""match_mode='all' requires all fields to match (AND)."""
await UserCrud.create(
db_session,
UserCreate(username="john_test", email="john_test@company.com"),
)
await UserCrud.create(
db_session,
UserCreate(username="john_other", email="other@example.com"),
)
# 'john' must be in username AND email
result = await UserCrud.paginate(
db_session,
search=SearchConfig(query="john", match_mode="all"),
search_fields=[User.username, User.email],
)
assert result["pagination"]["total_count"] == 1
assert result["data"][0].username == "john_test"
@pytest.mark.anyio
async def test_search_config_with_fields(self, db_session: AsyncSession):
"""SearchConfig can specify fields directly."""
await UserCrud.create(
db_session, UserCreate(username="test", email="findme@test.com")
)
result = await UserCrud.paginate(
db_session,
search=SearchConfig(query="findme", fields=[User.email]),
)
assert result["pagination"]["total_count"] == 1
class TestNoSearchableFieldsError:
"""Tests for NoSearchableFieldsError exception."""
def test_error_is_api_exception(self):
"""NoSearchableFieldsError inherits from ApiException."""
from fastapi_toolsets.exceptions import ApiException, NoSearchableFieldsError
assert issubclass(NoSearchableFieldsError, ApiException)
def test_error_has_api_error_fields(self):
"""NoSearchableFieldsError has proper api_error configuration."""
from fastapi_toolsets.exceptions import NoSearchableFieldsError
assert NoSearchableFieldsError.api_error.code == 400
assert NoSearchableFieldsError.api_error.err_code == "SEARCH-400"
def test_error_message_contains_model_name(self):
"""Error message includes the model name."""
from fastapi_toolsets.exceptions import NoSearchableFieldsError
error = NoSearchableFieldsError(User)
assert "User" in str(error)
assert error.model is User
def test_error_raised_when_no_fields(self):
"""Error is raised when search has no searchable fields."""
from sqlalchemy import Integer
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from fastapi_toolsets.crud.search import build_search_filters
from fastapi_toolsets.exceptions import NoSearchableFieldsError
# Model with no String columns
class NoStringBase(DeclarativeBase):
pass
class NoStringModel(NoStringBase):
__tablename__ = "no_strings"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
count: Mapped[int] = mapped_column(Integer, default=0)
with pytest.raises(NoSearchableFieldsError) as exc_info:
build_search_filters(NoStringModel, "test")
assert exc_info.value.model is NoStringModel
assert "NoStringModel" in str(exc_info.value)
class TestGetSearchableFields:
"""Tests for auto-detection of searchable fields."""
def test_detects_string_columns(self):
"""Detects String columns on the model."""
fields = get_searchable_fields(User, include_relationships=False)
# Should include username and email (String), not id or is_active
field_names = [str(f) for f in fields]
assert any("username" in f for f in field_names)
assert any("email" in f for f in field_names)
assert not any("id" in f and "role_id" not in f for f in field_names)
assert not any("is_active" in f for f in field_names)
def test_detects_relationship_fields(self):
"""Detects String fields on related models."""
fields = get_searchable_fields(User, include_relationships=True)
# Should include (User.role, Role.name)
has_role_name = any(isinstance(f, tuple) and len(f) == 2 for f in fields)
assert has_role_name
def test_skips_collection_relationships(self):
"""Skips one-to-many relationships."""
fields = get_searchable_fields(Role, include_relationships=True)
# Role.users is a collection, should not be included
field_strs = [str(f) for f in fields]
assert not any("users" in f for f in field_strs)

View File

@@ -1,5 +1,7 @@
"""Tests for fastapi_toolsets.fixtures module.""" """Tests for fastapi_toolsets.fixtures module."""
import uuid
import pytest import pytest
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@@ -57,20 +59,22 @@ class TestFixtureRegistry:
def test_register_with_decorator(self): def test_register_with_decorator(self):
"""Register fixture with decorator.""" """Register fixture with decorator."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id = uuid.uuid4()
@registry.register @registry.register
def roles(): def roles():
return [Role(id=1, name="admin")] return [Role(id=role_id, name="admin")]
assert "roles" in [f.name for f in registry.get_all()] assert "roles" in [f.name for f in registry.get_all()]
def test_register_with_custom_name(self): def test_register_with_custom_name(self):
"""Register fixture with custom name.""" """Register fixture with custom name."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id = uuid.uuid4()
@registry.register(name="custom_roles") @registry.register(name="custom_roles")
def roles(): def roles():
return [Role(id=1, name="admin")] return [Role(id=role_id, name="admin")]
fixture = registry.get("custom_roles") fixture = registry.get("custom_roles")
assert fixture.name == "custom_roles" assert fixture.name == "custom_roles"
@@ -78,14 +82,23 @@ class TestFixtureRegistry:
def test_register_with_dependencies(self): def test_register_with_dependencies(self):
"""Register fixture with dependencies.""" """Register fixture with dependencies."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id = uuid.uuid4()
user_id = uuid.uuid4()
@registry.register @registry.register
def roles(): def roles():
return [Role(id=1, name="admin")] return [Role(id=role_id, name="admin")]
@registry.register(depends_on=["roles"]) @registry.register(depends_on=["roles"])
def users(): def users():
return [User(id=1, username="admin", email="admin@test.com", role_id=1)] return [
User(
id=user_id,
username="admin",
email="admin@test.com",
role_id=role_id,
)
]
fixture = registry.get("users") fixture = registry.get("users")
assert fixture.depends_on == ["roles"] assert fixture.depends_on == ["roles"]
@@ -93,10 +106,11 @@ class TestFixtureRegistry:
def test_register_with_contexts(self): def test_register_with_contexts(self):
"""Register fixture with contexts.""" """Register fixture with contexts."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id = uuid.uuid4()
@registry.register(contexts=[Context.TESTING]) @registry.register(contexts=[Context.TESTING])
def test_data(): def test_data():
return [Role(id=100, name="test")] return [Role(id=role_id, name="test")]
fixture = registry.get("test_data") fixture = registry.get("test_data")
assert Context.TESTING.value in fixture.contexts assert Context.TESTING.value in fixture.contexts
@@ -244,12 +258,14 @@ class TestLoadFixtures:
async def test_load_single_fixture(self, db_session: AsyncSession): async def test_load_single_fixture(self, db_session: AsyncSession):
"""Load a single fixture.""" """Load a single fixture."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id_1 = uuid.uuid4()
role_id_2 = uuid.uuid4()
@registry.register @registry.register
def roles(): def roles():
return [ return [
Role(id=1, name="admin"), Role(id=role_id_1, name="admin"),
Role(id=2, name="user"), Role(id=role_id_2, name="user"),
] ]
result = await load_fixtures(db_session, registry, "roles") result = await load_fixtures(db_session, registry, "roles")
@@ -266,14 +282,23 @@ class TestLoadFixtures:
async def test_load_with_dependencies(self, db_session: AsyncSession): async def test_load_with_dependencies(self, db_session: AsyncSession):
"""Load fixtures with dependencies.""" """Load fixtures with dependencies."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id = uuid.uuid4()
user_id = uuid.uuid4()
@registry.register @registry.register
def roles(): def roles():
return [Role(id=1, name="admin")] return [Role(id=role_id, name="admin")]
@registry.register(depends_on=["roles"]) @registry.register(depends_on=["roles"])
def users(): def users():
return [User(id=1, username="admin", email="admin@test.com", role_id=1)] return [
User(
id=user_id,
username="admin",
email="admin@test.com",
role_id=role_id,
)
]
result = await load_fixtures(db_session, registry, "users") result = await load_fixtures(db_session, registry, "users")
@@ -289,10 +314,11 @@ class TestLoadFixtures:
async def test_load_with_merge_strategy(self, db_session: AsyncSession): async def test_load_with_merge_strategy(self, db_session: AsyncSession):
"""Load fixtures with MERGE strategy updates existing.""" """Load fixtures with MERGE strategy updates existing."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id = uuid.uuid4()
@registry.register @registry.register
def roles(): def roles():
return [Role(id=1, name="admin")] return [Role(id=role_id, name="admin")]
await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE) await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE)
await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE) await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE)
@@ -306,10 +332,11 @@ class TestLoadFixtures:
async def test_load_with_skip_existing_strategy(self, db_session: AsyncSession): async def test_load_with_skip_existing_strategy(self, db_session: AsyncSession):
"""Load fixtures with SKIP_EXISTING strategy.""" """Load fixtures with SKIP_EXISTING strategy."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id = uuid.uuid4()
@registry.register @registry.register
def roles(): def roles():
return [Role(id=1, name="original")] return [Role(id=role_id, name="original")]
await load_fixtures( await load_fixtures(
db_session, registry, "roles", strategy=LoadStrategy.SKIP_EXISTING db_session, registry, "roles", strategy=LoadStrategy.SKIP_EXISTING
@@ -317,7 +344,7 @@ class TestLoadFixtures:
@registry.register(name="roles_updated") @registry.register(name="roles_updated")
def roles_v2(): def roles_v2():
return [Role(id=1, name="updated")] return [Role(id=role_id, name="updated")]
registry._fixtures["roles"] = registry._fixtures.pop("roles_updated") registry._fixtures["roles"] = registry._fixtures.pop("roles_updated")
@@ -327,7 +354,7 @@ class TestLoadFixtures:
from .conftest import RoleCrud from .conftest import RoleCrud
role = await RoleCrud.first(db_session, [Role.id == 1]) role = await RoleCrud.first(db_session, [Role.id == role_id])
assert role is not None assert role is not None
assert role.name == "original" assert role.name == "original"
@@ -335,12 +362,14 @@ class TestLoadFixtures:
async def test_load_with_insert_strategy(self, db_session: AsyncSession): async def test_load_with_insert_strategy(self, db_session: AsyncSession):
"""Load fixtures with INSERT strategy.""" """Load fixtures with INSERT strategy."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id_1 = uuid.uuid4()
role_id_2 = uuid.uuid4()
@registry.register @registry.register
def roles(): def roles():
return [ return [
Role(id=1, name="admin"), Role(id=role_id_1, name="admin"),
Role(id=2, name="user"), Role(id=role_id_2, name="user"),
] ]
result = await load_fixtures( result = await load_fixtures(
@@ -375,14 +404,16 @@ class TestLoadFixtures:
): ):
"""Load multiple independent fixtures.""" """Load multiple independent fixtures."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id_1 = uuid.uuid4()
role_id_2 = uuid.uuid4()
@registry.register @registry.register
def roles(): def roles():
return [Role(id=1, name="admin")] return [Role(id=role_id_1, name="admin")]
@registry.register @registry.register
def other_roles(): def other_roles():
return [Role(id=2, name="user")] return [Role(id=role_id_2, name="user")]
result = await load_fixtures(db_session, registry, "roles", "other_roles") result = await load_fixtures(db_session, registry, "roles", "other_roles")
@@ -402,14 +433,16 @@ class TestLoadFixturesByContext:
async def test_load_by_single_context(self, db_session: AsyncSession): async def test_load_by_single_context(self, db_session: AsyncSession):
"""Load fixtures by single context.""" """Load fixtures by single context."""
registry = FixtureRegistry() registry = FixtureRegistry()
base_role_id = uuid.uuid4()
test_role_id = uuid.uuid4()
@registry.register(contexts=[Context.BASE]) @registry.register(contexts=[Context.BASE])
def base_roles(): def base_roles():
return [Role(id=1, name="base_role")] return [Role(id=base_role_id, name="base_role")]
@registry.register(contexts=[Context.TESTING]) @registry.register(contexts=[Context.TESTING])
def test_roles(): def test_roles():
return [Role(id=100, name="test_role")] return [Role(id=test_role_id, name="test_role")]
await load_fixtures_by_context(db_session, registry, Context.BASE) await load_fixtures_by_context(db_session, registry, Context.BASE)
@@ -418,7 +451,7 @@ class TestLoadFixturesByContext:
count = await RoleCrud.count(db_session) count = await RoleCrud.count(db_session)
assert count == 1 assert count == 1
role = await RoleCrud.first(db_session, [Role.id == 1]) role = await RoleCrud.first(db_session, [Role.id == base_role_id])
assert role is not None assert role is not None
assert role.name == "base_role" assert role.name == "base_role"
@@ -426,14 +459,16 @@ class TestLoadFixturesByContext:
async def test_load_by_multiple_contexts(self, db_session: AsyncSession): async def test_load_by_multiple_contexts(self, db_session: AsyncSession):
"""Load fixtures by multiple contexts.""" """Load fixtures by multiple contexts."""
registry = FixtureRegistry() registry = FixtureRegistry()
base_role_id = uuid.uuid4()
test_role_id = uuid.uuid4()
@registry.register(contexts=[Context.BASE]) @registry.register(contexts=[Context.BASE])
def base_roles(): def base_roles():
return [Role(id=1, name="base_role")] return [Role(id=base_role_id, name="base_role")]
@registry.register(contexts=[Context.TESTING]) @registry.register(contexts=[Context.TESTING])
def test_roles(): def test_roles():
return [Role(id=100, name="test_role")] return [Role(id=test_role_id, name="test_role")]
await load_fixtures_by_context( await load_fixtures_by_context(
db_session, registry, Context.BASE, Context.TESTING db_session, registry, Context.BASE, Context.TESTING
@@ -448,14 +483,23 @@ class TestLoadFixturesByContext:
async def test_load_context_with_dependencies(self, db_session: AsyncSession): async def test_load_context_with_dependencies(self, db_session: AsyncSession):
"""Load context fixtures with cross-context dependencies.""" """Load context fixtures with cross-context dependencies."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id = uuid.uuid4()
user_id = uuid.uuid4()
@registry.register(contexts=[Context.BASE]) @registry.register(contexts=[Context.BASE])
def roles(): def roles():
return [Role(id=1, name="admin")] return [Role(id=role_id, name="admin")]
@registry.register(depends_on=["roles"], contexts=[Context.TESTING]) @registry.register(depends_on=["roles"], contexts=[Context.TESTING])
def test_users(): def test_users():
return [User(id=1, username="tester", email="test@test.com", role_id=1)] return [
User(
id=user_id,
username="tester",
email="test@test.com",
role_id=role_id,
)
]
await load_fixtures_by_context(db_session, registry, Context.TESTING) await load_fixtures_by_context(db_session, registry, Context.TESTING)
@@ -471,20 +515,41 @@ class TestGetObjByAttr:
def setup_method(self): def setup_method(self):
"""Set up test fixtures for each test.""" """Set up test fixtures for each test."""
self.registry = FixtureRegistry() self.registry = FixtureRegistry()
self.role_id_1 = uuid.uuid4()
self.role_id_2 = uuid.uuid4()
self.role_id_3 = uuid.uuid4()
self.user_id_1 = uuid.uuid4()
self.user_id_2 = uuid.uuid4()
role_id_1 = self.role_id_1
role_id_2 = self.role_id_2
role_id_3 = self.role_id_3
user_id_1 = self.user_id_1
user_id_2 = self.user_id_2
@self.registry.register @self.registry.register
def roles() -> list[Role]: def roles() -> list[Role]:
return [ return [
Role(id=1, name="admin"), Role(id=role_id_1, name="admin"),
Role(id=2, name="user"), Role(id=role_id_2, name="user"),
Role(id=3, name="moderator"), Role(id=role_id_3, name="moderator"),
] ]
@self.registry.register(depends_on=["roles"]) @self.registry.register(depends_on=["roles"])
def users() -> list[User]: def users() -> list[User]:
return [ return [
User(id=1, username="alice", email="alice@example.com", role_id=1), User(
User(id=2, username="bob", email="bob@example.com", role_id=1), id=user_id_1,
username="alice",
email="alice@example.com",
role_id=role_id_1,
),
User(
id=user_id_2,
username="bob",
email="bob@example.com",
role_id=role_id_1,
),
] ]
self.roles = roles self.roles = roles
@@ -492,18 +557,18 @@ class TestGetObjByAttr:
def test_get_by_id(self): def test_get_by_id(self):
"""Get an object by its id attribute.""" """Get an object by its id attribute."""
role = get_obj_by_attr(self.roles, "id", 1) role = get_obj_by_attr(self.roles, "id", self.role_id_1)
assert role.name == "admin" assert role.name == "admin"
def test_get_user_by_username(self): def test_get_user_by_username(self):
"""Get a user by username.""" """Get a user by username."""
user = get_obj_by_attr(self.users, "username", "bob") user = get_obj_by_attr(self.users, "username", "bob")
assert user.id == 2 assert user.id == self.user_id_2
assert user.email == "bob@example.com" assert user.email == "bob@example.com"
def test_returns_first_match(self): def test_returns_first_match(self):
"""Returns the first matching object when multiple could match.""" """Returns the first matching object when multiple could match."""
user = get_obj_by_attr(self.users, "role_id", 1) user = get_obj_by_attr(self.users, "role_id", self.role_id_1)
assert user.username == "alice" assert user.username == "alice"
def test_no_match_raises_stop_iteration(self): def test_no_match_raises_stop_iteration(self):
@@ -514,4 +579,4 @@ class TestGetObjByAttr:
def test_no_match_on_wrong_value_type(self): def test_no_match_on_wrong_value_type(self):
"""Raises StopIteration when value type doesn't match.""" """Raises StopIteration when value type doesn't match."""
with pytest.raises(StopIteration): with pytest.raises(StopIteration):
get_obj_by_attr(self.roles, "id", "1") get_obj_by_attr(self.roles, "id", "not-a-uuid")

View File

@@ -1,5 +1,7 @@
"""Tests for fastapi_toolsets.pytest module.""" """Tests for fastapi_toolsets.pytest module."""
import uuid
import pytest import pytest
from fastapi import FastAPI from fastapi import FastAPI
from httpx import AsyncClient from httpx import AsyncClient
@@ -18,27 +20,49 @@ from .conftest import DATABASE_URL, Base, Role, RoleCrud, User, UserCrud
test_registry = FixtureRegistry() test_registry = FixtureRegistry()
# Fixed UUIDs for test fixtures to allow consistent assertions
ROLE_ADMIN_ID = uuid.UUID("00000000-0000-0000-0000-000000001000")
ROLE_USER_ID = uuid.UUID("00000000-0000-0000-0000-000000001001")
USER_ADMIN_ID = uuid.UUID("00000000-0000-0000-0000-000000002000")
USER_USER_ID = uuid.UUID("00000000-0000-0000-0000-000000002001")
USER_EXTRA_ID = uuid.UUID("00000000-0000-0000-0000-000000002002")
@test_registry.register(contexts=[Context.BASE]) @test_registry.register(contexts=[Context.BASE])
def roles() -> list[Role]: def roles() -> list[Role]:
return [ return [
Role(id=1000, name="plugin_admin"), Role(id=ROLE_ADMIN_ID, name="plugin_admin"),
Role(id=1001, name="plugin_user"), Role(id=ROLE_USER_ID, name="plugin_user"),
] ]
@test_registry.register(depends_on=["roles"], contexts=[Context.BASE]) @test_registry.register(depends_on=["roles"], contexts=[Context.BASE])
def users() -> list[User]: def users() -> list[User]:
return [ return [
User(id=1000, username="plugin_admin", email="padmin@test.com", role_id=1000), User(
User(id=1001, username="plugin_user", email="puser@test.com", role_id=1001), id=USER_ADMIN_ID,
username="plugin_admin",
email="padmin@test.com",
role_id=ROLE_ADMIN_ID,
),
User(
id=USER_USER_ID,
username="plugin_user",
email="puser@test.com",
role_id=ROLE_USER_ID,
),
] ]
@test_registry.register(depends_on=["users"], contexts=[Context.TESTING]) @test_registry.register(depends_on=["users"], contexts=[Context.TESTING])
def extra_users() -> list[User]: def extra_users() -> list[User]:
return [ return [
User(id=1002, username="plugin_extra", email="pextra@test.com", role_id=1001), User(
id=USER_EXTRA_ID,
username="plugin_extra",
email="pextra@test.com",
role_id=ROLE_USER_ID,
),
] ]
@@ -73,7 +97,7 @@ class TestGeneratedFixtures:
assert fixture_roles[1].name == "plugin_user" assert fixture_roles[1].name == "plugin_user"
# Verify data is in database # Verify data is in database
count = await RoleCrud.count(db_session, [Role.id >= 1000]) count = await RoleCrud.count(db_session)
assert count == 2 assert count == 2
@pytest.mark.anyio @pytest.mark.anyio
@@ -86,11 +110,11 @@ class TestGeneratedFixtures:
assert len(fixture_users) == 2 assert len(fixture_users) == 2
# Roles should also be in database # Roles should also be in database
roles_count = await RoleCrud.count(db_session, [Role.id >= 1000]) roles_count = await RoleCrud.count(db_session)
assert roles_count == 2 assert roles_count == 2
# Users should be in database # Users should be in database
users_count = await UserCrud.count(db_session, [User.id >= 1000]) users_count = await UserCrud.count(db_session)
assert users_count == 2 assert users_count == 2
@pytest.mark.anyio @pytest.mark.anyio
@@ -100,7 +124,7 @@ class TestGeneratedFixtures:
"""Fixture returns actual model instances.""" """Fixture returns actual model instances."""
user = fixture_users[0] user = fixture_users[0]
assert isinstance(user, User) assert isinstance(user, User)
assert user.id == 1000 assert user.id == USER_ADMIN_ID
assert user.username == "plugin_admin" assert user.username == "plugin_admin"
@pytest.mark.anyio @pytest.mark.anyio
@@ -111,7 +135,7 @@ class TestGeneratedFixtures:
# Load user with role relationship # Load user with role relationship
user = await UserCrud.get( user = await UserCrud.get(
db_session, db_session,
[User.id == 1000], [User.id == USER_ADMIN_ID],
load_options=[selectinload(User.role)], load_options=[selectinload(User.role)],
) )
@@ -127,8 +151,8 @@ class TestGeneratedFixtures:
assert len(fixture_extra_users) == 1 assert len(fixture_extra_users) == 1
# All fixtures should be loaded # All fixtures should be loaded
roles_count = await RoleCrud.count(db_session, [Role.id >= 1000]) roles_count = await RoleCrud.count(db_session)
users_count = await UserCrud.count(db_session, [User.id >= 1000]) users_count = await UserCrud.count(db_session)
assert roles_count == 2 assert roles_count == 2
assert users_count == 3 # 2 from users + 1 from extra_users assert users_count == 3 # 2 from users + 1 from extra_users
@@ -141,8 +165,7 @@ class TestGeneratedFixtures:
# Get all users loaded by fixture # Get all users loaded by fixture
users = await UserCrud.get_multi( users = await UserCrud.get_multi(
db_session, db_session,
filters=[User.id >= 1000], order_by=User.username,
order_by=User.id,
) )
assert len(users) == 2 assert len(users) == 2
@@ -161,8 +184,8 @@ class TestGeneratedFixtures:
assert len(fixture_users) == 2 assert len(fixture_users) == 2
# Both should be in database # Both should be in database
roles = await RoleCrud.get_multi(db_session, filters=[Role.id >= 1000]) roles = await RoleCrud.get_multi(db_session)
users = await UserCrud.get_multi(db_session, filters=[User.id >= 1000]) users = await UserCrud.get_multi(db_session)
assert len(roles) == 2 assert len(roles) == 2
assert len(users) == 2 assert len(users) == 2
@@ -215,14 +238,15 @@ class TestCreateDbSession:
@pytest.mark.anyio @pytest.mark.anyio
async def test_creates_working_session(self): async def test_creates_working_session(self):
"""Session can perform database operations.""" """Session can perform database operations."""
role_id = uuid.uuid4()
async with create_db_session(DATABASE_URL, Base) as session: async with create_db_session(DATABASE_URL, Base) as session:
assert isinstance(session, AsyncSession) assert isinstance(session, AsyncSession)
role = Role(id=9001, name="test_helper_role") role = Role(id=role_id, name="test_helper_role")
session.add(role) session.add(role)
await session.commit() await session.commit()
result = await session.execute(select(Role).where(Role.id == 9001)) result = await session.execute(select(Role).where(Role.id == role_id))
fetched = result.scalar_one() fetched = result.scalar_one()
assert fetched.name == "test_helper_role" assert fetched.name == "test_helper_role"
@@ -237,8 +261,9 @@ class TestCreateDbSession:
@pytest.mark.anyio @pytest.mark.anyio
async def test_tables_dropped_after_session(self): async def test_tables_dropped_after_session(self):
"""Tables are dropped after session closes when drop_tables=True.""" """Tables are dropped after session closes when drop_tables=True."""
role_id = uuid.uuid4()
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session: async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
role = Role(id=9002, name="will_be_dropped") role = Role(id=role_id, name="will_be_dropped")
session.add(role) session.add(role)
await session.commit() await session.commit()
@@ -250,14 +275,15 @@ class TestCreateDbSession:
@pytest.mark.anyio @pytest.mark.anyio
async def test_tables_preserved_when_drop_disabled(self): async def test_tables_preserved_when_drop_disabled(self):
"""Tables are preserved when drop_tables=False.""" """Tables are preserved when drop_tables=False."""
role_id = uuid.uuid4()
async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session: async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session:
role = Role(id=9003, name="preserved_role") role = Role(id=role_id, name="preserved_role")
session.add(role) session.add(role)
await session.commit() await session.commit()
# Create another session without dropping # Create another session without dropping
async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session: async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session:
result = await session.execute(select(Role).where(Role.id == 9003)) result = await session.execute(select(Role).where(Role.id == role_id))
fetched = result.scalar_one_or_none() fetched = result.scalar_one_or_none()
assert fetched is not None assert fetched is not None
assert fetched.name == "preserved_role" assert fetched.name == "preserved_role"

2
uv.lock generated
View File

@@ -220,7 +220,7 @@ wheels = [
[[package]] [[package]]
name = "fastapi-toolsets" name = "fastapi-toolsets"
version = "0.3.0" version = "0.4.1"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "asyncpg" }, { name = "asyncpg" },