mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-04-15 22:26:25 +02:00
feat: support Annotated[AsyncSession, Depends(...)] in PathDependency and BodyDependency (#146)
This commit is contained in:
@@ -13,8 +13,13 @@ The `dependencies` module provides two factory functions that create FastAPI dep
|
|||||||
```python
|
```python
|
||||||
from fastapi_toolsets.dependencies import PathDependency
|
from fastapi_toolsets.dependencies import PathDependency
|
||||||
|
|
||||||
|
# Plain callable
|
||||||
UserDep = PathDependency(model=User, field=User.id, session_dep=get_db)
|
UserDep = PathDependency(model=User, field=User.id, session_dep=get_db)
|
||||||
|
|
||||||
|
# Annotated
|
||||||
|
SessionDep = Annotated[AsyncSession, Depends(get_db)]
|
||||||
|
UserDep = PathDependency(model=User, field=User.id, session_dep=SessionDep)
|
||||||
|
|
||||||
@router.get("/users/{user_id}")
|
@router.get("/users/{user_id}")
|
||||||
async def get_user(user: User = UserDep):
|
async def get_user(user: User = UserDep):
|
||||||
return user
|
return user
|
||||||
@@ -37,8 +42,14 @@ async def get_user(user: User = UserDep):
|
|||||||
```python
|
```python
|
||||||
from fastapi_toolsets.dependencies import BodyDependency
|
from fastapi_toolsets.dependencies import BodyDependency
|
||||||
|
|
||||||
|
# Plain callable
|
||||||
RoleDep = BodyDependency(model=Role, field=Role.id, session_dep=get_db, body_field="role_id")
|
RoleDep = BodyDependency(model=Role, field=Role.id, session_dep=get_db, body_field="role_id")
|
||||||
|
|
||||||
|
# Annotated
|
||||||
|
SessionDep = Annotated[AsyncSession, Depends(get_db)]
|
||||||
|
RoleDep = BodyDependency(model=Role, field=Role.id, session_dep=SessionDep, body_field="role_id")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/users")
|
@router.post("/users")
|
||||||
async def create_user(body: UserCreateSchema, role: Role = RoleDep):
|
async def create_user(body: UserCreateSchema, role: Role = RoleDep):
|
||||||
user = User(username=body.username, role=role)
|
user = User(username=body.username, role=role)
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
"""Dependency factories for FastAPI routes."""
|
"""Dependency factories for FastAPI routes."""
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
|
import typing
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
|
from fastapi.params import Depends as DependsClass
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from .crud import CrudFactory
|
from .crud import CrudFactory
|
||||||
@@ -13,6 +15,15 @@ from .types import ModelType, SessionDependency
|
|||||||
__all__ = ["BodyDependency", "PathDependency"]
|
__all__ = ["BodyDependency", "PathDependency"]
|
||||||
|
|
||||||
|
|
||||||
|
def _unwrap_session_dep(session_dep: SessionDependency) -> Callable[..., Any]:
|
||||||
|
"""Extract the plain callable from ``Annotated[AsyncSession, Depends(fn)]`` if needed."""
|
||||||
|
if typing.get_origin(session_dep) is typing.Annotated:
|
||||||
|
for arg in typing.get_args(session_dep)[1:]:
|
||||||
|
if isinstance(arg, DependsClass):
|
||||||
|
return arg.dependency
|
||||||
|
return session_dep
|
||||||
|
|
||||||
|
|
||||||
def PathDependency(
|
def PathDependency(
|
||||||
model: type[ModelType],
|
model: type[ModelType],
|
||||||
field: Any,
|
field: Any,
|
||||||
@@ -44,6 +55,7 @@ def PathDependency(
|
|||||||
): ...
|
): ...
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
session_callable = _unwrap_session_dep(session_dep)
|
||||||
crud = CrudFactory(model)
|
crud = CrudFactory(model)
|
||||||
name = (
|
name = (
|
||||||
param_name
|
param_name
|
||||||
@@ -53,7 +65,7 @@ def PathDependency(
|
|||||||
python_type = field.type.python_type
|
python_type = field.type.python_type
|
||||||
|
|
||||||
async def dependency(
|
async def dependency(
|
||||||
session: AsyncSession = Depends(session_dep), **kwargs: Any
|
session: AsyncSession = Depends(session_callable), **kwargs: Any
|
||||||
) -> ModelType:
|
) -> ModelType:
|
||||||
value = kwargs[name]
|
value = kwargs[name]
|
||||||
return await crud.get(session, filters=[field == value])
|
return await crud.get(session, filters=[field == value])
|
||||||
@@ -70,7 +82,7 @@ def PathDependency(
|
|||||||
"session",
|
"session",
|
||||||
inspect.Parameter.KEYWORD_ONLY,
|
inspect.Parameter.KEYWORD_ONLY,
|
||||||
annotation=AsyncSession,
|
annotation=AsyncSession,
|
||||||
default=Depends(session_dep),
|
default=Depends(session_callable),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
@@ -112,11 +124,12 @@ def BodyDependency(
|
|||||||
): ...
|
): ...
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
session_callable = _unwrap_session_dep(session_dep)
|
||||||
crud = CrudFactory(model)
|
crud = CrudFactory(model)
|
||||||
python_type = field.type.python_type
|
python_type = field.type.python_type
|
||||||
|
|
||||||
async def dependency(
|
async def dependency(
|
||||||
session: AsyncSession = Depends(session_dep), **kwargs: Any
|
session: AsyncSession = Depends(session_callable), **kwargs: Any
|
||||||
) -> ModelType:
|
) -> ModelType:
|
||||||
value = kwargs[body_field]
|
value = kwargs[body_field]
|
||||||
return await crud.get(session, filters=[field == value])
|
return await crud.get(session, filters=[field == value])
|
||||||
@@ -133,7 +146,7 @@ def BodyDependency(
|
|||||||
"session",
|
"session",
|
||||||
inspect.Parameter.KEYWORD_ONLY,
|
inspect.Parameter.KEYWORD_ONLY,
|
||||||
annotation=AsyncSession,
|
annotation=AsyncSession,
|
||||||
default=Depends(session_dep),
|
default=Depends(session_callable),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -24,4 +24,4 @@ SearchFieldType = InstrumentedAttribute[Any] | tuple[InstrumentedAttribute[Any],
|
|||||||
FacetFieldType = SearchFieldType
|
FacetFieldType = SearchFieldType
|
||||||
|
|
||||||
# Dependency type aliases
|
# Dependency type aliases
|
||||||
SessionDependency = Callable[[], AsyncGenerator[AsyncSession, None]]
|
SessionDependency = Callable[[], AsyncGenerator[AsyncSession, None]] | Any
|
||||||
|
|||||||
@@ -3,13 +3,17 @@
|
|||||||
import inspect
|
import inspect
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from typing import Any, cast
|
from typing import Annotated, Any, cast
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi.params import Depends
|
from fastapi.params import Depends
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from fastapi_toolsets.dependencies import BodyDependency, PathDependency
|
from fastapi_toolsets.dependencies import (
|
||||||
|
BodyDependency,
|
||||||
|
PathDependency,
|
||||||
|
_unwrap_session_dep,
|
||||||
|
)
|
||||||
|
|
||||||
from .conftest import Role, RoleCreate, RoleCrud, User
|
from .conftest import Role, RoleCreate, RoleCrud, User
|
||||||
|
|
||||||
@@ -19,6 +23,24 @@ async def mock_get_db() -> AsyncGenerator[AsyncSession, None]:
|
|||||||
yield None
|
yield None
|
||||||
|
|
||||||
|
|
||||||
|
MockSessionDep = Annotated[AsyncSession, Depends(mock_get_db)]
|
||||||
|
|
||||||
|
|
||||||
|
class TestUnwrapSessionDep:
|
||||||
|
def test_plain_callable_returned_as_is(self):
|
||||||
|
"""Plain callable is returned unchanged."""
|
||||||
|
assert _unwrap_session_dep(mock_get_db) is mock_get_db
|
||||||
|
|
||||||
|
def test_annotated_with_depends_unwrapped(self):
|
||||||
|
"""Annotated form with Depends is unwrapped to the plain callable."""
|
||||||
|
assert _unwrap_session_dep(MockSessionDep) is mock_get_db
|
||||||
|
|
||||||
|
def test_annotated_without_depends_returned_as_is(self):
|
||||||
|
"""Annotated form with no Depends falls back to returning session_dep as-is."""
|
||||||
|
annotated_no_dep = Annotated[AsyncSession, "not_a_depends"]
|
||||||
|
assert _unwrap_session_dep(annotated_no_dep) is annotated_no_dep
|
||||||
|
|
||||||
|
|
||||||
class TestPathDependency:
|
class TestPathDependency:
|
||||||
"""Tests for PathDependency factory."""
|
"""Tests for PathDependency factory."""
|
||||||
|
|
||||||
@@ -95,6 +117,39 @@ class TestPathDependency:
|
|||||||
assert result.id == role.id
|
assert result.id == role.id
|
||||||
assert result.name == "test_role"
|
assert result.name == "test_role"
|
||||||
|
|
||||||
|
def test_annotated_session_dep_returns_depends_instance(self):
|
||||||
|
"""PathDependency accepts Annotated[AsyncSession, Depends(...)] form."""
|
||||||
|
dep = PathDependency(Role, Role.id, session_dep=MockSessionDep)
|
||||||
|
assert isinstance(dep, Depends)
|
||||||
|
|
||||||
|
def test_annotated_session_dep_signature(self):
|
||||||
|
"""PathDependency with Annotated session_dep produces a valid signature."""
|
||||||
|
dep = cast(Any, PathDependency(Role, Role.id, session_dep=MockSessionDep))
|
||||||
|
sig = inspect.signature(dep.dependency)
|
||||||
|
|
||||||
|
assert "role_id" in sig.parameters
|
||||||
|
assert "session" in sig.parameters
|
||||||
|
assert isinstance(sig.parameters["session"].default, Depends)
|
||||||
|
|
||||||
|
def test_annotated_session_dep_unwraps_callable(self):
|
||||||
|
"""PathDependency with Annotated form uses the underlying callable, not the Annotated type."""
|
||||||
|
dep = cast(Any, PathDependency(Role, Role.id, session_dep=MockSessionDep))
|
||||||
|
sig = inspect.signature(dep.dependency)
|
||||||
|
|
||||||
|
inner_dep = sig.parameters["session"].default
|
||||||
|
assert inner_dep.dependency is mock_get_db
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_annotated_session_dep_fetches_object(self, db_session):
|
||||||
|
"""PathDependency with Annotated session_dep correctly fetches object from database."""
|
||||||
|
role = await RoleCrud.create(db_session, RoleCreate(name="annotated_role"))
|
||||||
|
|
||||||
|
dep = cast(Any, PathDependency(Role, Role.id, session_dep=MockSessionDep))
|
||||||
|
result = await dep.dependency(session=db_session, role_id=role.id)
|
||||||
|
|
||||||
|
assert result.id == role.id
|
||||||
|
assert result.name == "annotated_role"
|
||||||
|
|
||||||
|
|
||||||
class TestBodyDependency:
|
class TestBodyDependency:
|
||||||
"""Tests for BodyDependency factory."""
|
"""Tests for BodyDependency factory."""
|
||||||
@@ -184,3 +239,39 @@ class TestBodyDependency:
|
|||||||
|
|
||||||
assert result.id == role.id
|
assert result.id == role.id
|
||||||
assert result.name == "body_test_role"
|
assert result.name == "body_test_role"
|
||||||
|
|
||||||
|
def test_annotated_session_dep_returns_depends_instance(self):
|
||||||
|
"""BodyDependency accepts Annotated[AsyncSession, Depends(...)] form."""
|
||||||
|
dep = BodyDependency(
|
||||||
|
Role, Role.id, session_dep=MockSessionDep, body_field="role_id"
|
||||||
|
)
|
||||||
|
assert isinstance(dep, Depends)
|
||||||
|
|
||||||
|
def test_annotated_session_dep_unwraps_callable(self):
|
||||||
|
"""BodyDependency with Annotated form uses the underlying callable, not the Annotated type."""
|
||||||
|
dep = cast(
|
||||||
|
Any,
|
||||||
|
BodyDependency(
|
||||||
|
Role, Role.id, session_dep=MockSessionDep, body_field="role_id"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
sig = inspect.signature(dep.dependency)
|
||||||
|
|
||||||
|
inner_dep = sig.parameters["session"].default
|
||||||
|
assert inner_dep.dependency is mock_get_db
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_annotated_session_dep_fetches_object(self, db_session):
|
||||||
|
"""BodyDependency with Annotated session_dep correctly fetches object from database."""
|
||||||
|
role = await RoleCrud.create(db_session, RoleCreate(name="body_annotated_role"))
|
||||||
|
|
||||||
|
dep = cast(
|
||||||
|
Any,
|
||||||
|
BodyDependency(
|
||||||
|
Role, Role.id, session_dep=MockSessionDep, body_field="role_id"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
result = await dep.dependency(session=db_session, role_id=role.id)
|
||||||
|
|
||||||
|
assert result.id == role.id
|
||||||
|
assert result.name == "body_annotated_role"
|
||||||
|
|||||||
Reference in New Issue
Block a user