Skip to content

Commit

Permalink
[AIRFLOW-3677] Improve CheckOperator test coverage (apache#4756)
Browse files Browse the repository at this point in the history
Add tests for check_operator module

- add missing tests cases for CheckOperator
- add missing tests cases for ValueCheckOperator
- refactor all three classes
- replace **locals in str.format by explicit args
  • Loading branch information
feluelle authored and Fokko committed Apr 9, 2019
1 parent 6970b23 commit efa5ba8
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 54 deletions.
97 changes: 59 additions & 38 deletions airflow/operators/check_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
# specific language governing permissions and limitations
# under the License.

from builtins import zip
from builtins import str
from builtins import str, zip
from typing import Optional, Any, Iterable, Dict, SupportsAbs

from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -82,12 +81,14 @@ def __init__(
def execute(self, context=None):
self.log.info('Executing SQL check: %s', self.sql)
records = self.get_db_hook().get_first(self.sql)

self.log.info('Record: %s', records)
if not records:
raise AirflowException("The query returned None")
elif not all([bool(r) for r in records]):
exceptstr = "Test failed.\nQuery:\n{q}\nResults:\n{r!s}"
raise AirflowException(exceptstr.format(q=self.sql, r=records))
raise AirflowException("Test failed.\nQuery:\n{query}\nResults:\n{records!s}".format(
query=self.sql, records=records))

self.log.info("Success.")

def get_db_hook(self):
Expand Down Expand Up @@ -149,38 +150,51 @@ def __init__(
def execute(self, context=None):
self.log.info('Executing SQL check: %s', self.sql)
records = self.get_db_hook().get_first(self.sql)

if not records:
raise AirflowException("The query returned None")

pass_value_conv = _convert_to_float_if_possible(self.pass_value)
is_numeric_value_check = isinstance(pass_value_conv, float)

tolerance_pct_str = None
if self.tol is not None:
tolerance_pct_str = str(self.tol * 100) + '%'
tolerance_pct_str = str(self.tol * 100) + '%' if self.has_tolerance else None
error_msg = ("Test failed.\nPass value:{pass_value_conv}\n"
"Tolerance:{tolerance_pct_str}\n"
"Query:\n{sql}\nResults:\n{records!s}").format(
pass_value_conv=pass_value_conv,
tolerance_pct_str=tolerance_pct_str,
sql=self.sql,
records=records
)

except_temp = ("Test failed.\nPass value:{pass_value_conv}\n"
"Tolerance:{tolerance_pct_str}\n"
"Query:\n{sql}\nResults:\n{records!s}".format(
pass_value_conv=pass_value_conv, tolerance_pct_str=tolerance_pct_str, # noqa: E122
sql=self.sql, records=records))
if not is_numeric_value_check:
tests = [str(r) == pass_value_conv for r in records]
tests = self._get_string_matches(records, pass_value_conv)
elif is_numeric_value_check:
try:
num_rec = [float(r) for r in records]
numeric_records = self._to_float(records)
except (ValueError, TypeError):
cvestr = "Converting a result to float failed.\n"
raise AirflowException(cvestr + except_temp)
if self.has_tolerance:
tests = [
pass_value_conv * (1 - self.tol) <=
r <= pass_value_conv * (1 + self.tol)
for r in num_rec]
else:
tests = [r == pass_value_conv for r in num_rec]
raise AirflowException("Converting a result to float failed.\n{}".format(error_msg))
tests = self._get_numeric_matches(numeric_records, pass_value_conv)
else:
tests = []

if not all(tests):
raise AirflowException(except_temp)
raise AirflowException(error_msg)

def _to_float(self, records):
return [float(record) for record in records]

def _get_string_matches(self, records, pass_value_conv):
return [str(record) == pass_value_conv for record in records]

def _get_numeric_matches(self, numeric_records, numeric_pass_value_conv):
if self.has_tolerance:
return [
numeric_pass_value_conv * (1 - self.tol) <= record <= numeric_pass_value_conv * (1 + self.tol)
for record in numeric_records
]

return [record == numeric_pass_value_conv for record in numeric_records]

def get_db_hook(self):
return BaseHook.get_hook(conn_id=self.conn_id)
Expand Down Expand Up @@ -229,15 +243,16 @@ class IntervalCheckOperator(BaseOperator):

@apply_defaults
def __init__(
self,
table, # type: str
metrics_thresholds, # type: Dict[str, int]
date_filter_column='ds', # type: Optional[str]
days_back=-7, # type: SupportsAbs[int]
ratio_formula='max_over_min', # type: Optional[str]
ignore_zero=True, # type: Optional[bool]
conn_id=None, # type: Optional[str]
*args, **kwargs):
self,
table, # type: str
metrics_thresholds, # type: Dict[str, int]
date_filter_column='ds', # type: Optional[str]
days_back=-7, # type: SupportsAbs[int]
ratio_formula='max_over_min', # type: Optional[str]
ignore_zero=True, # type: Optional[bool]
conn_id=None, # type: Optional[str]
*args, **kwargs
):
super(IntervalCheckOperator, self).__init__(*args, **kwargs)
if ratio_formula not in self.ratio_formulas:
msg_template = "Invalid diff_method: {diff_method}. " \
Expand All @@ -256,9 +271,10 @@ def __init__(
self.days_back = -abs(days_back)
self.conn_id = conn_id
sqlexp = ', '.join(self.metrics_sorted)
sqlt = ("SELECT {sqlexp} FROM {table}"
" WHERE {date_filter_column}=").format(
sqlexp=sqlexp, table=table, date_filter_column=date_filter_column)
sqlt = "SELECT {sqlexp} FROM {table} WHERE {date_filter_column}=".format(
sqlexp=sqlexp, table=table, date_filter_column=date_filter_column
)

self.sql1 = sqlt + "'{{ ds }}'"
self.sql2 = sqlt + "'{{ macros.ds_add(ds, " + str(self.days_back) + ") }}'"

Expand All @@ -269,14 +285,18 @@ def execute(self, context=None):
row2 = hook.get_first(self.sql2)
self.log.info('Executing SQL check: %s', self.sql1)
row1 = hook.get_first(self.sql1)

if not row2:
raise AirflowException("The query {q} returned None".format(q=self.sql2))
raise AirflowException("The query {} returned None".format(self.sql2))
if not row1:
raise AirflowException("The query {q} returned None".format(q=self.sql1))
raise AirflowException("The query {} returned None".format(self.sql1))

current = dict(zip(self.metrics_sorted, row1))
reference = dict(zip(self.metrics_sorted, row2))

ratios = {}
test_results = {}

for m in self.metrics_sorted:
cur = current[m]
ref = reference[m]
Expand Down Expand Up @@ -307,6 +327,7 @@ def execute(self, context=None):
)
raise AirflowException("The following tests have failed:\n {0}".format(", ".join(
sorted(failed_tests))))

self.log.info("All tests have passed")

def get_db_hook(self):
Expand Down
45 changes: 29 additions & 16 deletions tests/operators/test_check_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,37 @@

import unittest
from datetime import datetime
from airflow.models import DAG
from airflow.exceptions import AirflowException

from airflow.operators.check_operator import IntervalCheckOperator, ValueCheckOperator
from airflow.exceptions import AirflowException
from airflow.models import DAG
from airflow.operators.check_operator import ValueCheckOperator, CheckOperator, IntervalCheckOperator
from tests.compat import mock


class ValueCheckOperatorTest(unittest.TestCase):
class TestCheckOperator(unittest.TestCase):

@mock.patch.object(CheckOperator, 'get_db_hook')
def test_execute_no_records(self, mock_get_db_hook):
mock_get_db_hook.return_value.get_first.return_value = []

with self.assertRaises(AirflowException):
CheckOperator(sql='sql').execute()

@mock.patch.object(CheckOperator, 'get_db_hook')
def test_execute_not_all_records_are_true(self, mock_get_db_hook):
mock_get_db_hook.return_value.get_first.return_value = ["data", ""]

with self.assertRaises(AirflowException):
CheckOperator(sql='sql').execute()


class TestValueCheckOperator(unittest.TestCase):

def setUp(self):
self.task_id = 'test_task'
self.conn_id = 'default_conn'

def __construct_operator(self, sql, pass_value, tolerance=None):

def _construct_operator(self, sql, pass_value, tolerance=None):
dag = DAG('test_dag', start_date=datetime(2017, 1, 1))

return ValueCheckOperator(
Expand All @@ -46,44 +62,41 @@ def __construct_operator(self, sql, pass_value, tolerance=None):

def test_pass_value_template_string(self):
pass_value_str = "2018-03-22"
operator = self.__construct_operator('select date from tab1;', "{{ ds }}")
result = operator.render_template('pass_value', operator.pass_value,
{'ds': pass_value_str})
operator = self._construct_operator('select date from tab1;', "{{ ds }}")

result = operator.render_template('pass_value', operator.pass_value, {'ds': pass_value_str})

self.assertEqual(operator.task_id, self.task_id)
self.assertEqual(result, pass_value_str)

def test_pass_value_template_string_float(self):
pass_value_float = 4.0
operator = self.__construct_operator('select date from tab1;', pass_value_float)
operator = self._construct_operator('select date from tab1;', pass_value_float)

result = operator.render_template('pass_value', operator.pass_value, {})

self.assertEqual(operator.task_id, self.task_id)
self.assertEqual(result, str(pass_value_float))

@mock.patch.object(ValueCheckOperator, 'get_db_hook')
def test_execute_pass(self, mock_get_db_hook):

mock_hook = mock.Mock()
mock_hook.get_first.return_value = [10]
mock_get_db_hook.return_value = mock_hook

sql = 'select value from tab1 limit 1;'

operator = self.__construct_operator(sql, 5, 1)
operator = self._construct_operator(sql, 5, 1)

operator.execute(None)

mock_hook.get_first.assert_called_with(sql)

@mock.patch.object(ValueCheckOperator, 'get_db_hook')
def test_execute_fail(self, mock_get_db_hook):

mock_hook = mock.Mock()
mock_hook.get_first.return_value = [11]
mock_get_db_hook.return_value = mock_hook

operator = self.__construct_operator('select value from tab1 limit 1;', 5, 1)
operator = self._construct_operator('select value from tab1 limit 1;', 5, 1)

with self.assertRaisesRegexp(AirflowException, 'Tolerance:100.0%'):
operator.execute()
Expand Down

0 comments on commit efa5ba8

Please sign in to comment.