7 Commits

Author SHA1 Message Date
54f5479c24 Version 0.4.1 2026-01-29 14:15:55 -05:00
d3vyce
f467754df1 fix: cast to String non-text columns for crud search (#18)
fix: cast to String non-text columns for crud search
2026-01-29 19:44:48 +01:00
b57ce40b05 tests: change models to use UUID as primary key 2026-01-29 13:43:03 -05:00
5264631550 fix: cast to String non-text columns for crud search 2026-01-29 13:35:20 -05:00
a76f7c439d Version 0.4.0 2026-01-29 09:15:33 -05:00
d3vyce
d14551781c feat: add search to crud paginate function (#17)
* feat: add search to crud paginate function

* fixes: comments + tests import
2026-01-29 00:08:02 +01:00
d3vyce
577e087321 feat: add support for python 3.14 (#15) 2026-01-28 21:01:15 +01:00
16 changed files with 857 additions and 115 deletions

View File

@@ -17,7 +17,7 @@ jobs:
uses: astral-sh/setup-uv@v7
- name: Set up Python
run: uv python install 3.13
run: uv python install 3.14
- name: Install dependencies
run: uv sync

View File

@@ -56,7 +56,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.11", "3.12", "3.13"]
python-version: ["3.11", "3.12", "3.13", "3.14"]
services:
postgres:
@@ -92,7 +92,7 @@ jobs:
uv run pytest --cov --cov-report=xml --cov-report=term-missing
- name: Upload coverage to Codecov
if: matrix.python-version == '3.13'
if: matrix.python-version == '3.14'
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}

View File

@@ -1 +1 @@
3.13
3.14

View File

@@ -1,6 +1,6 @@
[project]
name = "fastapi-toolsets"
version = "0.3.0"
version = "0.4.1"
description = "Reusable tools for FastAPI: async CRUD, fixtures, CLI, and standardized responses for SQLAlchemy + PostgreSQL"
readme = "README.md"
license = "MIT"
@@ -24,6 +24,7 @@ classifiers = [
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
"Topic :: Software Development :: Libraries :: Python Modules",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development",

View File

@@ -21,4 +21,4 @@ Example usage:
return Response(data={"user": user.username}, message="Success")
"""
__version__ = "0.3.0"
__version__ = "0.4.1"

View File

@@ -0,0 +1,17 @@
"""Generic async CRUD operations for SQLAlchemy models."""
from ..exceptions import NoSearchableFieldsError
from .factory import CrudFactory
from .search import (
SearchConfig,
SearchFieldType,
get_searchable_fields,
)
__all__ = [
"CrudFactory",
"get_searchable_fields",
"NoSearchableFieldsError",
"SearchConfig",
"SearchFieldType",
]

View File

@@ -12,13 +12,9 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.sql.roles import WhereHavingRole
from .db import get_transaction
from .exceptions import NotFoundError
__all__ = [
"AsyncCrud",
"CrudFactory",
]
from ..db import get_transaction
from ..exceptions import NotFoundError
from .search import SearchConfig, SearchFieldType, build_search_filters
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
@@ -27,20 +23,10 @@ class AsyncCrud(Generic[ModelType]):
"""Generic async CRUD operations for SQLAlchemy models.
Subclass this and set the `model` class variable, or use `CrudFactory`.
Example:
class UserCrud(AsyncCrud[User]):
model = User
# Or use the factory:
UserCrud = CrudFactory(User)
# Then use it:
user = await UserCrud.get(session, [User.id == 1])
users = await UserCrud.get_multi(session, limit=10)
"""
model: ClassVar[type[DeclarativeBase]]
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
@classmethod
async def create(
@@ -313,6 +299,8 @@ class AsyncCrud(Generic[ModelType]):
order_by: Any | None = None,
page: int = 1,
items_per_page: int = 20,
search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None,
) -> dict[str, Any]:
"""Get paginated results with metadata.
@@ -323,23 +311,54 @@ class AsyncCrud(Generic[ModelType]):
order_by: Column or list of columns to order by
page: Page number (1-indexed)
items_per_page: Number of items per page
search: Search query string or SearchConfig object
search_fields: Fields to search in (overrides class default)
Returns:
Dict with 'data' and 'pagination' keys
"""
filters = filters or []
filters = list(filters) if filters else []
offset = (page - 1) * items_per_page
joins: list[Any] = []
items = await cls.get_multi(
session,
filters=filters,
load_options=load_options,
order_by=order_by,
limit=items_per_page,
offset=offset,
# Build search filters
if search:
search_filters, search_joins = build_search_filters(
cls.model,
search,
search_fields=search_fields,
default_fields=cls.searchable_fields,
)
filters.extend(search_filters)
joins.extend(search_joins)
total_count = await cls.count(session, filters=filters)
# Build query with joins
q = select(cls.model)
for join_rel in joins:
q = q.outerjoin(join_rel)
if filters:
q = q.where(and_(*filters))
if load_options:
q = q.options(*load_options)
if order_by is not None:
q = q.order_by(order_by)
q = q.offset(offset).limit(items_per_page)
result = await session.execute(q)
items = result.unique().scalars().all()
# Count query (with same joins and filters)
pk_col = cls.model.__mapper__.primary_key[0]
count_q = select(func.count(func.distinct(getattr(cls.model, pk_col.name))))
count_q = count_q.select_from(cls.model)
for join_rel in joins:
count_q = count_q.outerjoin(join_rel)
if filters:
count_q = count_q.where(and_(*filters))
count_result = await session.execute(count_q)
total_count = count_result.scalar_one()
return {
"data": items,
@@ -354,11 +373,14 @@ class AsyncCrud(Generic[ModelType]):
def CrudFactory(
model: type[ModelType],
*,
searchable_fields: Sequence[SearchFieldType] | None = None,
) -> type[AsyncCrud[ModelType]]:
"""Create a CRUD class for a specific model.
Args:
model: SQLAlchemy model class
searchable_fields: Optional list of searchable fields
Returns:
AsyncCrud subclass bound to the model
@@ -370,9 +392,25 @@ def CrudFactory(
UserCrud = CrudFactory(User)
PostCrud = CrudFactory(Post)
# With searchable fields:
UserCrud = CrudFactory(
User,
searchable_fields=[User.username, User.email, (User.role, Role.name)]
)
# Usage
user = await UserCrud.get(session, [User.id == 1])
posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id])
# With search
result = await UserCrud.paginate(session, search="john")
"""
cls = type(f"Async{model.__name__}Crud", (AsyncCrud,), {"model": model})
cls = type(
f"Async{model.__name__}Crud",
(AsyncCrud,),
{
"model": model,
"searchable_fields": searchable_fields,
},
)
return cast(type[AsyncCrud[ModelType]], cls)

View File

@@ -0,0 +1,146 @@
"""Search utilities for AsyncCrud."""
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal
from sqlalchemy import String, or_
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm.attributes import InstrumentedAttribute
from ..exceptions import NoSearchableFieldsError
if TYPE_CHECKING:
from sqlalchemy.sql.elements import ColumnElement
SearchFieldType = InstrumentedAttribute[Any] | tuple[InstrumentedAttribute[Any], ...]
@dataclass
class SearchConfig:
"""Advanced search configuration.
Attributes:
query: The search string
fields: Fields to search (columns or tuples for relationships)
case_sensitive: Case-sensitive search (default: False)
match_mode: "any" (OR) or "all" (AND) to combine fields
"""
query: str
fields: Sequence[SearchFieldType] | None = None
case_sensitive: bool = False
match_mode: Literal["any", "all"] = "any"
def get_searchable_fields(
model: type[DeclarativeBase],
*,
include_relationships: bool = True,
max_depth: int = 1,
) -> list[SearchFieldType]:
"""Auto-detect String fields on a model and its relationships.
Args:
model: SQLAlchemy model class
include_relationships: Include fields from many-to-one/one-to-one relationships
max_depth: Max depth for relationship traversal (default: 1)
Returns:
List of columns and tuples (relationship, column)
"""
fields: list[SearchFieldType] = []
mapper = model.__mapper__
# Direct String columns
for col in mapper.columns:
if isinstance(col.type, String):
fields.append(getattr(model, col.key))
# Relationships (one-to-one, many-to-one only)
if include_relationships and max_depth > 0:
for rel_name, rel_prop in mapper.relationships.items():
if rel_prop.uselist: # Skip collections (one-to-many, many-to-many)
continue
rel_attr = getattr(model, rel_name)
related_model = rel_prop.mapper.class_
for col in related_model.__mapper__.columns:
if isinstance(col.type, String):
fields.append((rel_attr, getattr(related_model, col.key)))
return fields
def build_search_filters(
model: type[DeclarativeBase],
search: str | SearchConfig,
search_fields: Sequence[SearchFieldType] | None = None,
default_fields: Sequence[SearchFieldType] | None = None,
) -> tuple[list["ColumnElement[bool]"], list[InstrumentedAttribute[Any]]]:
"""Build SQLAlchemy filter conditions for search.
Args:
model: SQLAlchemy model class
search: Search string or SearchConfig
search_fields: Fields specified per-call (takes priority)
default_fields: Default fields (from ClassVar)
Returns:
Tuple of (filter_conditions, joins_needed)
"""
# Normalize input
if isinstance(search, str):
config = SearchConfig(query=search, fields=search_fields)
else:
config = search
if search_fields is not None:
config = SearchConfig(
query=config.query,
fields=search_fields,
case_sensitive=config.case_sensitive,
match_mode=config.match_mode,
)
if not config.query or not config.query.strip():
return [], []
# Determine which fields to search
fields = config.fields or default_fields or get_searchable_fields(model)
if not fields:
raise NoSearchableFieldsError(model)
query = config.query.strip()
filters: list[ColumnElement[bool]] = []
joins: list[InstrumentedAttribute[Any]] = []
added_joins: set[str] = set()
for field in fields:
if isinstance(field, tuple):
# Relationship: (User.role, Role.name) or deeper
for rel in field[:-1]:
rel_key = str(rel)
if rel_key not in added_joins:
joins.append(rel)
added_joins.add(rel_key)
column = field[-1]
else:
column = field
# Build the filter (cast to String for non-text columns)
column_as_string = column.cast(String)
if config.case_sensitive:
filters.append(column_as_string.like(f"%{query}%"))
else:
filters.append(column_as_string.ilike(f"%{query}%"))
if not filters:
return [], []
# Combine based on match_mode
if config.match_mode == "any":
return [or_(*filters)], joins
else:
return filters, joins

View File

@@ -1,7 +1,9 @@
from .exceptions import (
ApiError,
ApiException,
ConflictError,
ForbiddenError,
NoSearchableFieldsError,
NotFoundError,
UnauthorizedError,
generate_error_responses,
@@ -9,11 +11,13 @@ from .exceptions import (
from .handler import init_exceptions_handlers
__all__ = [
"init_exceptions_handlers",
"generate_error_responses",
"ApiError",
"ApiException",
"ConflictError",
"ForbiddenError",
"generate_error_responses",
"init_exceptions_handlers",
"NoSearchableFieldsError",
"NotFoundError",
"UnauthorizedError",
]

View File

@@ -119,6 +119,25 @@ class RoleNotFoundError(NotFoundError):
)
class NoSearchableFieldsError(ApiException):
"""Raised when search is requested but no searchable fields are available."""
api_error = ApiError(
code=400,
msg="No Searchable Fields",
desc="No searchable fields configured for this resource.",
err_code="SEARCH-400",
)
def __init__(self, model: type) -> None:
self.model = model
detail = (
f"No searchable fields found for model '{model.__name__}'. "
"Provide 'search_fields' parameter or set 'searchable_fields' on the CRUD class."
)
super().__init__(detail)
def generate_error_responses(
*errors: type[ApiException],
) -> dict[int | str, dict[str, Any]]:

View File

@@ -1,10 +1,11 @@
"""Shared pytest fixtures for fastapi-utils tests."""
import os
import uuid
import pytest
from pydantic import BaseModel
from sqlalchemy import ForeignKey, String
from sqlalchemy import ForeignKey, String, Uuid
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
@@ -33,7 +34,7 @@ class Role(Base):
__tablename__ = "roles"
id: Mapped[int] = mapped_column(primary_key=True)
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
name: Mapped[str] = mapped_column(String(50), unique=True)
users: Mapped[list["User"]] = relationship(back_populates="role")
@@ -44,11 +45,13 @@ class User(Base):
__tablename__ = "users"
id: Mapped[int] = mapped_column(primary_key=True)
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
username: Mapped[str] = mapped_column(String(50), unique=True)
email: Mapped[str] = mapped_column(String(100), unique=True)
is_active: Mapped[bool] = mapped_column(default=True)
role_id: Mapped[int | None] = mapped_column(ForeignKey("roles.id"), nullable=True)
role_id: Mapped[uuid.UUID | None] = mapped_column(
ForeignKey("roles.id"), nullable=True
)
role: Mapped[Role | None] = relationship(back_populates="users")
@@ -58,11 +61,11 @@ class Post(Base):
__tablename__ = "posts"
id: Mapped[int] = mapped_column(primary_key=True)
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
title: Mapped[str] = mapped_column(String(200))
content: Mapped[str] = mapped_column(String(1000), default="")
is_published: Mapped[bool] = mapped_column(default=False)
author_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
author_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"))
# =============================================================================
@@ -73,7 +76,7 @@ class Post(Base):
class RoleCreate(BaseModel):
"""Schema for creating a role."""
id: int | None = None
id: uuid.UUID | None = None
name: str
@@ -86,11 +89,11 @@ class RoleUpdate(BaseModel):
class UserCreate(BaseModel):
"""Schema for creating a user."""
id: int | None = None
id: uuid.UUID | None = None
username: str
email: str
is_active: bool = True
role_id: int | None = None
role_id: uuid.UUID | None = None
class UserUpdate(BaseModel):
@@ -99,17 +102,17 @@ class UserUpdate(BaseModel):
username: str | None = None
email: str | None = None
is_active: bool | None = None
role_id: int | None = None
role_id: uuid.UUID | None = None
class PostCreate(BaseModel):
"""Schema for creating a post."""
id: int | None = None
id: uuid.UUID | None = None
title: str
content: str = ""
is_published: bool = False
author_id: int
author_id: uuid.UUID
class PostUpdate(BaseModel):
@@ -195,5 +198,5 @@ def sample_post_data() -> PostCreate:
title="Test Post",
content="Test content",
is_published=True,
author_id=1,
author_id=uuid.uuid4(),
)

View File

@@ -1,9 +1,12 @@
"""Tests for fastapi_toolsets.crud module."""
import uuid
import pytest
from sqlalchemy.ext.asyncio import AsyncSession
from fastapi_toolsets.crud import AsyncCrud, CrudFactory
from fastapi_toolsets.crud import CrudFactory
from fastapi_toolsets.crud.factory import AsyncCrud
from fastapi_toolsets.exceptions import NotFoundError
from .conftest import (
@@ -88,8 +91,9 @@ class TestCrudGet:
@pytest.mark.anyio
async def test_get_raises_not_found(self, db_session: AsyncSession):
"""Get raises NotFoundError for missing records."""
non_existent_id = uuid.uuid4()
with pytest.raises(NotFoundError):
await RoleCrud.get(db_session, [Role.id == 99999])
await RoleCrud.get(db_session, [Role.id == non_existent_id])
@pytest.mark.anyio
async def test_get_with_multiple_filters(self, db_session: AsyncSession):
@@ -222,11 +226,12 @@ class TestCrudUpdate:
@pytest.mark.anyio
async def test_update_raises_not_found(self, db_session: AsyncSession):
"""Update raises NotFoundError for missing records."""
non_existent_id = uuid.uuid4()
with pytest.raises(NotFoundError):
await RoleCrud.update(
db_session,
RoleUpdate(name="new"),
[Role.id == 99999],
[Role.id == non_existent_id],
)
@pytest.mark.anyio
@@ -339,7 +344,8 @@ class TestCrudUpsert:
@pytest.mark.anyio
async def test_upsert_insert_new_record(self, db_session: AsyncSession):
"""Upsert inserts a new record when it doesn't exist."""
data = RoleCreate(id=1, name="upsert_new")
role_id = uuid.uuid4()
data = RoleCreate(id=role_id, name="upsert_new")
role = await RoleCrud.upsert(
db_session,
data,
@@ -352,12 +358,13 @@ class TestCrudUpsert:
@pytest.mark.anyio
async def test_upsert_update_existing_record(self, db_session: AsyncSession):
"""Upsert updates an existing record."""
role_id = uuid.uuid4()
# First insert
data = RoleCreate(id=100, name="original_name")
data = RoleCreate(id=role_id, name="original_name")
await RoleCrud.upsert(db_session, data, index_elements=["id"])
# Upsert with update
updated_data = RoleCreate(id=100, name="updated_name")
updated_data = RoleCreate(id=role_id, name="updated_name")
role = await RoleCrud.upsert(
db_session,
updated_data,
@@ -369,22 +376,23 @@ class TestCrudUpsert:
assert role.name == "updated_name"
# Verify only one record exists
count = await RoleCrud.count(db_session, [Role.id == 100])
count = await RoleCrud.count(db_session, [Role.id == role_id])
assert count == 1
@pytest.mark.anyio
async def test_upsert_do_nothing_on_conflict(self, db_session: AsyncSession):
"""Upsert does nothing on conflict when set_ is not provided."""
role_id = uuid.uuid4()
# First insert
data = RoleCreate(id=200, name="do_nothing_original")
data = RoleCreate(id=role_id, name="do_nothing_original")
await RoleCrud.upsert(db_session, data, index_elements=["id"])
# Upsert without set_ (do nothing)
conflict_data = RoleCreate(id=200, name="do_nothing_conflict")
conflict_data = RoleCreate(id=role_id, name="do_nothing_conflict")
await RoleCrud.upsert(db_session, conflict_data, index_elements=["id"])
# Original value should be preserved
role = await RoleCrud.first(db_session, [Role.id == 200])
role = await RoleCrud.first(db_session, [Role.id == role_id])
assert role is not None
assert role.name == "do_nothing_original"

415
tests/test_crud_search.py Normal file
View File

@@ -0,0 +1,415 @@
"""Tests for CRUD search functionality."""
import uuid
import pytest
from sqlalchemy.ext.asyncio import AsyncSession
from fastapi_toolsets.crud import SearchConfig, get_searchable_fields
from .conftest import (
Role,
RoleCreate,
RoleCrud,
User,
UserCreate,
UserCrud,
)
class TestPaginateSearch:
"""Tests for paginate() with search."""
@pytest.mark.anyio
async def test_search_single_column(self, db_session: AsyncSession):
"""Search on a single direct column."""
await UserCrud.create(
db_session, UserCreate(username="john_doe", email="john@test.com")
)
await UserCrud.create(
db_session, UserCreate(username="jane_doe", email="jane@test.com")
)
await UserCrud.create(
db_session, UserCreate(username="bob_smith", email="bob@test.com")
)
result = await UserCrud.paginate(
db_session,
search="doe",
search_fields=[User.username],
)
assert result["pagination"]["total_count"] == 2
@pytest.mark.anyio
async def test_search_multiple_columns(self, db_session: AsyncSession):
"""Search across multiple columns (OR logic)."""
await UserCrud.create(
db_session, UserCreate(username="alice", email="alice@company.com")
)
await UserCrud.create(
db_session, UserCreate(username="company_bob", email="bob@other.com")
)
result = await UserCrud.paginate(
db_session,
search="company",
search_fields=[User.username, User.email],
)
assert result["pagination"]["total_count"] == 2
@pytest.mark.anyio
async def test_search_relationship_depth1(self, db_session: AsyncSession):
"""Search through a relationship (depth 1)."""
admin_role = await RoleCrud.create(db_session, RoleCreate(name="administrator"))
user_role = await RoleCrud.create(db_session, RoleCreate(name="basic_user"))
await UserCrud.create(
db_session,
UserCreate(username="admin1", email="a1@test.com", role_id=admin_role.id),
)
await UserCrud.create(
db_session,
UserCreate(username="admin2", email="a2@test.com", role_id=admin_role.id),
)
await UserCrud.create(
db_session,
UserCreate(username="user1", email="u1@test.com", role_id=user_role.id),
)
result = await UserCrud.paginate(
db_session,
search="admin",
search_fields=[(User.role, Role.name)],
)
assert result["pagination"]["total_count"] == 2
@pytest.mark.anyio
async def test_search_mixed_direct_and_relation(self, db_session: AsyncSession):
"""Search combining direct columns and relationships."""
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
await UserCrud.create(
db_session,
UserCreate(username="john", email="john@test.com", role_id=role.id),
)
# Search "admin" in username OR role.name
result = await UserCrud.paginate(
db_session,
search="admin",
search_fields=[User.username, (User.role, Role.name)],
)
assert result["pagination"]["total_count"] == 1
@pytest.mark.anyio
async def test_search_case_insensitive(self, db_session: AsyncSession):
"""Search is case-insensitive by default."""
await UserCrud.create(
db_session, UserCreate(username="JohnDoe", email="j@test.com")
)
result = await UserCrud.paginate(
db_session,
search="johndoe",
search_fields=[User.username],
)
assert result["pagination"]["total_count"] == 1
@pytest.mark.anyio
async def test_search_case_sensitive(self, db_session: AsyncSession):
"""Case-sensitive search with SearchConfig."""
await UserCrud.create(
db_session, UserCreate(username="JohnDoe", email="j@test.com")
)
# Should not find (case mismatch)
result = await UserCrud.paginate(
db_session,
search=SearchConfig(query="johndoe", case_sensitive=True),
search_fields=[User.username],
)
assert result["pagination"]["total_count"] == 0
# Should find (case match)
result = await UserCrud.paginate(
db_session,
search=SearchConfig(query="JohnDoe", case_sensitive=True),
search_fields=[User.username],
)
assert result["pagination"]["total_count"] == 1
@pytest.mark.anyio
async def test_search_empty_query(self, db_session: AsyncSession):
"""Empty search returns all results."""
await UserCrud.create(
db_session, UserCreate(username="user1", email="u1@test.com")
)
await UserCrud.create(
db_session, UserCreate(username="user2", email="u2@test.com")
)
result = await UserCrud.paginate(db_session, search="")
assert result["pagination"]["total_count"] == 2
result = await UserCrud.paginate(db_session, search=None)
assert result["pagination"]["total_count"] == 2
@pytest.mark.anyio
async def test_search_with_existing_filters(self, db_session: AsyncSession):
"""Search combines with existing filters (AND)."""
await UserCrud.create(
db_session,
UserCreate(username="active_john", email="aj@test.com", is_active=True),
)
await UserCrud.create(
db_session,
UserCreate(username="inactive_john", email="ij@test.com", is_active=False),
)
result = await UserCrud.paginate(
db_session,
filters=[User.is_active == True], # noqa: E712
search="john",
search_fields=[User.username],
)
assert result["pagination"]["total_count"] == 1
assert result["data"][0].username == "active_john"
@pytest.mark.anyio
async def test_search_auto_detect_fields(self, db_session: AsyncSession):
"""Auto-detect searchable fields when not specified."""
await UserCrud.create(
db_session, UserCreate(username="findme", email="other@test.com")
)
result = await UserCrud.paginate(db_session, search="findme")
assert result["pagination"]["total_count"] == 1
@pytest.mark.anyio
async def test_search_no_results(self, db_session: AsyncSession):
"""Search with no matching results."""
await UserCrud.create(
db_session, UserCreate(username="john", email="j@test.com")
)
result = await UserCrud.paginate(
db_session,
search="nonexistent",
search_fields=[User.username],
)
assert result["pagination"]["total_count"] == 0
assert result["data"] == []
@pytest.mark.anyio
async def test_search_with_pagination(self, db_session: AsyncSession):
"""Search respects pagination parameters."""
for i in range(15):
await UserCrud.create(
db_session,
UserCreate(username=f"user_{i}", email=f"user{i}@test.com"),
)
result = await UserCrud.paginate(
db_session,
search="user_",
search_fields=[User.username],
page=1,
items_per_page=5,
)
assert result["pagination"]["total_count"] == 15
assert len(result["data"]) == 5
assert result["pagination"]["has_more"] is True
@pytest.mark.anyio
async def test_search_null_relationship(self, db_session: AsyncSession):
"""Users without relationship are included (outerjoin)."""
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
await UserCrud.create(
db_session,
UserCreate(username="with_role", email="wr@test.com", role_id=role.id),
)
await UserCrud.create(
db_session,
UserCreate(username="no_role", email="nr@test.com", role_id=None),
)
# Search in username, not in role
result = await UserCrud.paginate(
db_session,
search="role",
search_fields=[User.username],
)
assert result["pagination"]["total_count"] == 2
@pytest.mark.anyio
async def test_search_with_order_by(self, db_session: AsyncSession):
"""Search works with order_by parameter."""
await UserCrud.create(
db_session, UserCreate(username="charlie", email="c@test.com")
)
await UserCrud.create(
db_session, UserCreate(username="alice", email="a@test.com")
)
await UserCrud.create(
db_session, UserCreate(username="bob", email="b@test.com")
)
result = await UserCrud.paginate(
db_session,
search="@test.com",
search_fields=[User.email],
order_by=User.username,
)
assert result["pagination"]["total_count"] == 3
usernames = [u.username for u in result["data"]]
assert usernames == ["alice", "bob", "charlie"]
@pytest.mark.anyio
async def test_search_non_string_column(self, db_session: AsyncSession):
"""Search on non-string columns (e.g., UUID) works via cast."""
user_id = uuid.UUID("12345678-1234-5678-1234-567812345678")
await UserCrud.create(
db_session, UserCreate(id=user_id, username="john", email="john@test.com")
)
await UserCrud.create(
db_session, UserCreate(username="jane", email="jane@test.com")
)
# Search by UUID (partial match)
result = await UserCrud.paginate(
db_session,
search="12345678",
search_fields=[User.id, User.username],
)
assert result["pagination"]["total_count"] == 1
assert result["data"][0].id == user_id
class TestSearchConfig:
"""Tests for SearchConfig options."""
@pytest.mark.anyio
async def test_match_mode_all(self, db_session: AsyncSession):
"""match_mode='all' requires all fields to match (AND)."""
await UserCrud.create(
db_session,
UserCreate(username="john_test", email="john_test@company.com"),
)
await UserCrud.create(
db_session,
UserCreate(username="john_other", email="other@example.com"),
)
# 'john' must be in username AND email
result = await UserCrud.paginate(
db_session,
search=SearchConfig(query="john", match_mode="all"),
search_fields=[User.username, User.email],
)
assert result["pagination"]["total_count"] == 1
assert result["data"][0].username == "john_test"
@pytest.mark.anyio
async def test_search_config_with_fields(self, db_session: AsyncSession):
"""SearchConfig can specify fields directly."""
await UserCrud.create(
db_session, UserCreate(username="test", email="findme@test.com")
)
result = await UserCrud.paginate(
db_session,
search=SearchConfig(query="findme", fields=[User.email]),
)
assert result["pagination"]["total_count"] == 1
class TestNoSearchableFieldsError:
"""Tests for NoSearchableFieldsError exception."""
def test_error_is_api_exception(self):
"""NoSearchableFieldsError inherits from ApiException."""
from fastapi_toolsets.exceptions import ApiException, NoSearchableFieldsError
assert issubclass(NoSearchableFieldsError, ApiException)
def test_error_has_api_error_fields(self):
"""NoSearchableFieldsError has proper api_error configuration."""
from fastapi_toolsets.exceptions import NoSearchableFieldsError
assert NoSearchableFieldsError.api_error.code == 400
assert NoSearchableFieldsError.api_error.err_code == "SEARCH-400"
def test_error_message_contains_model_name(self):
"""Error message includes the model name."""
from fastapi_toolsets.exceptions import NoSearchableFieldsError
error = NoSearchableFieldsError(User)
assert "User" in str(error)
assert error.model is User
def test_error_raised_when_no_fields(self):
"""Error is raised when search has no searchable fields."""
from sqlalchemy import Integer
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from fastapi_toolsets.crud.search import build_search_filters
from fastapi_toolsets.exceptions import NoSearchableFieldsError
# Model with no String columns
class NoStringBase(DeclarativeBase):
pass
class NoStringModel(NoStringBase):
__tablename__ = "no_strings"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
count: Mapped[int] = mapped_column(Integer, default=0)
with pytest.raises(NoSearchableFieldsError) as exc_info:
build_search_filters(NoStringModel, "test")
assert exc_info.value.model is NoStringModel
assert "NoStringModel" in str(exc_info.value)
class TestGetSearchableFields:
"""Tests for auto-detection of searchable fields."""
def test_detects_string_columns(self):
"""Detects String columns on the model."""
fields = get_searchable_fields(User, include_relationships=False)
# Should include username and email (String), not id or is_active
field_names = [str(f) for f in fields]
assert any("username" in f for f in field_names)
assert any("email" in f for f in field_names)
assert not any("id" in f and "role_id" not in f for f in field_names)
assert not any("is_active" in f for f in field_names)
def test_detects_relationship_fields(self):
"""Detects String fields on related models."""
fields = get_searchable_fields(User, include_relationships=True)
# Should include (User.role, Role.name)
has_role_name = any(isinstance(f, tuple) and len(f) == 2 for f in fields)
assert has_role_name
def test_skips_collection_relationships(self):
"""Skips one-to-many relationships."""
fields = get_searchable_fields(Role, include_relationships=True)
# Role.users is a collection, should not be included
field_strs = [str(f) for f in fields]
assert not any("users" in f for f in field_strs)

View File

@@ -1,5 +1,7 @@
"""Tests for fastapi_toolsets.fixtures module."""
import uuid
import pytest
from sqlalchemy.ext.asyncio import AsyncSession
@@ -57,20 +59,22 @@ class TestFixtureRegistry:
def test_register_with_decorator(self):
"""Register fixture with decorator."""
registry = FixtureRegistry()
role_id = uuid.uuid4()
@registry.register
def roles():
return [Role(id=1, name="admin")]
return [Role(id=role_id, name="admin")]
assert "roles" in [f.name for f in registry.get_all()]
def test_register_with_custom_name(self):
"""Register fixture with custom name."""
registry = FixtureRegistry()
role_id = uuid.uuid4()
@registry.register(name="custom_roles")
def roles():
return [Role(id=1, name="admin")]
return [Role(id=role_id, name="admin")]
fixture = registry.get("custom_roles")
assert fixture.name == "custom_roles"
@@ -78,14 +82,23 @@ class TestFixtureRegistry:
def test_register_with_dependencies(self):
"""Register fixture with dependencies."""
registry = FixtureRegistry()
role_id = uuid.uuid4()
user_id = uuid.uuid4()
@registry.register
def roles():
return [Role(id=1, name="admin")]
return [Role(id=role_id, name="admin")]
@registry.register(depends_on=["roles"])
def users():
return [User(id=1, username="admin", email="admin@test.com", role_id=1)]
return [
User(
id=user_id,
username="admin",
email="admin@test.com",
role_id=role_id,
)
]
fixture = registry.get("users")
assert fixture.depends_on == ["roles"]
@@ -93,10 +106,11 @@ class TestFixtureRegistry:
def test_register_with_contexts(self):
"""Register fixture with contexts."""
registry = FixtureRegistry()
role_id = uuid.uuid4()
@registry.register(contexts=[Context.TESTING])
def test_data():
return [Role(id=100, name="test")]
return [Role(id=role_id, name="test")]
fixture = registry.get("test_data")
assert Context.TESTING.value in fixture.contexts
@@ -244,12 +258,14 @@ class TestLoadFixtures:
async def test_load_single_fixture(self, db_session: AsyncSession):
"""Load a single fixture."""
registry = FixtureRegistry()
role_id_1 = uuid.uuid4()
role_id_2 = uuid.uuid4()
@registry.register
def roles():
return [
Role(id=1, name="admin"),
Role(id=2, name="user"),
Role(id=role_id_1, name="admin"),
Role(id=role_id_2, name="user"),
]
result = await load_fixtures(db_session, registry, "roles")
@@ -266,14 +282,23 @@ class TestLoadFixtures:
async def test_load_with_dependencies(self, db_session: AsyncSession):
"""Load fixtures with dependencies."""
registry = FixtureRegistry()
role_id = uuid.uuid4()
user_id = uuid.uuid4()
@registry.register
def roles():
return [Role(id=1, name="admin")]
return [Role(id=role_id, name="admin")]
@registry.register(depends_on=["roles"])
def users():
return [User(id=1, username="admin", email="admin@test.com", role_id=1)]
return [
User(
id=user_id,
username="admin",
email="admin@test.com",
role_id=role_id,
)
]
result = await load_fixtures(db_session, registry, "users")
@@ -289,10 +314,11 @@ class TestLoadFixtures:
async def test_load_with_merge_strategy(self, db_session: AsyncSession):
"""Load fixtures with MERGE strategy updates existing."""
registry = FixtureRegistry()
role_id = uuid.uuid4()
@registry.register
def roles():
return [Role(id=1, name="admin")]
return [Role(id=role_id, name="admin")]
await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE)
await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE)
@@ -306,10 +332,11 @@ class TestLoadFixtures:
async def test_load_with_skip_existing_strategy(self, db_session: AsyncSession):
"""Load fixtures with SKIP_EXISTING strategy."""
registry = FixtureRegistry()
role_id = uuid.uuid4()
@registry.register
def roles():
return [Role(id=1, name="original")]
return [Role(id=role_id, name="original")]
await load_fixtures(
db_session, registry, "roles", strategy=LoadStrategy.SKIP_EXISTING
@@ -317,7 +344,7 @@ class TestLoadFixtures:
@registry.register(name="roles_updated")
def roles_v2():
return [Role(id=1, name="updated")]
return [Role(id=role_id, name="updated")]
registry._fixtures["roles"] = registry._fixtures.pop("roles_updated")
@@ -327,7 +354,7 @@ class TestLoadFixtures:
from .conftest import RoleCrud
role = await RoleCrud.first(db_session, [Role.id == 1])
role = await RoleCrud.first(db_session, [Role.id == role_id])
assert role is not None
assert role.name == "original"
@@ -335,12 +362,14 @@ class TestLoadFixtures:
async def test_load_with_insert_strategy(self, db_session: AsyncSession):
"""Load fixtures with INSERT strategy."""
registry = FixtureRegistry()
role_id_1 = uuid.uuid4()
role_id_2 = uuid.uuid4()
@registry.register
def roles():
return [
Role(id=1, name="admin"),
Role(id=2, name="user"),
Role(id=role_id_1, name="admin"),
Role(id=role_id_2, name="user"),
]
result = await load_fixtures(
@@ -375,14 +404,16 @@ class TestLoadFixtures:
):
"""Load multiple independent fixtures."""
registry = FixtureRegistry()
role_id_1 = uuid.uuid4()
role_id_2 = uuid.uuid4()
@registry.register
def roles():
return [Role(id=1, name="admin")]
return [Role(id=role_id_1, name="admin")]
@registry.register
def other_roles():
return [Role(id=2, name="user")]
return [Role(id=role_id_2, name="user")]
result = await load_fixtures(db_session, registry, "roles", "other_roles")
@@ -402,14 +433,16 @@ class TestLoadFixturesByContext:
async def test_load_by_single_context(self, db_session: AsyncSession):
"""Load fixtures by single context."""
registry = FixtureRegistry()
base_role_id = uuid.uuid4()
test_role_id = uuid.uuid4()
@registry.register(contexts=[Context.BASE])
def base_roles():
return [Role(id=1, name="base_role")]
return [Role(id=base_role_id, name="base_role")]
@registry.register(contexts=[Context.TESTING])
def test_roles():
return [Role(id=100, name="test_role")]
return [Role(id=test_role_id, name="test_role")]
await load_fixtures_by_context(db_session, registry, Context.BASE)
@@ -418,7 +451,7 @@ class TestLoadFixturesByContext:
count = await RoleCrud.count(db_session)
assert count == 1
role = await RoleCrud.first(db_session, [Role.id == 1])
role = await RoleCrud.first(db_session, [Role.id == base_role_id])
assert role is not None
assert role.name == "base_role"
@@ -426,14 +459,16 @@ class TestLoadFixturesByContext:
async def test_load_by_multiple_contexts(self, db_session: AsyncSession):
"""Load fixtures by multiple contexts."""
registry = FixtureRegistry()
base_role_id = uuid.uuid4()
test_role_id = uuid.uuid4()
@registry.register(contexts=[Context.BASE])
def base_roles():
return [Role(id=1, name="base_role")]
return [Role(id=base_role_id, name="base_role")]
@registry.register(contexts=[Context.TESTING])
def test_roles():
return [Role(id=100, name="test_role")]
return [Role(id=test_role_id, name="test_role")]
await load_fixtures_by_context(
db_session, registry, Context.BASE, Context.TESTING
@@ -448,14 +483,23 @@ class TestLoadFixturesByContext:
async def test_load_context_with_dependencies(self, db_session: AsyncSession):
"""Load context fixtures with cross-context dependencies."""
registry = FixtureRegistry()
role_id = uuid.uuid4()
user_id = uuid.uuid4()
@registry.register(contexts=[Context.BASE])
def roles():
return [Role(id=1, name="admin")]
return [Role(id=role_id, name="admin")]
@registry.register(depends_on=["roles"], contexts=[Context.TESTING])
def test_users():
return [User(id=1, username="tester", email="test@test.com", role_id=1)]
return [
User(
id=user_id,
username="tester",
email="test@test.com",
role_id=role_id,
)
]
await load_fixtures_by_context(db_session, registry, Context.TESTING)
@@ -471,20 +515,41 @@ class TestGetObjByAttr:
def setup_method(self):
"""Set up test fixtures for each test."""
self.registry = FixtureRegistry()
self.role_id_1 = uuid.uuid4()
self.role_id_2 = uuid.uuid4()
self.role_id_3 = uuid.uuid4()
self.user_id_1 = uuid.uuid4()
self.user_id_2 = uuid.uuid4()
role_id_1 = self.role_id_1
role_id_2 = self.role_id_2
role_id_3 = self.role_id_3
user_id_1 = self.user_id_1
user_id_2 = self.user_id_2
@self.registry.register
def roles() -> list[Role]:
return [
Role(id=1, name="admin"),
Role(id=2, name="user"),
Role(id=3, name="moderator"),
Role(id=role_id_1, name="admin"),
Role(id=role_id_2, name="user"),
Role(id=role_id_3, name="moderator"),
]
@self.registry.register(depends_on=["roles"])
def users() -> list[User]:
return [
User(id=1, username="alice", email="alice@example.com", role_id=1),
User(id=2, username="bob", email="bob@example.com", role_id=1),
User(
id=user_id_1,
username="alice",
email="alice@example.com",
role_id=role_id_1,
),
User(
id=user_id_2,
username="bob",
email="bob@example.com",
role_id=role_id_1,
),
]
self.roles = roles
@@ -492,18 +557,18 @@ class TestGetObjByAttr:
def test_get_by_id(self):
"""Get an object by its id attribute."""
role = get_obj_by_attr(self.roles, "id", 1)
role = get_obj_by_attr(self.roles, "id", self.role_id_1)
assert role.name == "admin"
def test_get_user_by_username(self):
"""Get a user by username."""
user = get_obj_by_attr(self.users, "username", "bob")
assert user.id == 2
assert user.id == self.user_id_2
assert user.email == "bob@example.com"
def test_returns_first_match(self):
"""Returns the first matching object when multiple could match."""
user = get_obj_by_attr(self.users, "role_id", 1)
user = get_obj_by_attr(self.users, "role_id", self.role_id_1)
assert user.username == "alice"
def test_no_match_raises_stop_iteration(self):
@@ -514,4 +579,4 @@ class TestGetObjByAttr:
def test_no_match_on_wrong_value_type(self):
"""Raises StopIteration when value type doesn't match."""
with pytest.raises(StopIteration):
get_obj_by_attr(self.roles, "id", "1")
get_obj_by_attr(self.roles, "id", "not-a-uuid")

View File

@@ -1,5 +1,7 @@
"""Tests for fastapi_toolsets.pytest module."""
import uuid
import pytest
from fastapi import FastAPI
from httpx import AsyncClient
@@ -18,27 +20,49 @@ from .conftest import DATABASE_URL, Base, Role, RoleCrud, User, UserCrud
test_registry = FixtureRegistry()
# Fixed UUIDs for test fixtures to allow consistent assertions
ROLE_ADMIN_ID = uuid.UUID("00000000-0000-0000-0000-000000001000")
ROLE_USER_ID = uuid.UUID("00000000-0000-0000-0000-000000001001")
USER_ADMIN_ID = uuid.UUID("00000000-0000-0000-0000-000000002000")
USER_USER_ID = uuid.UUID("00000000-0000-0000-0000-000000002001")
USER_EXTRA_ID = uuid.UUID("00000000-0000-0000-0000-000000002002")
@test_registry.register(contexts=[Context.BASE])
def roles() -> list[Role]:
return [
Role(id=1000, name="plugin_admin"),
Role(id=1001, name="plugin_user"),
Role(id=ROLE_ADMIN_ID, name="plugin_admin"),
Role(id=ROLE_USER_ID, name="plugin_user"),
]
@test_registry.register(depends_on=["roles"], contexts=[Context.BASE])
def users() -> list[User]:
return [
User(id=1000, username="plugin_admin", email="padmin@test.com", role_id=1000),
User(id=1001, username="plugin_user", email="puser@test.com", role_id=1001),
User(
id=USER_ADMIN_ID,
username="plugin_admin",
email="padmin@test.com",
role_id=ROLE_ADMIN_ID,
),
User(
id=USER_USER_ID,
username="plugin_user",
email="puser@test.com",
role_id=ROLE_USER_ID,
),
]
@test_registry.register(depends_on=["users"], contexts=[Context.TESTING])
def extra_users() -> list[User]:
return [
User(id=1002, username="plugin_extra", email="pextra@test.com", role_id=1001),
User(
id=USER_EXTRA_ID,
username="plugin_extra",
email="pextra@test.com",
role_id=ROLE_USER_ID,
),
]
@@ -73,7 +97,7 @@ class TestGeneratedFixtures:
assert fixture_roles[1].name == "plugin_user"
# Verify data is in database
count = await RoleCrud.count(db_session, [Role.id >= 1000])
count = await RoleCrud.count(db_session)
assert count == 2
@pytest.mark.anyio
@@ -86,11 +110,11 @@ class TestGeneratedFixtures:
assert len(fixture_users) == 2
# Roles should also be in database
roles_count = await RoleCrud.count(db_session, [Role.id >= 1000])
roles_count = await RoleCrud.count(db_session)
assert roles_count == 2
# Users should be in database
users_count = await UserCrud.count(db_session, [User.id >= 1000])
users_count = await UserCrud.count(db_session)
assert users_count == 2
@pytest.mark.anyio
@@ -100,7 +124,7 @@ class TestGeneratedFixtures:
"""Fixture returns actual model instances."""
user = fixture_users[0]
assert isinstance(user, User)
assert user.id == 1000
assert user.id == USER_ADMIN_ID
assert user.username == "plugin_admin"
@pytest.mark.anyio
@@ -111,7 +135,7 @@ class TestGeneratedFixtures:
# Load user with role relationship
user = await UserCrud.get(
db_session,
[User.id == 1000],
[User.id == USER_ADMIN_ID],
load_options=[selectinload(User.role)],
)
@@ -127,8 +151,8 @@ class TestGeneratedFixtures:
assert len(fixture_extra_users) == 1
# All fixtures should be loaded
roles_count = await RoleCrud.count(db_session, [Role.id >= 1000])
users_count = await UserCrud.count(db_session, [User.id >= 1000])
roles_count = await RoleCrud.count(db_session)
users_count = await UserCrud.count(db_session)
assert roles_count == 2
assert users_count == 3 # 2 from users + 1 from extra_users
@@ -141,8 +165,7 @@ class TestGeneratedFixtures:
# Get all users loaded by fixture
users = await UserCrud.get_multi(
db_session,
filters=[User.id >= 1000],
order_by=User.id,
order_by=User.username,
)
assert len(users) == 2
@@ -161,8 +184,8 @@ class TestGeneratedFixtures:
assert len(fixture_users) == 2
# Both should be in database
roles = await RoleCrud.get_multi(db_session, filters=[Role.id >= 1000])
users = await UserCrud.get_multi(db_session, filters=[User.id >= 1000])
roles = await RoleCrud.get_multi(db_session)
users = await UserCrud.get_multi(db_session)
assert len(roles) == 2
assert len(users) == 2
@@ -215,14 +238,15 @@ class TestCreateDbSession:
@pytest.mark.anyio
async def test_creates_working_session(self):
"""Session can perform database operations."""
role_id = uuid.uuid4()
async with create_db_session(DATABASE_URL, Base) as session:
assert isinstance(session, AsyncSession)
role = Role(id=9001, name="test_helper_role")
role = Role(id=role_id, name="test_helper_role")
session.add(role)
await session.commit()
result = await session.execute(select(Role).where(Role.id == 9001))
result = await session.execute(select(Role).where(Role.id == role_id))
fetched = result.scalar_one()
assert fetched.name == "test_helper_role"
@@ -237,8 +261,9 @@ class TestCreateDbSession:
@pytest.mark.anyio
async def test_tables_dropped_after_session(self):
"""Tables are dropped after session closes when drop_tables=True."""
role_id = uuid.uuid4()
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
role = Role(id=9002, name="will_be_dropped")
role = Role(id=role_id, name="will_be_dropped")
session.add(role)
await session.commit()
@@ -250,14 +275,15 @@ class TestCreateDbSession:
@pytest.mark.anyio
async def test_tables_preserved_when_drop_disabled(self):
"""Tables are preserved when drop_tables=False."""
role_id = uuid.uuid4()
async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session:
role = Role(id=9003, name="preserved_role")
role = Role(id=role_id, name="preserved_role")
session.add(role)
await session.commit()
# Create another session without dropping
async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session:
result = await session.execute(select(Role).where(Role.id == 9003))
result = await session.execute(select(Role).where(Role.id == role_id))
fetched = result.scalar_one_or_none()
assert fetched is not None
assert fetched.name == "preserved_role"

2
uv.lock generated
View File

@@ -220,7 +220,7 @@ wheels = [
[[package]]
name = "fastapi-toolsets"
version = "0.3.0"
version = "0.4.1"
source = { editable = "." }
dependencies = [
{ name = "asyncpg" },