refactor: centralize type aliases in types.py and simplify crud layer

This commit is contained in:
2026-03-01 07:40:18 -05:00
parent fddfc98acc
commit bcf0c3becb
7 changed files with 137 additions and 168 deletions

View File

@@ -1,12 +1,9 @@
"""Generic async CRUD operations for SQLAlchemy models.""" """Generic async CRUD operations for SQLAlchemy models."""
from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError
from .factory import CrudFactory, JoinType, M2MFieldType, OrderByClause from ..types import FacetFieldType, JoinType, M2MFieldType, OrderByClause
from .search import ( from .factory import CrudFactory
FacetFieldType, from .search import SearchConfig, get_searchable_fields
SearchConfig,
get_searchable_fields,
)
__all__ = [ __all__ = [
"CrudFactory", "CrudFactory",

View File

@@ -6,10 +6,10 @@ import base64
import inspect import inspect
import json import json
import uuid as uuid_module import uuid as uuid_module
from collections.abc import Awaitable, Callable, Mapping, Sequence from collections.abc import Awaitable, Callable, Sequence
from datetime import date, datetime from datetime import date, datetime
from decimal import Decimal from decimal import Decimal
from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload from typing import Any, ClassVar, Generic, Literal, Self, cast, overload
from fastapi import Query from fastapi import Query
from pydantic import BaseModel from pydantic import BaseModel
@@ -20,28 +20,28 @@ from sqlalchemy.exc import NoResultFound
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase, QueryableAttribute, selectinload from sqlalchemy.orm import DeclarativeBase, QueryableAttribute, selectinload
from sqlalchemy.sql.base import ExecutableOption from sqlalchemy.sql.base import ExecutableOption
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.roles import WhereHavingRole from sqlalchemy.sql.roles import WhereHavingRole
from ..db import get_transaction from ..db import get_transaction
from ..exceptions import InvalidOrderFieldError, NotFoundError from ..exceptions import InvalidOrderFieldError, NotFoundError
from ..schemas import CursorPagination, OffsetPagination, PaginatedResponse, Response from ..schemas import CursorPagination, OffsetPagination, PaginatedResponse, Response
from .search import ( from ..types import (
FacetFieldType, FacetFieldType,
SearchConfig, JoinType,
M2MFieldType,
ModelType,
OrderByClause,
SchemaType,
SearchFieldType, SearchFieldType,
)
from .search import (
SearchConfig,
build_facets, build_facets,
build_filter_by, build_filter_by,
build_search_filters, build_search_filters,
facet_keys, facet_keys,
) )
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
SchemaType = TypeVar("SchemaType", bound=BaseModel)
JoinType = list[tuple[type[DeclarativeBase], Any]]
M2MFieldType = Mapping[str, QueryableAttribute[Any]]
OrderByClause = ColumnElement[Any] | QueryableAttribute[Any]
def _encode_cursor(value: Any) -> str: def _encode_cursor(value: Any) -> str:
"""Encode cursor column value as an base64 string.""" """Encode cursor column value as an base64 string."""
@@ -53,6 +53,22 @@ def _decode_cursor(cursor: str) -> str:
return json.loads(base64.b64decode(cursor.encode()).decode()) return json.loads(base64.b64decode(cursor.encode()).decode())
def _apply_joins(q: Any, joins: JoinType | None, outer_join: bool) -> Any:
"""Apply a list of (model, condition) joins to a SQLAlchemy select query."""
if not joins:
return q
for model, condition in joins:
q = q.outerjoin(model, condition) if outer_join else q.join(model, condition)
return q
def _apply_search_joins(q: Any, search_joins: list[Any]) -> Any:
"""Apply relationship-based outer joins (from search/filter_by) to a query."""
for join_rel in search_joins:
q = q.outerjoin(join_rel)
return q
class AsyncCrud(Generic[ModelType]): class AsyncCrud(Generic[ModelType]):
"""Generic async CRUD operations for SQLAlchemy models. """Generic async CRUD operations for SQLAlchemy models.
@@ -132,6 +148,40 @@ class AsyncCrud(Generic[ModelType]):
return set() return set()
return set(cls.m2m_fields.keys()) return set(cls.m2m_fields.keys())
@classmethod
def _prepare_filter_by(
cls: type[Self],
filter_by: dict[str, Any] | BaseModel | None,
facet_fields: Sequence[FacetFieldType] | None,
) -> tuple[list[Any], list[Any]]:
"""Normalize filter_by and return (filters, joins) to apply to the query."""
if isinstance(filter_by, BaseModel):
filter_by = filter_by.model_dump(exclude_none=True) or None
if not filter_by:
return [], []
resolved = facet_fields if facet_fields is not None else cls.facet_fields
return build_filter_by(filter_by, resolved or [])
@classmethod
async def _build_filter_attributes(
cls: type[Self],
session: AsyncSession,
facet_fields: Sequence[FacetFieldType] | None,
filters: list[Any],
search_joins: list[Any],
) -> dict[str, list[Any]] | None:
"""Build facet filter_attributes, or return None if no facet fields configured."""
resolved = facet_fields if facet_fields is not None else cls.facet_fields
if not resolved:
return None
return await build_facets(
session,
cls.model,
resolved,
base_filters=filters or None,
base_joins=search_joins or None,
)
@classmethod @classmethod
def filter_params( def filter_params(
cls: type[Self], cls: type[Self],
@@ -290,8 +340,7 @@ class AsyncCrud(Generic[ModelType]):
await session.refresh(db_model) await session.refresh(db_model)
result = cast(ModelType, db_model) result = cast(ModelType, db_model)
if schema: if schema:
data_out = schema.model_validate(result) if schema else result return Response(data=schema.model_validate(result))
return Response(data=data_out)
return result return result
@overload @overload
@@ -354,13 +403,7 @@ class AsyncCrud(Generic[ModelType]):
MultipleResultsFound: If more than one record found MultipleResultsFound: If more than one record found
""" """
q = select(cls.model) q = select(cls.model)
if joins: q = _apply_joins(q, joins, outer_join)
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
q = q.where(and_(*filters)) q = q.where(and_(*filters))
if resolved := cls._resolve_load_options(load_options): if resolved := cls._resolve_load_options(load_options):
q = q.options(*resolved) q = q.options(*resolved)
@@ -372,8 +415,7 @@ class AsyncCrud(Generic[ModelType]):
raise NotFoundError() raise NotFoundError()
result = cast(ModelType, item) result = cast(ModelType, item)
if schema: if schema:
data_out = schema.model_validate(result) if schema else result return Response(data=schema.model_validate(result))
return Response(data=data_out)
return result return result
@classmethod @classmethod
@@ -399,13 +441,7 @@ class AsyncCrud(Generic[ModelType]):
Model instance or None Model instance or None
""" """
q = select(cls.model) q = select(cls.model)
if joins: q = _apply_joins(q, joins, outer_join)
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
if filters: if filters:
q = q.where(and_(*filters)) q = q.where(and_(*filters))
if resolved := cls._resolve_load_options(load_options): if resolved := cls._resolve_load_options(load_options):
@@ -442,13 +478,7 @@ class AsyncCrud(Generic[ModelType]):
List of model instances List of model instances
""" """
q = select(cls.model) q = select(cls.model)
if joins: q = _apply_joins(q, joins, outer_join)
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
if filters: if filters:
q = q.where(and_(*filters)) q = q.where(and_(*filters))
if resolved := cls._resolve_load_options(load_options): if resolved := cls._resolve_load_options(load_options):
@@ -546,8 +576,7 @@ class AsyncCrud(Generic[ModelType]):
setattr(db_model, rel_attr, related_instances) setattr(db_model, rel_attr, related_instances)
await session.refresh(db_model) await session.refresh(db_model)
if schema: if schema:
data_out = schema.model_validate(db_model) if schema else db_model return Response(data=schema.model_validate(db_model))
return Response(data=data_out)
return db_model return db_model
@classmethod @classmethod
@@ -664,13 +693,7 @@ class AsyncCrud(Generic[ModelType]):
Number of matching records Number of matching records
""" """
q = select(func.count()).select_from(cls.model) q = select(func.count()).select_from(cls.model)
if joins: q = _apply_joins(q, joins, outer_join)
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
if filters: if filters:
q = q.where(and_(*filters)) q = q.where(and_(*filters))
result = await session.execute(q) result = await session.execute(q)
@@ -697,13 +720,7 @@ class AsyncCrud(Generic[ModelType]):
True if at least one record matches True if at least one record matches
""" """
q = select(cls.model) q = select(cls.model)
if joins: q = _apply_joins(q, joins, outer_join)
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
q = q.where(and_(*filters)).exists().select() q = q.where(and_(*filters)).exists().select()
result = await session.execute(q) result = await session.execute(q)
return bool(result.scalar()) return bool(result.scalar())
@@ -750,47 +767,29 @@ class AsyncCrud(Generic[ModelType]):
""" """
filters = list(filters) if filters else [] filters = list(filters) if filters else []
offset = (page - 1) * items_per_page offset = (page - 1) * items_per_page
search_joins: list[Any] = []
if isinstance(filter_by, BaseModel): fb_filters, search_joins = cls._prepare_filter_by(filter_by, facet_fields)
filter_by = filter_by.model_dump(exclude_none=True) or None filters.extend(fb_filters)
# Build filter_by conditions from declared facet fields
if filter_by:
resolved_facets_for_filter = (
facet_fields if facet_fields is not None else cls.facet_fields
)
fb_filters, fb_joins = build_filter_by(
filter_by, resolved_facets_for_filter or []
)
filters.extend(fb_filters)
search_joins.extend(fb_joins)
# Build search filters # Build search filters
if search: if search:
search_filters, search_joins = build_search_filters( search_filters, new_search_joins = build_search_filters(
cls.model, cls.model,
search, search,
search_fields=search_fields, search_fields=search_fields,
default_fields=cls.searchable_fields, default_fields=cls.searchable_fields,
) )
filters.extend(search_filters) filters.extend(search_filters)
search_joins.extend(new_search_joins)
# Build query with joins # Build query with joins
q = select(cls.model) q = select(cls.model)
# Apply explicit joins # Apply explicit joins
if joins: q = _apply_joins(q, joins, outer_join)
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
# Apply search joins (always outer joins for search) # Apply search joins (always outer joins for search)
for join_rel in search_joins: q = _apply_search_joins(q, search_joins)
q = q.outerjoin(join_rel)
if filters: if filters:
q = q.where(and_(*filters)) q = q.where(and_(*filters))
@@ -810,17 +809,10 @@ class AsyncCrud(Generic[ModelType]):
count_q = count_q.select_from(cls.model) count_q = count_q.select_from(cls.model)
# Apply explicit joins to count query # Apply explicit joins to count query
if joins: count_q = _apply_joins(count_q, joins, outer_join)
for model, condition in joins:
count_q = (
count_q.outerjoin(model, condition)
if outer_join
else count_q.join(model, condition)
)
# Apply search joins to count query # Apply search joins to count query
for join_rel in search_joins: count_q = _apply_search_joins(count_q, search_joins)
count_q = count_q.outerjoin(join_rel)
if filters: if filters:
count_q = count_q.where(and_(*filters)) count_q = count_q.where(and_(*filters))
@@ -828,19 +820,9 @@ class AsyncCrud(Generic[ModelType]):
count_result = await session.execute(count_q) count_result = await session.execute(count_q)
total_count = count_result.scalar_one() total_count = count_result.scalar_one()
# Build facets filter_attributes = await cls._build_filter_attributes(
resolved_facet_fields = ( session, facet_fields, filters, search_joins
facet_fields if facet_fields is not None else cls.facet_fields
) )
filter_attributes: dict[str, list[Any]] | None = None
if resolved_facet_fields:
filter_attributes = await build_facets(
session,
cls.model,
resolved_facet_fields,
base_filters=filters or None,
base_joins=search_joins or None,
)
return PaginatedResponse( return PaginatedResponse(
data=items, data=items,
@@ -897,21 +879,9 @@ class AsyncCrud(Generic[ModelType]):
PaginatedResponse with CursorPagination metadata PaginatedResponse with CursorPagination metadata
""" """
filters = list(filters) if filters else [] filters = list(filters) if filters else []
search_joins: list[Any] = []
if isinstance(filter_by, BaseModel): fb_filters, search_joins = cls._prepare_filter_by(filter_by, facet_fields)
filter_by = filter_by.model_dump(exclude_none=True) or None filters.extend(fb_filters)
# Build filter_by conditions from declared facet fields
if filter_by:
resolved_facets_for_filter = (
facet_fields if facet_fields is not None else cls.facet_fields
)
fb_filters, fb_joins = build_filter_by(
filter_by, resolved_facets_for_filter or []
)
filters.extend(fb_filters)
search_joins.extend(fb_joins)
if cls.cursor_column is None: if cls.cursor_column is None:
raise ValueError( raise ValueError(
@@ -944,29 +914,23 @@ class AsyncCrud(Generic[ModelType]):
# Build search filters # Build search filters
if search: if search:
search_filters, search_joins = build_search_filters( search_filters, new_search_joins = build_search_filters(
cls.model, cls.model,
search, search,
search_fields=search_fields, search_fields=search_fields,
default_fields=cls.searchable_fields, default_fields=cls.searchable_fields,
) )
filters.extend(search_filters) filters.extend(search_filters)
search_joins.extend(new_search_joins)
# Build query # Build query
q = select(cls.model) q = select(cls.model)
# Apply explicit joins # Apply explicit joins
if joins: q = _apply_joins(q, joins, outer_join)
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
# Apply search joins (always outer joins) # Apply search joins (always outer joins)
for join_rel in search_joins: q = _apply_search_joins(q, search_joins)
q = q.outerjoin(join_rel)
if filters: if filters:
q = q.where(and_(*filters)) q = q.where(and_(*filters))
@@ -998,19 +962,9 @@ class AsyncCrud(Generic[ModelType]):
items: list[Any] = [schema.model_validate(item) for item in items_page] items: list[Any] = [schema.model_validate(item) for item in items_page]
# Build facets filter_attributes = await cls._build_filter_attributes(
resolved_facet_fields = ( session, facet_fields, filters, search_joins
facet_fields if facet_fields is not None else cls.facet_fields
) )
filter_attributes: dict[str, list[Any]] | None = None
if resolved_facet_fields:
filter_attributes = await build_facets(
session,
cls.model,
resolved_facet_fields,
base_filters=filters or None,
base_joins=search_joins or None,
)
return PaginatedResponse( return PaginatedResponse(
data=items, data=items,

View File

@@ -1,24 +1,23 @@
"""Search utilities for AsyncCrud.""" """Search utilities for AsyncCrud."""
import asyncio import asyncio
import functools
from collections import Counter from collections import Counter
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass, replace
from typing import TYPE_CHECKING, Any, Literal from typing import TYPE_CHECKING, Any, Literal
from sqlalchemy import String, or_, select from sqlalchemy import String, and_, or_, select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.orm.attributes import InstrumentedAttribute
from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError
from ..types import FacetFieldType, SearchFieldType
if TYPE_CHECKING: if TYPE_CHECKING:
from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.elements import ColumnElement
SearchFieldType = InstrumentedAttribute[Any] | tuple[InstrumentedAttribute[Any], ...]
FacetFieldType = SearchFieldType
@dataclass @dataclass
class SearchConfig: class SearchConfig:
@@ -37,6 +36,7 @@ class SearchConfig:
match_mode: Literal["any", "all"] = "any" match_mode: Literal["any", "all"] = "any"
@functools.lru_cache(maxsize=128)
def get_searchable_fields( def get_searchable_fields(
model: type[DeclarativeBase], model: type[DeclarativeBase],
*, *,
@@ -101,14 +101,11 @@ def build_search_filters(
if isinstance(search, str): if isinstance(search, str):
config = SearchConfig(query=search, fields=search_fields) config = SearchConfig(query=search, fields=search_fields)
else: else:
config = search config = (
if search_fields is not None: replace(search, fields=search_fields)
config = SearchConfig( if search_fields is not None
query=config.query, else search
fields=search_fields, )
case_sensitive=config.case_sensitive,
match_mode=config.match_mode,
)
if not config.query or not config.query.strip(): if not config.query or not config.query.strip():
return [], [] return [], []
@@ -227,8 +224,6 @@ async def build_facets(
q = q.outerjoin(rel) q = q.outerjoin(rel)
if base_filters: if base_filters:
from sqlalchemy import and_
q = q.where(and_(*base_filters)) q = q.where(and_(*base_filters))
q = q.order_by(column) q = q.order_by(column)

View File

@@ -1,20 +1,17 @@
"""Dependency factories for FastAPI routes.""" """Dependency factories for FastAPI routes."""
import inspect import inspect
from collections.abc import AsyncGenerator, Callable from collections.abc import Callable
from typing import Any, TypeVar, cast from typing import Any, cast
from fastapi import Depends from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase
from .crud import CrudFactory from .crud import CrudFactory
from .types import ModelType, SessionDependency
__all__ = ["BodyDependency", "PathDependency"] __all__ = ["BodyDependency", "PathDependency"]
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
SessionDependency = Callable[[], AsyncGenerator[AsyncSession, None]]
def PathDependency( def PathDependency(
model: type[ModelType], model: type[ModelType],

View File

@@ -1,24 +1,23 @@
"""Fixture loading utilities for database seeding.""" """Fixture loading utilities for database seeding."""
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from typing import Any, TypeVar from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import DeclarativeBase
from ..db import get_transaction from ..db import get_transaction
from ..logger import get_logger from ..logger import get_logger
from ..types import ModelType
from .enum import LoadStrategy from .enum import LoadStrategy
from .registry import Context, FixtureRegistry from .registry import Context, FixtureRegistry
logger = get_logger() logger = get_logger()
T = TypeVar("T", bound=DeclarativeBase)
def get_obj_by_attr( def get_obj_by_attr(
fixtures: Callable[[], Sequence[T]], attr_name: str, value: Any fixtures: Callable[[], Sequence[ModelType]], attr_name: str, value: Any
) -> T: ) -> ModelType:
"""Get a SQLAlchemy model instance by matching an attribute value. """Get a SQLAlchemy model instance by matching an attribute value.
Args: Args:

View File

@@ -1,10 +1,12 @@
"""Base Pydantic schemas for API responses.""" """Base Pydantic schemas for API responses."""
from enum import Enum from enum import Enum
from typing import Any, ClassVar, Generic, TypeVar from typing import Any, ClassVar, Generic
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from .types import DataT
__all__ = [ __all__ = [
"ApiError", "ApiError",
"CursorPagination", "CursorPagination",
@@ -16,8 +18,6 @@ __all__ = [
"ResponseStatus", "ResponseStatus",
] ]
DataT = TypeVar("DataT")
class PydanticBase(BaseModel): class PydanticBase(BaseModel):
"""Base class for all Pydantic models with common configuration.""" """Base class for all Pydantic models with common configuration."""

View File

@@ -0,0 +1,27 @@
"""Shared type aliases for the fastapi-toolsets package."""
from collections.abc import AsyncGenerator, Callable, Mapping
from typing import Any, TypeVar
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase, QueryableAttribute
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.sql.elements import ColumnElement
# Generic TypeVars
DataT = TypeVar("DataT")
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
SchemaType = TypeVar("SchemaType", bound=BaseModel)
# CRUD type aliases
JoinType = list[tuple[type[DeclarativeBase], Any]]
M2MFieldType = Mapping[str, QueryableAttribute[Any]]
OrderByClause = ColumnElement[Any] | QueryableAttribute[Any]
# Search / facet type aliases
SearchFieldType = InstrumentedAttribute[Any] | tuple[InstrumentedAttribute[Any], ...]
FacetFieldType = SearchFieldType
# Dependency type aliases
SessionDependency = Callable[[], AsyncGenerator[AsyncSession, None]]