mirror of
https://github.com/d3vyce/sqlalchemy-pgview.git
synced 2026-03-01 21:30:48 +01:00
396 lines
14 KiB
Python
396 lines
14 KiB
Python
"""Tests for declarative view classes."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
from decimal import Decimal
|
|
from typing import TYPE_CHECKING
|
|
|
|
import pytest
|
|
from sqlalchemy import (
|
|
Integer,
|
|
Numeric,
|
|
String,
|
|
func,
|
|
select,
|
|
)
|
|
from sqlalchemy.engine import create_engine
|
|
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column
|
|
|
|
from sqlalchemy_pgview import (
|
|
MaterializedViewBase,
|
|
ViewBase,
|
|
)
|
|
from sqlalchemy_pgview.ddl import (
|
|
DropMaterializedView,
|
|
DropView,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from sqlalchemy.engine import Engine
|
|
|
|
|
|
@pytest.fixture
|
|
def pg_engine() -> Engine:
|
|
"""Create a PostgreSQL engine for testing."""
|
|
url = os.environ.get("POSTGRES_URL", "postgresql://test:test@localhost:5432/testdb")
|
|
return create_engine(url)
|
|
|
|
|
|
class TestViewBase:
|
|
"""Tests for ViewBase declarative views."""
|
|
|
|
def test_basic_view_class(self, pg_engine: Engine) -> None:
|
|
"""Test basic ViewBase subclass creation and querying."""
|
|
|
|
class Base(DeclarativeBase):
|
|
pass
|
|
|
|
class User(Base):
|
|
__tablename__ = "test_users_vb"
|
|
id: Mapped[int] = mapped_column(primary_key=True)
|
|
name: Mapped[str] = mapped_column(String(100))
|
|
is_active: Mapped[int] = mapped_column(Integer)
|
|
|
|
class ActiveUsers(ViewBase, Base):
|
|
__tablename__ = "test_active_users_vb"
|
|
__select__ = select(User.id, User.name).where(User.is_active == 1)
|
|
|
|
try:
|
|
Base.metadata.create_all(pg_engine)
|
|
|
|
with Session(pg_engine) as session:
|
|
session.add(User(id=1, name="Alice", is_active=1))
|
|
session.add(User(id=2, name="Bob", is_active=0))
|
|
session.add(User(id=3, name="Charlie", is_active=1))
|
|
session.commit()
|
|
|
|
with pg_engine.connect() as conn:
|
|
result = conn.execute(select(ActiveUsers.as_table())).fetchall()
|
|
assert len(result) == 2
|
|
names = {row.name for row in result}
|
|
assert names == {"Alice", "Charlie"}
|
|
|
|
finally:
|
|
with pg_engine.begin() as conn:
|
|
conn.execute(DropView(ActiveUsers._view, if_exists=True))
|
|
Base.metadata.drop_all(pg_engine)
|
|
|
|
def test_view_class_columns_shorthand(self, pg_engine: Engine) -> None:
|
|
"""Test that .c shorthand works for columns."""
|
|
|
|
class Base(DeclarativeBase):
|
|
pass
|
|
|
|
class Product(Base):
|
|
__tablename__ = "test_products_vb"
|
|
id: Mapped[int] = mapped_column(primary_key=True)
|
|
name: Mapped[str] = mapped_column(String(100))
|
|
price: Mapped[Decimal] = mapped_column(Numeric(10, 2))
|
|
|
|
class ProductView(ViewBase, Base):
|
|
__tablename__ = "test_product_view"
|
|
__select__ = select(Product.id, Product.name, Product.price)
|
|
|
|
try:
|
|
Base.metadata.create_all(pg_engine)
|
|
|
|
with Session(pg_engine) as session:
|
|
session.add(Product(id=1, name="Widget", price=Decimal("9.99")))
|
|
session.commit()
|
|
|
|
with pg_engine.connect() as conn:
|
|
result = conn.execute(
|
|
select(ProductView.c.name, ProductView.c.price)
|
|
).fetchone()
|
|
assert result.name == "Widget"
|
|
assert result.price == Decimal("9.99")
|
|
|
|
finally:
|
|
with pg_engine.begin() as conn:
|
|
conn.execute(DropView(ProductView._view, if_exists=True))
|
|
Base.metadata.drop_all(pg_engine)
|
|
|
|
def test_view_class_missing_tablename_raises(self) -> None:
|
|
"""Test that missing __tablename__ raises TypeError."""
|
|
|
|
class Base(DeclarativeBase):
|
|
pass
|
|
|
|
with pytest.raises(TypeError, match="must define __tablename__"):
|
|
|
|
class BadView(ViewBase, Base):
|
|
__select__ = select()
|
|
|
|
def test_view_class_missing_select_raises(self) -> None:
|
|
"""Test that missing __select__ raises TypeError."""
|
|
|
|
class Base(DeclarativeBase):
|
|
pass
|
|
|
|
with pytest.raises(TypeError, match="must define __select__"):
|
|
|
|
class BadView(ViewBase, Base):
|
|
__tablename__ = "bad_view"
|
|
|
|
|
|
class TestMaterializedViewBase:
|
|
"""Tests for MaterializedViewBase declarative views."""
|
|
|
|
def test_basic_materialized_view_class(self, pg_engine: Engine) -> None:
|
|
"""Test basic MaterializedViewBase subclass."""
|
|
|
|
class Base(DeclarativeBase):
|
|
pass
|
|
|
|
class Order(Base):
|
|
__tablename__ = "test_orders_mvb"
|
|
id: Mapped[int] = mapped_column(primary_key=True)
|
|
total: Mapped[Decimal] = mapped_column(Numeric(10, 2))
|
|
|
|
class OrderStats(MaterializedViewBase, Base):
|
|
__tablename__ = "test_order_stats_mvb"
|
|
__select__ = select(
|
|
func.count(Order.id).label("order_count"),
|
|
func.sum(Order.total).label("total_revenue"),
|
|
)
|
|
|
|
try:
|
|
Base.metadata.create_all(pg_engine)
|
|
|
|
# Initial state - empty
|
|
with pg_engine.connect() as conn:
|
|
result = conn.execute(select(OrderStats.as_table())).fetchone()
|
|
assert result.order_count == 0
|
|
|
|
with Session(pg_engine) as session:
|
|
session.add(Order(id=1, total=Decimal("100.00")))
|
|
session.add(Order(id=2, total=Decimal("200.00")))
|
|
session.commit()
|
|
|
|
# MV is stale until refresh
|
|
with pg_engine.connect() as conn:
|
|
result = conn.execute(select(OrderStats.as_table())).fetchone()
|
|
assert result.order_count == 0
|
|
|
|
# Refresh using class method
|
|
with pg_engine.begin() as conn:
|
|
OrderStats.refresh(conn)
|
|
|
|
with pg_engine.connect() as conn:
|
|
result = conn.execute(select(OrderStats.as_table())).fetchone()
|
|
assert result.order_count == 2
|
|
assert result.total_revenue == Decimal("300.00")
|
|
|
|
finally:
|
|
with pg_engine.begin() as conn:
|
|
conn.execute(DropMaterializedView(OrderStats._view, if_exists=True))
|
|
Base.metadata.drop_all(pg_engine)
|
|
|
|
def test_materialized_view_with_data_false(self, pg_engine: Engine) -> None:
|
|
"""Test MaterializedViewBase with __with_data__ = False."""
|
|
|
|
class Base(DeclarativeBase):
|
|
pass
|
|
|
|
class Item(Base):
|
|
__tablename__ = "test_items_mvb_nodata"
|
|
id: Mapped[int] = mapped_column(primary_key=True)
|
|
|
|
class ItemCount(MaterializedViewBase, Base):
|
|
__tablename__ = "test_item_count_mvb"
|
|
__select__ = select(func.count(Item.id).label("count"))
|
|
__with_data__ = False
|
|
|
|
try:
|
|
Base.metadata.create_all(pg_engine)
|
|
|
|
with Session(pg_engine) as session:
|
|
session.add(Item(id=1))
|
|
session.commit()
|
|
|
|
with pg_engine.begin() as conn:
|
|
ItemCount.refresh(conn)
|
|
|
|
with pg_engine.connect() as conn:
|
|
result = conn.execute(select(ItemCount.as_table())).fetchone()
|
|
assert result.count == 1
|
|
|
|
finally:
|
|
with pg_engine.begin() as conn:
|
|
conn.execute(DropMaterializedView(ItemCount._view, if_exists=True))
|
|
Base.metadata.drop_all(pg_engine)
|
|
|
|
def test_materialized_view_auto_refresh(self, pg_engine: Engine) -> None:
|
|
"""Test MaterializedViewBase auto-refresh functionality."""
|
|
|
|
class Base(DeclarativeBase):
|
|
pass
|
|
|
|
class Counter(Base):
|
|
__tablename__ = "test_counter_mvb_ar"
|
|
id: Mapped[int] = mapped_column(primary_key=True)
|
|
value: Mapped[int] = mapped_column(Integer)
|
|
|
|
class CounterSum(MaterializedViewBase, Base):
|
|
__tablename__ = "test_counter_sum_mvb"
|
|
__select__ = select(func.sum(Counter.value).label("total"))
|
|
|
|
class TestSession(Session):
|
|
pass
|
|
|
|
CounterSum.auto_refresh_on(TestSession, Counter.__table__)
|
|
|
|
try:
|
|
Base.metadata.create_all(pg_engine)
|
|
|
|
with TestSession(pg_engine) as session:
|
|
session.add(Counter(id=1, value=10))
|
|
session.add(Counter(id=2, value=20))
|
|
session.commit()
|
|
|
|
with pg_engine.connect() as conn:
|
|
result = conn.execute(select(CounterSum.as_table())).fetchone()
|
|
assert result.total == 30
|
|
|
|
finally:
|
|
with pg_engine.begin() as conn:
|
|
conn.execute(DropMaterializedView(CounterSum._view, if_exists=True))
|
|
Base.metadata.drop_all(pg_engine)
|
|
|
|
|
|
class TestMultipleInheritance:
|
|
"""Tests for multiple inheritance with DeclarativeBase."""
|
|
|
|
def test_view_inherits_metadata(self, pg_engine: Engine) -> None:
|
|
"""Test that ViewBase inherits metadata from DeclarativeBase."""
|
|
|
|
class Base(DeclarativeBase):
|
|
pass
|
|
|
|
class Product(Base):
|
|
__tablename__ = "test_product_mi"
|
|
id: Mapped[int] = mapped_column(primary_key=True)
|
|
name: Mapped[str] = mapped_column(String(100))
|
|
price: Mapped[Decimal] = mapped_column(Numeric(10, 2))
|
|
|
|
class ExpensiveProducts(ViewBase, Base):
|
|
__tablename__ = "test_expensive_products_mi"
|
|
__select__ = select(Product.id, Product.name, Product.price).where(
|
|
Product.price > 50
|
|
)
|
|
|
|
try:
|
|
Base.metadata.create_all(pg_engine)
|
|
|
|
with Session(pg_engine) as session:
|
|
session.add(Product(id=1, name="Cheap", price=Decimal("10.00")))
|
|
session.add(Product(id=2, name="Expensive", price=Decimal("100.00")))
|
|
session.add(Product(id=3, name="Premium", price=Decimal("200.00")))
|
|
session.commit()
|
|
|
|
with pg_engine.connect() as conn:
|
|
result = conn.execute(select(ExpensiveProducts.as_table())).fetchall()
|
|
assert len(result) == 2
|
|
names = {row.name for row in result}
|
|
assert names == {"Expensive", "Premium"}
|
|
|
|
finally:
|
|
with pg_engine.begin() as conn:
|
|
conn.execute(DropView(ExpensiveProducts._view, if_exists=True))
|
|
Base.metadata.drop_all(pg_engine)
|
|
|
|
def test_materialized_view_inherits_metadata(self, pg_engine: Engine) -> None:
|
|
"""Test that MaterializedViewBase inherits metadata from DeclarativeBase."""
|
|
|
|
class Base(DeclarativeBase):
|
|
pass
|
|
|
|
class Sale(Base):
|
|
__tablename__ = "test_sale_mi"
|
|
id: Mapped[int] = mapped_column(primary_key=True)
|
|
amount: Mapped[Decimal] = mapped_column(Numeric(10, 2))
|
|
|
|
class SaleStats(MaterializedViewBase, Base):
|
|
__tablename__ = "test_sale_stats_mi"
|
|
__select__ = select(
|
|
func.count(Sale.id).label("sale_count"),
|
|
func.sum(Sale.amount).label("total_amount"),
|
|
)
|
|
|
|
try:
|
|
Base.metadata.create_all(pg_engine)
|
|
|
|
with Session(pg_engine) as session:
|
|
session.add(Sale(id=1, amount=Decimal("100.00")))
|
|
session.add(Sale(id=2, amount=Decimal("200.00")))
|
|
session.commit()
|
|
|
|
with pg_engine.begin() as conn:
|
|
SaleStats.refresh(conn)
|
|
|
|
with pg_engine.connect() as conn:
|
|
result = conn.execute(select(SaleStats.as_table())).fetchone()
|
|
assert result.sale_count == 2
|
|
assert result.total_amount == Decimal("300.00")
|
|
|
|
finally:
|
|
with pg_engine.begin() as conn:
|
|
conn.execute(DropMaterializedView(SaleStats._view, if_exists=True))
|
|
Base.metadata.drop_all(pg_engine)
|
|
|
|
def test_multiple_views_same_base(self, pg_engine: Engine) -> None:
|
|
"""Test multiple views sharing the same DeclarativeBase."""
|
|
|
|
class Base(DeclarativeBase):
|
|
pass
|
|
|
|
class User(Base):
|
|
__tablename__ = "test_user_multi"
|
|
id: Mapped[int] = mapped_column(primary_key=True)
|
|
name: Mapped[str] = mapped_column(String(100))
|
|
role: Mapped[str] = mapped_column(String(50))
|
|
|
|
class AdminUsers(ViewBase, Base):
|
|
__tablename__ = "test_admin_users"
|
|
__select__ = select(User.id, User.name).where(User.role == "admin")
|
|
|
|
class RegularUsers(ViewBase, Base):
|
|
__tablename__ = "test_regular_users"
|
|
__select__ = select(User.id, User.name).where(User.role == "user")
|
|
|
|
class UserStats(MaterializedViewBase, Base):
|
|
__tablename__ = "test_user_stats_multi"
|
|
__select__ = select(func.count(User.id).label("count"))
|
|
|
|
try:
|
|
Base.metadata.create_all(pg_engine)
|
|
|
|
with Session(pg_engine) as session:
|
|
session.add(User(id=1, name="Alice", role="admin"))
|
|
session.add(User(id=2, name="Bob", role="user"))
|
|
session.add(User(id=3, name="Charlie", role="user"))
|
|
session.commit()
|
|
|
|
with pg_engine.connect() as conn:
|
|
admins = conn.execute(select(AdminUsers.as_table())).fetchall()
|
|
assert len(admins) == 1
|
|
assert admins[0].name == "Alice"
|
|
|
|
regulars = conn.execute(select(RegularUsers.as_table())).fetchall()
|
|
assert len(regulars) == 2
|
|
|
|
with pg_engine.begin() as conn:
|
|
UserStats.refresh(conn)
|
|
|
|
with pg_engine.connect() as conn:
|
|
stats = conn.execute(select(UserStats.as_table())).fetchone()
|
|
assert stats.count == 3
|
|
|
|
finally:
|
|
with pg_engine.begin() as conn:
|
|
conn.execute(DropView(AdminUsers._view, if_exists=True))
|
|
conn.execute(DropView(RegularUsers._view, if_exists=True))
|
|
conn.execute(DropMaterializedView(UserStats._view, if_exists=True))
|
|
Base.metadata.drop_all(pg_engine)
|