diff --git a/airflow/operators/sensors.py b/airflow/operators/sensors.py index 5276f6e40ca3e..4e4cb3bef9e41 100644 --- a/airflow/operators/sensors.py +++ b/airflow/operators/sensors.py @@ -69,12 +69,12 @@ def poke(self, context): def execute(self, context): started_at = datetime.now() while not self.poke(context): - sleep(self.poke_interval) - if (datetime.now() - started_at).seconds > self.timeout: + if (datetime.now() - started_at).total_seconds() > self.timeout: if self.soft_fail: raise AirflowSkipException('Snap. Time is OUT.') else: raise AirflowSensorTimeout('Snap. Time is OUT.') + sleep(self.poke_interval) logging.info("Success criteria met. Exiting.") diff --git a/tests/operators/sensors.py b/tests/operators/sensors.py index 025790e28fd2f..325ee8db305bd 100644 --- a/tests/operators/sensors.py +++ b/tests/operators/sensors.py @@ -12,11 +12,84 @@ # See the License for the specific language governing permissions and # limitations under the License. + +import logging import os +import time import unittest -from airflow.operators.sensors import HttpSensor -from airflow.exceptions import AirflowException +from datetime import datetime, timedelta + +from airflow import DAG, configuration +from airflow.operators.sensors import HttpSensor, BaseSensorOperator +from airflow.utils.decorators import apply_defaults +from airflow.exceptions import (AirflowException, + AirflowSensorTimeout, + AirflowSkipException) +configuration.test_mode() + +DEFAULT_DATE = datetime(2015, 1, 1) +TEST_DAG_ID = 'unit_test_dag' + + +class TimeoutTestSensor(BaseSensorOperator): + """ + Sensor that always returns the return_value provided + + :param return_value: Set to true to mark the task as SKIPPED on failure + :type return_value: any + """ + + @apply_defaults + def __init__( + self, + return_value=False, + *args, + **kwargs): + self.return_value = return_value + super(TimeoutTestSensor, self).__init__(*args, **kwargs) + + def poke(self, context): + return self.return_value + + def execute(self, context): + started_at = datetime.now() + time_jump = self.params.get('time_jump') + while not self.poke(context): + if time_jump: + started_at -= time_jump + if (datetime.now() - started_at).total_seconds() > self.timeout: + if self.soft_fail: + raise AirflowSkipException('Snap. Time is OUT.') + else: + raise AirflowSensorTimeout('Snap. Time is OUT.') + time.sleep(self.poke_interval) + logging.info("Success criteria met. Exiting.") + + +class SensorTimeoutTest(unittest.TestCase): + def setUp(self): + configuration.test_mode() + args = { + 'owner': 'airflow', + 'start_date': DEFAULT_DATE + } + dag = DAG(TEST_DAG_ID, default_args=args) + self.dag = dag + + def test_timeout(self): + t = TimeoutTestSensor( + task_id='test_timeout', + execution_timeout=timedelta(days=2), + return_value=False, + poke_interval=5, + params={'time_jump': timedelta(days=2, seconds=1)}, + dag=self.dag + ) + self.assertRaises( + AirflowSensorTimeout, + t.run, + start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, force=True) class HttpSensorTests(unittest.TestCase):