fix: cleanup + simplify

This commit is contained in:
2026-03-18 15:25:02 -04:00
parent 4342d224c7
commit 168345756f
8 changed files with 120 additions and 43 deletions

View File

@@ -10,13 +10,15 @@ from fastapi.security import SecurityScopes
from fastapi_toolsets.exceptions import UnauthorizedError
async def _call_validator(
validator: Callable[..., Any], *args: Any, **kwargs: Any
) -> Any:
"""Call *validator* with *args* and *kwargs*, awaiting it if it is a coroutine function."""
if inspect.iscoroutinefunction(validator):
return await validator(*args, **kwargs)
return validator(*args, **kwargs)
def _ensure_async(fn: Callable[..., Any]) -> Callable[..., Any]:
"""Wrap *fn* so it can always be awaited, caching the coroutine check at init time."""
if inspect.iscoroutinefunction(fn):
return fn
async def wrapper(*args: Any, **kwargs: Any) -> Any:
return fn(*args, **kwargs)
return wrapper
class AuthSource(ABC):

View File

@@ -9,7 +9,7 @@ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer, SecurityS
from fastapi_toolsets.exceptions import UnauthorizedError
from ..abc import AuthSource, _call_validator
from ..abc import AuthSource, _ensure_async
class BearerTokenAuth(AuthSource):
@@ -42,32 +42,30 @@ class BearerTokenAuth(AuthSource):
prefix: str | None = None,
**kwargs: Any,
) -> None:
self._validator = validator
self._validator = _ensure_async(validator)
self._prefix = prefix
self._kwargs = kwargs
self._scheme = HTTPBearer(auto_error=False)
_scheme = self._scheme
_validator = validator
_kwargs = kwargs
_prefix = prefix
async def _call(
security_scopes: SecurityScopes, # noqa: ARG001
credentials: Annotated[
HTTPAuthorizationCredentials | None, Depends(_scheme)
HTTPAuthorizationCredentials | None, Depends(self._scheme)
] = None,
) -> Any:
if credentials is None:
raise UnauthorizedError()
token = credentials.credentials
if _prefix is not None and not token.startswith(_prefix):
raise UnauthorizedError()
return await _call_validator(_validator, token, **_kwargs)
return await self._validate(credentials.credentials)
self._call_fn = _call
self.__signature__ = inspect.signature(_call)
async def _validate(self, token: str) -> Any:
"""Check prefix and call the validator."""
if self._prefix is not None and not token.startswith(self._prefix):
raise UnauthorizedError()
return await self._validator(token, **self._kwargs)
async def extract(self, request: Any) -> str | None:
"""Extract the raw credential from the request without validating.
@@ -91,7 +89,7 @@ class BearerTokenAuth(AuthSource):
Calls ``await validator(credential, **kwargs)`` where ``kwargs`` are
the extra keyword arguments provided at instantiation.
"""
return await _call_validator(self._validator, credential, **self._kwargs)
return await self._validate(credential)
def require(self, **kwargs: Any) -> "BearerTokenAuth":
"""Return a new instance with additional (or overriding) validator kwargs."""

View File

@@ -13,7 +13,7 @@ from fastapi.security import APIKeyCookie, SecurityScopes
from fastapi_toolsets.exceptions import UnauthorizedError
from ..abc import AuthSource, _call_validator
from ..abc import AuthSource, _ensure_async
class CookieAuth(AuthSource):
@@ -50,30 +50,27 @@ class CookieAuth(AuthSource):
**kwargs: Any,
) -> None:
self._name = name
self._validator = validator
self._validator = _ensure_async(validator)
self._secret_key = secret_key
self._ttl = ttl
self._kwargs = kwargs
self._scheme = APIKeyCookie(name=name, auto_error=False)
_scheme = self._scheme
_self = self
_kwargs = kwargs
async def _call(
security_scopes: SecurityScopes, # noqa: ARG001
value: Annotated[str | None, Depends(_scheme)] = None,
value: Annotated[str | None, Depends(self._scheme)] = None,
) -> Any:
if value is None:
raise UnauthorizedError()
plain = _self._verify(value)
return await _call_validator(_self._validator, plain, **_kwargs)
plain = self._verify(value)
return await self._validator(plain, **self._kwargs)
self._call_fn = _call
self.__signature__ = inspect.signature(_call)
def _hmac(self, data: str) -> str:
assert self._secret_key is not None
if self._secret_key is None:
raise RuntimeError("_hmac called without secret_key configured")
return hmac.new(
self._secret_key.encode(), data.encode(), hashlib.sha256
).hexdigest()
@@ -114,7 +111,7 @@ class CookieAuth(AuthSource):
async def authenticate(self, credential: str) -> Any:
plain = self._verify(credential)
return await _call_validator(self._validator, plain, **self._kwargs)
return await self._validator(plain, **self._kwargs)
def require(self, **kwargs: Any) -> "CookieAuth":
"""Return a new instance with additional (or overriding) validator kwargs."""

View File

@@ -8,7 +8,7 @@ from fastapi.security import APIKeyHeader, SecurityScopes
from fastapi_toolsets.exceptions import UnauthorizedError
from ..abc import AuthSource, _call_validator
from ..abc import AuthSource, _ensure_async
class APIKeyHeaderAuth(AuthSource):
@@ -35,21 +35,17 @@ class APIKeyHeaderAuth(AuthSource):
**kwargs: Any,
) -> None:
self._name = name
self._validator = validator
self._validator = _ensure_async(validator)
self._kwargs = kwargs
self._scheme = APIKeyHeader(name=name, auto_error=False)
_scheme = self._scheme
_validator = validator
_kwargs = kwargs
async def _call(
security_scopes: SecurityScopes, # noqa: ARG001
api_key: Annotated[str | None, Depends(_scheme)] = None,
api_key: Annotated[str | None, Depends(self._scheme)] = None,
) -> Any:
if api_key is None:
raise UnauthorizedError()
return await _call_validator(_validator, api_key, **_kwargs)
return await self._validator(api_key, **self._kwargs)
self._call_fn = _call
self.__signature__ = inspect.signature(_call)
@@ -60,7 +56,7 @@ class APIKeyHeaderAuth(AuthSource):
async def authenticate(self, credential: str) -> Any:
"""Validate a credential and return the identity."""
return await _call_validator(self._validator, credential, **self._kwargs)
return await self._validator(credential, **self._kwargs)
def require(self, **kwargs: Any) -> "APIKeyHeaderAuth":
"""Return a new instance with additional (or overriding) validator kwargs."""

View File

@@ -55,14 +55,12 @@ class MultiAuth:
def __init__(self, *sources: AuthSource) -> None:
self._sources = sources
_sources = sources
async def _call(
request: Request,
security_scopes: SecurityScopes, # noqa: ARG001
**kwargs: Any, # noqa: ARG001 — absorbs scheme values injected by FastAPI
) -> Any:
for source in _sources:
for source in self._sources:
credential = await source.extract(request)
if credential is not None:
return await source.authenticate(credential)