mirror of
https://github.com/d3vyce/sqlalchemy-pgview.git
synced 2026-03-01 18:00:47 +01:00
397 lines
12 KiB
Python
397 lines
12 KiB
Python
"""Tests for async engine support with asyncpg."""
|
|
|
|
import os
|
|
from decimal import Decimal
|
|
|
|
import pytest
|
|
from sqlalchemy import (
|
|
Column,
|
|
ForeignKey,
|
|
Integer,
|
|
MetaData,
|
|
Numeric,
|
|
String,
|
|
Table,
|
|
func,
|
|
insert,
|
|
select,
|
|
)
|
|
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
|
|
|
|
from sqlalchemy_pgview import (
|
|
CreateMaterializedView,
|
|
CreateView,
|
|
DropMaterializedView,
|
|
DropView,
|
|
MaterializedView,
|
|
RefreshMaterializedView,
|
|
View,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
async def async_engine() -> AsyncEngine:
|
|
"""Create an async PostgreSQL engine."""
|
|
url = os.environ.get("POSTGRES_URL")
|
|
if not url:
|
|
pytest.skip("POSTGRES_URL not set")
|
|
|
|
# Convert postgresql:// to postgresql+asyncpg://
|
|
async_url = url.replace("postgresql://", "postgresql+asyncpg://")
|
|
engine = create_async_engine(async_url)
|
|
|
|
yield engine
|
|
|
|
await engine.dispose()
|
|
|
|
|
|
@pytest.fixture
|
|
async def async_tables(async_engine: AsyncEngine) -> dict[str, Table]:
|
|
"""Create test tables with async engine."""
|
|
metadata = MetaData()
|
|
|
|
users = Table(
|
|
"async_users",
|
|
metadata,
|
|
Column("id", Integer, primary_key=True),
|
|
Column("name", String(100)),
|
|
Column("email", String(100)),
|
|
)
|
|
|
|
orders = Table(
|
|
"async_orders",
|
|
metadata,
|
|
Column("id", Integer, primary_key=True),
|
|
Column("user_id", Integer, ForeignKey("async_users.id")),
|
|
Column("total", Numeric(10, 2)),
|
|
)
|
|
|
|
async with async_engine.begin() as conn:
|
|
await conn.run_sync(metadata.create_all)
|
|
|
|
# Insert test data
|
|
await conn.execute(
|
|
insert(users),
|
|
[
|
|
{"id": 1, "name": "Alice", "email": "alice@example.com"},
|
|
{"id": 2, "name": "Bob", "email": "bob@example.com"},
|
|
],
|
|
)
|
|
|
|
await conn.execute(
|
|
insert(orders),
|
|
[
|
|
{"id": 1, "user_id": 1, "total": Decimal("100.00")},
|
|
{"id": 2, "user_id": 1, "total": Decimal("200.00")},
|
|
{"id": 3, "user_id": 2, "total": Decimal("150.00")},
|
|
],
|
|
)
|
|
|
|
yield {"users": users, "orders": orders, "metadata": metadata}
|
|
|
|
async with async_engine.begin() as conn:
|
|
await conn.run_sync(metadata.drop_all)
|
|
|
|
|
|
class TestAsyncView:
|
|
"""Tests for View with async engine."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_view_async(
|
|
self, async_engine: AsyncEngine, async_tables: dict
|
|
) -> None:
|
|
"""Test creating a view with async engine."""
|
|
users = async_tables["users"]
|
|
orders = async_tables["orders"]
|
|
|
|
user_stats = View(
|
|
"async_user_stats",
|
|
select(
|
|
users.c.id,
|
|
users.c.name,
|
|
func.count(orders.c.id).label("order_count"),
|
|
func.coalesce(func.sum(orders.c.total), 0).label("total_spent"),
|
|
)
|
|
.select_from(users.outerjoin(orders, users.c.id == orders.c.user_id))
|
|
.group_by(users.c.id, users.c.name),
|
|
)
|
|
|
|
async with async_engine.begin() as conn:
|
|
# Create view
|
|
await conn.execute(CreateView(user_stats, or_replace=True))
|
|
|
|
# Query view
|
|
result = await conn.execute(
|
|
select(user_stats.as_table()).order_by(user_stats.as_table().c.name)
|
|
)
|
|
rows = result.fetchall()
|
|
|
|
assert len(rows) == 2
|
|
assert rows[0].name == "Alice"
|
|
assert rows[0].order_count == 2
|
|
assert rows[0].total_spent == Decimal("300.00")
|
|
|
|
assert rows[1].name == "Bob"
|
|
assert rows[1].order_count == 1
|
|
assert rows[1].total_spent == Decimal("150.00")
|
|
|
|
# Drop view
|
|
await conn.execute(DropView(user_stats, if_exists=True))
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_view_reflects_changes_async(
|
|
self, async_engine: AsyncEngine, async_tables: dict
|
|
) -> None:
|
|
"""Test that view reflects changes immediately with async engine."""
|
|
orders = async_tables["orders"]
|
|
|
|
order_count_view = View(
|
|
"async_order_count",
|
|
select(func.count(orders.c.id).label("total_orders")),
|
|
)
|
|
|
|
async with async_engine.begin() as conn:
|
|
await conn.execute(CreateView(order_count_view, or_replace=True))
|
|
|
|
# Check initial count
|
|
result = await conn.execute(select(order_count_view.as_table()))
|
|
row = result.fetchone()
|
|
assert row.total_orders == 3
|
|
|
|
# Insert new order
|
|
await conn.execute(
|
|
insert(orders).values(id=100, user_id=1, total=Decimal("50.00"))
|
|
)
|
|
|
|
# View should show updated count
|
|
result = await conn.execute(select(order_count_view.as_table()))
|
|
row = result.fetchone()
|
|
assert row.total_orders == 4
|
|
|
|
await conn.execute(DropView(order_count_view, if_exists=True))
|
|
|
|
|
|
class TestAsyncMaterializedView:
|
|
"""Tests for MaterializedView with async engine."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_materialized_view_async(
|
|
self, async_engine: AsyncEngine, async_tables: dict
|
|
) -> None:
|
|
"""Test creating a materialized view with async engine."""
|
|
users = async_tables["users"]
|
|
orders = async_tables["orders"]
|
|
|
|
user_summary_mv = MaterializedView(
|
|
"async_user_summary_mv",
|
|
select(
|
|
users.c.id,
|
|
users.c.name,
|
|
func.count(orders.c.id).label("order_count"),
|
|
)
|
|
.select_from(users.outerjoin(orders, users.c.id == orders.c.user_id))
|
|
.group_by(users.c.id, users.c.name),
|
|
with_data=True,
|
|
)
|
|
|
|
async with async_engine.begin() as conn:
|
|
# Create materialized view
|
|
await conn.execute(CreateMaterializedView(user_summary_mv))
|
|
|
|
# Query materialized view
|
|
result = await conn.execute(
|
|
select(user_summary_mv.as_table()).order_by(
|
|
user_summary_mv.as_table().c.name
|
|
)
|
|
)
|
|
rows = result.fetchall()
|
|
|
|
assert len(rows) == 2
|
|
assert rows[0].name == "Alice"
|
|
assert rows[0].order_count == 2
|
|
|
|
# Drop materialized view
|
|
await conn.execute(DropMaterializedView(user_summary_mv, if_exists=True))
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_materialized_view_stale_until_refresh_async(
|
|
self, async_engine: AsyncEngine, async_tables: dict
|
|
) -> None:
|
|
"""Test that materialized view is stale until refreshed with async engine."""
|
|
orders = async_tables["orders"]
|
|
|
|
order_count_mv = MaterializedView(
|
|
"async_order_count_mv",
|
|
select(func.count(orders.c.id).label("total_orders")),
|
|
with_data=True,
|
|
)
|
|
|
|
async with async_engine.begin() as conn:
|
|
await conn.execute(CreateMaterializedView(order_count_mv))
|
|
|
|
# Check initial count
|
|
result = await conn.execute(select(order_count_mv.as_table()))
|
|
row = result.fetchone()
|
|
assert row.total_orders == 3
|
|
|
|
# Insert new order
|
|
await conn.execute(
|
|
insert(orders).values(id=101, user_id=1, total=Decimal("75.00"))
|
|
)
|
|
|
|
# Materialized view still shows old count (stale)
|
|
result = await conn.execute(select(order_count_mv.as_table()))
|
|
row = result.fetchone()
|
|
assert row.total_orders == 3 # Still 3!
|
|
|
|
# Refresh materialized view
|
|
await conn.execute(RefreshMaterializedView(order_count_mv))
|
|
|
|
# Now shows updated count
|
|
result = await conn.execute(select(order_count_mv.as_table()))
|
|
row = result.fetchone()
|
|
assert row.total_orders == 4
|
|
|
|
await conn.execute(DropMaterializedView(order_count_mv, if_exists=True))
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_materialized_view_refresh_method_async(
|
|
self, async_engine: AsyncEngine, async_tables: dict
|
|
) -> None:
|
|
"""Test MaterializedView.refresh() method with async engine."""
|
|
orders = async_tables["orders"]
|
|
|
|
count_mv = MaterializedView(
|
|
"async_count_mv",
|
|
select(func.count(orders.c.id).label("cnt")),
|
|
with_data=True,
|
|
)
|
|
|
|
async with async_engine.begin() as conn:
|
|
await conn.execute(CreateMaterializedView(count_mv))
|
|
|
|
result = await conn.execute(select(count_mv.as_table()))
|
|
assert result.fetchone().cnt == 3
|
|
|
|
# Insert data
|
|
await conn.execute(
|
|
insert(orders).values(id=102, user_id=2, total=Decimal("25.00"))
|
|
)
|
|
|
|
# Use run_sync to call the synchronous refresh method
|
|
def refresh_sync(sync_conn):
|
|
count_mv.refresh(sync_conn)
|
|
|
|
await conn.run_sync(refresh_sync)
|
|
|
|
result = await conn.execute(select(count_mv.as_table()))
|
|
assert result.fetchone().cnt == 4
|
|
|
|
await conn.execute(DropMaterializedView(count_mv, if_exists=True))
|
|
|
|
|
|
class TestAsyncDependencies:
|
|
"""Tests for dependency functions with async engine."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_all_views_async(
|
|
self, async_engine: AsyncEngine, async_tables: dict
|
|
) -> None:
|
|
"""Test get_all_views with async engine using run_sync."""
|
|
from sqlalchemy_pgview import get_all_views
|
|
|
|
users = async_tables["users"]
|
|
|
|
test_view = View(
|
|
"async_test_view_deps",
|
|
select(users.c.id, users.c.name),
|
|
)
|
|
|
|
async with async_engine.begin() as conn:
|
|
await conn.execute(CreateView(test_view, or_replace=True))
|
|
|
|
# Use run_sync for dependency functions
|
|
def get_views_sync(sync_conn):
|
|
return get_all_views(sync_conn)
|
|
|
|
views = await conn.run_sync(get_views_sync)
|
|
view_names = [v.name for v in views]
|
|
|
|
assert "async_test_view_deps" in view_names
|
|
|
|
await conn.execute(DropView(test_view, if_exists=True))
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_view_definition_async(
|
|
self, async_engine: AsyncEngine, async_tables: dict
|
|
) -> None:
|
|
"""Test get_view_definition with async engine."""
|
|
from sqlalchemy_pgview import get_view_definition
|
|
|
|
users = async_tables["users"]
|
|
|
|
test_view = View(
|
|
"async_def_test_view",
|
|
select(users.c.id, users.c.name).where(users.c.id > 0),
|
|
)
|
|
|
|
async with async_engine.begin() as conn:
|
|
await conn.execute(CreateView(test_view, or_replace=True))
|
|
|
|
def get_def_sync(sync_conn):
|
|
return get_view_definition(sync_conn, "async_def_test_view")
|
|
|
|
definition = await conn.run_sync(get_def_sync)
|
|
|
|
assert definition is not None
|
|
assert "async_users" in definition.lower()
|
|
|
|
await conn.execute(DropView(test_view, if_exists=True))
|
|
|
|
|
|
class TestAsyncViewWithJoins:
|
|
"""Tests for views with complex joins using async engine."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_view_with_aggregation_async(
|
|
self, async_engine: AsyncEngine, async_tables: dict
|
|
) -> None:
|
|
"""Test view with GROUP BY and aggregation functions."""
|
|
users = async_tables["users"]
|
|
orders = async_tables["orders"]
|
|
|
|
stats_view = View(
|
|
"async_stats_view",
|
|
select(
|
|
users.c.name,
|
|
func.count(orders.c.id).label("orders"),
|
|
func.sum(orders.c.total).label("total"),
|
|
func.avg(orders.c.total).label("avg_order"),
|
|
)
|
|
.select_from(users.join(orders, users.c.id == orders.c.user_id))
|
|
.group_by(users.c.name),
|
|
)
|
|
|
|
async with async_engine.begin() as conn:
|
|
await conn.execute(CreateView(stats_view, or_replace=True))
|
|
|
|
result = await conn.execute(
|
|
select(stats_view.as_table()).order_by(stats_view.as_table().c.total.desc())
|
|
)
|
|
rows = result.fetchall()
|
|
|
|
assert len(rows) == 2
|
|
|
|
# Alice: 2 orders, total 300
|
|
alice = next(r for r in rows if r.name == "Alice")
|
|
assert alice.orders == 2
|
|
assert alice.total == Decimal("300.00")
|
|
assert alice.avg_order == Decimal("150.00")
|
|
|
|
# Bob: 1 order, total 150
|
|
bob = next(r for r in rows if r.name == "Bob")
|
|
assert bob.orders == 1
|
|
assert bob.total == Decimal("150.00")
|
|
|
|
await conn.execute(DropView(stats_view, if_exists=True))
|