"""Tests for async engine support with asyncpg.""" import os from decimal import Decimal import pytest from sqlalchemy import ( Column, ForeignKey, Integer, MetaData, Numeric, String, Table, func, insert, select, ) from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine from sqlalchemy_pgview import ( CreateMaterializedView, CreateView, DropMaterializedView, DropView, MaterializedView, RefreshMaterializedView, View, ) @pytest.fixture async def async_engine() -> AsyncEngine: """Create an async PostgreSQL engine.""" url = os.environ.get("POSTGRES_URL") if not url: pytest.skip("POSTGRES_URL not set") # Convert postgresql:// to postgresql+asyncpg:// async_url = url.replace("postgresql://", "postgresql+asyncpg://") engine = create_async_engine(async_url) yield engine await engine.dispose() @pytest.fixture async def async_tables(async_engine: AsyncEngine) -> dict[str, Table]: """Create test tables with async engine.""" metadata = MetaData() users = Table( "async_users", metadata, Column("id", Integer, primary_key=True), Column("name", String(100)), Column("email", String(100)), ) orders = Table( "async_orders", metadata, Column("id", Integer, primary_key=True), Column("user_id", Integer, ForeignKey("async_users.id")), Column("total", Numeric(10, 2)), ) async with async_engine.begin() as conn: await conn.run_sync(metadata.create_all) # Insert test data await conn.execute( insert(users), [ {"id": 1, "name": "Alice", "email": "alice@example.com"}, {"id": 2, "name": "Bob", "email": "bob@example.com"}, ], ) await conn.execute( insert(orders), [ {"id": 1, "user_id": 1, "total": Decimal("100.00")}, {"id": 2, "user_id": 1, "total": Decimal("200.00")}, {"id": 3, "user_id": 2, "total": Decimal("150.00")}, ], ) yield {"users": users, "orders": orders, "metadata": metadata} async with async_engine.begin() as conn: await conn.run_sync(metadata.drop_all) class TestAsyncView: """Tests for View with async engine.""" @pytest.mark.asyncio async def test_create_view_async( self, async_engine: AsyncEngine, async_tables: dict ) -> None: """Test creating a view with async engine.""" users = async_tables["users"] orders = async_tables["orders"] user_stats = View( "async_user_stats", select( users.c.id, users.c.name, func.count(orders.c.id).label("order_count"), func.coalesce(func.sum(orders.c.total), 0).label("total_spent"), ) .select_from(users.outerjoin(orders, users.c.id == orders.c.user_id)) .group_by(users.c.id, users.c.name), ) async with async_engine.begin() as conn: # Create view await conn.execute(CreateView(user_stats, or_replace=True)) # Query view result = await conn.execute( select(user_stats.as_table()).order_by(user_stats.as_table().c.name) ) rows = result.fetchall() assert len(rows) == 2 assert rows[0].name == "Alice" assert rows[0].order_count == 2 assert rows[0].total_spent == Decimal("300.00") assert rows[1].name == "Bob" assert rows[1].order_count == 1 assert rows[1].total_spent == Decimal("150.00") # Drop view await conn.execute(DropView(user_stats, if_exists=True)) @pytest.mark.asyncio async def test_view_reflects_changes_async( self, async_engine: AsyncEngine, async_tables: dict ) -> None: """Test that view reflects changes immediately with async engine.""" orders = async_tables["orders"] order_count_view = View( "async_order_count", select(func.count(orders.c.id).label("total_orders")), ) async with async_engine.begin() as conn: await conn.execute(CreateView(order_count_view, or_replace=True)) # Check initial count result = await conn.execute(select(order_count_view.as_table())) row = result.fetchone() assert row.total_orders == 3 # Insert new order await conn.execute( insert(orders).values(id=100, user_id=1, total=Decimal("50.00")) ) # View should show updated count result = await conn.execute(select(order_count_view.as_table())) row = result.fetchone() assert row.total_orders == 4 await conn.execute(DropView(order_count_view, if_exists=True)) class TestAsyncMaterializedView: """Tests for MaterializedView with async engine.""" @pytest.mark.asyncio async def test_create_materialized_view_async( self, async_engine: AsyncEngine, async_tables: dict ) -> None: """Test creating a materialized view with async engine.""" users = async_tables["users"] orders = async_tables["orders"] user_summary_mv = MaterializedView( "async_user_summary_mv", select( users.c.id, users.c.name, func.count(orders.c.id).label("order_count"), ) .select_from(users.outerjoin(orders, users.c.id == orders.c.user_id)) .group_by(users.c.id, users.c.name), with_data=True, ) async with async_engine.begin() as conn: # Create materialized view await conn.execute(CreateMaterializedView(user_summary_mv)) # Query materialized view result = await conn.execute( select(user_summary_mv.as_table()).order_by( user_summary_mv.as_table().c.name ) ) rows = result.fetchall() assert len(rows) == 2 assert rows[0].name == "Alice" assert rows[0].order_count == 2 # Drop materialized view await conn.execute(DropMaterializedView(user_summary_mv, if_exists=True)) @pytest.mark.asyncio async def test_materialized_view_stale_until_refresh_async( self, async_engine: AsyncEngine, async_tables: dict ) -> None: """Test that materialized view is stale until refreshed with async engine.""" orders = async_tables["orders"] order_count_mv = MaterializedView( "async_order_count_mv", select(func.count(orders.c.id).label("total_orders")), with_data=True, ) async with async_engine.begin() as conn: await conn.execute(CreateMaterializedView(order_count_mv)) # Check initial count result = await conn.execute(select(order_count_mv.as_table())) row = result.fetchone() assert row.total_orders == 3 # Insert new order await conn.execute( insert(orders).values(id=101, user_id=1, total=Decimal("75.00")) ) # Materialized view still shows old count (stale) result = await conn.execute(select(order_count_mv.as_table())) row = result.fetchone() assert row.total_orders == 3 # Still 3! # Refresh materialized view await conn.execute(RefreshMaterializedView(order_count_mv)) # Now shows updated count result = await conn.execute(select(order_count_mv.as_table())) row = result.fetchone() assert row.total_orders == 4 await conn.execute(DropMaterializedView(order_count_mv, if_exists=True)) @pytest.mark.asyncio async def test_materialized_view_refresh_method_async( self, async_engine: AsyncEngine, async_tables: dict ) -> None: """Test MaterializedView.refresh() method with async engine.""" orders = async_tables["orders"] count_mv = MaterializedView( "async_count_mv", select(func.count(orders.c.id).label("cnt")), with_data=True, ) async with async_engine.begin() as conn: await conn.execute(CreateMaterializedView(count_mv)) result = await conn.execute(select(count_mv.as_table())) assert result.fetchone().cnt == 3 # Insert data await conn.execute( insert(orders).values(id=102, user_id=2, total=Decimal("25.00")) ) # Use run_sync to call the synchronous refresh method def refresh_sync(sync_conn): count_mv.refresh(sync_conn) await conn.run_sync(refresh_sync) result = await conn.execute(select(count_mv.as_table())) assert result.fetchone().cnt == 4 await conn.execute(DropMaterializedView(count_mv, if_exists=True)) class TestAsyncDependencies: """Tests for dependency functions with async engine.""" @pytest.mark.asyncio async def test_get_all_views_async( self, async_engine: AsyncEngine, async_tables: dict ) -> None: """Test get_all_views with async engine using run_sync.""" from sqlalchemy_pgview import get_all_views users = async_tables["users"] test_view = View( "async_test_view_deps", select(users.c.id, users.c.name), ) async with async_engine.begin() as conn: await conn.execute(CreateView(test_view, or_replace=True)) # Use run_sync for dependency functions def get_views_sync(sync_conn): return get_all_views(sync_conn) views = await conn.run_sync(get_views_sync) view_names = [v.name for v in views] assert "async_test_view_deps" in view_names await conn.execute(DropView(test_view, if_exists=True)) @pytest.mark.asyncio async def test_get_view_definition_async( self, async_engine: AsyncEngine, async_tables: dict ) -> None: """Test get_view_definition with async engine.""" from sqlalchemy_pgview import get_view_definition users = async_tables["users"] test_view = View( "async_def_test_view", select(users.c.id, users.c.name).where(users.c.id > 0), ) async with async_engine.begin() as conn: await conn.execute(CreateView(test_view, or_replace=True)) def get_def_sync(sync_conn): return get_view_definition(sync_conn, "async_def_test_view") definition = await conn.run_sync(get_def_sync) assert definition is not None assert "async_users" in definition.lower() await conn.execute(DropView(test_view, if_exists=True)) class TestAsyncViewWithJoins: """Tests for views with complex joins using async engine.""" @pytest.mark.asyncio async def test_view_with_aggregation_async( self, async_engine: AsyncEngine, async_tables: dict ) -> None: """Test view with GROUP BY and aggregation functions.""" users = async_tables["users"] orders = async_tables["orders"] stats_view = View( "async_stats_view", select( users.c.name, func.count(orders.c.id).label("orders"), func.sum(orders.c.total).label("total"), func.avg(orders.c.total).label("avg_order"), ) .select_from(users.join(orders, users.c.id == orders.c.user_id)) .group_by(users.c.name), ) async with async_engine.begin() as conn: await conn.execute(CreateView(stats_view, or_replace=True)) result = await conn.execute( select(stats_view.as_table()).order_by(stats_view.as_table().c.total.desc()) ) rows = result.fetchall() assert len(rows) == 2 # Alice: 2 orders, total 300 alice = next(r for r in rows if r.name == "Alice") assert alice.orders == 2 assert alice.total == Decimal("300.00") assert alice.avg_order == Decimal("150.00") # Bob: 1 order, total 150 bob = next(r for r in rows if r.name == "Bob") assert bob.orders == 1 assert bob.total == Decimal("150.00") await conn.execute(DropView(stats_view, if_exists=True))