Skip to content

Commit

Permalink
[sgd] Add checkpointing (ray-project#3638)
Browse files Browse the repository at this point in the history
  • Loading branch information
pschafhalter authored and robertnishihara committed Jan 8, 2019
1 parent 5e76d52 commit 5945b92
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 0 deletions.
6 changes: 6 additions & 0 deletions python/ray/experimental/sgd/mnist_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ def get_metrics(self):
})
return {"accuracy": accuracy}

def get_weights(self):
return self.variables.get_flat()

def set_weights(self, weights):
self.variables.set_flat(weights)


def train_mnist(config, reporter):
args = config["args"]
Expand Down
24 changes: 24 additions & 0 deletions python/ray/experimental/sgd/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,27 @@ def get_feed_dict(self):
TensorFlow feed_dict to add to the gradient operation.
"""
return {}

def get_weights(self):
"""Return weights from the model.
Implementing `get_weights` is required for checkpointing and fault
tolerance.
Returns:
Numpy array of weights from the model.
"""
raise NotImplementedError(
"get_weights of %s is not implemented" % self.__class__.__name__)

def set_weights(self, weights):
"""Sets the model weights.
Implementing `set_weights` is required for checkpointing and fault
tolerance.
Args:
weights: numpy array of weights for the model.
"""
raise NotImplementedError(
"set_weights of %s is not implemented" % self.__class__.__name__)
11 changes: 11 additions & 0 deletions python/ray/experimental/sgd/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import print_function

import logging
import os
import random
import time

Expand Down Expand Up @@ -177,6 +178,16 @@ def warmup(self):
ray.get([w.warmup.remote() for w in self.workers])
logger.info("Warmup complete")

def save_checkpoint(self, path):
w0 = self.for_model(lambda m: m.get_weights())
filename = os.path.join(path, "model.npy")
np.save(filename, w0)

def restore_checkpoint(self, path):
filename = os.path.join(path, "model.npy")
w0 = np.load(filename)
self.foreach_model(lambda m: m.set_weights(w0))


def _average_gradients(grads):
out = []
Expand Down
77 changes: 77 additions & 0 deletions python/ray/experimental/sgd/test_save_and_restore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#!/usr/bin/env python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os
import time

import ray
from ray.experimental.sgd.tfbench.test_model import TFBenchModel
from ray.experimental.sgd.sgd import DistributedSGD

parser = argparse.ArgumentParser()
parser.add_argument("--redis-address", default=None, type=str)
parser.add_argument("--num-iters", default=10, type=int)
parser.add_argument("--batch-size", default=1, type=int)
parser.add_argument("--num-workers", default=2, type=int)
parser.add_argument("--grad-shard-bytes", default=10000000, type=int)
parser.add_argument("--devices-per-worker", default=2, type=int)
parser.add_argument("--all-reduce-alg", default="simple", type=str)
parser.add_argument("--object-store-memory", default=None, type=int)
parser.add_argument("--checkpoint-dir", default="/tmp", type=str)
parser.add_argument(
"--strategy", default="simple", type=str, help="One of 'simple' or 'ps'")
parser.add_argument(
"--gpu", action="store_true", help="Use GPUs for optimization")

if __name__ == "__main__":
args, _ = parser.parse_known_args()
ray.init(
redis_address=args.redis_address,
object_store_memory=args.object_store_memory)

model_creator = (
lambda worker_idx, device_idx: TFBenchModel(
batch=args.batch_size, use_cpus=not args.gpu))

sgd = DistributedSGD(
model_creator,
num_workers=args.num_workers,
devices_per_worker=args.devices_per_worker,
gpu=args.gpu,
strategy=args.strategy,
grad_shard_bytes=args.grad_shard_bytes,
all_reduce_alg=args.all_reduce_alg)

if not os.path.exists(args.checkpoint_dir):
raise ValueError(
"Checkpoint directory does not exist: %s" % args.checkpoint_dir)

def step(i):
start = time.time()
print("== Step {} ==".format(i))
stats = sgd.step(fetch_stats=True)
ips = ((args.batch_size * args.num_workers * args.devices_per_worker) /
(time.time() - start))
print("Iteration time", time.time() - start, "Images per second", ips)
print("Current loss", stats)

i = 0
while i < args.num_iters:
step(i)
i += 1

print("Saving checkpoint...")
sgd.save_checkpoint(args.checkpoint_dir)
print("Done saving checkpoint")

step(i)

print("Restoring checkpoint")
sgd.restore_checkpoint(args.checkpoint_dir)
print("Done restoring checkpoint")

step(i)
10 changes: 10 additions & 0 deletions python/ray/experimental/sgd/tfbench/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from tfbench import model_config
from ray.experimental.sgd.model import Model
from ray.experimental.tfutils import TensorFlowVariables


class MockDataset():
Expand Down Expand Up @@ -46,6 +47,9 @@ def __init__(self, batch=64, use_cpus=False):
self.loss = tf.reduce_mean(loss, name='xentropy-loss')
self.optimizer = tf.train.GradientDescentOptimizer(1e-6)

self.variables = TensorFlowVariables(self.loss,
tf.get_default_session())

def get_loss(self):
return self.loss

Expand All @@ -54,3 +58,9 @@ def get_optimizer(self):

def get_feed_dict(self):
return {}

def get_weights(self):
return self.variables.get_flat()

def set_weights(self, weights):
self.variables.set_flat(weights)
8 changes: 8 additions & 0 deletions test/jenkins_tests/run_multi_node_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,14 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/experimental/sgd/test_sgd.py --num-iters=2 \
--batch-size=1 --strategy=ps

docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/experimental/sgd/test_save_and_restore.py --num-iters=2 \
--batch-size=1 --strategy=simple

docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/experimental/sgd/test_save_and_restore.py --num-iters=2 \
--batch-size=1 --strategy=ps

docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/experimental/sgd/mnist_example.py --num-iters=1 \
--num-workers=1 --devices-per-worker=1 --strategy=ps
Expand Down

0 comments on commit 5945b92

Please sign in to comment.