Skip to content

Commit

Permalink
feat!: snowpark (TobikoData#2666)
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao authored May 24, 2024
1 parent 64b273b commit f6250d5
Show file tree
Hide file tree
Showing 10 changed files with 107 additions and 27 deletions.
36 changes: 35 additions & 1 deletion docs/concepts/models/python_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ The `execute` function is wrapped with the `@model` [decorator](https://wiki.pyt

Because SQLMesh creates tables before evaluating models, the schema of the output DataFrame is a required argument. The `@model` argument `columns` contains a dictionary of column names to types.

The function takes an `ExecutionContext` that is able to run queries and to retrieve the current time interval that is being processed, along with arbitrary key-value arguments passed in at runtime. The function can either return a Pandas or PySpark Dataframe instance.
The function takes an `ExecutionContext` that is able to run queries and to retrieve the current time interval that is being processed, along with arbitrary key-value arguments passed in at runtime. The function can either return a Pandas, PySpark, or Snowpark Dataframe instance.

If the function output is too large, it can also be returned in chunks using Python generators.

Expand Down Expand Up @@ -292,6 +292,40 @@ def execute(
return df
```


### Snowpark
This example demonstrates using the Snowpark DataFrame API. If you use Snowflake, the DataFrame API is preferred to Pandas since it allows you to compute in a distributed fashion.

```python linenums="1"
import typing as t
from datetime import datetime

import pandas as pd
from snowwflake.snowpark.dataframe import DataFrame

from sqlmesh import ExecutionContext, model

@model(
"docs_example.snowpark",
columns={
"id": "int",
"name": "text",
"country": "text",
},
)
def execute(
context: ExecutionContext,
start: datetime,
end: datetime,
execution_time: datetime,
**kwargs: t.Any,
) -> DataFrame:
# returns the snowpark DataFrame directly, so no data is computed locally
df = context.snowpark.create_dataframe([[1, "a", "usa"], [2, "b", "cad"]], schema=["id", "name", "country"])
df = df.filter(df.id > 1)
return df
```

### Batching
If the output of a Python model is very large and you cannot use Spark, it may be helpful to split the output into multiple batches.

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
"pyspark~=3.5.0",
"pytz",
"snowflake-connector-python[pandas,secure-local-storage]>=3.0.2",
"snowflake-snowpark-python[pandas];python_version<'3.12'",
"sqlalchemy-stubs",
"tenacity==8.1.0",
"types-croniter",
Expand Down
5 changes: 4 additions & 1 deletion sqlmesh/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,10 @@ def evaluate(
execution_time=execution_time,
limit=limit,
)
ctx.obj.console.log_success(df)
if hasattr(df, "show"):
df.show(limit)
else:
ctx.obj.console.log_success(df)


@cli.command("format")
Expand Down
12 changes: 11 additions & 1 deletion sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,12 @@
if t.TYPE_CHECKING:
from typing_extensions import Literal

from sqlmesh.core.engine_adapter._typing import DF, PySparkDataFrame, PySparkSession
from sqlmesh.core.engine_adapter._typing import (
DF,
PySparkDataFrame,
PySparkSession,
SnowparkSession,
)
from sqlmesh.core.snapshot import Node

ModelOrSnapshot = t.Union[str, Model, Snapshot]
Expand Down Expand Up @@ -152,6 +157,11 @@ def spark(self) -> t.Optional[PySparkSession]:
"""Returns the spark session if it exists."""
return self.engine_adapter.spark

@property
def snowpark(self) -> t.Optional[SnowparkSession]:
"""Returns the snowpark session if it exists."""
return self.engine_adapter.snowpark

@property
def default_catalog(self) -> t.Optional[str]:
raise NotImplementedError
Expand Down
17 changes: 16 additions & 1 deletion sqlmesh/core/engine_adapter/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,27 @@
import pandas as pd
from sqlglot import exp

from sqlmesh.utils import optional_import

if t.TYPE_CHECKING:
import pyspark
import pyspark.sql.connect.dataframe

snowpark = optional_import("snowflake.snowpark")

Query = t.Union[exp.Query, exp.DerivedTable]
PySparkSession = t.Union[pyspark.sql.SparkSession, pyspark.sql.connect.dataframe.SparkSession]
PySparkDataFrame = t.Union[pyspark.sql.DataFrame, pyspark.sql.connect.dataframe.DataFrame]
DF = t.Union[pd.DataFrame, pyspark.sql.DataFrame, pyspark.sql.connect.dataframe.DataFrame]

# snowpark is not available on python 3.12
from snowflake.snowpark import Session as SnowparkSession # noqa
from snowflake.snowpark.dataframe import DataFrame as SnowparkDataFrame

DF = t.Union[
pd.DataFrame,
pyspark.sql.DataFrame,
pyspark.sql.connect.dataframe.DataFrame,
SnowparkDataFrame,
]

QueryOrDF = t.Union[Query, DF]
17 changes: 10 additions & 7 deletions sqlmesh/core/engine_adapter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
PySparkSession,
Query,
QueryOrDF,
SnowparkSession,
)
from sqlmesh.core.node import IntervalUnit

Expand Down Expand Up @@ -146,14 +147,14 @@ def cursor(self) -> t.Any:
def spark(self) -> t.Optional[PySparkSession]:
return None

@property
def snowpark(self) -> t.Optional[SnowparkSession]:
return None

@property
def comments_enabled(self) -> bool:
return self._register_comments and self.COMMENT_CREATION_TABLE.is_supported

@classmethod
def is_pandas_df(cls, value: t.Any) -> bool:
return isinstance(value, pd.DataFrame)

@classmethod
def _casted_columns(cls, columns_to_types: t.Dict[str, exp.DataType]) -> t.List[exp.Alias]:
return [
Expand Down Expand Up @@ -244,7 +245,7 @@ def _columns_to_types(
) -> t.Optional[t.Dict[str, exp.DataType]]:
if columns_to_types:
return columns_to_types
if self.is_pandas_df(query_or_df):
if isinstance(query_or_df, pd.DataFrame):
return columns_to_types_from_df(t.cast(pd.DataFrame, query_or_df))
return columns_to_types

Expand Down Expand Up @@ -833,8 +834,10 @@ def create_view(
view_properties: Optional view properties to add to the view.
create_kwargs: Additional kwargs to pass into the Create expression
"""
if self.is_pandas_df(query_or_df):
values = list(t.cast(pd.DataFrame, query_or_df).itertuples(index=False, name=None))
if isinstance(query_or_df, pd.DataFrame):
values: t.List[t.Tuple[t.Any, ...]] = list(
query_or_df.itertuples(index=False, name=None)
)
columns_to_types = columns_to_types or self._columns_to_types(query_or_df)
if not columns_to_types:
raise SQLMeshError("columns_to_types must be provided for dataframes")
Expand Down
2 changes: 1 addition & 1 deletion sqlmesh/core/engine_adapter/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def replace_query(
If it does exist then we need to do the:
`CREATE TABLE...`, `INSERT INTO...`, `RENAME TABLE...`, `RENAME TABLE...`, DROP TABLE...` dance.
"""
if not self.is_pandas_df(query_or_df) or not self.table_exists(table_name):
if not isinstance(query_or_df, pd.DataFrame) or not self.table_exists(table_name):
return super().replace_query(
table_name,
query_or_df,
Expand Down
38 changes: 26 additions & 12 deletions sqlmesh/core/engine_adapter/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@
SourceQuery,
set_catalog,
)
from sqlmesh.utils import optional_import
from sqlmesh.utils.errors import SQLMeshError

snowpark = optional_import("snowflake.snowpark")

if t.TYPE_CHECKING:
from sqlmesh.core._typing import SchemaName, SessionProperties, TableName
from sqlmesh.core.engine_adapter._typing import DF, Query
from sqlmesh.core.engine_adapter._typing import DF, Query, SnowparkSession


@set_catalog(
Expand Down Expand Up @@ -72,22 +75,30 @@ def session(self, properties: SessionProperties) -> t.Iterator[None]:
yield
self.execute(f"USE WAREHOUSE {current_warehouse_sql}")

@property
def snowpark(self) -> t.Optional[SnowparkSession]:
if snowpark:
return snowpark.Session.builder.configs(
{"connection": self._connection_pool.get()}
).getOrCreate()
return None

def _df_to_source_queries(
self,
df: DF,
columns_to_types: t.Dict[str, exp.DataType],
batch_size: int,
target_table: TableName,
) -> t.List[SourceQuery]:
assert isinstance(df, pd.DataFrame)
temp_table = self._get_temp_table(target_table or "pandas")

def query_factory() -> Query:
from snowflake.connector.pandas_tools import write_pandas
if snowpark and isinstance(df, snowpark.dataframe.DataFrame):
df.createOrReplaceTempView(temp_table.sql(dialect=self.dialect, identify=True))
elif isinstance(df, pd.DataFrame):
from snowflake.connector.pandas_tools import write_pandas

if not self.table_exists(temp_table):
# Workaround for https://github.com/snowflakedb/snowflake-connector-python/issues/1034
#
# The above issue has already been fixed upstream, but we keep the following
# line anyway in order to support a wider range of Snowflake versions.
schema = f'"{temp_table.db}"'
Expand All @@ -109,23 +120,26 @@ def query_factory() -> Query:
df[column] = pd.to_datetime(df[column]).dt.strftime(
"%Y-%m-%d %H:%M:%S.%f"
) # type: ignore
self.create_table(temp_table, columns_to_types, exists=False)
self.create_table(temp_table, columns_to_types)

write_pandas(
self._connection_pool.get(),
df,
temp_table.name,
schema=temp_table.db or None,
database=temp_table.catalog or None,
chunk_size=self.DEFAULT_BATCH_SIZE,
overwrite=True,
table_type="temp",
)
else:
raise SQLMeshError(
f"Unknown dataframe type: {type(df)} for {target_table}. Expecting pandas or snowpark."
)

return exp.select(*self._casted_columns(columns_to_types)).from_(temp_table)

return [
SourceQuery(
query_factory=query_factory,
cleanup_func=lambda: self.drop_table(temp_table),
)
]
return [SourceQuery(query_factory=query_factory)]

def _fetch_native_df(
self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
Expand Down
2 changes: 1 addition & 1 deletion sqlmesh/core/engine_adapter/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def try_get_pyspark_df(cls, value: t.Any) -> t.Optional[PySparkDataFrame]:

@classmethod
def try_get_pandas_df(cls, value: t.Any) -> t.Optional[pd.DataFrame]:
if cls.is_pandas_df(value):
if isinstance(value, pd.DataFrame):
return value
return None

Expand Down
4 changes: 2 additions & 2 deletions sqlmesh/core/snapshot/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ def apply(query_or_df: QueryOrDF, index: int = 0) -> None:
if isinstance(query_or_df, pd.DataFrame):
return query_or_df.head(limit)
if not isinstance(query_or_df, exp.Expression):
# We assume that if this branch is reached, `query_or_df` is a pyspark dataframe,
# We assume that if this branch is reached, `query_or_df` is a pyspark / snowpark dataframe,
# so we use `limit` instead of `head` to get back a dataframe instead of List[Row]
# https://spark.apache.org/docs/3.1.1/api/python/reference/api/pyspark.sql.DataFrame.head.html#pyspark.sql.DataFrame.head
return query_or_df.limit(limit)
Expand Down Expand Up @@ -591,7 +591,7 @@ def apply(query_or_df: QueryOrDF, index: int = 0) -> None:
query_or_df = reduce(
lambda a, b: (
pd.concat([a, b], ignore_index=True) # type: ignore
if self.adapter.is_pandas_df(a)
if isinstance(a, pd.DataFrame)
else a.union_all(b) # type: ignore
), # type: ignore
queries_or_dfs,
Expand Down

0 comments on commit f6250d5

Please sign in to comment.