From e0828c7e71aa171c665b1a5cdd07d6fe2b38354c Mon Sep 17 00:00:00 2001 From: d3vyce Date: Sun, 1 Mar 2026 07:40:18 -0500 Subject: [PATCH] refactor: centralize type aliases in types.py and simplify crud layer --- src/fastapi_toolsets/crud/__init__.py | 9 +- src/fastapi_toolsets/crud/factory.py | 220 ++++++++++--------------- src/fastapi_toolsets/crud/search.py | 25 ++- src/fastapi_toolsets/dependencies.py | 9 +- src/fastapi_toolsets/fixtures/utils.py | 9 +- src/fastapi_toolsets/schemas.py | 6 +- src/fastapi_toolsets/types.py | 27 +++ 7 files changed, 137 insertions(+), 168 deletions(-) create mode 100644 src/fastapi_toolsets/types.py diff --git a/src/fastapi_toolsets/crud/__init__.py b/src/fastapi_toolsets/crud/__init__.py index 68c6fe5..59058ad 100644 --- a/src/fastapi_toolsets/crud/__init__.py +++ b/src/fastapi_toolsets/crud/__init__.py @@ -1,12 +1,9 @@ """Generic async CRUD operations for SQLAlchemy models.""" from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError -from .factory import CrudFactory, JoinType, M2MFieldType, OrderByClause -from .search import ( - FacetFieldType, - SearchConfig, - get_searchable_fields, -) +from ..types import FacetFieldType, JoinType, M2MFieldType, OrderByClause +from .factory import CrudFactory +from .search import SearchConfig, get_searchable_fields __all__ = [ "CrudFactory", diff --git a/src/fastapi_toolsets/crud/factory.py b/src/fastapi_toolsets/crud/factory.py index 9a272c2..829fe47 100644 --- a/src/fastapi_toolsets/crud/factory.py +++ b/src/fastapi_toolsets/crud/factory.py @@ -6,10 +6,10 @@ import base64 import inspect import json import uuid as uuid_module -from collections.abc import Awaitable, Callable, Mapping, Sequence +from collections.abc import Awaitable, Callable, Sequence from datetime import date, datetime from decimal import Decimal -from typing import Any, ClassVar, Generic, Literal, Self, TypeVar, cast, overload +from typing import Any, ClassVar, Generic, Literal, Self, cast, overload from fastapi import Query from pydantic import BaseModel @@ -20,28 +20,28 @@ from sqlalchemy.exc import NoResultFound from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import DeclarativeBase, QueryableAttribute, selectinload from sqlalchemy.sql.base import ExecutableOption -from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.roles import WhereHavingRole from ..db import get_transaction from ..exceptions import InvalidOrderFieldError, NotFoundError from ..schemas import CursorPagination, OffsetPagination, PaginatedResponse, Response -from .search import ( +from ..types import ( FacetFieldType, - SearchConfig, + JoinType, + M2MFieldType, + ModelType, + OrderByClause, + SchemaType, SearchFieldType, +) +from .search import ( + SearchConfig, build_facets, build_filter_by, build_search_filters, facet_keys, ) -ModelType = TypeVar("ModelType", bound=DeclarativeBase) -SchemaType = TypeVar("SchemaType", bound=BaseModel) -JoinType = list[tuple[type[DeclarativeBase], Any]] -M2MFieldType = Mapping[str, QueryableAttribute[Any]] -OrderByClause = ColumnElement[Any] | QueryableAttribute[Any] - def _encode_cursor(value: Any) -> str: """Encode cursor column value as an base64 string.""" @@ -53,6 +53,22 @@ def _decode_cursor(cursor: str) -> str: return json.loads(base64.b64decode(cursor.encode()).decode()) +def _apply_joins(q: Any, joins: JoinType | None, outer_join: bool) -> Any: + """Apply a list of (model, condition) joins to a SQLAlchemy select query.""" + if not joins: + return q + for model, condition in joins: + q = q.outerjoin(model, condition) if outer_join else q.join(model, condition) + return q + + +def _apply_search_joins(q: Any, search_joins: list[Any]) -> Any: + """Apply relationship-based outer joins (from search/filter_by) to a query.""" + for join_rel in search_joins: + q = q.outerjoin(join_rel) + return q + + class AsyncCrud(Generic[ModelType]): """Generic async CRUD operations for SQLAlchemy models. @@ -132,6 +148,40 @@ class AsyncCrud(Generic[ModelType]): return set() return set(cls.m2m_fields.keys()) + @classmethod + def _prepare_filter_by( + cls: type[Self], + filter_by: dict[str, Any] | BaseModel | None, + facet_fields: Sequence[FacetFieldType] | None, + ) -> tuple[list[Any], list[Any]]: + """Normalize filter_by and return (filters, joins) to apply to the query.""" + if isinstance(filter_by, BaseModel): + filter_by = filter_by.model_dump(exclude_none=True) or None + if not filter_by: + return [], [] + resolved = facet_fields if facet_fields is not None else cls.facet_fields + return build_filter_by(filter_by, resolved or []) + + @classmethod + async def _build_filter_attributes( + cls: type[Self], + session: AsyncSession, + facet_fields: Sequence[FacetFieldType] | None, + filters: list[Any], + search_joins: list[Any], + ) -> dict[str, list[Any]] | None: + """Build facet filter_attributes, or return None if no facet fields configured.""" + resolved = facet_fields if facet_fields is not None else cls.facet_fields + if not resolved: + return None + return await build_facets( + session, + cls.model, + resolved, + base_filters=filters or None, + base_joins=search_joins or None, + ) + @classmethod def filter_params( cls: type[Self], @@ -290,8 +340,7 @@ class AsyncCrud(Generic[ModelType]): await session.refresh(db_model) result = cast(ModelType, db_model) if schema: - data_out = schema.model_validate(result) if schema else result - return Response(data=data_out) + return Response(data=schema.model_validate(result)) return result @overload @@ -354,13 +403,7 @@ class AsyncCrud(Generic[ModelType]): MultipleResultsFound: If more than one record found """ q = select(cls.model) - if joins: - for model, condition in joins: - q = ( - q.outerjoin(model, condition) - if outer_join - else q.join(model, condition) - ) + q = _apply_joins(q, joins, outer_join) q = q.where(and_(*filters)) if resolved := cls._resolve_load_options(load_options): q = q.options(*resolved) @@ -372,8 +415,7 @@ class AsyncCrud(Generic[ModelType]): raise NotFoundError() result = cast(ModelType, item) if schema: - data_out = schema.model_validate(result) if schema else result - return Response(data=data_out) + return Response(data=schema.model_validate(result)) return result @classmethod @@ -399,13 +441,7 @@ class AsyncCrud(Generic[ModelType]): Model instance or None """ q = select(cls.model) - if joins: - for model, condition in joins: - q = ( - q.outerjoin(model, condition) - if outer_join - else q.join(model, condition) - ) + q = _apply_joins(q, joins, outer_join) if filters: q = q.where(and_(*filters)) if resolved := cls._resolve_load_options(load_options): @@ -442,13 +478,7 @@ class AsyncCrud(Generic[ModelType]): List of model instances """ q = select(cls.model) - if joins: - for model, condition in joins: - q = ( - q.outerjoin(model, condition) - if outer_join - else q.join(model, condition) - ) + q = _apply_joins(q, joins, outer_join) if filters: q = q.where(and_(*filters)) if resolved := cls._resolve_load_options(load_options): @@ -546,8 +576,7 @@ class AsyncCrud(Generic[ModelType]): setattr(db_model, rel_attr, related_instances) await session.refresh(db_model) if schema: - data_out = schema.model_validate(db_model) if schema else db_model - return Response(data=data_out) + return Response(data=schema.model_validate(db_model)) return db_model @classmethod @@ -664,13 +693,7 @@ class AsyncCrud(Generic[ModelType]): Number of matching records """ q = select(func.count()).select_from(cls.model) - if joins: - for model, condition in joins: - q = ( - q.outerjoin(model, condition) - if outer_join - else q.join(model, condition) - ) + q = _apply_joins(q, joins, outer_join) if filters: q = q.where(and_(*filters)) result = await session.execute(q) @@ -697,13 +720,7 @@ class AsyncCrud(Generic[ModelType]): True if at least one record matches """ q = select(cls.model) - if joins: - for model, condition in joins: - q = ( - q.outerjoin(model, condition) - if outer_join - else q.join(model, condition) - ) + q = _apply_joins(q, joins, outer_join) q = q.where(and_(*filters)).exists().select() result = await session.execute(q) return bool(result.scalar()) @@ -750,47 +767,29 @@ class AsyncCrud(Generic[ModelType]): """ filters = list(filters) if filters else [] offset = (page - 1) * items_per_page - search_joins: list[Any] = [] - if isinstance(filter_by, BaseModel): - filter_by = filter_by.model_dump(exclude_none=True) or None - - # Build filter_by conditions from declared facet fields - if filter_by: - resolved_facets_for_filter = ( - facet_fields if facet_fields is not None else cls.facet_fields - ) - fb_filters, fb_joins = build_filter_by( - filter_by, resolved_facets_for_filter or [] - ) - filters.extend(fb_filters) - search_joins.extend(fb_joins) + fb_filters, search_joins = cls._prepare_filter_by(filter_by, facet_fields) + filters.extend(fb_filters) # Build search filters if search: - search_filters, search_joins = build_search_filters( + search_filters, new_search_joins = build_search_filters( cls.model, search, search_fields=search_fields, default_fields=cls.searchable_fields, ) filters.extend(search_filters) + search_joins.extend(new_search_joins) # Build query with joins q = select(cls.model) # Apply explicit joins - if joins: - for model, condition in joins: - q = ( - q.outerjoin(model, condition) - if outer_join - else q.join(model, condition) - ) + q = _apply_joins(q, joins, outer_join) # Apply search joins (always outer joins for search) - for join_rel in search_joins: - q = q.outerjoin(join_rel) + q = _apply_search_joins(q, search_joins) if filters: q = q.where(and_(*filters)) @@ -810,17 +809,10 @@ class AsyncCrud(Generic[ModelType]): count_q = count_q.select_from(cls.model) # Apply explicit joins to count query - if joins: - for model, condition in joins: - count_q = ( - count_q.outerjoin(model, condition) - if outer_join - else count_q.join(model, condition) - ) + count_q = _apply_joins(count_q, joins, outer_join) # Apply search joins to count query - for join_rel in search_joins: - count_q = count_q.outerjoin(join_rel) + count_q = _apply_search_joins(count_q, search_joins) if filters: count_q = count_q.where(and_(*filters)) @@ -828,19 +820,9 @@ class AsyncCrud(Generic[ModelType]): count_result = await session.execute(count_q) total_count = count_result.scalar_one() - # Build facets - resolved_facet_fields = ( - facet_fields if facet_fields is not None else cls.facet_fields + filter_attributes = await cls._build_filter_attributes( + session, facet_fields, filters, search_joins ) - filter_attributes: dict[str, list[Any]] | None = None - if resolved_facet_fields: - filter_attributes = await build_facets( - session, - cls.model, - resolved_facet_fields, - base_filters=filters or None, - base_joins=search_joins or None, - ) return PaginatedResponse( data=items, @@ -897,21 +879,9 @@ class AsyncCrud(Generic[ModelType]): PaginatedResponse with CursorPagination metadata """ filters = list(filters) if filters else [] - search_joins: list[Any] = [] - if isinstance(filter_by, BaseModel): - filter_by = filter_by.model_dump(exclude_none=True) or None - - # Build filter_by conditions from declared facet fields - if filter_by: - resolved_facets_for_filter = ( - facet_fields if facet_fields is not None else cls.facet_fields - ) - fb_filters, fb_joins = build_filter_by( - filter_by, resolved_facets_for_filter or [] - ) - filters.extend(fb_filters) - search_joins.extend(fb_joins) + fb_filters, search_joins = cls._prepare_filter_by(filter_by, facet_fields) + filters.extend(fb_filters) if cls.cursor_column is None: raise ValueError( @@ -944,29 +914,23 @@ class AsyncCrud(Generic[ModelType]): # Build search filters if search: - search_filters, search_joins = build_search_filters( + search_filters, new_search_joins = build_search_filters( cls.model, search, search_fields=search_fields, default_fields=cls.searchable_fields, ) filters.extend(search_filters) + search_joins.extend(new_search_joins) # Build query q = select(cls.model) # Apply explicit joins - if joins: - for model, condition in joins: - q = ( - q.outerjoin(model, condition) - if outer_join - else q.join(model, condition) - ) + q = _apply_joins(q, joins, outer_join) # Apply search joins (always outer joins) - for join_rel in search_joins: - q = q.outerjoin(join_rel) + q = _apply_search_joins(q, search_joins) if filters: q = q.where(and_(*filters)) @@ -998,19 +962,9 @@ class AsyncCrud(Generic[ModelType]): items: list[Any] = [schema.model_validate(item) for item in items_page] - # Build facets - resolved_facet_fields = ( - facet_fields if facet_fields is not None else cls.facet_fields + filter_attributes = await cls._build_filter_attributes( + session, facet_fields, filters, search_joins ) - filter_attributes: dict[str, list[Any]] | None = None - if resolved_facet_fields: - filter_attributes = await build_facets( - session, - cls.model, - resolved_facet_fields, - base_filters=filters or None, - base_joins=search_joins or None, - ) return PaginatedResponse( data=items, diff --git a/src/fastapi_toolsets/crud/search.py b/src/fastapi_toolsets/crud/search.py index 6e01df7..efc5f31 100644 --- a/src/fastapi_toolsets/crud/search.py +++ b/src/fastapi_toolsets/crud/search.py @@ -1,24 +1,23 @@ """Search utilities for AsyncCrud.""" import asyncio +import functools from collections import Counter from collections.abc import Sequence -from dataclasses import dataclass +from dataclasses import dataclass, replace from typing import TYPE_CHECKING, Any, Literal -from sqlalchemy import String, or_, select +from sqlalchemy import String, and_, or_, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm.attributes import InstrumentedAttribute from ..exceptions import InvalidFacetFilterError, NoSearchableFieldsError +from ..types import FacetFieldType, SearchFieldType if TYPE_CHECKING: from sqlalchemy.sql.elements import ColumnElement -SearchFieldType = InstrumentedAttribute[Any] | tuple[InstrumentedAttribute[Any], ...] -FacetFieldType = SearchFieldType - @dataclass class SearchConfig: @@ -37,6 +36,7 @@ class SearchConfig: match_mode: Literal["any", "all"] = "any" +@functools.lru_cache(maxsize=128) def get_searchable_fields( model: type[DeclarativeBase], *, @@ -101,14 +101,11 @@ def build_search_filters( if isinstance(search, str): config = SearchConfig(query=search, fields=search_fields) else: - config = search - if search_fields is not None: - config = SearchConfig( - query=config.query, - fields=search_fields, - case_sensitive=config.case_sensitive, - match_mode=config.match_mode, - ) + config = ( + replace(search, fields=search_fields) + if search_fields is not None + else search + ) if not config.query or not config.query.strip(): return [], [] @@ -227,8 +224,6 @@ async def build_facets( q = q.outerjoin(rel) if base_filters: - from sqlalchemy import and_ - q = q.where(and_(*base_filters)) q = q.order_by(column) diff --git a/src/fastapi_toolsets/dependencies.py b/src/fastapi_toolsets/dependencies.py index 90d8468..26eb75c 100644 --- a/src/fastapi_toolsets/dependencies.py +++ b/src/fastapi_toolsets/dependencies.py @@ -1,20 +1,17 @@ """Dependency factories for FastAPI routes.""" import inspect -from collections.abc import AsyncGenerator, Callable -from typing import Any, TypeVar, cast +from collections.abc import Callable +from typing import Any, cast from fastapi import Depends from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import DeclarativeBase from .crud import CrudFactory +from .types import ModelType, SessionDependency __all__ = ["BodyDependency", "PathDependency"] -ModelType = TypeVar("ModelType", bound=DeclarativeBase) -SessionDependency = Callable[[], AsyncGenerator[AsyncSession, None]] - def PathDependency( model: type[ModelType], diff --git a/src/fastapi_toolsets/fixtures/utils.py b/src/fastapi_toolsets/fixtures/utils.py index b2e8cef..1263c90 100644 --- a/src/fastapi_toolsets/fixtures/utils.py +++ b/src/fastapi_toolsets/fixtures/utils.py @@ -1,24 +1,23 @@ """Fixture loading utilities for database seeding.""" from collections.abc import Callable, Sequence -from typing import Any, TypeVar +from typing import Any from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import DeclarativeBase from ..db import get_transaction from ..logger import get_logger +from ..types import ModelType from .enum import LoadStrategy from .registry import Context, FixtureRegistry logger = get_logger() -T = TypeVar("T", bound=DeclarativeBase) - def get_obj_by_attr( - fixtures: Callable[[], Sequence[T]], attr_name: str, value: Any -) -> T: + fixtures: Callable[[], Sequence[ModelType]], attr_name: str, value: Any +) -> ModelType: """Get a SQLAlchemy model instance by matching an attribute value. Args: diff --git a/src/fastapi_toolsets/schemas.py b/src/fastapi_toolsets/schemas.py index e69090b..80016cd 100644 --- a/src/fastapi_toolsets/schemas.py +++ b/src/fastapi_toolsets/schemas.py @@ -1,10 +1,12 @@ """Base Pydantic schemas for API responses.""" from enum import Enum -from typing import Any, ClassVar, Generic, TypeVar +from typing import Any, ClassVar, Generic from pydantic import BaseModel, ConfigDict +from .types import DataT + __all__ = [ "ApiError", "CursorPagination", @@ -16,8 +18,6 @@ __all__ = [ "ResponseStatus", ] -DataT = TypeVar("DataT") - class PydanticBase(BaseModel): """Base class for all Pydantic models with common configuration.""" diff --git a/src/fastapi_toolsets/types.py b/src/fastapi_toolsets/types.py new file mode 100644 index 0000000..1941781 --- /dev/null +++ b/src/fastapi_toolsets/types.py @@ -0,0 +1,27 @@ +"""Shared type aliases for the fastapi-toolsets package.""" + +from collections.abc import AsyncGenerator, Callable, Mapping +from typing import Any, TypeVar + +from pydantic import BaseModel +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import DeclarativeBase, QueryableAttribute +from sqlalchemy.orm.attributes import InstrumentedAttribute +from sqlalchemy.sql.elements import ColumnElement + +# Generic TypeVars +DataT = TypeVar("DataT") +ModelType = TypeVar("ModelType", bound=DeclarativeBase) +SchemaType = TypeVar("SchemaType", bound=BaseModel) + +# CRUD type aliases +JoinType = list[tuple[type[DeclarativeBase], Any]] +M2MFieldType = Mapping[str, QueryableAttribute[Any]] +OrderByClause = ColumnElement[Any] | QueryableAttribute[Any] + +# Search / facet type aliases +SearchFieldType = InstrumentedAttribute[Any] | tuple[InstrumentedAttribute[Any], ...] +FacetFieldType = SearchFieldType + +# Dependency type aliases +SessionDependency = Callable[[], AsyncGenerator[AsyncSession, None]]