Skip to content

Commit

Permalink
In async pipelines, do not trigger a node if there are no changes sin…
Browse files Browse the repository at this point in the history
…ce previous run

A new node run is triggered only if input artifacts, exec properties or executor spec are different from the last execution even if the last execution failed.

PiperOrigin-RevId: 406211796
  • Loading branch information
goutham authored and tfx-copybara committed Oct 28, 2021
1 parent b634b08 commit 2c7a497
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 32 deletions.
53 changes: 44 additions & 9 deletions tfx/orchestration/experimental/core/async_pipeline_task_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
# limitations under the License.
"""TaskGenerator implementation for async pipelines."""

import hashlib
import itertools
from typing import Callable, List, Optional
from typing import Callable, Dict, List, Optional

from absl import logging
from tfx import types
from tfx.orchestration import metadata
from tfx.orchestration.experimental.core import constants
from tfx.orchestration.experimental.core import pipeline_state as pstate
from tfx.orchestration.experimental.core import service_jobs
from tfx.orchestration.experimental.core import task as task_lib
Expand Down Expand Up @@ -187,13 +190,19 @@ def _generate_tasks_for_node(
'are resolved.', node.node_info.id)
return result

# If the latest successful execution had the same resolved input artifacts,
# the component should not be triggered, so task is not generated.
# TODO(b/170231077): This logic should be handled by the resolver when it's
# implemented. Also, currently only the artifact ids of previous execution
# are checked to decide if a new execution is warranted but it may also be
# necessary to factor in the difference of execution properties.
latest_exec = task_gen_utils.get_latest_successful_execution(executions)
executor_spec_fingerprint = hashlib.sha256()
executor_spec = task_gen_utils.get_executor_spec(
self._pipeline_state.pipeline, node.node_info.id)
if executor_spec is not None:
executor_spec_fingerprint.update(
executor_spec.SerializeToString(deterministic=True))
resolved_info.exec_properties[
constants
.EXECUTOR_SPEC_FINGERPRINT_KEY] = executor_spec_fingerprint.hexdigest()

# If the latest execution had the same resolved input artifacts, execution
# properties and executor specs, we should not trigger a new execution.
latest_exec = task_gen_utils.get_latest_execution(executions)
if latest_exec:
artifact_ids_by_event_type = (
execution_lib.get_artifact_ids_by_event_type_for_execution_id(
Expand All @@ -203,7 +212,16 @@ def _generate_tasks_for_node(
current_exec_input_artifact_ids = set(
a.id
for a in itertools.chain(*resolved_info.input_artifacts.values()))
if latest_exec_input_artifact_ids == current_exec_input_artifact_ids:
latest_exec_properties = task_gen_utils.extract_properties(latest_exec)
current_exec_properties = resolved_info.exec_properties
latest_exec_executor_spec_fp = latest_exec_properties[
constants.EXECUTOR_SPEC_FINGERPRINT_KEY]
current_exec_executor_spec_fp = resolved_info.exec_properties[
constants.EXECUTOR_SPEC_FINGERPRINT_KEY]
if (latest_exec_input_artifact_ids == current_exec_input_artifact_ids and
_exec_properties_match(latest_exec_properties,
current_exec_properties) and
latest_exec_executor_spec_fp == current_exec_executor_spec_fp):
result.append(
task_lib.UpdateNodeStateTask(
node_uid=node_uid, state=pstate.NodeState.STARTED))
Expand Down Expand Up @@ -270,3 +288,20 @@ def _ensure_node_services_if_mixed(
return self._service_job_manager.ensure_node_services(
self._pipeline_state, node_id)
return None


def _exec_properties_match(
exec_props1: Dict[str, types.ExecPropertyTypes],
exec_props2: Dict[str, types.ExecPropertyTypes]) -> bool:
"""Returns True if exec properties match."""

def _filter_out_internal_keys(
props: Dict[str, types.ExecPropertyTypes]
) -> Dict[str, types.ExecPropertyTypes]:
return {
key: value for key, value in props.items() if not key.startswith('__')
}

exec_props1 = _filter_out_internal_keys(exec_props1)
exec_props2 = _filter_out_internal_keys(exec_props2)
return exec_props1 == exec_props2
115 changes: 113 additions & 2 deletions tfx/orchestration/experimental/core/async_pipeline_task_gen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,19 @@
import tensorflow as tf
from tfx.orchestration import metadata
from tfx.orchestration.experimental.core import async_pipeline_task_gen as asptg
from tfx.orchestration.experimental.core import mlmd_state
from tfx.orchestration.experimental.core import pipeline_state as pstate
from tfx.orchestration.experimental.core import service_jobs
from tfx.orchestration.experimental.core import task as task_lib
from tfx.orchestration.experimental.core import task_gen_utils
from tfx.orchestration.experimental.core import task_queue as tq
from tfx.orchestration.experimental.core import test_utils
from tfx.orchestration.experimental.core.testing import test_async_pipeline
from tfx.utils import status as status_lib

from google.protobuf import any_pb2
from ml_metadata.proto import metadata_store_pb2


class AsyncPipelineTaskGeneratorTest(test_utils.TfxTest,
parameterized.TestCase):
Expand Down Expand Up @@ -101,7 +106,8 @@ def _generate_and_test(self,
num_tasks_generated,
num_new_executions,
num_active_executions,
expected_exec_nodes=None):
expected_exec_nodes=None,
ignore_update_node_state_tasks=False):
"""Generates tasks and tests the effects."""
return test_utils.run_generator_and_test(
self,
Expand All @@ -115,7 +121,8 @@ def _generate_and_test(self,
num_tasks_generated=num_tasks_generated,
num_new_executions=num_new_executions,
num_active_executions=num_active_executions,
expected_exec_nodes=expected_exec_nodes)
expected_exec_nodes=expected_exec_nodes,
ignore_update_node_state_tasks=ignore_update_node_state_tasks)

@parameterized.parameters(0, 1)
def test_no_tasks_generated_when_no_inputs(self, min_count):
Expand Down Expand Up @@ -372,6 +379,110 @@ def _ensure_node_services(unused_pipeline_state, node_id):
self.assertTrue(task_lib.is_update_node_state_task(update_task))
self.assertEqual(status_lib.Code.ABORTED, update_task.status.code)

def test_triggering_upon_exec_properties_change(self):
test_utils.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1,
1)

[exec_transform_task] = self._generate_and_test(
False,
num_initial_executions=1,
num_tasks_generated=1,
num_new_executions=1,
num_active_executions=1,
expected_exec_nodes=[self._transform],
ignore_update_node_state_tasks=True)

# Fail the registered execution.
with self._mlmd_connection as m:
with mlmd_state.mlmd_execution_atomic_op(
m, exec_transform_task.execution_id) as execution:
execution.last_known_state = metadata_store_pb2.Execution.FAILED

# Try to generate with same execution properties. This should not trigger
# as there are no changes since last run.
self._generate_and_test(
False,
num_initial_executions=2,
num_tasks_generated=0,
num_new_executions=0,
num_active_executions=0,
ignore_update_node_state_tasks=True)

# Change execution properties of last run.
with self._mlmd_connection as m:
with mlmd_state.mlmd_execution_atomic_op(
m, exec_transform_task.execution_id) as execution:
execution.custom_properties['a_param'].int_value = 20

# Generating with different execution properties should trigger.
self._generate_and_test(
False,
num_initial_executions=2,
num_tasks_generated=1,
num_new_executions=1,
num_active_executions=1,
expected_exec_nodes=[self._transform],
ignore_update_node_state_tasks=True)

def test_triggering_upon_executor_spec_change(self):
test_utils.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1,
1)

with mock.patch.object(task_gen_utils,
'get_executor_spec') as mock_get_executor_spec:
mock_get_executor_spec.side_effect = _fake_executor_spec(1)
[exec_transform_task] = self._generate_and_test(
False,
num_initial_executions=1,
num_tasks_generated=1,
num_new_executions=1,
num_active_executions=1,
expected_exec_nodes=[self._transform],
ignore_update_node_state_tasks=True)

# Fail the registered execution.
with self._mlmd_connection as m:
with mlmd_state.mlmd_execution_atomic_op(
m, exec_transform_task.execution_id) as execution:
execution.last_known_state = metadata_store_pb2.Execution.FAILED

# Try to generate with same executor spec. This should not trigger as
# there are no changes since last run.
with mock.patch.object(task_gen_utils,
'get_executor_spec') as mock_get_executor_spec:
mock_get_executor_spec.side_effect = _fake_executor_spec(1)
self._generate_and_test(
False,
num_initial_executions=2,
num_tasks_generated=0,
num_new_executions=0,
num_active_executions=0,
ignore_update_node_state_tasks=True)

# Generating with a different executor spec should trigger.
with mock.patch.object(task_gen_utils,
'get_executor_spec') as mock_get_executor_spec:
mock_get_executor_spec.side_effect = _fake_executor_spec(2)
self._generate_and_test(
False,
num_initial_executions=2,
num_tasks_generated=1,
num_new_executions=1,
num_active_executions=1,
expected_exec_nodes=[self._transform],
ignore_update_node_state_tasks=True)


def _fake_executor_spec(val):

def _get_executor_spec(*unused_args, **unused_kwargs):
value = metadata_store_pb2.Value(int_value=val)
any_proto = any_pb2.Any()
any_proto.Pack(value)
return any_proto

return _get_executor_spec


if __name__ == '__main__':
tf.test.main()
1 change: 1 addition & 0 deletions tfx/orchestration/experimental/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Constants shared across modules."""

EXECUTION_ERROR_MSG_KEY = '__execution_error_msg__'
EXECUTOR_SPEC_FINGERPRINT_KEY = '__executor_spec_fingerprint__'

IMPORTER_NODE_TYPE = 'tfx.dsl.components.common.importer.Importer'
RESOLVER_NODE_TYPE = 'tfx.dsl.components.common.resolver.Resolver'
Expand Down
18 changes: 2 additions & 16 deletions tfx/orchestration/experimental/core/sync_pipeline_task_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
from tfx.utils import status as status_lib
from tfx.utils import topsort

from google.protobuf import any_pb2


class SyncPipelineTaskGenerator(task_gen.TaskGenerator):
"""Task generator for executing a sync pipeline.
Expand Down Expand Up @@ -335,7 +333,8 @@ def _resolve_inputs_and_generate_tasks_for_node(
self._mlmd_handle,
pipeline_node=node,
pipeline_info=self._pipeline.pipeline_info,
executor_spec=_get_executor_spec(self._pipeline, node.node_info.id),
executor_spec=task_gen_utils.get_executor_spec(self._pipeline,
node.node_info.id),
input_artifacts=resolved_info.input_artifacts,
output_artifacts=output_artifacts,
parameters=resolved_info.exec_properties)
Expand Down Expand Up @@ -424,19 +423,6 @@ def _abort_task(self, error_msg: str) -> task_lib.FinalizePipelineTask:
code=status_lib.Code.ABORTED, message=error_msg))


# TODO(b/182944474): Raise error in _get_executor_spec if executor spec is
# missing for a non-system node.
def _get_executor_spec(pipeline: pipeline_pb2.Pipeline,
node_id: str) -> Optional[any_pb2.Any]:
"""Returns executor spec for given node_id if it exists in pipeline IR, None otherwise."""
if not pipeline.deployment_config.Is(
pipeline_pb2.IntermediateDeploymentConfig.DESCRIPTOR):
return None
depl_config = pipeline_pb2.IntermediateDeploymentConfig()
pipeline.deployment_config.Unpack(depl_config)
return depl_config.executor_specs.get(node_id)


def _topsorted_layers(
pipeline: pipeline_pb2.Pipeline) -> List[List[pipeline_pb2.PipelineNode]]:
"""Returns pipeline nodes in topologically sorted layers."""
Expand Down
18 changes: 16 additions & 2 deletions tfx/orchestration/experimental/core/task_gen_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

import ml_metadata as mlmd
from ml_metadata.proto import metadata_store_pb2
from google.protobuf import any_pb2


@attr.s(auto_attribs=True)
Expand All @@ -49,7 +50,7 @@ def _generate_task_from_execution(metadata_handler: metadata.Metadata,
is_cancelled: bool = False) -> task_lib.Task:
"""Generates `ExecNodeTask` given execution."""
contexts = metadata_handler.store.get_contexts_by_execution(execution.id)
exec_properties = _extract_properties(execution)
exec_properties = extract_properties(execution)
input_artifacts = execution_lib.get_artifacts_dict(
metadata_handler, execution.id, [metadata_store_pb2.Event.INPUT])
outputs_resolver = outputs_utils.OutputsResolver(node, pipeline.pipeline_info,
Expand Down Expand Up @@ -113,7 +114,7 @@ def generate_task_from_active_execution(
is_cancelled=is_cancelled)


def _extract_properties(
def extract_properties(
execution: metadata_store_pb2.Execution
) -> Dict[str, types.ExecPropertyTypes]:
"""Extracts execution properties from mlmd Execution."""
Expand Down Expand Up @@ -253,3 +254,16 @@ def get_latest_execution(
"""Returns latest execution or `None` if iterable is empty."""
sorted_executions = execution_lib.sort_executions_newest_to_oldest(executions)
return sorted_executions[0] if sorted_executions else None


# TODO(b/182944474): Raise error in _get_executor_spec if executor spec is
# missing for a non-system node.
def get_executor_spec(pipeline: pipeline_pb2.Pipeline,
node_id: str) -> Optional[any_pb2.Any]:
"""Returns executor spec for given node_id if it exists in pipeline IR, None otherwise."""
if not pipeline.deployment_config.Is(
pipeline_pb2.IntermediateDeploymentConfig.DESCRIPTOR):
return None
depl_config = pipeline_pb2.IntermediateDeploymentConfig()
pipeline.deployment_config.Unpack(depl_config)
return depl_config.executor_specs.get(node_id)
13 changes: 10 additions & 3 deletions tfx/orchestration/experimental/core/testing/test_async_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from tfx.dsl.compiler import compiler
from tfx.dsl.component.experimental.annotations import InputArtifact
from tfx.dsl.component.experimental.annotations import OutputArtifact
from tfx.dsl.component.experimental.annotations import Parameter
from tfx.dsl.component.experimental.decorators import component
from tfx.orchestration import pipeline as pipeline_lib
from tfx.proto.orchestration import pipeline_pb2
Expand All @@ -27,11 +28,16 @@ def _example_gen(examples: OutputArtifact[standard_artifacts.Examples]):
del examples


# pytype: disable=wrong-arg-types
@component
def _transform(
examples: InputArtifact[standard_artifacts.Examples],
transform_graph: OutputArtifact[standard_artifacts.TransformGraph]):
del examples, transform_graph
transform_graph: OutputArtifact[standard_artifacts.TransformGraph],
a_param: Parameter[int]):
del examples, transform_graph, a_param


# pytype: enable=wrong-arg-types


@component
Expand All @@ -46,7 +52,8 @@ def create_pipeline() -> pipeline_pb2.Pipeline:
# pylint: disable=no-value-for-parameter
example_gen = _example_gen().with_id('my_example_gen')
transform = _transform(
examples=example_gen.outputs['examples']).with_id('my_transform')
examples=example_gen.outputs['examples'],
a_param=10).with_id('my_transform')
trainer = _trainer(
examples=example_gen.outputs['examples'],
transform_graph=transform.outputs['transform_graph']).with_id(
Expand Down

0 comments on commit 2c7a497

Please sign in to comment.