refactor: centralize type aliases in types.py and simplify crud layer

This commit is contained in:
2026-03-01 07:40:18 -05:00
parent 59d028d00e
commit e0828c7e71
7 changed files with 137 additions and 168 deletions

View File

@@ -1,12 +1,9 @@
"""Generic async CRUD operations for SQLAlchemy models."""
from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError
from .factory import CrudFactory, JoinType, M2MFieldType, OrderByClause
from .search import (
FacetFieldType,
SearchConfig,
get_searchable_fields,
)
from ..types import FacetFieldType, JoinType, M2MFieldType, OrderByClause
from .factory import CrudFactory
from .search import SearchConfig, get_searchable_fields
__all__ = [
"CrudFactory",

View File

@@ -6,10 +6,10 @@ import base64
import inspect
import json
import uuid as uuid_module
from collections.abc import Awaitable, Callable, Mapping, Sequence
from collections.abc import Awaitable, Callable, Sequence
from datetime import date, datetime
from decimal import Decimal
from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload
from typing import Any, ClassVar, Generic, Literal, Self, cast, overload
from fastapi import Query
from pydantic import BaseModel
@@ -20,28 +20,28 @@ from sqlalchemy.exc import NoResultFound
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase, QueryableAttribute, selectinload
from sqlalchemy.sql.base import ExecutableOption
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.roles import WhereHavingRole
from ..db import get_transaction
from ..exceptions import InvalidOrderFieldError, NotFoundError
from ..schemas import CursorPagination, OffsetPagination, PaginatedResponse, Response
from .search import (
from ..types import (
FacetFieldType,
SearchConfig,
JoinType,
M2MFieldType,
ModelType,
OrderByClause,
SchemaType,
SearchFieldType,
)
from .search import (
SearchConfig,
build_facets,
build_filter_by,
build_search_filters,
facet_keys,
)
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
SchemaType = TypeVar("SchemaType", bound=BaseModel)
JoinType = list[tuple[type[DeclarativeBase], Any]]
M2MFieldType = Mapping[str, QueryableAttribute[Any]]
OrderByClause = ColumnElement[Any] | QueryableAttribute[Any]
def _encode_cursor(value: Any) -> str:
"""Encode cursor column value as an base64 string."""
@@ -53,6 +53,22 @@ def _decode_cursor(cursor: str) -> str:
return json.loads(base64.b64decode(cursor.encode()).decode())
def _apply_joins(q: Any, joins: JoinType | None, outer_join: bool) -> Any:
"""Apply a list of (model, condition) joins to a SQLAlchemy select query."""
if not joins:
return q
for model, condition in joins:
q = q.outerjoin(model, condition) if outer_join else q.join(model, condition)
return q
def _apply_search_joins(q: Any, search_joins: list[Any]) -> Any:
"""Apply relationship-based outer joins (from search/filter_by) to a query."""
for join_rel in search_joins:
q = q.outerjoin(join_rel)
return q
class AsyncCrud(Generic[ModelType]):
"""Generic async CRUD operations for SQLAlchemy models.
@@ -132,6 +148,40 @@ class AsyncCrud(Generic[ModelType]):
return set()
return set(cls.m2m_fields.keys())
@classmethod
def _prepare_filter_by(
cls: type[Self],
filter_by: dict[str, Any] | BaseModel | None,
facet_fields: Sequence[FacetFieldType] | None,
) -> tuple[list[Any], list[Any]]:
"""Normalize filter_by and return (filters, joins) to apply to the query."""
if isinstance(filter_by, BaseModel):
filter_by = filter_by.model_dump(exclude_none=True) or None
if not filter_by:
return [], []
resolved = facet_fields if facet_fields is not None else cls.facet_fields
return build_filter_by(filter_by, resolved or [])
@classmethod
async def _build_filter_attributes(
cls: type[Self],
session: AsyncSession,
facet_fields: Sequence[FacetFieldType] | None,
filters: list[Any],
search_joins: list[Any],
) -> dict[str, list[Any]] | None:
"""Build facet filter_attributes, or return None if no facet fields configured."""
resolved = facet_fields if facet_fields is not None else cls.facet_fields
if not resolved:
return None
return await build_facets(
session,
cls.model,
resolved,
base_filters=filters or None,
base_joins=search_joins or None,
)
@classmethod
def filter_params(
cls: type[Self],
@@ -290,8 +340,7 @@ class AsyncCrud(Generic[ModelType]):
await session.refresh(db_model)
result = cast(ModelType, db_model)
if schema:
data_out = schema.model_validate(result) if schema else result
return Response(data=data_out)
return Response(data=schema.model_validate(result))
return result
@overload
@@ -354,13 +403,7 @@ class AsyncCrud(Generic[ModelType]):
MultipleResultsFound: If more than one record found
"""
q = select(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
q = _apply_joins(q, joins, outer_join)
q = q.where(and_(*filters))
if resolved := cls._resolve_load_options(load_options):
q = q.options(*resolved)
@@ -372,8 +415,7 @@ class AsyncCrud(Generic[ModelType]):
raise NotFoundError()
result = cast(ModelType, item)
if schema:
data_out = schema.model_validate(result) if schema else result
return Response(data=data_out)
return Response(data=schema.model_validate(result))
return result
@classmethod
@@ -399,13 +441,7 @@ class AsyncCrud(Generic[ModelType]):
Model instance or None
"""
q = select(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
q = _apply_joins(q, joins, outer_join)
if filters:
q = q.where(and_(*filters))
if resolved := cls._resolve_load_options(load_options):
@@ -442,13 +478,7 @@ class AsyncCrud(Generic[ModelType]):
List of model instances
"""
q = select(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
q = _apply_joins(q, joins, outer_join)
if filters:
q = q.where(and_(*filters))
if resolved := cls._resolve_load_options(load_options):
@@ -546,8 +576,7 @@ class AsyncCrud(Generic[ModelType]):
setattr(db_model, rel_attr, related_instances)
await session.refresh(db_model)
if schema:
data_out = schema.model_validate(db_model) if schema else db_model
return Response(data=data_out)
return Response(data=schema.model_validate(db_model))
return db_model
@classmethod
@@ -664,13 +693,7 @@ class AsyncCrud(Generic[ModelType]):
Number of matching records
"""
q = select(func.count()).select_from(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
q = _apply_joins(q, joins, outer_join)
if filters:
q = q.where(and_(*filters))
result = await session.execute(q)
@@ -697,13 +720,7 @@ class AsyncCrud(Generic[ModelType]):
True if at least one record matches
"""
q = select(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
q = _apply_joins(q, joins, outer_join)
q = q.where(and_(*filters)).exists().select()
result = await session.execute(q)
return bool(result.scalar())
@@ -750,47 +767,29 @@ class AsyncCrud(Generic[ModelType]):
"""
filters = list(filters) if filters else []
offset = (page - 1) * items_per_page
search_joins: list[Any] = []
if isinstance(filter_by, BaseModel):
filter_by = filter_by.model_dump(exclude_none=True) or None
# Build filter_by conditions from declared facet fields
if filter_by:
resolved_facets_for_filter = (
facet_fields if facet_fields is not None else cls.facet_fields
)
fb_filters, fb_joins = build_filter_by(
filter_by, resolved_facets_for_filter or []
)
filters.extend(fb_filters)
search_joins.extend(fb_joins)
fb_filters, search_joins = cls._prepare_filter_by(filter_by, facet_fields)
filters.extend(fb_filters)
# Build search filters
if search:
search_filters, search_joins = build_search_filters(
search_filters, new_search_joins = build_search_filters(
cls.model,
search,
search_fields=search_fields,
default_fields=cls.searchable_fields,
)
filters.extend(search_filters)
search_joins.extend(new_search_joins)
# Build query with joins
q = select(cls.model)
# Apply explicit joins
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
q = _apply_joins(q, joins, outer_join)
# Apply search joins (always outer joins for search)
for join_rel in search_joins:
q = q.outerjoin(join_rel)
q = _apply_search_joins(q, search_joins)
if filters:
q = q.where(and_(*filters))
@@ -810,17 +809,10 @@ class AsyncCrud(Generic[ModelType]):
count_q = count_q.select_from(cls.model)
# Apply explicit joins to count query
if joins:
for model, condition in joins:
count_q = (
count_q.outerjoin(model, condition)
if outer_join
else count_q.join(model, condition)
)
count_q = _apply_joins(count_q, joins, outer_join)
# Apply search joins to count query
for join_rel in search_joins:
count_q = count_q.outerjoin(join_rel)
count_q = _apply_search_joins(count_q, search_joins)
if filters:
count_q = count_q.where(and_(*filters))
@@ -828,19 +820,9 @@ class AsyncCrud(Generic[ModelType]):
count_result = await session.execute(count_q)
total_count = count_result.scalar_one()
# Build facets
resolved_facet_fields = (
facet_fields if facet_fields is not None else cls.facet_fields
filter_attributes = await cls._build_filter_attributes(
session, facet_fields, filters, search_joins
)
filter_attributes: dict[str, list[Any]] | None = None
if resolved_facet_fields:
filter_attributes = await build_facets(
session,
cls.model,
resolved_facet_fields,
base_filters=filters or None,
base_joins=search_joins or None,
)
return PaginatedResponse(
data=items,
@@ -897,21 +879,9 @@ class AsyncCrud(Generic[ModelType]):
PaginatedResponse with CursorPagination metadata
"""
filters = list(filters) if filters else []
search_joins: list[Any] = []
if isinstance(filter_by, BaseModel):
filter_by = filter_by.model_dump(exclude_none=True) or None
# Build filter_by conditions from declared facet fields
if filter_by:
resolved_facets_for_filter = (
facet_fields if facet_fields is not None else cls.facet_fields
)
fb_filters, fb_joins = build_filter_by(
filter_by, resolved_facets_for_filter or []
)
filters.extend(fb_filters)
search_joins.extend(fb_joins)
fb_filters, search_joins = cls._prepare_filter_by(filter_by, facet_fields)
filters.extend(fb_filters)
if cls.cursor_column is None:
raise ValueError(
@@ -944,29 +914,23 @@ class AsyncCrud(Generic[ModelType]):
# Build search filters
if search:
search_filters, search_joins = build_search_filters(
search_filters, new_search_joins = build_search_filters(
cls.model,
search,
search_fields=search_fields,
default_fields=cls.searchable_fields,
)
filters.extend(search_filters)
search_joins.extend(new_search_joins)
# Build query
q = select(cls.model)
# Apply explicit joins
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
q = _apply_joins(q, joins, outer_join)
# Apply search joins (always outer joins)
for join_rel in search_joins:
q = q.outerjoin(join_rel)
q = _apply_search_joins(q, search_joins)
if filters:
q = q.where(and_(*filters))
@@ -998,19 +962,9 @@ class AsyncCrud(Generic[ModelType]):
items: list[Any] = [schema.model_validate(item) for item in items_page]
# Build facets
resolved_facet_fields = (
facet_fields if facet_fields is not None else cls.facet_fields
filter_attributes = await cls._build_filter_attributes(
session, facet_fields, filters, search_joins
)
filter_attributes: dict[str, list[Any]] | None = None
if resolved_facet_fields:
filter_attributes = await build_facets(
session,
cls.model,
resolved_facet_fields,
base_filters=filters or None,
base_joins=search_joins or None,
)
return PaginatedResponse(
data=items,

View File

@@ -1,24 +1,23 @@
"""Search utilities for AsyncCrud."""
import asyncio
import functools
from collections import Counter
from collections.abc import Sequence
from dataclasses import dataclass
from dataclasses import dataclass, replace
from typing import TYPE_CHECKING, Any, Literal
from sqlalchemy import String, or_, select
from sqlalchemy import String, and_, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm.attributes import InstrumentedAttribute
from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError
from ..types import FacetFieldType, SearchFieldType
if TYPE_CHECKING:
from sqlalchemy.sql.elements import ColumnElement
SearchFieldType = InstrumentedAttribute[Any] | tuple[InstrumentedAttribute[Any], ...]
FacetFieldType = SearchFieldType
@dataclass
class SearchConfig:
@@ -37,6 +36,7 @@ class SearchConfig:
match_mode: Literal["any", "all"] = "any"
@functools.lru_cache(maxsize=128)
def get_searchable_fields(
model: type[DeclarativeBase],
*,
@@ -101,14 +101,11 @@ def build_search_filters(
if isinstance(search, str):
config = SearchConfig(query=search, fields=search_fields)
else:
config = search
if search_fields is not None:
config = SearchConfig(
query=config.query,
fields=search_fields,
case_sensitive=config.case_sensitive,
match_mode=config.match_mode,
)
config = (
replace(search, fields=search_fields)
if search_fields is not None
else search
)
if not config.query or not config.query.strip():
return [], []
@@ -227,8 +224,6 @@ async def build_facets(
q = q.outerjoin(rel)
if base_filters:
from sqlalchemy import and_
q = q.where(and_(*base_filters))
q = q.order_by(column)

View File

@@ -1,20 +1,17 @@
"""Dependency factories for FastAPI routes."""
import inspect
from collections.abc import AsyncGenerator, Callable
from typing import Any, TypeVar, cast
from collections.abc import Callable
from typing import Any, cast
from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase
from .crud import CrudFactory
from .types import ModelType, SessionDependency
__all__ = ["BodyDependency", "PathDependency"]
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
SessionDependency = Callable[[], AsyncGenerator[AsyncSession, None]]
def PathDependency(
model: type[ModelType],

View File

@@ -1,24 +1,23 @@
"""Fixture loading utilities for database seeding."""
from collections.abc import Callable, Sequence
from typing import Any, TypeVar
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase
from ..db import get_transaction
from ..logger import get_logger
from ..types import ModelType
from .enum import LoadStrategy
from .registry import Context, FixtureRegistry
logger = get_logger()
T = TypeVar("T", bound=DeclarativeBase)
def get_obj_by_attr(
fixtures: Callable[[], Sequence[T]], attr_name: str, value: Any
) -> T:
fixtures: Callable[[], Sequence[ModelType]], attr_name: str, value: Any
) -> ModelType:
"""Get a SQLAlchemy model instance by matching an attribute value.
Args:

View File

@@ -1,10 +1,12 @@
"""Base Pydantic schemas for API responses."""
from enum import Enum
from typing import Any, ClassVar, Generic, TypeVar
from typing import Any, ClassVar, Generic
from pydantic import BaseModel, ConfigDict
from .types import DataT
__all__ = [
"ApiError",
"CursorPagination",
@@ -16,8 +18,6 @@ __all__ = [
"ResponseStatus",
]
DataT = TypeVar("DataT")
class PydanticBase(BaseModel):
"""Base class for all Pydantic models with common configuration."""

View File

@@ -0,0 +1,27 @@
"""Shared type aliases for the fastapi-toolsets package."""
from collections.abc import AsyncGenerator, Callable, Mapping
from typing import Any, TypeVar
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase, QueryableAttribute
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.sql.elements import ColumnElement
# Generic TypeVars
DataT = TypeVar("DataT")
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
SchemaType = TypeVar("SchemaType", bound=BaseModel)
# CRUD type aliases
JoinType = list[tuple[type[DeclarativeBase], Any]]
M2MFieldType = Mapping[str, QueryableAttribute[Any]]
OrderByClause = ColumnElement[Any] | QueryableAttribute[Any]
# Search / facet type aliases
SearchFieldType = InstrumentedAttribute[Any] | tuple[InstrumentedAttribute[Any], ...]
FacetFieldType = SearchFieldType
# Dependency type aliases
SessionDependency = Callable[[], AsyncGenerator[AsyncSession, None]]