3 Commits

4 changed files with 74 additions and 6 deletions

View File

@@ -29,9 +29,14 @@ def get_obj_by_attr(
The first model instance where the attribute matches the given value.
Raises:
StopIteration: If no matching object is found.
StopIteration: If no matching object is found in the fixture group.
"""
try:
return next(obj for obj in fixtures() if getattr(obj, attr_name) == value)
except StopIteration:
raise StopIteration(
f"No object with {attr_name}={value} found in fixture '{getattr(fixtures, '__name__', repr(fixtures))}'"
) from None
async def load_fixtures(

View File

@@ -2,9 +2,10 @@
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from typing import Any
from typing import Any, Literal
from httpx import ASGITransport, AsyncClient
from httpx import ASGITransport, AsyncClient, Response
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase
@@ -108,3 +109,61 @@ async def create_db_session(
await conn.run_sync(base.metadata.drop_all)
finally:
await engine.dispose()
def _normalize_expected(
expected: BaseModel | list[BaseModel] | dict | list[dict],
) -> Any:
"""Normalize expected data to a JSON-compatible structure."""
if isinstance(expected, BaseModel):
return expected.model_dump(mode="json")
if isinstance(expected, list):
return [
item.model_dump(mode="json") if isinstance(item, BaseModel) else item
for item in expected
]
return expected
HttpMethod = Literal["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"]
async def assert_endpoint(
client: AsyncClient,
method: HttpMethod,
url: str,
*,
expected_status: int = 200,
expected_data: BaseModel | list[BaseModel] | dict | list[dict] | None = None,
request_headers: dict[str, str] | None = None,
request_json: Any | None = None,
request_params: dict[str, Any] | None = None,
request_content: bytes | None = None,
) -> Response:
"""Assert an API endpoint returns the expected status and data."""
kwargs: dict[str, Any] = {}
if request_headers is not None:
kwargs["headers"] = request_headers
if request_json is not None:
kwargs["json"] = request_json
if request_params is not None:
kwargs["params"] = request_params
if request_content is not None:
kwargs["content"] = request_content
response = await client.request(method, url, **kwargs)
assert response.status_code == expected_status, (
f"Expected status {expected_status}, got {response.status_code}. "
f"Response body: {response.text}"
)
if expected_data is not None:
response_json = response.json()
actual_data = response_json.get("data")
normalized = _normalize_expected(expected_data)
assert actual_data == normalized, (
f"Response data mismatch.\nExpected: {normalized}\nActual: {actual_data}"
)
return response

View File

@@ -10,6 +10,7 @@ __all__ = [
"ErrorResponse",
"Pagination",
"PaginatedResponse",
"PydanticBase",
"Response",
"ResponseStatus",
]

View File

@@ -744,8 +744,11 @@ class TestGetObjByAttr:
assert user.username == "alice"
def test_no_match_raises_stop_iteration(self):
"""Raises StopIteration when no object matches."""
with pytest.raises(StopIteration):
"""Raises StopIteration with contextual message when no object matches."""
with pytest.raises(
StopIteration,
match="No object with name=nonexistent found in fixture 'roles'",
):
get_obj_by_attr(self.roles, "name", "nonexistent")
def test_no_match_on_wrong_value_type(self):