Skip to content

Commit

Permalink
Support fail_fast setting in SyncPipelineTaskGenerator
Browse files Browse the repository at this point in the history
The default continues to be fail_fast=True. The plumbing to enable setting fail_fast=False is TBD.

PiperOrigin-RevId: 402716811
  • Loading branch information
goutham authored and tfx-copybara committed Oct 13, 2021
1 parent 26b787d commit f314cc7
Show file tree
Hide file tree
Showing 7 changed files with 270 additions and 88 deletions.
12 changes: 12 additions & 0 deletions tfx/orchestration/experimental/core/async_pipeline_task_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def __init__(self, mlmd_handle: metadata.Metadata,
self._pipeline = pipeline
self._is_task_id_tracked_fn = is_task_id_tracked_fn
self._service_job_manager = service_job_manager
# TODO(b/201294315): Remove once the underlying issue is fixed.
self._generate_invoked = False

def generate(self) -> List[task_lib.Task]:
"""Generates tasks for all executable nodes in the async pipeline.
Expand All @@ -80,7 +82,17 @@ def generate(self) -> List[task_lib.Task]:
Returns:
A `list` of tasks to execute.
Raises:
RuntimeError: If `generate` invoked more than once on the same instance.
"""
# TODO(b/201294315): Remove this artificial restriction once the underlying
# issue is fixed.
if self._generate_invoked:
raise RuntimeError(
'Invoking `generate` more than once on the same instance of '
'AsyncPipelineTaskGenerator is restricted due to a bug.')
self._generate_invoked = True
result = []
for node in [n.pipeline_node for n in self._pipeline.nodes]:
node_uid = task_lib.NodeUid.from_pipeline_node(self._pipeline, node)
Expand Down
3 changes: 2 additions & 1 deletion tfx/orchestration/experimental/core/pipeline_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,9 +576,10 @@ def _filter_by_state(node_infos: List[_NodeInfo],

# Initialize task generator for the pipeline.
if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC:
# TODO(b/200618482): Remove fail_fast=True.
generator = sync_pipeline_task_gen.SyncPipelineTaskGenerator(
mlmd_handle, pipeline_state, task_queue.contains_task_id,
service_job_manager)
service_job_manager, fail_fast=True)
elif pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC:
generator = async_pipeline_task_gen.AsyncPipelineTaskGenerator(
mlmd_handle, pipeline_state, task_queue.contains_task_id,
Expand Down
26 changes: 24 additions & 2 deletions tfx/orchestration/experimental/core/pipeline_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class NodeState(json_utils.Jsonable):
STARTING = 'starting' # Pending work before state can change to STARTED.
STARTED = 'started' # Node is ready for execution.
STOPPING = 'stopping' # Pending work before state can change to STOPPED.
STOPPED = 'stopped' # Node execution is stoped.
STOPPED = 'stopped' # Node execution is stopped.
RUNNING = 'running' # Node is under active execution (i.e. triggered).
COMPLETE = 'complete' # Node execution completed successfully.
SKIPPED = 'skipped' # Node execution skipped due to conditional.
Expand Down Expand Up @@ -121,6 +121,20 @@ def is_stoppable(self) -> bool:
return self.state in set(
[self.STARTING, self.STARTED, self.RUNNING, self.PAUSED])

def is_success(self) -> bool:
return is_node_state_success(self.state)

def is_failure(self) -> bool:
return is_node_state_failure(self.state)


def is_node_state_success(state: str) -> bool:
return state in (NodeState.COMPLETE, NodeState.SKIPPED)


def is_node_state_failure(state: str) -> bool:
return state == NodeState.FAILED


_NODE_STATE_TO_RUN_STATE_MAP = {
NodeState.STARTING: run_state_pb2.RunState.UNKNOWN,
Expand Down Expand Up @@ -174,7 +188,7 @@ class PipelineState:
mlmd_handle: Handle to MLMD db.
pipeline: The pipeline proto associated with this `PipelineState` object.
TODO(b/201294315): Fix self.pipeline going out of sync with the actual
pipeline proto stored in the underlying MLMD execution in some cases.
pipeline proto stored in the underlying MLMD execution in some cases.
execution_id: Id of the underlying execution in MLMD.
pipeline_uid: Unique id of the pipeline.
"""
Expand Down Expand Up @@ -437,6 +451,14 @@ def get_node_state(self, node_uid: task_lib.NodeUid) -> NodeState:
node_states_dict = _get_node_states_dict(self._execution)
return node_states_dict.get(node_uid.node_id, NodeState())

def get_node_states_dict(self) -> Dict[task_lib.NodeUid, NodeState]:
self._check_context()
result = {}
for node in get_all_pipeline_nodes(self.pipeline):
node_uid = task_lib.NodeUid.from_pipeline_node(self.pipeline, node)
result[node_uid] = self.get_node_state(node_uid)
return result

def get_pipeline_execution_state(self) -> metadata_store_pb2.Execution.State:
"""Returns state of underlying pipeline execution."""
self._check_context()
Expand Down
37 changes: 37 additions & 0 deletions tfx/orchestration/experimental/core/pipeline_state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,43 @@ def test_initiate_node_start_stop(self):
node_state = pipeline_state.get_node_state(node_uid)
self.assertEqual(pstate.NodeState.STARTED, node_state.state)

def test_get_node_states_dict(self):
with self._mlmd_connection as m:
pipeline = pipeline_pb2.Pipeline()
pipeline.pipeline_info.id = 'pipeline1'
pipeline.execution_mode = pipeline_pb2.Pipeline.SYNC
pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen'
pipeline.nodes.add().pipeline_node.node_info.id = 'Transform'
pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer'
pipeline.nodes.add().pipeline_node.node_info.id = 'Evaluator'
eg_node_uid = task_lib.NodeUid(pipeline_uid, 'ExampleGen')
transform_node_uid = task_lib.NodeUid(pipeline_uid, 'Transform')
trainer_node_uid = task_lib.NodeUid(pipeline_uid, 'Trainer')
evaluator_node_uid = task_lib.NodeUid(pipeline_uid, 'Evaluator')
with pstate.PipelineState.new(m, pipeline) as pipeline_state:
with pipeline_state.node_state_update_context(
eg_node_uid) as node_state:
node_state.update(pstate.NodeState.COMPLETE)
with pipeline_state.node_state_update_context(
transform_node_uid) as node_state:
node_state.update(pstate.NodeState.RUNNING)
with pipeline_state.node_state_update_context(
trainer_node_uid) as node_state:
node_state.update(pstate.NodeState.STARTING)
with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state:
self.assertEqual(
{
eg_node_uid:
pstate.NodeState(state=pstate.NodeState.COMPLETE),
transform_node_uid:
pstate.NodeState(state=pstate.NodeState.RUNNING),
trainer_node_uid:
pstate.NodeState(state=pstate.NodeState.STARTING),
evaluator_node_uid:
pstate.NodeState(state=pstate.NodeState.STARTED),
}, pipeline_state.get_node_states_dict())

def test_save_and_remove_property(self):
property_key = 'key'
property_value = 'value'
Expand Down
Loading

0 comments on commit f314cc7

Please sign in to comment.