mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-03-02 17:30:48 +01:00
refactor: centralize type aliases in types.py and simplify crud layer
This commit is contained in:
@@ -1,12 +1,9 @@
|
|||||||
"""Generic async CRUD operations for SQLAlchemy models."""
|
"""Generic async CRUD operations for SQLAlchemy models."""
|
||||||
|
|
||||||
from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError
|
from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError
|
||||||
from .factory import CrudFactory, JoinType, M2MFieldType, OrderByClause
|
from ..types import FacetFieldType, JoinType, M2MFieldType, OrderByClause
|
||||||
from .search import (
|
from .factory import CrudFactory
|
||||||
FacetFieldType,
|
from .search import SearchConfig, get_searchable_fields
|
||||||
SearchConfig,
|
|
||||||
get_searchable_fields,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"CrudFactory",
|
"CrudFactory",
|
||||||
|
|||||||
@@ -6,10 +6,10 @@ import base64
|
|||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import uuid as uuid_module
|
import uuid as uuid_module
|
||||||
from collections.abc import Awaitable, Callable, Mapping, Sequence
|
from collections.abc import Awaitable, Callable, Sequence
|
||||||
from datetime import date, datetime
|
from datetime import date, datetime
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload
|
from typing import Any, ClassVar, Generic, Literal, Self, cast, overload
|
||||||
|
|
||||||
from fastapi import Query
|
from fastapi import Query
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -20,28 +20,28 @@ from sqlalchemy.exc import NoResultFound
|
|||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.orm import DeclarativeBase, QueryableAttribute, selectinload
|
from sqlalchemy.orm import DeclarativeBase, QueryableAttribute, selectinload
|
||||||
from sqlalchemy.sql.base import ExecutableOption
|
from sqlalchemy.sql.base import ExecutableOption
|
||||||
from sqlalchemy.sql.elements import ColumnElement
|
|
||||||
from sqlalchemy.sql.roles import WhereHavingRole
|
from sqlalchemy.sql.roles import WhereHavingRole
|
||||||
|
|
||||||
from ..db import get_transaction
|
from ..db import get_transaction
|
||||||
from ..exceptions import InvalidOrderFieldError, NotFoundError
|
from ..exceptions import InvalidOrderFieldError, NotFoundError
|
||||||
from ..schemas import CursorPagination, OffsetPagination, PaginatedResponse, Response
|
from ..schemas import CursorPagination, OffsetPagination, PaginatedResponse, Response
|
||||||
from .search import (
|
from ..types import (
|
||||||
FacetFieldType,
|
FacetFieldType,
|
||||||
SearchConfig,
|
JoinType,
|
||||||
|
M2MFieldType,
|
||||||
|
ModelType,
|
||||||
|
OrderByClause,
|
||||||
|
SchemaType,
|
||||||
SearchFieldType,
|
SearchFieldType,
|
||||||
|
)
|
||||||
|
from .search import (
|
||||||
|
SearchConfig,
|
||||||
build_facets,
|
build_facets,
|
||||||
build_filter_by,
|
build_filter_by,
|
||||||
build_search_filters,
|
build_search_filters,
|
||||||
facet_keys,
|
facet_keys,
|
||||||
)
|
)
|
||||||
|
|
||||||
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
|
||||||
SchemaType = TypeVar("SchemaType", bound=BaseModel)
|
|
||||||
JoinType = list[tuple[type[DeclarativeBase], Any]]
|
|
||||||
M2MFieldType = Mapping[str, QueryableAttribute[Any]]
|
|
||||||
OrderByClause = ColumnElement[Any] | QueryableAttribute[Any]
|
|
||||||
|
|
||||||
|
|
||||||
def _encode_cursor(value: Any) -> str:
|
def _encode_cursor(value: Any) -> str:
|
||||||
"""Encode cursor column value as an base64 string."""
|
"""Encode cursor column value as an base64 string."""
|
||||||
@@ -53,6 +53,22 @@ def _decode_cursor(cursor: str) -> str:
|
|||||||
return json.loads(base64.b64decode(cursor.encode()).decode())
|
return json.loads(base64.b64decode(cursor.encode()).decode())
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_joins(q: Any, joins: JoinType | None, outer_join: bool) -> Any:
|
||||||
|
"""Apply a list of (model, condition) joins to a SQLAlchemy select query."""
|
||||||
|
if not joins:
|
||||||
|
return q
|
||||||
|
for model, condition in joins:
|
||||||
|
q = q.outerjoin(model, condition) if outer_join else q.join(model, condition)
|
||||||
|
return q
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_search_joins(q: Any, search_joins: list[Any]) -> Any:
|
||||||
|
"""Apply relationship-based outer joins (from search/filter_by) to a query."""
|
||||||
|
for join_rel in search_joins:
|
||||||
|
q = q.outerjoin(join_rel)
|
||||||
|
return q
|
||||||
|
|
||||||
|
|
||||||
class AsyncCrud(Generic[ModelType]):
|
class AsyncCrud(Generic[ModelType]):
|
||||||
"""Generic async CRUD operations for SQLAlchemy models.
|
"""Generic async CRUD operations for SQLAlchemy models.
|
||||||
|
|
||||||
@@ -132,6 +148,40 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
return set()
|
return set()
|
||||||
return set(cls.m2m_fields.keys())
|
return set(cls.m2m_fields.keys())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _prepare_filter_by(
|
||||||
|
cls: type[Self],
|
||||||
|
filter_by: dict[str, Any] | BaseModel | None,
|
||||||
|
facet_fields: Sequence[FacetFieldType] | None,
|
||||||
|
) -> tuple[list[Any], list[Any]]:
|
||||||
|
"""Normalize filter_by and return (filters, joins) to apply to the query."""
|
||||||
|
if isinstance(filter_by, BaseModel):
|
||||||
|
filter_by = filter_by.model_dump(exclude_none=True) or None
|
||||||
|
if not filter_by:
|
||||||
|
return [], []
|
||||||
|
resolved = facet_fields if facet_fields is not None else cls.facet_fields
|
||||||
|
return build_filter_by(filter_by, resolved or [])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _build_filter_attributes(
|
||||||
|
cls: type[Self],
|
||||||
|
session: AsyncSession,
|
||||||
|
facet_fields: Sequence[FacetFieldType] | None,
|
||||||
|
filters: list[Any],
|
||||||
|
search_joins: list[Any],
|
||||||
|
) -> dict[str, list[Any]] | None:
|
||||||
|
"""Build facet filter_attributes, or return None if no facet fields configured."""
|
||||||
|
resolved = facet_fields if facet_fields is not None else cls.facet_fields
|
||||||
|
if not resolved:
|
||||||
|
return None
|
||||||
|
return await build_facets(
|
||||||
|
session,
|
||||||
|
cls.model,
|
||||||
|
resolved,
|
||||||
|
base_filters=filters or None,
|
||||||
|
base_joins=search_joins or None,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def filter_params(
|
def filter_params(
|
||||||
cls: type[Self],
|
cls: type[Self],
|
||||||
@@ -290,8 +340,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
await session.refresh(db_model)
|
await session.refresh(db_model)
|
||||||
result = cast(ModelType, db_model)
|
result = cast(ModelType, db_model)
|
||||||
if schema:
|
if schema:
|
||||||
data_out = schema.model_validate(result) if schema else result
|
return Response(data=schema.model_validate(result))
|
||||||
return Response(data=data_out)
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
@@ -354,13 +403,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
MultipleResultsFound: If more than one record found
|
MultipleResultsFound: If more than one record found
|
||||||
"""
|
"""
|
||||||
q = select(cls.model)
|
q = select(cls.model)
|
||||||
if joins:
|
q = _apply_joins(q, joins, outer_join)
|
||||||
for model, condition in joins:
|
|
||||||
q = (
|
|
||||||
q.outerjoin(model, condition)
|
|
||||||
if outer_join
|
|
||||||
else q.join(model, condition)
|
|
||||||
)
|
|
||||||
q = q.where(and_(*filters))
|
q = q.where(and_(*filters))
|
||||||
if resolved := cls._resolve_load_options(load_options):
|
if resolved := cls._resolve_load_options(load_options):
|
||||||
q = q.options(*resolved)
|
q = q.options(*resolved)
|
||||||
@@ -372,8 +415,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
raise NotFoundError()
|
raise NotFoundError()
|
||||||
result = cast(ModelType, item)
|
result = cast(ModelType, item)
|
||||||
if schema:
|
if schema:
|
||||||
data_out = schema.model_validate(result) if schema else result
|
return Response(data=schema.model_validate(result))
|
||||||
return Response(data=data_out)
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -399,13 +441,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
Model instance or None
|
Model instance or None
|
||||||
"""
|
"""
|
||||||
q = select(cls.model)
|
q = select(cls.model)
|
||||||
if joins:
|
q = _apply_joins(q, joins, outer_join)
|
||||||
for model, condition in joins:
|
|
||||||
q = (
|
|
||||||
q.outerjoin(model, condition)
|
|
||||||
if outer_join
|
|
||||||
else q.join(model, condition)
|
|
||||||
)
|
|
||||||
if filters:
|
if filters:
|
||||||
q = q.where(and_(*filters))
|
q = q.where(and_(*filters))
|
||||||
if resolved := cls._resolve_load_options(load_options):
|
if resolved := cls._resolve_load_options(load_options):
|
||||||
@@ -442,13 +478,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
List of model instances
|
List of model instances
|
||||||
"""
|
"""
|
||||||
q = select(cls.model)
|
q = select(cls.model)
|
||||||
if joins:
|
q = _apply_joins(q, joins, outer_join)
|
||||||
for model, condition in joins:
|
|
||||||
q = (
|
|
||||||
q.outerjoin(model, condition)
|
|
||||||
if outer_join
|
|
||||||
else q.join(model, condition)
|
|
||||||
)
|
|
||||||
if filters:
|
if filters:
|
||||||
q = q.where(and_(*filters))
|
q = q.where(and_(*filters))
|
||||||
if resolved := cls._resolve_load_options(load_options):
|
if resolved := cls._resolve_load_options(load_options):
|
||||||
@@ -546,8 +576,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
setattr(db_model, rel_attr, related_instances)
|
setattr(db_model, rel_attr, related_instances)
|
||||||
await session.refresh(db_model)
|
await session.refresh(db_model)
|
||||||
if schema:
|
if schema:
|
||||||
data_out = schema.model_validate(db_model) if schema else db_model
|
return Response(data=schema.model_validate(db_model))
|
||||||
return Response(data=data_out)
|
|
||||||
return db_model
|
return db_model
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -664,13 +693,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
Number of matching records
|
Number of matching records
|
||||||
"""
|
"""
|
||||||
q = select(func.count()).select_from(cls.model)
|
q = select(func.count()).select_from(cls.model)
|
||||||
if joins:
|
q = _apply_joins(q, joins, outer_join)
|
||||||
for model, condition in joins:
|
|
||||||
q = (
|
|
||||||
q.outerjoin(model, condition)
|
|
||||||
if outer_join
|
|
||||||
else q.join(model, condition)
|
|
||||||
)
|
|
||||||
if filters:
|
if filters:
|
||||||
q = q.where(and_(*filters))
|
q = q.where(and_(*filters))
|
||||||
result = await session.execute(q)
|
result = await session.execute(q)
|
||||||
@@ -697,13 +720,7 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
True if at least one record matches
|
True if at least one record matches
|
||||||
"""
|
"""
|
||||||
q = select(cls.model)
|
q = select(cls.model)
|
||||||
if joins:
|
q = _apply_joins(q, joins, outer_join)
|
||||||
for model, condition in joins:
|
|
||||||
q = (
|
|
||||||
q.outerjoin(model, condition)
|
|
||||||
if outer_join
|
|
||||||
else q.join(model, condition)
|
|
||||||
)
|
|
||||||
q = q.where(and_(*filters)).exists().select()
|
q = q.where(and_(*filters)).exists().select()
|
||||||
result = await session.execute(q)
|
result = await session.execute(q)
|
||||||
return bool(result.scalar())
|
return bool(result.scalar())
|
||||||
@@ -750,47 +767,29 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
"""
|
"""
|
||||||
filters = list(filters) if filters else []
|
filters = list(filters) if filters else []
|
||||||
offset = (page - 1) * items_per_page
|
offset = (page - 1) * items_per_page
|
||||||
search_joins: list[Any] = []
|
|
||||||
|
|
||||||
if isinstance(filter_by, BaseModel):
|
fb_filters, search_joins = cls._prepare_filter_by(filter_by, facet_fields)
|
||||||
filter_by = filter_by.model_dump(exclude_none=True) or None
|
filters.extend(fb_filters)
|
||||||
|
|
||||||
# Build filter_by conditions from declared facet fields
|
|
||||||
if filter_by:
|
|
||||||
resolved_facets_for_filter = (
|
|
||||||
facet_fields if facet_fields is not None else cls.facet_fields
|
|
||||||
)
|
|
||||||
fb_filters, fb_joins = build_filter_by(
|
|
||||||
filter_by, resolved_facets_for_filter or []
|
|
||||||
)
|
|
||||||
filters.extend(fb_filters)
|
|
||||||
search_joins.extend(fb_joins)
|
|
||||||
|
|
||||||
# Build search filters
|
# Build search filters
|
||||||
if search:
|
if search:
|
||||||
search_filters, search_joins = build_search_filters(
|
search_filters, new_search_joins = build_search_filters(
|
||||||
cls.model,
|
cls.model,
|
||||||
search,
|
search,
|
||||||
search_fields=search_fields,
|
search_fields=search_fields,
|
||||||
default_fields=cls.searchable_fields,
|
default_fields=cls.searchable_fields,
|
||||||
)
|
)
|
||||||
filters.extend(search_filters)
|
filters.extend(search_filters)
|
||||||
|
search_joins.extend(new_search_joins)
|
||||||
|
|
||||||
# Build query with joins
|
# Build query with joins
|
||||||
q = select(cls.model)
|
q = select(cls.model)
|
||||||
|
|
||||||
# Apply explicit joins
|
# Apply explicit joins
|
||||||
if joins:
|
q = _apply_joins(q, joins, outer_join)
|
||||||
for model, condition in joins:
|
|
||||||
q = (
|
|
||||||
q.outerjoin(model, condition)
|
|
||||||
if outer_join
|
|
||||||
else q.join(model, condition)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply search joins (always outer joins for search)
|
# Apply search joins (always outer joins for search)
|
||||||
for join_rel in search_joins:
|
q = _apply_search_joins(q, search_joins)
|
||||||
q = q.outerjoin(join_rel)
|
|
||||||
|
|
||||||
if filters:
|
if filters:
|
||||||
q = q.where(and_(*filters))
|
q = q.where(and_(*filters))
|
||||||
@@ -810,17 +809,10 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
count_q = count_q.select_from(cls.model)
|
count_q = count_q.select_from(cls.model)
|
||||||
|
|
||||||
# Apply explicit joins to count query
|
# Apply explicit joins to count query
|
||||||
if joins:
|
count_q = _apply_joins(count_q, joins, outer_join)
|
||||||
for model, condition in joins:
|
|
||||||
count_q = (
|
|
||||||
count_q.outerjoin(model, condition)
|
|
||||||
if outer_join
|
|
||||||
else count_q.join(model, condition)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply search joins to count query
|
# Apply search joins to count query
|
||||||
for join_rel in search_joins:
|
count_q = _apply_search_joins(count_q, search_joins)
|
||||||
count_q = count_q.outerjoin(join_rel)
|
|
||||||
|
|
||||||
if filters:
|
if filters:
|
||||||
count_q = count_q.where(and_(*filters))
|
count_q = count_q.where(and_(*filters))
|
||||||
@@ -828,19 +820,9 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
count_result = await session.execute(count_q)
|
count_result = await session.execute(count_q)
|
||||||
total_count = count_result.scalar_one()
|
total_count = count_result.scalar_one()
|
||||||
|
|
||||||
# Build facets
|
filter_attributes = await cls._build_filter_attributes(
|
||||||
resolved_facet_fields = (
|
session, facet_fields, filters, search_joins
|
||||||
facet_fields if facet_fields is not None else cls.facet_fields
|
|
||||||
)
|
)
|
||||||
filter_attributes: dict[str, list[Any]] | None = None
|
|
||||||
if resolved_facet_fields:
|
|
||||||
filter_attributes = await build_facets(
|
|
||||||
session,
|
|
||||||
cls.model,
|
|
||||||
resolved_facet_fields,
|
|
||||||
base_filters=filters or None,
|
|
||||||
base_joins=search_joins or None,
|
|
||||||
)
|
|
||||||
|
|
||||||
return PaginatedResponse(
|
return PaginatedResponse(
|
||||||
data=items,
|
data=items,
|
||||||
@@ -897,21 +879,9 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
PaginatedResponse with CursorPagination metadata
|
PaginatedResponse with CursorPagination metadata
|
||||||
"""
|
"""
|
||||||
filters = list(filters) if filters else []
|
filters = list(filters) if filters else []
|
||||||
search_joins: list[Any] = []
|
|
||||||
|
|
||||||
if isinstance(filter_by, BaseModel):
|
fb_filters, search_joins = cls._prepare_filter_by(filter_by, facet_fields)
|
||||||
filter_by = filter_by.model_dump(exclude_none=True) or None
|
filters.extend(fb_filters)
|
||||||
|
|
||||||
# Build filter_by conditions from declared facet fields
|
|
||||||
if filter_by:
|
|
||||||
resolved_facets_for_filter = (
|
|
||||||
facet_fields if facet_fields is not None else cls.facet_fields
|
|
||||||
)
|
|
||||||
fb_filters, fb_joins = build_filter_by(
|
|
||||||
filter_by, resolved_facets_for_filter or []
|
|
||||||
)
|
|
||||||
filters.extend(fb_filters)
|
|
||||||
search_joins.extend(fb_joins)
|
|
||||||
|
|
||||||
if cls.cursor_column is None:
|
if cls.cursor_column is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -944,29 +914,23 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
|
|
||||||
# Build search filters
|
# Build search filters
|
||||||
if search:
|
if search:
|
||||||
search_filters, search_joins = build_search_filters(
|
search_filters, new_search_joins = build_search_filters(
|
||||||
cls.model,
|
cls.model,
|
||||||
search,
|
search,
|
||||||
search_fields=search_fields,
|
search_fields=search_fields,
|
||||||
default_fields=cls.searchable_fields,
|
default_fields=cls.searchable_fields,
|
||||||
)
|
)
|
||||||
filters.extend(search_filters)
|
filters.extend(search_filters)
|
||||||
|
search_joins.extend(new_search_joins)
|
||||||
|
|
||||||
# Build query
|
# Build query
|
||||||
q = select(cls.model)
|
q = select(cls.model)
|
||||||
|
|
||||||
# Apply explicit joins
|
# Apply explicit joins
|
||||||
if joins:
|
q = _apply_joins(q, joins, outer_join)
|
||||||
for model, condition in joins:
|
|
||||||
q = (
|
|
||||||
q.outerjoin(model, condition)
|
|
||||||
if outer_join
|
|
||||||
else q.join(model, condition)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply search joins (always outer joins)
|
# Apply search joins (always outer joins)
|
||||||
for join_rel in search_joins:
|
q = _apply_search_joins(q, search_joins)
|
||||||
q = q.outerjoin(join_rel)
|
|
||||||
|
|
||||||
if filters:
|
if filters:
|
||||||
q = q.where(and_(*filters))
|
q = q.where(and_(*filters))
|
||||||
@@ -998,19 +962,9 @@ class AsyncCrud(Generic[ModelType]):
|
|||||||
|
|
||||||
items: list[Any] = [schema.model_validate(item) for item in items_page]
|
items: list[Any] = [schema.model_validate(item) for item in items_page]
|
||||||
|
|
||||||
# Build facets
|
filter_attributes = await cls._build_filter_attributes(
|
||||||
resolved_facet_fields = (
|
session, facet_fields, filters, search_joins
|
||||||
facet_fields if facet_fields is not None else cls.facet_fields
|
|
||||||
)
|
)
|
||||||
filter_attributes: dict[str, list[Any]] | None = None
|
|
||||||
if resolved_facet_fields:
|
|
||||||
filter_attributes = await build_facets(
|
|
||||||
session,
|
|
||||||
cls.model,
|
|
||||||
resolved_facet_fields,
|
|
||||||
base_filters=filters or None,
|
|
||||||
base_joins=search_joins or None,
|
|
||||||
)
|
|
||||||
|
|
||||||
return PaginatedResponse(
|
return PaginatedResponse(
|
||||||
data=items,
|
data=items,
|
||||||
|
|||||||
@@ -1,24 +1,23 @@
|
|||||||
"""Search utilities for AsyncCrud."""
|
"""Search utilities for AsyncCrud."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import functools
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, replace
|
||||||
from typing import TYPE_CHECKING, Any, Literal
|
from typing import TYPE_CHECKING, Any, Literal
|
||||||
|
|
||||||
from sqlalchemy import String, or_, select
|
from sqlalchemy import String, and_, or_, select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||||
|
|
||||||
from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError
|
from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError
|
||||||
|
from ..types import FacetFieldType, SearchFieldType
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sqlalchemy.sql.elements import ColumnElement
|
from sqlalchemy.sql.elements import ColumnElement
|
||||||
|
|
||||||
SearchFieldType = InstrumentedAttribute[Any] | tuple[InstrumentedAttribute[Any], ...]
|
|
||||||
FacetFieldType = SearchFieldType
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SearchConfig:
|
class SearchConfig:
|
||||||
@@ -37,6 +36,7 @@ class SearchConfig:
|
|||||||
match_mode: Literal["any", "all"] = "any"
|
match_mode: Literal["any", "all"] = "any"
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache(maxsize=128)
|
||||||
def get_searchable_fields(
|
def get_searchable_fields(
|
||||||
model: type[DeclarativeBase],
|
model: type[DeclarativeBase],
|
||||||
*,
|
*,
|
||||||
@@ -101,14 +101,11 @@ def build_search_filters(
|
|||||||
if isinstance(search, str):
|
if isinstance(search, str):
|
||||||
config = SearchConfig(query=search, fields=search_fields)
|
config = SearchConfig(query=search, fields=search_fields)
|
||||||
else:
|
else:
|
||||||
config = search
|
config = (
|
||||||
if search_fields is not None:
|
replace(search, fields=search_fields)
|
||||||
config = SearchConfig(
|
if search_fields is not None
|
||||||
query=config.query,
|
else search
|
||||||
fields=search_fields,
|
)
|
||||||
case_sensitive=config.case_sensitive,
|
|
||||||
match_mode=config.match_mode,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not config.query or not config.query.strip():
|
if not config.query or not config.query.strip():
|
||||||
return [], []
|
return [], []
|
||||||
@@ -227,8 +224,6 @@ async def build_facets(
|
|||||||
q = q.outerjoin(rel)
|
q = q.outerjoin(rel)
|
||||||
|
|
||||||
if base_filters:
|
if base_filters:
|
||||||
from sqlalchemy import and_
|
|
||||||
|
|
||||||
q = q.where(and_(*base_filters))
|
q = q.where(and_(*base_filters))
|
||||||
|
|
||||||
q = q.order_by(column)
|
q = q.order_by(column)
|
||||||
|
|||||||
@@ -1,20 +1,17 @@
|
|||||||
"""Dependency factories for FastAPI routes."""
|
"""Dependency factories for FastAPI routes."""
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from collections.abc import AsyncGenerator, Callable
|
from collections.abc import Callable
|
||||||
from typing import Any, TypeVar, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
|
||||||
|
|
||||||
from .crud import CrudFactory
|
from .crud import CrudFactory
|
||||||
|
from .types import ModelType, SessionDependency
|
||||||
|
|
||||||
__all__ = ["BodyDependency", "PathDependency"]
|
__all__ = ["BodyDependency", "PathDependency"]
|
||||||
|
|
||||||
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
|
||||||
SessionDependency = Callable[[], AsyncGenerator[AsyncSession, None]]
|
|
||||||
|
|
||||||
|
|
||||||
def PathDependency(
|
def PathDependency(
|
||||||
model: type[ModelType],
|
model: type[ModelType],
|
||||||
|
|||||||
@@ -1,24 +1,23 @@
|
|||||||
"""Fixture loading utilities for database seeding."""
|
"""Fixture loading utilities for database seeding."""
|
||||||
|
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from typing import Any, TypeVar
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
from ..db import get_transaction
|
from ..db import get_transaction
|
||||||
from ..logger import get_logger
|
from ..logger import get_logger
|
||||||
|
from ..types import ModelType
|
||||||
from .enum import LoadStrategy
|
from .enum import LoadStrategy
|
||||||
from .registry import Context, FixtureRegistry
|
from .registry import Context, FixtureRegistry
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
T = TypeVar("T", bound=DeclarativeBase)
|
|
||||||
|
|
||||||
|
|
||||||
def get_obj_by_attr(
|
def get_obj_by_attr(
|
||||||
fixtures: Callable[[], Sequence[T]], attr_name: str, value: Any
|
fixtures: Callable[[], Sequence[ModelType]], attr_name: str, value: Any
|
||||||
) -> T:
|
) -> ModelType:
|
||||||
"""Get a SQLAlchemy model instance by matching an attribute value.
|
"""Get a SQLAlchemy model instance by matching an attribute value.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
"""Base Pydantic schemas for API responses."""
|
"""Base Pydantic schemas for API responses."""
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, ClassVar, Generic, TypeVar
|
from typing import Any, ClassVar, Generic
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
from .types import DataT
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ApiError",
|
"ApiError",
|
||||||
"CursorPagination",
|
"CursorPagination",
|
||||||
@@ -16,8 +18,6 @@ __all__ = [
|
|||||||
"ResponseStatus",
|
"ResponseStatus",
|
||||||
]
|
]
|
||||||
|
|
||||||
DataT = TypeVar("DataT")
|
|
||||||
|
|
||||||
|
|
||||||
class PydanticBase(BaseModel):
|
class PydanticBase(BaseModel):
|
||||||
"""Base class for all Pydantic models with common configuration."""
|
"""Base class for all Pydantic models with common configuration."""
|
||||||
|
|||||||
27
src/fastapi_toolsets/types.py
Normal file
27
src/fastapi_toolsets/types.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
"""Shared type aliases for the fastapi-toolsets package."""
|
||||||
|
|
||||||
|
from collections.abc import AsyncGenerator, Callable, Mapping
|
||||||
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, QueryableAttribute
|
||||||
|
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||||
|
from sqlalchemy.sql.elements import ColumnElement
|
||||||
|
|
||||||
|
# Generic TypeVars
|
||||||
|
DataT = TypeVar("DataT")
|
||||||
|
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
||||||
|
SchemaType = TypeVar("SchemaType", bound=BaseModel)
|
||||||
|
|
||||||
|
# CRUD type aliases
|
||||||
|
JoinType = list[tuple[type[DeclarativeBase], Any]]
|
||||||
|
M2MFieldType = Mapping[str, QueryableAttribute[Any]]
|
||||||
|
OrderByClause = ColumnElement[Any] | QueryableAttribute[Any]
|
||||||
|
|
||||||
|
# Search / facet type aliases
|
||||||
|
SearchFieldType = InstrumentedAttribute[Any] | tuple[InstrumentedAttribute[Any], ...]
|
||||||
|
FacetFieldType = SearchFieldType
|
||||||
|
|
||||||
|
# Dependency type aliases
|
||||||
|
SessionDependency = Callable[[], AsyncGenerator[AsyncSession, None]]
|
||||||
Reference in New Issue
Block a user