Skip to content

Commit

Permalink
[tune] Fault tolerance improvements (ray-project#5877)
Browse files Browse the repository at this point in the history
* Precede ray.get with ray.wait.

* Trigger checkpoint deletes locally in Trainable

* Clean-up code.

* Minor changes.

* Track best checkpoint so far again

* Pulled checkpoint GC out of Trainable.

* Added comments, error logging.

* Immediate pull after checkpoint taken; rsync source delete on pull

* Minor doc fixes

* Fix checkpoint manager bug

* Fix bugs, tests, formatting

* Fix bugs, feature flag for force sync.

* Fix test.

* Fix minor bugs: clear proc and less verbose sync_on_checkpoint warnings.

* Fix bug: update IP of last_result.

* Fixed message.

* Added a lot of logging.

* Changes to ray trial executor.

* More bug fixes (logging after failure), better logging.

* Fix richards bug and logging

* Add comments.

* try-except

* Fix heapq bug.

* .

* Move handling of no available trials to ray_trial_executor (ray-project#1)

* Fix formatting bug, lint.

* Addressed Richard's comments

* Revert tests.

* fix rebase

* Fix trial location reporting.

* Fix test

* Fix lint

* Rebase, use ray.get w/ timeout, lint.

* lint

* fix rebase

* Address richard's comments
  • Loading branch information
ujvl authored and richardliaw committed Nov 18, 2019
1 parent 66edebc commit 2965dc1
Show file tree
Hide file tree
Showing 20 changed files with 837 additions and 451 deletions.
2 changes: 1 addition & 1 deletion doc/source/walkthrough.rst
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ You can also set a timeout to return early from a ``get`` that's blocking for to

.. code-block:: python
from ray.exceptions import RayTimeoutException
from ray.exceptions import RayTimeoutError
@ray.remote
def long_running_function()
Expand Down
22 changes: 18 additions & 4 deletions python/ray/tune/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,22 @@
randn, loguniform)

__all__ = [
"Trainable", "TuneError", "grid_search", "register_env",
"register_trainable", "run", "run_experiments", "Experiment", "function",
"sample_from", "track", "uniform", "choice", "randint", "randn",
"loguniform", "progress_reporter", "ExperimentAnalysis", "Analysis"
"Trainable",
"TuneError",
"grid_search",
"register_env",
"register_trainable",
"run",
"run_experiments",
"Experiment",
"function",
"sample_from",
"track",
"uniform",
"choice",
"randint",
"randn",
"loguniform",
"ExperimentAnalysis",
"Analysis",
]
142 changes: 142 additions & 0 deletions python/ray/tune/checkpoint_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# coding: utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import heapq
import logging
import os
import shutil

try:
FileNotFoundError
except NameError:
FileNotFoundError = IOError

logger = logging.getLogger(__name__)


class Checkpoint(object):
"""Describes a checkpoint of trial state.
Checkpoint may be saved in different storage.
Attributes:
storage (str): Storage type.
value (str): If storage==MEMORY, value is a Python object.
If storage==DISK, value is a path points to the checkpoint in disk.
"""

MEMORY = "memory"
DISK = "disk"

def __init__(self, storage, value, result=None):
self.storage = storage
self.value = value
self.result = result or {}

def delete(self):
"""Deletes checkpoint data if disk checkpoint."""
if self.storage == Checkpoint.DISK and self.value:
checkpoint_dir = self.value
if not os.path.exists(checkpoint_dir):
raise FileNotFoundError(
"Attempted to delete checkpoint at {} but "
"path was not found.".format(checkpoint_dir))
elif os.path.isfile(checkpoint_dir):
shutil.rmtree(os.path.dirname(checkpoint_dir))
else:
shutil.rmtree(checkpoint_dir)

@staticmethod
def from_object(value=None):
"""Creates a checkpoint from a Python object."""
return Checkpoint(Checkpoint.MEMORY, value)


class QueueItem(object):
def __init__(self, priority, value):
self.priority = priority
self.value = value

def __cmp__(self, other):
# For python2.7 compatibility.
if self.priority == other.priority:
return 0
return -1 if self.priority < other.priority else 1

def __lt__(self, other):
return self.priority < other.priority


class CheckpointManager(object):
"""Manages checkpoints on the driver for a trial."""

def __init__(self, keep_checkpoints_num, checkpoint_score_attr):
"""Initializes a new CheckpointManager.
Args:
keep_checkpoints_num (int): Keep at least this many checkpoints.
checkpoint_score_attr (str): Attribute to use to determine which
checkpoints to keep.
"""
self.keep_checkpoints_num = keep_checkpoints_num or float("inf")
assert self.keep_checkpoints_num > 0, (
"keep_checkpoints_num must be greater than 0.")
self._checkpoint_score_desc = checkpoint_score_attr.startswith("min-")
if self._checkpoint_score_desc:
self._checkpoint_score_attr = checkpoint_score_attr[4:]
else:
self._checkpoint_score_attr = checkpoint_score_attr

self.newest_checkpoint = Checkpoint(Checkpoint.MEMORY, None)
self._best_checkpoints = []
self._membership = set()

def on_checkpoint(self, checkpoint):
"""Starts tracking checkpoint metadata on checkpoint.
Sets newest checkpoint. Deletes previous checkpoint as long as it isn't
one of the best ones. Also deletes the worst checkpoint if at capacity.
Args:
checkpoint (Checkpoint): Trial state checkpoint.
Raises:
KeyError if checkpoint_score_attr not in result of checkpoint.
"""
old_checkpoint = self.newest_checkpoint
self.newest_checkpoint = checkpoint

try:
queue_item = QueueItem(self._priority(checkpoint), checkpoint)
except KeyError:
if old_checkpoint not in self._membership:
old_checkpoint.delete()
logger.error("Result dict has no key: {}. "
"checkpoint_score_attr must be set to a key in the "
"result dict.".format(self._checkpoint_score_attr))
return

if len(self._best_checkpoints) < self.keep_checkpoints_num:
heapq.heappush(self._best_checkpoints, queue_item)
self._membership.add(checkpoint)
elif queue_item.priority >= self._best_checkpoints[0].priority:
worst = heapq.heappushpop(self._best_checkpoints, queue_item).value
self._membership.add(checkpoint)
if worst in self._membership:
self._membership.remove(worst)
worst.delete()

# Remove the old checkpoint if it isn't one of the best ones.
if old_checkpoint not in self._membership:
old_checkpoint.delete()

def best_checkpoints(self):
"""Returns best checkpoints, sorted by score."""
checkpoints = sorted(self._best_checkpoints, key=lambda c: c.priority)
return [queue_item.value for queue_item in checkpoints]

def _priority(self, checkpoint):
priority = checkpoint.result[self._checkpoint_score_attr]
return -priority if self._checkpoint_score_desc else priority
11 changes: 10 additions & 1 deletion python/ray/tune/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,19 @@ def make_parser(parser_creator=None, **kwargs):
action="store_true",
help="Whether to checkpoint at the end of the experiment. "
"Default is False.")
parser.add_argument(
"--no-sync-on-checkpoint",
action="store_true",
help="Disable sync-down of trial checkpoint, which is enabled by "
"default to guarantee recoverability. If set, checkpoint syncing from "
"worker to driver is asynchronous. Set this only if synchronous "
"checkpointing is too slow and trial restoration failures can be "
"tolerated")
parser.add_argument(
"--keep-checkpoints-num",
default=None,
type=int,
help="Number of last checkpoints to keep. Others get "
help="Number of best checkpoints to keep. Others get "
"deleted. Default (None) keeps all checkpoints.")
parser.add_argument(
"--checkpoint-score-attr",
Expand Down Expand Up @@ -177,6 +185,7 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
stopping_criterion=spec.get("stop", {}),
checkpoint_freq=args.checkpoint_freq,
checkpoint_at_end=args.checkpoint_at_end,
sync_on_checkpoint=not args.no_sync_on_checkpoint,
keep_checkpoints_num=args.keep_checkpoints_num,
checkpoint_score_attr=args.checkpoint_score_attr,
export_formats=spec.get("export_formats", []),
Expand Down
11 changes: 9 additions & 2 deletions python/ray/tune/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(self,
sync_to_driver=None,
checkpoint_freq=0,
checkpoint_at_end=False,
sync_on_checkpoint=True,
keep_checkpoints_num=None,
checkpoint_score_attr=None,
export_formats=None,
Expand All @@ -80,6 +81,11 @@ def __init__(self,
repeat=None,
trial_resources=None,
sync_function=None):
"""Initialize a new Experiment.
The args here take the same meaning as the command line flags defined
in `tune.py:run`.
"""
if repeat:
_raise_deprecation_note("repeat", "num_samples", soft=False)
if trial_resources:
Expand All @@ -102,7 +108,7 @@ def __init__(self,
"criteria must take exactly 2 parameters.".format(stop))

config = config or {}
self._run_identifier = Experiment._register_if_needed(run)
self._run_identifier = Experiment.register_if_needed(run)
spec = {
"run": self._run_identifier,
"stop": stop,
Expand All @@ -117,6 +123,7 @@ def __init__(self,
"sync_to_driver": sync_to_driver,
"checkpoint_freq": checkpoint_freq,
"checkpoint_at_end": checkpoint_at_end,
"sync_on_checkpoint": sync_on_checkpoint,
"keep_checkpoints_num": keep_checkpoints_num,
"checkpoint_score_attr": checkpoint_score_attr,
"export_formats": export_formats or [],
Expand Down Expand Up @@ -156,7 +163,7 @@ def from_json(cls, name, spec):
return exp

@classmethod
def _register_if_needed(cls, run_object):
def register_if_needed(cls, run_object):
"""Registers Trainable or Function at runtime.
Assumes already registered if run_object is a string.
Expand Down
49 changes: 38 additions & 11 deletions python/ray/tune/log_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,17 @@
_log_sync_warned = False


def log_sync_template():
def log_sync_template(options=""):
"""Syncs the local_dir between driver and worker if possible.
Requires ray cluster to be started with the autoscaler. Also requires
rsync to be installed.
Args:
options (str): Addtional rsync options.
Returns:
Sync template with source and target parameters.
"""
if not distutils.spawn.find_executable("rsync"):
logger.error("Log sync requires rsync to be installed.")
Expand All @@ -36,12 +41,14 @@ def log_sync_template():
_log_sync_warned = True
return

return ("""rsync -savz -e "ssh -i {ssh_key} -o ConnectTimeout=120s """
"""-o StrictHostKeyChecking=no" {{source}} {{target}}"""
).format(ssh_key=quote(ssh_key))
rsh = "ssh -i {ssh_key} -o ConnectTimeout=120s -o StrictHostKeyChecking=no"
rsh = rsh.format(ssh_key=quote(ssh_key))
template = """rsync {options} -savz -e "{rsh}" {{source}} {{target}}"""
return template.format(options=options, rsh=rsh)


class NodeSyncMixin(object):
# TODO(ujvl): Refactor this code.
"""Mixin for syncing files to/from a remote dir to a local dir."""

def __init__(self):
Expand All @@ -53,23 +60,43 @@ def set_worker_ip(self, worker_ip):
"""Set the worker ip to sync logs from."""
self.worker_ip = worker_ip

def _check_valid_worker_ip(self):
def has_remote_target(self):
"""Returns whether the Syncer has a remote target."""
if not self.worker_ip:
logger.debug("Worker ip unknown, skipping log sync for {}".format(
self._local_dir))
logger.debug("Worker IP unknown, skipping log sync for %s",
self._local_dir)
return False
if self.worker_ip == self.local_ip:
logger.debug(
"Worker ip is local ip, skipping log sync for {}".format(
self._local_dir))
logger.debug("Worker IP is local IP, skipping log sync for %s",
self._local_dir)
return False
return True

def sync_up_if_needed(self):
if not self.has_remote_target():
return True
super(NodeSyncMixin, self).sync_up()

def sync_down_if_needed(self):
if not self.has_remote_target():
return True
super(NodeSyncMixin, self).sync_down()

def sync_down(self):
if not self.has_remote_target():
return True
return super(NodeSyncMixin, self).sync_down()

def sync_up(self):
if not self.has_remote_target():
return True
return super(NodeSyncMixin, self).sync_up()

@property
def _remote_path(self):
ssh_user = get_ssh_user()
global _log_sync_warned
if not self._check_valid_worker_ip():
if not self.has_remote_target():
return
if ssh_user is None:
if not _log_sync_warned:
Expand Down
Loading

0 comments on commit 2965dc1

Please sign in to comment.