Skip to content

Commit

Permalink
Merge pull request apache#1213 from airbnb/task_ref_str
Browse files Browse the repository at this point in the history
Tasks refs to upstream and downstream tasks using strings instead obj refs
  • Loading branch information
mistercrunch committed Mar 27, 2016
2 parents 333e45b + c28fbd7 commit 182a4cb
Showing 1 changed file with 32 additions and 29 deletions.
61 changes: 32 additions & 29 deletions airflow/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,18 +798,17 @@ def are_dependents_done(self, session=None):
"""
task = self.task

if not task._downstream_list:
if not task.downstream_task_ids:
return True

downstream_task_ids = [t.task_id for t in task._downstream_list]
ti = session.query(func.count(TaskInstance.task_id)).filter(
TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id.in_(downstream_task_ids),
TaskInstance.task_id.in_(task.downstream_task_ids),
TaskInstance.execution_date == self.execution_date,
TaskInstance.state == State.SUCCESS,
)
count = ti[0][0]
return count == len(task._downstream_list)
return count == len(task.downstream_task_ids)

@provide_session
def are_dependencies_met(
Expand Down Expand Up @@ -859,10 +858,9 @@ def are_dependencies_met(
return False

# Checking that all upstream dependencies have succeeded
if not task._upstream_list or task.trigger_rule == TR.DUMMY:
if not task.upstream_list or task.trigger_rule == TR.DUMMY:
return True

upstream_task_ids = [t.task_id for t in task._upstream_list]
qry = (
session
.query(
Expand All @@ -878,15 +876,15 @@ def are_dependencies_met(
)
.filter(
TI.dag_id == self.dag_id,
TI.task_id.in_(upstream_task_ids),
TI.task_id.in_(task.upstream_task_ids),
TI.execution_date == self.execution_date,
TI.state.in_([
State.SUCCESS, State.FAILED,
State.UPSTREAM_FAILED, State.SKIPPED]),
)
)
successes, skipped, failed, upstream_failed, done = qry.first()
upstream = len(task._upstream_list)
upstream = len(task.upstream_task_ids)
tr = task.trigger_rule
upstream_done = done >= upstream

Expand Down Expand Up @@ -1585,8 +1583,8 @@ def __init__(
self.dag = dag

# Private attributes
self._upstream_list = []
self._downstream_list = []
self._upstream_task_ids = []
self._downstream_task_ids = []

self._comps = {
'task_id',
Expand Down Expand Up @@ -1692,8 +1690,6 @@ def __deepcopy__(self, memo):
result = cls.__new__(cls)
memo[id(self)] = result

self._upstream_list = sorted(self._upstream_list, key=lambda x: x.task_id)
self._downstream_list = sorted(self._downstream_list, key=lambda x: x.task_id)
for k, v in list(self.__dict__.items()):
if k not in ('user_defined_macros', 'params'):
setattr(result, k, copy.deepcopy(v, memo))
Expand Down Expand Up @@ -1768,12 +1764,20 @@ def resolve_template_files(self):
@property
def upstream_list(self):
"""@property: list of tasks directly upstream"""
return self._upstream_list
return [self.dag.get_task(tid) for tid in self._upstream_task_ids]

@property
def upstream_task_ids(self):
return self._upstream_task_ids

@property
def downstream_list(self):
"""@property: list of tasks directly downstream"""
return self._downstream_list
return [self.dag.get_task(tid) for tid in self._downstream_task_ids]

@property
def downstream_task_ids(self):
return self._downstream_task_ids

def clear(
self, start_date=None, end_date=None,
Expand All @@ -1795,12 +1799,12 @@ def clear(
tasks = [self.task_id]

if upstream:
tasks += \
[t.task_id for t in self.get_flat_relatives(upstream=True)]
tasks += [
t.task_id for t in self.get_flat_relatives(upstream=True)]

if downstream:
tasks += \
[t.task_id for t in self.get_flat_relatives(upstream=False)]
tasks += [
t.task_id for t in self.get_flat_relatives(upstream=False)]

qry = qry.filter(TI.task_id.in_(tasks))

Expand Down Expand Up @@ -1911,11 +1915,11 @@ def _set_relatives(self, task_or_task_list, upstream=False):
if not isinstance(task, BaseOperator):
raise AirflowException('Expecting a task')
if upstream:
task.append_only_new(task._downstream_list, self)
self.append_only_new(self._upstream_list, task)
task.append_only_new(task._downstream_task_ids, self.task_id)
self.append_only_new(self._upstream_task_ids, task.task_id)
else:
self.append_only_new(self._downstream_list, task)
task.append_only_new(task._upstream_list, self)
self.append_only_new(self._downstream_task_ids, task.task_id)
task.append_only_new(task._upstream_task_ids, self.task_id)

self.detect_downstream_cycle()

Expand Down Expand Up @@ -2478,17 +2482,16 @@ def sub_dag(self, task_regex, include_downstream=False,
also_include += t.get_flat_relatives(upstream=False)
if include_upstream:
also_include += t.get_flat_relatives(upstream=True)

# Compiling the unique list of tasks that made the cut
tasks = list(set(regex_match + also_include))
dag.tasks = tasks
dag.tasks = list(set(regex_match + also_include))
for t in dag.tasks:
# Removing upstream/downstream references to tasks that did not
# made the cut
t._upstream_list = [
ut for ut in t._upstream_list if utils.is_in(ut, tasks)]
t._downstream_list = [
ut for ut in t._downstream_list if utils.is_in(ut, tasks)]

t._upstream_task_ids = [
tid for tid in t._upstream_task_ids if tid in dag.task_ids]
t._downstream_task_ids = [
tid for tid in t._downstream_task_ids if tid in dag.task_ids]
return dag

def has_task(self, task_id):
Expand Down

0 comments on commit 182a4cb

Please sign in to comment.