Skip to content

Commit

Permalink
[train] Storage refactor: Support PBT and BOHB (ray-project#38736)
Browse files Browse the repository at this point in the history
Signed-off-by: Kai Fricke <[email protected]>
  • Loading branch information
krfricke authored Aug 25, 2023
1 parent bfb57f7 commit 5e3b2f7
Show file tree
Hide file tree
Showing 24 changed files with 500 additions and 229 deletions.
19 changes: 10 additions & 9 deletions .buildkite/pipeline.ml.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,15 @@
--test_tag_filters=-gpu_only,-gpu,-minimal,-tune,-doctest,-new_storage
python/ray/train/...

- label: ":steam_locomotive: :octopus: Train + Tune tests and examples"
conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_TRAIN_AFFECTED"]
instance_size: medium
commands:
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
- TRAIN_TESTING=1 TUNE_TESTING=1 ./ci/env/install-dependencies.sh
- ./ci/env/env_info.sh
- bazel test --config=ci $(./ci/run/bazel_export_options) --build_tests_only --test_tag_filters=tune,-gpu_only,-ray_air,-gpu,-doctest,-new_storage python/ray/train/...
# Currently empty test suite
#- label: ":steam_locomotive: :octopus: Train + Tune tests and examples"
# conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_TRAIN_AFFECTED"]
# instance_size: medium
# commands:
# - cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
# - TRAIN_TESTING=1 TUNE_TESTING=1 ./ci/env/install-dependencies.sh
# - ./ci/env/env_info.sh
# - bazel test --config=ci $(./ci/run/bazel_export_options) --build_tests_only --test_tag_filters=tune,-gpu_only,-ray_air,-gpu,-doctest,-new_storage python/ray/train/...


- label: ":brain: RLlib: Benchmarks (Torch 2.x)"
Expand Down Expand Up @@ -523,7 +524,7 @@
# (see https://github.com/ray-project/ray/pull/38432/)
- pip install "transformers==4.30.2" "datasets==2.14.0"
- ./ci/env/env_info.sh
- bazel test --config=ci $(./ci/run/bazel_export_options) --build_tests_only --test_tag_filters=-timeseries_libs,-external,-ray_air,-gpu,-post_wheel_build,-doctest,-datasets_train,-highly_parallel doc/...
- bazel test --config=ci $(./ci/run/bazel_export_options) --build_tests_only --test_tag_filters=-timeseries_libs,-external,-ray_air,-gpu,-post_wheel_build,-doctest,-datasets_train,-highly_parallel,-new_storage doc/...

- label: ":book: Doc tests and examples with time series libraries"
conditions:
Expand Down
14 changes: 14 additions & 0 deletions doc/source/tune/examples/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,28 @@ py_test_run_all_notebooks(
size = "medium",
include = ["*.ipynb"],
exclude = [
"bohb_example.ipynb",
"pbt_ppo_example.ipynb",
"pbt_guide.ipynb",
"tune-xgboost.ipynb",
"tune-xgboost.ipynb",
"sigopt_example.ipynb", # REGRESSION: no credentials
],
data = ["//doc/source/tune/examples:tune_examples"],
tags = ["exclusive", "team:ml"],
)

py_test_run_all_notebooks(
size = "medium",
include = [
"bohb_example.ipynb",
"pbt_guide.ipynb"
],
exclude = [],
data = ["//doc/source/tune/examples:tune_examples"],
tags = ["exclusive", "team:ml", "new_storage"],
)

# GPU tests
py_test_run_all_notebooks(
size = "large",
Expand Down
2 changes: 1 addition & 1 deletion python/ray/air/_internal/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def _process_persistent_checkpoint(
# Only remove if checkpoint data is different
if worst_checkpoint.dir_or_data != checkpoint.dir_or_data:
self._maybe_delete_persisted_checkpoint(worst_checkpoint)
logger.debug(f"Removed worst checkpoint from " f"{worst_checkpoint}.")
logger.debug(f"Removed worst checkpoint from {worst_checkpoint}.")

self._replace_latest_persisted_checkpoint(checkpoint)
else:
Expand Down
4 changes: 2 additions & 2 deletions python/ray/train/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ py_test(
size = "medium",
main = "examples/pytorch/tune_cifar_torch_pbt_example.py",
srcs = ["examples/pytorch/tune_cifar_torch_pbt_example.py"],
tags = ["team:ml", "exclusive", "pytorch", "tune", "no_new_storage"],
tags = ["team:ml", "exclusive", "pytorch", "tune", "new_storage"],
deps = [":train_lib"],
args = ["--smoke-test"]
)
Expand Down Expand Up @@ -109,7 +109,7 @@ py_test(
name = "horovod_cifar_pbt_example",
size = "small",
srcs = ["examples/horovod/horovod_cifar_pbt_example.py"],
tags = ["team:ml", "exlusive", "no_new_storage"],
tags = ["team:ml", "exlusive", "new_storage"],
deps = [":train_lib"],
args = ["--smoke-test"]
)
Expand Down
30 changes: 30 additions & 0 deletions python/ray/train/_internal/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,36 @@ class TrainingResult:
metadata: Optional[Dict] = None


class _FutureTrainingResult:
"""A future that will be resolved to a `_TrainingResult`.
This is needed for specific schedulers such as PBT that schedule saves.
This wrapper should be removed after refactoring PBT to not schedule saves anymore.
"""

def __init__(self, future: ray.ObjectRef):
self.future = future

def resolve(self, block: bool = True) -> Optional["_TrainingResult"]:
"""Resolve into ``_TrainingResult``.
This will return None for function trainables if no checkpoint has been
saved before.
"""
if block:
timeout = None
else:
timeout = 1e-9
try:
return ray.get(self.future, timeout=timeout)
except TimeoutError:
# Not ready, yet
pass
except Exception as exc:
logger.error(f"Error resolving result: {exc}")


class _TrainingResult:
"""A (checkpoint, metrics) result reported by the user."""

Expand Down
31 changes: 21 additions & 10 deletions python/ray/train/examples/horovod/horovod_cifar_pbt_example.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os
import tempfile

import numpy as np
import torch
import torch.nn as nn
Expand All @@ -7,13 +10,14 @@
from torchvision.models import resnet18

import ray
import ray.cloudpickle as cpickle
from ray.train import (
Checkpoint,
CheckpointConfig,
FailureConfig,
RunConfig,
ScalingConfig,
)
from ray.train._checkpoint import Checkpoint
import ray.train.torch
from ray.train.horovod import HorovodTrainer
from ray import train, tune
Expand Down Expand Up @@ -52,7 +56,10 @@ def train_loop_per_worker(config):

checkpoint = train.get_checkpoint()
if checkpoint:
checkpoint_dict = checkpoint.to_dict()
with checkpoint.as_directory() as checkpoint_dir:
with open(os.path.join(checkpoint_dir, "data.ckpt"), "rb") as fp:
checkpoint_dict = cpickle.load(fp)

model_state = checkpoint_dict["model_state"]
optimizer_state = checkpoint_dict["optimizer_state"]
epoch = checkpoint_dict["epoch"] + 1
Expand Down Expand Up @@ -111,14 +118,18 @@ def train_loop_per_worker(config):
if config["smoke_test"]:
break

checkpoint = Checkpoint.from_dict(
dict(
model_state=net.state_dict(),
optimizer_state=optimizer.state_dict(),
epoch=epoch,
)
)
train.report(dict(loss=running_loss / epoch_steps), checkpoint=checkpoint)
with tempfile.TemporaryDirectory() as checkpoint_dir:
with open(os.path.join(checkpoint_dir, "data.ckpt"), "wb") as fp:
cpickle.dump(
dict(
model_state=net.state_dict(),
optimizer_state=optimizer.state_dict(),
epoch=epoch,
),
fp,
)
checkpoint = Checkpoint.from_directory(checkpoint_dir)
train.report(dict(loss=running_loss / epoch_steps), checkpoint=checkpoint)


if __name__ == "__main__":
Expand Down
29 changes: 19 additions & 10 deletions python/ray/train/examples/pytorch/tune_cifar_torch_pbt_example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import os
import tempfile

import torch
import torch.nn as nn
Expand All @@ -10,8 +11,10 @@
from torchvision.models import resnet18

import ray
import ray.cloudpickle as cpickle
from ray import train, tune
from ray.train import Checkpoint, FailureConfig, RunConfig, ScalingConfig
from ray.train import FailureConfig, RunConfig, ScalingConfig
from ray.train._checkpoint import Checkpoint
from ray.train.torch import TorchTrainer
from ray.tune.schedulers import PopulationBasedTraining
from ray.tune.tune_config import TuneConfig
Expand Down Expand Up @@ -80,7 +83,9 @@ def train_func(config):

starting_epoch = 0
if train.get_checkpoint():
checkpoint_dict = train.get_checkpoint().to_dict()
with train.get_checkpoint().as_directory() as checkpoint_dir:
with open(os.path.join(checkpoint_dir, "data.ckpt"), "rb") as fp:
checkpoint_dict = cpickle.load(fp)

# Load in model
model_state = checkpoint_dict["model"]
Expand Down Expand Up @@ -144,15 +149,19 @@ def train_func(config):
for epoch in range(starting_epoch, epochs):
train_epoch(train_loader, model, criterion, optimizer)
result = validate_epoch(validation_loader, model, criterion)
checkpoint = Checkpoint.from_dict(
{
"epoch": epoch,
"model": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
}
)

train.report(result, checkpoint=checkpoint)
with tempfile.TemporaryDirectory() as checkpoint_dir:
with open(os.path.join(checkpoint_dir, "data.ckpt"), "wb") as fp:
cpickle.dump(
{
"epoch": epoch,
"model": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
},
fp,
)
checkpoint = Checkpoint.from_directory(checkpoint_dir)
train.report(result, checkpoint=checkpoint)


if __name__ == "__main__":
Expand Down
22 changes: 12 additions & 10 deletions python/ray/tune/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -351,23 +351,23 @@ py_test(
size = "large",
srcs = ["tests/test_trial_scheduler.py"],
deps = [":tune_lib"],
tags = ["team:ml", "exclusive", "medium_instance", "no_new_storage"],
tags = ["team:ml", "exclusive", "medium_instance", "new_storage"],
)

py_test(
name = "test_trial_scheduler_pbt",
size = "large",
srcs = ["tests/test_trial_scheduler_pbt.py"],
deps = [":tune_lib"],
tags = ["team:ml", "exclusive", "medium_instance", "no_new_storage"],
tags = ["team:ml", "exclusive", "medium_instance", "new_storage"],
)

py_test(
name = "test_trial_scheduler_resource_changing",
size = "small",
srcs = ["tests/test_trial_scheduler_resource_changing.py"],
deps = [":tune_lib"],
tags = ["team:ml", "exclusive", "no_new_storage"],
tags = ["team:ml", "exclusive"],
)

py_test(
Expand Down Expand Up @@ -557,7 +557,7 @@ py_test(
size = "medium",
srcs = ["examples/bohb_example.py"],
deps = [":tune_lib"],
tags = ["team:ml", "exclusive", "example", "no_new_storage"]
tags = ["team:ml", "exclusive", "example", "new_storage"]
)

py_test(
Expand Down Expand Up @@ -753,7 +753,7 @@ py_test(
size = "small",
srcs = ["examples/pbt_convnet_example.py"],
deps = [":tune_lib"],
tags = ["team:ml", "exclusive", "example", "no_new_storage"],
tags = ["team:ml", "exclusive", "example"],
args = ["--smoke-test"]
)

Expand All @@ -762,7 +762,7 @@ py_test(
size = "small",
srcs = ["examples/pbt_convnet_function_example.py"],
deps = [":tune_lib"],
tags = ["team:ml", "exclusive", "example", "no_new_storage"],
tags = ["team:ml", "exclusive", "example"],
args = ["--smoke-test"]
)

Expand All @@ -771,7 +771,7 @@ py_test(
size = "medium",
srcs = ["examples/pbt_dcgan_mnist/pbt_dcgan_mnist_func.py"],
deps = [":tune_lib"],
tags = ["team:ml", "exclusive", "example", "no_new_storage"],
tags = ["team:ml", "exclusive", "example"],
args = ["--smoke-test"]
)

Expand All @@ -780,7 +780,7 @@ py_test(
size = "medium",
srcs = ["examples/pbt_dcgan_mnist/pbt_dcgan_mnist_trainable.py"],
deps = [":tune_lib"],
tags = ["team:ml", "exclusive", "example", "no_new_storage"],
tags = ["team:ml", "exclusive", "example"],
args = ["--smoke-test"]
)

Expand All @@ -789,7 +789,7 @@ py_test(
size = "small",
srcs = ["examples/pbt_example.py"],
deps = [":tune_lib"],
tags = ["team:ml", "exclusive", "example", "no_new_storage"],
tags = ["team:ml", "exclusive", "example"],
args = ["--smoke-test"]
)

Expand All @@ -807,7 +807,7 @@ py_test(
size = "small",
srcs = ["examples/pbt_memnn_example.py"],
deps = [":tune_lib"],
tags = ["team:ml", "exclusive", "example", "no_new_storage"],
tags = ["team:ml", "exclusive", "example"],
args = ["--smoke-test"]
)

Expand All @@ -821,6 +821,8 @@ py_test(
# args = ["--smoke-test"]
# )

# Exclude from new storage tests as transformers still uses the old tune
# API.
py_test(
name = "pbt_transformers",
size = "small",
Expand Down
8 changes: 0 additions & 8 deletions python/ray/tune/examples/pbt_convnet_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,3 @@ def stop_all(self):

best_result = results.get_best_result()
best_checkpoint = best_result.checkpoint

restored_trainable = PytorchTrainable()
restored_trainable.restore(best_checkpoint)
best_model = restored_trainable.model
# Note that test only runs on a small random set of the test data, thus the
# accuracy may be different from metrics shown in tuning process.
test_acc = test_func(best_model, get_data_loaders()[1])
print("best model accuracy: ", test_acc)
6 changes: 3 additions & 3 deletions python/ray/tune/examples/pbt_convnet_function_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import ray
from ray import train, tune
from ray.train import Checkpoint
from ray.train._checkpoint import Checkpoint
from ray.tune.schedulers import PopulationBasedTraining

# __tutorial_imports_end__
Expand Down Expand Up @@ -65,7 +65,7 @@ def train_convnet(config):
# __train_end__


def test_best_model(results: tune.ResultGrid):
def eval_best_model(results: tune.ResultGrid):
"""Test the best model given output of tuner.fit()."""
with results.get_best_result().checkpoint.as_directory() as best_checkpoint_path:
best_model = ConvNet()
Expand Down Expand Up @@ -141,4 +141,4 @@ def stop_all(self):
results = tuner.fit()
# __tune_end__

test_best_model(results)
eval_best_model(results)
2 changes: 1 addition & 1 deletion python/ray/tune/examples/pbt_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def reset_config(self, new_config):
checkpoint_config=train.CheckpointConfig(
checkpoint_frequency=perturbation_interval,
checkpoint_score_attribute="mean_accuracy",
num_to_keep=2,
num_to_keep=4,
),
),
tune_config=tune.TuneConfig(
Expand Down
Loading

0 comments on commit 5e3b2f7

Please sign in to comment.