"""Tests for declarative view classes.""" from __future__ import annotations import os from decimal import Decimal from typing import TYPE_CHECKING import pytest from sqlalchemy import ( Integer, Numeric, String, func, select, ) from sqlalchemy.engine import create_engine from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column from sqlalchemy_pgview import ( MaterializedViewBase, ViewBase, ) from sqlalchemy_pgview.ddl import ( DropMaterializedView, DropView, ) if TYPE_CHECKING: from sqlalchemy.engine import Engine @pytest.fixture def pg_engine() -> Engine: """Create a PostgreSQL engine for testing.""" url = os.environ.get("POSTGRES_URL", "postgresql://test:test@localhost:5432/testdb") return create_engine(url) class TestViewBase: """Tests for ViewBase declarative views.""" def test_basic_view_class(self, pg_engine: Engine) -> None: """Test basic ViewBase subclass creation and querying.""" class Base(DeclarativeBase): pass class User(Base): __tablename__ = "test_users_vb" id: Mapped[int] = mapped_column(primary_key=True) name: Mapped[str] = mapped_column(String(100)) is_active: Mapped[int] = mapped_column(Integer) class ActiveUsers(ViewBase, Base): __tablename__ = "test_active_users_vb" __select__ = select(User.id, User.name).where(User.is_active == 1) try: Base.metadata.create_all(pg_engine) with Session(pg_engine) as session: session.add(User(id=1, name="Alice", is_active=1)) session.add(User(id=2, name="Bob", is_active=0)) session.add(User(id=3, name="Charlie", is_active=1)) session.commit() with pg_engine.connect() as conn: result = conn.execute(select(ActiveUsers.as_table())).fetchall() assert len(result) == 2 names = {row.name for row in result} assert names == {"Alice", "Charlie"} finally: with pg_engine.begin() as conn: conn.execute(DropView(ActiveUsers._view, if_exists=True)) Base.metadata.drop_all(pg_engine) def test_view_class_columns_shorthand(self, pg_engine: Engine) -> None: """Test that .c shorthand works for columns.""" class Base(DeclarativeBase): pass class Product(Base): __tablename__ = "test_products_vb" id: Mapped[int] = mapped_column(primary_key=True) name: Mapped[str] = mapped_column(String(100)) price: Mapped[Decimal] = mapped_column(Numeric(10, 2)) class ProductView(ViewBase, Base): __tablename__ = "test_product_view" __select__ = select(Product.id, Product.name, Product.price) try: Base.metadata.create_all(pg_engine) with Session(pg_engine) as session: session.add(Product(id=1, name="Widget", price=Decimal("9.99"))) session.commit() with pg_engine.connect() as conn: result = conn.execute( select(ProductView.c.name, ProductView.c.price) ).fetchone() assert result.name == "Widget" assert result.price == Decimal("9.99") finally: with pg_engine.begin() as conn: conn.execute(DropView(ProductView._view, if_exists=True)) Base.metadata.drop_all(pg_engine) def test_view_class_missing_tablename_raises(self) -> None: """Test that missing __tablename__ raises TypeError.""" class Base(DeclarativeBase): pass with pytest.raises(TypeError, match="must define __tablename__"): class BadView(ViewBase, Base): __select__ = select() def test_view_class_missing_select_raises(self) -> None: """Test that missing __select__ raises TypeError.""" class Base(DeclarativeBase): pass with pytest.raises(TypeError, match="must define __select__"): class BadView(ViewBase, Base): __tablename__ = "bad_view" class TestMaterializedViewBase: """Tests for MaterializedViewBase declarative views.""" def test_basic_materialized_view_class(self, pg_engine: Engine) -> None: """Test basic MaterializedViewBase subclass.""" class Base(DeclarativeBase): pass class Order(Base): __tablename__ = "test_orders_mvb" id: Mapped[int] = mapped_column(primary_key=True) total: Mapped[Decimal] = mapped_column(Numeric(10, 2)) class OrderStats(MaterializedViewBase, Base): __tablename__ = "test_order_stats_mvb" __select__ = select( func.count(Order.id).label("order_count"), func.sum(Order.total).label("total_revenue"), ) try: Base.metadata.create_all(pg_engine) # Initial state - empty with pg_engine.connect() as conn: result = conn.execute(select(OrderStats.as_table())).fetchone() assert result.order_count == 0 with Session(pg_engine) as session: session.add(Order(id=1, total=Decimal("100.00"))) session.add(Order(id=2, total=Decimal("200.00"))) session.commit() # MV is stale until refresh with pg_engine.connect() as conn: result = conn.execute(select(OrderStats.as_table())).fetchone() assert result.order_count == 0 # Refresh using class method with pg_engine.begin() as conn: OrderStats.refresh(conn) with pg_engine.connect() as conn: result = conn.execute(select(OrderStats.as_table())).fetchone() assert result.order_count == 2 assert result.total_revenue == Decimal("300.00") finally: with pg_engine.begin() as conn: conn.execute(DropMaterializedView(OrderStats._view, if_exists=True)) Base.metadata.drop_all(pg_engine) def test_materialized_view_with_data_false(self, pg_engine: Engine) -> None: """Test MaterializedViewBase with __with_data__ = False.""" class Base(DeclarativeBase): pass class Item(Base): __tablename__ = "test_items_mvb_nodata" id: Mapped[int] = mapped_column(primary_key=True) class ItemCount(MaterializedViewBase, Base): __tablename__ = "test_item_count_mvb" __select__ = select(func.count(Item.id).label("count")) __with_data__ = False try: Base.metadata.create_all(pg_engine) with Session(pg_engine) as session: session.add(Item(id=1)) session.commit() with pg_engine.begin() as conn: ItemCount.refresh(conn) with pg_engine.connect() as conn: result = conn.execute(select(ItemCount.as_table())).fetchone() assert result.count == 1 finally: with pg_engine.begin() as conn: conn.execute(DropMaterializedView(ItemCount._view, if_exists=True)) Base.metadata.drop_all(pg_engine) def test_materialized_view_auto_refresh(self, pg_engine: Engine) -> None: """Test MaterializedViewBase auto-refresh functionality.""" class Base(DeclarativeBase): pass class Counter(Base): __tablename__ = "test_counter_mvb_ar" id: Mapped[int] = mapped_column(primary_key=True) value: Mapped[int] = mapped_column(Integer) class CounterSum(MaterializedViewBase, Base): __tablename__ = "test_counter_sum_mvb" __select__ = select(func.sum(Counter.value).label("total")) class TestSession(Session): pass CounterSum.auto_refresh_on(TestSession, Counter.__table__) try: Base.metadata.create_all(pg_engine) with TestSession(pg_engine) as session: session.add(Counter(id=1, value=10)) session.add(Counter(id=2, value=20)) session.commit() with pg_engine.connect() as conn: result = conn.execute(select(CounterSum.as_table())).fetchone() assert result.total == 30 finally: with pg_engine.begin() as conn: conn.execute(DropMaterializedView(CounterSum._view, if_exists=True)) Base.metadata.drop_all(pg_engine) class TestMultipleInheritance: """Tests for multiple inheritance with DeclarativeBase.""" def test_view_inherits_metadata(self, pg_engine: Engine) -> None: """Test that ViewBase inherits metadata from DeclarativeBase.""" class Base(DeclarativeBase): pass class Product(Base): __tablename__ = "test_product_mi" id: Mapped[int] = mapped_column(primary_key=True) name: Mapped[str] = mapped_column(String(100)) price: Mapped[Decimal] = mapped_column(Numeric(10, 2)) class ExpensiveProducts(ViewBase, Base): __tablename__ = "test_expensive_products_mi" __select__ = select(Product.id, Product.name, Product.price).where( Product.price > 50 ) try: Base.metadata.create_all(pg_engine) with Session(pg_engine) as session: session.add(Product(id=1, name="Cheap", price=Decimal("10.00"))) session.add(Product(id=2, name="Expensive", price=Decimal("100.00"))) session.add(Product(id=3, name="Premium", price=Decimal("200.00"))) session.commit() with pg_engine.connect() as conn: result = conn.execute(select(ExpensiveProducts.as_table())).fetchall() assert len(result) == 2 names = {row.name for row in result} assert names == {"Expensive", "Premium"} finally: with pg_engine.begin() as conn: conn.execute(DropView(ExpensiveProducts._view, if_exists=True)) Base.metadata.drop_all(pg_engine) def test_materialized_view_inherits_metadata(self, pg_engine: Engine) -> None: """Test that MaterializedViewBase inherits metadata from DeclarativeBase.""" class Base(DeclarativeBase): pass class Sale(Base): __tablename__ = "test_sale_mi" id: Mapped[int] = mapped_column(primary_key=True) amount: Mapped[Decimal] = mapped_column(Numeric(10, 2)) class SaleStats(MaterializedViewBase, Base): __tablename__ = "test_sale_stats_mi" __select__ = select( func.count(Sale.id).label("sale_count"), func.sum(Sale.amount).label("total_amount"), ) try: Base.metadata.create_all(pg_engine) with Session(pg_engine) as session: session.add(Sale(id=1, amount=Decimal("100.00"))) session.add(Sale(id=2, amount=Decimal("200.00"))) session.commit() with pg_engine.begin() as conn: SaleStats.refresh(conn) with pg_engine.connect() as conn: result = conn.execute(select(SaleStats.as_table())).fetchone() assert result.sale_count == 2 assert result.total_amount == Decimal("300.00") finally: with pg_engine.begin() as conn: conn.execute(DropMaterializedView(SaleStats._view, if_exists=True)) Base.metadata.drop_all(pg_engine) def test_multiple_views_same_base(self, pg_engine: Engine) -> None: """Test multiple views sharing the same DeclarativeBase.""" class Base(DeclarativeBase): pass class User(Base): __tablename__ = "test_user_multi" id: Mapped[int] = mapped_column(primary_key=True) name: Mapped[str] = mapped_column(String(100)) role: Mapped[str] = mapped_column(String(50)) class AdminUsers(ViewBase, Base): __tablename__ = "test_admin_users" __select__ = select(User.id, User.name).where(User.role == "admin") class RegularUsers(ViewBase, Base): __tablename__ = "test_regular_users" __select__ = select(User.id, User.name).where(User.role == "user") class UserStats(MaterializedViewBase, Base): __tablename__ = "test_user_stats_multi" __select__ = select(func.count(User.id).label("count")) try: Base.metadata.create_all(pg_engine) with Session(pg_engine) as session: session.add(User(id=1, name="Alice", role="admin")) session.add(User(id=2, name="Bob", role="user")) session.add(User(id=3, name="Charlie", role="user")) session.commit() with pg_engine.connect() as conn: admins = conn.execute(select(AdminUsers.as_table())).fetchall() assert len(admins) == 1 assert admins[0].name == "Alice" regulars = conn.execute(select(RegularUsers.as_table())).fetchall() assert len(regulars) == 2 with pg_engine.begin() as conn: UserStats.refresh(conn) with pg_engine.connect() as conn: stats = conn.execute(select(UserStats.as_table())).fetchone() assert stats.count == 3 finally: with pg_engine.begin() as conn: conn.execute(DropView(AdminUsers._view, if_exists=True)) conn.execute(DropView(RegularUsers._view, if_exists=True)) conn.execute(DropMaterializedView(UserStats._view, if_exists=True)) Base.metadata.drop_all(pg_engine)