Skip to content

Commit

Permalink
[AIRFLOW 1149][AIRFLOW-1149] Allow for custom filters in Jinja2 templ…
Browse files Browse the repository at this point in the history
…ates

Closes apache#2258 from
NielsZeilemaker/jinja_custom_filters
  • Loading branch information
NielsZeilemaker authored and bolkedebruin committed Apr 29, 2017
1 parent 66168ef commit 48135ad
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 3 deletions.
19 changes: 16 additions & 3 deletions airflow/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2249,11 +2249,13 @@ def __deepcopy__(self, memo):
memo[id(self)] = result

for k, v in list(self.__dict__.items()):
if k not in ('user_defined_macros', 'params'):
if k not in ('user_defined_macros', 'user_defined_filters', 'params'):
setattr(result, k, copy.deepcopy(v, memo))
result.params = self.params
if hasattr(self, 'user_defined_macros'):
result.user_defined_macros = self.user_defined_macros
if hasattr(self, 'user_defined_filters'):
result.user_defined_filters = self.user_defined_filters
return result

def render_template_from_field(self, attr, content, context, jinja_env):
Expand Down Expand Up @@ -2644,6 +2646,12 @@ class DAG(BaseDag, LoggingMixin):
templates related to this DAG. Note that you can pass any
type of object here.
:type user_defined_macros: dict
:param user_defined_filters: a dictionary of filters that will be exposed
in your jinja templates. For example, passing
``dict(hello=lambda name: 'Hello %s' % name)`` to this argument allows
you to ``{{ 'world' | hello }}`` in all jinja templates related to
this DAG.
:type user_defined_filters: dict
:param default_args: A dictionary of default parameters to be used
as constructor keyword parameters when initialising operators.
Note that operators have the same hook, and precede those defined
Expand Down Expand Up @@ -2684,6 +2692,7 @@ def __init__(
full_filepath=None,
template_searchpath=None,
user_defined_macros=None,
user_defined_filters=None,
default_args=None,
concurrency=configuration.getint('core', 'dag_concurrency'),
max_active_runs=configuration.getint(
Expand All @@ -2696,6 +2705,7 @@ def __init__(
params=None):

self.user_defined_macros = user_defined_macros
self.user_defined_filters = user_defined_filters
self.default_args = default_args or {}
self.params = params or {}

Expand Down Expand Up @@ -3034,7 +3044,7 @@ def crawl_for_tasks(objects):
def get_template_env(self):
"""
Returns a jinja2 Environment while taking into account the DAGs
template_searchpath and user_defined_macros
template_searchpath, user_defined_macros and user_defined_filters
"""
searchpath = [self.folder]
if self.template_searchpath:
Expand All @@ -3046,6 +3056,8 @@ def get_template_env(self):
cache_size=0)
if self.user_defined_macros:
env.globals.update(self.user_defined_macros)
if self.user_defined_filters:
env.filters.update(self.user_defined_filters)

return env

Expand Down Expand Up @@ -3212,10 +3224,11 @@ def __deepcopy__(self, memo):
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in list(self.__dict__.items()):
if k not in ('user_defined_macros', 'params'):
if k not in ('user_defined_macros', 'user_defined_filters', 'params'):
setattr(result, k, copy.deepcopy(v, memo))

result.user_defined_macros = self.user_defined_macros
result.user_defined_filters = self.user_defined_filters
result.params = self.params
return result

Expand Down
10 changes: 10 additions & 0 deletions docs/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,16 @@ different languages, and general flexibility in structuring pipelines. It is
also possible to define your ``template_searchpath`` as pointing to any folder
locations in the DAG constructor call.

Using that same DAG constructor call, it is possible to define
``user_defined_macros`` which allow you to specify your own variables.
For example, passing ``dict(foo='bar')`` to this argument allows you
to use ``{{ foo }}`` in your templates. Moreover, specifying
``user_defined_filters`` allow you to register you own filters. For example,
passing ``dict(hello=lambda name: 'Hello %s' % name)`` to this argument allows
you to use ``{{ 'world' | hello }}`` in your templates. For more information
regarding custom filters have a look at the
`Jinja Documentation <http://jinja.pocoo.org/docs/dev/api/#writing-filters>`_

For more information on the variables and macros that can be referenced
in templates, make sure to read through the :ref:`macros` section

Expand Down
55 changes: 55 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,61 @@ def test_get_num_task_instances(self):
states=[None, State.QUEUED, State.RUNNING], session=session))
session.close()

def test_render_template_field(self):
"""Tests if render_template from a field works"""

dag = DAG('test-dag',
start_date=DEFAULT_DATE)

with dag:
task = DummyOperator(task_id='op1')

result = task.render_template('', '{{ foo }}', dict(foo='bar'))
self.assertEqual(result, 'bar')

def test_render_template_field_macro(self):
""" Tests if render_template from a field works,
if a custom filter was defined"""

dag = DAG('test-dag',
start_date=DEFAULT_DATE,
user_defined_macros = dict(foo='bar'))

with dag:
task = DummyOperator(task_id='op1')

result = task.render_template('', '{{ foo }}', dict())
self.assertEqual(result, 'bar')

def test_user_defined_filters(self):
def jinja_udf(name):
return 'Hello %s' %name

dag = models.DAG('test-dag',
start_date=DEFAULT_DATE,
user_defined_filters=dict(hello=jinja_udf))
jinja_env = dag.get_template_env()

self.assertIn('hello', jinja_env.filters)
self.assertEqual(jinja_env.filters['hello'], jinja_udf)

def test_render_template_field_filter(self):
""" Tests if render_template from a field works,
if a custom filter was defined"""

def jinja_udf(name):
return 'Hello %s' %name

dag = DAG('test-dag',
start_date=DEFAULT_DATE,
user_defined_filters = dict(hello=jinja_udf))

with dag:
task = DummyOperator(task_id='op1')

result = task.render_template('', "{{ 'world' | hello}}", dict())
self.assertEqual(result, 'Hello world')


class DagStatTest(unittest.TestCase):
def test_dagstats_crud(self):
Expand Down

0 comments on commit 48135ad

Please sign in to comment.