diff --git a/tfx/orchestration/experimental/core/pipeline_ops.py b/tfx/orchestration/experimental/core/pipeline_ops.py index e874197182..6294a774ed 100644 --- a/tfx/orchestration/experimental/core/pipeline_ops.py +++ b/tfx/orchestration/experimental/core/pipeline_ops.py @@ -217,6 +217,17 @@ def stop_node( mlmd_handle, active_executions[0].id, timeout_secs=timeout_secs) +@_to_status_not_ok_error +def initiate_pipeline_update( + mlmd_handle: metadata.Metadata, + pipeline: pipeline_pb2.Pipeline) -> pstate.PipelineState: + """Initiates pipeline update.""" + pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) + with pstate.PipelineState.load(mlmd_handle, pipeline_uid) as pipeline_state: + pipeline_state.initiate_update(pipeline) + return pipeline_state + + @_to_status_not_ok_error def _wait_for_inactivation( mlmd_handle: metadata.Metadata, @@ -271,10 +282,13 @@ def orchestrate(mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue, active_pipeline_states = [] stop_initiated_pipeline_states = [] + update_initiated_pipeline_states = [] for pipeline_state in pipeline_states: with pipeline_state: - if pipeline_state.stop_initiated_reason() is not None: + if pipeline_state.is_stop_initiated(): stop_initiated_pipeline_states.append(pipeline_state) + elif pipeline_state.is_update_initiated(): + update_initiated_pipeline_states.append(pipeline_state) elif pipeline_state.is_active(): active_pipeline_states.append(pipeline_state) else: @@ -289,6 +303,12 @@ def orchestrate(mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue, _orchestrate_stop_initiated_pipeline(mlmd_handle, task_queue, service_job_manager, pipeline_state) + for pipeline_state in update_initiated_pipeline_states: + logging.info('Orchestrating update-initiated pipeline: %s', + pipeline_state.pipeline_uid) + _orchestrate_update_initiated_pipeline(mlmd_handle, task_queue, + service_job_manager, pipeline_state) + for pipeline_state in active_pipeline_states: logging.info('Orchestrating pipeline: %s', pipeline_state.pipeline_uid) _orchestrate_active_pipeline(mlmd_handle, task_queue, service_job_manager, @@ -315,32 +335,53 @@ def _get_pipeline_states( return result -def _orchestrate_stop_initiated_pipeline( - mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue, - service_job_manager: service_jobs.ServiceJobManager, - pipeline_state: pstate.PipelineState) -> None: - """Orchestrates stop initiated pipeline.""" - with pipeline_state: - stop_reason = pipeline_state.stop_initiated_reason() - assert stop_reason is not None +def _cancel_nodes(mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue, + service_job_manager: service_jobs.ServiceJobManager, + pipeline_state: pstate.PipelineState, pause: bool) -> bool: + """Cancels pipeline nodes and returns `True` if any node is currently active.""" pipeline = pipeline_state.pipeline - has_active_executions = False + is_active = False for node in pstate.get_all_pipeline_nodes(pipeline): if service_job_manager.is_pure_service_node(pipeline_state, node.node_info.id): service_job_manager.stop_node_services(pipeline_state, node.node_info.id) - elif _maybe_enqueue_cancellation_task(mlmd_handle, pipeline, node, - task_queue): - has_active_executions = True + elif _maybe_enqueue_cancellation_task( + mlmd_handle, pipeline, node, task_queue, pause=pause): + is_active = True elif service_job_manager.is_mixed_service_node(pipeline_state, node.node_info.id): service_job_manager.stop_node_services(pipeline_state, node.node_info.id) - if not has_active_executions: + return is_active + + +def _orchestrate_stop_initiated_pipeline( + mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue, + service_job_manager: service_jobs.ServiceJobManager, + pipeline_state: pstate.PipelineState) -> None: + """Orchestrates stop initiated pipeline.""" + with pipeline_state: + stop_reason = pipeline_state.stop_initiated_reason() + assert stop_reason is not None + is_active = _cancel_nodes( + mlmd_handle, task_queue, service_job_manager, pipeline_state, pause=False) + if not is_active: with pipeline_state: # Update pipeline execution state in MLMD. pipeline_state.set_pipeline_execution_state_from_status(stop_reason) +def _orchestrate_update_initiated_pipeline( + mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue, + service_job_manager: service_jobs.ServiceJobManager, + pipeline_state: pstate.PipelineState) -> None: + """Orchestrates an update-initiated pipeline.""" + is_active = _cancel_nodes( + mlmd_handle, task_queue, service_job_manager, pipeline_state, pause=True) + if not is_active: + with pipeline_state: + pipeline_state.apply_pipeline_update() + + @attr.s(auto_attribs=True, kw_only=True) class _NodeInfo: """A convenience container of pipeline node and its state.""" @@ -464,14 +505,16 @@ def _get_node_infos(pipeline_state: pstate.PipelineState) -> List[_NodeInfo]: def _maybe_enqueue_cancellation_task(mlmd_handle: metadata.Metadata, pipeline: pipeline_pb2.Pipeline, node: pipeline_pb2.PipelineNode, - task_queue: tq.TaskQueue) -> bool: + task_queue: tq.TaskQueue, + pause: bool = False) -> bool: """Enqueues a node cancellation task if not already stopped. If the node has an ExecNodeTask in the task queue, issue a cancellation. - Otherwise, if the node has an active execution in MLMD but no ExecNodeTask - enqueued, it may be due to orchestrator restart after stopping was initiated - but before the schedulers could finish. So, enqueue an ExecNodeTask with - is_cancelled set to give a chance for the scheduler to finish gracefully. + Otherwise, when pause=False, if the node has an active execution in MLMD but + no ExecNodeTask enqueued, it may be due to orchestrator restart after stopping + was initiated but before the schedulers could finish. So, enqueue an + ExecNodeTask with is_cancelled set to give a chance for the scheduler to + finish gracefully. Args: mlmd_handle: A handle to the MLMD db. @@ -479,6 +522,8 @@ def _maybe_enqueue_cancellation_task(mlmd_handle: metadata.Metadata, node: The node to cancel. task_queue: A `TaskQueue` instance into which any cancellation tasks will be enqueued. + pause: Whether the cancellation is to pause the node rather than cancelling + the execution. Returns: `True` if a cancellation task was enqueued. `False` if node is already @@ -489,9 +534,10 @@ def _maybe_enqueue_cancellation_task(mlmd_handle: metadata.Metadata, if task_queue.contains_task_id(exec_node_task_id): task_queue.enqueue( task_lib.CancelNodeTask( - node_uid=task_lib.NodeUid.from_pipeline_node(pipeline, node))) + node_uid=task_lib.NodeUid.from_pipeline_node(pipeline, node), + pause=pause)) return True - else: + if not pause: executions = task_gen_utils.get_executions(mlmd_handle, node) exec_node_task = task_gen_utils.generate_task_from_active_execution( mlmd_handle, pipeline, node, executions, is_cancelled=True) diff --git a/tfx/orchestration/experimental/core/pipeline_ops_test.py b/tfx/orchestration/experimental/core/pipeline_ops_test.py index 098880d7c0..0f964f9f39 100644 --- a/tfx/orchestration/experimental/core/pipeline_ops_test.py +++ b/tfx/orchestration/experimental/core/pipeline_ops_test.py @@ -353,8 +353,10 @@ def test_orchestrate_active_pipelines(self, mock_async_task_gen, @mock.patch.object(sync_pipeline_task_gen, 'SyncPipelineTaskGenerator') @mock.patch.object(async_pipeline_task_gen, 'AsyncPipelineTaskGenerator') @mock.patch.object(task_gen_utils, 'generate_task_from_active_execution') - def test_stop_initiated_pipelines(self, pipeline, mock_gen_task_from_active, - mock_async_task_gen, mock_sync_task_gen): + def test_orchestrate_stop_initiated_pipelines(self, pipeline, + mock_gen_task_from_active, + mock_async_task_gen, + mock_sync_task_gen): with self._mlmd_connection as m: pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen' pipeline.nodes.add().pipeline_node.node_info.id = 'Transform' @@ -459,6 +461,83 @@ def test_stop_initiated_pipelines(self, pipeline, mock_gen_task_from_active, mock.call(mock.ANY, 'Transform')], any_order=True) + @parameterized.parameters( + _test_pipeline('pipeline1'), + _test_pipeline('pipeline1', pipeline_pb2.Pipeline.SYNC)) + def test_orchestrate_update_initiated_pipelines(self, pipeline): + with self._mlmd_connection as m: + pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen' + pipeline.nodes.add().pipeline_node.node_info.id = 'Transform' + pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer' + pipeline.nodes.add().pipeline_node.node_info.id = 'Evaluator' + + mock_service_job_manager = mock.create_autospec( + service_jobs.ServiceJobManager, instance=True) + mock_service_job_manager.is_pure_service_node.side_effect = ( + lambda _, node_id: node_id == 'ExampleGen') + mock_service_job_manager.is_mixed_service_node.side_effect = ( + lambda _, node_id: node_id == 'Transform') + + pipeline_ops.initiate_pipeline_start(m, pipeline) + + task_queue = tq.TaskQueue() + + for node_id in ('Transform', 'Trainer', 'Evaluator'): + task_queue.enqueue( + test_utils.create_exec_node_task( + task_lib.NodeUid( + pipeline_uid=task_lib.PipelineUid.from_pipeline(pipeline), + node_id=node_id))) + + pipeline_state = pipeline_ops.initiate_pipeline_update(m, pipeline) + with pipeline_state: + self.assertTrue(pipeline_state.is_update_initiated()) + + pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager) + + # stop_node_services should be called for ExampleGen which is a pure + # service node. + mock_service_job_manager.stop_node_services.assert_called_once_with( + mock.ANY, 'ExampleGen') + mock_service_job_manager.reset_mock() + + # Simulate completion of all the exec node tasks. + for node_id in ('Transform', 'Trainer', 'Evaluator'): + task = task_queue.dequeue() + task_queue.task_done(task) + self.assertTrue(task_lib.is_exec_node_task(task)) + self.assertEqual(node_id, task.node_uid.node_id) + + # Verify that cancellation tasks were enqueued in the last `orchestrate` + # call, and dequeue them. + for node_id in ('Transform', 'Trainer', 'Evaluator'): + task = task_queue.dequeue() + task_queue.task_done(task) + self.assertTrue(task_lib.is_cancel_node_task(task)) + self.assertEqual(node_id, task.node_uid.node_id) + self.assertTrue(task.pause) + + self.assertTrue(task_queue.is_empty()) + + # Pipeline continues to be in update initiated state until all + # ExecNodeTasks have been dequeued (which was not the case when last + # `orchestrate` call was made). + with pipeline_state: + self.assertTrue(pipeline_state.is_update_initiated()) + + pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager) + + # stop_node_services should be called for Transform (mixed service node) + # too since corresponding ExecNodeTask has been processed. + mock_service_job_manager.stop_node_services.assert_has_calls( + [mock.call(mock.ANY, 'ExampleGen'), + mock.call(mock.ANY, 'Transform')]) + + # Pipeline should no longer be in update-initiated state but be active. + with pipeline_state: + self.assertFalse(pipeline_state.is_update_initiated()) + self.assertTrue(pipeline_state.is_active()) + @parameterized.parameters( _test_pipeline('pipeline1'), _test_pipeline('pipeline1', pipeline_pb2.Pipeline.SYNC)) diff --git a/tfx/orchestration/experimental/core/pipeline_state.py b/tfx/orchestration/experimental/core/pipeline_state.py index 2067d85bd7..88f36b5441 100644 --- a/tfx/orchestration/experimental/core/pipeline_state.py +++ b/tfx/orchestration/experimental/core/pipeline_state.py @@ -17,7 +17,7 @@ import contextlib import threading import time -from typing import Dict, Iterator, List, Mapping, Optional +from typing import Dict, Iterator, List, Mapping, Optional, Tuple import attr from tfx import types @@ -41,6 +41,7 @@ _PIPELINE_STATUS_MSG = 'pipeline_status_msg' _NODE_STATES = 'node_states' _PIPELINE_RUN_METADATA = 'pipeline_run_metadata' +_UPDATED_PIPELINE_IR = 'updated_pipeline_ir' _ORCHESTRATOR_EXECUTION_TYPE = metadata_store_pb2.ExecutionType( name=_ORCHESTRATOR_RESERVED_ID, properties={_PIPELINE_IR: metadata_store_pb2.STRING}) @@ -139,7 +140,6 @@ def __init__(self, mlmd_handle: metadata.Metadata, self.mlmd_handle = mlmd_handle self.pipeline = pipeline self.execution_id = execution_id - self.pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) # Only set within the pipeline state context. self._mlmd_execution_atomic_op_context = None @@ -181,8 +181,7 @@ def new( message=f'Pipeline with uid {pipeline_uid} already active.') exec_properties = { - _PIPELINE_IR: - base64.b64encode(pipeline.SerializeToString()).decode('utf-8') + _PIPELINE_IR: _base64_encode_pipeline(pipeline) } if pipeline_run_metadata: exec_properties[_PIPELINE_RUN_METADATA] = json_utils.dumps( @@ -258,6 +257,10 @@ def load_from_orchestrator_context( pipeline=pipeline, execution_id=active_execution.id) + @property + def pipeline_uid(self) -> task_lib.PipelineUid: + return task_lib.PipelineUid.from_pipeline(self.pipeline) + def is_active(self) -> bool: """Returns `True` if pipeline is active.""" self._check_context() @@ -277,6 +280,60 @@ def initiate_stop(self, status: status_lib.Status) -> None: status.message) record_state_change_time() + def initiate_update(self, updated_pipeline: pipeline_pb2.Pipeline) -> None: + """Initiates pipeline update process.""" + self._check_context() + + if self.pipeline.execution_mode != updated_pipeline.execution_mode: + raise status_lib.StatusNotOkError( + code=status_lib.Code.INVALID_ARGUMENT, + message=('Updating execution_mode of an active pipeline is not ' + 'supported')) + + # TODO(b/194311197): We require that structure of the updated pipeline + # exactly matches the original. There is scope to relax this restriction. + + def _structure( + pipeline: pipeline_pb2.Pipeline + ) -> List[Tuple[str, List[str], List[str]]]: + return [(node.node_info.id, list(node.upstream_nodes), + list(node.downstream_nodes)) + for node in get_all_pipeline_nodes(pipeline)] + + if _structure(self.pipeline) != _structure(updated_pipeline): + raise status_lib.StatusNotOkError( + code=status_lib.Code.INVALID_ARGUMENT, + message=('Updated pipeline should have the same structure as the ' + 'original.')) + + data_types_utils.set_metadata_value( + self._execution.custom_properties[_UPDATED_PIPELINE_IR], + _base64_encode_pipeline(updated_pipeline)) + record_state_change_time() + + def is_update_initiated(self) -> bool: + self._check_context() + return self.is_active() and self._execution.custom_properties.get( + _UPDATED_PIPELINE_IR) is not None + + def apply_pipeline_update(self) -> None: + """Applies pipeline update that was previously initiated.""" + self._check_context() + updated_pipeline_ir = _get_metadata_value( + self._execution.custom_properties.get(_UPDATED_PIPELINE_IR)) + if not updated_pipeline_ir: + raise status_lib.StatusNotOkError( + code=status_lib.Code.INVALID_ARGUMENT, + message='No updated pipeline IR to apply') + data_types_utils.set_metadata_value( + self._execution.custom_properties[_PIPELINE_IR], updated_pipeline_ir) + del self._execution.custom_properties[_UPDATED_PIPELINE_IR] + self.pipeline = _base64_decode_pipeline(updated_pipeline_ir) + + def is_stop_initiated(self) -> bool: + self._check_context() + return self.stop_initiated_reason() is not None + def stop_initiated_reason(self) -> Optional[status_lib.Status]: """Returns status object if stop initiated, `None` otherwise.""" self._check_context() @@ -552,9 +609,7 @@ def _get_pipeline_from_orchestrator_execution( execution: metadata_store_pb2.Execution) -> pipeline_pb2.Pipeline: pipeline_ir_b64 = data_types_utils.get_metadata_value( execution.properties[_PIPELINE_IR]) - pipeline = pipeline_pb2.Pipeline() - pipeline.ParseFromString(base64.b64decode(pipeline_ir_b64)) - return pipeline + return _base64_decode_pipeline(pipeline_ir_b64) def _get_active_execution( @@ -587,3 +642,13 @@ def _get_creation_time(execution): return execution.create_time_since_epoch return max(executions, key=_get_creation_time) + + +def _base64_encode_pipeline(pipeline: pipeline_pb2.Pipeline) -> str: + return base64.b64encode(pipeline.SerializeToString()).decode('utf-8') + + +def _base64_decode_pipeline(pipeline_encoded: str) -> pipeline_pb2.Pipeline: + result = pipeline_pb2.Pipeline() + result.ParseFromString(base64.b64decode(pipeline_encoded)) + return result diff --git a/tfx/orchestration/experimental/core/pipeline_state_test.py b/tfx/orchestration/experimental/core/pipeline_state_test.py index dc80c6e385..1a212b16e1 100644 --- a/tfx/orchestration/experimental/core/pipeline_state_test.py +++ b/tfx/orchestration/experimental/core/pipeline_state_test.py @@ -30,11 +30,14 @@ def _test_pipeline(pipeline_id, execution_mode: pipeline_pb2.Pipeline.ExecutionMode = ( - pipeline_pb2.Pipeline.ASYNC)): + pipeline_pb2.Pipeline.ASYNC), + param=1): pipeline = pipeline_pb2.Pipeline() pipeline.pipeline_info.id = pipeline_id pipeline.execution_mode = execution_mode pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer' + pipeline.nodes[0].pipeline_node.parameters.parameters[ + 'param'].field_value.int_value = param return pipeline @@ -192,7 +195,45 @@ def test_pipeline_stop_initiation(self): m, task_lib.PipelineUid.from_pipeline(pipeline)) as pipeline_state: self.assertEqual(status, pipeline_state.stop_initiated_reason()) - def test_node_start_and_stop(self): + def test_update_initiation_and_apply(self): + with self._mlmd_connection as m: + pipeline = _test_pipeline('pipeline1', param=1) + updated_pipeline = _test_pipeline('pipeline1', param=2) + with pstate.PipelineState.new(m, pipeline) as pipeline_state: + self.assertFalse(pipeline_state.is_update_initiated()) + pipeline_state.initiate_update(updated_pipeline) + self.assertTrue(pipeline_state.is_update_initiated()) + + # Reload from MLMD and verify. + with pstate.PipelineState.load( + m, task_lib.PipelineUid.from_pipeline(pipeline)) as pipeline_state: + self.assertTrue(pipeline_state.is_update_initiated()) + self.assertEqual(pipeline, pipeline_state.pipeline) + pipeline_state.apply_pipeline_update() + self.assertFalse(pipeline_state.is_update_initiated()) + self.assertTrue(pipeline_state.is_active()) + self.assertEqual(updated_pipeline, pipeline_state.pipeline) + + # Update should fail if execution mode is different. + updated_pipeline = _test_pipeline( + 'pipeline1', execution_mode=pipeline_pb2.Pipeline.SYNC) + with pstate.PipelineState.load( + m, task_lib.PipelineUid.from_pipeline(pipeline)) as pipeline_state: + with self.assertRaisesRegex(status_lib.StatusNotOkError, + 'Updating execution_mode.*not supported'): + pipeline_state.initiate_update(updated_pipeline) + + # Update should fail if pipeline structure changed. + updated_pipeline = _test_pipeline( + 'pipeline1', execution_mode=pipeline_pb2.Pipeline.SYNC) + updated_pipeline.nodes.add().pipeline_node.node_info.id = 'Evaluator' + with pstate.PipelineState.load( + m, task_lib.PipelineUid.from_pipeline(pipeline)) as pipeline_state: + with self.assertRaisesRegex(status_lib.StatusNotOkError, + 'Updating execution_mode.*not supported'): + pipeline_state.initiate_update(updated_pipeline) + + def test_initiate_node_start_stop(self): with self._mlmd_connection as m: pipeline = _test_pipeline('pipeline1') node_uid = task_lib.NodeUid( diff --git a/tfx/orchestration/experimental/core/task.py b/tfx/orchestration/experimental/core/task.py index 3ed1440f8b..6d1ab1b79b 100644 --- a/tfx/orchestration/experimental/core/task.py +++ b/tfx/orchestration/experimental/core/task.py @@ -18,8 +18,7 @@ """ import abc -import typing -from typing import Dict, List, Type, TypeVar +from typing import Dict, Hashable, List, Type, TypeVar import attr from tfx import types @@ -68,7 +67,7 @@ def from_pipeline_node(cls: Type['NodeUid'], pipeline: pipeline_pb2.Pipeline, # Task id can be any hashable type. -TaskId = typing.Hashable +TaskId = TypeVar('TaskId', bound=Hashable) _TaskT = TypeVar('_TaskT', bound='Task') @@ -134,8 +133,15 @@ def get_pipeline_node(self) -> pipeline_pb2.PipelineNode: @attr.s(auto_attribs=True, frozen=True) class CancelNodeTask(Task): - """Task to instruct cancellation of an ongoing node execution.""" + """Task to instruct cancellation of an ongoing node execution. + + Attributes: + node_uid: Uid of the node to be cancelled. + pause: The node is being paused with the intention of resuming the same + execution after restart. + """ node_uid: NodeUid + pause: bool = False @property def task_id(self) -> TaskId: diff --git a/tfx/orchestration/experimental/core/task_manager.py b/tfx/orchestration/experimental/core/task_manager.py index e45d727666..25c9ef973a 100644 --- a/tfx/orchestration/experimental/core/task_manager.py +++ b/tfx/orchestration/experimental/core/task_manager.py @@ -16,7 +16,7 @@ from concurrent import futures import threading import typing -from typing import Optional +from typing import Dict, Optional from absl import logging from tfx.dsl.io import fileio @@ -50,6 +50,21 @@ def __init__(self, errors): self.errors = errors +class _SchedulerWrapper: + """Wraps a TaskScheduler to store additional details.""" + + def __init__(self, task_scheduler: ts.TaskScheduler): + self._task_scheduler = task_scheduler + self.pause = False + + def schedule(self) -> ts.TaskSchedulerResult: + return self._task_scheduler.schedule() + + def cancel(self, pause: bool = False) -> None: + self.pause = pause + self._task_scheduler.cancel() + + class TaskManager: """TaskManager acts on the tasks fetched from the task queues. @@ -83,7 +98,7 @@ def __init__(self, self._tm_lock = threading.Lock() self._stop_event = threading.Event() - self._scheduler_by_node_uid = {} + self._scheduler_by_node_uid: Dict[task_lib.NodeUid, _SchedulerWrapper] = {} # Async executor for the main task management thread. self._main_executor = futures.ThreadPoolExecutor(max_workers=1) @@ -171,8 +186,9 @@ def _handle_exec_node_task(self, task: task_lib.ExecNodeTask) -> None: raise RuntimeError( 'Cannot create multiple task schedulers for the same task; ' 'task_id: {}'.format(task.task_id)) - scheduler = ts.TaskSchedulerRegistry.create_task_scheduler( - self._mlmd_handle, task.pipeline, task) + scheduler = _SchedulerWrapper( + ts.TaskSchedulerRegistry.create_task_scheduler( + self._mlmd_handle, task.pipeline, task)) self._scheduler_by_node_uid[node_uid] = scheduler self._ts_futures.add( self._ts_executor.submit(self._process_exec_node_task, scheduler, @@ -189,10 +205,10 @@ def _handle_cancel_node_task(self, task: task_lib.CancelNodeTask) -> None: 'No task scheduled for node uid: %s. The task might have already ' 'completed before it could be cancelled.', task.node_uid) else: - scheduler.cancel() + scheduler.cancel(task.pause) self._task_queue.task_done(task) - def _process_exec_node_task(self, scheduler: ts.TaskScheduler, + def _process_exec_node_task(self, scheduler: _SchedulerWrapper, task: task_lib.ExecNodeTask) -> None: """Processes an `ExecNodeTask` using the given task scheduler.""" # This is a blocking call to the scheduler which can take a long time to @@ -210,8 +226,12 @@ def _process_exec_node_task(self, scheduler: ts.TaskScheduler, code=status_lib.Code.ABORTED, message=str(e))) logging.info('For ExecNodeTask id: %s, task-scheduler result status: %s', task.task_id, result.status) - _publish_execution_results( - mlmd_handle=self._mlmd_handle, task=task, result=result) + # If the node was paused, we do not complete the execution as it is expected + # that a new ExecNodeTask would be issued for resuming the execution. + if not (scheduler.pause and + result.status.code == status_lib.Code.CANCELLED): + _publish_execution_results( + mlmd_handle=self._mlmd_handle, task=task, result=result) with self._tm_lock: del self._scheduler_by_node_uid[task.node_uid] self._task_queue.task_done(task) diff --git a/tfx/orchestration/experimental/core/task_manager_test.py b/tfx/orchestration/experimental/core/task_manager_test.py index f78fc1db3b..69ee42d8fb 100644 --- a/tfx/orchestration/experimental/core/task_manager_test.py +++ b/tfx/orchestration/experimental/core/task_manager_test.py @@ -46,11 +46,11 @@ def _test_exec_node_task(node_id, pipeline_id, pipeline=None): return test_utils.create_exec_node_task(node_uid, pipeline=pipeline) -def _test_cancel_node_task(node_id, pipeline_id): +def _test_cancel_node_task(node_id, pipeline_id, pause=False): node_uid = task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid(pipeline_id=pipeline_id), node_id=node_id) - return task_lib.CancelNodeTask(node_uid=node_uid) + return task_lib.CancelNodeTask(node_uid=node_uid, pause=pause) class _Collector: @@ -108,10 +108,12 @@ def setUp(self): deployment_config.executor_specs['Trainer'].Pack(executor_spec) deployment_config.executor_specs['Transform'].Pack(executor_spec) deployment_config.executor_specs['Evaluator'].Pack(executor_spec) + deployment_config.executor_specs['Pusher'].Pack(executor_spec) pipeline = pipeline_pb2.Pipeline() pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer' pipeline.nodes.add().pipeline_node.node_info.id = 'Transform' pipeline.nodes.add().pipeline_node.node_info.id = 'Evaluator' + pipeline.nodes.add().pipeline_node.node_info.id = 'Pusher' pipeline.pipeline_info.id = 'test-pipeline' pipeline.deployment_config.Pack(deployment_config) @@ -140,7 +142,7 @@ def test_task_handling(self, mock_publish): self._type_url, functools.partial( _FakeTaskScheduler, - block_nodes={'Trainer', 'Transform'}, + block_nodes={'Trainer', 'Transform', 'Pusher'}, collector=collector)) task_queue = tq.TaskQueue() @@ -160,16 +162,27 @@ def test_task_handling(self, mock_publish): 'Evaluator', 'test-pipeline', pipeline=self._pipeline) task_queue.enqueue(evaluator_exec_task) task_queue.enqueue(_test_cancel_node_task('Transform', 'test-pipeline')) + pusher_exec_task = _test_exec_node_task( + 'Pusher', 'test-pipeline', pipeline=self._pipeline) + task_queue.enqueue(pusher_exec_task) + task_queue.enqueue( + _test_cancel_node_task('Pusher', 'test-pipeline', pause=True)) self.assertTrue(task_manager.done()) self.assertIsNone(task_manager.exception()) # Ensure that all exec and cancellation tasks were processed correctly. - self.assertCountEqual( - [trainer_exec_task, transform_exec_task, evaluator_exec_task], - collector.scheduled_tasks) - self.assertCountEqual([trainer_exec_task, transform_exec_task], - collector.cancelled_tasks) + self.assertCountEqual([ + trainer_exec_task, + transform_exec_task, + evaluator_exec_task, + pusher_exec_task, + ], collector.scheduled_tasks) + self.assertCountEqual([ + trainer_exec_task, + transform_exec_task, + pusher_exec_task, + ], collector.cancelled_tasks) result_ok = ts.TaskSchedulerResult( status=status_lib.Status( @@ -191,6 +204,8 @@ def test_task_handling(self, mock_publish): mlmd_handle=mock.ANY, task=evaluator_exec_task, result=result_ok), ], any_order=True) + # It is expected that publish is not called for Pusher because it was + # cancelled with pause=True so there must be only 3 calls. self.assertLen(mock_publish.mock_calls, 3) @mock.patch.object(tm, '_publish_execution_results')