Skip to content

Commit

Permalink
Fix how we recurse through collection-like template fields.
Browse files Browse the repository at this point in the history
  • Loading branch information
Kirill Kourtchikov committed Oct 7, 2015
1 parent 9ce1cdb commit 376cdd4
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 24 deletions.
60 changes: 36 additions & 24 deletions airflow/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,21 +1059,8 @@ def render_templates(self):
for attr in task.__class__.template_fields:
content = getattr(task, attr)
if content:
if isinstance(content, basestring):
result = rt(content, jinja_context)
elif isinstance(content, (list, tuple)):
result = [rt(s, jinja_context) for s in content]
elif isinstance(content, dict):
result = {
k: rt(v, jinja_context)
for k, v in list(content.items())}
else:
param_type = type(content)
msg = (
"Type '{param_type}' used for parameter '{attr}' is "
"not supported for templating").format(**locals())
raise AirflowException(msg)
setattr(task, attr, result)
rendered_content = self.task.render_template(content, jinja_context)
setattr(task, attr, rendered_content)

def email_alert(self, exception, is_retry=False):
task = self.task
Expand Down Expand Up @@ -1519,18 +1506,43 @@ def __deepcopy__(self, memo):

return result

def render_template(self, content, context):
if hasattr(self, 'dag'):
env = self.dag.get_template_env()
def render_template_from_field(self, content, context, jinja_env):
'''
Renders a template from a field. If the field is a string, it will
simply render the string and return the result. If it is a collection or
nested set of collections, it will traverse the structure and render
all strings in it.
'''
rt = self.render_template_from_field
if isinstance(content, basestring):
result = jinja_env.from_string(content).render(**context)
elif isinstance(content, (list, tuple)):
result = [rt(e, context, jinja_env) for e in content]
elif isinstance(content, dict):
result = {
k: rt(v, context, jinja_env)
for k, v in list(content.items())}
else:
env = jinja2.Environment(cache_size=0)
param_type = type(content)
msg = (
"Type '{param_type}' used for parameter '{attr}' is "
"not supported for templating").format(**locals())
raise AirflowException(msg)
return result

def render_template(self, content, context):
'''
Renders a template either from a file or directly in a field, and returns
the rendered result.
'''
jinja_env = self.dag.get_template_env() \
if hasattr(self, 'dag') \
else jinja2.Environment(cache_size=0)

exts = self.__class__.template_ext
if any([content.endswith(ext) for ext in exts]):
template = env.get_template(content)
else:
template = env.from_string(content)
return template.render(**context)
return jinja_env.get_template(content).render(**context) \
if isinstance(content, basestring) and any([content.endswith(ext) for ext in exts]) \
else self.render_template_from_field(content, context, jinja_env)

def prepare_template(self):
'''
Expand Down
22 changes: 22 additions & 0 deletions tests/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,28 @@ def test_py_op(templates_dict, ds, **kwargs):
dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, force=True)


def test_complex_template(self):
class OperatorSubclass(operators.BaseOperator):
template_fields = ['some_templated_field']
def __init__(self, some_templated_field, *args, **kwargs):
super(OperatorSubclass, self).__init__(*args, **kwargs)
self.some_templated_field = some_templated_field
def execute(*args, **kwargs):
pass
def test_some_templated_field_template_render(context):
self.assertEqual(context['ti'].task.some_templated_field['bar'][1], context['ds'])
t = OperatorSubclass(
task_id='test_complex_template',
provide_context=True,
some_templated_field={
'foo':'123',
'bar':['baz', '{{ ds }}']
},
on_success_callback=test_some_templated_field_template_render,
dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, force=True)

def test_import_examples(self):
self.assertEqual(len(self.dagbag.dags), NUM_EXAMPLE_DAGS)

Expand Down

0 comments on commit 376cdd4

Please sign in to comment.