Compare commits

..

1 Commits

4 changed files with 122 additions and 7 deletions

View File

@@ -13,8 +13,13 @@ The `dependencies` module provides two factory functions that create FastAPI dep
```python ```python
from fastapi_toolsets.dependencies import PathDependency from fastapi_toolsets.dependencies import PathDependency
# Plain callable
UserDep = PathDependency(model=User, field=User.id, session_dep=get_db) UserDep = PathDependency(model=User, field=User.id, session_dep=get_db)
# Annotated
SessionDep = Annotated[AsyncSession, Depends(get_db)]
UserDep = PathDependency(model=User, field=User.id, session_dep=SessionDep)
@router.get("/users/{user_id}") @router.get("/users/{user_id}")
async def get_user(user: User = UserDep): async def get_user(user: User = UserDep):
return user return user
@@ -37,8 +42,14 @@ async def get_user(user: User = UserDep):
```python ```python
from fastapi_toolsets.dependencies import BodyDependency from fastapi_toolsets.dependencies import BodyDependency
# Plain callable
RoleDep = BodyDependency(model=Role, field=Role.id, session_dep=get_db, body_field="role_id") RoleDep = BodyDependency(model=Role, field=Role.id, session_dep=get_db, body_field="role_id")
# Annotated
SessionDep = Annotated[AsyncSession, Depends(get_db)]
RoleDep = BodyDependency(model=Role, field=Role.id, session_dep=SessionDep, body_field="role_id")
@router.post("/users") @router.post("/users")
async def create_user(body: UserCreateSchema, role: Role = RoleDep): async def create_user(body: UserCreateSchema, role: Role = RoleDep):
user = User(username=body.username, role=role) user = User(username=body.username, role=role)

View File

@@ -1,10 +1,12 @@
"""Dependency factories for FastAPI routes.""" """Dependency factories for FastAPI routes."""
import inspect import inspect
import typing
from collections.abc import Callable from collections.abc import Callable
from typing import Any, cast from typing import Any, cast
from fastapi import Depends from fastapi import Depends
from fastapi.params import Depends as DependsClass
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from .crud import CrudFactory from .crud import CrudFactory
@@ -13,6 +15,15 @@ from .types import ModelType, SessionDependency
__all__ = ["BodyDependency", "PathDependency"] __all__ = ["BodyDependency", "PathDependency"]
def _unwrap_session_dep(session_dep: SessionDependency) -> Callable[..., Any]:
"""Extract the plain callable from ``Annotated[AsyncSession, Depends(fn)]`` if needed."""
if typing.get_origin(session_dep) is typing.Annotated:
for arg in typing.get_args(session_dep)[1:]:
if isinstance(arg, DependsClass):
return arg.dependency
return session_dep
def PathDependency( def PathDependency(
model: type[ModelType], model: type[ModelType],
field: Any, field: Any,
@@ -44,6 +55,7 @@ def PathDependency(
): ... ): ...
``` ```
""" """
session_callable = _unwrap_session_dep(session_dep)
crud = CrudFactory(model) crud = CrudFactory(model)
name = ( name = (
param_name param_name
@@ -53,7 +65,7 @@ def PathDependency(
python_type = field.type.python_type python_type = field.type.python_type
async def dependency( async def dependency(
session: AsyncSession = Depends(session_dep), **kwargs: Any session: AsyncSession = Depends(session_callable), **kwargs: Any
) -> ModelType: ) -> ModelType:
value = kwargs[name] value = kwargs[name]
return await crud.get(session, filters=[field == value]) return await crud.get(session, filters=[field == value])
@@ -70,7 +82,7 @@ def PathDependency(
"session", "session",
inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.KEYWORD_ONLY,
annotation=AsyncSession, annotation=AsyncSession,
default=Depends(session_dep), default=Depends(session_callable),
), ),
] ]
), ),
@@ -112,11 +124,12 @@ def BodyDependency(
): ... ): ...
``` ```
""" """
session_callable = _unwrap_session_dep(session_dep)
crud = CrudFactory(model) crud = CrudFactory(model)
python_type = field.type.python_type python_type = field.type.python_type
async def dependency( async def dependency(
session: AsyncSession = Depends(session_dep), **kwargs: Any session: AsyncSession = Depends(session_callable), **kwargs: Any
) -> ModelType: ) -> ModelType:
value = kwargs[body_field] value = kwargs[body_field]
return await crud.get(session, filters=[field == value]) return await crud.get(session, filters=[field == value])
@@ -133,7 +146,7 @@ def BodyDependency(
"session", "session",
inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.KEYWORD_ONLY,
annotation=AsyncSession, annotation=AsyncSession,
default=Depends(session_dep), default=Depends(session_callable),
), ),
] ]
), ),

View File

@@ -24,4 +24,4 @@ SearchFieldType = InstrumentedAttribute[Any] | tuple[InstrumentedAttribute[Any],
FacetFieldType = SearchFieldType FacetFieldType = SearchFieldType
# Dependency type aliases # Dependency type aliases
SessionDependency = Callable[[], AsyncGenerator[AsyncSession, None]] SessionDependency = Callable[[], AsyncGenerator[AsyncSession, None]] | Any

View File

@@ -3,13 +3,17 @@
import inspect import inspect
import uuid import uuid
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from typing import Any, cast from typing import Annotated, Any, cast
import pytest import pytest
from fastapi.params import Depends from fastapi.params import Depends
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from fastapi_toolsets.dependencies import BodyDependency, PathDependency from fastapi_toolsets.dependencies import (
BodyDependency,
PathDependency,
_unwrap_session_dep,
)
from .conftest import Role, RoleCreate, RoleCrud, User from .conftest import Role, RoleCreate, RoleCrud, User
@@ -19,6 +23,24 @@ async def mock_get_db() -> AsyncGenerator[AsyncSession, None]:
yield None yield None
MockSessionDep = Annotated[AsyncSession, Depends(mock_get_db)]
class TestUnwrapSessionDep:
def test_plain_callable_returned_as_is(self):
"""Plain callable is returned unchanged."""
assert _unwrap_session_dep(mock_get_db) is mock_get_db
def test_annotated_with_depends_unwrapped(self):
"""Annotated form with Depends is unwrapped to the plain callable."""
assert _unwrap_session_dep(MockSessionDep) is mock_get_db
def test_annotated_without_depends_returned_as_is(self):
"""Annotated form with no Depends falls back to returning session_dep as-is."""
annotated_no_dep = Annotated[AsyncSession, "not_a_depends"]
assert _unwrap_session_dep(annotated_no_dep) is annotated_no_dep
class TestPathDependency: class TestPathDependency:
"""Tests for PathDependency factory.""" """Tests for PathDependency factory."""
@@ -95,6 +117,39 @@ class TestPathDependency:
assert result.id == role.id assert result.id == role.id
assert result.name == "test_role" assert result.name == "test_role"
def test_annotated_session_dep_returns_depends_instance(self):
"""PathDependency accepts Annotated[AsyncSession, Depends(...)] form."""
dep = PathDependency(Role, Role.id, session_dep=MockSessionDep)
assert isinstance(dep, Depends)
def test_annotated_session_dep_signature(self):
"""PathDependency with Annotated session_dep produces a valid signature."""
dep = cast(Any, PathDependency(Role, Role.id, session_dep=MockSessionDep))
sig = inspect.signature(dep.dependency)
assert "role_id" in sig.parameters
assert "session" in sig.parameters
assert isinstance(sig.parameters["session"].default, Depends)
def test_annotated_session_dep_unwraps_callable(self):
"""PathDependency with Annotated form uses the underlying callable, not the Annotated type."""
dep = cast(Any, PathDependency(Role, Role.id, session_dep=MockSessionDep))
sig = inspect.signature(dep.dependency)
inner_dep = sig.parameters["session"].default
assert inner_dep.dependency is mock_get_db
@pytest.mark.anyio
async def test_annotated_session_dep_fetches_object(self, db_session):
"""PathDependency with Annotated session_dep correctly fetches object from database."""
role = await RoleCrud.create(db_session, RoleCreate(name="annotated_role"))
dep = cast(Any, PathDependency(Role, Role.id, session_dep=MockSessionDep))
result = await dep.dependency(session=db_session, role_id=role.id)
assert result.id == role.id
assert result.name == "annotated_role"
class TestBodyDependency: class TestBodyDependency:
"""Tests for BodyDependency factory.""" """Tests for BodyDependency factory."""
@@ -184,3 +239,39 @@ class TestBodyDependency:
assert result.id == role.id assert result.id == role.id
assert result.name == "body_test_role" assert result.name == "body_test_role"
def test_annotated_session_dep_returns_depends_instance(self):
"""BodyDependency accepts Annotated[AsyncSession, Depends(...)] form."""
dep = BodyDependency(
Role, Role.id, session_dep=MockSessionDep, body_field="role_id"
)
assert isinstance(dep, Depends)
def test_annotated_session_dep_unwraps_callable(self):
"""BodyDependency with Annotated form uses the underlying callable, not the Annotated type."""
dep = cast(
Any,
BodyDependency(
Role, Role.id, session_dep=MockSessionDep, body_field="role_id"
),
)
sig = inspect.signature(dep.dependency)
inner_dep = sig.parameters["session"].default
assert inner_dep.dependency is mock_get_db
@pytest.mark.anyio
async def test_annotated_session_dep_fetches_object(self, db_session):
"""BodyDependency with Annotated session_dep correctly fetches object from database."""
role = await RoleCrud.create(db_session, RoleCreate(name="body_annotated_role"))
dep = cast(
Any,
BodyDependency(
Role, Role.id, session_dep=MockSessionDep, body_field="role_id"
),
)
result = await dep.dependency(session=db_session, role_id=role.id)
assert result.id == role.id
assert result.name == "body_annotated_role"