Skip to content

Commit

Permalink
Surface granular node run states for sync pipelines
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 396762590
  • Loading branch information
goutham authored and tfx-copybara committed Sep 15, 2021
1 parent c96262e commit d21f213
Show file tree
Hide file tree
Showing 14 changed files with 605 additions and 225 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from ml_metadata.proto import metadata_store_pb2


# TODO(b/199908896): Surface granular node states similar to sync pipeline task
# generator.
class AsyncPipelineTaskGenerator(task_gen.TaskGenerator):
"""Task generator for executing an async pipeline.
Expand Down
11 changes: 9 additions & 2 deletions tfx/orchestration/experimental/core/mlmd_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import copy
import threading
import typing
from typing import Iterator, MutableMapping
from typing import Callable, Iterator, MutableMapping, Optional

import cachetools
from tfx.orchestration import metadata
Expand Down Expand Up @@ -107,7 +107,9 @@ def clear_cache(self):
@contextlib.contextmanager
def mlmd_execution_atomic_op(
mlmd_handle: metadata.Metadata,
execution_id: int) -> Iterator[metadata_store_pb2.Execution]:
execution_id: int,
on_commit: Optional[Callable[[], None]] = None
) -> Iterator[metadata_store_pb2.Execution]:
"""Context manager for accessing or mutating an execution atomically.
The idea of using this context manager is to ensure that the in-memory state
Expand All @@ -122,6 +124,9 @@ def mlmd_execution_atomic_op(
Args:
mlmd_handle: A handle to MLMD db.
execution_id: Id of the execution to yield.
on_commit: An optional callback function which is invoked post successful
MLMD execution commit operation. This won't be invoked if execution is not
mutated within the context and hence MLMD commit is not needed.
Yields:
If execution with given id exists in MLMD, the execution is yielded under
Expand All @@ -143,6 +148,8 @@ def mlmd_execution_atomic_op(
# Make a copy before writing to cache as the yielded `execution_copy`
# object may be modified even after exiting the contextmanager.
_execution_cache.put_execution(mlmd_handle, copy.deepcopy(execution_copy))
if on_commit is not None:
on_commit()


def clear_in_memory_state():
Expand Down
7 changes: 6 additions & 1 deletion tfx/orchestration/experimental/core/mlmd_state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,20 +90,25 @@ def setUp(self):
connection_config=connection_config)

def test_mlmd_execution_update(self):
event_on_commit = threading.Event()
with self._mlmd_connection as m:
expected_execution = _write_test_execution(m)
# Mutate execution.
with mlmd_state.mlmd_execution_atomic_op(
m, expected_execution.id) as execution:
m, expected_execution.id, on_commit=event_on_commit.set) as execution:
self.assertEqual(expected_execution, execution)
execution.last_known_state = metadata_store_pb2.Execution.CANCELED
self.assertFalse(event_on_commit.is_set()) # not yet invoked.
# Test that updated execution is committed to MLMD.
[execution] = m.store.get_executions_by_id([execution.id])
self.assertEqual(metadata_store_pb2.Execution.CANCELED,
execution.last_known_state)
# Test that in-memory state is also in sync.
self.assertEqual(execution,
mlmd_state._execution_cache._cache[execution.id])
# Test that on_commit callback was invoked.
self.assertTrue(event_on_commit.is_set())
# Sanity checks that the updated execution is yielded in the next call.
with mlmd_state.mlmd_execution_atomic_op(
m, expected_execution.id) as execution2:
self.assertEqual(execution, execution2)
Expand Down
31 changes: 20 additions & 11 deletions tfx/orchestration/experimental/core/pipeline_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,7 @@ def initiate_node_start(mlmd_handle: metadata.Metadata,
with pstate.PipelineState.load(mlmd_handle,
node_uid.pipeline_uid) as pipeline_state:
with pipeline_state.node_state_update_context(node_uid) as node_state:
if node_state.state not in (pstate.NodeState.STARTING,
pstate.NodeState.STARTED):
if node_state.is_startable():
node_state.update(pstate.NodeState.STARTING)
return pipeline_state

Expand Down Expand Up @@ -193,8 +192,7 @@ def stop_node(mlmd_handle: metadata.Metadata,
f'{node_uid}'))
node = filtered_nodes[0]
with pipeline_state.node_state_update_context(node_uid) as node_state:
if node_state.state not in (pstate.NodeState.STOPPING,
pstate.NodeState.STOPPED):
if node_state.is_stoppable():
node_state.update(
pstate.NodeState.STOPPING,
status_lib.Status(
Expand Down Expand Up @@ -556,15 +554,25 @@ def _filter_by_state(node_infos: List[_NodeInfo],

tasks = generator.generate()

# Change the state of all nodes in state STARTING to STARTED.
starting_node_infos = _filter_by_state(node_infos, pstate.NodeState.STARTING)
with pipeline_state:
for node_info in starting_node_infos:
node_uid = task_lib.NodeUid.from_pipeline_node(pipeline, node_info.node)
# Handle all the UpdateNodeStateTasks by updating node states.
for task in tasks:
if task_lib.is_update_node_state_task(task):
task = typing.cast(task_lib.UpdateNodeStateTask, task)
with pipeline_state.node_state_update_context(
task.node_uid) as node_state:
node_state.update(task.state, task.status)

tasks = [t for t in tasks if not task_lib.is_update_node_state_task(t)]

# If there are still nodes in state STARTING, change them to STARTED.
for node in pstate.get_all_pipeline_nodes(pipeline_state.pipeline):
node_uid = task_lib.NodeUid.from_pipeline_node(pipeline_state.pipeline,
node)
with pipeline_state.node_state_update_context(node_uid) as node_state:
node_state.update(pstate.NodeState.STARTED)
if node_state.state == pstate.NodeState.STARTING:
node_state.update(pstate.NodeState.STARTED)

with pipeline_state:
for task in tasks:
if task_lib.is_exec_node_task(task):
task = typing.cast(task_lib.ExecNodeTask, task)
Expand All @@ -574,7 +582,8 @@ def _filter_by_state(node_infos: List[_NodeInfo],
task = typing.cast(task_lib.FinalizeNodeTask, task)
with pipeline_state.node_state_update_context(
task.node_uid) as node_state:
node_state.update(pstate.NodeState.STOPPING, task.status)
if node_state.is_stoppable():
node_state.update(pstate.NodeState.STOPPING, task.status)
else:
assert task_lib.is_finalize_pipeline_task(task)
assert pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC
Expand Down
65 changes: 65 additions & 0 deletions tfx/orchestration/experimental/core/pipeline_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,71 @@ def test_resume_manual_node(self):
self.assertEqual(node_state.state,
manual_task_scheduler.ManualNodeState.COMPLETED)

@mock.patch.object(sync_pipeline_task_gen, 'SyncPipelineTaskGenerator')
def test_update_node_state_tasks_handling(self, mock_sync_task_gen):
with self._mlmd_connection as m:
pipeline = _test_pipeline(
'pipeline1', execution_mode=pipeline_pb2.Pipeline.SYNC)
pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen'
pipeline.nodes.add().pipeline_node.node_info.id = 'Transform'
pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer'
pipeline.nodes.add().pipeline_node.node_info.id = 'Evaluator'
pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
eg_node_uid = task_lib.NodeUid(pipeline_uid, 'ExampleGen')
transform_node_uid = task_lib.NodeUid(pipeline_uid, 'Transform')
trainer_node_uid = task_lib.NodeUid(pipeline_uid, 'Trainer')
evaluator_node_uid = task_lib.NodeUid(pipeline_uid, 'Evaluator')

with pipeline_ops.initiate_pipeline_start(m, pipeline) as pipeline_state:
# Set initial states for the nodes.
with pipeline_state.node_state_update_context(
eg_node_uid) as node_state:
node_state.update(pstate.NodeState.RUNNING)
with pipeline_state.node_state_update_context(
transform_node_uid) as node_state:
node_state.update(pstate.NodeState.STARTING)
with pipeline_state.node_state_update_context(
trainer_node_uid) as node_state:
node_state.update(pstate.NodeState.STARTED)
with pipeline_state.node_state_update_context(
evaluator_node_uid) as node_state:
node_state.update(pstate.NodeState.RUNNING)

mock_sync_task_gen.return_value.generate.side_effect = [
[
task_lib.UpdateNodeStateTask(
node_uid=eg_node_uid, state=pstate.NodeState.COMPLETE),
task_lib.UpdateNodeStateTask(
node_uid=trainer_node_uid, state=pstate.NodeState.RUNNING),
task_lib.UpdateNodeStateTask(
node_uid=evaluator_node_uid,
state=pstate.NodeState.FAILED,
status=status_lib.Status(
code=status_lib.Code.ABORTED, message='foobar error'))
],
]

task_queue = tq.TaskQueue()
pipeline_ops.orchestrate(m, task_queue,
service_jobs.DummyServiceJobManager())
self.assertEqual(1, mock_sync_task_gen.return_value.generate.call_count)

with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state:
self.assertEqual(pstate.NodeState.COMPLETE,
pipeline_state.get_node_state(eg_node_uid).state)
self.assertEqual(
pstate.NodeState.STARTED,
pipeline_state.get_node_state(transform_node_uid).state)
self.assertEqual(pstate.NodeState.RUNNING,
pipeline_state.get_node_state(trainer_node_uid).state)
self.assertEqual(
pstate.NodeState.FAILED,
pipeline_state.get_node_state(evaluator_node_uid).state)
self.assertEqual(
status_lib.Status(
code=status_lib.Code.ABORTED, message='foobar error'),
pipeline_state.get_node_state(evaluator_node_uid).status)


if __name__ == '__main__':
tf.test.main()
89 changes: 60 additions & 29 deletions tfx/orchestration/experimental/core/pipeline_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from tfx.orchestration.portable.mlmd import context_lib
from tfx.orchestration.portable.mlmd import execution_lib
from tfx.proto.orchestration import pipeline_pb2
from tfx.proto.orchestration import run_state_pb2
from tfx.utils import json_utils
from tfx.utils import status as status_lib

Expand Down Expand Up @@ -59,24 +60,23 @@ class NodeState(json_utils.Jsonable):
status: Status of the node in state STOPPING or STOPPED.
"""

# This state indicates that the node is pending some internal work before its
# state can be changed to STARTED.
STARTING = 'starting'

# This state indicates that the node is started and can run when triggering
# events occur.
STARTED = 'started'

# This state indicates that the node is pending some internal work before its
# state can be changed to STOPPED.
STOPPING = 'stopping'

# This state indicates that the node is stopped.
STOPPED = 'stopped'
STARTING = 'starting' # Pending work before state can change to STARTED.
STARTED = 'started' # Node is ready for execution.
STOPPING = 'stopping' # Pending work before state can change to STOPPED.
STOPPED = 'stopped' # Node execution is stoped.
RUNNING = 'running' # Node is under active execution (i.e. triggered).
COMPLETE = 'complete' # Node execution completed successfully.
SKIPPED = 'skipped' # Node execution skipped due to conditional.
PAUSED = 'paused' # Node was paused and may be resumed in the future.
FAILED = 'failed' # Node execution failed due to errors.

state: str = attr.ib(
default=STARTED,
validator=attr.validators.in_([STARTING, STARTED, STOPPING, STOPPED]))
validator=attr.validators.in_([
STARTING, STARTED, STOPPING, STOPPED, RUNNING, COMPLETE, SKIPPED,
PAUSED, FAILED
]),
on_setattr=attr.setters.validate)
status_code: Optional[int] = None
status_msg: str = ''

Expand All @@ -97,6 +97,29 @@ def update(self,
self.status_code = None
self.status_msg = ''

def is_startable(self) -> bool:
"""Returns True if the node can be started."""
return self.state in set(
[self.PAUSED, self.STOPPING, self.STOPPED, self.FAILED])

def is_stoppable(self) -> bool:
"""Returns True if the node can be stopped."""
return self.state in set(
[self.STARTING, self.STARTED, self.RUNNING, self.PAUSED])


_NODE_STATE_TO_RUN_STATE_MAP = {
NodeState.STARTING: run_state_pb2.RunState.UNKNOWN,
NodeState.STARTED: run_state_pb2.RunState.READY,
NodeState.STOPPING: run_state_pb2.RunState.UNKNOWN,
NodeState.STOPPED: run_state_pb2.RunState.STOPPED,
NodeState.RUNNING: run_state_pb2.RunState.RUNNING,
NodeState.COMPLETE: run_state_pb2.RunState.COMPLETE,
NodeState.SKIPPED: run_state_pb2.RunState.SKIPPED,
NodeState.PAUSED: run_state_pb2.RunState.PAUSED,
NodeState.FAILED: run_state_pb2.RunState.FAILED
}


def record_state_change_time() -> None:
"""Records current time at the point of function call as state change time.
Expand Down Expand Up @@ -180,9 +203,7 @@ def new(
code=status_lib.Code.ALREADY_EXISTS,
message=f'Pipeline with uid {pipeline_uid} already active.')

exec_properties = {
_PIPELINE_IR: _base64_encode_pipeline(pipeline)
}
exec_properties = {_PIPELINE_IR: _base64_encode_pipeline(pipeline)}
if pipeline_run_metadata:
exec_properties[_PIPELINE_RUN_METADATA] = json_utils.dumps(
pipeline_run_metadata)
Expand Down Expand Up @@ -285,7 +306,6 @@ def initiate_stop(self, status: status_lib.Status) -> None:
data_types_utils.set_metadata_value(
self._execution.custom_properties[_PIPELINE_STATUS_MSG],
status.message)
record_state_change_time()

def initiate_update(self, updated_pipeline: pipeline_pb2.Pipeline) -> None:
"""Initiates pipeline update process."""
Expand Down Expand Up @@ -327,7 +347,6 @@ def _structure(
data_types_utils.set_metadata_value(
self._execution.custom_properties[_UPDATED_PIPELINE_IR],
_base64_encode_pipeline(updated_pipeline))
record_state_change_time()

def is_update_initiated(self) -> bool:
self._check_context()
Expand Down Expand Up @@ -375,11 +394,10 @@ def node_state_update_context(
code=status_lib.Code.INVALID_ARGUMENT,
message=(f'Node {node_uid} does not belong to the pipeline '
f'{self.pipeline_uid}'))
node_states_dict = self._get_node_states_dict()
node_states_dict = _get_node_states_dict(self._execution)
node_state = node_states_dict.setdefault(node_uid.node_id, NodeState())
yield node_state
self._save_node_states_dict(node_states_dict)
record_state_change_time()

def get_node_state(self, node_uid: task_lib.NodeUid) -> NodeState:
self._check_context()
Expand All @@ -388,7 +406,7 @@ def get_node_state(self, node_uid: task_lib.NodeUid) -> NodeState:
code=status_lib.Code.INVALID_ARGUMENT,
message=(f'Node {node_uid} does not belong to the pipeline '
f'{self.pipeline_uid}'))
node_states_dict = self._get_node_states_dict()
node_states_dict = _get_node_states_dict(self._execution)
return node_states_dict.get(node_uid.node_id, NodeState())

def get_pipeline_execution_state(self) -> metadata_store_pb2.Execution.State:
Expand Down Expand Up @@ -425,19 +443,14 @@ def remove_property(self, property_key: str) -> None:
if self._execution.custom_properties.get(property_key):
del self._execution.custom_properties[property_key]

def _get_node_states_dict(self) -> Dict[str, NodeState]:
node_states_json = _get_metadata_value(
self._execution.custom_properties.get(_NODE_STATES))
return json_utils.loads(node_states_json) if node_states_json else {}

def _save_node_states_dict(self, node_states: Dict[str, NodeState]) -> None:
data_types_utils.set_metadata_value(
self._execution.custom_properties[_NODE_STATES],
json_utils.dumps(node_states))

def __enter__(self) -> 'PipelineState':
mlmd_execution_atomic_op_context = mlmd_state.mlmd_execution_atomic_op(
self.mlmd_handle, self.execution_id)
self.mlmd_handle, self.execution_id, record_state_change_time)
execution = mlmd_execution_atomic_op_context.__enter__()
self._mlmd_execution_atomic_op_context = mlmd_execution_atomic_op_context
self._execution = execution
Expand Down Expand Up @@ -560,6 +573,17 @@ def pipeline_run_metadata(self) -> Dict[str, types.Property]:
return json_utils.loads(
pipeline_run_metadata) if pipeline_run_metadata else {}

def get_node_run_states(self) -> Dict[str, run_state_pb2.RunState]:
"""Returns a dict mapping node id to current run state."""
result = {}
node_states_dict = _get_node_states_dict(self.execution)
for node in get_all_pipeline_nodes(self.pipeline):
node_state = node_states_dict.get(node.node_info.id, NodeState())
result[node.node_info.id] = run_state_pb2.RunState(
state=_NODE_STATE_TO_RUN_STATE_MAP[node_state.state],
status_msg=node_state.status_msg)
return result


def get_orchestrator_contexts(
mlmd_handle: metadata.Metadata) -> List[metadata_store_pb2.Context]:
Expand Down Expand Up @@ -670,3 +694,10 @@ def _base64_decode_pipeline(pipeline_encoded: str) -> pipeline_pb2.Pipeline:
result = pipeline_pb2.Pipeline()
result.ParseFromString(base64.b64decode(pipeline_encoded))
return result


def _get_node_states_dict(
pipeline_execution: metadata_store_pb2.Execution) -> Dict[str, NodeState]:
node_states_json = _get_metadata_value(
pipeline_execution.custom_properties.get(_NODE_STATES))
return json_utils.loads(node_states_json) if node_states_json else {}
Loading

0 comments on commit d21f213

Please sign in to comment.