refactor: centralize type aliases in types.py and simplify crud layer

This commit is contained in:
2026-03-01 07:40:18 -05:00
parent 59d028d00e
commit e0828c7e71
7 changed files with 137 additions and 168 deletions

View File

@@ -6,10 +6,10 @@ import base64
import inspect
import json
import uuid as uuid_module
from collections.abc import Awaitable, Callable, Mapping, Sequence
from collections.abc import Awaitable, Callable, Sequence
from datetime import date, datetime
from decimal import Decimal
from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload
from typing import Any, ClassVar, Generic, Literal, Self, cast, overload
from fastapi import Query
from pydantic import BaseModel
@@ -20,28 +20,28 @@ from sqlalchemy.exc import NoResultFound
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase, QueryableAttribute, selectinload
from sqlalchemy.sql.base import ExecutableOption
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.roles import WhereHavingRole
from ..db import get_transaction
from ..exceptions import InvalidOrderFieldError, NotFoundError
from ..schemas import CursorPagination, OffsetPagination, PaginatedResponse, Response
from .search import (
from ..types import (
FacetFieldType,
SearchConfig,
JoinType,
M2MFieldType,
ModelType,
OrderByClause,
SchemaType,
SearchFieldType,
)
from .search import (
SearchConfig,
build_facets,
build_filter_by,
build_search_filters,
facet_keys,
)
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
SchemaType = TypeVar("SchemaType", bound=BaseModel)
JoinType = list[tuple[type[DeclarativeBase], Any]]
M2MFieldType = Mapping[str, QueryableAttribute[Any]]
OrderByClause = ColumnElement[Any] | QueryableAttribute[Any]
def _encode_cursor(value: Any) -> str:
"""Encode cursor column value as an base64 string."""
@@ -53,6 +53,22 @@ def _decode_cursor(cursor: str) -> str:
return json.loads(base64.b64decode(cursor.encode()).decode())
def _apply_joins(q: Any, joins: JoinType | None, outer_join: bool) -> Any:
"""Apply a list of (model, condition) joins to a SQLAlchemy select query."""
if not joins:
return q
for model, condition in joins:
q = q.outerjoin(model, condition) if outer_join else q.join(model, condition)
return q
def _apply_search_joins(q: Any, search_joins: list[Any]) -> Any:
"""Apply relationship-based outer joins (from search/filter_by) to a query."""
for join_rel in search_joins:
q = q.outerjoin(join_rel)
return q
class AsyncCrud(Generic[ModelType]):
"""Generic async CRUD operations for SQLAlchemy models.
@@ -132,6 +148,40 @@ class AsyncCrud(Generic[ModelType]):
return set()
return set(cls.m2m_fields.keys())
@classmethod
def _prepare_filter_by(
cls: type[Self],
filter_by: dict[str, Any] | BaseModel | None,
facet_fields: Sequence[FacetFieldType] | None,
) -> tuple[list[Any], list[Any]]:
"""Normalize filter_by and return (filters, joins) to apply to the query."""
if isinstance(filter_by, BaseModel):
filter_by = filter_by.model_dump(exclude_none=True) or None
if not filter_by:
return [], []
resolved = facet_fields if facet_fields is not None else cls.facet_fields
return build_filter_by(filter_by, resolved or [])
@classmethod
async def _build_filter_attributes(
cls: type[Self],
session: AsyncSession,
facet_fields: Sequence[FacetFieldType] | None,
filters: list[Any],
search_joins: list[Any],
) -> dict[str, list[Any]] | None:
"""Build facet filter_attributes, or return None if no facet fields configured."""
resolved = facet_fields if facet_fields is not None else cls.facet_fields
if not resolved:
return None
return await build_facets(
session,
cls.model,
resolved,
base_filters=filters or None,
base_joins=search_joins or None,
)
@classmethod
def filter_params(
cls: type[Self],
@@ -290,8 +340,7 @@ class AsyncCrud(Generic[ModelType]):
await session.refresh(db_model)
result = cast(ModelType, db_model)
if schema:
data_out = schema.model_validate(result) if schema else result
return Response(data=data_out)
return Response(data=schema.model_validate(result))
return result
@overload
@@ -354,13 +403,7 @@ class AsyncCrud(Generic[ModelType]):
MultipleResultsFound: If more than one record found
"""
q = select(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
q = _apply_joins(q, joins, outer_join)
q = q.where(and_(*filters))
if resolved := cls._resolve_load_options(load_options):
q = q.options(*resolved)
@@ -372,8 +415,7 @@ class AsyncCrud(Generic[ModelType]):
raise NotFoundError()
result = cast(ModelType, item)
if schema:
data_out = schema.model_validate(result) if schema else result
return Response(data=data_out)
return Response(data=schema.model_validate(result))
return result
@classmethod
@@ -399,13 +441,7 @@ class AsyncCrud(Generic[ModelType]):
Model instance or None
"""
q = select(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
q = _apply_joins(q, joins, outer_join)
if filters:
q = q.where(and_(*filters))
if resolved := cls._resolve_load_options(load_options):
@@ -442,13 +478,7 @@ class AsyncCrud(Generic[ModelType]):
List of model instances
"""
q = select(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
q = _apply_joins(q, joins, outer_join)
if filters:
q = q.where(and_(*filters))
if resolved := cls._resolve_load_options(load_options):
@@ -546,8 +576,7 @@ class AsyncCrud(Generic[ModelType]):
setattr(db_model, rel_attr, related_instances)
await session.refresh(db_model)
if schema:
data_out = schema.model_validate(db_model) if schema else db_model
return Response(data=data_out)
return Response(data=schema.model_validate(db_model))
return db_model
@classmethod
@@ -664,13 +693,7 @@ class AsyncCrud(Generic[ModelType]):
Number of matching records
"""
q = select(func.count()).select_from(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
q = _apply_joins(q, joins, outer_join)
if filters:
q = q.where(and_(*filters))
result = await session.execute(q)
@@ -697,13 +720,7 @@ class AsyncCrud(Generic[ModelType]):
True if at least one record matches
"""
q = select(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
q = _apply_joins(q, joins, outer_join)
q = q.where(and_(*filters)).exists().select()
result = await session.execute(q)
return bool(result.scalar())
@@ -750,47 +767,29 @@ class AsyncCrud(Generic[ModelType]):
"""
filters = list(filters) if filters else []
offset = (page - 1) * items_per_page
search_joins: list[Any] = []
if isinstance(filter_by, BaseModel):
filter_by = filter_by.model_dump(exclude_none=True) or None
# Build filter_by conditions from declared facet fields
if filter_by:
resolved_facets_for_filter = (
facet_fields if facet_fields is not None else cls.facet_fields
)
fb_filters, fb_joins = build_filter_by(
filter_by, resolved_facets_for_filter or []
)
filters.extend(fb_filters)
search_joins.extend(fb_joins)
fb_filters, search_joins = cls._prepare_filter_by(filter_by, facet_fields)
filters.extend(fb_filters)
# Build search filters
if search:
search_filters, search_joins = build_search_filters(
search_filters, new_search_joins = build_search_filters(
cls.model,
search,
search_fields=search_fields,
default_fields=cls.searchable_fields,
)
filters.extend(search_filters)
search_joins.extend(new_search_joins)
# Build query with joins
q = select(cls.model)
# Apply explicit joins
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
q = _apply_joins(q, joins, outer_join)
# Apply search joins (always outer joins for search)
for join_rel in search_joins:
q = q.outerjoin(join_rel)
q = _apply_search_joins(q, search_joins)
if filters:
q = q.where(and_(*filters))
@@ -810,17 +809,10 @@ class AsyncCrud(Generic[ModelType]):
count_q = count_q.select_from(cls.model)
# Apply explicit joins to count query
if joins:
for model, condition in joins:
count_q = (
count_q.outerjoin(model, condition)
if outer_join
else count_q.join(model, condition)
)
count_q = _apply_joins(count_q, joins, outer_join)
# Apply search joins to count query
for join_rel in search_joins:
count_q = count_q.outerjoin(join_rel)
count_q = _apply_search_joins(count_q, search_joins)
if filters:
count_q = count_q.where(and_(*filters))
@@ -828,19 +820,9 @@ class AsyncCrud(Generic[ModelType]):
count_result = await session.execute(count_q)
total_count = count_result.scalar_one()
# Build facets
resolved_facet_fields = (
facet_fields if facet_fields is not None else cls.facet_fields
filter_attributes = await cls._build_filter_attributes(
session, facet_fields, filters, search_joins
)
filter_attributes: dict[str, list[Any]] | None = None
if resolved_facet_fields:
filter_attributes = await build_facets(
session,
cls.model,
resolved_facet_fields,
base_filters=filters or None,
base_joins=search_joins or None,
)
return PaginatedResponse(
data=items,
@@ -897,21 +879,9 @@ class AsyncCrud(Generic[ModelType]):
PaginatedResponse with CursorPagination metadata
"""
filters = list(filters) if filters else []
search_joins: list[Any] = []
if isinstance(filter_by, BaseModel):
filter_by = filter_by.model_dump(exclude_none=True) or None
# Build filter_by conditions from declared facet fields
if filter_by:
resolved_facets_for_filter = (
facet_fields if facet_fields is not None else cls.facet_fields
)
fb_filters, fb_joins = build_filter_by(
filter_by, resolved_facets_for_filter or []
)
filters.extend(fb_filters)
search_joins.extend(fb_joins)
fb_filters, search_joins = cls._prepare_filter_by(filter_by, facet_fields)
filters.extend(fb_filters)
if cls.cursor_column is None:
raise ValueError(
@@ -944,29 +914,23 @@ class AsyncCrud(Generic[ModelType]):
# Build search filters
if search:
search_filters, search_joins = build_search_filters(
search_filters, new_search_joins = build_search_filters(
cls.model,
search,
search_fields=search_fields,
default_fields=cls.searchable_fields,
)
filters.extend(search_filters)
search_joins.extend(new_search_joins)
# Build query
q = select(cls.model)
# Apply explicit joins
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
q = _apply_joins(q, joins, outer_join)
# Apply search joins (always outer joins)
for join_rel in search_joins:
q = q.outerjoin(join_rel)
q = _apply_search_joins(q, search_joins)
if filters:
q = q.where(and_(*filters))
@@ -998,19 +962,9 @@ class AsyncCrud(Generic[ModelType]):
items: list[Any] = [schema.model_validate(item) for item in items_page]
# Build facets
resolved_facet_fields = (
facet_fields if facet_fields is not None else cls.facet_fields
filter_attributes = await cls._build_filter_attributes(
session, facet_fields, filters, search_joins
)
filter_attributes: dict[str, list[Any]] | None = None
if resolved_facet_fields:
filter_attributes = await build_facets(
session,
cls.model,
resolved_facet_fields,
base_filters=filters or None,
base_joins=search_joins or None,
)
return PaginatedResponse(
data=items,