feat: support Annotated[AsyncSession, Depends(...)] in PathDependency and BodyDependency (#146)

This commit is contained in:
d3vyce
2026-03-18 20:10:58 +01:00
committed by GitHub
parent 2c494fcd17
commit 2e9c6c0c90
4 changed files with 122 additions and 7 deletions

View File

@@ -1,10 +1,12 @@
"""Dependency factories for FastAPI routes."""
import inspect
import typing
from collections.abc import Callable
from typing import Any, cast
from fastapi import Depends
from fastapi.params import Depends as DependsClass
from sqlalchemy.ext.asyncio import AsyncSession
from .crud import CrudFactory
@@ -13,6 +15,15 @@ from .types import ModelType, SessionDependency
__all__ = ["BodyDependency", "PathDependency"]
def _unwrap_session_dep(session_dep: SessionDependency) -> Callable[..., Any]:
"""Extract the plain callable from ``Annotated[AsyncSession, Depends(fn)]`` if needed."""
if typing.get_origin(session_dep) is typing.Annotated:
for arg in typing.get_args(session_dep)[1:]:
if isinstance(arg, DependsClass):
return arg.dependency
return session_dep
def PathDependency(
model: type[ModelType],
field: Any,
@@ -44,6 +55,7 @@ def PathDependency(
): ...
```
"""
session_callable = _unwrap_session_dep(session_dep)
crud = CrudFactory(model)
name = (
param_name
@@ -53,7 +65,7 @@ def PathDependency(
python_type = field.type.python_type
async def dependency(
session: AsyncSession = Depends(session_dep), **kwargs: Any
session: AsyncSession = Depends(session_callable), **kwargs: Any
) -> ModelType:
value = kwargs[name]
return await crud.get(session, filters=[field == value])
@@ -70,7 +82,7 @@ def PathDependency(
"session",
inspect.Parameter.KEYWORD_ONLY,
annotation=AsyncSession,
default=Depends(session_dep),
default=Depends(session_callable),
),
]
),
@@ -112,11 +124,12 @@ def BodyDependency(
): ...
```
"""
session_callable = _unwrap_session_dep(session_dep)
crud = CrudFactory(model)
python_type = field.type.python_type
async def dependency(
session: AsyncSession = Depends(session_dep), **kwargs: Any
session: AsyncSession = Depends(session_callable), **kwargs: Any
) -> ModelType:
value = kwargs[body_field]
return await crud.get(session, filters=[field == value])
@@ -133,7 +146,7 @@ def BodyDependency(
"session",
inspect.Parameter.KEYWORD_ONLY,
annotation=AsyncSession,
default=Depends(session_dep),
default=Depends(session_callable),
),
]
),

View File

@@ -24,4 +24,4 @@ SearchFieldType = InstrumentedAttribute[Any] | tuple[InstrumentedAttribute[Any],
FacetFieldType = SearchFieldType
# Dependency type aliases
SessionDependency = Callable[[], AsyncGenerator[AsyncSession, None]]
SessionDependency = Callable[[], AsyncGenerator[AsyncSession, None]] | Any