Skip to content

Commit

Permalink
Making Jinja templates as files work
Browse files Browse the repository at this point in the history
  • Loading branch information
mistercrunch committed Nov 2, 2014
1 parent c8cc050 commit dc9537d
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 16 deletions.
15 changes: 12 additions & 3 deletions airflow/hooks/hive_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, hive_dbid=settings.HIVE_DEFAULT_DBID):
DatabaseConnection).filter(
DatabaseConnection.db_id == hive_dbid)
if db.count() == 0:
raise Exception("The presto_dbid you provided isn't defined")
raise Exception("The dbid you provided isn't defined")
else:
db = db.all()[0]
self.host = db.host
Expand All @@ -42,8 +42,8 @@ def __getstate__(self):
return d

def __setstate__(self, d):
d['hive'] = self.get_hive_client()
self.__dict__.update(d)
d['hive'] = self.get_hive_client()

def get_hive_client(self):
transport = TSocket.TSocket(self.host, self.port)
Expand Down Expand Up @@ -98,9 +98,18 @@ def run(self, hql, schema=None):
def run_cli(self, hql, schema=None):
if schema:
hql = "USE {schema};\n{hql}".format(**locals())
sp = subprocess.Popen(['hive', '-e', hql])
sp = subprocess.Popen(
['hive', '-e', hql],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
all_err = ''
for line in iter(sp.stdout.readline, ''):
logging.info(line)
sp.wait()

if sp.returncode:
raise Exception(all_err)

def max_partition(self, schema, table):
'''
Returns the maximum value for all partitions in a table. Works only
Expand Down
47 changes: 36 additions & 11 deletions airflow/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def process_file(self, filepath):
else:
for dag in m.__dict__.values():
if type(dag) == DAG:
dag.filepath = filepath.replace(
settings.AIRFLOW_HOME + '/', '')
dag.full_filepath = filepath
#.replace(settings.AIRFLOW_HOME + '/', '')
if dag.dag_id in self.dags:
raise Exception(
'Two DAGs with the same dag_id. No good.')
Expand Down Expand Up @@ -396,7 +396,7 @@ def run(
"Not ready for retry yet. " +
"Next run after {0}".format(next_run)
)
elif self.state in State.runnable():
elif force or self.state in State.runnable():
if self.state == State.UP_FOR_RETRY:
self.try_number += 1
else:
Expand All @@ -417,21 +417,35 @@ def run(
try:
if not mark_success:
from airflow import macros
env = jinja2.Environment(
loader=jinja2.FileSystemLoader(
self.task.dag.folder))
tables = None
if 'tables' in self.task.params:
tables = self.task.params['tables']
jinja_context = {
'ti': self,
'execution_date': self.execution_date,
'ds': self.execution_date.isoformat()[:10],
'task': self.task,
'dag': self.task.dag,
'ds': self.execution_date.isoformat()[:10],
'execution_date': self.execution_date,
'macros': macros,
'params': self.task.params,
'tables': tables,
'task': self.task,
'task_instance': self,
'ti': self,
}
task_copy = copy.copy(self.task)
for attr in task_copy.__class__.template_fields:
source = getattr(task_copy, attr)
template = jinja2.Template(source)
for ext in task_copy.__class__.template_ext:
# Magic, if field has the right extension, look
# for the file.
if source.strip().endswith(ext):
template = env.get_template(source)
setattr(
task_copy, attr,
jinja2.Template(source).render(**jinja_context)
template.render(**jinja_context)
)
task_copy.execute(self.execution_date)
except Exception as e:
Expand Down Expand Up @@ -663,7 +677,10 @@ class derived from this one results in the creation of a task object,
:type dag: DAG
"""

# For derived classes to define which fields will get jinjaified
template_fields = []
# Defines wich files extensions to look for in the templated fields
template_ext = []

__tablename__ = "task"

Expand Down Expand Up @@ -898,7 +915,7 @@ class DAG(Base):
dag_id = Column(String(ID_LEN), primary_key=True)
task_count = Column(Integer)
parallelism = Column(Integer)
filepath = Column(String(2000))
full_filepath = Column(String(2000))

tasks = relationship(
"BaseOperator", cascade="merge, delete, delete-orphan", backref='dag')
Expand All @@ -907,14 +924,14 @@ def __init__(
self, dag_id,
schedule_interval=timedelta(days=1),
start_date=None, end_date=None, parallelism=0,
filepath=None):
full_filepath=None):

utils.validate_key(dag_id)
self.dag_id = dag_id
self.end_date = end_date or datetime.now()
self.parallelism = parallelism
self.schedule_interval = schedule_interval
self.filepath = filepath if filepath else ''
self.full_filepath = full_filepath if full_filepath else ''

def __repr__(self):
return "<DAG: {self.dag_id}>".format(self=self)
Expand All @@ -923,6 +940,14 @@ def __repr__(self):
def task_ids(self):
return [t.task_id for t in self.tasks]

@property
def filepath(self):
return self.full_filepath.replace(settings.AIRFLOW_HOME + '/', '')

@property
def folder(self):
return os.path.dirname(self.full_filepath)

@property
def latest_execution_date(self):
TI = TaskInstance
Expand Down
1 change: 1 addition & 0 deletions airflow/operators/bash_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
class BashOperator(BaseOperator):

template_fields = ('bash_command',)
template_ext = ('.sh', '.bash',)

__mapper_args__ = {
'polymorphic_identity': 'BashOperator'
Expand Down
3 changes: 2 additions & 1 deletion airflow/operators/hive_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class HiveOperator(BaseOperator):
'polymorphic_identity': 'HiveOperator'
}
template_fields = ('hql',)
template_ext = ('.hql', '.sql',)

def __init__(
self, hql, hive_dbid=settings.HIVE_DEFAULT_DBID,
Expand All @@ -30,4 +31,4 @@ def __init__(

def execute(self, execution_date):
logging.info('Executing: ' + self.hql)
self.hook.run_cli(hql=self.hql, schema=self.hive_dbid)
self.hook.run_cli(hql=self.hql)
1 change: 1 addition & 0 deletions airflow/operators/mysql_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class MySqlOperator(BaseOperator):
'polymorphic_identity': 'MySqlOperator'
}
template_fields = ('sql',)
template_ext = ('.sql',)

def __init__(self, sql, mysql_dbid, *args, **kwargs):
"""
Expand Down
2 changes: 2 additions & 0 deletions airflow/operators/sensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class SqlSensor(BaseSensorOperator):
in (0, '0', '')
"""
template_fields = ('sql',)
template_ext = ('.hql', '.sql',)

__mapper_args__ = {
'polymorphic_identity': 'SqlSensor'
}
Expand Down
12 changes: 11 additions & 1 deletion dags/examples/example2.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,17 @@
task.set_downstream(run_this)
dag.add_task(task)

task = BashOperator(task_id='also_run_this', bash_command='ls -l', **default_args)
cmd = """\
echo {{ params.tables.the_table }}
"""
task = BashOperator(
task_id='also_run_this', bash_command=cmd,
params={
'tables': {
'the_table': 'da_table',
}
},
**default_args)
dag.add_task(task)
task.set_downstream(run_this_last)
task.set_upstream(run_this)
Expand Down
13 changes: 13 additions & 0 deletions dags/examples/test_hivepartitionsensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from datetime import datetime
default_args = {
'owner': 'max',
'start_date': datetime(2014, 9, 1),
'mysql_dbid': 'local_mysql',
}
from airflow.operators import HivePartitionSensor
from airflow import settings
from airflow import DAG

dag = DAG("test_wfh")
t = HivePartitionSensor(task_id="test_hps", schema='core_data',table="fct_revenue", partition="ds='{{ ds }}'", **default_args )
dag.add_task(t)

0 comments on commit dc9537d

Please sign in to comment.