Skip to content

Commit

Permalink
Fix bug where all runs fail if DAG fails (MarquezProject#107)
Browse files Browse the repository at this point in the history
* Fix bug where all runs fail if DAG fails

Reports task state information for each run instead of dag state.

Also: reworks the handle_callback logic to remove superfluous call to report_task for operators with extract_on_complete (BigQuery).
Signed-off-by: henneberger <[email protected]>

* Fix style

Signed-off-by: henneberger <[email protected]>

* Fix bug where not all runs are reported on task complete. Also, assure job id mapping is only popped from stack once during handle_callback

Signed-off-by: henneberger <[email protected]>

* Remove semicolon

Signed-off-by: henneberger <[email protected]>

Co-authored-by: Willy Lulciuc <[email protected]>
  • Loading branch information
henneberger and wslulciuc authored Nov 30, 2020
1 parent 5570016 commit 965a794
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 118 deletions.
162 changes: 61 additions & 101 deletions marquez_airflow/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from airflow import LoggingMixin
from airflow.contrib.operators.bigquery_operator import BigQueryOperator
from airflow.operators.postgres_operator import PostgresOperator
from airflow.utils.state import State
from marquez_client import Clients
from marquez_client.models import JobType
from pendulum import Pendulum
Expand Down Expand Up @@ -125,6 +126,17 @@ def _timed_log_message(self, start_time):
f'marquez_namespace={self.marquez_namespace} ' \
f'duration_ms={(self._now_ms() - start_time)}'

def _handle_task_state(self, marquez_job_run_ids, ti):
for marquez_job_run_id in marquez_job_run_ids:
if ti.state in {State.SUCCESS, State.SKIPPED}:
self.log.info(f"Setting success: {ti.task_id}")
self.get_or_create_marquez_client(). \
mark_job_run_as_completed(run_id=marquez_job_run_id)
else:
self.log.info(f"Setting failed: {ti.task_id}")
self.get_or_create_marquez_client().mark_job_run_as_failed(
run_id=marquez_job_run_id)

def handle_callback(self, *args, **kwargs):
self.log.debug(f"handle_callback({args}, {kwargs})")

Expand All @@ -133,58 +145,64 @@ def handle_callback(self, *args, **kwargs):
self.log.debug(f"handle_callback() dagrun : {dagrun}")
task_instances = dagrun.get_task_instances()
self.log.info(f"{task_instances}")

for ti in task_instances:
task = dagrun.get_dag().get_task(ti.task_id)
self.log.info(f"ti: {ti} of task: {task}")
ti.task = self.get_task(ti.task_id)
task = ti.task

extractor = self._get_extractor(task)

ti_location = self._get_location(task)
self.log.info(f"{ti}")
job_name = self._marquez_job_name_from_ti(ti)
session = kwargs.get('session')
marquez_job_run_ids = self._job_id_mapping.pop(
job_name, dagrun.run_id, session)
if marquez_job_run_ids is None:
self.log.error(f'No runs assocated with task {ti}')
continue

if extractor:

steps_meta = add_airflow_info_to(
task,
extractor(task).extract_on_complete(ti))

for step in steps_meta:
self.log.info(f'step: {step}')

marquez_run_id = self._get_marquez_run_id(
ti, dagrun, kwargs)
self.log.info(f'marquez_run_id: {marquez_run_id}')

self.register_datasets(step.inputs)
self.register_datasets(step.outputs, marquez_run_id)
inputs = self._to_dataset_ids(step.inputs)
outputs = self._to_dataset_ids(step.outputs)
self.log.info(
f'inputs: {inputs} '
f'outputs: {outputs} '
)
self.get_or_create_marquez_client().create_job(
namespace_name=self.marquez_namespace,
job_name=step.name,
job_type=JobType.BATCH,
location=step.location or ti_location,
input_dataset=inputs,
output_dataset=outputs,
description=self.description,
context=step.context,
run_id=marquez_run_id)

self.log.info(f"client.create_job(run_id="
f"{marquez_run_id}) successful.")

try:
self.report_jobrun_change(
ti, dagrun.run_id, **kwargs)
except Exception as e:
self.log.error(
f'Failed to record task run state change: {e} '
f'dag_id={self.dag_id}',
exc_info=True)

for marquez_run_id in marquez_job_run_ids:
self.log.info(f'marquez_run_id: {marquez_run_id}')

inputs = None
if step.inputs is not None:
self.register_datasets(step.inputs)
inputs = self._to_dataset_ids(step.inputs)
self.log.info(
f'inputs: {inputs} '
)
outputs = None
if step.outputs is not None:
self.register_datasets(step.outputs,
marquez_run_id)
outputs = self._to_dataset_ids(step.outputs)
self.log.info(
f'outputs: {outputs} '
)

ti_location = self._get_location(task)
self.get_or_create_marquez_client().create_job(
namespace_name=self.marquez_namespace,
job_name=step.name,
job_type=JobType.BATCH,
location=step.location or ti_location,
input_dataset=inputs,
output_dataset=outputs,
description=self.description,
context=step.context,
run_id=marquez_run_id)

self.log.info(f"client.create_job(run_id="
f"{marquez_run_id}) successful.")
self._handle_task_state(marquez_job_run_ids, ti)
return
except Exception as e:
self.log.error(
f'Failed to record dagrun state change: {e} '
Expand Down Expand Up @@ -313,13 +331,9 @@ def report_task(self,
nominal_start_time=start_time,
nominal_end_time=end_time)

if external_run_id:
marquez_jobrun_ids.append(external_run_id)
marquez_client.mark_job_run_as_started(run_id=external_run_id)
else:
self.log.error(
f'Failed to get run id: {step.name} {task_info}'
)
marquez_jobrun_ids.append(external_run_id)
marquez_client.mark_job_run_as_started(run_id=external_run_id)

self.log.info(
f'Successfully recorded job run: {step.name} {task_info}'
f'airflow_dag_execution_time={start_time} '
Expand All @@ -340,48 +354,6 @@ def report_task(self,
def compute_endtime(self, execution_date):
return self.following_schedule(execution_date)

def report_jobrun_change(self, ti, run_id, **kwargs):
job_name = self._marquez_job_name_from_ti(ti)
session = kwargs.get('session')
marquez_job_run_ids = self._job_id_mapping.pop(
job_name, run_id, session)

task_info = \
f'airflow_dag_id={self.dag_id} ' \
f'task_id={ti.task_id} ' \
f'airflow_run_id={run_id} ' \
f'marquez_run_id={marquez_job_run_ids} ' \
f'marquez_namespace={self.marquez_namespace}' \
f'marquez_job_name={job_name} '

if marquez_job_run_ids:
self.log.info(
f'Found job runs: {task_info}'
f'marquez_run_ids={marquez_job_run_ids} ')

state = 'UNKNOWN'
if kwargs.get('success'):
state = 'COMPLETED'
for marquez_job_run_id in marquez_job_run_ids:
for task_id, task in self.task_dict.items():
if task_id == ti.task_id:
self.log.info(f'task_id: {task_id}')
self.report_task(
run_id,
None,
task,
self._get_extractor(task),
marquez_job_run_id=marquez_job_run_id)
self.get_or_create_marquez_client(). \
mark_job_run_as_completed(run_id=marquez_job_run_id)
else:
state = 'FAILED'
for marquez_job_run_id in marquez_job_run_ids:
self.get_or_create_marquez_client().mark_job_run_as_failed(
run_id=marquez_job_run_id)

self.log.info(f'Marked job run(s) as {state}. {task_info}')

def get_or_create_marquez_client(self):
if not self._marquez_client:
self._marquez_client = Clients.new_write_only_client()
Expand Down Expand Up @@ -461,15 +433,3 @@ def _marquez_job_name_from_ti(ti):
@staticmethod
def _marquez_job_name(dag_id, task_id):
return f'{dag_id}.{task_id}'

def _get_marquez_run_id(self, ti, dagrun, kwargs):
self.log.debug(f"_get_marquez_run_id({ti}, {dagrun}, {kwargs})")
job_name = self._marquez_job_name_from_ti(ti)
session = kwargs.get('session')
job_run_ids = self._job_id_mapping.get(
job_name, dagrun.run_id, session)
self.log.info(f"job_run_ids: {job_run_ids}")
if job_run_ids:
return job_run_ids[0]
else:
return None
22 changes: 5 additions & 17 deletions tests/test_marquez_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,12 @@ def test_marquez_dag(mock_get_or_create_marquez_client, mock_uuid,
# dataset call is not invoked.
mock_marquez_client.create_dataset.assert_not_called()

# session = settings.Session()
# (6) Start task that will be marked as failed
task_will_fail.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
ti1 = TaskInstance(task=task_will_fail, execution_date=DEFAULT_DATE)
ti1.state = State.FAILED
session.add(ti1)
session.commit()

dag.handle_callback(dagrun, success=False, session=session)
mock_marquez_client.mark_job_run_as_failed.assert_called_once_with(
Expand Down Expand Up @@ -402,21 +406,6 @@ def test_marquez_dag_with_extractor(mock_get_or_create_marquez_client,
description=DAG_DESCRIPTION,
namespace_name=DAG_NAMESPACE,
run_id=run_id
),
# TODO: consolidate the two calls, this second call is spurious
mock.call(
job_name=f"{dag_id}.{TASK_ID_COMPLETED}",
job_type=JobType.BATCH,
location=completed_task_location,
input_dataset=[
{'namespace': 'default', 'name': 'extract_input1'}
],
output_dataset=[
{'namespace': 'default', 'name': 'extract_output1'}
],
context=mock.ANY,
description=DAG_DESCRIPTION,
namespace_name=DAG_NAMESPACE,
)
])

Expand Down Expand Up @@ -458,6 +447,5 @@ def test_marquez_dag_with_extractor(mock_get_or_create_marquez_client,
'create_dataset', # we would expect only the output to be updated
'create_dataset',
'create_job',
'create_job', # we would expect only one call to update the job
'mark_job_run_as_completed'
]

0 comments on commit 965a794

Please sign in to comment.