diff --git a/tfx/orchestration/experimental/core/pipeline_ops.py b/tfx/orchestration/experimental/core/pipeline_ops.py index 2267eef4c2..4a66f942be 100644 --- a/tfx/orchestration/experimental/core/pipeline_ops.py +++ b/tfx/orchestration/experimental/core/pipeline_ops.py @@ -475,13 +475,17 @@ def _cancel_nodes(mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue, for node in pstate.get_all_pipeline_nodes(pipeline): if service_job_manager.is_pure_service_node(pipeline_state, node.node_info.id): - service_job_manager.stop_node_services(pipeline_state, node.node_info.id) + if not service_job_manager.stop_node_services(pipeline_state, + node.node_info.id): + is_active = True elif _maybe_enqueue_cancellation_task( mlmd_handle, pipeline, node, task_queue, pause=pause): is_active = True elif service_job_manager.is_mixed_service_node(pipeline_state, node.node_info.id): - service_job_manager.stop_node_services(pipeline_state, node.node_info.id) + if not service_job_manager.stop_node_services(pipeline_state, + node.node_info.id): + is_active = True return is_active @@ -548,17 +552,17 @@ def _filter_by_state(node_infos: List[_NodeInfo], for node_info in stopping_node_infos: if service_job_manager.is_pure_service_node(pipeline_state, node_info.node.node_info.id): - service_job_manager.stop_node_services(pipeline_state, - node_info.node.node_info.id) - stopped_node_infos.append(node_info) + if service_job_manager.stop_node_services(pipeline_state, + node_info.node.node_info.id): + stopped_node_infos.append(node_info) elif _maybe_enqueue_cancellation_task(mlmd_handle, pipeline, node_info.node, task_queue): pass elif service_job_manager.is_mixed_service_node(pipeline_state, node_info.node.node_info.id): - service_job_manager.stop_node_services(pipeline_state, - node_info.node.node_info.id) - stopped_node_infos.append(node_info) + if service_job_manager.stop_node_services(pipeline_state, + node_info.node.node_info.id): + stopped_node_infos.append(node_info) else: stopped_node_infos.append(node_info) diff --git a/tfx/orchestration/experimental/core/pipeline_ops_test.py b/tfx/orchestration/experimental/core/pipeline_ops_test.py index a4c3442c5a..f2568a3329 100644 --- a/tfx/orchestration/experimental/core/pipeline_ops_test.py +++ b/tfx/orchestration/experimental/core/pipeline_ops_test.py @@ -74,6 +74,15 @@ def setUp(self): self._mlmd_connection = metadata.Metadata( connection_config=connection_config) + mock_service_job_manager = mock.create_autospec( + service_jobs.ServiceJobManager, instance=True) + mock_service_job_manager.is_pure_service_node.side_effect = ( + lambda _, node_id: node_id == 'ExampleGen') + mock_service_job_manager.is_mixed_service_node.side_effect = ( + lambda _, node_id: node_id == 'Transform') + mock_service_job_manager.stop_node_services.return_value = True + self._mock_service_job_manager = mock_service_job_manager + @parameterized.named_parameters( dict(testcase_name='async', pipeline=_test_pipeline('pipeline1')), dict( @@ -219,8 +228,7 @@ def test_stop_node_wait_for_inactivation_timeout(self): pstate.PipelineState.new(m, pipeline) with self.assertRaisesRegex( status_lib.StatusNotOkError, - 'Timed out.*waiting for node inactivation.' - ) as exception_context: + 'Timed out.*waiting for node inactivation.') as exception_context: pipeline_ops.stop_node(m, node_uid, timeout_secs=1.0) self.assertEqual(status_lib.Code.DEADLINE_EXCEEDED, exception_context.exception.code) @@ -332,13 +340,6 @@ def test_orchestrate_stop_initiated_pipelines(self, pipeline, pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer' pipeline.nodes.add().pipeline_node.node_info.id = 'Evaluator' - mock_service_job_manager = mock.create_autospec( - service_jobs.ServiceJobManager, instance=True) - mock_service_job_manager.is_pure_service_node.side_effect = ( - lambda _, node_id: node_id == 'ExampleGen') - mock_service_job_manager.is_mixed_service_node.side_effect = ( - lambda _, node_id: node_id == 'Transform') - pipeline_ops.initiate_pipeline_start(m, pipeline) with pstate.PipelineState.load( m, task_lib.PipelineUid.from_pipeline(pipeline)) as pipeline_state: @@ -365,7 +366,7 @@ def test_orchestrate_stop_initiated_pipelines(self, pipeline, is_cancelled=True), None, None, None, None ] - pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager) + pipeline_ops.orchestrate(m, task_queue, self._mock_service_job_manager) # There are no active pipelines so these shouldn't be called. mock_async_task_gen.assert_not_called() @@ -373,9 +374,9 @@ def test_orchestrate_stop_initiated_pipelines(self, pipeline, # stop_node_services should be called for ExampleGen which is a pure # service node. - mock_service_job_manager.stop_node_services.assert_called_once_with( + self._mock_service_job_manager.stop_node_services.assert_called_once_with( mock.ANY, 'ExampleGen') - mock_service_job_manager.reset_mock() + self._mock_service_job_manager.reset_mock() task_queue.task_done(transform_task) # Pop out transform task. @@ -417,7 +418,7 @@ def test_orchestrate_stop_initiated_pipelines(self, pipeline, # Call `orchestrate` again; this time there are no more active node # executions so the pipeline should be marked as cancelled. - pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager) + pipeline_ops.orchestrate(m, task_queue, self._mock_service_job_manager) self.assertTrue(task_queue.is_empty()) [execution] = m.store.get_executions_by_id([pipeline_execution_id]) self.assertEqual(metadata_store_pb2.Execution.CANCELED, @@ -425,7 +426,7 @@ def test_orchestrate_stop_initiated_pipelines(self, pipeline, # stop_node_services should be called on both ExampleGen and Transform # which are service nodes. - mock_service_job_manager.stop_node_services.assert_has_calls( + self._mock_service_job_manager.stop_node_services.assert_has_calls( [mock.call(mock.ANY, 'ExampleGen'), mock.call(mock.ANY, 'Transform')], any_order=True) @@ -440,13 +441,6 @@ def test_orchestrate_update_initiated_pipelines(self, pipeline): pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer' pipeline.nodes.add().pipeline_node.node_info.id = 'Evaluator' - mock_service_job_manager = mock.create_autospec( - service_jobs.ServiceJobManager, instance=True) - mock_service_job_manager.is_pure_service_node.side_effect = ( - lambda _, node_id: node_id == 'ExampleGen') - mock_service_job_manager.is_mixed_service_node.side_effect = ( - lambda _, node_id: node_id == 'Transform') - pipeline_ops.initiate_pipeline_start(m, pipeline) task_queue = tq.TaskQueue() @@ -462,13 +456,13 @@ def test_orchestrate_update_initiated_pipelines(self, pipeline): with pipeline_state: self.assertTrue(pipeline_state.is_update_initiated()) - pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager) + pipeline_ops.orchestrate(m, task_queue, self._mock_service_job_manager) # stop_node_services should be called for ExampleGen which is a pure # service node. - mock_service_job_manager.stop_node_services.assert_called_once_with( + self._mock_service_job_manager.stop_node_services.assert_called_once_with( mock.ANY, 'ExampleGen') - mock_service_job_manager.reset_mock() + self._mock_service_job_manager.reset_mock() # Simulate completion of all the exec node tasks. for node_id in ('Transform', 'Trainer', 'Evaluator'): @@ -494,11 +488,11 @@ def test_orchestrate_update_initiated_pipelines(self, pipeline): with pipeline_state: self.assertTrue(pipeline_state.is_update_initiated()) - pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager) + pipeline_ops.orchestrate(m, task_queue, self._mock_service_job_manager) # stop_node_services should be called for Transform (mixed service node) # too since corresponding ExecNodeTask has been processed. - mock_service_job_manager.stop_node_services.assert_has_calls( + self._mock_service_job_manager.stop_node_services.assert_has_calls( [mock.call(mock.ANY, 'ExampleGen'), mock.call(mock.ANY, 'Transform')]) @@ -558,10 +552,6 @@ def test_active_pipelines_with_stopped_nodes(self, pipeline, pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer' pipeline.nodes.add().pipeline_node.node_info.id = 'Evaluator' - mock_service_job_manager = mock.create_autospec( - service_jobs.ServiceJobManager, instance=True) - mock_service_job_manager.is_pure_service_node.side_effect = ( - lambda _, node_id: node_id == 'ExampleGen') example_gen_node_uid = task_lib.NodeUid.from_pipeline_node( pipeline, pipeline.nodes[0].pipeline_node) @@ -607,12 +597,12 @@ def test_active_pipelines_with_stopped_nodes(self, pipeline, # Simulate Evaluator having an active execution in MLMD. mock_gen_task_from_active.side_effect = [evaluator_task] - pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager) + pipeline_ops.orchestrate(m, task_queue, self._mock_service_job_manager) self.assertEqual(1, mock_task_gen.return_value.generate.call_count) # stop_node_services should be called on example-gen which is a pure # service node. - mock_service_job_manager.stop_node_services.assert_called_once_with( + self._mock_service_job_manager.stop_node_services.assert_called_once_with( mock.ANY, 'ExampleGen') # Verify that tasks are enqueued in the expected order: @@ -797,9 +787,6 @@ def test_pure_service_node_stop_then_start_flow(self, pipeline, pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen' - mock_service_job_manager = mock.create_autospec( - service_jobs.ServiceJobManager, instance=True) - mock_service_job_manager.is_pure_service_node.return_value = True example_gen_node_uid = task_lib.NodeUid.from_pipeline_node( pipeline, pipeline.nodes[0].pipeline_node) @@ -813,11 +800,11 @@ def test_pure_service_node_stop_then_start_flow(self, pipeline, task_queue = tq.TaskQueue() - pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager) + pipeline_ops.orchestrate(m, task_queue, self._mock_service_job_manager) # stop_node_services should be called for ExampleGen which is a pure # service node. - mock_service_job_manager.stop_node_services.assert_called_once_with( + self._mock_service_job_manager.stop_node_services.assert_called_once_with( mock.ANY, 'ExampleGen') with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state: @@ -826,7 +813,7 @@ def test_pure_service_node_stop_then_start_flow(self, pipeline, self.assertEqual(status_lib.Code.CANCELLED, node_state.status.code) pipeline_ops.initiate_node_start(m, example_gen_node_uid) - pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager) + pipeline_ops.orchestrate(m, task_queue, self._mock_service_job_manager) with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state: node_state = pipeline_state.get_node_state(example_gen_node_uid) @@ -844,11 +831,6 @@ def test_mixed_service_node_stop_then_start_flow(self, pipeline, pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) pipeline.nodes.add().pipeline_node.node_info.id = 'Transform' - mock_service_job_manager = mock.create_autospec( - service_jobs.ServiceJobManager, instance=True) - mock_service_job_manager.is_pure_service_node.return_value = False - mock_service_job_manager.is_mixed_service_node.return_value = True - transform_node_uid = task_lib.NodeUid.from_pipeline_node( pipeline, pipeline.nodes[0].pipeline_node) @@ -868,11 +850,11 @@ def test_mixed_service_node_stop_then_start_flow(self, pipeline, node_uid=transform_node_uid) task_queue.enqueue(transform_task) - pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager) + pipeline_ops.orchestrate(m, task_queue, self._mock_service_job_manager) # stop_node_services should not be called as there was an active # ExecNodeTask for Transform which is a mixed service node. - mock_service_job_manager.stop_node_services.assert_not_called() + self._mock_service_job_manager.stop_node_services.assert_not_called() # Dequeue pre-existing transform task. task = task_queue.dequeue() @@ -890,11 +872,11 @@ def test_mixed_service_node_stop_then_start_flow(self, pipeline, self.assertEqual(pstate.NodeState.STOPPING, node_state.state) self.assertEqual(status_lib.Code.CANCELLED, node_state.status.code) - pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager) + pipeline_ops.orchestrate(m, task_queue, self._mock_service_job_manager) # stop_node_services should be called for Transform which is a mixed # service node and corresponding ExecNodeTask has been dequeued. - mock_service_job_manager.stop_node_services.assert_called_once_with( + self._mock_service_job_manager.stop_node_services.assert_called_once_with( mock.ANY, 'Transform') with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state: @@ -903,7 +885,7 @@ def test_mixed_service_node_stop_then_start_flow(self, pipeline, self.assertEqual(status_lib.Code.CANCELLED, node_state.status.code) pipeline_ops.initiate_node_start(m, transform_node_uid) - pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager) + pipeline_ops.orchestrate(m, task_queue, self._mock_service_job_manager) with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state: node_state = pipeline_state.get_node_state(transform_node_uid) @@ -1027,6 +1009,55 @@ def test_update_node_state_tasks_handling(self, mock_sync_task_gen): code=status_lib.Code.ABORTED, message='foobar error'), pipeline_state.get_node_state(evaluator_node_uid).status) + @parameterized.parameters( + _test_pipeline('pipeline1'), + _test_pipeline('pipeline1', pipeline_pb2.Pipeline.SYNC)) + @mock.patch.object(sync_pipeline_task_gen, 'SyncPipelineTaskGenerator') + @mock.patch.object(async_pipeline_task_gen, 'AsyncPipelineTaskGenerator') + def test_stop_node_services_failure(self, pipeline, mock_async_task_gen, + mock_sync_task_gen): + with self._mlmd_connection as m: + pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) + pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen' + pipeline.nodes.add().pipeline_node.node_info.id = 'Transform' + + example_gen_node_uid = task_lib.NodeUid.from_pipeline_node( + pipeline, pipeline.nodes[0].pipeline_node) + transform_node_uid = task_lib.NodeUid.from_pipeline_node( + pipeline, pipeline.nodes[1].pipeline_node) + + pipeline_ops.initiate_pipeline_start(m, pipeline) + with pstate.PipelineState.load( + m, task_lib.PipelineUid.from_pipeline(pipeline)) as pipeline_state: + with pipeline_state.node_state_update_context( + example_gen_node_uid) as node_state: + node_state.update(pstate.NodeState.STOPPING, + status_lib.Status(code=status_lib.Code.CANCELLED)) + with pipeline_state.node_state_update_context( + transform_node_uid) as node_state: + node_state.update(pstate.NodeState.STOPPING, + status_lib.Status(code=status_lib.Code.CANCELLED)) + + task_queue = tq.TaskQueue() + + # Simulate failure of stop_node_services. + self._mock_service_job_manager.stop_node_services.return_value = False + + pipeline_ops.orchestrate(m, task_queue, self._mock_service_job_manager) + + self._mock_service_job_manager.stop_node_services.assert_has_calls( + [mock.call(mock.ANY, 'ExampleGen'), + mock.call(mock.ANY, 'Transform')], + any_order=True) + + # Node state should be STOPPING, not STOPPED since stop_node_services + # failed. + with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state: + node_state = pipeline_state.get_node_state(example_gen_node_uid) + self.assertEqual(pstate.NodeState.STOPPING, node_state.state) + node_state = pipeline_state.get_node_state(transform_node_uid) + self.assertEqual(pstate.NodeState.STOPPING, node_state.state) + if __name__ == '__main__': tf.test.main() diff --git a/tfx/orchestration/experimental/core/service_jobs.py b/tfx/orchestration/experimental/core/service_jobs.py index fb42332bad..c64a9bafa2 100644 --- a/tfx/orchestration/experimental/core/service_jobs.py +++ b/tfx/orchestration/experimental/core/service_jobs.py @@ -60,7 +60,7 @@ def ensure_node_services(self, pipeline_state: pstate.PipelineState, @abc.abstractmethod def stop_node_services(self, pipeline_state: pstate.PipelineState, - node_id: str) -> None: + node_id: str) -> bool: """Stops service jobs (if any) associated with the node. Note that this method will only be called if either `is_pure_service_node` @@ -69,6 +69,9 @@ def stop_node_services(self, pipeline_state: pstate.PipelineState, Args: pipeline_state: A `PipelineState` object for an active pipeline. node_id: Id of the node to stop services. + + Returns: + `True` if the operation was successful, `False` otherwise. """ @abc.abstractmethod @@ -107,7 +110,7 @@ def ensure_node_services(self, pipeline_state: pstate.PipelineState, raise NotImplementedError('Service jobs not supported.') def stop_node_services(self, pipeline_state: pstate.PipelineState, - node_id: str) -> None: + node_id: str) -> bool: del pipeline_state, node_id raise NotImplementedError('Service jobs not supported.') @@ -122,7 +125,6 @@ def is_mixed_service_node(self, pipeline_state: pstate.PipelineState, return False -# TODO(b/201346378): Also handle exceptions in stop_node_services. class ExceptionHandlingServiceJobManagerWrapper(ServiceJobManager): """Wraps a ServiceJobManager instance and does some basic exception handling.""" @@ -140,8 +142,14 @@ def ensure_node_services(self, pipeline_state: pstate.PipelineState, return ServiceStatus.FAILED def stop_node_services(self, pipeline_state: pstate.PipelineState, - node_id: str) -> None: - self._service_job_manager.stop_node_services(pipeline_state, node_id) + node_id: str) -> bool: + try: + return self._service_job_manager.stop_node_services( + pipeline_state, node_id) + except Exception: # pylint: disable=broad-except + logging.exception( + 'Exception raised by underlying `ServiceJobManager` instance.') + return False def is_pure_service_node(self, pipeline_state: pstate.PipelineState, node_id: str) -> bool: diff --git a/tfx/orchestration/experimental/core/service_jobs_test.py b/tfx/orchestration/experimental/core/service_jobs_test.py index 83c0d7c8ae..a3ced40516 100644 --- a/tfx/orchestration/experimental/core/service_jobs_test.py +++ b/tfx/orchestration/experimental/core/service_jobs_test.py @@ -27,6 +27,7 @@ def setUp(self): service_jobs.ServiceJobManager, instance=True) self._mock_service_job_manager.ensure_node_services.return_value = ( service_jobs.ServiceStatus.SUCCESS) + self._mock_service_job_manager.stop_node_services.return_value = True self._mock_service_job_manager.is_pure_service_node.return_value = True self._mock_service_job_manager.is_mixed_service_node.return_value = False self._wrapper = service_jobs.ExceptionHandlingServiceJobManagerWrapper( @@ -35,7 +36,7 @@ def setUp(self): def test_calls_forwarded_to_underlying_instance(self): self.assertEqual(service_jobs.ServiceStatus.SUCCESS, self._wrapper.ensure_node_services(mock.Mock(), 'node1')) - self._wrapper.stop_node_services(mock.Mock(), 'node2') + self.assertTrue(self._wrapper.stop_node_services(mock.Mock(), 'node2')) self.assertTrue(self._wrapper.is_pure_service_node(mock.Mock(), 'node3')) self.assertFalse(self._wrapper.is_mixed_service_node(mock.Mock(), 'node4')) self._mock_service_job_manager.ensure_node_services.assert_called_once_with( @@ -47,7 +48,7 @@ def test_calls_forwarded_to_underlying_instance(self): self._mock_service_job_manager.is_mixed_service_node.assert_called_once_with( mock.ANY, 'node4') - def test_exception_handling(self): + def test_ensure_node_services_exception_handling(self): self._mock_service_job_manager.ensure_node_services.side_effect = RuntimeError( 'test error') self.assertEqual(service_jobs.ServiceStatus.FAILED, @@ -55,6 +56,13 @@ def test_exception_handling(self): self._mock_service_job_manager.ensure_node_services.assert_called_once_with( mock.ANY, 'node1') + def test_stop_node_services_exception_handling(self): + self._mock_service_job_manager.stop_node_services.side_effect = RuntimeError( + 'test error') + self.assertFalse(self._wrapper.stop_node_services(mock.Mock(), 'node2')) + self._mock_service_job_manager.stop_node_services.assert_called_once_with( + mock.ANY, 'node2') + if __name__ == '__main__': tf.test.main()