Skip to content

Commit

Permalink
divide mean image by number of training data points
Browse files Browse the repository at this point in the history
  • Loading branch information
robertnishihara committed Feb 20, 2016
1 parent 463c9bf commit e8073fa
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions src/main/scala/apps/ImageNetApp.scala
Original file line number Diff line number Diff line change
Expand Up @@ -57,21 +57,20 @@ object ImageNetApp {
var trainDF = sqlContext.createDataFrame(trainRDD.map{ case (a, b) => Row(a, b)}, schema)
var testDF = sqlContext.createDataFrame(testRDD.map{ case (a, b) => Row(a, b)}, schema)

val numTrainData = trainDF.count()
logger.log("numTrainData = " + numTrainData.toString)
val numTestData = testDF.count()
logger.log("numTestData = " + numTestData.toString)

logger.log("computing mean image")
val meanImage = trainDF.map(row => row(0).asInstanceOf[Array[Byte]].map(e => e.toLong))
.reduce((a, b) => (a, b).zipped.map(_ + _))
.map(e => e.toFloat)
.map(e => (e.toDouble / numTrainData).toFloat)

logger.log("coalescing") // if you want to shuffle your data, replace coalesce with repartition
trainDF = trainDF.coalesce(numWorkers)
testDF = testDF.coalesce(numWorkers)

val numTrainData = trainDF.count()
logger.log("numTrainData = " + numTrainData.toString)

val numTestData = testDF.count()
logger.log("numTestData = " + numTestData.toString)

val trainPartitionSizes = trainDF.mapPartitions(iter => Array(iter.size).iterator).persist()
val testPartitionSizes = testDF.mapPartitions(iter => Array(iter.size).iterator).persist()
trainPartitionSizes.foreach(size => workerStore.put("trainPartitionSize", size))
Expand Down

0 comments on commit e8073fa

Please sign in to comment.