mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-16 06:36:26 +02:00
feat: support Annotated[AsyncSession, Depends(...)] in PathDependency and BodyDependency (#146)
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
"""Dependency factories for FastAPI routes."""
|
||||
|
||||
import inspect
|
||||
import typing
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi.params import Depends as DependsClass
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from .crud import CrudFactory
|
||||
@@ -13,6 +15,15 @@ from .types import ModelType, SessionDependency
|
||||
__all__ = ["BodyDependency", "PathDependency"]
|
||||
|
||||
|
||||
def _unwrap_session_dep(session_dep: SessionDependency) -> Callable[..., Any]:
|
||||
"""Extract the plain callable from ``Annotated[AsyncSession, Depends(fn)]`` if needed."""
|
||||
if typing.get_origin(session_dep) is typing.Annotated:
|
||||
for arg in typing.get_args(session_dep)[1:]:
|
||||
if isinstance(arg, DependsClass):
|
||||
return arg.dependency
|
||||
return session_dep
|
||||
|
||||
|
||||
def PathDependency(
|
||||
model: type[ModelType],
|
||||
field: Any,
|
||||
@@ -44,6 +55,7 @@ def PathDependency(
|
||||
): ...
|
||||
```
|
||||
"""
|
||||
session_callable = _unwrap_session_dep(session_dep)
|
||||
crud = CrudFactory(model)
|
||||
name = (
|
||||
param_name
|
||||
@@ -53,7 +65,7 @@ def PathDependency(
|
||||
python_type = field.type.python_type
|
||||
|
||||
async def dependency(
|
||||
session: AsyncSession = Depends(session_dep), **kwargs: Any
|
||||
session: AsyncSession = Depends(session_callable), **kwargs: Any
|
||||
) -> ModelType:
|
||||
value = kwargs[name]
|
||||
return await crud.get(session, filters=[field == value])
|
||||
@@ -70,7 +82,7 @@ def PathDependency(
|
||||
"session",
|
||||
inspect.Parameter.KEYWORD_ONLY,
|
||||
annotation=AsyncSession,
|
||||
default=Depends(session_dep),
|
||||
default=Depends(session_callable),
|
||||
),
|
||||
]
|
||||
),
|
||||
@@ -112,11 +124,12 @@ def BodyDependency(
|
||||
): ...
|
||||
```
|
||||
"""
|
||||
session_callable = _unwrap_session_dep(session_dep)
|
||||
crud = CrudFactory(model)
|
||||
python_type = field.type.python_type
|
||||
|
||||
async def dependency(
|
||||
session: AsyncSession = Depends(session_dep), **kwargs: Any
|
||||
session: AsyncSession = Depends(session_callable), **kwargs: Any
|
||||
) -> ModelType:
|
||||
value = kwargs[body_field]
|
||||
return await crud.get(session, filters=[field == value])
|
||||
@@ -133,7 +146,7 @@ def BodyDependency(
|
||||
"session",
|
||||
inspect.Parameter.KEYWORD_ONLY,
|
||||
annotation=AsyncSession,
|
||||
default=Depends(session_dep),
|
||||
default=Depends(session_callable),
|
||||
),
|
||||
]
|
||||
),
|
||||
|
||||
Reference in New Issue
Block a user