Skip to content

Commit

Permalink
Merge pull request amplab#43 from amplab/workerstore
Browse files Browse the repository at this point in the history
create nets in main and hold refs in WorkerStore
  • Loading branch information
pcmoritz committed Dec 14, 2015
2 parents 571544c + 1fd3f18 commit 0b3fa9d
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 30 deletions.
37 changes: 22 additions & 15 deletions src/main/scala/apps/CifarApp.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,7 @@ object CifarApp {
val imShape = Array(channels, height, width)
val size = imShape.product

// initialize nets on workers
val sparkNetHome = "/root/SparkNet"
System.load(sparkNetHome + "/build/libccaffe.so")
var netParameter = ProtoLoader.loadNetPrototxt(sparkNetHome + "/caffe/examples/cifar10/cifar10_full_train_test.prototxt")
netParameter = ProtoLoader.replaceDataLayers(netParameter, trainBatchSize, testBatchSize, channels, height, width)
val solverParameter = ProtoLoader.loadSolverPrototxtWithNet(sparkNetHome + "/caffe/examples/cifar10/cifar10_full_solver.prototxt", netParameter, None)
val net = CaffeNet(solverParameter)
val workerStore = new WorkerStore()

def main(args: Array[String]) {
val numWorkers = args(0).toInt
Expand All @@ -36,6 +30,8 @@ object CifarApp {
.set("spark.task.maxFailures", "1")
val sc = new SparkContext(conf)

val sparkNetHome = sys.env("SPARKNET_HOME")

// information for logging
val startTime = System.currentTimeMillis()
val trainingLog = new PrintWriter(new File("training_log_" + startTime.toString + ".txt" ))
Expand All @@ -49,8 +45,6 @@ object CifarApp {
trainingLog.flush()
}

var netWeights = net.getWeights()

val loader = new CifarLoader(sparkNetHome + "/caffe/data/cifar10/")
log("loading train data")
var trainRDD = sc.parallelize(loader.trainImages.zip(loader.trainLabels))
Expand Down Expand Up @@ -83,12 +77,25 @@ object CifarApp {

val workers = sc.parallelize(Array.range(0, numWorkers), numWorkers)

// initialize nets on workers
workers.foreach(_ => {
System.load(sparkNetHome + "/build/libccaffe.so")
var netParameter = ProtoLoader.loadNetPrototxt(sparkNetHome + "/caffe/examples/cifar10/cifar10_full_train_test.prototxt")
netParameter = ProtoLoader.replaceDataLayers(netParameter, trainBatchSize, testBatchSize, channels, height, width)
val solverParameter = ProtoLoader.loadSolverPrototxtWithNet(sparkNetHome + "/caffe/examples/cifar10/cifar10_full_solver.prototxt", netParameter, None)
val net = CaffeNet(solverParameter)
workerStore.setNet("net", net)
})

// initialize weights on master
var netWeights = workers.map(_ => workerStore.getNet("net").getWeights()).collect()(0)

var i = 0
while (true) {
log("broadcasting weights", i)
val broadcastWeights = sc.broadcast(netWeights)
log("setting weights on workers", i)
workers.foreach(_ => net.setWeights(broadcastWeights.value))
workers.foreach(_ => workerStore.getNet("net").setWeights(broadcastWeights.value))

if (i % 10 == 0) {
log("testing, i")
Expand All @@ -98,8 +105,8 @@ object CifarApp {
val len = lenIt.next
assert(!lenIt.hasNext)
val minibatchSampler = new MinibatchSampler(testMinibatchIt, len, len)
net.setTestData(minibatchSampler, len, None)
Array(net.test()).iterator // do testing
workerStore.getNet("net").setTestData(minibatchSampler, len, None)
Array(workerStore.getNet("net").test()).iterator // do testing
}
).cache()
val testScoresAggregate = testScores.reduce((a, b) => (a, b).zipped.map(_ + _))
Expand All @@ -115,14 +122,14 @@ object CifarApp {
val len = lenIt.next
assert(!lenIt.hasNext)
val minibatchSampler = new MinibatchSampler(trainMinibatchIt, len, syncInterval)
net.setTrainData(minibatchSampler, None)
net.train(syncInterval)
workerStore.getNet("net").setTrainData(minibatchSampler, None)
workerStore.getNet("net").train(syncInterval)
Array(0).iterator
}
).foreachPartition(_ => ())

log("collecting weights", i)
netWeights = workers.map(_ => { net.getWeights() }).reduce((a, b) => WeightCollection.add(a, b))
netWeights = workers.map(_ => { workerStore.getNet("net").getWeights() }).reduce((a, b) => WeightCollection.add(a, b))
netWeights.scalarDivide(1F * numWorkers)
i += 1
}
Expand Down
37 changes: 22 additions & 15 deletions src/main/scala/apps/ImageNetApp.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,7 @@ object ImageNetApp {
val fullImShape = Array(channels, fullHeight, fullWidth)
val fullImSize = fullImShape.product

// initialize nets on workers
val sparkNetHome = "/root/SparkNet"
System.load(sparkNetHome + "/build/libccaffe.so")
var netParameter = ProtoLoader.loadNetPrototxt(sparkNetHome + "/caffe/models/bvlc_reference_caffenet/train_val.prototxt")
netParameter = ProtoLoader.replaceDataLayers(netParameter, trainBatchSize, testBatchSize, channels, croppedHeight, croppedWidth)
val solverParameter = ProtoLoader.loadSolverPrototxtWithNet(sparkNetHome + "/caffe/models/bvlc_reference_caffenet/solver.prototxt", netParameter, None)
val net = CaffeNet(solverParameter)
val workerStore = new WorkerStore()

def main(args: Array[String]) {
val numWorkers = args(0).toInt
Expand All @@ -44,6 +38,8 @@ object ImageNetApp {
.set("spark.eventLog.enabled", "true")
val sc = new SparkContext(conf)

val sparkNetHome = sys.env("SPARKNET_HOME")

// information for logging
val startTime = System.currentTimeMillis()
val trainingLog = new PrintWriter(new File("training_log_" + startTime.toString + ".txt" ))
Expand All @@ -57,8 +53,6 @@ object ImageNetApp {
trainingLog.flush()
}

var netWeights = net.getWeights()

val loader = new ImageNetLoader("sparknet")
log("loading train data")
var trainRDD = loader.apply(sc, "ILSVRC2012_training/", "train.txt")
Expand Down Expand Up @@ -96,12 +90,25 @@ object ImageNetApp {

val workers = sc.parallelize(Array.range(0, numWorkers), numWorkers)

// initialize nets on workers
workers.foreach(_ => {
System.load(sparkNetHome + "/build/libccaffe.so")
var netParameter = ProtoLoader.loadNetPrototxt(sparkNetHome + "/caffe/models/bvlc_reference_caffenet/train_val.prototxt")
netParameter = ProtoLoader.replaceDataLayers(netParameter, trainBatchSize, testBatchSize, channels, croppedHeight, croppedWidth)
val solverParameter = ProtoLoader.loadSolverPrototxtWithNet(sparkNetHome + "/caffe/models/bvlc_reference_caffenet/solver.prototxt", netParameter, None)
val net = CaffeNet(solverParameter)
workerStore.setNet("net", net)
})

// initialize weights on master
var netWeights = workers.map(_ => workerStore.getNet("net").getWeights()).collect()(0)

var i = 0
while (true) {
log("broadcasting weights", i)
val broadcastWeights = sc.broadcast(netWeights)
log("setting weights on workers", i)
workers.foreach(_ => net.setWeights(broadcastWeights.value))
workers.foreach(_ => workerStore.getNet("net").setWeights(broadcastWeights.value))

if (i % 10 == 0) {
log("testing", i)
Expand Down Expand Up @@ -130,8 +137,8 @@ object ImageNetApp {
}
}
val minibatchSampler = new MinibatchSampler(testMinibatchIt, len, len)
net.setTestData(minibatchSampler, len, Some(imageNetTestPreprocessing))
Array(net.test()).iterator // do testing
workerStore.getNet("net").setTestData(minibatchSampler, len, Some(imageNetTestPreprocessing))
Array(workerStore.getNet("net").test()).iterator // do testing
}
).cache() // the function inside has side effects, so we need the cache to ensure we don't redo it
// add up test accuracies (a and b are arrays in case there are multiple test layers)
Expand Down Expand Up @@ -168,14 +175,14 @@ object ImageNetApp {
}
}
val minibatchSampler = new MinibatchSampler(trainMinibatchIt, len, syncInterval)
net.setTrainData(minibatchSampler, Some(imageNetTrainPreprocessing))
net.train(syncInterval) // train for syncInterval minibatches
workerStore.getNet("net").setTrainData(minibatchSampler, Some(imageNetTrainPreprocessing))
workerStore.getNet("net").train(syncInterval) // train for syncInterval minibatches
Array(0).iterator // give the closure the right signature
}
).foreachPartition(_ => ())

log("collecting weights", i)
netWeights = workers.map(_ => net.getWeights()).reduce((a, b) => WeightCollection.add(a, b))
netWeights = workers.map(_ => workerStore.getNet("net").getWeights()).reduce((a, b) => WeightCollection.add(a, b))
netWeights.scalarDivide(1F * numWorkers)

i += 1
Expand Down
15 changes: 15 additions & 0 deletions src/main/scala/libs/WorkerStore.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package libs

import libs._

class WorkerStore() {
var nets: Map[String, CaffeNet] = Map()

def setNet(name: String, net: CaffeNet) = {
nets += (name -> net)
}

def getNet(name: String): CaffeNet = {
return nets(name)
}
}

0 comments on commit 0b3fa9d

Please sign in to comment.