Skip to content

Commit

Permalink
Merge pull request amplab#83 from amplab/loaderspeedup
Browse files Browse the repository at this point in the history
read cifar in as floats
  • Loading branch information
pcmoritz committed Feb 19, 2016
2 parents 8ac5d11 + aebe126 commit 463c9bf
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
4 changes: 2 additions & 2 deletions src/main/scala/apps/CifarApp.scala
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ object CifarApp {

// convert to dataframes
val schema = StructType(StructField("data", ArrayType(FloatType), false) :: StructField("label", IntegerType, false) :: Nil)
var trainDF = sqlContext.createDataFrame(trainRDD.map{ case (a, b) => Row(a.map(x => x.toFloat), b)}, schema)
var testDF = sqlContext.createDataFrame(testRDD.map{ case (a, b) => Row(a.map(x => x.toFloat), b)}, schema)
var trainDF = sqlContext.createDataFrame(trainRDD.map{ case (a, b) => Row(a, b)}, schema)
var testDF = sqlContext.createDataFrame(testRDD.map{ case (a, b) => Row(a, b)}, schema)

logger.log("repartition data")
trainDF = trainDF.repartition(numWorkers).cache()
Expand Down
10 changes: 5 additions & 5 deletions src/main/scala/loaders/CifarLoader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ class CifarLoader(path: String) {
val nBatches = 5
val nData = nBatches * batchSize

val trainImages = new Array[Array[Byte]](nData)
val trainImages = new Array[Array[Float]](nData)
val trainLabels = new Array[Int](nData)

val testImages = new Array[Array[Byte]](batchSize)
val testImages = new Array[Array[Float]](batchSize)
val testLabels = new Array[Int](batchSize)

val r = new Random()
Expand Down Expand Up @@ -62,7 +62,7 @@ class CifarLoader(path: String) {
}
}

def readBatch(file: File, batch: Int, images: Array[Array[Byte]], labels: Array[Int], perm: Vector[Int]) {
def readBatch(file: File, batch: Int, images: Array[Array[Float]], labels: Array[Int], perm: Vector[Int]) {
val buffer = new Array[Byte](1 + size)
val inputStream = new FileInputStream(file)

Expand All @@ -72,11 +72,11 @@ class CifarLoader(path: String) {
while(nRead != -1) {
assert(i < batchSize)
labels(perm(batch * batchSize + i)) = (buffer(0) & 0xFF) // convert to unsigned
images(perm(batch * batchSize + i)) = new Array[Byte](size)
images(perm(batch * batchSize + i)) = new Array[Float](size)
var j = 0
while (j < size) {
// we access buffer(j + 1) because the 0th position holds the label
images(perm(batch * batchSize + i))(j) = buffer(j + 1)
images(perm(batch * batchSize + i))(j) = buffer(j + 1) & 0xFF
j += 1
}
nRead = inputStream.read(buffer)
Expand Down

0 comments on commit 463c9bf

Please sign in to comment.