mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-03-01 17:00:48 +01:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
348ed4c148
|
|||
|
bd6e90de1b
|
|||
|
|
4404fb3df9 | ||
|
|
f68793fbdb | ||
|
|
3a69c3c788 | ||
|
e861a0a49a
|
|||
|
|
cb2cf572e0 |
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "fastapi-toolsets"
|
||||
version = "0.6.0"
|
||||
version = "0.7.0"
|
||||
description = "Reusable tools for FastAPI: async CRUD, fixtures, CLI, and standardized responses for SQLAlchemy + PostgreSQL"
|
||||
readme = "README.md"
|
||||
license = "MIT"
|
||||
@@ -59,10 +59,10 @@ dev = [
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
manager = "fastapi_toolsets.cli:cli"
|
||||
manager = "fastapi_toolsets.cli.app:cli"
|
||||
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.9.26,<0.10.0"]
|
||||
requires = ["uv_build>=0.10,<0.11.0"]
|
||||
build-backend = "uv_build"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
|
||||
@@ -21,4 +21,4 @@ Example usage:
|
||||
return Response(data={"user": user.username}, message="Success")
|
||||
"""
|
||||
|
||||
__version__ = "0.6.0"
|
||||
__version__ = "0.7.0"
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import typer
|
||||
|
||||
from ..logger import configure_logging
|
||||
from .config import get_custom_cli
|
||||
from .pyproject import load_pyproject
|
||||
|
||||
@@ -27,4 +28,5 @@ if _config.get("fixtures") and _config.get("db_context"):
|
||||
@cli.callback()
|
||||
def main(ctx: typer.Context) -> None:
|
||||
"""FastAPI utilities CLI."""
|
||||
configure_logging()
|
||||
ctx.ensure_object(dict)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Generic async CRUD operations for SQLAlchemy models."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, ClassVar, Generic, Self, TypeVar, cast
|
||||
from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import and_, func, select
|
||||
@@ -14,6 +14,7 @@ from sqlalchemy.sql.roles import WhereHavingRole
|
||||
|
||||
from ..db import get_transaction
|
||||
from ..exceptions import NotFoundError
|
||||
from ..schemas import PaginatedResponse, Pagination, Response
|
||||
from .search import SearchConfig, SearchFieldType, build_search_filters
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
||||
@@ -29,26 +30,80 @@ class AsyncCrud(Generic[ModelType]):
|
||||
model: ClassVar[type[DeclarativeBase]]
|
||||
searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def create( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
obj: BaseModel,
|
||||
*,
|
||||
as_response: Literal[True],
|
||||
) -> Response[ModelType]: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def create( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
obj: BaseModel,
|
||||
*,
|
||||
as_response: Literal[False] = ...,
|
||||
) -> ModelType: ...
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
obj: BaseModel,
|
||||
) -> ModelType:
|
||||
*,
|
||||
as_response: bool = False,
|
||||
) -> ModelType | Response[ModelType]:
|
||||
"""Create a new record in the database.
|
||||
|
||||
Args:
|
||||
session: DB async session
|
||||
obj: Pydantic model with data to create
|
||||
as_response: If True, wrap result in Response object
|
||||
|
||||
Returns:
|
||||
Created model instance
|
||||
Created model instance or Response wrapping it
|
||||
"""
|
||||
async with get_transaction(session):
|
||||
db_model = cls.model(**obj.model_dump())
|
||||
session.add(db_model)
|
||||
await session.refresh(db_model)
|
||||
return cast(ModelType, db_model)
|
||||
result = cast(ModelType, db_model)
|
||||
if as_response:
|
||||
return Response(data=result)
|
||||
return result
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def get( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
filters: list[Any],
|
||||
*,
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
with_for_update: bool = False,
|
||||
load_options: list[Any] | None = None,
|
||||
as_response: Literal[True],
|
||||
) -> Response[ModelType]: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def get( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
filters: list[Any],
|
||||
*,
|
||||
joins: JoinType | None = None,
|
||||
outer_join: bool = False,
|
||||
with_for_update: bool = False,
|
||||
load_options: list[Any] | None = None,
|
||||
as_response: Literal[False] = ...,
|
||||
) -> ModelType: ...
|
||||
|
||||
@classmethod
|
||||
async def get(
|
||||
@@ -60,7 +115,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
outer_join: bool = False,
|
||||
with_for_update: bool = False,
|
||||
load_options: list[Any] | None = None,
|
||||
) -> ModelType:
|
||||
as_response: bool = False,
|
||||
) -> ModelType | Response[ModelType]:
|
||||
"""Get exactly one record. Raises NotFoundError if not found.
|
||||
|
||||
Args:
|
||||
@@ -70,9 +126,10 @@ class AsyncCrud(Generic[ModelType]):
|
||||
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
|
||||
with_for_update: Lock the row for update
|
||||
load_options: SQLAlchemy loader options (e.g., selectinload)
|
||||
as_response: If True, wrap result in Response object
|
||||
|
||||
Returns:
|
||||
Model instance
|
||||
Model instance or Response wrapping it
|
||||
|
||||
Raises:
|
||||
NotFoundError: If no record found
|
||||
@@ -95,7 +152,10 @@ class AsyncCrud(Generic[ModelType]):
|
||||
item = result.unique().scalar_one_or_none()
|
||||
if not item:
|
||||
raise NotFoundError()
|
||||
return cast(ModelType, item)
|
||||
result = cast(ModelType, item)
|
||||
if as_response:
|
||||
return Response(data=result)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
async def first(
|
||||
@@ -183,6 +243,32 @@ class AsyncCrud(Generic[ModelType]):
|
||||
result = await session.execute(q)
|
||||
return cast(Sequence[ModelType], result.unique().scalars().all())
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def update( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
obj: BaseModel,
|
||||
filters: list[Any],
|
||||
*,
|
||||
exclude_unset: bool = True,
|
||||
exclude_none: bool = False,
|
||||
as_response: Literal[True],
|
||||
) -> Response[ModelType]: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def update( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
obj: BaseModel,
|
||||
filters: list[Any],
|
||||
*,
|
||||
exclude_unset: bool = True,
|
||||
exclude_none: bool = False,
|
||||
as_response: Literal[False] = ...,
|
||||
) -> ModelType: ...
|
||||
|
||||
@classmethod
|
||||
async def update(
|
||||
cls: type[Self],
|
||||
@@ -192,7 +278,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
*,
|
||||
exclude_unset: bool = True,
|
||||
exclude_none: bool = False,
|
||||
) -> ModelType:
|
||||
as_response: bool = False,
|
||||
) -> ModelType | Response[ModelType]:
|
||||
"""Update a record in the database.
|
||||
|
||||
Args:
|
||||
@@ -201,9 +288,10 @@ class AsyncCrud(Generic[ModelType]):
|
||||
filters: List of SQLAlchemy filter conditions
|
||||
exclude_unset: Exclude fields not explicitly set in the schema
|
||||
exclude_none: Exclude fields with None value
|
||||
as_response: If True, wrap result in Response object
|
||||
|
||||
Returns:
|
||||
Updated model instance
|
||||
Updated model instance or Response wrapping it
|
||||
|
||||
Raises:
|
||||
NotFoundError: If no record found
|
||||
@@ -216,6 +304,8 @@ class AsyncCrud(Generic[ModelType]):
|
||||
for key, value in values.items():
|
||||
setattr(db_model, key, value)
|
||||
await session.refresh(db_model)
|
||||
if as_response:
|
||||
return Response(data=db_model)
|
||||
return db_model
|
||||
|
||||
@classmethod
|
||||
@@ -264,24 +354,49 @@ class AsyncCrud(Generic[ModelType]):
|
||||
)
|
||||
return cast(ModelType | None, db_model)
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def delete( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
filters: list[Any],
|
||||
*,
|
||||
as_response: Literal[True],
|
||||
) -> Response[None]: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def delete( # pragma: no cover
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
filters: list[Any],
|
||||
*,
|
||||
as_response: Literal[False] = ...,
|
||||
) -> bool: ...
|
||||
|
||||
@classmethod
|
||||
async def delete(
|
||||
cls: type[Self],
|
||||
session: AsyncSession,
|
||||
filters: list[Any],
|
||||
) -> bool:
|
||||
*,
|
||||
as_response: bool = False,
|
||||
) -> bool | Response[None]:
|
||||
"""Delete records from the database.
|
||||
|
||||
Args:
|
||||
session: DB async session
|
||||
filters: List of SQLAlchemy filter conditions
|
||||
as_response: If True, wrap result in Response object
|
||||
|
||||
Returns:
|
||||
True if deletion was executed
|
||||
True if deletion was executed, or Response wrapping it
|
||||
"""
|
||||
async with get_transaction(session):
|
||||
q = sql_delete(cls.model).where(and_(*filters))
|
||||
await session.execute(q)
|
||||
if as_response:
|
||||
return Response(data=None)
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
@@ -363,7 +478,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
items_per_page: int = 20,
|
||||
search: str | SearchConfig | None = None,
|
||||
search_fields: Sequence[SearchFieldType] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
) -> PaginatedResponse[ModelType]:
|
||||
"""Get paginated results with metadata.
|
||||
|
||||
Args:
|
||||
@@ -420,7 +535,7 @@ class AsyncCrud(Generic[ModelType]):
|
||||
|
||||
q = q.offset(offset).limit(items_per_page)
|
||||
result = await session.execute(q)
|
||||
items = result.unique().scalars().all()
|
||||
items = cast(list[ModelType], result.unique().scalars().all())
|
||||
|
||||
# Count query (with same joins and filters)
|
||||
pk_col = cls.model.__mapper__.primary_key[0]
|
||||
@@ -446,15 +561,15 @@ class AsyncCrud(Generic[ModelType]):
|
||||
count_result = await session.execute(count_q)
|
||||
total_count = count_result.scalar_one()
|
||||
|
||||
return {
|
||||
"data": items,
|
||||
"pagination": {
|
||||
"total_count": total_count,
|
||||
"items_per_page": items_per_page,
|
||||
"page": page,
|
||||
"has_more": page * items_per_page < total_count,
|
||||
},
|
||||
}
|
||||
return PaginatedResponse(
|
||||
data=items,
|
||||
pagination=Pagination(
|
||||
total_count=total_count,
|
||||
items_per_page=items_per_page,
|
||||
page=page,
|
||||
has_more=page * items_per_page < total_count,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def CrudFactory(
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
"""Fixture system with dependency management and context support."""
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from ..logger import get_logger
|
||||
from .enum import Context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import logging
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, TypeVar
|
||||
|
||||
@@ -6,10 +5,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from ..db import get_transaction
|
||||
from ..logger import get_logger
|
||||
from .enum import LoadStrategy
|
||||
from .registry import Context, FixtureRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger()
|
||||
|
||||
T = TypeVar("T", bound=DeclarativeBase)
|
||||
|
||||
|
||||
81
src/fastapi_toolsets/logger.py
Normal file
81
src/fastapi_toolsets/logger.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Logging configuration for FastAPI applications and CLI tools."""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from typing import Literal
|
||||
|
||||
__all__ = ["LogLevel", "configure_logging", "get_logger"]
|
||||
|
||||
DEFAULT_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
UVICORN_LOGGERS = ("uvicorn", "uvicorn.access", "uvicorn.error")
|
||||
|
||||
LogLevel = Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
|
||||
|
||||
|
||||
def configure_logging(
|
||||
level: LogLevel | int = "INFO",
|
||||
fmt: str = DEFAULT_FORMAT,
|
||||
logger_name: str | None = None,
|
||||
) -> logging.Logger:
|
||||
"""Configure logging with a stdout handler and consistent format.
|
||||
|
||||
Sets up a :class:`~logging.StreamHandler` writing to stdout with the
|
||||
given format and level. Also configures the uvicorn loggers so that
|
||||
FastAPI access logs use the same format.
|
||||
|
||||
Calling this function multiple times is safe -- existing handlers are
|
||||
replaced rather than duplicated.
|
||||
|
||||
Args:
|
||||
level: Log level (e.g. ``"DEBUG"``, ``"INFO"``, or ``logging.DEBUG``).
|
||||
fmt: Log format string. Defaults to
|
||||
``"%(asctime)s - %(name)s - %(levelname)s - %(message)s"``.
|
||||
logger_name: Logger name to configure. ``None`` (the default)
|
||||
configures the root logger so all loggers inherit the settings.
|
||||
|
||||
Returns:
|
||||
The configured Logger instance.
|
||||
"""
|
||||
formatter = logging.Formatter(fmt)
|
||||
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
logger = logging.getLogger(logger_name)
|
||||
logger.handlers.clear()
|
||||
logger.addHandler(handler)
|
||||
logger.setLevel(level)
|
||||
|
||||
for name in UVICORN_LOGGERS:
|
||||
uv_logger = logging.getLogger(name)
|
||||
uv_logger.handlers.clear()
|
||||
uv_logger.addHandler(handler)
|
||||
uv_logger.setLevel(level)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
_SENTINEL = object()
|
||||
|
||||
|
||||
def get_logger(name: str | None = _SENTINEL) -> logging.Logger: # type: ignore[assignment]
|
||||
"""Return a logger with the given *name*.
|
||||
|
||||
A thin convenience wrapper around :func:`logging.getLogger` that keeps
|
||||
logging imports consistent across the codebase.
|
||||
|
||||
When called without arguments, the caller's ``__name__`` is used
|
||||
automatically, so ``get_logger()`` in a module is equivalent to
|
||||
``logging.getLogger(__name__)``. Pass ``None`` explicitly to get the
|
||||
root logger.
|
||||
|
||||
Args:
|
||||
name: Logger name. Defaults to the caller's ``__name__``.
|
||||
Pass ``None`` to get the root logger.
|
||||
|
||||
Returns:
|
||||
A Logger instance.
|
||||
"""
|
||||
if name is _SENTINEL:
|
||||
name = sys._getframe(1).f_globals.get("__name__")
|
||||
return logging.getLogger(name)
|
||||
@@ -429,11 +429,11 @@ class TestCrudPaginate:
|
||||
|
||||
result = await RoleCrud.paginate(db_session, page=1, items_per_page=10)
|
||||
|
||||
assert len(result["data"]) == 10
|
||||
assert result["pagination"]["total_count"] == 25
|
||||
assert result["pagination"]["page"] == 1
|
||||
assert result["pagination"]["items_per_page"] == 10
|
||||
assert result["pagination"]["has_more"] is True
|
||||
assert len(result.data) == 10
|
||||
assert result.pagination.total_count == 25
|
||||
assert result.pagination.page == 1
|
||||
assert result.pagination.items_per_page == 10
|
||||
assert result.pagination.has_more is True
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_paginate_last_page(self, db_session: AsyncSession):
|
||||
@@ -443,8 +443,8 @@ class TestCrudPaginate:
|
||||
|
||||
result = await RoleCrud.paginate(db_session, page=3, items_per_page=10)
|
||||
|
||||
assert len(result["data"]) == 5
|
||||
assert result["pagination"]["has_more"] is False
|
||||
assert len(result.data) == 5
|
||||
assert result.pagination.has_more is False
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_paginate_with_filters(self, db_session: AsyncSession):
|
||||
@@ -466,7 +466,7 @@ class TestCrudPaginate:
|
||||
items_per_page=10,
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 5
|
||||
assert result.pagination.total_count == 5
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_paginate_with_ordering(self, db_session: AsyncSession):
|
||||
@@ -482,7 +482,7 @@ class TestCrudPaginate:
|
||||
items_per_page=10,
|
||||
)
|
||||
|
||||
names = [r.name for r in result["data"]]
|
||||
names = [r.name for r in result.data]
|
||||
assert names == ["alpha", "bravo", "charlie"]
|
||||
|
||||
|
||||
@@ -690,8 +690,8 @@ class TestCrudJoins:
|
||||
items_per_page=10,
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 3
|
||||
assert len(result["data"]) == 3
|
||||
assert result.pagination.total_count == 3
|
||||
assert len(result.data) == 3
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_paginate_with_outer_join(self, db_session: AsyncSession):
|
||||
@@ -721,8 +721,8 @@ class TestCrudJoins:
|
||||
items_per_page=10,
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 2
|
||||
assert len(result["data"]) == 2
|
||||
assert result.pagination.total_count == 2
|
||||
assert len(result.data) == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_multiple_joins(self, db_session: AsyncSession):
|
||||
@@ -752,3 +752,63 @@ class TestCrudJoins:
|
||||
)
|
||||
assert len(users) == 1
|
||||
assert users[0].username == "multi_join"
|
||||
|
||||
|
||||
class TestAsResponse:
|
||||
"""Tests for as_response parameter."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_as_response(self, db_session: AsyncSession):
|
||||
"""Create with as_response=True returns Response."""
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
data = RoleCreate(name="response_role")
|
||||
result = await RoleCrud.create(db_session, data, as_response=True)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert result.data is not None
|
||||
assert result.data.name == "response_role"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_as_response(self, db_session: AsyncSession):
|
||||
"""Get with as_response=True returns Response."""
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
created = await RoleCrud.create(db_session, RoleCreate(name="get_response"))
|
||||
result = await RoleCrud.get(
|
||||
db_session, [Role.id == created.id], as_response=True
|
||||
)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert result.data is not None
|
||||
assert result.data.id == created.id
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_as_response(self, db_session: AsyncSession):
|
||||
"""Update with as_response=True returns Response."""
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
created = await RoleCrud.create(db_session, RoleCreate(name="old_name"))
|
||||
result = await RoleCrud.update(
|
||||
db_session,
|
||||
RoleUpdate(name="new_name"),
|
||||
[Role.id == created.id],
|
||||
as_response=True,
|
||||
)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert result.data is not None
|
||||
assert result.data.name == "new_name"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_as_response(self, db_session: AsyncSession):
|
||||
"""Delete with as_response=True returns Response."""
|
||||
from fastapi_toolsets.schemas import Response
|
||||
|
||||
created = await RoleCrud.create(db_session, RoleCreate(name="to_delete"))
|
||||
result = await RoleCrud.delete(
|
||||
db_session, [Role.id == created.id], as_response=True
|
||||
)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert result.data is None
|
||||
|
||||
@@ -39,7 +39,7 @@ class TestPaginateSearch:
|
||||
search_fields=[User.username],
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 2
|
||||
assert result.pagination.total_count == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_multiple_columns(self, db_session: AsyncSession):
|
||||
@@ -57,7 +57,7 @@ class TestPaginateSearch:
|
||||
search_fields=[User.username, User.email],
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 2
|
||||
assert result.pagination.total_count == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_relationship_depth1(self, db_session: AsyncSession):
|
||||
@@ -84,7 +84,7 @@ class TestPaginateSearch:
|
||||
search_fields=[(User.role, Role.name)],
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 2
|
||||
assert result.pagination.total_count == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_mixed_direct_and_relation(self, db_session: AsyncSession):
|
||||
@@ -102,7 +102,7 @@ class TestPaginateSearch:
|
||||
search_fields=[User.username, (User.role, Role.name)],
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 1
|
||||
assert result.pagination.total_count == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_case_insensitive(self, db_session: AsyncSession):
|
||||
@@ -117,7 +117,7 @@ class TestPaginateSearch:
|
||||
search_fields=[User.username],
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 1
|
||||
assert result.pagination.total_count == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_case_sensitive(self, db_session: AsyncSession):
|
||||
@@ -132,7 +132,7 @@ class TestPaginateSearch:
|
||||
search=SearchConfig(query="johndoe", case_sensitive=True),
|
||||
search_fields=[User.username],
|
||||
)
|
||||
assert result["pagination"]["total_count"] == 0
|
||||
assert result.pagination.total_count == 0
|
||||
|
||||
# Should find (case match)
|
||||
result = await UserCrud.paginate(
|
||||
@@ -140,7 +140,7 @@ class TestPaginateSearch:
|
||||
search=SearchConfig(query="JohnDoe", case_sensitive=True),
|
||||
search_fields=[User.username],
|
||||
)
|
||||
assert result["pagination"]["total_count"] == 1
|
||||
assert result.pagination.total_count == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_empty_query(self, db_session: AsyncSession):
|
||||
@@ -153,10 +153,10 @@ class TestPaginateSearch:
|
||||
)
|
||||
|
||||
result = await UserCrud.paginate(db_session, search="")
|
||||
assert result["pagination"]["total_count"] == 2
|
||||
assert result.pagination.total_count == 2
|
||||
|
||||
result = await UserCrud.paginate(db_session, search=None)
|
||||
assert result["pagination"]["total_count"] == 2
|
||||
assert result.pagination.total_count == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_with_existing_filters(self, db_session: AsyncSession):
|
||||
@@ -177,8 +177,8 @@ class TestPaginateSearch:
|
||||
search_fields=[User.username],
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 1
|
||||
assert result["data"][0].username == "active_john"
|
||||
assert result.pagination.total_count == 1
|
||||
assert result.data[0].username == "active_john"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_auto_detect_fields(self, db_session: AsyncSession):
|
||||
@@ -189,7 +189,7 @@ class TestPaginateSearch:
|
||||
|
||||
result = await UserCrud.paginate(db_session, search="findme")
|
||||
|
||||
assert result["pagination"]["total_count"] == 1
|
||||
assert result.pagination.total_count == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_no_results(self, db_session: AsyncSession):
|
||||
@@ -204,8 +204,8 @@ class TestPaginateSearch:
|
||||
search_fields=[User.username],
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 0
|
||||
assert result["data"] == []
|
||||
assert result.pagination.total_count == 0
|
||||
assert result.data == []
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_with_pagination(self, db_session: AsyncSession):
|
||||
@@ -224,9 +224,9 @@ class TestPaginateSearch:
|
||||
items_per_page=5,
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 15
|
||||
assert len(result["data"]) == 5
|
||||
assert result["pagination"]["has_more"] is True
|
||||
assert result.pagination.total_count == 15
|
||||
assert len(result.data) == 5
|
||||
assert result.pagination.has_more is True
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_null_relationship(self, db_session: AsyncSession):
|
||||
@@ -248,7 +248,7 @@ class TestPaginateSearch:
|
||||
search_fields=[User.username],
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 2
|
||||
assert result.pagination.total_count == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_with_order_by(self, db_session: AsyncSession):
|
||||
@@ -270,8 +270,8 @@ class TestPaginateSearch:
|
||||
order_by=User.username,
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 3
|
||||
usernames = [u.username for u in result["data"]]
|
||||
assert result.pagination.total_count == 3
|
||||
usernames = [u.username for u in result.data]
|
||||
assert usernames == ["alice", "bob", "charlie"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -292,8 +292,8 @@ class TestPaginateSearch:
|
||||
search_fields=[User.id, User.username],
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 1
|
||||
assert result["data"][0].id == user_id
|
||||
assert result.pagination.total_count == 1
|
||||
assert result.data[0].id == user_id
|
||||
|
||||
|
||||
class TestSearchConfig:
|
||||
@@ -318,8 +318,8 @@ class TestSearchConfig:
|
||||
search_fields=[User.username, User.email],
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 1
|
||||
assert result["data"][0].username == "john_test"
|
||||
assert result.pagination.total_count == 1
|
||||
assert result.data[0].username == "john_test"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_config_with_fields(self, db_session: AsyncSession):
|
||||
@@ -333,7 +333,7 @@ class TestSearchConfig:
|
||||
search=SearchConfig(query="findme", fields=[User.email]),
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 1
|
||||
assert result.pagination.total_count == 1
|
||||
|
||||
|
||||
class TestNoSearchableFieldsError:
|
||||
|
||||
118
tests/test_logger.py
Normal file
118
tests/test_logger.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from fastapi_toolsets.logger import (
|
||||
DEFAULT_FORMAT,
|
||||
UVICORN_LOGGERS,
|
||||
configure_logging,
|
||||
get_logger,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_loggers():
|
||||
"""Reset the root and uvicorn loggers after each test."""
|
||||
yield
|
||||
root = logging.getLogger()
|
||||
root.handlers.clear()
|
||||
root.setLevel(logging.WARNING)
|
||||
for name in UVICORN_LOGGERS:
|
||||
uv = logging.getLogger(name)
|
||||
uv.handlers.clear()
|
||||
uv.setLevel(logging.NOTSET)
|
||||
|
||||
|
||||
class TestConfigureLogging:
|
||||
def test_sets_up_handler_and_format(self):
|
||||
logger = configure_logging()
|
||||
|
||||
assert len(logger.handlers) == 1
|
||||
handler = logger.handlers[0]
|
||||
assert isinstance(handler, logging.StreamHandler)
|
||||
assert handler.stream is sys.stdout
|
||||
assert handler.formatter is not None
|
||||
assert handler.formatter._fmt == DEFAULT_FORMAT
|
||||
|
||||
def test_default_level_is_info(self):
|
||||
logger = configure_logging()
|
||||
|
||||
assert logger.level == logging.INFO
|
||||
|
||||
def test_custom_level_string(self):
|
||||
logger = configure_logging(level="DEBUG")
|
||||
|
||||
assert logger.level == logging.DEBUG
|
||||
|
||||
def test_custom_level_int(self):
|
||||
logger = configure_logging(level=logging.WARNING)
|
||||
|
||||
assert logger.level == logging.WARNING
|
||||
|
||||
def test_custom_format(self):
|
||||
custom_fmt = "%(levelname)s: %(message)s"
|
||||
logger = configure_logging(fmt=custom_fmt)
|
||||
|
||||
handler = logger.handlers[0]
|
||||
assert handler.formatter is not None
|
||||
assert handler.formatter._fmt == custom_fmt
|
||||
|
||||
def test_named_logger(self):
|
||||
logger = configure_logging(logger_name="myapp")
|
||||
|
||||
assert logger.name == "myapp"
|
||||
assert len(logger.handlers) == 1
|
||||
|
||||
def test_default_configures_root_logger(self):
|
||||
logger = configure_logging()
|
||||
|
||||
assert logger is logging.getLogger()
|
||||
|
||||
def test_idempotent_no_duplicate_handlers(self):
|
||||
configure_logging()
|
||||
configure_logging()
|
||||
logger = configure_logging()
|
||||
|
||||
assert len(logger.handlers) == 1
|
||||
|
||||
def test_configures_uvicorn_loggers(self):
|
||||
configure_logging(level="DEBUG")
|
||||
|
||||
for name in UVICORN_LOGGERS:
|
||||
uv_logger = logging.getLogger(name)
|
||||
assert len(uv_logger.handlers) == 1
|
||||
assert uv_logger.level == logging.DEBUG
|
||||
handler = uv_logger.handlers[0]
|
||||
assert handler.formatter is not None
|
||||
assert handler.formatter._fmt == DEFAULT_FORMAT
|
||||
|
||||
def test_returns_configured_logger(self):
|
||||
logger = configure_logging(logger_name="test.return")
|
||||
|
||||
assert isinstance(logger, logging.Logger)
|
||||
assert logger.name == "test.return"
|
||||
|
||||
|
||||
class TestGetLogger:
|
||||
def test_returns_named_logger(self):
|
||||
logger = get_logger("myapp.services")
|
||||
|
||||
assert isinstance(logger, logging.Logger)
|
||||
assert logger.name == "myapp.services"
|
||||
|
||||
def test_returns_root_logger_when_none(self):
|
||||
logger = get_logger(None)
|
||||
|
||||
assert logger is logging.getLogger()
|
||||
|
||||
def test_defaults_to_caller_module_name(self):
|
||||
logger = get_logger()
|
||||
|
||||
assert logger.name == __name__
|
||||
|
||||
def test_same_name_returns_same_logger(self):
|
||||
a = get_logger("myapp")
|
||||
b = get_logger("myapp")
|
||||
|
||||
assert a is b
|
||||
8
uv.lock
generated
8
uv.lock
generated
@@ -205,7 +205,7 @@ toml = [
|
||||
|
||||
[[package]]
|
||||
name = "fastapi"
|
||||
version = "0.128.0"
|
||||
version = "0.128.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "annotated-doc" },
|
||||
@@ -213,14 +213,14 @@ dependencies = [
|
||||
{ name = "starlette" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/52/08/8c8508db6c7b9aae8f7175046af41baad690771c9bcde676419965e338c7/fastapi-0.128.0.tar.gz", hash = "sha256:1cc179e1cef10a6be60ffe429f79b829dce99d8de32d7acb7e6c8dfdf7f2645a", size = 365682, upload-time = "2025-12-27T15:21:13.714Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f6/59/28bde150415783ff084334e3de106eb7461a57864cf69f343950ad5a5ddd/fastapi-0.128.1.tar.gz", hash = "sha256:ce5be4fa26d4ce6f54debcc873d1fb8e0e248f5c48d7502ba6c61457ab2dc766", size = 374260, upload-time = "2026-02-04T17:35:10.542Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/5c/05/5cbb59154b093548acd0f4c7c474a118eda06da25aa75c616b72d8fcd92a/fastapi-0.128.0-py3-none-any.whl", hash = "sha256:aebd93f9716ee3b4f4fcfe13ffb7cf308d99c9f3ab5622d8877441072561582d", size = 103094, upload-time = "2025-12-27T15:21:12.154Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1a/08/3953db1979ea131c68279b997c6465080118b407f0800445b843f8e164b3/fastapi-0.128.1-py3-none-any.whl", hash = "sha256:ee82146bbf91ea5bbf2bb8629e4c6e056c4fbd997ea6068501b11b15260b50fb", size = 103810, upload-time = "2026-02-04T17:35:08.02Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fastapi-toolsets"
|
||||
version = "0.6.0"
|
||||
version = "0.7.0"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "asyncpg" },
|
||||
|
||||
Reference in New Issue
Block a user