Skip to content

Commit

Permalink
Support tmp_dir
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 426772459
  • Loading branch information
goutham authored and tfx-copybara committed Feb 6, 2022
1 parent 80f8525 commit 75cf0aa
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def _generate_tasks_for_node(
execution.id),
stateful_working_dir=outputs_resolver
.get_stateful_working_directory(execution.id),
tmp_dir=outputs_resolver.make_tmp_dir(execution.id),
pipeline=self._pipeline))
return result

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ def _resolve_inputs_and_generate_tasks_for_node(
execution.id),
stateful_working_dir=outputs_resolver
.get_stateful_working_directory(execution.id),
tmp_dir=outputs_resolver.make_tmp_dir(execution.id),
pipeline=self._pipeline))
return result

Expand Down
2 changes: 2 additions & 0 deletions tfx/orchestration/experimental/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class ExecNodeTask(Task):
output_artifacts: Output artifacts dict.
executor_output_uri: URI for the executor output.
stateful_working_dir: Working directory for the node execution.
tmp_dir: Temporary directory for the node execution.
pipeline: The pipeline IR proto containing the node to be executed.
is_cancelled: Indicates whether this is a cancelled execution. The task
scheduler is expected to gracefully exit after doing any necessary
Expand All @@ -114,6 +115,7 @@ class ExecNodeTask(Task):
output_artifacts: Dict[str, List[types.Artifact]]
executor_output_uri: str
stateful_working_dir: str
tmp_dir: str
pipeline: pipeline_pb2.Pipeline
is_cancelled: bool = False

Expand Down
1 change: 1 addition & 0 deletions tfx/orchestration/experimental/core/task_gen_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def _generate_task_from_execution(metadata_handler: metadata.Metadata,
execution.id),
stateful_working_dir=outputs_resolver.get_stateful_working_directory(
execution.id),
tmp_dir=outputs_resolver.make_tmp_dir(execution.id),
pipeline=pipeline,
is_cancelled=is_cancelled)

Expand Down
7 changes: 7 additions & 0 deletions tfx/orchestration/experimental/core/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,15 @@ def _remove_output_dirs(task: task_lib.ExecNodeTask,


def _remove_task_dirs(task: task_lib.ExecNodeTask) -> None:
"""Removes directories created for the task."""
if task.stateful_working_dir:
outputs_utils.remove_stateful_working_dir(task.stateful_working_dir)
if task.tmp_dir:
try:
fileio.rmtree(task.tmp_dir)
except fileio.NotFoundError:
logging.warning(
'tmp_dir %s not found while attempting to delete, ignoring.')
if task.executor_output_uri:
try:
fileio.remove(task.executor_output_uri)
Expand Down
19 changes: 14 additions & 5 deletions tfx/orchestration/experimental/core/task_manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ def setUp(self):
self.assertTrue(task_lib.is_exec_node_task(tasks[1]))
self.assertEqual('my_transform', tasks[1].node_uid.node_id)
self.assertTrue(os.path.exists(tasks[1].stateful_working_dir))
self.assertTrue(os.path.exists(tasks[1].tmp_dir))
self._task = tasks[1]
self._output_artifact_uri = self._task.output_artifacts['transform_graph'][
0].uri
Expand Down Expand Up @@ -411,8 +412,9 @@ def test_successful_execution_resulting_in_executor_output(self):
self.assertEqual(metadata_store_pb2.Execution.COMPLETE,
execution.last_known_state)

# Check that stateful working dir is removed.
# Check that stateful working dir and tmp_dir are removed.
self.assertFalse(os.path.exists(self._task.stateful_working_dir))
self.assertFalse(os.path.exists(self._task.tmp_dir))
# Output artifact URI remains as execution was successful.
self.assertTrue(os.path.exists(self._output_artifact_uri))

Expand All @@ -434,8 +436,9 @@ def test_successful_execution_resulting_in_output_artifacts(self):
self.assertEqual(metadata_store_pb2.Execution.COMPLETE,
execution.last_known_state)

# Check that stateful working dir is removed.
# Check that stateful working dir and tmp_dir are removed.
self.assertFalse(os.path.exists(self._task.stateful_working_dir))
self.assertFalse(os.path.exists(self._task.tmp_dir))
# Output artifact URI remains as execution was successful.
self.assertTrue(os.path.exists(self._output_artifact_uri))

Expand All @@ -459,8 +462,10 @@ def test_scheduler_failure(self):
data_types_utils.get_metadata_value(
execution.custom_properties[constants.EXECUTION_ERROR_MSG_KEY]))

# Check that stateful working dir and output artifact URI are removed.
# Check that stateful working dir, tmp_dir and output artifact URI are
# removed.
self.assertFalse(os.path.exists(self._task.stateful_working_dir))
self.assertFalse(os.path.exists(self._task.tmp_dir))
self.assertFalse(os.path.exists(self._output_artifact_uri))

def test_executor_failure(self):
Expand Down Expand Up @@ -488,8 +493,10 @@ def test_executor_failure(self):
data_types_utils.get_metadata_value(
execution.custom_properties[constants.EXECUTION_ERROR_MSG_KEY]))

# Check that stateful working dir and output artifact URI are removed.
# Check that stateful working dir, tmp_dir and output artifact URI are
# removed.
self.assertFalse(os.path.exists(self._task.stateful_working_dir))
self.assertFalse(os.path.exists(self._task.tmp_dir))
self.assertFalse(os.path.exists(self._output_artifact_uri))

def test_scheduler_raises_exception(self):
Expand All @@ -505,8 +512,10 @@ def test_scheduler_raises_exception(self):
self.assertEqual(metadata_store_pb2.Execution.FAILED,
execution.last_known_state)

# Check that stateful working dir and output artifact URI are removed.
# Check that stateful working dir, tmp_dir and output artifact URI are
# removed.
self.assertFalse(os.path.exists(self._task.stateful_working_dir))
self.assertFalse(os.path.exists(self._task.tmp_dir))
self.assertFalse(os.path.exists(self._output_artifact_uri))


Expand Down
2 changes: 2 additions & 0 deletions tfx/orchestration/experimental/core/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def create_exec_node_task(node_uid,
output_artifacts=None,
executor_output_uri=None,
stateful_working_dir=None,
tmp_dir=None,
pipeline=None,
is_cancelled=False) -> task_lib.ExecNodeTask:
"""Creates an `ExecNodeTask` for testing."""
Expand All @@ -159,6 +160,7 @@ def create_exec_node_task(node_uid,
output_artifacts=output_artifacts or {},
executor_output_uri=executor_output_uri or '',
stateful_working_dir=stateful_working_dir or '',
tmp_dir=tmp_dir or '',
pipeline=pipeline or mock.Mock(),
is_cancelled=is_cancelled)

Expand Down

0 comments on commit 75cf0aa

Please sign in to comment.