Skip to content

Commit

Permalink
Handle exceptions from stop_node_services
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 400889331
  • Loading branch information
goutham authored and tfx-copybara committed Oct 5, 2021
1 parent f62d96a commit f142075
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 63 deletions.
20 changes: 12 additions & 8 deletions tfx/orchestration/experimental/core/pipeline_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,13 +475,17 @@ def _cancel_nodes(mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue,
for node in pstate.get_all_pipeline_nodes(pipeline):
if service_job_manager.is_pure_service_node(pipeline_state,
node.node_info.id):
service_job_manager.stop_node_services(pipeline_state, node.node_info.id)
if not service_job_manager.stop_node_services(pipeline_state,
node.node_info.id):
is_active = True
elif _maybe_enqueue_cancellation_task(
mlmd_handle, pipeline, node, task_queue, pause=pause):
is_active = True
elif service_job_manager.is_mixed_service_node(pipeline_state,
node.node_info.id):
service_job_manager.stop_node_services(pipeline_state, node.node_info.id)
if not service_job_manager.stop_node_services(pipeline_state,
node.node_info.id):
is_active = True
return is_active


Expand Down Expand Up @@ -548,17 +552,17 @@ def _filter_by_state(node_infos: List[_NodeInfo],
for node_info in stopping_node_infos:
if service_job_manager.is_pure_service_node(pipeline_state,
node_info.node.node_info.id):
service_job_manager.stop_node_services(pipeline_state,
node_info.node.node_info.id)
stopped_node_infos.append(node_info)
if service_job_manager.stop_node_services(pipeline_state,
node_info.node.node_info.id):
stopped_node_infos.append(node_info)
elif _maybe_enqueue_cancellation_task(mlmd_handle, pipeline, node_info.node,
task_queue):
pass
elif service_job_manager.is_mixed_service_node(pipeline_state,
node_info.node.node_info.id):
service_job_manager.stop_node_services(pipeline_state,
node_info.node.node_info.id)
stopped_node_infos.append(node_info)
if service_job_manager.stop_node_services(pipeline_state,
node_info.node.node_info.id):
stopped_node_infos.append(node_info)
else:
stopped_node_infos.append(node_info)

Expand Down
127 changes: 79 additions & 48 deletions tfx/orchestration/experimental/core/pipeline_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,15 @@ def setUp(self):
self._mlmd_connection = metadata.Metadata(
connection_config=connection_config)

mock_service_job_manager = mock.create_autospec(
service_jobs.ServiceJobManager, instance=True)
mock_service_job_manager.is_pure_service_node.side_effect = (
lambda _, node_id: node_id == 'ExampleGen')
mock_service_job_manager.is_mixed_service_node.side_effect = (
lambda _, node_id: node_id == 'Transform')
mock_service_job_manager.stop_node_services.return_value = True
self._mock_service_job_manager = mock_service_job_manager

@parameterized.named_parameters(
dict(testcase_name='async', pipeline=_test_pipeline('pipeline1')),
dict(
Expand Down Expand Up @@ -219,8 +228,7 @@ def test_stop_node_wait_for_inactivation_timeout(self):
pstate.PipelineState.new(m, pipeline)
with self.assertRaisesRegex(
status_lib.StatusNotOkError,
'Timed out.*waiting for node inactivation.'
) as exception_context:
'Timed out.*waiting for node inactivation.') as exception_context:
pipeline_ops.stop_node(m, node_uid, timeout_secs=1.0)
self.assertEqual(status_lib.Code.DEADLINE_EXCEEDED,
exception_context.exception.code)
Expand Down Expand Up @@ -332,13 +340,6 @@ def test_orchestrate_stop_initiated_pipelines(self, pipeline,
pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer'
pipeline.nodes.add().pipeline_node.node_info.id = 'Evaluator'

mock_service_job_manager = mock.create_autospec(
service_jobs.ServiceJobManager, instance=True)
mock_service_job_manager.is_pure_service_node.side_effect = (
lambda _, node_id: node_id == 'ExampleGen')
mock_service_job_manager.is_mixed_service_node.side_effect = (
lambda _, node_id: node_id == 'Transform')

pipeline_ops.initiate_pipeline_start(m, pipeline)
with pstate.PipelineState.load(
m, task_lib.PipelineUid.from_pipeline(pipeline)) as pipeline_state:
Expand All @@ -365,17 +366,17 @@ def test_orchestrate_stop_initiated_pipelines(self, pipeline,
is_cancelled=True), None, None, None, None
]

pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager)
pipeline_ops.orchestrate(m, task_queue, self._mock_service_job_manager)

# There are no active pipelines so these shouldn't be called.
mock_async_task_gen.assert_not_called()
mock_sync_task_gen.assert_not_called()

# stop_node_services should be called for ExampleGen which is a pure
# service node.
mock_service_job_manager.stop_node_services.assert_called_once_with(
self._mock_service_job_manager.stop_node_services.assert_called_once_with(
mock.ANY, 'ExampleGen')
mock_service_job_manager.reset_mock()
self._mock_service_job_manager.reset_mock()

task_queue.task_done(transform_task) # Pop out transform task.

Expand Down Expand Up @@ -417,15 +418,15 @@ def test_orchestrate_stop_initiated_pipelines(self, pipeline,

# Call `orchestrate` again; this time there are no more active node
# executions so the pipeline should be marked as cancelled.
pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager)
pipeline_ops.orchestrate(m, task_queue, self._mock_service_job_manager)
self.assertTrue(task_queue.is_empty())
[execution] = m.store.get_executions_by_id([pipeline_execution_id])
self.assertEqual(metadata_store_pb2.Execution.CANCELED,
execution.last_known_state)

# stop_node_services should be called on both ExampleGen and Transform
# which are service nodes.
mock_service_job_manager.stop_node_services.assert_has_calls(
self._mock_service_job_manager.stop_node_services.assert_has_calls(
[mock.call(mock.ANY, 'ExampleGen'),
mock.call(mock.ANY, 'Transform')],
any_order=True)
Expand All @@ -440,13 +441,6 @@ def test_orchestrate_update_initiated_pipelines(self, pipeline):
pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer'
pipeline.nodes.add().pipeline_node.node_info.id = 'Evaluator'

mock_service_job_manager = mock.create_autospec(
service_jobs.ServiceJobManager, instance=True)
mock_service_job_manager.is_pure_service_node.side_effect = (
lambda _, node_id: node_id == 'ExampleGen')
mock_service_job_manager.is_mixed_service_node.side_effect = (
lambda _, node_id: node_id == 'Transform')

pipeline_ops.initiate_pipeline_start(m, pipeline)

task_queue = tq.TaskQueue()
Expand All @@ -462,13 +456,13 @@ def test_orchestrate_update_initiated_pipelines(self, pipeline):
with pipeline_state:
self.assertTrue(pipeline_state.is_update_initiated())

pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager)
pipeline_ops.orchestrate(m, task_queue, self._mock_service_job_manager)

# stop_node_services should be called for ExampleGen which is a pure
# service node.
mock_service_job_manager.stop_node_services.assert_called_once_with(
self._mock_service_job_manager.stop_node_services.assert_called_once_with(
mock.ANY, 'ExampleGen')
mock_service_job_manager.reset_mock()
self._mock_service_job_manager.reset_mock()

# Simulate completion of all the exec node tasks.
for node_id in ('Transform', 'Trainer', 'Evaluator'):
Expand All @@ -494,11 +488,11 @@ def test_orchestrate_update_initiated_pipelines(self, pipeline):
with pipeline_state:
self.assertTrue(pipeline_state.is_update_initiated())

pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager)
pipeline_ops.orchestrate(m, task_queue, self._mock_service_job_manager)

# stop_node_services should be called for Transform (mixed service node)
# too since corresponding ExecNodeTask has been processed.
mock_service_job_manager.stop_node_services.assert_has_calls(
self._mock_service_job_manager.stop_node_services.assert_has_calls(
[mock.call(mock.ANY, 'ExampleGen'),
mock.call(mock.ANY, 'Transform')])

Expand Down Expand Up @@ -558,10 +552,6 @@ def test_active_pipelines_with_stopped_nodes(self, pipeline,
pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer'
pipeline.nodes.add().pipeline_node.node_info.id = 'Evaluator'

mock_service_job_manager = mock.create_autospec(
service_jobs.ServiceJobManager, instance=True)
mock_service_job_manager.is_pure_service_node.side_effect = (
lambda _, node_id: node_id == 'ExampleGen')
example_gen_node_uid = task_lib.NodeUid.from_pipeline_node(
pipeline, pipeline.nodes[0].pipeline_node)

Expand Down Expand Up @@ -607,12 +597,12 @@ def test_active_pipelines_with_stopped_nodes(self, pipeline,
# Simulate Evaluator having an active execution in MLMD.
mock_gen_task_from_active.side_effect = [evaluator_task]

pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager)
pipeline_ops.orchestrate(m, task_queue, self._mock_service_job_manager)
self.assertEqual(1, mock_task_gen.return_value.generate.call_count)

# stop_node_services should be called on example-gen which is a pure
# service node.
mock_service_job_manager.stop_node_services.assert_called_once_with(
self._mock_service_job_manager.stop_node_services.assert_called_once_with(
mock.ANY, 'ExampleGen')

# Verify that tasks are enqueued in the expected order:
Expand Down Expand Up @@ -797,9 +787,6 @@ def test_pure_service_node_stop_then_start_flow(self, pipeline,
pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen'

mock_service_job_manager = mock.create_autospec(
service_jobs.ServiceJobManager, instance=True)
mock_service_job_manager.is_pure_service_node.return_value = True
example_gen_node_uid = task_lib.NodeUid.from_pipeline_node(
pipeline, pipeline.nodes[0].pipeline_node)

Expand All @@ -813,11 +800,11 @@ def test_pure_service_node_stop_then_start_flow(self, pipeline,

task_queue = tq.TaskQueue()

pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager)
pipeline_ops.orchestrate(m, task_queue, self._mock_service_job_manager)

# stop_node_services should be called for ExampleGen which is a pure
# service node.
mock_service_job_manager.stop_node_services.assert_called_once_with(
self._mock_service_job_manager.stop_node_services.assert_called_once_with(
mock.ANY, 'ExampleGen')

with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state:
Expand All @@ -826,7 +813,7 @@ def test_pure_service_node_stop_then_start_flow(self, pipeline,
self.assertEqual(status_lib.Code.CANCELLED, node_state.status.code)

pipeline_ops.initiate_node_start(m, example_gen_node_uid)
pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager)
pipeline_ops.orchestrate(m, task_queue, self._mock_service_job_manager)

with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state:
node_state = pipeline_state.get_node_state(example_gen_node_uid)
Expand All @@ -844,11 +831,6 @@ def test_mixed_service_node_stop_then_start_flow(self, pipeline,
pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
pipeline.nodes.add().pipeline_node.node_info.id = 'Transform'

mock_service_job_manager = mock.create_autospec(
service_jobs.ServiceJobManager, instance=True)
mock_service_job_manager.is_pure_service_node.return_value = False
mock_service_job_manager.is_mixed_service_node.return_value = True

transform_node_uid = task_lib.NodeUid.from_pipeline_node(
pipeline, pipeline.nodes[0].pipeline_node)

Expand All @@ -868,11 +850,11 @@ def test_mixed_service_node_stop_then_start_flow(self, pipeline,
node_uid=transform_node_uid)
task_queue.enqueue(transform_task)

pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager)
pipeline_ops.orchestrate(m, task_queue, self._mock_service_job_manager)

# stop_node_services should not be called as there was an active
# ExecNodeTask for Transform which is a mixed service node.
mock_service_job_manager.stop_node_services.assert_not_called()
self._mock_service_job_manager.stop_node_services.assert_not_called()

# Dequeue pre-existing transform task.
task = task_queue.dequeue()
Expand All @@ -890,11 +872,11 @@ def test_mixed_service_node_stop_then_start_flow(self, pipeline,
self.assertEqual(pstate.NodeState.STOPPING, node_state.state)
self.assertEqual(status_lib.Code.CANCELLED, node_state.status.code)

pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager)
pipeline_ops.orchestrate(m, task_queue, self._mock_service_job_manager)

# stop_node_services should be called for Transform which is a mixed
# service node and corresponding ExecNodeTask has been dequeued.
mock_service_job_manager.stop_node_services.assert_called_once_with(
self._mock_service_job_manager.stop_node_services.assert_called_once_with(
mock.ANY, 'Transform')

with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state:
Expand All @@ -903,7 +885,7 @@ def test_mixed_service_node_stop_then_start_flow(self, pipeline,
self.assertEqual(status_lib.Code.CANCELLED, node_state.status.code)

pipeline_ops.initiate_node_start(m, transform_node_uid)
pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager)
pipeline_ops.orchestrate(m, task_queue, self._mock_service_job_manager)

with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state:
node_state = pipeline_state.get_node_state(transform_node_uid)
Expand Down Expand Up @@ -1027,6 +1009,55 @@ def test_update_node_state_tasks_handling(self, mock_sync_task_gen):
code=status_lib.Code.ABORTED, message='foobar error'),
pipeline_state.get_node_state(evaluator_node_uid).status)

@parameterized.parameters(
_test_pipeline('pipeline1'),
_test_pipeline('pipeline1', pipeline_pb2.Pipeline.SYNC))
@mock.patch.object(sync_pipeline_task_gen, 'SyncPipelineTaskGenerator')
@mock.patch.object(async_pipeline_task_gen, 'AsyncPipelineTaskGenerator')
def test_stop_node_services_failure(self, pipeline, mock_async_task_gen,
mock_sync_task_gen):
with self._mlmd_connection as m:
pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen'
pipeline.nodes.add().pipeline_node.node_info.id = 'Transform'

example_gen_node_uid = task_lib.NodeUid.from_pipeline_node(
pipeline, pipeline.nodes[0].pipeline_node)
transform_node_uid = task_lib.NodeUid.from_pipeline_node(
pipeline, pipeline.nodes[1].pipeline_node)

pipeline_ops.initiate_pipeline_start(m, pipeline)
with pstate.PipelineState.load(
m, task_lib.PipelineUid.from_pipeline(pipeline)) as pipeline_state:
with pipeline_state.node_state_update_context(
example_gen_node_uid) as node_state:
node_state.update(pstate.NodeState.STOPPING,
status_lib.Status(code=status_lib.Code.CANCELLED))
with pipeline_state.node_state_update_context(
transform_node_uid) as node_state:
node_state.update(pstate.NodeState.STOPPING,
status_lib.Status(code=status_lib.Code.CANCELLED))

task_queue = tq.TaskQueue()

# Simulate failure of stop_node_services.
self._mock_service_job_manager.stop_node_services.return_value = False

pipeline_ops.orchestrate(m, task_queue, self._mock_service_job_manager)

self._mock_service_job_manager.stop_node_services.assert_has_calls(
[mock.call(mock.ANY, 'ExampleGen'),
mock.call(mock.ANY, 'Transform')],
any_order=True)

# Node state should be STOPPING, not STOPPED since stop_node_services
# failed.
with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state:
node_state = pipeline_state.get_node_state(example_gen_node_uid)
self.assertEqual(pstate.NodeState.STOPPING, node_state.state)
node_state = pipeline_state.get_node_state(transform_node_uid)
self.assertEqual(pstate.NodeState.STOPPING, node_state.state)


if __name__ == '__main__':
tf.test.main()
18 changes: 13 additions & 5 deletions tfx/orchestration/experimental/core/service_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def ensure_node_services(self, pipeline_state: pstate.PipelineState,

@abc.abstractmethod
def stop_node_services(self, pipeline_state: pstate.PipelineState,
node_id: str) -> None:
node_id: str) -> bool:
"""Stops service jobs (if any) associated with the node.
Note that this method will only be called if either `is_pure_service_node`
Expand All @@ -69,6 +69,9 @@ def stop_node_services(self, pipeline_state: pstate.PipelineState,
Args:
pipeline_state: A `PipelineState` object for an active pipeline.
node_id: Id of the node to stop services.
Returns:
`True` if the operation was successful, `False` otherwise.
"""

@abc.abstractmethod
Expand Down Expand Up @@ -107,7 +110,7 @@ def ensure_node_services(self, pipeline_state: pstate.PipelineState,
raise NotImplementedError('Service jobs not supported.')

def stop_node_services(self, pipeline_state: pstate.PipelineState,
node_id: str) -> None:
node_id: str) -> bool:
del pipeline_state, node_id
raise NotImplementedError('Service jobs not supported.')

Expand All @@ -122,7 +125,6 @@ def is_mixed_service_node(self, pipeline_state: pstate.PipelineState,
return False


# TODO(b/201346378): Also handle exceptions in stop_node_services.
class ExceptionHandlingServiceJobManagerWrapper(ServiceJobManager):
"""Wraps a ServiceJobManager instance and does some basic exception handling."""

Expand All @@ -140,8 +142,14 @@ def ensure_node_services(self, pipeline_state: pstate.PipelineState,
return ServiceStatus.FAILED

def stop_node_services(self, pipeline_state: pstate.PipelineState,
node_id: str) -> None:
self._service_job_manager.stop_node_services(pipeline_state, node_id)
node_id: str) -> bool:
try:
return self._service_job_manager.stop_node_services(
pipeline_state, node_id)
except Exception: # pylint: disable=broad-except
logging.exception(
'Exception raised by underlying `ServiceJobManager` instance.')
return False

def is_pure_service_node(self, pipeline_state: pstate.PipelineState,
node_id: str) -> bool:
Expand Down
Loading

0 comments on commit f142075

Please sign in to comment.