Skip to content

Commit

Permalink
[tune] added type hints (ray-project#10806)
Browse files Browse the repository at this point in the history
Co-authored-by: Richard Liaw <[email protected]>
  • Loading branch information
krfricke and richardliaw authored Sep 16, 2020
1 parent 5e030db commit c9fafe7
Show file tree
Hide file tree
Showing 31 changed files with 709 additions and 472 deletions.
1 change: 1 addition & 0 deletions doc/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __getattr__(cls, name):
"horovod",
"horovod.ray",
"kubernetes",
"mxnet",
"mxnet.model",
"psutil",
"ray._raylet",
Expand Down
85 changes: 56 additions & 29 deletions python/ray/tune/analysis/experiment_analysis.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import json
import logging
import os
from typing import Dict
from numbers import Number
from typing import Any, Dict, List, Optional, Tuple

from ray.tune.checkpoint_manager import Checkpoint
from ray.tune.utils import flatten_dict

try:
Expand Down Expand Up @@ -37,7 +37,10 @@ class Analysis:
in the respective functions.
"""

def __init__(self, experiment_dir, default_metric=None, default_mode=None):
def __init__(self,
experiment_dir: str,
default_metric: Optional[str] = None,
default_mode: Optional[str] = None):
experiment_dir = os.path.expanduser(experiment_dir)
if not os.path.isdir(experiment_dir):
raise ValueError(
Expand All @@ -59,14 +62,14 @@ def __init__(self, experiment_dir, default_metric=None, default_mode=None):
else:
self.fetch_trial_dataframes()

def _validate_metric(self, metric):
def _validate_metric(self, metric: str) -> str:
if not metric and not self.default_metric:
raise ValueError(
"No `metric` has been passed and `default_metric` has "
"not been set. Please specify the `metric` parameter.")
return metric or self.default_metric

def _validate_mode(self, mode):
def _validate_mode(self, mode: str) -> str:
if not mode and not self.default_mode:
raise ValueError(
"No `mode` has been passed and `default_mode` has "
Expand All @@ -75,7 +78,9 @@ def _validate_mode(self, mode):
raise ValueError("If set, `mode` has to be one of [min, max]")
return mode or self.default_mode

def dataframe(self, metric=None, mode=None):
def dataframe(self,
metric: Optional[str] = None,
mode: Optional[str] = None) -> DataFrame:
"""Returns a pandas.DataFrame object constructed from the trials.
Args:
Expand All @@ -97,7 +102,9 @@ def dataframe(self, metric=None, mode=None):
rows[path].update(logdir=path)
return pd.DataFrame(list(rows.values()))

def get_best_config(self, metric=None, mode=None):
def get_best_config(self,
metric: Optional[str] = None,
mode: Optional[str] = None) -> Optional[Dict]:
"""Retrieve the best config corresponding to the trial.
Args:
Expand All @@ -122,7 +129,9 @@ def get_best_config(self, metric=None, mode=None):
best_path = compare_op(rows, key=lambda k: rows[k][metric])
return all_configs[best_path]

def get_best_logdir(self, metric=None, mode=None):
def get_best_logdir(self,
metric: Optional[str] = None,
mode: Optional[str] = None) -> Optional[str]:
"""Retrieve the logdir corresponding to the best trial.
Args:
Expand All @@ -148,7 +157,7 @@ def get_best_logdir(self, metric=None, mode=None):
self._experiment_dir))
return None

def fetch_trial_dataframes(self):
def fetch_trial_dataframes(self) -> Dict[str, DataFrame]:
fail_count = 0
for path in self._get_trial_paths():
try:
Expand All @@ -162,15 +171,16 @@ def fetch_trial_dataframes(self):
"Couldn't read results from {} paths".format(fail_count))
return self.trial_dataframes

def get_all_configs(self, prefix=False):
def get_all_configs(self, prefix: bool = False) -> Dict[str, Dict]:
"""Returns a list of all configurations.
Args:
prefix (bool): If True, flattens the config dict
and prepends `config/`.
Returns:
List[dict]: List of all configurations of trials,
Dict[str, Dict]: Dict of all configurations of trials, indexed by
their trial dir.
"""
fail_count = 0
for path in self._get_trial_paths():
Expand All @@ -189,7 +199,10 @@ def get_all_configs(self, prefix=False):
"Couldn't read config from {} paths".format(fail_count))
return self._configs

def get_trial_checkpoints_paths(self, trial, metric=None):
def get_trial_checkpoints_paths(self,
trial: Trial,
metric: Optional[str] = None
) -> List[Tuple[str, Number]]:
"""Gets paths and metrics of all persistent checkpoints of a trial.
Args:
Expand All @@ -215,11 +228,14 @@ def get_trial_checkpoints_paths(self, trial, metric=None):
return path_metric_df[["chkpt_path", metric]].values.tolist()
elif isinstance(trial, Trial):
checkpoints = trial.checkpoint_manager.best_checkpoints()
return [[c.value, c.result[metric]] for c in checkpoints]
return [(c.value, c.result[metric]) for c in checkpoints]
else:
raise ValueError("trial should be a string or a Trial instance.")

def get_best_checkpoint(self, trial, metric=None, mode=None):
def get_best_checkpoint(self,
trial: Trial,
metric: Optional[str] = None,
mode: Optional[str] = None) -> Optional[str]:
"""Gets best persistent checkpoint path of provided trial.
Args:
Expand All @@ -244,7 +260,9 @@ def get_best_checkpoint(self, trial, metric=None, mode=None):
else:
return min(checkpoint_paths, key=lambda x: x[1])[0]

def _retrieve_rows(self, metric=None, mode=None):
def _retrieve_rows(self,
metric: Optional[str] = None,
mode: Optional[str] = None) -> Dict[str, Any]:
assert mode is None or mode in ["max", "min"]
rows = {}
for path, df in self.trial_dataframes.items():
Expand All @@ -264,7 +282,7 @@ def _retrieve_rows(self, metric=None, mode=None):

return rows

def _get_trial_paths(self):
def _get_trial_paths(self) -> List[str]:
_trial_paths = []
for trial_path, _, files in os.walk(self._experiment_dir):
if EXPR_PROGRESS_FILE in files:
Expand All @@ -276,7 +294,7 @@ def _get_trial_paths(self):
return _trial_paths

@property
def trial_dataframes(self):
def trial_dataframes(self) -> Dict[str, DataFrame]:
"""List of all dataframes of the trials."""
return self._trial_dataframes

Expand Down Expand Up @@ -306,10 +324,10 @@ class ExperimentAnalysis(Analysis):
"""

def __init__(self,
experiment_checkpoint_path,
trials=None,
default_metric=None,
default_mode=None):
experiment_checkpoint_path: str,
trials: Optional[List[Trial]] = None,
default_metric: Optional[str] = None,
default_mode: Optional[str] = None):
experiment_checkpoint_path = os.path.expanduser(
experiment_checkpoint_path)
if not os.path.isfile(experiment_checkpoint_path):
Expand Down Expand Up @@ -365,8 +383,8 @@ def best_config(self) -> Dict:
return self.get_best_config(self.default_metric, self.default_mode)

@property
def best_checkpoint(self) -> Checkpoint:
"""Get the checkpoint of the best trial of the experiment
def best_checkpoint(self) -> str:
"""Get the checkpoint path of the best trial of the experiment
The best trial is determined by comparing the last trial results
using the `metric` and `mode` parameters passed to `tune.run()`.
Expand Down Expand Up @@ -471,7 +489,10 @@ def results_df(self) -> DataFrame:
],
index="trial_id")

def get_best_trial(self, metric=None, mode=None, scope="last"):
def get_best_trial(self,
metric: Optional[str] = None,
mode: Optional[str] = None,
scope: str = "last") -> Optional[Trial]:
"""Retrieve the best trial object.
Compares all trials' scores on ``metric``.
Expand Down Expand Up @@ -535,7 +556,10 @@ def get_best_trial(self, metric=None, mode=None, scope="last"):
"parameter?")
return best_trial

def get_best_config(self, metric=None, mode=None, scope="last"):
def get_best_config(self,
metric: Optional[str] = None,
mode: Optional[str] = None,
scope: str = "last") -> Optional[Dict]:
"""Retrieve the best config corresponding to the trial.
Compares all trials' scores on `metric`.
Expand All @@ -562,7 +586,10 @@ def get_best_config(self, metric=None, mode=None, scope="last"):
best_trial = self.get_best_trial(metric, mode, scope)
return best_trial.config if best_trial else None

def get_best_logdir(self, metric=None, mode=None, scope="last"):
def get_best_logdir(self,
metric: Optional[str] = None,
mode: Optional[str] = None,
scope: str = "last") -> Optional[str]:
"""Retrieve the logdir corresponding to the best trial.
Compares all trials' scores on `metric`.
Expand All @@ -589,15 +616,15 @@ def get_best_logdir(self, metric=None, mode=None, scope="last"):
best_trial = self.get_best_trial(metric, mode, scope)
return best_trial.logdir if best_trial else None

def stats(self):
def stats(self) -> Dict:
"""Returns a dictionary of the statistics of the experiment."""
return self._experiment_state.get("stats")

def runner_data(self):
def runner_data(self) -> Dict:
"""Returns a dictionary of the TrialRunner data."""
return self._experiment_state.get("runner_data")

def _get_trial_paths(self):
def _get_trial_paths(self) -> List[str]:
"""Overwrites Analysis to only have trials of one experiment."""
if self.trials:
_trial_paths = [t.logdir for t in self.trials]
Expand Down
33 changes: 18 additions & 15 deletions python/ray/tune/integration/horovod.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
import logging
from typing import Callable, Dict, Type

from filelock import FileLock

import ray
Expand All @@ -15,11 +17,11 @@
logger = logging.getLogger(__name__)


def get_rank():
def get_rank() -> str:
return os.environ["HOROVOD_RANK"]


def logger_creator(log_config, logdir):
def logger_creator(log_config: Dict, logdir: str) -> NoopLogger:
"""Simple NOOP logger for worker trainables."""
index = get_rank()
worker_dir = os.path.join(logdir, "worker_{}".format(index))
Expand Down Expand Up @@ -51,7 +53,7 @@ class _HorovodTrainable(tune.Trainable):
def num_workers(self):
return self._num_hosts * self._num_slots

def setup(self, config):
def setup(self, config: Dict):
trainable = wrap_function(self.__class__._function)
# We use a filelock here to ensure that the file-writing
# process is safe across different trainables.
Expand Down Expand Up @@ -82,22 +84,22 @@ def setup(self, config):
"logger_creator": lambda cfg: logger_creator(cfg, logdir_)
})

def step(self):
def step(self) -> Dict:
if self._finished:
raise RuntimeError("Training has already finished.")
result = self.executor.execute(lambda w: w.step())[0]
if RESULT_DUPLICATE in result:
self._finished = True
return result

def save_checkpoint(self, checkpoint_dir):
def save_checkpoint(self, checkpoint_dir: str) -> str:
# TODO: optimize if colocated
save_obj = self.executor.execute_single(lambda w: w.save_to_object())
checkpoint_path = TrainableUtil.create_from_pickle(
save_obj, checkpoint_dir)
return checkpoint_path

def load_checkpoint(self, checkpoint_dir):
def load_checkpoint(self, checkpoint_dir: str):
checkpoint_obj = TrainableUtil.checkpoint_to_object(checkpoint_dir)
x_id = ray.put(checkpoint_obj)
return self.executor.execute(lambda w: w.restore_from_object(x_id))
Expand All @@ -107,13 +109,14 @@ def stop(self):
self.executor.shutdown()


def DistributedTrainableCreator(func,
use_gpu=False,
num_hosts=1,
num_slots=1,
num_cpus_per_slot=1,
timeout_s=30,
replicate_pem=False):
def DistributedTrainableCreator(
func: Callable,
use_gpu: bool = False,
num_hosts: int = 1,
num_slots: int = 1,
num_cpus_per_slot: int = 1,
timeout_s: int = 30,
replicate_pem: bool = False) -> Type[_HorovodTrainable]:
"""Converts Horovod functions to be executable by Tune.
Requires horovod > 0.19 to work.
Expand Down Expand Up @@ -198,7 +201,7 @@ class WrappedHorovodTrainable(_HorovodTrainable):
_timeout_s = timeout_s

@classmethod
def default_resource_request(cls, config):
def default_resource_request(cls, config: Dict):
extra_gpu = int(num_hosts * num_slots) * int(use_gpu)
extra_cpu = int(num_hosts * num_slots * num_cpus_per_slot)

Expand All @@ -216,7 +219,7 @@ def default_resource_request(cls, config):
# that force us to include mocks as part of the module.


def _train_simple(config):
def _train_simple(config: Dict):
import horovod.torch as hvd
hvd.init()
from ray import tune
Expand Down
2 changes: 1 addition & 1 deletion python/ray/tune/integration/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, on: Union[str, List[str]] = "validation_end"):
on, self._allowed))
self._on = on

def _handle(self, logs: Dict):
def _handle(self, logs: Dict, when: str):
raise NotImplementedError

def on_batch_begin(self, batch, logs=None):
Expand Down
Loading

0 comments on commit c9fafe7

Please sign in to comment.