Skip to content

Commit

Permalink
Merge pull request apache#260 from jlowin/XCom
Browse files Browse the repository at this point in the history
Add XCom (cross-communication) functionality
  • Loading branch information
mistercrunch committed Aug 20, 2015
2 parents 33b4593 + 0235a9b commit 767e6f5
Show file tree
Hide file tree
Showing 4 changed files with 366 additions and 10 deletions.
45 changes: 45 additions & 0 deletions airflow/example_dags/example_xcom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from __future__ import print_function
import airflow
import datetime

dag = airflow.DAG(
'example_xcom',
start_date=datetime.datetime(2015, 1, 1),
default_args={'owner': 'airflow', 'provide_context': True})

value_1 = [1, 2, 3]
value_2 = {'a': 'b'}

def push(**kwargs):
# pushes an XCom without a specific target
kwargs['ti'].xcom_push(key='value from pusher 1', value=value_1)

def push_by_returning(**kwargs):
# pushes an XCom without a specific target, just by returning it
return value_2

def puller(**kwargs):
ti = kwargs['ti']

# get value_1
v1 = ti.xcom_pull(key=None, task_ids='push')
assert v1 == value_1

# get value_2
v2 = ti.xcom_pull(task_ids='push_by_returning')
assert v2 == value_2

# get both value_1 and value_2
v1, v2 = ti.xcom_pull(key=None, task_ids=['push', 'push_by_returning'])
assert (v1, v2) == (value_1, value_2)

push1 = airflow.operators.PythonOperator(
task_id='push', dag=dag, python_callable=push)

push2 = airflow.operators.PythonOperator(
task_id='push_by_returning', dag=dag, python_callable=push_by_returning)

pull = airflow.operators.PythonOperator(
task_id='puller', dag=dag, python_callable=puller)

pull.set_upstream([push1, push2])
267 changes: 259 additions & 8 deletions airflow/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from builtins import object
import copy
from datetime import datetime, timedelta
import functools
import getpass
import imp
import jinja2
Expand All @@ -19,7 +20,7 @@
from sqlalchemy import (
Column, Integer, String, DateTime, Text, Boolean, ForeignKey, PickleType,
Index, BigInteger)
from sqlalchemy import case, func, or_
from sqlalchemy import case, func, or_, and_
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.dialects.mysql import LONGTEXT
from sqlalchemy.orm import relationship
Expand All @@ -28,12 +29,15 @@
from airflow.executors import DEFAULT_EXECUTOR, LocalExecutor
from airflow.configuration import conf
from airflow.utils import (
AirflowException, State, apply_defaults, provide_session)
AirflowException, State, apply_defaults, provide_session,
is_container, as_tuple)

Base = declarative_base()
ID_LEN = 250
SQL_ALCHEMY_CONN = conf.get('core', 'SQL_ALCHEMY_CONN')
DAGS_FOLDER = os.path.expanduser(conf.get('core', 'DAGS_FOLDER'))
XCOM_RETURN_KEY = 'return_value'


if 'mysql' in SQL_ALCHEMY_CONN:
LongText = LONGTEXT
Expand Down Expand Up @@ -676,10 +680,10 @@ def are_dependencies_met(
qry = (
session
.query(
func.sum(
case([(TI.state == State.SUCCESS, 1)], else_=0)),
func.sum(
case([(TI.state == State.SKIPPED, 1)], else_=0)),
func.coalesce(func.sum(
case([(TI.state == State.SUCCESS, 1)], else_=0)), 0),
func.coalesce(func.sum(
case([(TI.state == State.SKIPPED, 1)], else_=0)), 0),
func.count(TI.task_id),
)
.filter(
Expand Down Expand Up @@ -843,12 +847,19 @@ def signal_handler(signum, frame):

# If a timout is specified for the task, make it fail
# if it goes beyond
result = None
if task_copy.execution_timeout:
with utils.timeout(int(
task_copy.execution_timeout.total_seconds())):
task_copy.execute(context=context)
result = task_copy.execute(context=context)

else:
task_copy.execute(context=context)
result = task_copy.execute(context=context)

# If the task returns a result, push an XCom containing it
if result is not None:
self.xcom_push(key=XCOM_RETURN_KEY, value=result)

task_copy.post_execute(context=context)
except (Exception, KeyboardInterrupt) as e:
self.handle_failure(e, test_mode, context)
Expand Down Expand Up @@ -1000,6 +1011,92 @@ def set_duration(self):
else:
self.duration = None

def xcom_push(
self,
key,
value,
execution_date=None):
"""
Make an XCom available for tasks to pull.
:param key: A key for the XCom
:type key: string
:param value: A value for the XCom. The value is pickled and stored
in the database.
:type value: any pickleable object
:param execution_date: if provided, the XCom will not be visible until
this date. This can be used, for example, to send a message to a
task on a future date without it being immediately visible.
:type execution_date: datetime
"""

if execution_date and execution_date < self.execution_date:
raise ValueError(
'execution_date can not be in the past (current '
'execution_date is {}; received {})'.format(
self.execution_date, execution_date))

XCom.set(
key=key,
value=value,
task_id=self.task_id,
dag_id=self.dag_id,
execution_date=execution_date or self.execution_date)

def xcom_pull(
self,
task_ids,
dag_id=None,
key=XCOM_RETURN_KEY,
include_prior_dates=False,
limit=None):
"""
Pull XComs that optionally meet certain criteria.
The default value for `key` ("{return_key}") limits the search to XComs
that were returned by other tasks (as opposed to those that were pushed
manually). To remove this filter, pass key=None (or any desired value).
If a single task_id string is provided, the result is the value of the
most recent matching XCom from that task_id. If multiple task_ids are
provided, a tuple of matching values is returned. None is returned
whenever no matches are found.
:param key: A key for the XCom. If provided, only XComs with matching
keys will be returned. The default value is "{return_key}",
the key automatically given to XComs returned by tasks (as opposed
to being pushed manually). To remove the filter, pass key=None.
:type key: string
:param task_ids: Only XComs from tasks with matching ids will be
pulled. Can pass None to remove the filter.
:type task_ids: string or iterable of strings (representing task_ids)
:param dag_id: If provided, only pulls XComs from this DAG.
If None (default), the DAG of the calling task is used.
:type dag_id: string
:param include_prior_dates: If False, only XComs from the current
execution_date are returned. If True, XComs from previous dates
are returned as well.
:type include_prior_dates: bool
:param limit: the maximum number of results to return. Pass None for
no limit.
:type limit: int
""".format(return_key=XCOM_RETURN_KEY)

if dag_id is None:
dag_id = self.dag_id

pull_fn = functools.partial(
XCom.get_one,
execution_date=self.execution_date,
key=key,
dag_id=dag_id,
include_prior_dates=include_prior_dates)

if is_container(task_ids):
return tuple(pull_fn(task_id=t) for t in task_ids)
else:
return pull_fn(task_id=task_ids)


class Log(Base):
"""
Expand Down Expand Up @@ -1466,6 +1563,36 @@ def set_upstream(self, task_or_task_list):
"""
self._set_relatives(task_or_task_list, upstream=True)

def xcom_push(
self,
context,
key,
value,
execution_date=None):
"""
See TaskInstance.xcom_push()
"""
context['ti'].xcom_push(
key=key,
value=value,
execution_date=execution_date)

def xcom_pull(
self,
context,
task_ids,
dag_id=None,
key=XCOM_RETURN_KEY,
include_prior_dates=None):
"""
See TaskInstance.xcom_pull()
"""
return context['ti'].xcom_pull(
key=key,
task_ids=task_ids,
dag_id=dag_id,
include_prior_dates=include_prior_dates)


class DagModel(Base):

Expand Down Expand Up @@ -2011,6 +2138,130 @@ def get(cls, key, session, deserialize_json=False):
return v


class XCom(Base):
"""
Base class for XCom objects.
"""
__tablename__ = "xcom"

id = Column(Integer, primary_key=True)
key = Column(String(512))
value = Column(PickleType(pickler=dill))
timestamp = Column(DateTime, server_default=func.current_timestamp())
execution_date = Column(DateTime, nullable=False)

# source information
task_id = Column(String(ID_LEN), nullable=False)
dag_id = Column(String(ID_LEN), nullable=False)

def __repr__(self):
return '<XCom "{key}" ({task_id} @ {execution_date})>'.format(
key=self.key,
task_id=self.task_id,
execution_date=self.execution_date)

@classmethod
@provide_session
def set(
cls,
key,
value,
execution_date,
task_id,
dag_id,
session=None):
"""
Store an XCom value.
"""
session.expunge_all()

# remove any duplicate XComs
session.query(cls).filter(
cls.key == key,
cls.execution_date == execution_date,
cls.task_id == task_id,
cls.dag_id == dag_id).delete()

# insert new XCom
session.add(XCom(
key=key,
value=value,
execution_date=execution_date,
task_id=task_id,
dag_id=dag_id))

session.commit()

@classmethod
@provide_session
def get_one(
cls,
execution_date,
key=None,
task_id=None,
dag_id=None,
include_prior_dates=False,
session=None):
"""
Retrieve an XCom value, optionally meeting certain criteria
"""
filters = []
if key:
filters.append(cls.key == key)
if task_id:
filters.append(cls.task_id == task_id)
if dag_id:
filters.append(cls.dag_id == dag_id)
if include_prior_dates:
filters.append(cls.execution_date <= execution_date)
else:
filters.append(cls.execution_date == execution_date)

query = (
session.query(cls.value)
.filter(and_(*filters))
.order_by(cls.execution_date.desc(), cls.timestamp.desc())
.limit(1))

result = query.first()
if result:
return result.value

@classmethod
@provide_session
def get_many(
cls,
execution_date,
key=None,
task_ids=None,
dag_ids=None,
include_prior_dates=False,
limit=100,
session=None):
"""
Retrieve an XCom value, optionally meeting certain criteria
"""
filters = []
if key:
filters.append(cls.key == key)
if task_ids:
filters.append(cls.task_id.in_(as_tuple(task_ids)))
if dag_ids:
filters.append(cls.dag_id.in_(as_tuple(dag_ids)))
if include_prior_dates:
filters.append(cls.execution_date <= execution_date)
else:
filters.append(cls.execution_date == execution_date)

query = (
session.query(cls)
.filter(and_(*filters))
.order_by(cls.execution_date.desc(), cls.timestamp.desc())
.limit(limit))

return query.all()


class Pool(Base):
__tablename__ = "slot_pool"

Expand Down
Loading

0 comments on commit 767e6f5

Please sign in to comment.