Skip to content

Commit

Permalink
Add support for updating a pipeline in-flight
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 392597217
  • Loading branch information
goutham authored and tfx-copybara committed Aug 24, 2021
1 parent 3ef1134 commit 6167deb
Show file tree
Hide file tree
Showing 7 changed files with 324 additions and 52 deletions.
88 changes: 67 additions & 21 deletions tfx/orchestration/experimental/core/pipeline_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,17 @@ def stop_node(
mlmd_handle, active_executions[0].id, timeout_secs=timeout_secs)


@_to_status_not_ok_error
def initiate_pipeline_update(
mlmd_handle: metadata.Metadata,
pipeline: pipeline_pb2.Pipeline) -> pstate.PipelineState:
"""Initiates pipeline update."""
pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
with pstate.PipelineState.load(mlmd_handle, pipeline_uid) as pipeline_state:
pipeline_state.initiate_update(pipeline)
return pipeline_state


@_to_status_not_ok_error
def _wait_for_inactivation(
mlmd_handle: metadata.Metadata,
Expand Down Expand Up @@ -271,10 +282,13 @@ def orchestrate(mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue,

active_pipeline_states = []
stop_initiated_pipeline_states = []
update_initiated_pipeline_states = []
for pipeline_state in pipeline_states:
with pipeline_state:
if pipeline_state.stop_initiated_reason() is not None:
if pipeline_state.is_stop_initiated():
stop_initiated_pipeline_states.append(pipeline_state)
elif pipeline_state.is_update_initiated():
update_initiated_pipeline_states.append(pipeline_state)
elif pipeline_state.is_active():
active_pipeline_states.append(pipeline_state)
else:
Expand All @@ -289,6 +303,12 @@ def orchestrate(mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue,
_orchestrate_stop_initiated_pipeline(mlmd_handle, task_queue,
service_job_manager, pipeline_state)

for pipeline_state in update_initiated_pipeline_states:
logging.info('Orchestrating update-initiated pipeline: %s',
pipeline_state.pipeline_uid)
_orchestrate_update_initiated_pipeline(mlmd_handle, task_queue,
service_job_manager, pipeline_state)

for pipeline_state in active_pipeline_states:
logging.info('Orchestrating pipeline: %s', pipeline_state.pipeline_uid)
_orchestrate_active_pipeline(mlmd_handle, task_queue, service_job_manager,
Expand All @@ -315,32 +335,53 @@ def _get_pipeline_states(
return result


def _orchestrate_stop_initiated_pipeline(
mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue,
service_job_manager: service_jobs.ServiceJobManager,
pipeline_state: pstate.PipelineState) -> None:
"""Orchestrates stop initiated pipeline."""
with pipeline_state:
stop_reason = pipeline_state.stop_initiated_reason()
assert stop_reason is not None
def _cancel_nodes(mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue,
service_job_manager: service_jobs.ServiceJobManager,
pipeline_state: pstate.PipelineState, pause: bool) -> bool:
"""Cancels pipeline nodes and returns `True` if any node is currently active."""
pipeline = pipeline_state.pipeline
has_active_executions = False
is_active = False
for node in pstate.get_all_pipeline_nodes(pipeline):
if service_job_manager.is_pure_service_node(pipeline_state,
node.node_info.id):
service_job_manager.stop_node_services(pipeline_state, node.node_info.id)
elif _maybe_enqueue_cancellation_task(mlmd_handle, pipeline, node,
task_queue):
has_active_executions = True
elif _maybe_enqueue_cancellation_task(
mlmd_handle, pipeline, node, task_queue, pause=pause):
is_active = True
elif service_job_manager.is_mixed_service_node(pipeline_state,
node.node_info.id):
service_job_manager.stop_node_services(pipeline_state, node.node_info.id)
if not has_active_executions:
return is_active


def _orchestrate_stop_initiated_pipeline(
mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue,
service_job_manager: service_jobs.ServiceJobManager,
pipeline_state: pstate.PipelineState) -> None:
"""Orchestrates stop initiated pipeline."""
with pipeline_state:
stop_reason = pipeline_state.stop_initiated_reason()
assert stop_reason is not None
is_active = _cancel_nodes(
mlmd_handle, task_queue, service_job_manager, pipeline_state, pause=False)
if not is_active:
with pipeline_state:
# Update pipeline execution state in MLMD.
pipeline_state.set_pipeline_execution_state_from_status(stop_reason)


def _orchestrate_update_initiated_pipeline(
mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue,
service_job_manager: service_jobs.ServiceJobManager,
pipeline_state: pstate.PipelineState) -> None:
"""Orchestrates an update-initiated pipeline."""
is_active = _cancel_nodes(
mlmd_handle, task_queue, service_job_manager, pipeline_state, pause=True)
if not is_active:
with pipeline_state:
pipeline_state.apply_pipeline_update()


@attr.s(auto_attribs=True, kw_only=True)
class _NodeInfo:
"""A convenience container of pipeline node and its state."""
Expand Down Expand Up @@ -464,21 +505,25 @@ def _get_node_infos(pipeline_state: pstate.PipelineState) -> List[_NodeInfo]:
def _maybe_enqueue_cancellation_task(mlmd_handle: metadata.Metadata,
pipeline: pipeline_pb2.Pipeline,
node: pipeline_pb2.PipelineNode,
task_queue: tq.TaskQueue) -> bool:
task_queue: tq.TaskQueue,
pause: bool = False) -> bool:
"""Enqueues a node cancellation task if not already stopped.
If the node has an ExecNodeTask in the task queue, issue a cancellation.
Otherwise, if the node has an active execution in MLMD but no ExecNodeTask
enqueued, it may be due to orchestrator restart after stopping was initiated
but before the schedulers could finish. So, enqueue an ExecNodeTask with
is_cancelled set to give a chance for the scheduler to finish gracefully.
Otherwise, when pause=False, if the node has an active execution in MLMD but
no ExecNodeTask enqueued, it may be due to orchestrator restart after stopping
was initiated but before the schedulers could finish. So, enqueue an
ExecNodeTask with is_cancelled set to give a chance for the scheduler to
finish gracefully.
Args:
mlmd_handle: A handle to the MLMD db.
pipeline: The pipeline containing the node to cancel.
node: The node to cancel.
task_queue: A `TaskQueue` instance into which any cancellation tasks will be
enqueued.
pause: Whether the cancellation is to pause the node rather than cancelling
the execution.
Returns:
`True` if a cancellation task was enqueued. `False` if node is already
Expand All @@ -489,9 +534,10 @@ def _maybe_enqueue_cancellation_task(mlmd_handle: metadata.Metadata,
if task_queue.contains_task_id(exec_node_task_id):
task_queue.enqueue(
task_lib.CancelNodeTask(
node_uid=task_lib.NodeUid.from_pipeline_node(pipeline, node)))
node_uid=task_lib.NodeUid.from_pipeline_node(pipeline, node),
pause=pause))
return True
else:
if not pause:
executions = task_gen_utils.get_executions(mlmd_handle, node)
exec_node_task = task_gen_utils.generate_task_from_active_execution(
mlmd_handle, pipeline, node, executions, is_cancelled=True)
Expand Down
83 changes: 81 additions & 2 deletions tfx/orchestration/experimental/core/pipeline_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,10 @@ def test_orchestrate_active_pipelines(self, mock_async_task_gen,
@mock.patch.object(sync_pipeline_task_gen, 'SyncPipelineTaskGenerator')
@mock.patch.object(async_pipeline_task_gen, 'AsyncPipelineTaskGenerator')
@mock.patch.object(task_gen_utils, 'generate_task_from_active_execution')
def test_stop_initiated_pipelines(self, pipeline, mock_gen_task_from_active,
mock_async_task_gen, mock_sync_task_gen):
def test_orchestrate_stop_initiated_pipelines(self, pipeline,
mock_gen_task_from_active,
mock_async_task_gen,
mock_sync_task_gen):
with self._mlmd_connection as m:
pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen'
pipeline.nodes.add().pipeline_node.node_info.id = 'Transform'
Expand Down Expand Up @@ -459,6 +461,83 @@ def test_stop_initiated_pipelines(self, pipeline, mock_gen_task_from_active,
mock.call(mock.ANY, 'Transform')],
any_order=True)

@parameterized.parameters(
_test_pipeline('pipeline1'),
_test_pipeline('pipeline1', pipeline_pb2.Pipeline.SYNC))
def test_orchestrate_update_initiated_pipelines(self, pipeline):
with self._mlmd_connection as m:
pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen'
pipeline.nodes.add().pipeline_node.node_info.id = 'Transform'
pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer'
pipeline.nodes.add().pipeline_node.node_info.id = 'Evaluator'

mock_service_job_manager = mock.create_autospec(
service_jobs.ServiceJobManager, instance=True)
mock_service_job_manager.is_pure_service_node.side_effect = (
lambda _, node_id: node_id == 'ExampleGen')
mock_service_job_manager.is_mixed_service_node.side_effect = (
lambda _, node_id: node_id == 'Transform')

pipeline_ops.initiate_pipeline_start(m, pipeline)

task_queue = tq.TaskQueue()

for node_id in ('Transform', 'Trainer', 'Evaluator'):
task_queue.enqueue(
test_utils.create_exec_node_task(
task_lib.NodeUid(
pipeline_uid=task_lib.PipelineUid.from_pipeline(pipeline),
node_id=node_id)))

pipeline_state = pipeline_ops.initiate_pipeline_update(m, pipeline)
with pipeline_state:
self.assertTrue(pipeline_state.is_update_initiated())

pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager)

# stop_node_services should be called for ExampleGen which is a pure
# service node.
mock_service_job_manager.stop_node_services.assert_called_once_with(
mock.ANY, 'ExampleGen')
mock_service_job_manager.reset_mock()

# Simulate completion of all the exec node tasks.
for node_id in ('Transform', 'Trainer', 'Evaluator'):
task = task_queue.dequeue()
task_queue.task_done(task)
self.assertTrue(task_lib.is_exec_node_task(task))
self.assertEqual(node_id, task.node_uid.node_id)

# Verify that cancellation tasks were enqueued in the last `orchestrate`
# call, and dequeue them.
for node_id in ('Transform', 'Trainer', 'Evaluator'):
task = task_queue.dequeue()
task_queue.task_done(task)
self.assertTrue(task_lib.is_cancel_node_task(task))
self.assertEqual(node_id, task.node_uid.node_id)
self.assertTrue(task.pause)

self.assertTrue(task_queue.is_empty())

# Pipeline continues to be in update initiated state until all
# ExecNodeTasks have been dequeued (which was not the case when last
# `orchestrate` call was made).
with pipeline_state:
self.assertTrue(pipeline_state.is_update_initiated())

pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager)

# stop_node_services should be called for Transform (mixed service node)
# too since corresponding ExecNodeTask has been processed.
mock_service_job_manager.stop_node_services.assert_has_calls(
[mock.call(mock.ANY, 'ExampleGen'),
mock.call(mock.ANY, 'Transform')])

# Pipeline should no longer be in update-initiated state but be active.
with pipeline_state:
self.assertFalse(pipeline_state.is_update_initiated())
self.assertTrue(pipeline_state.is_active())

@parameterized.parameters(
_test_pipeline('pipeline1'),
_test_pipeline('pipeline1', pipeline_pb2.Pipeline.SYNC))
Expand Down
79 changes: 72 additions & 7 deletions tfx/orchestration/experimental/core/pipeline_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import contextlib
import threading
import time
from typing import Dict, Iterator, List, Mapping, Optional
from typing import Dict, Iterator, List, Mapping, Optional, Tuple

import attr
from tfx import types
Expand All @@ -41,6 +41,7 @@
_PIPELINE_STATUS_MSG = 'pipeline_status_msg'
_NODE_STATES = 'node_states'
_PIPELINE_RUN_METADATA = 'pipeline_run_metadata'
_UPDATED_PIPELINE_IR = 'updated_pipeline_ir'
_ORCHESTRATOR_EXECUTION_TYPE = metadata_store_pb2.ExecutionType(
name=_ORCHESTRATOR_RESERVED_ID,
properties={_PIPELINE_IR: metadata_store_pb2.STRING})
Expand Down Expand Up @@ -139,7 +140,6 @@ def __init__(self, mlmd_handle: metadata.Metadata,
self.mlmd_handle = mlmd_handle
self.pipeline = pipeline
self.execution_id = execution_id
self.pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)

# Only set within the pipeline state context.
self._mlmd_execution_atomic_op_context = None
Expand Down Expand Up @@ -181,8 +181,7 @@ def new(
message=f'Pipeline with uid {pipeline_uid} already active.')

exec_properties = {
_PIPELINE_IR:
base64.b64encode(pipeline.SerializeToString()).decode('utf-8')
_PIPELINE_IR: _base64_encode_pipeline(pipeline)
}
if pipeline_run_metadata:
exec_properties[_PIPELINE_RUN_METADATA] = json_utils.dumps(
Expand Down Expand Up @@ -258,6 +257,10 @@ def load_from_orchestrator_context(
pipeline=pipeline,
execution_id=active_execution.id)

@property
def pipeline_uid(self) -> task_lib.PipelineUid:
return task_lib.PipelineUid.from_pipeline(self.pipeline)

def is_active(self) -> bool:
"""Returns `True` if pipeline is active."""
self._check_context()
Expand All @@ -277,6 +280,60 @@ def initiate_stop(self, status: status_lib.Status) -> None:
status.message)
record_state_change_time()

def initiate_update(self, updated_pipeline: pipeline_pb2.Pipeline) -> None:
"""Initiates pipeline update process."""
self._check_context()

if self.pipeline.execution_mode != updated_pipeline.execution_mode:
raise status_lib.StatusNotOkError(
code=status_lib.Code.INVALID_ARGUMENT,
message=('Updating execution_mode of an active pipeline is not '
'supported'))

# TODO(b/194311197): We require that structure of the updated pipeline
# exactly matches the original. There is scope to relax this restriction.

def _structure(
pipeline: pipeline_pb2.Pipeline
) -> List[Tuple[str, List[str], List[str]]]:
return [(node.node_info.id, list(node.upstream_nodes),
list(node.downstream_nodes))
for node in get_all_pipeline_nodes(pipeline)]

if _structure(self.pipeline) != _structure(updated_pipeline):
raise status_lib.StatusNotOkError(
code=status_lib.Code.INVALID_ARGUMENT,
message=('Updated pipeline should have the same structure as the '
'original.'))

data_types_utils.set_metadata_value(
self._execution.custom_properties[_UPDATED_PIPELINE_IR],
_base64_encode_pipeline(updated_pipeline))
record_state_change_time()

def is_update_initiated(self) -> bool:
self._check_context()
return self.is_active() and self._execution.custom_properties.get(
_UPDATED_PIPELINE_IR) is not None

def apply_pipeline_update(self) -> None:
"""Applies pipeline update that was previously initiated."""
self._check_context()
updated_pipeline_ir = _get_metadata_value(
self._execution.custom_properties.get(_UPDATED_PIPELINE_IR))
if not updated_pipeline_ir:
raise status_lib.StatusNotOkError(
code=status_lib.Code.INVALID_ARGUMENT,
message='No updated pipeline IR to apply')
data_types_utils.set_metadata_value(
self._execution.custom_properties[_PIPELINE_IR], updated_pipeline_ir)
del self._execution.custom_properties[_UPDATED_PIPELINE_IR]
self.pipeline = _base64_decode_pipeline(updated_pipeline_ir)

def is_stop_initiated(self) -> bool:
self._check_context()
return self.stop_initiated_reason() is not None

def stop_initiated_reason(self) -> Optional[status_lib.Status]:
"""Returns status object if stop initiated, `None` otherwise."""
self._check_context()
Expand Down Expand Up @@ -552,9 +609,7 @@ def _get_pipeline_from_orchestrator_execution(
execution: metadata_store_pb2.Execution) -> pipeline_pb2.Pipeline:
pipeline_ir_b64 = data_types_utils.get_metadata_value(
execution.properties[_PIPELINE_IR])
pipeline = pipeline_pb2.Pipeline()
pipeline.ParseFromString(base64.b64decode(pipeline_ir_b64))
return pipeline
return _base64_decode_pipeline(pipeline_ir_b64)


def _get_active_execution(
Expand Down Expand Up @@ -587,3 +642,13 @@ def _get_creation_time(execution):
return execution.create_time_since_epoch

return max(executions, key=_get_creation_time)


def _base64_encode_pipeline(pipeline: pipeline_pb2.Pipeline) -> str:
return base64.b64encode(pipeline.SerializeToString()).decode('utf-8')


def _base64_decode_pipeline(pipeline_encoded: str) -> pipeline_pb2.Pipeline:
result = pipeline_pb2.Pipeline()
result.ParseFromString(base64.b64decode(pipeline_encoded))
return result
Loading

0 comments on commit 6167deb

Please sign in to comment.