Skip to content

Commit

Permalink
Improve code quality of ExternalTaskSensor (apache#12574)
Browse files Browse the repository at this point in the history
  • Loading branch information
turbaszek authored Nov 24, 2020
1 parent 74ed92b commit b57b932
Showing 1 changed file with 39 additions and 42 deletions.
81 changes: 39 additions & 42 deletions airflow/sensors/external_task_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import datetime
import os
from typing import FrozenSet, Optional, Union
from typing import Any, Callable, FrozenSet, List, Optional, Union

from sqlalchemy import func

Expand Down Expand Up @@ -91,13 +91,13 @@ def operator_extra_links(self):
def __init__(
self,
*,
external_dag_id,
external_task_id=None,
allowed_states=None,
failed_states=None,
execution_delta=None,
execution_date_fn=None,
check_existence=False,
external_dag_id: str,
external_task_id: Optional[str] = None,
allowed_states: Optional[List[str]] = None,
failed_states: Optional[List[str]] = None,
execution_delta: Optional[datetime.timedelta] = None,
execution_date_fn: Optional[Callable] = None,
check_existence: bool = False,
**kwargs,
):
super().__init__(**kwargs)
Expand Down Expand Up @@ -136,8 +136,7 @@ def __init__(
self.external_dag_id = external_dag_id
self.external_task_id = external_task_id
self.check_existence = check_existence
# we only check the existence for the first time.
self.has_checked_existence = False
self._has_checked_existence = False

@provide_session
def poke(self, context, session=None):
Expand All @@ -149,38 +148,22 @@ def poke(self, context, session=None):
dttm = context['execution_date']

dttm_filter = dttm if isinstance(dttm, list) else [dttm]
serialized_dttm_filter = ','.join([datetime.isoformat() for datetime in dttm_filter])
serialized_dttm_filter = ','.join(dt.isoformat() for dt in dttm_filter)

self.log.info(
'Poking for %s.%s on %s ... ', self.external_dag_id, self.external_task_id, serialized_dttm_filter
)

DM = DagModel
# we only do the check for 1st time, no need for subsequent poke
if self.check_existence and not self.has_checked_existence:
dag_to_wait = session.query(DM).filter(DM.dag_id == self.external_dag_id).first()

if not dag_to_wait:
raise AirflowException(f'The external DAG {self.external_dag_id} does not exist.')
elif not os.path.exists(dag_to_wait.fileloc):
raise AirflowException(f'The external DAG {self.external_dag_id} was deleted.')

if self.external_task_id:
refreshed_dag_info = DagBag(dag_to_wait.fileloc).get_dag(self.external_dag_id)
if not refreshed_dag_info.has_task(self.external_task_id):
raise AirflowException(
f'The external task {self.external_task_id} in '
f'DAG {self.external_dag_id} does not exist.'
)
self.has_checked_existence = True
# In poke mode this will check dag existence only once
if self.check_existence and not self._has_checked_existence:
self._check_for_existence(session=session)

count_allowed = self.get_count(dttm_filter, session, self.allowed_states)

count_failed = -1
if len(self.failed_states) > 0:
if self.failed_states:
count_failed = self.get_count(dttm_filter, session, self.failed_states)

session.commit()
if count_failed == len(dttm_filter):
if self.external_task_id:
raise AirflowException(
Expand All @@ -191,7 +174,25 @@ def poke(self, context, session=None):

return count_allowed == len(dttm_filter)

def get_count(self, dttm_filter, session, states):
def _check_for_existence(self, session) -> None:
dag_to_wait = session.query(DagModel).filter(DagModel.dag_id == self.external_dag_id).first()

if not dag_to_wait:
raise AirflowException(f'The external DAG {self.external_dag_id} does not exist.')

if not os.path.exists(dag_to_wait.fileloc):
raise AirflowException(f'The external DAG {self.external_dag_id} was deleted.')

if self.external_task_id:
refreshed_dag_info = DagBag(dag_to_wait.fileloc).get_dag(self.external_dag_id)
if not refreshed_dag_info.has_task(self.external_task_id):
raise AirflowException(
f'The external task {self.external_task_id} in '
f'DAG {self.external_dag_id} does not exist.'
)
self._has_checked_existence = True

def get_count(self, dttm_filter, session, states) -> int:
"""
Get the count of records against dttm filter and states
Expand All @@ -205,11 +206,9 @@ def get_count(self, dttm_filter, session, states):
"""
TI = TaskInstance
DR = DagRun

if self.external_task_id:
# .count() is inefficient
count = (
session.query(func.count())
session.query(func.count()) # .count() is inefficient
.filter(
TI.dag_id == self.external_dag_id,
TI.task_id == self.external_task_id,
Expand All @@ -219,7 +218,6 @@ def get_count(self, dttm_filter, session, states):
.scalar()
)
else:
# .count() is inefficient
count = (
session.query(func.count())
.filter(
Expand All @@ -231,7 +229,7 @@ def get_count(self, dttm_filter, session, states):
)
return count

def _handle_execution_date_fn(self, context):
def _handle_execution_date_fn(self, context) -> Any:
"""
This function is to handle backwards compatibility with how this operator was
previously where it only passes the execution date, but also allow for the newer
Expand Down Expand Up @@ -279,8 +277,8 @@ class ExternalTaskMarker(DummyOperator):
def __init__(
self,
*,
external_dag_id,
external_task_id,
external_dag_id: str,
external_task_id: str,
execution_date: Optional[Union[str, datetime.datetime]] = "{{ execution_date.isoformat() }}",
recursion_depth: int = 10,
**kwargs,
Expand All @@ -294,10 +292,9 @@ def __init__(
self.execution_date = execution_date
else:
raise TypeError(
'Expected str or datetime.datetime type for execution_date. Got {}'.format(
type(execution_date)
)
f'Expected str or datetime.datetime type for execution_date. Got {type(execution_date)}'
)

if recursion_depth <= 0:
raise ValueError("recursion_depth should be a positive integer")
self.recursion_depth = recursion_depth
Expand Down

0 comments on commit b57b932

Please sign in to comment.