Skip to content

Commit

Permalink
standardizes template fields for BaseSQLOperator and adds `database…
Browse files Browse the repository at this point in the history
…` as a templated field (apache#39826)

* standardizes template fields for BaseSQLOperator

* adds template_fields sequence string type

* fixes hook params check in init
  • Loading branch information
nyoungstudios authored Jun 3, 2024
1 parent 3548636 commit 651a6d6
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 13 deletions.
29 changes: 16 additions & 13 deletions airflow/providers/common/sql/operators/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ class BaseSQLOperator(BaseOperator):

conn_id_field = "conn_id"

template_fields: Sequence[str] = ("conn_id", "database", "hook_params")

def __init__(
self,
*,
Expand All @@ -139,7 +141,7 @@ def __init__(
super().__init__(**kwargs)
self.conn_id = conn_id
self.database = database
self.hook_params = {} if hook_params is None else hook_params
self.hook_params = hook_params or {}
self.retry_on_failure = retry_on_failure

@cached_property
Expand Down Expand Up @@ -220,7 +222,7 @@ class SQLExecuteQueryOperator(BaseSQLOperator):
:ref:`howto/operator:SQLExecuteQueryOperator`
"""

template_fields: Sequence[str] = ("conn_id", "sql", "parameters", "hook_params")
template_fields: Sequence[str] = ("sql", "parameters", *BaseSQLOperator.template_fields)
template_ext: Sequence[str] = (".sql", ".json")
template_fields_renderers = {"sql": "sql", "parameters": "json"}
ui_color = "#cdaaed"
Expand Down Expand Up @@ -425,7 +427,7 @@ class SQLColumnCheckOperator(BaseSQLOperator):
:ref:`howto/operator:SQLColumnCheckOperator`
"""

template_fields: Sequence[str] = ("partition_clause", "table", "sql", "hook_params")
template_fields: Sequence[str] = ("table", "partition_clause", "sql", *BaseSQLOperator.template_fields)
template_fields_renderers = {"sql": "sql"}

sql_check_template = """
Expand Down Expand Up @@ -653,7 +655,7 @@ class SQLTableCheckOperator(BaseSQLOperator):
:ref:`howto/operator:SQLTableCheckOperator`
"""

template_fields: Sequence[str] = ("partition_clause", "table", "sql", "conn_id", "hook_params")
template_fields: Sequence[str] = ("table", "partition_clause", "sql", *BaseSQLOperator.template_fields)

template_fields_renderers = {"sql": "sql"}

Expand Down Expand Up @@ -769,7 +771,7 @@ class SQLCheckOperator(BaseSQLOperator):
:param parameters: (optional) the parameters to render the SQL query with.
"""

template_fields: Sequence[str] = ("sql", "hook_params")
template_fields: Sequence[str] = ("sql", *BaseSQLOperator.template_fields)
template_ext: Sequence[str] = (
".hql",
".sql",
Expand Down Expand Up @@ -815,11 +817,7 @@ class SQLValueCheckOperator(BaseSQLOperator):
"""

__mapper_args__ = {"polymorphic_identity": "SQLValueCheckOperator"}
template_fields: Sequence[str] = (
"sql",
"pass_value",
"hook_params",
)
template_fields: Sequence[str] = ("sql", "pass_value", *BaseSQLOperator.template_fields)
template_ext: Sequence[str] = (
".hql",
".sql",
Expand Down Expand Up @@ -916,7 +914,7 @@ class SQLIntervalCheckOperator(BaseSQLOperator):
"""

__mapper_args__ = {"polymorphic_identity": "SQLIntervalCheckOperator"}
template_fields: Sequence[str] = ("sql1", "sql2", "hook_params")
template_fields: Sequence[str] = ("sql1", "sql2", *BaseSQLOperator.template_fields)
template_ext: Sequence[str] = (
".hql",
".sql",
Expand Down Expand Up @@ -1044,7 +1042,12 @@ class SQLThresholdCheckOperator(BaseSQLOperator):
:param max_threshold: numerical value or max threshold sql to be executed (templated)
"""

template_fields: Sequence[str] = ("sql", "min_threshold", "max_threshold", "hook_params")
template_fields: Sequence[str] = (
"sql",
"min_threshold",
"max_threshold",
*BaseSQLOperator.template_fields,
)
template_ext: Sequence[str] = (
".hql",
".sql",
Expand Down Expand Up @@ -1142,7 +1145,7 @@ class BranchSQLOperator(BaseSQLOperator, SkipMixin):
:param parameters: (optional) the parameters to render the SQL query with.
"""

template_fields: Sequence[str] = ("sql",)
template_fields: Sequence[str] = ("sql", *BaseSQLOperator.template_fields)
template_ext: Sequence[str] = (".sql",)
template_fields_renderers = {"sql": "sql"}
ui_color = "#a22034"
Expand Down
1 change: 1 addition & 0 deletions airflow/providers/common/sql/operators/sql.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def parse_boolean(val: str) -> str | bool: ...

class BaseSQLOperator(BaseOperator):
conn_id_field: str
template_fields: Sequence[str]
conn_id: Incomplete
database: Incomplete
hook_params: Incomplete
Expand Down
23 changes: 23 additions & 0 deletions tests/providers/common/sql/operators/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from airflow.operators.empty import EmptyOperator
from airflow.providers.common.sql.hooks.sql import fetch_all_handler
from airflow.providers.common.sql.operators.sql import (
BaseSQLOperator,
BranchSQLOperator,
SQLCheckOperator,
SQLColumnCheckOperator,
Expand Down Expand Up @@ -59,6 +60,28 @@ def _get_mock_db_hook():
return MockHook()


class TestBaseSQLOperator:
def _construct_operator(self, **kwargs):
dag = DAG("test_dag", start_date=datetime.datetime(2017, 1, 1), render_template_as_native_obj=True)
return BaseSQLOperator(
task_id="test_task",
conn_id="{{ conn_id }}",
database="{{ database }}",
hook_params="{{ hook_params }}",
**kwargs,
dag=dag,
)

def test_templated_fields(self):
operator = self._construct_operator()
operator.render_template_fields(
{"conn_id": "my_conn_id", "database": "my_database", "hook_params": {"key": "value"}}
)
assert operator.conn_id == "my_conn_id"
assert operator.database == "my_database"
assert operator.hook_params == {"key": "value"}


class TestSQLExecuteQueryOperator:
def _construct_operator(self, sql, **kwargs):
dag = DAG("test_dag", start_date=datetime.datetime(2017, 1, 1))
Expand Down

0 comments on commit 651a6d6

Please sign in to comment.