Skip to content

Commit

Permalink
Logging and returning info about query execution SnowflakeHook (apach…
Browse files Browse the repository at this point in the history
  • Loading branch information
JavierLopezT authored Jul 1, 2021
1 parent a2d6aa0 commit 8b41c2e
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 5 deletions.
18 changes: 15 additions & 3 deletions airflow/providers/snowflake/hooks/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from snowflake import connector
from snowflake.connector import SnowflakeConnection
from snowflake.connector import DictCursor, SnowflakeConnection
from snowflake.connector.util_text import split_statements

from airflow.hooks.dbapi import DbApiHook
Expand Down Expand Up @@ -251,7 +251,10 @@ def run(self, sql: Union[str, list], autocommit: bool = False, parameters: Optio
"""
Runs a command or a list of commands. Pass a list of sql
statements to the sql parameter to get them to execute
sequentially
sequentially. The variable execution_info is returned so that
it can be used in the Operators to modify the behavior
depending on the result of the query (i.e fail the operator
if the copy has processed 0 files)
:param sql: the sql string to be executed with possibly multiple statements,
or a list of sql statements to execute
Expand All @@ -273,14 +276,21 @@ def run(self, sql: Union[str, list], autocommit: bool = False, parameters: Optio
sql = [sql_string for sql_string, _ in split_statements_tuple if sql_string]

self.log.debug("Executing %d statements against Snowflake DB", len(sql))
with closing(conn.cursor()) as cur:
with closing(conn.cursor(DictCursor)) as cur:

for sql_statement in sql:

self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters)
if parameters:
cur.execute(sql_statement, parameters)
else:
cur.execute(sql_statement)

execution_info = []
for row in cur:
self.log.info("Statement execution info - %s", row)
execution_info.append(row)

self.log.info("Rows affected: %s", cur.rowcount)
self.log.info("Snowflake query id: %s", cur.sfqid)
self.query_ids.append(cur.sfqid)
Expand All @@ -289,3 +299,5 @@ def run(self, sql: Union[str, list], autocommit: bool = False, parameters: Optio
# or if db does not supports autocommit, we do a manual commit.
if not self.get_autocommit(conn):
conn.commit()

return execution_info
7 changes: 6 additions & 1 deletion airflow/providers/snowflake/operators/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(
snowflake_conn_id: str = 'snowflake_default',
parameters: Optional[dict] = None,
autocommit: bool = True,
do_xcom_push: bool = True,
warehouse: Optional[str] = None,
database: Optional[str] = None,
role: Optional[str] = None,
Expand All @@ -89,6 +90,7 @@ def __init__(
self.snowflake_conn_id = snowflake_conn_id
self.sql = sql
self.autocommit = autocommit
self.do_xcom_push = do_xcom_push
self.parameters = parameters
self.warehouse = warehouse
self.database = database
Expand Down Expand Up @@ -118,5 +120,8 @@ def execute(self, context: Any) -> None:
"""Run query on snowflake"""
self.log.info('Executing: %s', self.sql)
hook = self.get_hook()
hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters)
execution_info = hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters)
self.query_ids = hook.query_ids

if self.do_xcom_push:
return execution_info
3 changes: 2 additions & 1 deletion tests/providers/snowflake/operators/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,6 @@ def test_snowflake_operator(self, mock_get_hook):
dummy VARCHAR(50)
);
"""
operator = SnowflakeOperator(task_id='basic_snowflake', sql=sql, dag=self.dag)
operator = SnowflakeOperator(task_id='basic_snowflake', sql=sql, dag=self.dag, do_xcom_push=False)
# do_xcom_push=False because otherwise the XCom test will fail due to the mocking (it actually works)
operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)

0 comments on commit 8b41c2e

Please sign in to comment.