Skip to content

Commit

Permalink
[tune] Sync logs from workers and improve tensorboard reporting (ray-…
Browse files Browse the repository at this point in the history
  • Loading branch information
ericl authored and richardliaw committed Feb 26, 2018
1 parent aefefcb commit 87e107e
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 35 deletions.
23 changes: 23 additions & 0 deletions python/ray/tune/cluster_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import getpass
import os


def get_ssh_user():
"""Returns ssh username for connecting to cluster workers."""

return getpass.getuser()


# TODO(ekl) this currently only works for clusters launched with
# ray create_or_update
def get_ssh_key():
"""Returns ssh key to connecting to cluster workers."""

path = os.path.expanduser("~/ray_bootstrap_key.pem")
if os.path.exists(path):
return path
return None
88 changes: 71 additions & 17 deletions python/ray/tune/log_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@

import distutils.spawn
import os
import pipes
import subprocess
import time

import ray
from ray.tune.cluster_info import get_ssh_key, get_ssh_user
from ray.tune.error import TuneError
from ray.tune.result import DEFAULT_RESULTS_DIR

Expand All @@ -15,20 +18,21 @@
_syncers = {}


def get_syncer(local_dir, remote_dir):
if not remote_dir.startswith("s3://"):
raise TuneError("Upload uri must start with s3://")
def get_syncer(local_dir, remote_dir=None):
if remote_dir:
if not remote_dir.startswith("s3://"):
raise TuneError("Upload uri must start with s3://")

if not distutils.spawn.find_executable("aws"):
raise TuneError("Upload uri requires awscli tool to be installed")
if not distutils.spawn.find_executable("aws"):
raise TuneError("Upload uri requires awscli tool to be installed")

if local_dir.startswith(DEFAULT_RESULTS_DIR + "/"):
rel_path = os.path.relpath(local_dir, DEFAULT_RESULTS_DIR)
remote_dir = os.path.join(remote_dir, rel_path)
if local_dir.startswith(DEFAULT_RESULTS_DIR + "/"):
rel_path = os.path.relpath(local_dir, DEFAULT_RESULTS_DIR)
remote_dir = os.path.join(remote_dir, rel_path)

key = (local_dir, remote_dir)
if key not in _syncers:
_syncers[key] = _S3LogSyncer(local_dir, remote_dir)
_syncers[key] = _LogSyncer(local_dir, remote_dir)

return _syncers[key]

Expand All @@ -38,23 +42,64 @@ def wait_for_log_sync():
syncer.wait()


class _S3LogSyncer(object):
def __init__(self, local_dir, remote_dir):
class _LogSyncer(object):
"""Log syncer for tune.
This syncs files from workers to the local node, and optionally also from
the local node to a remote directory (e.g. S3)."""

def __init__(self, local_dir, remote_dir=None):
self.local_dir = local_dir
self.remote_dir = remote_dir
self.last_sync_time = 0
self.sync_process = None
print("Created S3LogSyncer for {} -> {}".format(local_dir, remote_dir))
self.local_ip = ray.services.get_node_ip_address()
self.worker_ip = None
print("Created LogSyncer for {} -> {}".format(local_dir, remote_dir))

def set_worker_ip(self, worker_ip):
"""Set the worker ip to sync logs from."""

self.worker_ip = worker_ip

def sync_if_needed(self):
if time.time() - self.last_sync_time > 300:
self.sync_now()

def sync_now(self, force=False):
print(
"Syncing files from {} -> {}".format(
self.local_dir, self.remote_dir))
self.last_sync_time = time.time()
if not self.worker_ip:
print(
"Worker ip unknown, skipping log sync for {}".format(
self.local_dir))
return

if self.worker_ip == self.local_ip:
worker_to_local_sync_cmd = None # don't need to rsync
else:
ssh_key = get_ssh_key()
ssh_user = get_ssh_user()
if ssh_key is None or ssh_user is None:
print(
"Error: log sync requires cluster to be setup with "
"`ray create_or_update`.")
return
if not distutils.spawn.find_executable("rsync"):
print("Error: log sync requires rsync to be installed.")
return
worker_to_local_sync_cmd = (
("""rsync -avz -e "ssh -i '{}' -o ConnectTimeout=120s """
"""-o StrictHostKeyChecking=no" '{}@{}:{}/' '{}/'""").format(
ssh_key, ssh_user, self.worker_ip,
pipes.quote(self.local_dir), pipes.quote(self.local_dir)))

if self.remote_dir:
local_to_remote_sync_cmd = (
"aws s3 sync '{}' '{}'".format(
pipes.quote(self.local_dir), pipes.quote(self.remote_dir)))
else:
local_to_remote_sync_cmd = None

if self.sync_process:
self.sync_process.poll()
if self.sync_process.returncode is None:
Expand All @@ -63,8 +108,17 @@ def sync_now(self, force=False):
else:
print("Warning: last sync is still in progress, skipping")
return
self.sync_process = subprocess.Popen(
["aws", "s3", "sync", self.local_dir, self.remote_dir])

if worker_to_local_sync_cmd or local_to_remote_sync_cmd:
final_cmd = ""
if worker_to_local_sync_cmd:
final_cmd += worker_to_local_sync_cmd
if local_to_remote_sync_cmd:
if final_cmd:
final_cmd += " && "
final_cmd += local_to_remote_sync_cmd
print("Running log sync: {}".format(final_cmd))
self.sync_process = subprocess.Popen(final_cmd, shell=True)

def wait(self):
if self.sync_process:
Expand Down
50 changes: 32 additions & 18 deletions python/ray/tune/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@ class Logger(object):
multiple formats (TensorBoard, rllab/viskit, plain json) at once.
"""

_attrs_to_log = [
"time_this_iter_s", "mean_loss", "mean_accuracy",
"episode_reward_mean", "episode_len_mean"]

def __init__(self, config, logdir, upload_uri=None):
self.config = config
self.logdir = logdir
Expand All @@ -47,6 +43,11 @@ def close(self):

pass

def flush(self):
"""Flushes all disk writes to storage."""

pass


class UnifiedLogger(Logger):
"""Unified result logger for TensorBoard, rllab/viskit, plain json.
Expand All @@ -60,22 +61,22 @@ def _init(self):
print("TF not installed - cannot log with {}...".format(cls))
continue
self._loggers.append(cls(self.config, self.logdir, self.uri))
if self.uri:
self._log_syncer = get_syncer(self.logdir, self.uri)
else:
self._log_syncer = None
self._log_syncer = get_syncer(self.logdir, self.uri)

def on_result(self, result):
for logger in self._loggers:
logger.on_result(result)
if self._log_syncer:
self._log_syncer.sync_if_needed()
self._log_syncer.set_worker_ip(result.node_ip)
self._log_syncer.sync_if_needed()

def close(self):
for logger in self._loggers:
logger.close()
if self._log_syncer:
self._log_syncer.sync_now(force=True)
self._log_syncer.sync_now(force=True)

def flush(self):
self._log_syncer.sync_now(force=True)
self._log_syncer.wait()


class NoopLogger(Logger):
Expand Down Expand Up @@ -103,17 +104,30 @@ def close(self):
self.local_out.close()


def to_tf_values(result, path):
values = []
for attr, value in result.items():
if value is not None:
if type(value) in [int, float]:
values.append(tf.Summary.Value(
tag="/".join(path + [attr]),
simple_value=value))
elif type(value) is dict:
values.extend(to_tf_values(value, path + [attr]))
return values


class _TFLogger(Logger):
def _init(self):
self._file_writer = tf.summary.FileWriter(self.logdir)

def on_result(self, result):
values = []
for attr in Logger._attrs_to_log:
if getattr(result, attr) is not None:
values.append(tf.Summary.Value(
tag="ray/tune/{}".format(attr),
simple_value=getattr(result, attr)))
tmp = result._asdict()
for k in [
"config", "pid", "timestamp", "time_total_s",
"timesteps_total"]:
del tmp[k] # not useful to tf log these
values = to_tf_values(tmp, ["ray", "tune"])
train_stats = tf.Summary(value=values)
self._file_writer.add_summary(train_stats, result.timesteps_total)

Expand Down
3 changes: 3 additions & 0 deletions python/ray/tune/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@
# (Auto-filled) The hostname of the machine hosting the training process.
"hostname",

# (Auto-filled) The node ip of the machine hosting the training process.
"node_ip",

# (Auto=filled) The current hyperparameter configuration.
"config",
])
Expand Down
3 changes: 3 additions & 0 deletions python/ray/tune/trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import time
import uuid

import ray
from ray.tune import TuneError
from ray.tune.logger import UnifiedLogger
from ray.tune.result import DEFAULT_RESULTS_DIR
Expand Down Expand Up @@ -87,6 +88,7 @@ def __init__(self, config=None, registry=None, logger_creator=None):
self._timesteps_total = 0
self._setup()
self._initialize_ok = True
self._local_ip = ray.services.get_node_ip_address()

def train(self):
"""Runs one logical iteration of training.
Expand Down Expand Up @@ -136,6 +138,7 @@ def train(self):
neg_mean_loss=neg_loss,
pid=os.getpid(),
hostname=os.uname()[1],
node_ip=self._local_ip,
config=self.config)

self._result_logger.on_result(result)
Expand Down
1 change: 1 addition & 0 deletions python/ray/tune/trial_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def _try_recover(self, trial, error_msg):
try:
print("Attempting to recover trial state from last checkpoint")
trial.stop(error=True, error_msg=error_msg, stop_logger=False)
trial.result_logger.flush() # make sure checkpoint is synced
trial.start()
self._running[trial.train_remote()] = trial
except Exception:
Expand Down

0 comments on commit 87e107e

Please sign in to comment.