mirror of
https://github.com/d3vyce/fastapi-toolsets.git
synced 2026-03-01 17:00:48 +01:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
54f5479c24
|
|||
|
|
f467754df1 | ||
|
b57ce40b05
|
|||
|
5264631550
|
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "fastapi-toolsets"
|
||||
version = "0.4.0"
|
||||
version = "0.4.1"
|
||||
description = "Reusable tools for FastAPI: async CRUD, fixtures, CLI, and standardized responses for SQLAlchemy + PostgreSQL"
|
||||
readme = "README.md"
|
||||
license = "MIT"
|
||||
|
||||
@@ -21,4 +21,4 @@ Example usage:
|
||||
return Response(data={"user": user.username}, message="Success")
|
||||
"""
|
||||
|
||||
__version__ = "0.4.0"
|
||||
__version__ = "0.4.1"
|
||||
|
||||
@@ -129,11 +129,12 @@ def build_search_filters(
|
||||
else:
|
||||
column = field
|
||||
|
||||
# Build the filter
|
||||
# Build the filter (cast to String for non-text columns)
|
||||
column_as_string = column.cast(String)
|
||||
if config.case_sensitive:
|
||||
filters.append(column.like(f"%{query}%"))
|
||||
filters.append(column_as_string.like(f"%{query}%"))
|
||||
else:
|
||||
filters.append(column.ilike(f"%{query}%"))
|
||||
filters.append(column_as_string.ilike(f"%{query}%"))
|
||||
|
||||
if not filters:
|
||||
return [], []
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from .exceptions import (
|
||||
ApiError,
|
||||
ApiException,
|
||||
ConflictError,
|
||||
ForbiddenError,
|
||||
@@ -10,11 +11,12 @@ from .exceptions import (
|
||||
from .handler import init_exceptions_handlers
|
||||
|
||||
__all__ = [
|
||||
"init_exceptions_handlers",
|
||||
"generate_error_responses",
|
||||
"ApiError",
|
||||
"ApiException",
|
||||
"ConflictError",
|
||||
"ForbiddenError",
|
||||
"generate_error_responses",
|
||||
"init_exceptions_handlers",
|
||||
"NoSearchableFieldsError",
|
||||
"NotFoundError",
|
||||
"UnauthorizedError",
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
"""Shared pytest fixtures for fastapi-utils tests."""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import ForeignKey, String
|
||||
from sqlalchemy import ForeignKey, String, Uuid
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||
|
||||
@@ -33,7 +34,7 @@ class Role(Base):
|
||||
|
||||
__tablename__ = "roles"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||
name: Mapped[str] = mapped_column(String(50), unique=True)
|
||||
|
||||
users: Mapped[list["User"]] = relationship(back_populates="role")
|
||||
@@ -44,11 +45,13 @@ class User(Base):
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||
username: Mapped[str] = mapped_column(String(50), unique=True)
|
||||
email: Mapped[str] = mapped_column(String(100), unique=True)
|
||||
is_active: Mapped[bool] = mapped_column(default=True)
|
||||
role_id: Mapped[int | None] = mapped_column(ForeignKey("roles.id"), nullable=True)
|
||||
role_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
ForeignKey("roles.id"), nullable=True
|
||||
)
|
||||
|
||||
role: Mapped[Role | None] = relationship(back_populates="users")
|
||||
|
||||
@@ -58,11 +61,11 @@ class Post(Base):
|
||||
|
||||
__tablename__ = "posts"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
||||
title: Mapped[str] = mapped_column(String(200))
|
||||
content: Mapped[str] = mapped_column(String(1000), default="")
|
||||
is_published: Mapped[bool] = mapped_column(default=False)
|
||||
author_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
|
||||
author_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -73,7 +76,7 @@ class Post(Base):
|
||||
class RoleCreate(BaseModel):
|
||||
"""Schema for creating a role."""
|
||||
|
||||
id: int | None = None
|
||||
id: uuid.UUID | None = None
|
||||
name: str
|
||||
|
||||
|
||||
@@ -86,11 +89,11 @@ class RoleUpdate(BaseModel):
|
||||
class UserCreate(BaseModel):
|
||||
"""Schema for creating a user."""
|
||||
|
||||
id: int | None = None
|
||||
id: uuid.UUID | None = None
|
||||
username: str
|
||||
email: str
|
||||
is_active: bool = True
|
||||
role_id: int | None = None
|
||||
role_id: uuid.UUID | None = None
|
||||
|
||||
|
||||
class UserUpdate(BaseModel):
|
||||
@@ -99,17 +102,17 @@ class UserUpdate(BaseModel):
|
||||
username: str | None = None
|
||||
email: str | None = None
|
||||
is_active: bool | None = None
|
||||
role_id: int | None = None
|
||||
role_id: uuid.UUID | None = None
|
||||
|
||||
|
||||
class PostCreate(BaseModel):
|
||||
"""Schema for creating a post."""
|
||||
|
||||
id: int | None = None
|
||||
id: uuid.UUID | None = None
|
||||
title: str
|
||||
content: str = ""
|
||||
is_published: bool = False
|
||||
author_id: int
|
||||
author_id: uuid.UUID
|
||||
|
||||
|
||||
class PostUpdate(BaseModel):
|
||||
@@ -195,5 +198,5 @@ def sample_post_data() -> PostCreate:
|
||||
title="Test Post",
|
||||
content="Test content",
|
||||
is_published=True,
|
||||
author_id=1,
|
||||
author_id=uuid.uuid4(),
|
||||
)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Tests for fastapi_toolsets.crud module."""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
@@ -89,8 +91,9 @@ class TestCrudGet:
|
||||
@pytest.mark.anyio
|
||||
async def test_get_raises_not_found(self, db_session: AsyncSession):
|
||||
"""Get raises NotFoundError for missing records."""
|
||||
non_existent_id = uuid.uuid4()
|
||||
with pytest.raises(NotFoundError):
|
||||
await RoleCrud.get(db_session, [Role.id == 99999])
|
||||
await RoleCrud.get(db_session, [Role.id == non_existent_id])
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_with_multiple_filters(self, db_session: AsyncSession):
|
||||
@@ -223,11 +226,12 @@ class TestCrudUpdate:
|
||||
@pytest.mark.anyio
|
||||
async def test_update_raises_not_found(self, db_session: AsyncSession):
|
||||
"""Update raises NotFoundError for missing records."""
|
||||
non_existent_id = uuid.uuid4()
|
||||
with pytest.raises(NotFoundError):
|
||||
await RoleCrud.update(
|
||||
db_session,
|
||||
RoleUpdate(name="new"),
|
||||
[Role.id == 99999],
|
||||
[Role.id == non_existent_id],
|
||||
)
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -340,7 +344,8 @@ class TestCrudUpsert:
|
||||
@pytest.mark.anyio
|
||||
async def test_upsert_insert_new_record(self, db_session: AsyncSession):
|
||||
"""Upsert inserts a new record when it doesn't exist."""
|
||||
data = RoleCreate(id=1, name="upsert_new")
|
||||
role_id = uuid.uuid4()
|
||||
data = RoleCreate(id=role_id, name="upsert_new")
|
||||
role = await RoleCrud.upsert(
|
||||
db_session,
|
||||
data,
|
||||
@@ -353,12 +358,13 @@ class TestCrudUpsert:
|
||||
@pytest.mark.anyio
|
||||
async def test_upsert_update_existing_record(self, db_session: AsyncSession):
|
||||
"""Upsert updates an existing record."""
|
||||
role_id = uuid.uuid4()
|
||||
# First insert
|
||||
data = RoleCreate(id=100, name="original_name")
|
||||
data = RoleCreate(id=role_id, name="original_name")
|
||||
await RoleCrud.upsert(db_session, data, index_elements=["id"])
|
||||
|
||||
# Upsert with update
|
||||
updated_data = RoleCreate(id=100, name="updated_name")
|
||||
updated_data = RoleCreate(id=role_id, name="updated_name")
|
||||
role = await RoleCrud.upsert(
|
||||
db_session,
|
||||
updated_data,
|
||||
@@ -370,22 +376,23 @@ class TestCrudUpsert:
|
||||
assert role.name == "updated_name"
|
||||
|
||||
# Verify only one record exists
|
||||
count = await RoleCrud.count(db_session, [Role.id == 100])
|
||||
count = await RoleCrud.count(db_session, [Role.id == role_id])
|
||||
assert count == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_upsert_do_nothing_on_conflict(self, db_session: AsyncSession):
|
||||
"""Upsert does nothing on conflict when set_ is not provided."""
|
||||
role_id = uuid.uuid4()
|
||||
# First insert
|
||||
data = RoleCreate(id=200, name="do_nothing_original")
|
||||
data = RoleCreate(id=role_id, name="do_nothing_original")
|
||||
await RoleCrud.upsert(db_session, data, index_elements=["id"])
|
||||
|
||||
# Upsert without set_ (do nothing)
|
||||
conflict_data = RoleCreate(id=200, name="do_nothing_conflict")
|
||||
conflict_data = RoleCreate(id=role_id, name="do_nothing_conflict")
|
||||
await RoleCrud.upsert(db_session, conflict_data, index_elements=["id"])
|
||||
|
||||
# Original value should be preserved
|
||||
role = await RoleCrud.first(db_session, [Role.id == 200])
|
||||
role = await RoleCrud.first(db_session, [Role.id == role_id])
|
||||
assert role is not None
|
||||
assert role.name == "do_nothing_original"
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Tests for CRUD search functionality."""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
@@ -272,6 +274,27 @@ class TestPaginateSearch:
|
||||
usernames = [u.username for u in result["data"]]
|
||||
assert usernames == ["alice", "bob", "charlie"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_search_non_string_column(self, db_session: AsyncSession):
|
||||
"""Search on non-string columns (e.g., UUID) works via cast."""
|
||||
user_id = uuid.UUID("12345678-1234-5678-1234-567812345678")
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(id=user_id, username="john", email="john@test.com")
|
||||
)
|
||||
await UserCrud.create(
|
||||
db_session, UserCreate(username="jane", email="jane@test.com")
|
||||
)
|
||||
|
||||
# Search by UUID (partial match)
|
||||
result = await UserCrud.paginate(
|
||||
db_session,
|
||||
search="12345678",
|
||||
search_fields=[User.id, User.username],
|
||||
)
|
||||
|
||||
assert result["pagination"]["total_count"] == 1
|
||||
assert result["data"][0].id == user_id
|
||||
|
||||
|
||||
class TestSearchConfig:
|
||||
"""Tests for SearchConfig options."""
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Tests for fastapi_toolsets.fixtures module."""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
@@ -57,20 +59,22 @@ class TestFixtureRegistry:
|
||||
def test_register_with_decorator(self):
|
||||
"""Register fixture with decorator."""
|
||||
registry = FixtureRegistry()
|
||||
role_id = uuid.uuid4()
|
||||
|
||||
@registry.register
|
||||
def roles():
|
||||
return [Role(id=1, name="admin")]
|
||||
return [Role(id=role_id, name="admin")]
|
||||
|
||||
assert "roles" in [f.name for f in registry.get_all()]
|
||||
|
||||
def test_register_with_custom_name(self):
|
||||
"""Register fixture with custom name."""
|
||||
registry = FixtureRegistry()
|
||||
role_id = uuid.uuid4()
|
||||
|
||||
@registry.register(name="custom_roles")
|
||||
def roles():
|
||||
return [Role(id=1, name="admin")]
|
||||
return [Role(id=role_id, name="admin")]
|
||||
|
||||
fixture = registry.get("custom_roles")
|
||||
assert fixture.name == "custom_roles"
|
||||
@@ -78,14 +82,23 @@ class TestFixtureRegistry:
|
||||
def test_register_with_dependencies(self):
|
||||
"""Register fixture with dependencies."""
|
||||
registry = FixtureRegistry()
|
||||
role_id = uuid.uuid4()
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
@registry.register
|
||||
def roles():
|
||||
return [Role(id=1, name="admin")]
|
||||
return [Role(id=role_id, name="admin")]
|
||||
|
||||
@registry.register(depends_on=["roles"])
|
||||
def users():
|
||||
return [User(id=1, username="admin", email="admin@test.com", role_id=1)]
|
||||
return [
|
||||
User(
|
||||
id=user_id,
|
||||
username="admin",
|
||||
email="admin@test.com",
|
||||
role_id=role_id,
|
||||
)
|
||||
]
|
||||
|
||||
fixture = registry.get("users")
|
||||
assert fixture.depends_on == ["roles"]
|
||||
@@ -93,10 +106,11 @@ class TestFixtureRegistry:
|
||||
def test_register_with_contexts(self):
|
||||
"""Register fixture with contexts."""
|
||||
registry = FixtureRegistry()
|
||||
role_id = uuid.uuid4()
|
||||
|
||||
@registry.register(contexts=[Context.TESTING])
|
||||
def test_data():
|
||||
return [Role(id=100, name="test")]
|
||||
return [Role(id=role_id, name="test")]
|
||||
|
||||
fixture = registry.get("test_data")
|
||||
assert Context.TESTING.value in fixture.contexts
|
||||
@@ -244,12 +258,14 @@ class TestLoadFixtures:
|
||||
async def test_load_single_fixture(self, db_session: AsyncSession):
|
||||
"""Load a single fixture."""
|
||||
registry = FixtureRegistry()
|
||||
role_id_1 = uuid.uuid4()
|
||||
role_id_2 = uuid.uuid4()
|
||||
|
||||
@registry.register
|
||||
def roles():
|
||||
return [
|
||||
Role(id=1, name="admin"),
|
||||
Role(id=2, name="user"),
|
||||
Role(id=role_id_1, name="admin"),
|
||||
Role(id=role_id_2, name="user"),
|
||||
]
|
||||
|
||||
result = await load_fixtures(db_session, registry, "roles")
|
||||
@@ -266,14 +282,23 @@ class TestLoadFixtures:
|
||||
async def test_load_with_dependencies(self, db_session: AsyncSession):
|
||||
"""Load fixtures with dependencies."""
|
||||
registry = FixtureRegistry()
|
||||
role_id = uuid.uuid4()
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
@registry.register
|
||||
def roles():
|
||||
return [Role(id=1, name="admin")]
|
||||
return [Role(id=role_id, name="admin")]
|
||||
|
||||
@registry.register(depends_on=["roles"])
|
||||
def users():
|
||||
return [User(id=1, username="admin", email="admin@test.com", role_id=1)]
|
||||
return [
|
||||
User(
|
||||
id=user_id,
|
||||
username="admin",
|
||||
email="admin@test.com",
|
||||
role_id=role_id,
|
||||
)
|
||||
]
|
||||
|
||||
result = await load_fixtures(db_session, registry, "users")
|
||||
|
||||
@@ -289,10 +314,11 @@ class TestLoadFixtures:
|
||||
async def test_load_with_merge_strategy(self, db_session: AsyncSession):
|
||||
"""Load fixtures with MERGE strategy updates existing."""
|
||||
registry = FixtureRegistry()
|
||||
role_id = uuid.uuid4()
|
||||
|
||||
@registry.register
|
||||
def roles():
|
||||
return [Role(id=1, name="admin")]
|
||||
return [Role(id=role_id, name="admin")]
|
||||
|
||||
await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE)
|
||||
await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE)
|
||||
@@ -306,10 +332,11 @@ class TestLoadFixtures:
|
||||
async def test_load_with_skip_existing_strategy(self, db_session: AsyncSession):
|
||||
"""Load fixtures with SKIP_EXISTING strategy."""
|
||||
registry = FixtureRegistry()
|
||||
role_id = uuid.uuid4()
|
||||
|
||||
@registry.register
|
||||
def roles():
|
||||
return [Role(id=1, name="original")]
|
||||
return [Role(id=role_id, name="original")]
|
||||
|
||||
await load_fixtures(
|
||||
db_session, registry, "roles", strategy=LoadStrategy.SKIP_EXISTING
|
||||
@@ -317,7 +344,7 @@ class TestLoadFixtures:
|
||||
|
||||
@registry.register(name="roles_updated")
|
||||
def roles_v2():
|
||||
return [Role(id=1, name="updated")]
|
||||
return [Role(id=role_id, name="updated")]
|
||||
|
||||
registry._fixtures["roles"] = registry._fixtures.pop("roles_updated")
|
||||
|
||||
@@ -327,7 +354,7 @@ class TestLoadFixtures:
|
||||
|
||||
from .conftest import RoleCrud
|
||||
|
||||
role = await RoleCrud.first(db_session, [Role.id == 1])
|
||||
role = await RoleCrud.first(db_session, [Role.id == role_id])
|
||||
assert role is not None
|
||||
assert role.name == "original"
|
||||
|
||||
@@ -335,12 +362,14 @@ class TestLoadFixtures:
|
||||
async def test_load_with_insert_strategy(self, db_session: AsyncSession):
|
||||
"""Load fixtures with INSERT strategy."""
|
||||
registry = FixtureRegistry()
|
||||
role_id_1 = uuid.uuid4()
|
||||
role_id_2 = uuid.uuid4()
|
||||
|
||||
@registry.register
|
||||
def roles():
|
||||
return [
|
||||
Role(id=1, name="admin"),
|
||||
Role(id=2, name="user"),
|
||||
Role(id=role_id_1, name="admin"),
|
||||
Role(id=role_id_2, name="user"),
|
||||
]
|
||||
|
||||
result = await load_fixtures(
|
||||
@@ -375,14 +404,16 @@ class TestLoadFixtures:
|
||||
):
|
||||
"""Load multiple independent fixtures."""
|
||||
registry = FixtureRegistry()
|
||||
role_id_1 = uuid.uuid4()
|
||||
role_id_2 = uuid.uuid4()
|
||||
|
||||
@registry.register
|
||||
def roles():
|
||||
return [Role(id=1, name="admin")]
|
||||
return [Role(id=role_id_1, name="admin")]
|
||||
|
||||
@registry.register
|
||||
def other_roles():
|
||||
return [Role(id=2, name="user")]
|
||||
return [Role(id=role_id_2, name="user")]
|
||||
|
||||
result = await load_fixtures(db_session, registry, "roles", "other_roles")
|
||||
|
||||
@@ -402,14 +433,16 @@ class TestLoadFixturesByContext:
|
||||
async def test_load_by_single_context(self, db_session: AsyncSession):
|
||||
"""Load fixtures by single context."""
|
||||
registry = FixtureRegistry()
|
||||
base_role_id = uuid.uuid4()
|
||||
test_role_id = uuid.uuid4()
|
||||
|
||||
@registry.register(contexts=[Context.BASE])
|
||||
def base_roles():
|
||||
return [Role(id=1, name="base_role")]
|
||||
return [Role(id=base_role_id, name="base_role")]
|
||||
|
||||
@registry.register(contexts=[Context.TESTING])
|
||||
def test_roles():
|
||||
return [Role(id=100, name="test_role")]
|
||||
return [Role(id=test_role_id, name="test_role")]
|
||||
|
||||
await load_fixtures_by_context(db_session, registry, Context.BASE)
|
||||
|
||||
@@ -418,7 +451,7 @@ class TestLoadFixturesByContext:
|
||||
count = await RoleCrud.count(db_session)
|
||||
assert count == 1
|
||||
|
||||
role = await RoleCrud.first(db_session, [Role.id == 1])
|
||||
role = await RoleCrud.first(db_session, [Role.id == base_role_id])
|
||||
assert role is not None
|
||||
assert role.name == "base_role"
|
||||
|
||||
@@ -426,14 +459,16 @@ class TestLoadFixturesByContext:
|
||||
async def test_load_by_multiple_contexts(self, db_session: AsyncSession):
|
||||
"""Load fixtures by multiple contexts."""
|
||||
registry = FixtureRegistry()
|
||||
base_role_id = uuid.uuid4()
|
||||
test_role_id = uuid.uuid4()
|
||||
|
||||
@registry.register(contexts=[Context.BASE])
|
||||
def base_roles():
|
||||
return [Role(id=1, name="base_role")]
|
||||
return [Role(id=base_role_id, name="base_role")]
|
||||
|
||||
@registry.register(contexts=[Context.TESTING])
|
||||
def test_roles():
|
||||
return [Role(id=100, name="test_role")]
|
||||
return [Role(id=test_role_id, name="test_role")]
|
||||
|
||||
await load_fixtures_by_context(
|
||||
db_session, registry, Context.BASE, Context.TESTING
|
||||
@@ -448,14 +483,23 @@ class TestLoadFixturesByContext:
|
||||
async def test_load_context_with_dependencies(self, db_session: AsyncSession):
|
||||
"""Load context fixtures with cross-context dependencies."""
|
||||
registry = FixtureRegistry()
|
||||
role_id = uuid.uuid4()
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
@registry.register(contexts=[Context.BASE])
|
||||
def roles():
|
||||
return [Role(id=1, name="admin")]
|
||||
return [Role(id=role_id, name="admin")]
|
||||
|
||||
@registry.register(depends_on=["roles"], contexts=[Context.TESTING])
|
||||
def test_users():
|
||||
return [User(id=1, username="tester", email="test@test.com", role_id=1)]
|
||||
return [
|
||||
User(
|
||||
id=user_id,
|
||||
username="tester",
|
||||
email="test@test.com",
|
||||
role_id=role_id,
|
||||
)
|
||||
]
|
||||
|
||||
await load_fixtures_by_context(db_session, registry, Context.TESTING)
|
||||
|
||||
@@ -471,20 +515,41 @@ class TestGetObjByAttr:
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures for each test."""
|
||||
self.registry = FixtureRegistry()
|
||||
self.role_id_1 = uuid.uuid4()
|
||||
self.role_id_2 = uuid.uuid4()
|
||||
self.role_id_3 = uuid.uuid4()
|
||||
self.user_id_1 = uuid.uuid4()
|
||||
self.user_id_2 = uuid.uuid4()
|
||||
|
||||
role_id_1 = self.role_id_1
|
||||
role_id_2 = self.role_id_2
|
||||
role_id_3 = self.role_id_3
|
||||
user_id_1 = self.user_id_1
|
||||
user_id_2 = self.user_id_2
|
||||
|
||||
@self.registry.register
|
||||
def roles() -> list[Role]:
|
||||
return [
|
||||
Role(id=1, name="admin"),
|
||||
Role(id=2, name="user"),
|
||||
Role(id=3, name="moderator"),
|
||||
Role(id=role_id_1, name="admin"),
|
||||
Role(id=role_id_2, name="user"),
|
||||
Role(id=role_id_3, name="moderator"),
|
||||
]
|
||||
|
||||
@self.registry.register(depends_on=["roles"])
|
||||
def users() -> list[User]:
|
||||
return [
|
||||
User(id=1, username="alice", email="alice@example.com", role_id=1),
|
||||
User(id=2, username="bob", email="bob@example.com", role_id=1),
|
||||
User(
|
||||
id=user_id_1,
|
||||
username="alice",
|
||||
email="alice@example.com",
|
||||
role_id=role_id_1,
|
||||
),
|
||||
User(
|
||||
id=user_id_2,
|
||||
username="bob",
|
||||
email="bob@example.com",
|
||||
role_id=role_id_1,
|
||||
),
|
||||
]
|
||||
|
||||
self.roles = roles
|
||||
@@ -492,18 +557,18 @@ class TestGetObjByAttr:
|
||||
|
||||
def test_get_by_id(self):
|
||||
"""Get an object by its id attribute."""
|
||||
role = get_obj_by_attr(self.roles, "id", 1)
|
||||
role = get_obj_by_attr(self.roles, "id", self.role_id_1)
|
||||
assert role.name == "admin"
|
||||
|
||||
def test_get_user_by_username(self):
|
||||
"""Get a user by username."""
|
||||
user = get_obj_by_attr(self.users, "username", "bob")
|
||||
assert user.id == 2
|
||||
assert user.id == self.user_id_2
|
||||
assert user.email == "bob@example.com"
|
||||
|
||||
def test_returns_first_match(self):
|
||||
"""Returns the first matching object when multiple could match."""
|
||||
user = get_obj_by_attr(self.users, "role_id", 1)
|
||||
user = get_obj_by_attr(self.users, "role_id", self.role_id_1)
|
||||
assert user.username == "alice"
|
||||
|
||||
def test_no_match_raises_stop_iteration(self):
|
||||
@@ -514,4 +579,4 @@ class TestGetObjByAttr:
|
||||
def test_no_match_on_wrong_value_type(self):
|
||||
"""Raises StopIteration when value type doesn't match."""
|
||||
with pytest.raises(StopIteration):
|
||||
get_obj_by_attr(self.roles, "id", "1")
|
||||
get_obj_by_attr(self.roles, "id", "not-a-uuid")
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Tests for fastapi_toolsets.pytest module."""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from httpx import AsyncClient
|
||||
@@ -18,27 +20,49 @@ from .conftest import DATABASE_URL, Base, Role, RoleCrud, User, UserCrud
|
||||
|
||||
test_registry = FixtureRegistry()
|
||||
|
||||
# Fixed UUIDs for test fixtures to allow consistent assertions
|
||||
ROLE_ADMIN_ID = uuid.UUID("00000000-0000-0000-0000-000000001000")
|
||||
ROLE_USER_ID = uuid.UUID("00000000-0000-0000-0000-000000001001")
|
||||
USER_ADMIN_ID = uuid.UUID("00000000-0000-0000-0000-000000002000")
|
||||
USER_USER_ID = uuid.UUID("00000000-0000-0000-0000-000000002001")
|
||||
USER_EXTRA_ID = uuid.UUID("00000000-0000-0000-0000-000000002002")
|
||||
|
||||
|
||||
@test_registry.register(contexts=[Context.BASE])
|
||||
def roles() -> list[Role]:
|
||||
return [
|
||||
Role(id=1000, name="plugin_admin"),
|
||||
Role(id=1001, name="plugin_user"),
|
||||
Role(id=ROLE_ADMIN_ID, name="plugin_admin"),
|
||||
Role(id=ROLE_USER_ID, name="plugin_user"),
|
||||
]
|
||||
|
||||
|
||||
@test_registry.register(depends_on=["roles"], contexts=[Context.BASE])
|
||||
def users() -> list[User]:
|
||||
return [
|
||||
User(id=1000, username="plugin_admin", email="padmin@test.com", role_id=1000),
|
||||
User(id=1001, username="plugin_user", email="puser@test.com", role_id=1001),
|
||||
User(
|
||||
id=USER_ADMIN_ID,
|
||||
username="plugin_admin",
|
||||
email="padmin@test.com",
|
||||
role_id=ROLE_ADMIN_ID,
|
||||
),
|
||||
User(
|
||||
id=USER_USER_ID,
|
||||
username="plugin_user",
|
||||
email="puser@test.com",
|
||||
role_id=ROLE_USER_ID,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@test_registry.register(depends_on=["users"], contexts=[Context.TESTING])
|
||||
def extra_users() -> list[User]:
|
||||
return [
|
||||
User(id=1002, username="plugin_extra", email="pextra@test.com", role_id=1001),
|
||||
User(
|
||||
id=USER_EXTRA_ID,
|
||||
username="plugin_extra",
|
||||
email="pextra@test.com",
|
||||
role_id=ROLE_USER_ID,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@@ -73,7 +97,7 @@ class TestGeneratedFixtures:
|
||||
assert fixture_roles[1].name == "plugin_user"
|
||||
|
||||
# Verify data is in database
|
||||
count = await RoleCrud.count(db_session, [Role.id >= 1000])
|
||||
count = await RoleCrud.count(db_session)
|
||||
assert count == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -86,11 +110,11 @@ class TestGeneratedFixtures:
|
||||
assert len(fixture_users) == 2
|
||||
|
||||
# Roles should also be in database
|
||||
roles_count = await RoleCrud.count(db_session, [Role.id >= 1000])
|
||||
roles_count = await RoleCrud.count(db_session)
|
||||
assert roles_count == 2
|
||||
|
||||
# Users should be in database
|
||||
users_count = await UserCrud.count(db_session, [User.id >= 1000])
|
||||
users_count = await UserCrud.count(db_session)
|
||||
assert users_count == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -100,7 +124,7 @@ class TestGeneratedFixtures:
|
||||
"""Fixture returns actual model instances."""
|
||||
user = fixture_users[0]
|
||||
assert isinstance(user, User)
|
||||
assert user.id == 1000
|
||||
assert user.id == USER_ADMIN_ID
|
||||
assert user.username == "plugin_admin"
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -111,7 +135,7 @@ class TestGeneratedFixtures:
|
||||
# Load user with role relationship
|
||||
user = await UserCrud.get(
|
||||
db_session,
|
||||
[User.id == 1000],
|
||||
[User.id == USER_ADMIN_ID],
|
||||
load_options=[selectinload(User.role)],
|
||||
)
|
||||
|
||||
@@ -127,8 +151,8 @@ class TestGeneratedFixtures:
|
||||
assert len(fixture_extra_users) == 1
|
||||
|
||||
# All fixtures should be loaded
|
||||
roles_count = await RoleCrud.count(db_session, [Role.id >= 1000])
|
||||
users_count = await UserCrud.count(db_session, [User.id >= 1000])
|
||||
roles_count = await RoleCrud.count(db_session)
|
||||
users_count = await UserCrud.count(db_session)
|
||||
|
||||
assert roles_count == 2
|
||||
assert users_count == 3 # 2 from users + 1 from extra_users
|
||||
@@ -141,8 +165,7 @@ class TestGeneratedFixtures:
|
||||
# Get all users loaded by fixture
|
||||
users = await UserCrud.get_multi(
|
||||
db_session,
|
||||
filters=[User.id >= 1000],
|
||||
order_by=User.id,
|
||||
order_by=User.username,
|
||||
)
|
||||
|
||||
assert len(users) == 2
|
||||
@@ -161,8 +184,8 @@ class TestGeneratedFixtures:
|
||||
assert len(fixture_users) == 2
|
||||
|
||||
# Both should be in database
|
||||
roles = await RoleCrud.get_multi(db_session, filters=[Role.id >= 1000])
|
||||
users = await UserCrud.get_multi(db_session, filters=[User.id >= 1000])
|
||||
roles = await RoleCrud.get_multi(db_session)
|
||||
users = await UserCrud.get_multi(db_session)
|
||||
|
||||
assert len(roles) == 2
|
||||
assert len(users) == 2
|
||||
@@ -215,14 +238,15 @@ class TestCreateDbSession:
|
||||
@pytest.mark.anyio
|
||||
async def test_creates_working_session(self):
|
||||
"""Session can perform database operations."""
|
||||
role_id = uuid.uuid4()
|
||||
async with create_db_session(DATABASE_URL, Base) as session:
|
||||
assert isinstance(session, AsyncSession)
|
||||
|
||||
role = Role(id=9001, name="test_helper_role")
|
||||
role = Role(id=role_id, name="test_helper_role")
|
||||
session.add(role)
|
||||
await session.commit()
|
||||
|
||||
result = await session.execute(select(Role).where(Role.id == 9001))
|
||||
result = await session.execute(select(Role).where(Role.id == role_id))
|
||||
fetched = result.scalar_one()
|
||||
assert fetched.name == "test_helper_role"
|
||||
|
||||
@@ -237,8 +261,9 @@ class TestCreateDbSession:
|
||||
@pytest.mark.anyio
|
||||
async def test_tables_dropped_after_session(self):
|
||||
"""Tables are dropped after session closes when drop_tables=True."""
|
||||
role_id = uuid.uuid4()
|
||||
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
|
||||
role = Role(id=9002, name="will_be_dropped")
|
||||
role = Role(id=role_id, name="will_be_dropped")
|
||||
session.add(role)
|
||||
await session.commit()
|
||||
|
||||
@@ -250,14 +275,15 @@ class TestCreateDbSession:
|
||||
@pytest.mark.anyio
|
||||
async def test_tables_preserved_when_drop_disabled(self):
|
||||
"""Tables are preserved when drop_tables=False."""
|
||||
role_id = uuid.uuid4()
|
||||
async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session:
|
||||
role = Role(id=9003, name="preserved_role")
|
||||
role = Role(id=role_id, name="preserved_role")
|
||||
session.add(role)
|
||||
await session.commit()
|
||||
|
||||
# Create another session without dropping
|
||||
async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session:
|
||||
result = await session.execute(select(Role).where(Role.id == 9003))
|
||||
result = await session.execute(select(Role).where(Role.id == role_id))
|
||||
fetched = result.scalar_one_or_none()
|
||||
assert fetched is not None
|
||||
assert fetched.name == "preserved_role"
|
||||
|
||||
Reference in New Issue
Block a user