Skip to content

Commit

Permalink
Register multiple executions for a list of input artifacts
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 434827877
  • Loading branch information
tfx-copybara committed Mar 15, 2022
1 parent aa5f5c1 commit 25257bd
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 36 deletions.
71 changes: 47 additions & 24 deletions tfx/orchestration/experimental/core/sync_pipeline_task_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,18 @@
from tfx.orchestration import data_types_utils
from tfx.orchestration import metadata
from tfx.orchestration.experimental.core import constants
from tfx.orchestration.experimental.core import mlmd_state
from tfx.orchestration.experimental.core import pipeline_state as pstate
from tfx.orchestration.experimental.core import service_jobs
from tfx.orchestration.experimental.core import task as task_lib
from tfx.orchestration.experimental.core import task_gen
from tfx.orchestration.experimental.core import task_gen_utils
from tfx.orchestration.portable import execution_publish_utils
from tfx.orchestration.portable import outputs_utils
from tfx.orchestration.portable.mlmd import execution_lib
from tfx.proto.orchestration import pipeline_pb2
from tfx.utils import status as status_lib
from tfx.utils import topsort
from ml_metadata.proto import metadata_store_pb2


class SyncPipelineTaskGenerator(task_gen.TaskGenerator):
Expand Down Expand Up @@ -241,28 +242,36 @@ def _generate_tasks_for_node(
return result

node_executions = task_gen_utils.get_executions(self._mlmd_handle, node)
latest_execution = task_gen_utils.get_latest_execution(node_executions)
latest_executions_set = task_gen_utils.get_latest_executions_set(
node_executions)

# If the latest execution is successful, we're done.
if latest_execution and execution_lib.is_execution_successful(
latest_execution):
# If all the executions in the set for the node are successful, we're done.
if latest_executions_set and all(
execution_lib.is_execution_successful(e)
for e in latest_executions_set):
logging.info('Node successful: %s', node_uid)
result.append(
task_lib.UpdateNodeStateTask(
node_uid=node_uid, state=pstate.NodeState.COMPLETE))
return result

# If the latest execution failed or cancelled, the pipeline should be
# aborted if the node is not in state STARTING. For nodes that are
# in state STARTING, a new execution is created.
if (latest_execution and
not execution_lib.is_execution_active(latest_execution) and
node_state.state != pstate.NodeState.STARTING):
error_msg_value = latest_execution.custom_properties.get(
constants.EXECUTION_ERROR_MSG_KEY)
error_msg = data_types_utils.get_metadata_value(
error_msg_value) if error_msg_value else ''
error_msg = f'node failed; node uid: {node_uid}; error: {error_msg}'
# If one of the executions in the set for the node failed or cancelled, the
# pipeline should be aborted if the node is not in state STARTING.
# For nodes that are in state STARTING, new executions are created.
# TODO (b/223627713) a node in a ForEach is not restartable, it is better
# to prevent restarting for now.
failed_executions = [
e for e in latest_executions_set if execution_lib.is_execution_failed(e)
]
if failed_executions and (len(latest_executions_set) > 1 or
node_state.state != pstate.NodeState.STARTING):
error_msg = f'node {node_uid} failed; '
for e in failed_executions:
error_msg_value = e.custom_properties.get(
constants.EXECUTION_ERROR_MSG_KEY)
error_msg_value = data_types_utils.get_metadata_value(
error_msg_value) if error_msg_value else ''
error_msg += f'error: {error_msg_value}; '
result.append(
task_lib.UpdateNodeStateTask(
node_uid=node_uid,
Expand All @@ -271,13 +280,20 @@ def _generate_tasks_for_node(
code=status_lib.Code.ABORTED, message=error_msg)))
return result

exec_node_task = task_gen_utils.generate_task_from_active_execution(
self._mlmd_handle, self._pipeline, node, node_executions)
if exec_node_task:
latest_active_execution = task_gen_utils.get_latest_active_execution(
latest_executions_set)
if latest_active_execution:
with mlmd_state.mlmd_execution_atomic_op(
mlmd_handle=self._mlmd_handle,
execution_id=latest_active_execution.id) as execution:
execution.last_known_state = metadata_store_pb2.Execution.RUNNING
result.append(
task_lib.UpdateNodeStateTask(
node_uid=node_uid, state=pstate.NodeState.RUNNING))
result.append(exec_node_task)
result.append(
task_gen_utils.generate_task_from_execution(self._mlmd_handle,
self._pipeline, node,
execution))
return result

# Finally, we are ready to generate tasks for the node by resolving inputs.
Expand Down Expand Up @@ -308,15 +324,22 @@ def _resolve_inputs_and_generate_tasks_for_node(
status=status_lib.Status(
code=status_lib.Code.ABORTED, message=error_msg)))
return result
# TODO(b/207038460): Update sync pipeline to support ForEach.
input_artifacts = resolved_info.input_artifacts[0]

execution = execution_publish_utils.register_execution(
executions = task_gen_utils.register_executions(
metadata_handler=self._mlmd_handle,
execution_type=node.node_info.type,
contexts=resolved_info.contexts,
input_artifacts=input_artifacts,
input_dicts=resolved_info.input_artifacts,
exec_properties=resolved_info.exec_properties)

# Selects the first artifacts and create a exec task.
input_artifacts = resolved_info.input_artifacts[0]
# Selects the first execution and marks it as RUNNING.
with mlmd_state.mlmd_execution_atomic_op(
mlmd_handle=self._mlmd_handle,
execution_id=executions[0].id) as execution:
execution.last_known_state = metadata_store_pb2.Execution.RUNNING

outputs_resolver = outputs_utils.OutputsResolver(
node, self._pipeline.pipeline_info, self._pipeline.runtime_spec,
self._pipeline.execution_mode)
Expand Down
118 changes: 106 additions & 12 deletions tfx/orchestration/experimental/core/task_gen_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
"""Utilities for task generation."""

import itertools
from typing import Dict, Iterable, List, Optional, Sequence
import time
from typing import Dict, Iterable, List, Mapping, Optional, Sequence
import uuid

from absl import logging
import attr
Expand All @@ -35,6 +37,9 @@
from ml_metadata.proto import metadata_store_pb2
from google.protobuf import any_pb2

_EXECUTION_SET_SIZE = '__execution_set_size__'
_EXECUTION_TIMESTAMP = '__execution_timestamp__'


@attr.s(auto_attribs=True)
class ResolvedInfo:
Expand All @@ -43,12 +48,15 @@ class ResolvedInfo:
input_artifacts: List[Optional[typing_utils.ArtifactMultiMap]]


def _generate_task_from_execution(metadata_handler: metadata.Metadata,
pipeline: pipeline_pb2.Pipeline,
node: pipeline_pb2.PipelineNode,
execution: metadata_store_pb2.Execution,
is_cancelled: bool = False) -> task_lib.Task:
def generate_task_from_execution(metadata_handler: metadata.Metadata,
pipeline: pipeline_pb2.Pipeline,
node: pipeline_pb2.PipelineNode,
execution: metadata_store_pb2.Execution,
is_cancelled: bool = False) -> task_lib.Task:
"""Generates `ExecNodeTask` given execution."""
if not execution_lib.is_execution_active(execution):
raise RuntimeError(f'Execution is not active: {execution}.')

contexts = metadata_handler.store.get_contexts_by_execution(execution.id)
exec_properties = extract_properties(execution)
input_artifacts = execution_lib.get_artifacts_dict(
Expand Down Expand Up @@ -104,10 +112,13 @@ def generate_task_from_active_execution(
if not active_executions:
return None
if len(active_executions) > 1:
# TODO (b/223627713) a node in a ForEach is not restartable, it is better
# to prevent restarting for now.
raise RuntimeError(
'Unexpected multiple active executions for the node: {}\n executions: '
'{}'.format(node.node_info.id, active_executions))
return _generate_task_from_execution(
'{}. Updating/restarting a foreach node is not supported yet'.format(
node.node_info.id, active_executions))
return generate_task_from_execution(
metadata_handler,
pipeline,
node,
Expand Down Expand Up @@ -179,10 +190,6 @@ def generate_resolved_info(
return None
assert isinstance(resolved_input_artifacts, inputs_utils.Trigger)
assert resolved_input_artifacts
# TODO(b/197741942): Support multiple dicts.
if len(resolved_input_artifacts) > 1:
raise NotImplementedError(
'Handling more than one input dicts not implemented.')

return ResolvedInfo(
contexts=contexts,
Expand Down Expand Up @@ -238,6 +245,16 @@ def is_latest_execution_successful(
execution) if execution else False


def get_latest_active_execution(
executions: Iterable[metadata_store_pb2.Execution]
) -> Optional[metadata_store_pb2.Execution]:
"""Returns the latest active execution or `None` if no active executions exist."""
active_executions = [
e for e in executions if execution_lib.is_execution_active(e)
]
return get_latest_execution(active_executions)


def get_latest_successful_execution(
executions: Iterable[metadata_store_pb2.Execution]
) -> Optional[metadata_store_pb2.Execution]:
Expand All @@ -252,10 +269,39 @@ def get_latest_execution(
executions: Iterable[metadata_store_pb2.Execution]
) -> Optional[metadata_store_pb2.Execution]:
"""Returns latest execution or `None` if iterable is empty."""
# TODO(guoweihe) After b/207038460, multiple executions can have the same
# creation time. We may need another custom property to determine their order.
sorted_executions = execution_lib.sort_executions_newest_to_oldest(executions)
return sorted_executions[0] if sorted_executions else None


def get_latest_executions_set(
executions: Iterable[metadata_store_pb2.Execution]
) -> List[metadata_store_pb2.Execution]:
"""Returns latest set of executions."""
sorted_executions = execution_lib.sort_executions_newest_to_oldest(executions)
if not sorted_executions:
return []

size = sorted_executions[0].custom_properties.get(_EXECUTION_SET_SIZE)
if not size:
return [sorted_executions[0]]

# TODO(b/217390865): After we can register several executions in one
# transaction, the following code can be simplified.
# But before the feature is implemented, we can abandon those partially
# registered executions. For example, if orchestrator fail after publishing
# 1/3 and 2/3 but before 3/3, this function return empty array.
timestamp = sorted_executions[0].custom_properties.get(
_EXECUTION_TIMESTAMP).int_value
latest_execution_set = [
e for e in sorted_executions[:size.int_value]
if e.custom_properties.get(_EXECUTION_TIMESTAMP).int_value == timestamp
]
return [] if len(latest_execution_set) != size.int_value else list(
reversed(latest_execution_set))


# TODO(b/182944474): Raise error in _get_executor_spec if executor spec is
# missing for a non-system node.
def get_executor_spec(pipeline: pipeline_pb2.Pipeline,
Expand All @@ -267,3 +313,51 @@ def get_executor_spec(pipeline: pipeline_pb2.Pipeline,
depl_config = pipeline_pb2.IntermediateDeploymentConfig()
pipeline.deployment_config.Unpack(depl_config)
return depl_config.executor_specs.get(node_id)


def register_executions(
metadata_handler: metadata.Metadata,
execution_type: metadata_store_pb2.ExecutionType,
contexts: Sequence[metadata_store_pb2.Context],
input_dicts: List[typing_utils.ArtifactMultiMap],
exec_properties: Optional[Mapping[str, types.ExecPropertyTypes]] = None,
) -> List[metadata_store_pb2.Execution]:
"""Registers multiple executions in MLMD.
Along with the execution:
- the input artifacts will be linked to the executions.
- the contexts will be linked to both the executions and its input artifacts.
Args:
metadata_handler: A handler to access MLMD.
execution_type: The type of the execution.
contexts: MLMD contexts to associated with the executions.
input_dicts: A list of dictionaries of artifacts. One execution will be
registered for each of the input_dict.
exec_properties: Execution properties. Will be attached to the executions.
Returns:
A list of MLMD executions that are registered in MLMD, with id populated.
All regiested executions have state of NEW.
"""
executions = []
# TODO(b/207038460): Use the new feature of batch executions update once it is
# implemented (b/209883142).
timestamp = int(time.time() * 1e6)
for input_artifacts in input_dicts:
execution = execution_lib.prepare_execution(
metadata_handler,
execution_type,
metadata_store_pb2.Execution.NEW,
exec_properties,
execution_name=str(uuid.uuid4()))
execution.custom_properties[_EXECUTION_SET_SIZE].int_value = len(
input_dicts)
execution.custom_properties[_EXECUTION_TIMESTAMP].int_value = timestamp
executions.append(
execution_lib.put_execution(
metadata_handler,
execution,
contexts,
input_artifacts=input_artifacts))
return executions
42 changes: 42 additions & 0 deletions tfx/orchestration/experimental/core/task_gen_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tfx.orchestration.experimental.core import task_gen_utils
from tfx.orchestration.experimental.core import test_utils as otu
from tfx.orchestration.experimental.core.testing import test_async_pipeline
from tfx.types import standard_artifacts
from tfx.utils import test_case_utils as tu

from ml_metadata.proto import metadata_store_pb2
Expand Down Expand Up @@ -221,6 +222,47 @@ def test_get_latest_successful_execution(self):
self.assertEqual(execs[1],
task_gen_utils.get_latest_successful_execution(execs))

def test_get_latest_activate_execution_set(self):
with self._mlmd_connection as m:
# Registers two sets of executions.
task_gen_utils.register_executions(
m,
metadata_store_pb2.ExecutionType(name='my_ex_type'), {},
input_dicts=[{
'input_example': [standard_artifacts.Examples()]
}, {
'input_example': [standard_artifacts.Examples()]
}])
newer_execution_set = task_gen_utils.register_executions(
m,
metadata_store_pb2.ExecutionType(name='my_ex_type'), {},
input_dicts=[{
'input_example': [standard_artifacts.Examples()]
}, {
'input_example': [standard_artifacts.Examples()]
}])

executions = m.store.get_executions()
self.assertLen(executions, 4)

latest_execution_set = task_gen_utils.get_latest_executions_set(
executions)
self.assertLen(latest_execution_set, 2)
self.assertProtoPartiallyEquals(
newer_execution_set[0],
latest_execution_set[0],
ignored_fields=[
'type_id', 'create_time_since_epoch',
'last_update_time_since_epoch'
])
self.assertProtoPartiallyEquals(
newer_execution_set[1],
latest_execution_set[1],
ignored_fields=[
'type_id', 'create_time_since_epoch',
'last_update_time_since_epoch'
])


if __name__ == '__main__':
tf.test.main()
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,16 @@ def test_importer_task_scheduler(self):

[execution
] = m.store.get_executions_by_id([self._importer_task.execution_id])
del execution.custom_properties['__execution_timestamp__']
self.assertProtoPartiallyEquals(
"""
last_known_state: COMPLETE
custom_properties {
key: "__execution_set_size__"
value {
int_value: 1
}
}
custom_properties {
key: "artifact_uri"
value {
Expand Down
13 changes: 13 additions & 0 deletions tfx/orchestration/portable/mlmd/execution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,19 @@ def is_execution_active(execution: metadata_store_pb2.Execution) -> bool:
execution.last_known_state == metadata_store_pb2.Execution.RUNNING)


def is_execution_failed(execution: metadata_store_pb2.Execution) -> bool:
"""Whether or not an execution is failed.
Args:
execution: An execution message.
Returns:
A bool value indicating whether or not the execution is failed.
"""
return not is_execution_successful(execution) and not is_execution_active(
execution)


def is_internal_key(key: str) -> bool:
"""Returns `True` if the key is an internal-only execution property key."""
return key.startswith('__')
Expand Down

0 comments on commit 25257bd

Please sign in to comment.