diff --git a/docs/module/dependencies.md b/docs/module/dependencies.md index 4097718..5f54c67 100644 --- a/docs/module/dependencies.md +++ b/docs/module/dependencies.md @@ -13,8 +13,13 @@ The `dependencies` module provides two factory functions that create FastAPI dep ```python from fastapi_toolsets.dependencies import PathDependency +# Plain callable UserDep = PathDependency(model=User, field=User.id, session_dep=get_db) +# Annotated +SessionDep = Annotated[AsyncSession, Depends(get_db)] +UserDep = PathDependency(model=User, field=User.id, session_dep=SessionDep) + @router.get("/users/{user_id}") async def get_user(user: User = UserDep): return user @@ -37,8 +42,14 @@ async def get_user(user: User = UserDep): ```python from fastapi_toolsets.dependencies import BodyDependency +# Plain callable RoleDep = BodyDependency(model=Role, field=Role.id, session_dep=get_db, body_field="role_id") +# Annotated +SessionDep = Annotated[AsyncSession, Depends(get_db)] +RoleDep = BodyDependency(model=Role, field=Role.id, session_dep=SessionDep, body_field="role_id") + + @router.post("/users") async def create_user(body: UserCreateSchema, role: Role = RoleDep): user = User(username=body.username, role=role) diff --git a/src/fastapi_toolsets/dependencies.py b/src/fastapi_toolsets/dependencies.py index 26eb75c..bfcbf0b 100644 --- a/src/fastapi_toolsets/dependencies.py +++ b/src/fastapi_toolsets/dependencies.py @@ -1,10 +1,12 @@ """Dependency factories for FastAPI routes.""" import inspect +import typing from collections.abc import Callable from typing import Any, cast from fastapi import Depends +from fastapi.params import Depends as DependsClass from sqlalchemy.ext.asyncio import AsyncSession from .crud import CrudFactory @@ -13,6 +15,15 @@ from .types import ModelType, SessionDependency __all__ = ["BodyDependency", "PathDependency"] +def _unwrap_session_dep(session_dep: SessionDependency) -> Callable[..., Any]: + """Extract the plain callable from ``Annotated[AsyncSession, Depends(fn)]`` if needed.""" + if typing.get_origin(session_dep) is typing.Annotated: + for arg in typing.get_args(session_dep)[1:]: + if isinstance(arg, DependsClass): + return arg.dependency + return session_dep + + def PathDependency( model: type[ModelType], field: Any, @@ -44,6 +55,7 @@ def PathDependency( ): ... ``` """ + session_callable = _unwrap_session_dep(session_dep) crud = CrudFactory(model) name = ( param_name @@ -53,7 +65,7 @@ def PathDependency( python_type = field.type.python_type async def dependency( - session: AsyncSession = Depends(session_dep), **kwargs: Any + session: AsyncSession = Depends(session_callable), **kwargs: Any ) -> ModelType: value = kwargs[name] return await crud.get(session, filters=[field == value]) @@ -70,7 +82,7 @@ def PathDependency( "session", inspect.Parameter.KEYWORD_ONLY, annotation=AsyncSession, - default=Depends(session_dep), + default=Depends(session_callable), ), ] ), @@ -112,11 +124,12 @@ def BodyDependency( ): ... ``` """ + session_callable = _unwrap_session_dep(session_dep) crud = CrudFactory(model) python_type = field.type.python_type async def dependency( - session: AsyncSession = Depends(session_dep), **kwargs: Any + session: AsyncSession = Depends(session_callable), **kwargs: Any ) -> ModelType: value = kwargs[body_field] return await crud.get(session, filters=[field == value]) @@ -133,7 +146,7 @@ def BodyDependency( "session", inspect.Parameter.KEYWORD_ONLY, annotation=AsyncSession, - default=Depends(session_dep), + default=Depends(session_callable), ), ] ), diff --git a/src/fastapi_toolsets/types.py b/src/fastapi_toolsets/types.py index 1941781..a89eef8 100644 --- a/src/fastapi_toolsets/types.py +++ b/src/fastapi_toolsets/types.py @@ -24,4 +24,4 @@ SearchFieldType = InstrumentedAttribute[Any] | tuple[InstrumentedAttribute[Any], FacetFieldType = SearchFieldType # Dependency type aliases -SessionDependency = Callable[[], AsyncGenerator[AsyncSession, None]] +SessionDependency = Callable[[], AsyncGenerator[AsyncSession, None]] | Any diff --git a/tests/test_dependencies.py b/tests/test_dependencies.py index 2a89b81..b974b40 100644 --- a/tests/test_dependencies.py +++ b/tests/test_dependencies.py @@ -3,13 +3,17 @@ import inspect import uuid from collections.abc import AsyncGenerator -from typing import Any, cast +from typing import Annotated, Any, cast import pytest from fastapi.params import Depends from sqlalchemy.ext.asyncio import AsyncSession -from fastapi_toolsets.dependencies import BodyDependency, PathDependency +from fastapi_toolsets.dependencies import ( + BodyDependency, + PathDependency, + _unwrap_session_dep, +) from .conftest import Role, RoleCreate, RoleCrud, User @@ -19,6 +23,24 @@ async def mock_get_db() -> AsyncGenerator[AsyncSession, None]: yield None +MockSessionDep = Annotated[AsyncSession, Depends(mock_get_db)] + + +class TestUnwrapSessionDep: + def test_plain_callable_returned_as_is(self): + """Plain callable is returned unchanged.""" + assert _unwrap_session_dep(mock_get_db) is mock_get_db + + def test_annotated_with_depends_unwrapped(self): + """Annotated form with Depends is unwrapped to the plain callable.""" + assert _unwrap_session_dep(MockSessionDep) is mock_get_db + + def test_annotated_without_depends_returned_as_is(self): + """Annotated form with no Depends falls back to returning session_dep as-is.""" + annotated_no_dep = Annotated[AsyncSession, "not_a_depends"] + assert _unwrap_session_dep(annotated_no_dep) is annotated_no_dep + + class TestPathDependency: """Tests for PathDependency factory.""" @@ -95,6 +117,39 @@ class TestPathDependency: assert result.id == role.id assert result.name == "test_role" + def test_annotated_session_dep_returns_depends_instance(self): + """PathDependency accepts Annotated[AsyncSession, Depends(...)] form.""" + dep = PathDependency(Role, Role.id, session_dep=MockSessionDep) + assert isinstance(dep, Depends) + + def test_annotated_session_dep_signature(self): + """PathDependency with Annotated session_dep produces a valid signature.""" + dep = cast(Any, PathDependency(Role, Role.id, session_dep=MockSessionDep)) + sig = inspect.signature(dep.dependency) + + assert "role_id" in sig.parameters + assert "session" in sig.parameters + assert isinstance(sig.parameters["session"].default, Depends) + + def test_annotated_session_dep_unwraps_callable(self): + """PathDependency with Annotated form uses the underlying callable, not the Annotated type.""" + dep = cast(Any, PathDependency(Role, Role.id, session_dep=MockSessionDep)) + sig = inspect.signature(dep.dependency) + + inner_dep = sig.parameters["session"].default + assert inner_dep.dependency is mock_get_db + + @pytest.mark.anyio + async def test_annotated_session_dep_fetches_object(self, db_session): + """PathDependency with Annotated session_dep correctly fetches object from database.""" + role = await RoleCrud.create(db_session, RoleCreate(name="annotated_role")) + + dep = cast(Any, PathDependency(Role, Role.id, session_dep=MockSessionDep)) + result = await dep.dependency(session=db_session, role_id=role.id) + + assert result.id == role.id + assert result.name == "annotated_role" + class TestBodyDependency: """Tests for BodyDependency factory.""" @@ -184,3 +239,39 @@ class TestBodyDependency: assert result.id == role.id assert result.name == "body_test_role" + + def test_annotated_session_dep_returns_depends_instance(self): + """BodyDependency accepts Annotated[AsyncSession, Depends(...)] form.""" + dep = BodyDependency( + Role, Role.id, session_dep=MockSessionDep, body_field="role_id" + ) + assert isinstance(dep, Depends) + + def test_annotated_session_dep_unwraps_callable(self): + """BodyDependency with Annotated form uses the underlying callable, not the Annotated type.""" + dep = cast( + Any, + BodyDependency( + Role, Role.id, session_dep=MockSessionDep, body_field="role_id" + ), + ) + sig = inspect.signature(dep.dependency) + + inner_dep = sig.parameters["session"].default + assert inner_dep.dependency is mock_get_db + + @pytest.mark.anyio + async def test_annotated_session_dep_fetches_object(self, db_session): + """BodyDependency with Annotated session_dep correctly fetches object from database.""" + role = await RoleCrud.create(db_session, RoleCreate(name="body_annotated_role")) + + dep = cast( + Any, + BodyDependency( + Role, Role.id, session_dep=MockSessionDep, body_field="role_id" + ), + ) + result = await dep.dependency(session=db_session, role_id=role.id) + + assert result.id == role.id + assert result.name == "body_annotated_role"