mirror of
https://github.com/d3vyce/sqlalchemy-pgview.git
synced 2026-03-01 21:40:47 +01:00
Initial commit
This commit is contained in:
395
tests/test_declarative.py
Normal file
395
tests/test_declarative.py
Normal file
@@ -0,0 +1,395 @@
|
||||
"""Tests for declarative view classes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from decimal import Decimal
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import (
|
||||
Integer,
|
||||
Numeric,
|
||||
String,
|
||||
func,
|
||||
select,
|
||||
)
|
||||
from sqlalchemy.engine import create_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column
|
||||
|
||||
from sqlalchemy_pgview import (
|
||||
MaterializedViewBase,
|
||||
ViewBase,
|
||||
)
|
||||
from sqlalchemy_pgview.ddl import (
|
||||
DropMaterializedView,
|
||||
DropView,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pg_engine() -> Engine:
|
||||
"""Create a PostgreSQL engine for testing."""
|
||||
url = os.environ.get("POSTGRES_URL", "postgresql://test:test@localhost:5432/testdb")
|
||||
return create_engine(url)
|
||||
|
||||
|
||||
class TestViewBase:
|
||||
"""Tests for ViewBase declarative views."""
|
||||
|
||||
def test_basic_view_class(self, pg_engine: Engine) -> None:
|
||||
"""Test basic ViewBase subclass creation and querying."""
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "test_users_vb"
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(100))
|
||||
is_active: Mapped[int] = mapped_column(Integer)
|
||||
|
||||
class ActiveUsers(ViewBase, Base):
|
||||
__tablename__ = "test_active_users_vb"
|
||||
__select__ = select(User.id, User.name).where(User.is_active == 1)
|
||||
|
||||
try:
|
||||
Base.metadata.create_all(pg_engine)
|
||||
|
||||
with Session(pg_engine) as session:
|
||||
session.add(User(id=1, name="Alice", is_active=1))
|
||||
session.add(User(id=2, name="Bob", is_active=0))
|
||||
session.add(User(id=3, name="Charlie", is_active=1))
|
||||
session.commit()
|
||||
|
||||
with pg_engine.connect() as conn:
|
||||
result = conn.execute(select(ActiveUsers.as_table())).fetchall()
|
||||
assert len(result) == 2
|
||||
names = {row.name for row in result}
|
||||
assert names == {"Alice", "Charlie"}
|
||||
|
||||
finally:
|
||||
with pg_engine.begin() as conn:
|
||||
conn.execute(DropView(ActiveUsers._view, if_exists=True))
|
||||
Base.metadata.drop_all(pg_engine)
|
||||
|
||||
def test_view_class_columns_shorthand(self, pg_engine: Engine) -> None:
|
||||
"""Test that .c shorthand works for columns."""
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
class Product(Base):
|
||||
__tablename__ = "test_products_vb"
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(100))
|
||||
price: Mapped[Decimal] = mapped_column(Numeric(10, 2))
|
||||
|
||||
class ProductView(ViewBase, Base):
|
||||
__tablename__ = "test_product_view"
|
||||
__select__ = select(Product.id, Product.name, Product.price)
|
||||
|
||||
try:
|
||||
Base.metadata.create_all(pg_engine)
|
||||
|
||||
with Session(pg_engine) as session:
|
||||
session.add(Product(id=1, name="Widget", price=Decimal("9.99")))
|
||||
session.commit()
|
||||
|
||||
with pg_engine.connect() as conn:
|
||||
result = conn.execute(
|
||||
select(ProductView.c.name, ProductView.c.price)
|
||||
).fetchone()
|
||||
assert result.name == "Widget"
|
||||
assert result.price == Decimal("9.99")
|
||||
|
||||
finally:
|
||||
with pg_engine.begin() as conn:
|
||||
conn.execute(DropView(ProductView._view, if_exists=True))
|
||||
Base.metadata.drop_all(pg_engine)
|
||||
|
||||
def test_view_class_missing_tablename_raises(self) -> None:
|
||||
"""Test that missing __tablename__ raises TypeError."""
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
with pytest.raises(TypeError, match="must define __tablename__"):
|
||||
|
||||
class BadView(ViewBase, Base):
|
||||
__select__ = select()
|
||||
|
||||
def test_view_class_missing_select_raises(self) -> None:
|
||||
"""Test that missing __select__ raises TypeError."""
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
with pytest.raises(TypeError, match="must define __select__"):
|
||||
|
||||
class BadView(ViewBase, Base):
|
||||
__tablename__ = "bad_view"
|
||||
|
||||
|
||||
class TestMaterializedViewBase:
|
||||
"""Tests for MaterializedViewBase declarative views."""
|
||||
|
||||
def test_basic_materialized_view_class(self, pg_engine: Engine) -> None:
|
||||
"""Test basic MaterializedViewBase subclass."""
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
class Order(Base):
|
||||
__tablename__ = "test_orders_mvb"
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
total: Mapped[Decimal] = mapped_column(Numeric(10, 2))
|
||||
|
||||
class OrderStats(MaterializedViewBase, Base):
|
||||
__tablename__ = "test_order_stats_mvb"
|
||||
__select__ = select(
|
||||
func.count(Order.id).label("order_count"),
|
||||
func.sum(Order.total).label("total_revenue"),
|
||||
)
|
||||
|
||||
try:
|
||||
Base.metadata.create_all(pg_engine)
|
||||
|
||||
# Initial state - empty
|
||||
with pg_engine.connect() as conn:
|
||||
result = conn.execute(select(OrderStats.as_table())).fetchone()
|
||||
assert result.order_count == 0
|
||||
|
||||
with Session(pg_engine) as session:
|
||||
session.add(Order(id=1, total=Decimal("100.00")))
|
||||
session.add(Order(id=2, total=Decimal("200.00")))
|
||||
session.commit()
|
||||
|
||||
# MV is stale until refresh
|
||||
with pg_engine.connect() as conn:
|
||||
result = conn.execute(select(OrderStats.as_table())).fetchone()
|
||||
assert result.order_count == 0
|
||||
|
||||
# Refresh using class method
|
||||
with pg_engine.begin() as conn:
|
||||
OrderStats.refresh(conn)
|
||||
|
||||
with pg_engine.connect() as conn:
|
||||
result = conn.execute(select(OrderStats.as_table())).fetchone()
|
||||
assert result.order_count == 2
|
||||
assert result.total_revenue == Decimal("300.00")
|
||||
|
||||
finally:
|
||||
with pg_engine.begin() as conn:
|
||||
conn.execute(DropMaterializedView(OrderStats._view, if_exists=True))
|
||||
Base.metadata.drop_all(pg_engine)
|
||||
|
||||
def test_materialized_view_with_data_false(self, pg_engine: Engine) -> None:
|
||||
"""Test MaterializedViewBase with __with_data__ = False."""
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
class Item(Base):
|
||||
__tablename__ = "test_items_mvb_nodata"
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
|
||||
class ItemCount(MaterializedViewBase, Base):
|
||||
__tablename__ = "test_item_count_mvb"
|
||||
__select__ = select(func.count(Item.id).label("count"))
|
||||
__with_data__ = False
|
||||
|
||||
try:
|
||||
Base.metadata.create_all(pg_engine)
|
||||
|
||||
with Session(pg_engine) as session:
|
||||
session.add(Item(id=1))
|
||||
session.commit()
|
||||
|
||||
with pg_engine.begin() as conn:
|
||||
ItemCount.refresh(conn)
|
||||
|
||||
with pg_engine.connect() as conn:
|
||||
result = conn.execute(select(ItemCount.as_table())).fetchone()
|
||||
assert result.count == 1
|
||||
|
||||
finally:
|
||||
with pg_engine.begin() as conn:
|
||||
conn.execute(DropMaterializedView(ItemCount._view, if_exists=True))
|
||||
Base.metadata.drop_all(pg_engine)
|
||||
|
||||
def test_materialized_view_auto_refresh(self, pg_engine: Engine) -> None:
|
||||
"""Test MaterializedViewBase auto-refresh functionality."""
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
class Counter(Base):
|
||||
__tablename__ = "test_counter_mvb_ar"
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
value: Mapped[int] = mapped_column(Integer)
|
||||
|
||||
class CounterSum(MaterializedViewBase, Base):
|
||||
__tablename__ = "test_counter_sum_mvb"
|
||||
__select__ = select(func.sum(Counter.value).label("total"))
|
||||
|
||||
class TestSession(Session):
|
||||
pass
|
||||
|
||||
CounterSum.auto_refresh_on(TestSession, Counter.__table__)
|
||||
|
||||
try:
|
||||
Base.metadata.create_all(pg_engine)
|
||||
|
||||
with TestSession(pg_engine) as session:
|
||||
session.add(Counter(id=1, value=10))
|
||||
session.add(Counter(id=2, value=20))
|
||||
session.commit()
|
||||
|
||||
with pg_engine.connect() as conn:
|
||||
result = conn.execute(select(CounterSum.as_table())).fetchone()
|
||||
assert result.total == 30
|
||||
|
||||
finally:
|
||||
with pg_engine.begin() as conn:
|
||||
conn.execute(DropMaterializedView(CounterSum._view, if_exists=True))
|
||||
Base.metadata.drop_all(pg_engine)
|
||||
|
||||
|
||||
class TestMultipleInheritance:
|
||||
"""Tests for multiple inheritance with DeclarativeBase."""
|
||||
|
||||
def test_view_inherits_metadata(self, pg_engine: Engine) -> None:
|
||||
"""Test that ViewBase inherits metadata from DeclarativeBase."""
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
class Product(Base):
|
||||
__tablename__ = "test_product_mi"
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(100))
|
||||
price: Mapped[Decimal] = mapped_column(Numeric(10, 2))
|
||||
|
||||
class ExpensiveProducts(ViewBase, Base):
|
||||
__tablename__ = "test_expensive_products_mi"
|
||||
__select__ = select(Product.id, Product.name, Product.price).where(
|
||||
Product.price > 50
|
||||
)
|
||||
|
||||
try:
|
||||
Base.metadata.create_all(pg_engine)
|
||||
|
||||
with Session(pg_engine) as session:
|
||||
session.add(Product(id=1, name="Cheap", price=Decimal("10.00")))
|
||||
session.add(Product(id=2, name="Expensive", price=Decimal("100.00")))
|
||||
session.add(Product(id=3, name="Premium", price=Decimal("200.00")))
|
||||
session.commit()
|
||||
|
||||
with pg_engine.connect() as conn:
|
||||
result = conn.execute(select(ExpensiveProducts.as_table())).fetchall()
|
||||
assert len(result) == 2
|
||||
names = {row.name for row in result}
|
||||
assert names == {"Expensive", "Premium"}
|
||||
|
||||
finally:
|
||||
with pg_engine.begin() as conn:
|
||||
conn.execute(DropView(ExpensiveProducts._view, if_exists=True))
|
||||
Base.metadata.drop_all(pg_engine)
|
||||
|
||||
def test_materialized_view_inherits_metadata(self, pg_engine: Engine) -> None:
|
||||
"""Test that MaterializedViewBase inherits metadata from DeclarativeBase."""
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
class Sale(Base):
|
||||
__tablename__ = "test_sale_mi"
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
amount: Mapped[Decimal] = mapped_column(Numeric(10, 2))
|
||||
|
||||
class SaleStats(MaterializedViewBase, Base):
|
||||
__tablename__ = "test_sale_stats_mi"
|
||||
__select__ = select(
|
||||
func.count(Sale.id).label("sale_count"),
|
||||
func.sum(Sale.amount).label("total_amount"),
|
||||
)
|
||||
|
||||
try:
|
||||
Base.metadata.create_all(pg_engine)
|
||||
|
||||
with Session(pg_engine) as session:
|
||||
session.add(Sale(id=1, amount=Decimal("100.00")))
|
||||
session.add(Sale(id=2, amount=Decimal("200.00")))
|
||||
session.commit()
|
||||
|
||||
with pg_engine.begin() as conn:
|
||||
SaleStats.refresh(conn)
|
||||
|
||||
with pg_engine.connect() as conn:
|
||||
result = conn.execute(select(SaleStats.as_table())).fetchone()
|
||||
assert result.sale_count == 2
|
||||
assert result.total_amount == Decimal("300.00")
|
||||
|
||||
finally:
|
||||
with pg_engine.begin() as conn:
|
||||
conn.execute(DropMaterializedView(SaleStats._view, if_exists=True))
|
||||
Base.metadata.drop_all(pg_engine)
|
||||
|
||||
def test_multiple_views_same_base(self, pg_engine: Engine) -> None:
|
||||
"""Test multiple views sharing the same DeclarativeBase."""
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "test_user_multi"
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(100))
|
||||
role: Mapped[str] = mapped_column(String(50))
|
||||
|
||||
class AdminUsers(ViewBase, Base):
|
||||
__tablename__ = "test_admin_users"
|
||||
__select__ = select(User.id, User.name).where(User.role == "admin")
|
||||
|
||||
class RegularUsers(ViewBase, Base):
|
||||
__tablename__ = "test_regular_users"
|
||||
__select__ = select(User.id, User.name).where(User.role == "user")
|
||||
|
||||
class UserStats(MaterializedViewBase, Base):
|
||||
__tablename__ = "test_user_stats_multi"
|
||||
__select__ = select(func.count(User.id).label("count"))
|
||||
|
||||
try:
|
||||
Base.metadata.create_all(pg_engine)
|
||||
|
||||
with Session(pg_engine) as session:
|
||||
session.add(User(id=1, name="Alice", role="admin"))
|
||||
session.add(User(id=2, name="Bob", role="user"))
|
||||
session.add(User(id=3, name="Charlie", role="user"))
|
||||
session.commit()
|
||||
|
||||
with pg_engine.connect() as conn:
|
||||
admins = conn.execute(select(AdminUsers.as_table())).fetchall()
|
||||
assert len(admins) == 1
|
||||
assert admins[0].name == "Alice"
|
||||
|
||||
regulars = conn.execute(select(RegularUsers.as_table())).fetchall()
|
||||
assert len(regulars) == 2
|
||||
|
||||
with pg_engine.begin() as conn:
|
||||
UserStats.refresh(conn)
|
||||
|
||||
with pg_engine.connect() as conn:
|
||||
stats = conn.execute(select(UserStats.as_table())).fetchone()
|
||||
assert stats.count == 3
|
||||
|
||||
finally:
|
||||
with pg_engine.begin() as conn:
|
||||
conn.execute(DropView(AdminUsers._view, if_exists=True))
|
||||
conn.execute(DropView(RegularUsers._view, if_exists=True))
|
||||
conn.execute(DropMaterializedView(UserStats._view, if_exists=True))
|
||||
Base.metadata.drop_all(pg_engine)
|
||||
Reference in New Issue
Block a user