diff --git a/.rat-excludes b/.rat-excludes index 7b66e5144b4f5..070dec0297135 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -81,3 +81,6 @@ pylint_todo.txt .bash_history .bash_aliases .inputrc + +# the example notebook is ASF 2 licensed but RAT cannot read this +input_notebook.ipynb diff --git a/airflow/api/client/api_client.py b/airflow/api/client/api_client.py index df909e431b283..4d8dde056edec 100644 --- a/airflow/api/client/api_client.py +++ b/airflow/api/client/api_client.py @@ -70,3 +70,12 @@ def delete_pool(self, name): :param name: pool name """ raise NotImplementedError() + + def get_lineage(self, dag_id: str, execution_date: str): + """ + Return the lineage information for the dag on this execution date + :param dag_id: + :param execution_date: + :return: + """ + raise NotImplementedError() diff --git a/airflow/api/client/json_client.py b/airflow/api/client/json_client.py index cc4280f285e30..7422255b5519b 100644 --- a/airflow/api/client/json_client.py +++ b/airflow/api/client/json_client.py @@ -93,3 +93,9 @@ def delete_pool(self, name): url = urljoin(self._api_base_url, endpoint) pool = self._request(url, method='DELETE') return pool['pool'], pool['slots'], pool['description'] + + def get_lineage(self, dag_id: str, execution_date: str): + endpoint = f"/api/experimental/lineage/{dag_id}/{execution_date}" + url = urljoin(self._api_base_url, endpoint) + data = self._request(url, method='GET') + return data['message'] diff --git a/airflow/api/client/local_client.py b/airflow/api/client/local_client.py index 37c9785740bf5..8c3f174abd8ee 100644 --- a/airflow/api/client/local_client.py +++ b/airflow/api/client/local_client.py @@ -20,6 +20,7 @@ from airflow.api.client import api_client from airflow.api.common.experimental import delete_dag, pool, trigger_dag +from airflow.api.common.experimental.get_lineage import get_lineage as get_lineage_api class Client(api_client.Client): @@ -50,3 +51,7 @@ def create_pool(self, name, slots, description): def delete_pool(self, name): the_pool = pool.delete_pool(name=name) return the_pool.pool, the_pool.slots, the_pool.description + + def get_lineage(self, dag_id, execution_date): + lineage = get_lineage_api(dag_id=dag_id, execution_date=execution_date) + return lineage diff --git a/airflow/api/common/experimental/get_lineage.py b/airflow/api/common/experimental/get_lineage.py new file mode 100644 index 0000000000000..cbecfae28aba9 --- /dev/null +++ b/airflow/api/common/experimental/get_lineage.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Lineage apis +""" +import datetime +from typing import Any, Dict, List + +from airflow.api.common.experimental import check_and_get_dag, check_and_get_dagrun +from airflow.lineage import PIPELINE_INLETS, PIPELINE_OUTLETS +from airflow.models.xcom import XCom +from airflow.utils.session import provide_session + + +@provide_session +def get_lineage(dag_id: str, execution_date: datetime.datetime, session=None) -> Dict[str, Dict[str, Any]]: + """ + Gets the lineage information for dag specified + """ + dag = check_and_get_dag(dag_id) + check_and_get_dagrun(dag, execution_date) + + inlets: List[XCom] = XCom.get_many(dag_ids=dag_id, execution_date=execution_date, + key=PIPELINE_INLETS, session=session).all() + outlets: List[XCom] = XCom.get_many(dag_ids=dag_id, execution_date=execution_date, + key=PIPELINE_OUTLETS, session=session).all() + + lineage: Dict[str, Dict[str, Any]] = {} + for meta in inlets: + lineage[meta.task_id] = {'inlets': meta.value} + + for meta in outlets: + lineage[meta.task_id]['outlets'] = meta.value + + return {'task_ids': lineage} diff --git a/airflow/example_dags/example_papermill_operator.py b/airflow/example_dags/example_papermill_operator.py new file mode 100644 index 0000000000000..70e2258c8949f --- /dev/null +++ b/airflow/example_dags/example_papermill_operator.py @@ -0,0 +1,73 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os +from datetime import timedelta + +import scrapbook as sb + +import airflow +from airflow.lineage import AUTO +from airflow.models import DAG +from airflow.operators.papermill_operator import PapermillOperator +from airflow.operators.python_operator import PythonOperator + + +def check_notebook(inlets, execution_date): + """ + Verify the message in the notebook + """ + notebook = sb.read_notebook(inlets[0].url) + message = notebook.scraps['message'] + print(f"Message in notebook {message} for {execution_date}") + + if message.data != f"Ran from Airflow at {execution_date}!": + return False + + return True + + +args = { + 'owner': 'airflow', + 'start_date': airflow.utils.dates.days_ago(2) +} + +dag = DAG( + dag_id='example_papermill_operator', default_args=args, + schedule_interval='0 0 * * *', + dagrun_timeout=timedelta(minutes=60)) + +run_this = PapermillOperator( + task_id="run_example_notebook", + dag=dag, + input_nb=os.path.join(os.path.dirname(os.path.realpath(__file__)), + "input_notebook.ipynb"), + output_nb="/tmp/out-{{ execution_date }}.ipynb", + parameters={"msgs": "Ran from Airflow at {{ execution_date }}!"} +) + +check_output = PythonOperator( + task_id='check_out', + python_callable=check_notebook, + dag=dag, + inlets=AUTO) + +check_output.set_upstream(run_this) + +if __name__ == "__main__": + dag.cli() diff --git a/airflow/example_dags/input_notebook.ipynb b/airflow/example_dags/input_notebook.ipynb new file mode 100644 index 0000000000000..eb73f825e9f7c --- /dev/null +++ b/airflow/example_dags/input_notebook.ipynb @@ -0,0 +1,120 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " Licensed to the Apache Software Foundation (ASF) under one\n", + " or more contributor license agreements. See the NOTICE file\n", + " distributed with this work for additional information\n", + " regarding copyright ownership. The ASF licenses this file\n", + " to you under the Apache License, Version 2.0 (the\n", + " \"License\"); you may not use this file except in compliance\n", + " with the License. You may obtain a copy of the License at\n", + "\n", + " http://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + " Unless required by applicable law or agreed to in writing,\n", + " software distributed under the License is distributed on an\n", + " \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n", + " KIND, either express or implied. See the License for the\n", + " specific language governing permissions and limitations\n", + " under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This is an example jupyter notebook for Apache Airflow that shows how to use\n", + "papermill in combination with scrapbook" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import scrapbook as sb" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The parameter tag for cells is used to tell papermill where it can find variables it needs to set" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "tags": [ + "parameters" + ] + }, + "outputs": [], + "source": [ + "msgs = \"Hello!\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Inside the notebook you can save data by calling the glue function. Then later you can read the results of that notebook by “scrap” name (see the Airflow Papermill example DAG)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "application/scrapbook.scrap.text+json": { + "data": "Hello!", + "encoder": "text", + "name": "message", + "version": 1 + } + }, + "metadata": { + "scrapbook": { + "data": true, + "display": false, + "name": "message" + } + }, + "output_type": "display_data" + } + ], + "source": [ + "sb.glue('message', msgs)" + ] + } + ], + "metadata": { + "celltoolbar": "Tags", + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/airflow/lineage/__init__.py b/airflow/lineage/__init__.py index 85c53a6cb895f..35e5bf366e8af 100644 --- a/airflow/lineage/__init__.py +++ b/airflow/lineage/__init__.py @@ -154,8 +154,6 @@ def wrapper(self, context, *args, **kwargs): self.inlets.extend(_inlets) self.inlets.extend(self._inlets) - self.inlets = [_render_object(i, context) - for i in self.inlets if attr.has(i)] elif self._inlets: raise AttributeError("inlets is not a list, operator, string or attr annotated object") @@ -165,8 +163,12 @@ def wrapper(self, context, *args, **kwargs): self.outlets.extend(self._outlets) - self.outlets = list(map(lambda i: _render_object(i, context), - filter(attr.has, self.outlets))) + # render inlets and outlets + self.inlets = [_render_object(i, context) + for i in self.inlets if attr.has(i)] + + self.outlets = [_render_object(i, context) + for i in self.outlets if attr.has(i)] self.log.debug("inlets: %s, outlets: %s", self.inlets, self.outlets) return func(self, context, *args, **kwargs) diff --git a/airflow/www/api/experimental/endpoints.py b/airflow/www/api/experimental/endpoints.py index 5e1a37b3cc19e..83a06820d05db 100644 --- a/airflow/www/api/experimental/endpoints.py +++ b/airflow/www/api/experimental/endpoints.py @@ -24,6 +24,7 @@ from airflow.api.common.experimental.get_code import get_code from airflow.api.common.experimental.get_dag_run_state import get_dag_run_state from airflow.api.common.experimental.get_dag_runs import get_dag_runs +from airflow.api.common.experimental.get_lineage import get_lineage as get_lineage_api from airflow.api.common.experimental.get_task import get_task from airflow.api.common.experimental.get_task_instance import get_task_instance from airflow.exceptions import AirflowException @@ -353,3 +354,33 @@ def delete_pool(name): return response else: return jsonify(pool.to_json()) + + +@csrf.exempt +@api_experimental.route('/lineage//', + methods=['GET']) +@requires_authentication +def get_lineage(dag_id: str, execution_date: str): + # Convert string datetime into actual datetime + try: + execution_date = timezone.parse(execution_date) + except ValueError: + error_message = ( + 'Given execution date, {}, could not be identified ' + 'as a date. Example date format: 2015-11-16T14:34:15+00:00'.format( + execution_date)) + _log.info(error_message) + response = jsonify({'error': error_message}) + response.status_code = 400 + + return response + + try: + lineage = get_lineage_api(dag_id=dag_id, execution_date=execution_date) + except AirflowException as err: + _log.error(err) + response = jsonify(error=f"{err}") + response.status_code = err.status_code + return response + else: + return jsonify(lineage) diff --git a/docs/rest-api-ref.rst b/docs/rest-api-ref.rst index c7c4c9820b4d5..51832c059345d 100644 --- a/docs/rest-api-ref.rst +++ b/docs/rest-api-ref.rst @@ -98,3 +98,7 @@ Endpoints .. http:delete:: /api/experimental/pools/ Delete pool. + +.. http:get:: /api/experimental/lineage/// + + Returns the lineage information for the dag. diff --git a/setup.py b/setup.py index 5407c0ffe0cd8..aef0b6f4c82be 100644 --- a/setup.py +++ b/setup.py @@ -290,8 +290,8 @@ def write_version(filename: str = os.path.join(*["airflow", "git_version"])): 'pypd>=1.1.0', ] papermill = [ - 'papermill[all]>=1.0.0', - 'nteract-scrapbook[all]>=0.2.1', + 'papermill[all]>=1.2.1', + 'nteract-scrapbook[all]>=0.3.1', ] password = [ 'bcrypt>=2.0.0', @@ -437,7 +437,7 @@ def do_setup(): version=version, packages=find_packages(exclude=['tests*']), package_data={ - '': ['airflow/alembic.ini', "airflow/git_version"], + '': ['airflow/alembic.ini', "airflow/git_version", "*.ipynb"], 'airflow.serialization': ["*.json"], }, include_package_data=True, diff --git a/tests/lineage/test_lineage.py b/tests/lineage/test_lineage.py index ab2bfcf2f4d7c..0d38fe24b431b 100644 --- a/tests/lineage/test_lineage.py +++ b/tests/lineage/test_lineage.py @@ -98,3 +98,28 @@ def test_lineage(self): op5.pre_execute(ctx5) self.assertEqual(len(op5.inlets), 2) op5.post_execute(ctx5) + + def test_lineage_render(self): + # tests inlets / outlets are rendered if they are added + # after initalization + dag = DAG( + dag_id='test_lineage_render', + start_date=DEFAULT_DATE + ) + + with dag: + op1 = DummyOperator(task_id='task1') + + f1s = "/tmp/does_not_exist_1-{}" + file1 = File(f1s.format("{{ execution_date }}")) + + op1.inlets.append(file1) + op1.outlets.append(file1) + + # execution_date is set in the context in order to avoid creating task instances + ctx1 = {"ti": TI(task=op1, execution_date=DEFAULT_DATE), + "execution_date": DEFAULT_DATE} + + op1.pre_execute(ctx1) + self.assertEqual(op1.inlets[0].url, f1s.format(DEFAULT_DATE)) + self.assertEqual(op1.outlets[0].url, f1s.format(DEFAULT_DATE)) diff --git a/tests/www/api/experimental/test_endpoints.py b/tests/www/api/experimental/test_endpoints.py index a3eb49d5b3674..23c9dfc5fe605 100644 --- a/tests/www/api/experimental/test_endpoints.py +++ b/tests/www/api/experimental/test_endpoints.py @@ -16,7 +16,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - import json import unittest from datetime import timedelta @@ -306,6 +305,49 @@ def test_dagrun_status(self): self.assertEqual(400, response.status_code) self.assertIn('error', response.data.decode('utf-8')) + def test_lineage_info(self): + url_template = '/api/experimental/lineage/{}/{}' + dag_id = 'example_papermill_operator' + execution_date = utcnow().replace(microsecond=0) + datetime_string = quote_plus(execution_date.isoformat()) + wrong_datetime_string = quote_plus( + datetime(1990, 1, 1, 1, 1, 1).isoformat() + ) + + # create DagRun + trigger_dag(dag_id=dag_id, + run_id='test_lineage_info_run', + execution_date=execution_date) + + # test correct execution + response = self.client.get( + url_template.format(dag_id, datetime_string) + ) + self.assertEqual(200, response.status_code) + self.assertIn('task_ids', response.data.decode('utf-8')) + self.assertNotIn('error', response.data.decode('utf-8')) + + # Test error for nonexistent dag + response = self.client.get( + url_template.format('does_not_exist_dag', datetime_string), + ) + self.assertEqual(404, response.status_code) + self.assertIn('error', response.data.decode('utf-8')) + + # Test error for nonexistent dag run (wrong execution_date) + response = self.client.get( + url_template.format(dag_id, wrong_datetime_string) + ) + self.assertEqual(404, response.status_code) + self.assertIn('error', response.data.decode('utf-8')) + + # Test error for bad datetime format + response = self.client.get( + url_template.format(dag_id, 'not_a_datetime') + ) + self.assertEqual(400, response.status_code) + self.assertIn('error', response.data.decode('utf-8')) + class TestPoolApiExperimental(TestBase):