Skip to content

Commit

Permalink
Revert "Tasks references upstream and downstream tasks using strings …
Browse files Browse the repository at this point in the history
…instead of references"

This reverts commit 6c1207b.
  • Loading branch information
mistercrunch committed Apr 12, 2016
1 parent c8ee5aa commit 173b193
Showing 1 changed file with 29 additions and 24 deletions.
53 changes: 29 additions & 24 deletions airflow/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,17 +844,18 @@ def are_dependents_done(self, session=None):
"""
task = self.task

if not task._downstream_task_ids:
if not task._downstream_list:
return True

downstream_task_ids = [t.task_id for t in task._downstream_list]
ti = session.query(func.count(TaskInstance.task_id)).filter(
TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id.in_(task._downstream_task_ids),
TaskInstance.task_id.in_(downstream_task_ids),
TaskInstance.execution_date == self.execution_date,
TaskInstance.state == State.SUCCESS,
)
count = ti[0][0]
return count == len(task._downstream_task_ids)
return count == len(task._downstream_list)

@provide_session
def are_dependencies_met(
Expand Down Expand Up @@ -910,9 +911,10 @@ def are_dependencies_met(
return False

# Checking that all upstream dependencies have succeeded
if not task.upstream_list or task.trigger_rule == TR.DUMMY:
if not task._upstream_list or task.trigger_rule == TR.DUMMY:
return True

upstream_task_ids = [t.task_id for t in task._upstream_list]
qry = (
session
.query(
Expand All @@ -928,15 +930,15 @@ def are_dependencies_met(
)
.filter(
TI.dag_id == self.dag_id,
TI.task_id.in_(task._upstream_task_ids),
TI.task_id.in_(upstream_task_ids),
TI.execution_date == self.execution_date,
TI.state.in_([
State.SUCCESS, State.FAILED,
State.UPSTREAM_FAILED, State.SKIPPED]),
)
)
successes, skipped, failed, upstream_failed, done = qry.first()
upstream = len(task._upstream_task_ids)
upstream = len(task._upstream_list)
tr = task.trigger_rule
upstream_done = done >= upstream

Expand Down Expand Up @@ -1661,8 +1663,8 @@ def __init__(
self.dag = dag

# Private attributes
self._upstream_task_ids = []
self._downstream_task_ids = []
self._upstream_list = []
self._downstream_list = []

self._comps = {
'task_id',
Expand Down Expand Up @@ -1768,6 +1770,8 @@ def __deepcopy__(self, memo):
result = cls.__new__(cls)
memo[id(self)] = result

self._upstream_list = sorted(self._upstream_list, key=lambda x: x.task_id)
self._downstream_list = sorted(self._downstream_list, key=lambda x: x.task_id)
for k, v in list(self.__dict__.items()):
if k not in ('user_defined_macros', 'params'):
setattr(result, k, copy.deepcopy(v, memo))
Expand Down Expand Up @@ -1842,12 +1846,12 @@ def resolve_template_files(self):
@property
def upstream_list(self):
"""@property: list of tasks directly upstream"""
return [self.dag.get_task(tid) for tid in self._upstream_task_ids]
return self._upstream_list

@property
def downstream_list(self):
"""@property: list of tasks directly downstream"""
return [self.dag.get_task(tid) for tid in self._downstream_task_ids]
return self._downstream_list

def clear(
self, start_date=None, end_date=None,
Expand All @@ -1869,12 +1873,12 @@ def clear(
tasks = [self.task_id]

if upstream:
tasks += [
t.task_id for t in self.get_flat_relatives(upstream=True)]
tasks += \
[t.task_id for t in self.get_flat_relatives(upstream=True)]

if downstream:
tasks += [
t.task_id for t in self.get_flat_relatives(upstream=False)]
tasks += \
[t.task_id for t in self.get_flat_relatives(upstream=False)]

qry = qry.filter(TI.task_id.in_(tasks))

Expand Down Expand Up @@ -1992,11 +1996,11 @@ def _set_relatives(self, task_or_task_list, upstream=False):
if not isinstance(task, BaseOperator):
raise AirflowException('Expecting a task')
if upstream:
task.append_only_new(task._downstream_task_ids, self.task_id)
self.append_only_new(self._upstream_task_ids, task.task_id)
task.append_only_new(task._downstream_list, self)
self.append_only_new(self._upstream_list, task)
else:
self.append_only_new(self._downstream_task_ids, task.task_id)
task.append_only_new(task._upstream_task_ids, self.task_id)
self.append_only_new(self._downstream_list, task)
task.append_only_new(task._upstream_list, self)

self.detect_downstream_cycle()

Expand Down Expand Up @@ -2613,16 +2617,17 @@ def sub_dag(self, task_regex, include_downstream=False,
also_include += t.get_flat_relatives(upstream=False)
if include_upstream:
also_include += t.get_flat_relatives(upstream=True)

# Compiling the unique list of tasks that made the cut
dag.tasks = list(set(regex_match + also_include))
tasks = list(set(regex_match + also_include))
dag.tasks = tasks
for t in dag.tasks:
# Removing upstream/downstream references to tasks that did not
# made the cut
t._upstream_task_ids = [
tid for tid in t._upstream_task_ids if tid in dag.task_ids]
t._downstream_task_ids = [
tid for tid in t._downstream_task_ids if tid in dag.task_ids]
t._upstream_list = [
ut for ut in t._upstream_list if utils.is_in(ut, tasks)]
t._downstream_list = [
ut for ut in t._downstream_list if utils.is_in(ut, tasks)]

return dag

def has_task(self, task_id):
Expand Down

0 comments on commit 173b193

Please sign in to comment.