mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-03-01 17:00:48 +01:00
Compare commits
2 Commits
aa72dc2eb5
...
d14551781c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d14551781c | ||
|
|
577e087321 |
2
.github/workflows/build-release.yml
vendored
2
.github/workflows/build-release.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
uses: astral-sh/setup-uv@v7
|
||||
|
||||
- name: Set up Python
|
||||
run: uv python install 3.13
|
||||
run: uv python install 3.14
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync
|
||||
|
||||
4
.github/workflows/ci.yml
vendored
4
.github/workflows/ci.yml
vendored
@@ -56,7 +56,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.11", "3.12", "3.13"]
|
||||
python-version: ["3.11", "3.12", "3.13", "3.14"]
|
||||
|
||||
services:
|
||||
postgres:
|
||||
@@ -92,7 +92,7 @@ jobs:
|
||||
uv run pytest --cov --cov-report=xml --cov-report=term-missing
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
if: matrix.python-version == '3.13'
|
||||
if: matrix.python-version == '3.14'
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
@@ -1 +1 @@
|
||||
3.13
|
||||
3.14
|
||||
|
||||
@@ -24,6 +24,7 @@ classifiers = [
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Programming Language :: Python :: 3.14",
|
||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||
"Topic :: Software Development :: Libraries",
|
||||
"Topic :: Software Development",
|
||||
|
||||
17
src/fastapi_toolsets/crud/__init__.py
Normal file
17
src/fastapi_toolsets/crud/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Generic async CRUD operations for SQLAlchemy models."""
|
||||
|
||||
from ..exceptions import NoSearchableFieldsError
|
||||
from .factory import CrudFactory
|
||||
from .search import (
|
||||
SearchConfig,
|
||||
SearchFieldType,
|
||||
get_searchable_fields,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CrudFactory",
|
||||
"get_searchable_fields",
|
||||
"NoSearchableFieldsError",
|
||||
"SearchConfig",
|
||||
"SearchFieldType",
|
||||
]
|
||||
@@ -12,13 +12,9 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from sqlalchemy.sql.roles import WhereHavingRole
|
||||
|
||||
from .db import get_transaction
|
||||
from .exceptions import NotFoundError
|
||||
|
||||
__all__ = [
|
||||
"AsyncCrud",
|
||||
"CrudFactory",
|
||||
]
|
||||
from ..db import get_transaction
|
||||
from ..exceptions import NotFoundError
|
||||
from .search import SearchConfig, SearchFieldType, build_search_filters
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
||||
|
||||
@@ -27,20 +23,10 @@ class AsyncCrud(Generic[ModelType]):
|
||||
"""Generic async CRUD operations for SQLAlchemy models.
|
||||
|
||||
Subclass this and set the `model` class variable, or use `CrudFactory`.
|
||||
|
||||
Example:
|
||||
class UserCrud(AsyncCrud[User]):
|
||||
model = User
|
||||
|
||||
# Or use the factory:
|
||||
UserCrud = CrudFactory(User)
|
||||
|
||||
# Then use it:
|
||||
user = await UserCrud.get(session, [User.id == 1])
|
||||
users = await UserCrud.get_multi(session, limit=10)
|
||||
"""
|
||||
|
||||
model: ClassVar[type[DeclarativeBase]]
|
||||
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
@@ -313,6 +299,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
order_by: Any | None = None,
|
||||
page: int = 1,
|
||||
items_per_page: int = 20,
|
||||
search: str | SearchConfig | None = None,
|
||||
search_fields: Sequence[SearchFieldType] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Get paginated results with metadata.
|
||||
|
||||
@@ -323,23 +311,54 @@ class AsyncCrud(Generic[ModelType]):
|
||||
order_by: Column or list of columns to order by
|
||||
page: Page number (1-indexed)
|
||||
items_per_page: Number of items per page
|
||||
search: Search query string or SearchConfig object
|
||||
search_fields: Fields to search in (overrides class default)
|
||||
|
||||
Returns:
|
||||
Dict with 'data' and 'pagination' keys
|
||||
"""
|
||||
filters = filters or []
|
||||
filters = list(filters) if filters else []
|
||||
offset = (page - 1) * items_per_page
|
||||
joins: list[Any] = []
|
||||
|
||||
items = await cls.get_multi(
|
||||
session,
|
||||
filters=filters,
|
||||
load_options=load_options,
|
||||
order_by=order_by,
|
||||
limit=items_per_page,
|
||||
offset=offset,
|
||||
)
|
||||
# Build search filters
|
||||
if search:
|
||||
search_filters, search_joins = build_search_filters(
|
||||
cls.model,
|
||||
search,
|
||||
search_fields=search_fields,
|
||||
default_fields=cls.searchable_fields,
|
||||
)
|
||||
filters.extend(search_filters)
|
||||
joins.extend(search_joins)
|
||||
|
||||
total_count = await cls.count(session, filters=filters)
|
||||
# Build query with joins
|
||||
q = select(cls.model)
|
||||
for join_rel in joins:
|
||||
q = q.outerjoin(join_rel)
|
||||
|
||||
if filters:
|
||||
q = q.where(and_(*filters))
|
||||
if load_options:
|
||||
q = q.options(*load_options)
|
||||
if order_by is not None:
|
||||
q = q.order_by(order_by)
|
||||
|
||||
q = q.offset(offset).limit(items_per_page)
|
||||
result = await session.execute(q)
|
||||
items = result.unique().scalars().all()
|
||||
|
||||
# Count query (with same joins and filters)
|
||||
pk_col = cls.model.__mapper__.primary_key[0]
|
||||
count_q = select(func.count(func.distinct(getattr(cls.model, pk_col.name))))
|
||||
count_q = count_q.select_from(cls.model)
|
||||
for join_rel in joins:
|
||||
count_q = count_q.outerjoin(join_rel)
|
||||
if filters:
|
||||
count_q = count_q.where(and_(*filters))
|
||||
|
||||
count_result = await session.execute(count_q)
|
||||
total_count = count_result.scalar_one()
|
||||
|
||||
return {
|
||||
"data": items,
|
||||
@@ -354,11 +373,14 @@ class AsyncCrud(Generic[ModelType]):
|
||||
|
||||
def CrudFactory(
|
||||
model: type[ModelType],
|
||||
*,
|
||||
searchable_fields: Sequence[SearchFieldType] | None = None,
|
||||
) -> type[AsyncCrud[ModelType]]:
|
||||
"""Create a CRUD class for a specific model.
|
||||
|
||||
Args:
|
||||
model: SQLAlchemy model class
|
||||
searchable_fields: Optional list of searchable fields
|
||||
|
||||
Returns:
|
||||
AsyncCrud subclass bound to the model
|
||||
@@ -370,9 +392,25 @@ def CrudFactory(
|
||||
UserCrud = CrudFactory(User)
|
||||
PostCrud = CrudFactory(Post)
|
||||
|
||||
# With searchable fields:
|
||||
UserCrud = CrudFactory(
|
||||
User,
|
||||
searchable_fields=[User.username, User.email, (User.role, Role.name)]
|
||||
)
|
||||
|
||||
# Usage
|
||||
user = await UserCrud.get(session, [User.id == 1])
|
||||
posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id])
|
||||
|
||||
# With search
|
||||
result = await UserCrud.paginate(session, search="john")
|
||||
"""
|
||||
cls = type(f"Async{model.__name__}Crud", (AsyncCrud,), {"model": model})
|
||||
cls = type(
|
||||
f"Async{model.__name__}Crud",
|
||||
(AsyncCrud,),
|
||||
{
|
||||
"model": model,
|
||||
"searchable_fields": searchable_fields,
|
||||
},
|
||||
)
|
||||
return cast(type[AsyncCrud[ModelType]], cls)
|
||||
145
src/fastapi_toolsets/crud/search.py
Normal file
145
src/fastapi_toolsets/crud/search.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""Search utilities for AsyncCrud."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from sqlalchemy import String, or_
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||
|
||||
from ..exceptions import NoSearchableFieldsError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
|
||||
SearchFieldType = InstrumentedAttribute[Any] | tuple[InstrumentedAttribute[Any], ...]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchConfig:
|
||||
"""Advanced search configuration.
|
||||
|
||||
Attributes:
|
||||
query: The search string
|
||||
fields: Fields to search (columns or tuples for relationships)
|
||||
case_sensitive: Case-sensitive search (default: False)
|
||||
match_mode: "any" (OR) or "all" (AND) to combine fields
|
||||
"""
|
||||
|
||||
query: str
|
||||
fields: Sequence[SearchFieldType] | None = None
|
||||
case_sensitive: bool = False
|
||||
match_mode: Literal["any", "all"] = "any"
|
||||
|
||||
|
||||
def get_searchable_fields(
|
||||
model: type[DeclarativeBase],
|
||||
*,
|
||||
include_relationships: bool = True,
|
||||
max_depth: int = 1,
|
||||
) -> list[SearchFieldType]:
|
||||
"""Auto-detect String fields on a model and its relationships.
|
||||
|
||||
Args:
|
||||
model: SQLAlchemy model class
|
||||
include_relationships: Include fields from many-to-one/one-to-one relationships
|
||||
max_depth: Max depth for relationship traversal (default: 1)
|
||||
|
||||
Returns:
|
||||
List of columns and tuples (relationship, column)
|
||||
"""
|
||||
fields: list[SearchFieldType] = []
|
||||
mapper = model.__mapper__
|
||||
|
||||
# Direct String columns
|
||||
for col in mapper.columns:
|
||||
if isinstance(col.type, String):
|
||||
fields.append(getattr(model, col.key))
|
||||
|
||||
# Relationships (one-to-one, many-to-one only)
|
||||
if include_relationships and max_depth > 0:
|
||||
for rel_name, rel_prop in mapper.relationships.items():
|
||||
if rel_prop.uselist: # Skip collections (one-to-many, many-to-many)
|
||||
continue
|
||||
|
||||
rel_attr = getattr(model, rel_name)
|
||||
related_model = rel_prop.mapper.class_
|
||||
|
||||
for col in related_model.__mapper__.columns:
|
||||
if isinstance(col.type, String):
|
||||
fields.append((rel_attr, getattr(related_model, col.key)))
|
||||
|
||||
return fields
|
||||
|
||||
|
||||
def build_search_filters(
|
||||
model: type[DeclarativeBase],
|
||||
search: str | SearchConfig,
|
||||
search_fields: Sequence[SearchFieldType] | None = None,
|
||||
default_fields: Sequence[SearchFieldType] | None = None,
|
||||
) -> tuple[list["ColumnElement[bool]"], list[InstrumentedAttribute[Any]]]:
|
||||
"""Build SQLAlchemy filter conditions for search.
|
||||
|
||||
Args:
|
||||
model: SQLAlchemy model class
|
||||
search: Search string or SearchConfig
|
||||
search_fields: Fields specified per-call (takes priority)
|
||||
default_fields: Default fields (from ClassVar)
|
||||
|
||||
Returns:
|
||||
Tuple of (filter_conditions, joins_needed)
|
||||
"""
|
||||
# Normalize input
|
||||
if isinstance(search, str):
|
||||
config = SearchConfig(query=search, fields=search_fields)
|
||||
else:
|
||||
config = search
|
||||
if search_fields is not None:
|
||||
config = SearchConfig(
|
||||
query=config.query,
|
||||
fields=search_fields,
|
||||
case_sensitive=config.case_sensitive,
|
||||
match_mode=config.match_mode,
|
||||
)
|
||||
|
||||
if not config.query or not config.query.strip():
|
||||
return [], []
|
||||
|
||||
# Determine which fields to search
|
||||
fields = config.fields or default_fields or get_searchable_fields(model)
|
||||
|
||||
if not fields:
|
||||
raise NoSearchableFieldsError(model)
|
||||
|
||||
query = config.query.strip()
|
||||
filters: list[ColumnElement[bool]] = []
|
||||
joins: list[InstrumentedAttribute[Any]] = []
|
||||
added_joins: set[str] = set()
|
||||
|
||||
for field in fields:
|
||||
if isinstance(field, tuple):
|
||||
# Relationship: (User.role, Role.name) or deeper
|
||||
for rel in field[:-1]:
|
||||
rel_key = str(rel)
|
||||
if rel_key not in added_joins:
|
||||
joins.append(rel)
|
||||
added_joins.add(rel_key)
|
||||
column = field[-1]
|
||||
else:
|
||||
column = field
|
||||
|
||||
# Build the filter
|
||||
if config.case_sensitive:
|
||||
filters.append(column.like(f"%{query}%"))
|
||||
else:
|
||||
filters.append(column.ilike(f"%{query}%"))
|
||||
|
||||
if not filters:
|
||||
return [], []
|
||||
|
||||
# Combine based on match_mode
|
||||
if config.match_mode == "any":
|
||||
return [or_(*filters)], joins
|
||||
else:
|
||||
return filters, joins
|
||||
@@ -2,6 +2,7 @@ from .exceptions import (
|
||||
ApiException,
|
||||
ConflictError,
|
||||
ForbiddenError,
|
||||
NoSearchableFieldsError,
|
||||
NotFoundError,
|
||||
UnauthorizedError,
|
||||
generate_error_responses,
|
||||
@@ -14,6 +15,7 @@ __all__ = [
|
||||
"ApiException",
|
||||
"ConflictError",
|
||||
"ForbiddenError",
|
||||
"NoSearchableFieldsError",
|
||||
"NotFoundError",
|
||||
"UnauthorizedError",
|
||||
]
|
||||
|
||||
@@ -119,6 +119,25 @@ class RoleNotFoundError(NotFoundError):
|
||||
)
|
||||
|
||||
|
||||
class NoSearchableFieldsError(ApiException):
|
||||
"""Raised when search is requested but no searchable fields are available."""
|
||||
|
||||
api_error = ApiError(
|
||||
code=400,
|
||||
msg="No Searchable Fields",
|
||||
desc="No searchable fields configured for this resource.",
|
||||
err_code="SEARCH-400",
|
||||
)
|
||||
|
||||
def __init__(self, model: type) -> None:
|
||||
self.model = model
|
||||
detail = (
|
||||
f"No searchable fields found for model '{model.__name__}'. "
|
||||
"Provide 'search_fields' parameter or set 'searchable_fields' on the CRUD class."
|
||||
)
|
||||
super().__init__(detail)
|
||||
|
||||
|
||||
def generate_error_responses(
|
||||
*errors: type[ApiException],
|
||||
) -> dict[int | str, dict[str, Any]]:
|
||||
|
||||
@@ -3,7 +3,8 @@
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from fastapi_toolsets.crud import AsyncCrud, CrudFactory
|
||||
from fastapi_toolsets.crud import CrudFactory
|
||||
from fastapi_toolsets.crud.factory import AsyncCrud
|
||||
from fastapi_toolsets.exceptions import NotFoundError
|
||||
|
||||
from .conftest import (
|
||||
|
||||
392
tests/test_crud_search.py
Normal file
392
tests/test_crud_search.py
Normal file
@@ -0,0 +1,392 @@
|
||||
"""Tests for CRUD search functionality."""
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from fastapi_toolsets.crud import SearchConfig, get_searchable_fields
|
||||
|
||||
from .conftest import (
|
||||
Role,
|
||||
RoleCreate,
|
||||
RoleCrud,
|
||||
User,
|
||||
UserCreate,
|
||||
UserCrud,
|
||||
)
|
||||
|
||||
|
||||
class TestPaginateSearch:
|
||||
"""Tests for paginate() with search."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_single_column(self, db_session: AsyncSession):
|
||||
"""Search on a single direct column."""
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="john_doe", email="john@test.com")
|
||||
)
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="jane_doe", email="jane@test.com")
|
||||
)
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="bob_smith", email="bob@test.com")
|
||||
)
|
||||
|
||||
result = await UserCrud.paginate(
|
||||
db_session,
|
||||
search="doe",
|
||||
search_fields=[User.username],
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_multiple_columns(self, db_session: AsyncSession):
|
||||
"""Search across multiple columns (OR logic)."""
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="alice", email="alice@company.com")
|
||||
)
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="company_bob", email="bob@other.com")
|
||||
)
|
||||
|
||||
result = await UserCrud.paginate(
|
||||
db_session,
|
||||
search="company",
|
||||
search_fields=[User.username, User.email],
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_relationship_depth1(self, db_session: AsyncSession):
|
||||
"""Search through a relationship (depth 1)."""
|
||||
admin_role = await RoleCrud.create(db_session, RoleCreate(name="administrator"))
|
||||
user_role = await RoleCrud.create(db_session, RoleCreate(name="basic_user"))
|
||||
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="admin1", email="a1@test.com", role_id=admin_role.id),
|
||||
)
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="admin2", email="a2@test.com", role_id=admin_role.id),
|
||||
)
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="user1", email="u1@test.com", role_id=user_role.id),
|
||||
)
|
||||
|
||||
result = await UserCrud.paginate(
|
||||
db_session,
|
||||
search="admin",
|
||||
search_fields=[(User.role, Role.name)],
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_mixed_direct_and_relation(self, db_session: AsyncSession):
|
||||
"""Search combining direct columns and relationships."""
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="john", email="john@test.com", role_id=role.id),
|
||||
)
|
||||
|
||||
# Search "admin" in username OR role.name
|
||||
result = await UserCrud.paginate(
|
||||
db_session,
|
||||
search="admin",
|
||||
search_fields=[User.username, (User.role, Role.name)],
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_case_insensitive(self, db_session: AsyncSession):
|
||||
"""Search is case-insensitive by default."""
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="JohnDoe", email="j@test.com")
|
||||
)
|
||||
|
||||
result = await UserCrud.paginate(
|
||||
db_session,
|
||||
search="johndoe",
|
||||
search_fields=[User.username],
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_case_sensitive(self, db_session: AsyncSession):
|
||||
"""Case-sensitive search with SearchConfig."""
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="JohnDoe", email="j@test.com")
|
||||
)
|
||||
|
||||
# Should not find (case mismatch)
|
||||
result = await UserCrud.paginate(
|
||||
db_session,
|
||||
search=SearchConfig(query="johndoe", case_sensitive=True),
|
||||
search_fields=[User.username],
|
||||
)
|
||||
assert result["pagination"]["total_count"] == 0
|
||||
|
||||
# Should find (case match)
|
||||
result = await UserCrud.paginate(
|
||||
db_session,
|
||||
search=SearchConfig(query="JohnDoe", case_sensitive=True),
|
||||
search_fields=[User.username],
|
||||
)
|
||||
assert result["pagination"]["total_count"] == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_empty_query(self, db_session: AsyncSession):
|
||||
"""Empty search returns all results."""
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="user1", email="u1@test.com")
|
||||
)
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="user2", email="u2@test.com")
|
||||
)
|
||||
|
||||
result = await UserCrud.paginate(db_session, search="")
|
||||
assert result["pagination"]["total_count"] == 2
|
||||
|
||||
result = await UserCrud.paginate(db_session, search=None)
|
||||
assert result["pagination"]["total_count"] == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_with_existing_filters(self, db_session: AsyncSession):
|
||||
"""Search combines with existing filters (AND)."""
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="active_john", email="aj@test.com", is_active=True),
|
||||
)
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="inactive_john", email="ij@test.com", is_active=False),
|
||||
)
|
||||
|
||||
result = await UserCrud.paginate(
|
||||
db_session,
|
||||
filters=[User.is_active == True], # noqa: E712
|
||||
search="john",
|
||||
search_fields=[User.username],
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 1
|
||||
assert result["data"][0].username == "active_john"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_auto_detect_fields(self, db_session: AsyncSession):
|
||||
"""Auto-detect searchable fields when not specified."""
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="findme", email="other@test.com")
|
||||
)
|
||||
|
||||
result = await UserCrud.paginate(db_session, search="findme")
|
||||
|
||||
assert result["pagination"]["total_count"] == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_no_results(self, db_session: AsyncSession):
|
||||
"""Search with no matching results."""
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="john", email="j@test.com")
|
||||
)
|
||||
|
||||
result = await UserCrud.paginate(
|
||||
db_session,
|
||||
search="nonexistent",
|
||||
search_fields=[User.username],
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 0
|
||||
assert result["data"] == []
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_with_pagination(self, db_session: AsyncSession):
|
||||
"""Search respects pagination parameters."""
|
||||
for i in range(15):
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username=f"user_{i}", email=f"user{i}@test.com"),
|
||||
)
|
||||
|
||||
result = await UserCrud.paginate(
|
||||
db_session,
|
||||
search="user_",
|
||||
search_fields=[User.username],
|
||||
page=1,
|
||||
items_per_page=5,
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 15
|
||||
assert len(result["data"]) == 5
|
||||
assert result["pagination"]["has_more"] is True
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_null_relationship(self, db_session: AsyncSession):
|
||||
"""Users without relationship are included (outerjoin)."""
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="with_role", email="wr@test.com", role_id=role.id),
|
||||
)
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="no_role", email="nr@test.com", role_id=None),
|
||||
)
|
||||
|
||||
# Search in username, not in role
|
||||
result = await UserCrud.paginate(
|
||||
db_session,
|
||||
search="role",
|
||||
search_fields=[User.username],
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_with_order_by(self, db_session: AsyncSession):
|
||||
"""Search works with order_by parameter."""
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="charlie", email="c@test.com")
|
||||
)
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="alice", email="a@test.com")
|
||||
)
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="bob", email="b@test.com")
|
||||
)
|
||||
|
||||
result = await UserCrud.paginate(
|
||||
db_session,
|
||||
search="@test.com",
|
||||
search_fields=[User.email],
|
||||
order_by=User.username,
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 3
|
||||
usernames = [u.username for u in result["data"]]
|
||||
assert usernames == ["alice", "bob", "charlie"]
|
||||
|
||||
|
||||
class TestSearchConfig:
|
||||
"""Tests for SearchConfig options."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_match_mode_all(self, db_session: AsyncSession):
|
||||
"""match_mode='all' requires all fields to match (AND)."""
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="john_test", email="john_test@company.com"),
|
||||
)
|
||||
await UserCrud.create(
|
||||
db_session,
|
||||
UserCreate(username="john_other", email="other@example.com"),
|
||||
)
|
||||
|
||||
# 'john' must be in username AND email
|
||||
result = await UserCrud.paginate(
|
||||
db_session,
|
||||
search=SearchConfig(query="john", match_mode="all"),
|
||||
search_fields=[User.username, User.email],
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 1
|
||||
assert result["data"][0].username == "john_test"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_config_with_fields(self, db_session: AsyncSession):
|
||||
"""SearchConfig can specify fields directly."""
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="test", email="findme@test.com")
|
||||
)
|
||||
|
||||
result = await UserCrud.paginate(
|
||||
db_session,
|
||||
search=SearchConfig(query="findme", fields=[User.email]),
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 1
|
||||
|
||||
|
||||
class TestNoSearchableFieldsError:
|
||||
"""Tests for NoSearchableFieldsError exception."""
|
||||
|
||||
def test_error_is_api_exception(self):
|
||||
"""NoSearchableFieldsError inherits from ApiException."""
|
||||
from fastapi_toolsets.exceptions import ApiException, NoSearchableFieldsError
|
||||
|
||||
assert issubclass(NoSearchableFieldsError, ApiException)
|
||||
|
||||
def test_error_has_api_error_fields(self):
|
||||
"""NoSearchableFieldsError has proper api_error configuration."""
|
||||
from fastapi_toolsets.exceptions import NoSearchableFieldsError
|
||||
|
||||
assert NoSearchableFieldsError.api_error.code == 400
|
||||
assert NoSearchableFieldsError.api_error.err_code == "SEARCH-400"
|
||||
|
||||
def test_error_message_contains_model_name(self):
|
||||
"""Error message includes the model name."""
|
||||
from fastapi_toolsets.exceptions import NoSearchableFieldsError
|
||||
|
||||
error = NoSearchableFieldsError(User)
|
||||
assert "User" in str(error)
|
||||
assert error.model is User
|
||||
|
||||
def test_error_raised_when_no_fields(self):
|
||||
"""Error is raised when search has no searchable fields."""
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
from fastapi_toolsets.crud.search import build_search_filters
|
||||
from fastapi_toolsets.exceptions import NoSearchableFieldsError
|
||||
|
||||
# Model with no String columns
|
||||
class NoStringBase(DeclarativeBase):
|
||||
pass
|
||||
|
||||
class NoStringModel(NoStringBase):
|
||||
__tablename__ = "no_strings"
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
|
||||
with pytest.raises(NoSearchableFieldsError) as exc_info:
|
||||
build_search_filters(NoStringModel, "test")
|
||||
|
||||
assert exc_info.value.model is NoStringModel
|
||||
assert "NoStringModel" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestGetSearchableFields:
|
||||
"""Tests for auto-detection of searchable fields."""
|
||||
|
||||
def test_detects_string_columns(self):
|
||||
"""Detects String columns on the model."""
|
||||
fields = get_searchable_fields(User, include_relationships=False)
|
||||
|
||||
# Should include username and email (String), not id or is_active
|
||||
field_names = [str(f) for f in fields]
|
||||
assert any("username" in f for f in field_names)
|
||||
assert any("email" in f for f in field_names)
|
||||
assert not any("id" in f and "role_id" not in f for f in field_names)
|
||||
assert not any("is_active" in f for f in field_names)
|
||||
|
||||
def test_detects_relationship_fields(self):
|
||||
"""Detects String fields on related models."""
|
||||
fields = get_searchable_fields(User, include_relationships=True)
|
||||
|
||||
# Should include (User.role, Role.name)
|
||||
has_role_name = any(isinstance(f, tuple) and len(f) == 2 for f in fields)
|
||||
assert has_role_name
|
||||
|
||||
def test_skips_collection_relationships(self):
|
||||
"""Skips one-to-many relationships."""
|
||||
fields = get_searchable_fields(Role, include_relationships=True)
|
||||
|
||||
# Role.users is a collection, should not be included
|
||||
field_strs = [str(f) for f in fields]
|
||||
assert not any("users" in f for f in field_strs)
|
||||
Reference in New Issue
Block a user