3 Commits

Author SHA1 Message Date
a76f7c439d Version 0.4.0 2026-01-29 09:15:33 -05:00
d3vyce
d14551781c feat: add search to crud paginate function (#17)
* feat: add search to crud paginate function

* fixes: comments + tests import
2026-01-29 00:08:02 +01:00
d3vyce
577e087321 feat: add support for python 3.14 (#15) 2026-01-28 21:01:15 +01:00
13 changed files with 652 additions and 37 deletions

View File

@@ -17,7 +17,7 @@ jobs:
uses: astral-sh/setup-uv@v7 uses: astral-sh/setup-uv@v7
- name: Set up Python - name: Set up Python
run: uv python install 3.13 run: uv python install 3.14
- name: Install dependencies - name: Install dependencies
run: uv sync run: uv sync

View File

@@ -56,7 +56,7 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
python-version: ["3.11", "3.12", "3.13"] python-version: ["3.11", "3.12", "3.13", "3.14"]
services: services:
postgres: postgres:
@@ -92,7 +92,7 @@ jobs:
uv run pytest --cov --cov-report=xml --cov-report=term-missing uv run pytest --cov --cov-report=xml --cov-report=term-missing
- name: Upload coverage to Codecov - name: Upload coverage to Codecov
if: matrix.python-version == '3.13' if: matrix.python-version == '3.14'
uses: codecov/codecov-action@v5 uses: codecov/codecov-action@v5
with: with:
token: ${{ secrets.CODECOV_TOKEN }} token: ${{ secrets.CODECOV_TOKEN }}

View File

@@ -1 +1 @@
3.13 3.14

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "fastapi-toolsets" name = "fastapi-toolsets"
version = "0.3.0" version = "0.4.0"
description = "Reusable tools for FastAPI: async CRUD, fixtures, CLI, and standardized responses for SQLAlchemy + PostgreSQL" description = "Reusable tools for FastAPI: async CRUD, fixtures, CLI, and standardized responses for SQLAlchemy + PostgreSQL"
readme = "README.md" readme = "README.md"
license = "MIT" license = "MIT"
@@ -24,6 +24,7 @@ classifiers = [
"Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13", "Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
"Topic :: Software Development :: Libraries :: Python Modules", "Topic :: Software Development :: Libraries :: Python Modules",
"Topic :: Software Development :: Libraries", "Topic :: Software Development :: Libraries",
"Topic :: Software Development", "Topic :: Software Development",

View File

@@ -21,4 +21,4 @@ Example usage:
return Response(data={"user": user.username}, message="Success") return Response(data={"user": user.username}, message="Success")
""" """
__version__ = "0.3.0" __version__ = "0.4.0"

View File

@@ -0,0 +1,17 @@
"""Generic async CRUD operations for SQLAlchemy models."""
from ..exceptions import NoSearchableFieldsError
from .factory import CrudFactory
from .search import (
SearchConfig,
SearchFieldType,
get_searchable_fields,
)
__all__ = [
"CrudFactory",
"get_searchable_fields",
"NoSearchableFieldsError",
"SearchConfig",
"SearchFieldType",
]

View File

@@ -12,13 +12,9 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.sql.roles import WhereHavingRole from sqlalchemy.sql.roles import WhereHavingRole
from .db import get_transaction from ..db import get_transaction
from .exceptions import NotFoundError from ..exceptions import NotFoundError
from .search import SearchConfig, SearchFieldType, build_search_filters
__all__ = [
"AsyncCrud",
"CrudFactory",
]
ModelType = TypeVar("ModelType", bound=DeclarativeBase) ModelType = TypeVar("ModelType", bound=DeclarativeBase)
@@ -27,20 +23,10 @@ class AsyncCrud(Generic[ModelType]):
"""Generic async CRUD operations for SQLAlchemy models. """Generic async CRUD operations for SQLAlchemy models.
Subclass this and set the `model` class variable, or use `CrudFactory`. Subclass this and set the `model` class variable, or use `CrudFactory`.
Example:
class UserCrud(AsyncCrud[User]):
model = User
# Or use the factory:
UserCrud = CrudFactory(User)
# Then use it:
user = await UserCrud.get(session, [User.id == 1])
users = await UserCrud.get_multi(session, limit=10)
""" """
model: ClassVar[type[DeclarativeBase]] model: ClassVar[type[DeclarativeBase]]
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
@classmethod @classmethod
async def create( async def create(
@@ -313,6 +299,8 @@ class AsyncCrud(Generic[ModelType]):
order_by: Any | None = None, order_by: Any | None = None,
page: int = 1, page: int = 1,
items_per_page: int = 20, items_per_page: int = 20,
search: str | SearchConfig | None = None,
search_fields: Sequence[SearchFieldType] | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Get paginated results with metadata. """Get paginated results with metadata.
@@ -323,23 +311,54 @@ class AsyncCrud(Generic[ModelType]):
order_by: Column or list of columns to order by order_by: Column or list of columns to order by
page: Page number (1-indexed) page: Page number (1-indexed)
items_per_page: Number of items per page items_per_page: Number of items per page
search: Search query string or SearchConfig object
search_fields: Fields to search in (overrides class default)
Returns: Returns:
Dict with 'data' and 'pagination' keys Dict with 'data' and 'pagination' keys
""" """
filters = filters or [] filters = list(filters) if filters else []
offset = (page - 1) * items_per_page offset = (page - 1) * items_per_page
joins: list[Any] = []
items = await cls.get_multi( # Build search filters
session, if search:
filters=filters, search_filters, search_joins = build_search_filters(
load_options=load_options, cls.model,
order_by=order_by, search,
limit=items_per_page, search_fields=search_fields,
offset=offset, default_fields=cls.searchable_fields,
) )
filters.extend(search_filters)
joins.extend(search_joins)
total_count = await cls.count(session, filters=filters) # Build query with joins
q = select(cls.model)
for join_rel in joins:
q = q.outerjoin(join_rel)
if filters:
q = q.where(and_(*filters))
if load_options:
q = q.options(*load_options)
if order_by is not None:
q = q.order_by(order_by)
q = q.offset(offset).limit(items_per_page)
result = await session.execute(q)
items = result.unique().scalars().all()
# Count query (with same joins and filters)
pk_col = cls.model.__mapper__.primary_key[0]
count_q = select(func.count(func.distinct(getattr(cls.model, pk_col.name))))
count_q = count_q.select_from(cls.model)
for join_rel in joins:
count_q = count_q.outerjoin(join_rel)
if filters:
count_q = count_q.where(and_(*filters))
count_result = await session.execute(count_q)
total_count = count_result.scalar_one()
return { return {
"data": items, "data": items,
@@ -354,11 +373,14 @@ class AsyncCrud(Generic[ModelType]):
def CrudFactory( def CrudFactory(
model: type[ModelType], model: type[ModelType],
*,
searchable_fields: Sequence[SearchFieldType] | None = None,
) -> type[AsyncCrud[ModelType]]: ) -> type[AsyncCrud[ModelType]]:
"""Create a CRUD class for a specific model. """Create a CRUD class for a specific model.
Args: Args:
model: SQLAlchemy model class model: SQLAlchemy model class
searchable_fields: Optional list of searchable fields
Returns: Returns:
AsyncCrud subclass bound to the model AsyncCrud subclass bound to the model
@@ -370,9 +392,25 @@ def CrudFactory(
UserCrud = CrudFactory(User) UserCrud = CrudFactory(User)
PostCrud = CrudFactory(Post) PostCrud = CrudFactory(Post)
# With searchable fields:
UserCrud = CrudFactory(
User,
searchable_fields=[User.username, User.email, (User.role, Role.name)]
)
# Usage # Usage
user = await UserCrud.get(session, [User.id == 1]) user = await UserCrud.get(session, [User.id == 1])
posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id]) posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id])
# With search
result = await UserCrud.paginate(session, search="john")
""" """
cls = type(f"Async{model.__name__}Crud", (AsyncCrud,), {"model": model}) cls = type(
f"Async{model.__name__}Crud",
(AsyncCrud,),
{
"model": model,
"searchable_fields": searchable_fields,
},
)
return cast(type[AsyncCrud[ModelType]], cls) return cast(type[AsyncCrud[ModelType]], cls)

View File

@@ -0,0 +1,145 @@
"""Search utilities for AsyncCrud."""
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal
from sqlalchemy import String, or_
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm.attributes import InstrumentedAttribute
from ..exceptions import NoSearchableFieldsError
if TYPE_CHECKING:
from sqlalchemy.sql.elements import ColumnElement
SearchFieldType = InstrumentedAttribute[Any] | tuple[InstrumentedAttribute[Any], ...]
@dataclass
class SearchConfig:
"""Advanced search configuration.
Attributes:
query: The search string
fields: Fields to search (columns or tuples for relationships)
case_sensitive: Case-sensitive search (default: False)
match_mode: "any" (OR) or "all" (AND) to combine fields
"""
query: str
fields: Sequence[SearchFieldType] | None = None
case_sensitive: bool = False
match_mode: Literal["any", "all"] = "any"
def get_searchable_fields(
model: type[DeclarativeBase],
*,
include_relationships: bool = True,
max_depth: int = 1,
) -> list[SearchFieldType]:
"""Auto-detect String fields on a model and its relationships.
Args:
model: SQLAlchemy model class
include_relationships: Include fields from many-to-one/one-to-one relationships
max_depth: Max depth for relationship traversal (default: 1)
Returns:
List of columns and tuples (relationship, column)
"""
fields: list[SearchFieldType] = []
mapper = model.__mapper__
# Direct String columns
for col in mapper.columns:
if isinstance(col.type, String):
fields.append(getattr(model, col.key))
# Relationships (one-to-one, many-to-one only)
if include_relationships and max_depth > 0:
for rel_name, rel_prop in mapper.relationships.items():
if rel_prop.uselist: # Skip collections (one-to-many, many-to-many)
continue
rel_attr = getattr(model, rel_name)
related_model = rel_prop.mapper.class_
for col in related_model.__mapper__.columns:
if isinstance(col.type, String):
fields.append((rel_attr, getattr(related_model, col.key)))
return fields
def build_search_filters(
model: type[DeclarativeBase],
search: str | SearchConfig,
search_fields: Sequence[SearchFieldType] | None = None,
default_fields: Sequence[SearchFieldType] | None = None,
) -> tuple[list["ColumnElement[bool]"], list[InstrumentedAttribute[Any]]]:
"""Build SQLAlchemy filter conditions for search.
Args:
model: SQLAlchemy model class
search: Search string or SearchConfig
search_fields: Fields specified per-call (takes priority)
default_fields: Default fields (from ClassVar)
Returns:
Tuple of (filter_conditions, joins_needed)
"""
# Normalize input
if isinstance(search, str):
config = SearchConfig(query=search, fields=search_fields)
else:
config = search
if search_fields is not None:
config = SearchConfig(
query=config.query,
fields=search_fields,
case_sensitive=config.case_sensitive,
match_mode=config.match_mode,
)
if not config.query or not config.query.strip():
return [], []
# Determine which fields to search
fields = config.fields or default_fields or get_searchable_fields(model)
if not fields:
raise NoSearchableFieldsError(model)
query = config.query.strip()
filters: list[ColumnElement[bool]] = []
joins: list[InstrumentedAttribute[Any]] = []
added_joins: set[str] = set()
for field in fields:
if isinstance(field, tuple):
# Relationship: (User.role, Role.name) or deeper
for rel in field[:-1]:
rel_key = str(rel)
if rel_key not in added_joins:
joins.append(rel)
added_joins.add(rel_key)
column = field[-1]
else:
column = field
# Build the filter
if config.case_sensitive:
filters.append(column.like(f"%{query}%"))
else:
filters.append(column.ilike(f"%{query}%"))
if not filters:
return [], []
# Combine based on match_mode
if config.match_mode == "any":
return [or_(*filters)], joins
else:
return filters, joins

View File

@@ -2,6 +2,7 @@ from .exceptions import (
ApiException, ApiException,
ConflictError, ConflictError,
ForbiddenError, ForbiddenError,
NoSearchableFieldsError,
NotFoundError, NotFoundError,
UnauthorizedError, UnauthorizedError,
generate_error_responses, generate_error_responses,
@@ -14,6 +15,7 @@ __all__ = [
"ApiException", "ApiException",
"ConflictError", "ConflictError",
"ForbiddenError", "ForbiddenError",
"NoSearchableFieldsError",
"NotFoundError", "NotFoundError",
"UnauthorizedError", "UnauthorizedError",
] ]

View File

@@ -119,6 +119,25 @@ class RoleNotFoundError(NotFoundError):
) )
class NoSearchableFieldsError(ApiException):
"""Raised when search is requested but no searchable fields are available."""
api_error = ApiError(
code=400,
msg="No Searchable Fields",
desc="No searchable fields configured for this resource.",
err_code="SEARCH-400",
)
def __init__(self, model: type) -> None:
self.model = model
detail = (
f"No searchable fields found for model '{model.__name__}'. "
"Provide 'search_fields' parameter or set 'searchable_fields' on the CRUD class."
)
super().__init__(detail)
def generate_error_responses( def generate_error_responses(
*errors: type[ApiException], *errors: type[ApiException],
) -> dict[int | str, dict[str, Any]]: ) -> dict[int | str, dict[str, Any]]:

View File

@@ -3,7 +3,8 @@
import pytest import pytest
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from fastapi_toolsets.crud import AsyncCrud, CrudFactory from fastapi_toolsets.crud import CrudFactory
from fastapi_toolsets.crud.factory import AsyncCrud
from fastapi_toolsets.exceptions import NotFoundError from fastapi_toolsets.exceptions import NotFoundError
from .conftest import ( from .conftest import (

392
tests/test_crud_search.py Normal file
View File

@@ -0,0 +1,392 @@
"""Tests for CRUD search functionality."""
import pytest
from sqlalchemy.ext.asyncio import AsyncSession
from fastapi_toolsets.crud import SearchConfig, get_searchable_fields
from .conftest import (
Role,
RoleCreate,
RoleCrud,
User,
UserCreate,
UserCrud,
)
class TestPaginateSearch:
"""Tests for paginate() with search."""
@pytest.mark.anyio
async def test_search_single_column(self, db_session: AsyncSession):
"""Search on a single direct column."""
await UserCrud.create(
db_session, UserCreate(username="john_doe", email="john@test.com")
)
await UserCrud.create(
db_session, UserCreate(username="jane_doe", email="jane@test.com")
)
await UserCrud.create(
db_session, UserCreate(username="bob_smith", email="bob@test.com")
)
result = await UserCrud.paginate(
db_session,
search="doe",
search_fields=[User.username],
)
assert result["pagination"]["total_count"] == 2
@pytest.mark.anyio
async def test_search_multiple_columns(self, db_session: AsyncSession):
"""Search across multiple columns (OR logic)."""
await UserCrud.create(
db_session, UserCreate(username="alice", email="alice@company.com")
)
await UserCrud.create(
db_session, UserCreate(username="company_bob", email="bob@other.com")
)
result = await UserCrud.paginate(
db_session,
search="company",
search_fields=[User.username, User.email],
)
assert result["pagination"]["total_count"] == 2
@pytest.mark.anyio
async def test_search_relationship_depth1(self, db_session: AsyncSession):
"""Search through a relationship (depth 1)."""
admin_role = await RoleCrud.create(db_session, RoleCreate(name="administrator"))
user_role = await RoleCrud.create(db_session, RoleCreate(name="basic_user"))
await UserCrud.create(
db_session,
UserCreate(username="admin1", email="a1@test.com", role_id=admin_role.id),
)
await UserCrud.create(
db_session,
UserCreate(username="admin2", email="a2@test.com", role_id=admin_role.id),
)
await UserCrud.create(
db_session,
UserCreate(username="user1", email="u1@test.com", role_id=user_role.id),
)
result = await UserCrud.paginate(
db_session,
search="admin",
search_fields=[(User.role, Role.name)],
)
assert result["pagination"]["total_count"] == 2
@pytest.mark.anyio
async def test_search_mixed_direct_and_relation(self, db_session: AsyncSession):
"""Search combining direct columns and relationships."""
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
await UserCrud.create(
db_session,
UserCreate(username="john", email="john@test.com", role_id=role.id),
)
# Search "admin" in username OR role.name
result = await UserCrud.paginate(
db_session,
search="admin",
search_fields=[User.username, (User.role, Role.name)],
)
assert result["pagination"]["total_count"] == 1
@pytest.mark.anyio
async def test_search_case_insensitive(self, db_session: AsyncSession):
"""Search is case-insensitive by default."""
await UserCrud.create(
db_session, UserCreate(username="JohnDoe", email="j@test.com")
)
result = await UserCrud.paginate(
db_session,
search="johndoe",
search_fields=[User.username],
)
assert result["pagination"]["total_count"] == 1
@pytest.mark.anyio
async def test_search_case_sensitive(self, db_session: AsyncSession):
"""Case-sensitive search with SearchConfig."""
await UserCrud.create(
db_session, UserCreate(username="JohnDoe", email="j@test.com")
)
# Should not find (case mismatch)
result = await UserCrud.paginate(
db_session,
search=SearchConfig(query="johndoe", case_sensitive=True),
search_fields=[User.username],
)
assert result["pagination"]["total_count"] == 0
# Should find (case match)
result = await UserCrud.paginate(
db_session,
search=SearchConfig(query="JohnDoe", case_sensitive=True),
search_fields=[User.username],
)
assert result["pagination"]["total_count"] == 1
@pytest.mark.anyio
async def test_search_empty_query(self, db_session: AsyncSession):
"""Empty search returns all results."""
await UserCrud.create(
db_session, UserCreate(username="user1", email="u1@test.com")
)
await UserCrud.create(
db_session, UserCreate(username="user2", email="u2@test.com")
)
result = await UserCrud.paginate(db_session, search="")
assert result["pagination"]["total_count"] == 2
result = await UserCrud.paginate(db_session, search=None)
assert result["pagination"]["total_count"] == 2
@pytest.mark.anyio
async def test_search_with_existing_filters(self, db_session: AsyncSession):
"""Search combines with existing filters (AND)."""
await UserCrud.create(
db_session,
UserCreate(username="active_john", email="aj@test.com", is_active=True),
)
await UserCrud.create(
db_session,
UserCreate(username="inactive_john", email="ij@test.com", is_active=False),
)
result = await UserCrud.paginate(
db_session,
filters=[User.is_active == True], # noqa: E712
search="john",
search_fields=[User.username],
)
assert result["pagination"]["total_count"] == 1
assert result["data"][0].username == "active_john"
@pytest.mark.anyio
async def test_search_auto_detect_fields(self, db_session: AsyncSession):
"""Auto-detect searchable fields when not specified."""
await UserCrud.create(
db_session, UserCreate(username="findme", email="other@test.com")
)
result = await UserCrud.paginate(db_session, search="findme")
assert result["pagination"]["total_count"] == 1
@pytest.mark.anyio
async def test_search_no_results(self, db_session: AsyncSession):
"""Search with no matching results."""
await UserCrud.create(
db_session, UserCreate(username="john", email="j@test.com")
)
result = await UserCrud.paginate(
db_session,
search="nonexistent",
search_fields=[User.username],
)
assert result["pagination"]["total_count"] == 0
assert result["data"] == []
@pytest.mark.anyio
async def test_search_with_pagination(self, db_session: AsyncSession):
"""Search respects pagination parameters."""
for i in range(15):
await UserCrud.create(
db_session,
UserCreate(username=f"user_{i}", email=f"user{i}@test.com"),
)
result = await UserCrud.paginate(
db_session,
search="user_",
search_fields=[User.username],
page=1,
items_per_page=5,
)
assert result["pagination"]["total_count"] == 15
assert len(result["data"]) == 5
assert result["pagination"]["has_more"] is True
@pytest.mark.anyio
async def test_search_null_relationship(self, db_session: AsyncSession):
"""Users without relationship are included (outerjoin)."""
role = await RoleCrud.create(db_session, RoleCreate(name="admin"))
await UserCrud.create(
db_session,
UserCreate(username="with_role", email="wr@test.com", role_id=role.id),
)
await UserCrud.create(
db_session,
UserCreate(username="no_role", email="nr@test.com", role_id=None),
)
# Search in username, not in role
result = await UserCrud.paginate(
db_session,
search="role",
search_fields=[User.username],
)
assert result["pagination"]["total_count"] == 2
@pytest.mark.anyio
async def test_search_with_order_by(self, db_session: AsyncSession):
"""Search works with order_by parameter."""
await UserCrud.create(
db_session, UserCreate(username="charlie", email="c@test.com")
)
await UserCrud.create(
db_session, UserCreate(username="alice", email="a@test.com")
)
await UserCrud.create(
db_session, UserCreate(username="bob", email="b@test.com")
)
result = await UserCrud.paginate(
db_session,
search="@test.com",
search_fields=[User.email],
order_by=User.username,
)
assert result["pagination"]["total_count"] == 3
usernames = [u.username for u in result["data"]]
assert usernames == ["alice", "bob", "charlie"]
class TestSearchConfig:
"""Tests for SearchConfig options."""
@pytest.mark.anyio
async def test_match_mode_all(self, db_session: AsyncSession):
"""match_mode='all' requires all fields to match (AND)."""
await UserCrud.create(
db_session,
UserCreate(username="john_test", email="john_test@company.com"),
)
await UserCrud.create(
db_session,
UserCreate(username="john_other", email="other@example.com"),
)
# 'john' must be in username AND email
result = await UserCrud.paginate(
db_session,
search=SearchConfig(query="john", match_mode="all"),
search_fields=[User.username, User.email],
)
assert result["pagination"]["total_count"] == 1
assert result["data"][0].username == "john_test"
@pytest.mark.anyio
async def test_search_config_with_fields(self, db_session: AsyncSession):
"""SearchConfig can specify fields directly."""
await UserCrud.create(
db_session, UserCreate(username="test", email="findme@test.com")
)
result = await UserCrud.paginate(
db_session,
search=SearchConfig(query="findme", fields=[User.email]),
)
assert result["pagination"]["total_count"] == 1
class TestNoSearchableFieldsError:
"""Tests for NoSearchableFieldsError exception."""
def test_error_is_api_exception(self):
"""NoSearchableFieldsError inherits from ApiException."""
from fastapi_toolsets.exceptions import ApiException, NoSearchableFieldsError
assert issubclass(NoSearchableFieldsError, ApiException)
def test_error_has_api_error_fields(self):
"""NoSearchableFieldsError has proper api_error configuration."""
from fastapi_toolsets.exceptions import NoSearchableFieldsError
assert NoSearchableFieldsError.api_error.code == 400
assert NoSearchableFieldsError.api_error.err_code == "SEARCH-400"
def test_error_message_contains_model_name(self):
"""Error message includes the model name."""
from fastapi_toolsets.exceptions import NoSearchableFieldsError
error = NoSearchableFieldsError(User)
assert "User" in str(error)
assert error.model is User
def test_error_raised_when_no_fields(self):
"""Error is raised when search has no searchable fields."""
from sqlalchemy import Integer
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from fastapi_toolsets.crud.search import build_search_filters
from fastapi_toolsets.exceptions import NoSearchableFieldsError
# Model with no String columns
class NoStringBase(DeclarativeBase):
pass
class NoStringModel(NoStringBase):
__tablename__ = "no_strings"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
count: Mapped[int] = mapped_column(Integer, default=0)
with pytest.raises(NoSearchableFieldsError) as exc_info:
build_search_filters(NoStringModel, "test")
assert exc_info.value.model is NoStringModel
assert "NoStringModel" in str(exc_info.value)
class TestGetSearchableFields:
"""Tests for auto-detection of searchable fields."""
def test_detects_string_columns(self):
"""Detects String columns on the model."""
fields = get_searchable_fields(User, include_relationships=False)
# Should include username and email (String), not id or is_active
field_names = [str(f) for f in fields]
assert any("username" in f for f in field_names)
assert any("email" in f for f in field_names)
assert not any("id" in f and "role_id" not in f for f in field_names)
assert not any("is_active" in f for f in field_names)
def test_detects_relationship_fields(self):
"""Detects String fields on related models."""
fields = get_searchable_fields(User, include_relationships=True)
# Should include (User.role, Role.name)
has_role_name = any(isinstance(f, tuple) and len(f) == 2 for f in fields)
assert has_role_name
def test_skips_collection_relationships(self):
"""Skips one-to-many relationships."""
fields = get_searchable_fields(Role, include_relationships=True)
# Role.users is a collection, should not be included
field_strs = [str(f) for f in fields]
assert not any("users" in f for f in field_strs)

2
uv.lock generated
View File

@@ -220,7 +220,7 @@ wheels = [
[[package]] [[package]]
name = "fastapi-toolsets" name = "fastapi-toolsets"
version = "0.3.0" version = "0.4.0"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "asyncpg" }, { name = "asyncpg" },