mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-16 06:36:26 +02:00
feat: support Annotated[AsyncSession, Depends(...)] in PathDependency and BodyDependency (#146)
This commit is contained in:
@@ -13,8 +13,13 @@ The `dependencies` module provides two factory functions that create FastAPI dep
|
||||
```python
|
||||
from fastapi_toolsets.dependencies import PathDependency
|
||||
|
||||
# Plain callable
|
||||
UserDep = PathDependency(model=User, field=User.id, session_dep=get_db)
|
||||
|
||||
# Annotated
|
||||
SessionDep = Annotated[AsyncSession, Depends(get_db)]
|
||||
UserDep = PathDependency(model=User, field=User.id, session_dep=SessionDep)
|
||||
|
||||
@router.get("/users/{user_id}")
|
||||
async def get_user(user: User = UserDep):
|
||||
return user
|
||||
@@ -37,8 +42,14 @@ async def get_user(user: User = UserDep):
|
||||
```python
|
||||
from fastapi_toolsets.dependencies import BodyDependency
|
||||
|
||||
# Plain callable
|
||||
RoleDep = BodyDependency(model=Role, field=Role.id, session_dep=get_db, body_field="role_id")
|
||||
|
||||
# Annotated
|
||||
SessionDep = Annotated[AsyncSession, Depends(get_db)]
|
||||
RoleDep = BodyDependency(model=Role, field=Role.id, session_dep=SessionDep, body_field="role_id")
|
||||
|
||||
|
||||
@router.post("/users")
|
||||
async def create_user(body: UserCreateSchema, role: Role = RoleDep):
|
||||
user = User(username=body.username, role=role)
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
"""Dependency factories for FastAPI routes."""
|
||||
|
||||
import inspect
|
||||
import typing
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi.params import Depends as DependsClass
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from .crud import CrudFactory
|
||||
@@ -13,6 +15,15 @@ from .types import ModelType, SessionDependency
|
||||
__all__ = ["BodyDependency", "PathDependency"]
|
||||
|
||||
|
||||
def _unwrap_session_dep(session_dep: SessionDependency) -> Callable[..., Any]:
|
||||
"""Extract the plain callable from ``Annotated[AsyncSession, Depends(fn)]`` if needed."""
|
||||
if typing.get_origin(session_dep) is typing.Annotated:
|
||||
for arg in typing.get_args(session_dep)[1:]:
|
||||
if isinstance(arg, DependsClass):
|
||||
return arg.dependency
|
||||
return session_dep
|
||||
|
||||
|
||||
def PathDependency(
|
||||
model: type[ModelType],
|
||||
field: Any,
|
||||
@@ -44,6 +55,7 @@ def PathDependency(
|
||||
): ...
|
||||
```
|
||||
"""
|
||||
session_callable = _unwrap_session_dep(session_dep)
|
||||
crud = CrudFactory(model)
|
||||
name = (
|
||||
param_name
|
||||
@@ -53,7 +65,7 @@ def PathDependency(
|
||||
python_type = field.type.python_type
|
||||
|
||||
async def dependency(
|
||||
session: AsyncSession = Depends(session_dep), **kwargs: Any
|
||||
session: AsyncSession = Depends(session_callable), **kwargs: Any
|
||||
) -> ModelType:
|
||||
value = kwargs[name]
|
||||
return await crud.get(session, filters=[field == value])
|
||||
@@ -70,7 +82,7 @@ def PathDependency(
|
||||
"session",
|
||||
inspect.Parameter.KEYWORD_ONLY,
|
||||
annotation=AsyncSession,
|
||||
default=Depends(session_dep),
|
||||
default=Depends(session_callable),
|
||||
),
|
||||
]
|
||||
),
|
||||
@@ -112,11 +124,12 @@ def BodyDependency(
|
||||
): ...
|
||||
```
|
||||
"""
|
||||
session_callable = _unwrap_session_dep(session_dep)
|
||||
crud = CrudFactory(model)
|
||||
python_type = field.type.python_type
|
||||
|
||||
async def dependency(
|
||||
session: AsyncSession = Depends(session_dep), **kwargs: Any
|
||||
session: AsyncSession = Depends(session_callable), **kwargs: Any
|
||||
) -> ModelType:
|
||||
value = kwargs[body_field]
|
||||
return await crud.get(session, filters=[field == value])
|
||||
@@ -133,7 +146,7 @@ def BodyDependency(
|
||||
"session",
|
||||
inspect.Parameter.KEYWORD_ONLY,
|
||||
annotation=AsyncSession,
|
||||
default=Depends(session_dep),
|
||||
default=Depends(session_callable),
|
||||
),
|
||||
]
|
||||
),
|
||||
|
||||
@@ -24,4 +24,4 @@ SearchFieldType = InstrumentedAttribute[Any] | tuple[InstrumentedAttribute[Any],
|
||||
FacetFieldType = SearchFieldType
|
||||
|
||||
# Dependency type aliases
|
||||
SessionDependency = Callable[[], AsyncGenerator[AsyncSession, None]]
|
||||
SessionDependency = Callable[[], AsyncGenerator[AsyncSession, None]] | Any
|
||||
|
||||
@@ -3,13 +3,17 @@
|
||||
import inspect
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any, cast
|
||||
from typing import Annotated, Any, cast
|
||||
|
||||
import pytest
|
||||
from fastapi.params import Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from fastapi_toolsets.dependencies import BodyDependency, PathDependency
|
||||
from fastapi_toolsets.dependencies import (
|
||||
BodyDependency,
|
||||
PathDependency,
|
||||
_unwrap_session_dep,
|
||||
)
|
||||
|
||||
from .conftest import Role, RoleCreate, RoleCrud, User
|
||||
|
||||
@@ -19,6 +23,24 @@ async def mock_get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
yield None
|
||||
|
||||
|
||||
MockSessionDep = Annotated[AsyncSession, Depends(mock_get_db)]
|
||||
|
||||
|
||||
class TestUnwrapSessionDep:
|
||||
def test_plain_callable_returned_as_is(self):
|
||||
"""Plain callable is returned unchanged."""
|
||||
assert _unwrap_session_dep(mock_get_db) is mock_get_db
|
||||
|
||||
def test_annotated_with_depends_unwrapped(self):
|
||||
"""Annotated form with Depends is unwrapped to the plain callable."""
|
||||
assert _unwrap_session_dep(MockSessionDep) is mock_get_db
|
||||
|
||||
def test_annotated_without_depends_returned_as_is(self):
|
||||
"""Annotated form with no Depends falls back to returning session_dep as-is."""
|
||||
annotated_no_dep = Annotated[AsyncSession, "not_a_depends"]
|
||||
assert _unwrap_session_dep(annotated_no_dep) is annotated_no_dep
|
||||
|
||||
|
||||
class TestPathDependency:
|
||||
"""Tests for PathDependency factory."""
|
||||
|
||||
@@ -95,6 +117,39 @@ class TestPathDependency:
|
||||
assert result.id == role.id
|
||||
assert result.name == "test_role"
|
||||
|
||||
def test_annotated_session_dep_returns_depends_instance(self):
|
||||
"""PathDependency accepts Annotated[AsyncSession, Depends(...)] form."""
|
||||
dep = PathDependency(Role, Role.id, session_dep=MockSessionDep)
|
||||
assert isinstance(dep, Depends)
|
||||
|
||||
def test_annotated_session_dep_signature(self):
|
||||
"""PathDependency with Annotated session_dep produces a valid signature."""
|
||||
dep = cast(Any, PathDependency(Role, Role.id, session_dep=MockSessionDep))
|
||||
sig = inspect.signature(dep.dependency)
|
||||
|
||||
assert "role_id" in sig.parameters
|
||||
assert "session" in sig.parameters
|
||||
assert isinstance(sig.parameters["session"].default, Depends)
|
||||
|
||||
def test_annotated_session_dep_unwraps_callable(self):
|
||||
"""PathDependency with Annotated form uses the underlying callable, not the Annotated type."""
|
||||
dep = cast(Any, PathDependency(Role, Role.id, session_dep=MockSessionDep))
|
||||
sig = inspect.signature(dep.dependency)
|
||||
|
||||
inner_dep = sig.parameters["session"].default
|
||||
assert inner_dep.dependency is mock_get_db
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_annotated_session_dep_fetches_object(self, db_session):
|
||||
"""PathDependency with Annotated session_dep correctly fetches object from database."""
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="annotated_role"))
|
||||
|
||||
dep = cast(Any, PathDependency(Role, Role.id, session_dep=MockSessionDep))
|
||||
result = await dep.dependency(session=db_session, role_id=role.id)
|
||||
|
||||
assert result.id == role.id
|
||||
assert result.name == "annotated_role"
|
||||
|
||||
|
||||
class TestBodyDependency:
|
||||
"""Tests for BodyDependency factory."""
|
||||
@@ -184,3 +239,39 @@ class TestBodyDependency:
|
||||
|
||||
assert result.id == role.id
|
||||
assert result.name == "body_test_role"
|
||||
|
||||
def test_annotated_session_dep_returns_depends_instance(self):
|
||||
"""BodyDependency accepts Annotated[AsyncSession, Depends(...)] form."""
|
||||
dep = BodyDependency(
|
||||
Role, Role.id, session_dep=MockSessionDep, body_field="role_id"
|
||||
)
|
||||
assert isinstance(dep, Depends)
|
||||
|
||||
def test_annotated_session_dep_unwraps_callable(self):
|
||||
"""BodyDependency with Annotated form uses the underlying callable, not the Annotated type."""
|
||||
dep = cast(
|
||||
Any,
|
||||
BodyDependency(
|
||||
Role, Role.id, session_dep=MockSessionDep, body_field="role_id"
|
||||
),
|
||||
)
|
||||
sig = inspect.signature(dep.dependency)
|
||||
|
||||
inner_dep = sig.parameters["session"].default
|
||||
assert inner_dep.dependency is mock_get_db
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_annotated_session_dep_fetches_object(self, db_session):
|
||||
"""BodyDependency with Annotated session_dep correctly fetches object from database."""
|
||||
role = await RoleCrud.create(db_session, RoleCreate(name="body_annotated_role"))
|
||||
|
||||
dep = cast(
|
||||
Any,
|
||||
BodyDependency(
|
||||
Role, Role.id, session_dep=MockSessionDep, body_field="role_id"
|
||||
),
|
||||
)
|
||||
result = await dep.dependency(session=db_session, role_id=role.id)
|
||||
|
||||
assert result.id == role.id
|
||||
assert result.name == "body_annotated_role"
|
||||
|
||||
Reference in New Issue
Block a user