tests: change models to use UUID as primary key

This commit is contained in:
2026-01-29 13:43:03 -05:00
parent 5264631550
commit b57ce40b05
5 changed files with 187 additions and 104 deletions

View File

@@ -1,10 +1,11 @@
"""Shared pytest fixtures for fastapi-utils tests.""" """Shared pytest fixtures for fastapi-utils tests."""
import os import os
import uuid
import pytest import pytest
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import ForeignKey, String from sqlalchemy import ForeignKey, String, Uuid
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
@@ -33,7 +34,7 @@ class Role(Base):
__tablename__ = "roles" __tablename__ = "roles"
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
name: Mapped[str] = mapped_column(String(50), unique=True) name: Mapped[str] = mapped_column(String(50), unique=True)
users: Mapped[list["User"]] = relationship(back_populates="role") users: Mapped[list["User"]] = relationship(back_populates="role")
@@ -44,11 +45,13 @@ class User(Base):
__tablename__ = "users" __tablename__ = "users"
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
username: Mapped[str] = mapped_column(String(50), unique=True) username: Mapped[str] = mapped_column(String(50), unique=True)
email: Mapped[str] = mapped_column(String(100), unique=True) email: Mapped[str] = mapped_column(String(100), unique=True)
is_active: Mapped[bool] = mapped_column(default=True) is_active: Mapped[bool] = mapped_column(default=True)
role_id: Mapped[int | None] = mapped_column(ForeignKey("roles.id"), nullable=True) role_id: Mapped[uuid.UUID | None] = mapped_column(
ForeignKey("roles.id"), nullable=True
)
role: Mapped[Role | None] = relationship(back_populates="users") role: Mapped[Role | None] = relationship(back_populates="users")
@@ -58,11 +61,11 @@ class Post(Base):
__tablename__ = "posts" __tablename__ = "posts"
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
title: Mapped[str] = mapped_column(String(200)) title: Mapped[str] = mapped_column(String(200))
content: Mapped[str] = mapped_column(String(1000), default="") content: Mapped[str] = mapped_column(String(1000), default="")
is_published: Mapped[bool] = mapped_column(default=False) is_published: Mapped[bool] = mapped_column(default=False)
author_id: Mapped[int] = mapped_column(ForeignKey("users.id")) author_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id"))
# ============================================================================= # =============================================================================
@@ -73,7 +76,7 @@ class Post(Base):
class RoleCreate(BaseModel): class RoleCreate(BaseModel):
"""Schema for creating a role.""" """Schema for creating a role."""
id: int | None = None id: uuid.UUID | None = None
name: str name: str
@@ -86,11 +89,11 @@ class RoleUpdate(BaseModel):
class UserCreate(BaseModel): class UserCreate(BaseModel):
"""Schema for creating a user.""" """Schema for creating a user."""
id: int | None = None id: uuid.UUID | None = None
username: str username: str
email: str email: str
is_active: bool = True is_active: bool = True
role_id: int | None = None role_id: uuid.UUID | None = None
class UserUpdate(BaseModel): class UserUpdate(BaseModel):
@@ -99,17 +102,17 @@ class UserUpdate(BaseModel):
username: str | None = None username: str | None = None
email: str | None = None email: str | None = None
is_active: bool | None = None is_active: bool | None = None
role_id: int | None = None role_id: uuid.UUID | None = None
class PostCreate(BaseModel): class PostCreate(BaseModel):
"""Schema for creating a post.""" """Schema for creating a post."""
id: int | None = None id: uuid.UUID | None = None
title: str title: str
content: str = "" content: str = ""
is_published: bool = False is_published: bool = False
author_id: int author_id: uuid.UUID
class PostUpdate(BaseModel): class PostUpdate(BaseModel):
@@ -195,5 +198,5 @@ def sample_post_data() -> PostCreate:
title="Test Post", title="Test Post",
content="Test content", content="Test content",
is_published=True, is_published=True,
author_id=1, author_id=uuid.uuid4(),
) )

View File

@@ -1,5 +1,7 @@
"""Tests for fastapi_toolsets.crud module.""" """Tests for fastapi_toolsets.crud module."""
import uuid
import pytest import pytest
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@@ -89,8 +91,9 @@ class TestCrudGet:
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_raises_not_found(self, db_session: AsyncSession): async def test_get_raises_not_found(self, db_session: AsyncSession):
"""Get raises NotFoundError for missing records.""" """Get raises NotFoundError for missing records."""
non_existent_id = uuid.uuid4()
with pytest.raises(NotFoundError): with pytest.raises(NotFoundError):
await RoleCrud.get(db_session, [Role.id == 99999]) await RoleCrud.get(db_session, [Role.id == non_existent_id])
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_with_multiple_filters(self, db_session: AsyncSession): async def test_get_with_multiple_filters(self, db_session: AsyncSession):
@@ -223,11 +226,12 @@ class TestCrudUpdate:
@pytest.mark.anyio @pytest.mark.anyio
async def test_update_raises_not_found(self, db_session: AsyncSession): async def test_update_raises_not_found(self, db_session: AsyncSession):
"""Update raises NotFoundError for missing records.""" """Update raises NotFoundError for missing records."""
non_existent_id = uuid.uuid4()
with pytest.raises(NotFoundError): with pytest.raises(NotFoundError):
await RoleCrud.update( await RoleCrud.update(
db_session, db_session,
RoleUpdate(name="new"), RoleUpdate(name="new"),
[Role.id == 99999], [Role.id == non_existent_id],
) )
@pytest.mark.anyio @pytest.mark.anyio
@@ -340,7 +344,8 @@ class TestCrudUpsert:
@pytest.mark.anyio @pytest.mark.anyio
async def test_upsert_insert_new_record(self, db_session: AsyncSession): async def test_upsert_insert_new_record(self, db_session: AsyncSession):
"""Upsert inserts a new record when it doesn't exist.""" """Upsert inserts a new record when it doesn't exist."""
data = RoleCreate(id=1, name="upsert_new") role_id = uuid.uuid4()
data = RoleCreate(id=role_id, name="upsert_new")
role = await RoleCrud.upsert( role = await RoleCrud.upsert(
db_session, db_session,
data, data,
@@ -353,12 +358,13 @@ class TestCrudUpsert:
@pytest.mark.anyio @pytest.mark.anyio
async def test_upsert_update_existing_record(self, db_session: AsyncSession): async def test_upsert_update_existing_record(self, db_session: AsyncSession):
"""Upsert updates an existing record.""" """Upsert updates an existing record."""
role_id = uuid.uuid4()
# First insert # First insert
data = RoleCreate(id=100, name="original_name") data = RoleCreate(id=role_id, name="original_name")
await RoleCrud.upsert(db_session, data, index_elements=["id"]) await RoleCrud.upsert(db_session, data, index_elements=["id"])
# Upsert with update # Upsert with update
updated_data = RoleCreate(id=100, name="updated_name") updated_data = RoleCreate(id=role_id, name="updated_name")
role = await RoleCrud.upsert( role = await RoleCrud.upsert(
db_session, db_session,
updated_data, updated_data,
@@ -370,22 +376,23 @@ class TestCrudUpsert:
assert role.name == "updated_name" assert role.name == "updated_name"
# Verify only one record exists # Verify only one record exists
count = await RoleCrud.count(db_session, [Role.id == 100]) count = await RoleCrud.count(db_session, [Role.id == role_id])
assert count == 1 assert count == 1
@pytest.mark.anyio @pytest.mark.anyio
async def test_upsert_do_nothing_on_conflict(self, db_session: AsyncSession): async def test_upsert_do_nothing_on_conflict(self, db_session: AsyncSession):
"""Upsert does nothing on conflict when set_ is not provided.""" """Upsert does nothing on conflict when set_ is not provided."""
role_id = uuid.uuid4()
# First insert # First insert
data = RoleCreate(id=200, name="do_nothing_original") data = RoleCreate(id=role_id, name="do_nothing_original")
await RoleCrud.upsert(db_session, data, index_elements=["id"]) await RoleCrud.upsert(db_session, data, index_elements=["id"])
# Upsert without set_ (do nothing) # Upsert without set_ (do nothing)
conflict_data = RoleCreate(id=200, name="do_nothing_conflict") conflict_data = RoleCreate(id=role_id, name="do_nothing_conflict")
await RoleCrud.upsert(db_session, conflict_data, index_elements=["id"]) await RoleCrud.upsert(db_session, conflict_data, index_elements=["id"])
# Original value should be preserved # Original value should be preserved
role = await RoleCrud.first(db_session, [Role.id == 200]) role = await RoleCrud.first(db_session, [Role.id == role_id])
assert role is not None assert role is not None
assert role.name == "do_nothing_original" assert role.name == "do_nothing_original"

View File

@@ -3,15 +3,11 @@
import uuid import uuid
import pytest import pytest
from pydantic import BaseModel
from sqlalchemy import String
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Mapped, mapped_column
from fastapi_toolsets.crud import CrudFactory, SearchConfig, get_searchable_fields from fastapi_toolsets.crud import SearchConfig, get_searchable_fields
from .conftest import ( from .conftest import (
Base,
Role, Role,
RoleCreate, RoleCreate,
RoleCrud, RoleCrud,
@@ -281,37 +277,23 @@ class TestPaginateSearch:
@pytest.mark.anyio @pytest.mark.anyio
async def test_search_non_string_column(self, db_session: AsyncSession): async def test_search_non_string_column(self, db_session: AsyncSession):
"""Search on non-string columns (e.g., UUID) works via cast.""" """Search on non-string columns (e.g., UUID) works via cast."""
user_id = uuid.UUID("12345678-1234-5678-1234-567812345678")
class Account(Base): await UserCrud.create(
__tablename__ = "accounts" db_session, UserCreate(id=user_id, username="john", email="john@test.com")
id: Mapped[uuid.UUID] = mapped_column(primary_key=True, default=uuid.uuid4) )
name: Mapped[str] = mapped_column(String(100)) await UserCrud.create(
db_session, UserCreate(username="jane", email="jane@test.com")
class AccountCreate(BaseModel):
id: uuid.UUID | None = None
name: str
AccountCrud = CrudFactory(Account)
# Create table for this test
async with db_session.get_bind().begin() as conn:
await conn.run_sync(Base.metadata.create_all)
account_id = uuid.UUID("12345678-1234-5678-1234-567812345678")
await AccountCrud.create(
db_session, AccountCreate(id=account_id, name="Test Account")
) )
await AccountCrud.create(db_session, AccountCreate(name="Other Account"))
# Search by UUID (partial match) # Search by UUID (partial match)
result = await AccountCrud.paginate( result = await UserCrud.paginate(
db_session, db_session,
search="12345678", search="12345678",
search_fields=[Account.id, Account.name], search_fields=[User.id, User.username],
) )
assert result["pagination"]["total_count"] == 1 assert result["pagination"]["total_count"] == 1
assert result["data"][0].id == account_id assert result["data"][0].id == user_id
class TestSearchConfig: class TestSearchConfig:

View File

@@ -1,5 +1,7 @@
"""Tests for fastapi_toolsets.fixtures module.""" """Tests for fastapi_toolsets.fixtures module."""
import uuid
import pytest import pytest
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@@ -57,20 +59,22 @@ class TestFixtureRegistry:
def test_register_with_decorator(self): def test_register_with_decorator(self):
"""Register fixture with decorator.""" """Register fixture with decorator."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id = uuid.uuid4()
@registry.register @registry.register
def roles(): def roles():
return [Role(id=1, name="admin")] return [Role(id=role_id, name="admin")]
assert "roles" in [f.name for f in registry.get_all()] assert "roles" in [f.name for f in registry.get_all()]
def test_register_with_custom_name(self): def test_register_with_custom_name(self):
"""Register fixture with custom name.""" """Register fixture with custom name."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id = uuid.uuid4()
@registry.register(name="custom_roles") @registry.register(name="custom_roles")
def roles(): def roles():
return [Role(id=1, name="admin")] return [Role(id=role_id, name="admin")]
fixture = registry.get("custom_roles") fixture = registry.get("custom_roles")
assert fixture.name == "custom_roles" assert fixture.name == "custom_roles"
@@ -78,14 +82,23 @@ class TestFixtureRegistry:
def test_register_with_dependencies(self): def test_register_with_dependencies(self):
"""Register fixture with dependencies.""" """Register fixture with dependencies."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id = uuid.uuid4()
user_id = uuid.uuid4()
@registry.register @registry.register
def roles(): def roles():
return [Role(id=1, name="admin")] return [Role(id=role_id, name="admin")]
@registry.register(depends_on=["roles"]) @registry.register(depends_on=["roles"])
def users(): def users():
return [User(id=1, username="admin", email="admin@test.com", role_id=1)] return [
User(
id=user_id,
username="admin",
email="admin@test.com",
role_id=role_id,
)
]
fixture = registry.get("users") fixture = registry.get("users")
assert fixture.depends_on == ["roles"] assert fixture.depends_on == ["roles"]
@@ -93,10 +106,11 @@ class TestFixtureRegistry:
def test_register_with_contexts(self): def test_register_with_contexts(self):
"""Register fixture with contexts.""" """Register fixture with contexts."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id = uuid.uuid4()
@registry.register(contexts=[Context.TESTING]) @registry.register(contexts=[Context.TESTING])
def test_data(): def test_data():
return [Role(id=100, name="test")] return [Role(id=role_id, name="test")]
fixture = registry.get("test_data") fixture = registry.get("test_data")
assert Context.TESTING.value in fixture.contexts assert Context.TESTING.value in fixture.contexts
@@ -244,12 +258,14 @@ class TestLoadFixtures:
async def test_load_single_fixture(self, db_session: AsyncSession): async def test_load_single_fixture(self, db_session: AsyncSession):
"""Load a single fixture.""" """Load a single fixture."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id_1 = uuid.uuid4()
role_id_2 = uuid.uuid4()
@registry.register @registry.register
def roles(): def roles():
return [ return [
Role(id=1, name="admin"), Role(id=role_id_1, name="admin"),
Role(id=2, name="user"), Role(id=role_id_2, name="user"),
] ]
result = await load_fixtures(db_session, registry, "roles") result = await load_fixtures(db_session, registry, "roles")
@@ -266,14 +282,23 @@ class TestLoadFixtures:
async def test_load_with_dependencies(self, db_session: AsyncSession): async def test_load_with_dependencies(self, db_session: AsyncSession):
"""Load fixtures with dependencies.""" """Load fixtures with dependencies."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id = uuid.uuid4()
user_id = uuid.uuid4()
@registry.register @registry.register
def roles(): def roles():
return [Role(id=1, name="admin")] return [Role(id=role_id, name="admin")]
@registry.register(depends_on=["roles"]) @registry.register(depends_on=["roles"])
def users(): def users():
return [User(id=1, username="admin", email="admin@test.com", role_id=1)] return [
User(
id=user_id,
username="admin",
email="admin@test.com",
role_id=role_id,
)
]
result = await load_fixtures(db_session, registry, "users") result = await load_fixtures(db_session, registry, "users")
@@ -289,10 +314,11 @@ class TestLoadFixtures:
async def test_load_with_merge_strategy(self, db_session: AsyncSession): async def test_load_with_merge_strategy(self, db_session: AsyncSession):
"""Load fixtures with MERGE strategy updates existing.""" """Load fixtures with MERGE strategy updates existing."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id = uuid.uuid4()
@registry.register @registry.register
def roles(): def roles():
return [Role(id=1, name="admin")] return [Role(id=role_id, name="admin")]
await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE) await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE)
await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE) await load_fixtures(db_session, registry, "roles", strategy=LoadStrategy.MERGE)
@@ -306,10 +332,11 @@ class TestLoadFixtures:
async def test_load_with_skip_existing_strategy(self, db_session: AsyncSession): async def test_load_with_skip_existing_strategy(self, db_session: AsyncSession):
"""Load fixtures with SKIP_EXISTING strategy.""" """Load fixtures with SKIP_EXISTING strategy."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id = uuid.uuid4()
@registry.register @registry.register
def roles(): def roles():
return [Role(id=1, name="original")] return [Role(id=role_id, name="original")]
await load_fixtures( await load_fixtures(
db_session, registry, "roles", strategy=LoadStrategy.SKIP_EXISTING db_session, registry, "roles", strategy=LoadStrategy.SKIP_EXISTING
@@ -317,7 +344,7 @@ class TestLoadFixtures:
@registry.register(name="roles_updated") @registry.register(name="roles_updated")
def roles_v2(): def roles_v2():
return [Role(id=1, name="updated")] return [Role(id=role_id, name="updated")]
registry._fixtures["roles"] = registry._fixtures.pop("roles_updated") registry._fixtures["roles"] = registry._fixtures.pop("roles_updated")
@@ -327,7 +354,7 @@ class TestLoadFixtures:
from .conftest import RoleCrud from .conftest import RoleCrud
role = await RoleCrud.first(db_session, [Role.id == 1]) role = await RoleCrud.first(db_session, [Role.id == role_id])
assert role is not None assert role is not None
assert role.name == "original" assert role.name == "original"
@@ -335,12 +362,14 @@ class TestLoadFixtures:
async def test_load_with_insert_strategy(self, db_session: AsyncSession): async def test_load_with_insert_strategy(self, db_session: AsyncSession):
"""Load fixtures with INSERT strategy.""" """Load fixtures with INSERT strategy."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id_1 = uuid.uuid4()
role_id_2 = uuid.uuid4()
@registry.register @registry.register
def roles(): def roles():
return [ return [
Role(id=1, name="admin"), Role(id=role_id_1, name="admin"),
Role(id=2, name="user"), Role(id=role_id_2, name="user"),
] ]
result = await load_fixtures( result = await load_fixtures(
@@ -375,14 +404,16 @@ class TestLoadFixtures:
): ):
"""Load multiple independent fixtures.""" """Load multiple independent fixtures."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id_1 = uuid.uuid4()
role_id_2 = uuid.uuid4()
@registry.register @registry.register
def roles(): def roles():
return [Role(id=1, name="admin")] return [Role(id=role_id_1, name="admin")]
@registry.register @registry.register
def other_roles(): def other_roles():
return [Role(id=2, name="user")] return [Role(id=role_id_2, name="user")]
result = await load_fixtures(db_session, registry, "roles", "other_roles") result = await load_fixtures(db_session, registry, "roles", "other_roles")
@@ -402,14 +433,16 @@ class TestLoadFixturesByContext:
async def test_load_by_single_context(self, db_session: AsyncSession): async def test_load_by_single_context(self, db_session: AsyncSession):
"""Load fixtures by single context.""" """Load fixtures by single context."""
registry = FixtureRegistry() registry = FixtureRegistry()
base_role_id = uuid.uuid4()
test_role_id = uuid.uuid4()
@registry.register(contexts=[Context.BASE]) @registry.register(contexts=[Context.BASE])
def base_roles(): def base_roles():
return [Role(id=1, name="base_role")] return [Role(id=base_role_id, name="base_role")]
@registry.register(contexts=[Context.TESTING]) @registry.register(contexts=[Context.TESTING])
def test_roles(): def test_roles():
return [Role(id=100, name="test_role")] return [Role(id=test_role_id, name="test_role")]
await load_fixtures_by_context(db_session, registry, Context.BASE) await load_fixtures_by_context(db_session, registry, Context.BASE)
@@ -418,7 +451,7 @@ class TestLoadFixturesByContext:
count = await RoleCrud.count(db_session) count = await RoleCrud.count(db_session)
assert count == 1 assert count == 1
role = await RoleCrud.first(db_session, [Role.id == 1]) role = await RoleCrud.first(db_session, [Role.id == base_role_id])
assert role is not None assert role is not None
assert role.name == "base_role" assert role.name == "base_role"
@@ -426,14 +459,16 @@ class TestLoadFixturesByContext:
async def test_load_by_multiple_contexts(self, db_session: AsyncSession): async def test_load_by_multiple_contexts(self, db_session: AsyncSession):
"""Load fixtures by multiple contexts.""" """Load fixtures by multiple contexts."""
registry = FixtureRegistry() registry = FixtureRegistry()
base_role_id = uuid.uuid4()
test_role_id = uuid.uuid4()
@registry.register(contexts=[Context.BASE]) @registry.register(contexts=[Context.BASE])
def base_roles(): def base_roles():
return [Role(id=1, name="base_role")] return [Role(id=base_role_id, name="base_role")]
@registry.register(contexts=[Context.TESTING]) @registry.register(contexts=[Context.TESTING])
def test_roles(): def test_roles():
return [Role(id=100, name="test_role")] return [Role(id=test_role_id, name="test_role")]
await load_fixtures_by_context( await load_fixtures_by_context(
db_session, registry, Context.BASE, Context.TESTING db_session, registry, Context.BASE, Context.TESTING
@@ -448,14 +483,23 @@ class TestLoadFixturesByContext:
async def test_load_context_with_dependencies(self, db_session: AsyncSession): async def test_load_context_with_dependencies(self, db_session: AsyncSession):
"""Load context fixtures with cross-context dependencies.""" """Load context fixtures with cross-context dependencies."""
registry = FixtureRegistry() registry = FixtureRegistry()
role_id = uuid.uuid4()
user_id = uuid.uuid4()
@registry.register(contexts=[Context.BASE]) @registry.register(contexts=[Context.BASE])
def roles(): def roles():
return [Role(id=1, name="admin")] return [Role(id=role_id, name="admin")]
@registry.register(depends_on=["roles"], contexts=[Context.TESTING]) @registry.register(depends_on=["roles"], contexts=[Context.TESTING])
def test_users(): def test_users():
return [User(id=1, username="tester", email="test@test.com", role_id=1)] return [
User(
id=user_id,
username="tester",
email="test@test.com",
role_id=role_id,
)
]
await load_fixtures_by_context(db_session, registry, Context.TESTING) await load_fixtures_by_context(db_session, registry, Context.TESTING)
@@ -471,20 +515,41 @@ class TestGetObjByAttr:
def setup_method(self): def setup_method(self):
"""Set up test fixtures for each test.""" """Set up test fixtures for each test."""
self.registry = FixtureRegistry() self.registry = FixtureRegistry()
self.role_id_1 = uuid.uuid4()
self.role_id_2 = uuid.uuid4()
self.role_id_3 = uuid.uuid4()
self.user_id_1 = uuid.uuid4()
self.user_id_2 = uuid.uuid4()
role_id_1 = self.role_id_1
role_id_2 = self.role_id_2
role_id_3 = self.role_id_3
user_id_1 = self.user_id_1
user_id_2 = self.user_id_2
@self.registry.register @self.registry.register
def roles() -> list[Role]: def roles() -> list[Role]:
return [ return [
Role(id=1, name="admin"), Role(id=role_id_1, name="admin"),
Role(id=2, name="user"), Role(id=role_id_2, name="user"),
Role(id=3, name="moderator"), Role(id=role_id_3, name="moderator"),
] ]
@self.registry.register(depends_on=["roles"]) @self.registry.register(depends_on=["roles"])
def users() -> list[User]: def users() -> list[User]:
return [ return [
User(id=1, username="alice", email="alice@example.com", role_id=1), User(
User(id=2, username="bob", email="bob@example.com", role_id=1), id=user_id_1,
username="alice",
email="alice@example.com",
role_id=role_id_1,
),
User(
id=user_id_2,
username="bob",
email="bob@example.com",
role_id=role_id_1,
),
] ]
self.roles = roles self.roles = roles
@@ -492,18 +557,18 @@ class TestGetObjByAttr:
def test_get_by_id(self): def test_get_by_id(self):
"""Get an object by its id attribute.""" """Get an object by its id attribute."""
role = get_obj_by_attr(self.roles, "id", 1) role = get_obj_by_attr(self.roles, "id", self.role_id_1)
assert role.name == "admin" assert role.name == "admin"
def test_get_user_by_username(self): def test_get_user_by_username(self):
"""Get a user by username.""" """Get a user by username."""
user = get_obj_by_attr(self.users, "username", "bob") user = get_obj_by_attr(self.users, "username", "bob")
assert user.id == 2 assert user.id == self.user_id_2
assert user.email == "bob@example.com" assert user.email == "bob@example.com"
def test_returns_first_match(self): def test_returns_first_match(self):
"""Returns the first matching object when multiple could match.""" """Returns the first matching object when multiple could match."""
user = get_obj_by_attr(self.users, "role_id", 1) user = get_obj_by_attr(self.users, "role_id", self.role_id_1)
assert user.username == "alice" assert user.username == "alice"
def test_no_match_raises_stop_iteration(self): def test_no_match_raises_stop_iteration(self):
@@ -514,4 +579,4 @@ class TestGetObjByAttr:
def test_no_match_on_wrong_value_type(self): def test_no_match_on_wrong_value_type(self):
"""Raises StopIteration when value type doesn't match.""" """Raises StopIteration when value type doesn't match."""
with pytest.raises(StopIteration): with pytest.raises(StopIteration):
get_obj_by_attr(self.roles, "id", "1") get_obj_by_attr(self.roles, "id", "not-a-uuid")

View File

@@ -1,5 +1,7 @@
"""Tests for fastapi_toolsets.pytest module.""" """Tests for fastapi_toolsets.pytest module."""
import uuid
import pytest import pytest
from fastapi import FastAPI from fastapi import FastAPI
from httpx import AsyncClient from httpx import AsyncClient
@@ -18,27 +20,49 @@ from .conftest import DATABASE_URL, Base, Role, RoleCrud, User, UserCrud
test_registry = FixtureRegistry() test_registry = FixtureRegistry()
# Fixed UUIDs for test fixtures to allow consistent assertions
ROLE_ADMIN_ID = uuid.UUID("00000000-0000-0000-0000-000000001000")
ROLE_USER_ID = uuid.UUID("00000000-0000-0000-0000-000000001001")
USER_ADMIN_ID = uuid.UUID("00000000-0000-0000-0000-000000002000")
USER_USER_ID = uuid.UUID("00000000-0000-0000-0000-000000002001")
USER_EXTRA_ID = uuid.UUID("00000000-0000-0000-0000-000000002002")
@test_registry.register(contexts=[Context.BASE]) @test_registry.register(contexts=[Context.BASE])
def roles() -> list[Role]: def roles() -> list[Role]:
return [ return [
Role(id=1000, name="plugin_admin"), Role(id=ROLE_ADMIN_ID, name="plugin_admin"),
Role(id=1001, name="plugin_user"), Role(id=ROLE_USER_ID, name="plugin_user"),
] ]
@test_registry.register(depends_on=["roles"], contexts=[Context.BASE]) @test_registry.register(depends_on=["roles"], contexts=[Context.BASE])
def users() -> list[User]: def users() -> list[User]:
return [ return [
User(id=1000, username="plugin_admin", email="padmin@test.com", role_id=1000), User(
User(id=1001, username="plugin_user", email="puser@test.com", role_id=1001), id=USER_ADMIN_ID,
username="plugin_admin",
email="padmin@test.com",
role_id=ROLE_ADMIN_ID,
),
User(
id=USER_USER_ID,
username="plugin_user",
email="puser@test.com",
role_id=ROLE_USER_ID,
),
] ]
@test_registry.register(depends_on=["users"], contexts=[Context.TESTING]) @test_registry.register(depends_on=["users"], contexts=[Context.TESTING])
def extra_users() -> list[User]: def extra_users() -> list[User]:
return [ return [
User(id=1002, username="plugin_extra", email="pextra@test.com", role_id=1001), User(
id=USER_EXTRA_ID,
username="plugin_extra",
email="pextra@test.com",
role_id=ROLE_USER_ID,
),
] ]
@@ -73,7 +97,7 @@ class TestGeneratedFixtures:
assert fixture_roles[1].name == "plugin_user" assert fixture_roles[1].name == "plugin_user"
# Verify data is in database # Verify data is in database
count = await RoleCrud.count(db_session, [Role.id >= 1000]) count = await RoleCrud.count(db_session)
assert count == 2 assert count == 2
@pytest.mark.anyio @pytest.mark.anyio
@@ -86,11 +110,11 @@ class TestGeneratedFixtures:
assert len(fixture_users) == 2 assert len(fixture_users) == 2
# Roles should also be in database # Roles should also be in database
roles_count = await RoleCrud.count(db_session, [Role.id >= 1000]) roles_count = await RoleCrud.count(db_session)
assert roles_count == 2 assert roles_count == 2
# Users should be in database # Users should be in database
users_count = await UserCrud.count(db_session, [User.id >= 1000]) users_count = await UserCrud.count(db_session)
assert users_count == 2 assert users_count == 2
@pytest.mark.anyio @pytest.mark.anyio
@@ -100,7 +124,7 @@ class TestGeneratedFixtures:
"""Fixture returns actual model instances.""" """Fixture returns actual model instances."""
user = fixture_users[0] user = fixture_users[0]
assert isinstance(user, User) assert isinstance(user, User)
assert user.id == 1000 assert user.id == USER_ADMIN_ID
assert user.username == "plugin_admin" assert user.username == "plugin_admin"
@pytest.mark.anyio @pytest.mark.anyio
@@ -111,7 +135,7 @@ class TestGeneratedFixtures:
# Load user with role relationship # Load user with role relationship
user = await UserCrud.get( user = await UserCrud.get(
db_session, db_session,
[User.id == 1000], [User.id == USER_ADMIN_ID],
load_options=[selectinload(User.role)], load_options=[selectinload(User.role)],
) )
@@ -127,8 +151,8 @@ class TestGeneratedFixtures:
assert len(fixture_extra_users) == 1 assert len(fixture_extra_users) == 1
# All fixtures should be loaded # All fixtures should be loaded
roles_count = await RoleCrud.count(db_session, [Role.id >= 1000]) roles_count = await RoleCrud.count(db_session)
users_count = await UserCrud.count(db_session, [User.id >= 1000]) users_count = await UserCrud.count(db_session)
assert roles_count == 2 assert roles_count == 2
assert users_count == 3 # 2 from users + 1 from extra_users assert users_count == 3 # 2 from users + 1 from extra_users
@@ -141,8 +165,7 @@ class TestGeneratedFixtures:
# Get all users loaded by fixture # Get all users loaded by fixture
users = await UserCrud.get_multi( users = await UserCrud.get_multi(
db_session, db_session,
filters=[User.id >= 1000], order_by=User.username,
order_by=User.id,
) )
assert len(users) == 2 assert len(users) == 2
@@ -161,8 +184,8 @@ class TestGeneratedFixtures:
assert len(fixture_users) == 2 assert len(fixture_users) == 2
# Both should be in database # Both should be in database
roles = await RoleCrud.get_multi(db_session, filters=[Role.id >= 1000]) roles = await RoleCrud.get_multi(db_session)
users = await UserCrud.get_multi(db_session, filters=[User.id >= 1000]) users = await UserCrud.get_multi(db_session)
assert len(roles) == 2 assert len(roles) == 2
assert len(users) == 2 assert len(users) == 2
@@ -215,14 +238,15 @@ class TestCreateDbSession:
@pytest.mark.anyio @pytest.mark.anyio
async def test_creates_working_session(self): async def test_creates_working_session(self):
"""Session can perform database operations.""" """Session can perform database operations."""
role_id = uuid.uuid4()
async with create_db_session(DATABASE_URL, Base) as session: async with create_db_session(DATABASE_URL, Base) as session:
assert isinstance(session, AsyncSession) assert isinstance(session, AsyncSession)
role = Role(id=9001, name="test_helper_role") role = Role(id=role_id, name="test_helper_role")
session.add(role) session.add(role)
await session.commit() await session.commit()
result = await session.execute(select(Role).where(Role.id == 9001)) result = await session.execute(select(Role).where(Role.id == role_id))
fetched = result.scalar_one() fetched = result.scalar_one()
assert fetched.name == "test_helper_role" assert fetched.name == "test_helper_role"
@@ -237,8 +261,9 @@ class TestCreateDbSession:
@pytest.mark.anyio @pytest.mark.anyio
async def test_tables_dropped_after_session(self): async def test_tables_dropped_after_session(self):
"""Tables are dropped after session closes when drop_tables=True.""" """Tables are dropped after session closes when drop_tables=True."""
role_id = uuid.uuid4()
async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session: async with create_db_session(DATABASE_URL, Base, drop_tables=True) as session:
role = Role(id=9002, name="will_be_dropped") role = Role(id=role_id, name="will_be_dropped")
session.add(role) session.add(role)
await session.commit() await session.commit()
@@ -250,14 +275,15 @@ class TestCreateDbSession:
@pytest.mark.anyio @pytest.mark.anyio
async def test_tables_preserved_when_drop_disabled(self): async def test_tables_preserved_when_drop_disabled(self):
"""Tables are preserved when drop_tables=False.""" """Tables are preserved when drop_tables=False."""
role_id = uuid.uuid4()
async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session: async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session:
role = Role(id=9003, name="preserved_role") role = Role(id=role_id, name="preserved_role")
session.add(role) session.add(role)
await session.commit() await session.commit()
# Create another session without dropping # Create another session without dropping
async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session: async with create_db_session(DATABASE_URL, Base, drop_tables=False) as session:
result = await session.execute(select(Role).where(Role.id == 9003)) result = await session.execute(select(Role).where(Role.id == role_id))
fetched = result.scalar_one_or_none() fetched = result.scalar_one_or_none()
assert fetched is not None assert fetched is not None
assert fetched.name == "preserved_role" assert fetched.name == "preserved_role"