Skip to content

Commit

Permalink
Merge pull request amplab#32 from amplab/callbackperf
Browse files Browse the repository at this point in the history
Callback performance improvements
  • Loading branch information
robertnishihara committed Nov 30, 2015
2 parents d6976cc + 2f08ddb commit 15e4f1d
Show file tree
Hide file tree
Showing 3 changed files with 228 additions and 1 deletion.
59 changes: 59 additions & 0 deletions src/main/java/libs/ByteImage.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// this only works for a 2D image, but a similar class can be written for ND tensors

import java.awt.image.BufferedImage;

class ByteImage {
byte[] red;
byte[] green;
byte[] blue;

final int width;
final int height;

// create a byte image from a BufferedImage
public ByteImage(BufferedImage image) {
width = image.getWidth();
height = image.getHeight();
int[] pixels = image.getRGB(0, 0, width, height, null, 0, width);
red = new byte[width * height];
green = new byte[width * height];
blue = new byte[width * height];


for (int row = 0; row < height; row++) {
for (int col = 0; col < width; col++) {
int rgb = pixels[row * width + col];
red[row * width + col] = (byte)((rgb >> 16) & 0xFF);
green[row * width + col] = (byte)((rgb >> 8) & 0xFF);
blue[row * width + col] = (byte)(rgb & 0xFF);
}
}
}

// create an "empty" byte image
public ByteImage(int width, int height) {
this.width = width;
this.height = height;
red = new byte[width * height];
green = new byte[width * height];
blue = new byte[width * height];
}

void cropInto(float[] buffer, int[] lowerOffsets, int[] upperOffsets) {
assert(0 <= lowerOffsets[0] && lowerOffsets[0] < upperOffsets[0] && upperOffsets[0] <= height);
assert(0 <= lowerOffsets[1] && lowerOffsets[1] < upperOffsets[1] && upperOffsets[1] <= width);

final int h = upperOffsets[0] - lowerOffsets[0];
final int w = upperOffsets[1] - lowerOffsets[1];
final int lr = lowerOffsets[0];
final int lc = lowerOffsets[1];

for(int row = 0; row < w; row++) {
for(int col = 0; col < h; col++) {
buffer[0 * w * h + row * w + col] = red[(row + lr) * width + (col + lc)];
buffer[1 * w * h + row * w + col] = green[(row + lr) * width + (col + lc)];
buffer[2 * w * h + row * w + col] = blue[(row + lr) * width + (col + lc)];
}
}
}
}
168 changes: 168 additions & 0 deletions src/test/scala/apps/CallbackBenchmarkSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
/*
Testing the speed of various ways to do callbacks; on pcm's laptop:
0.149: before empty callback
0.149: before copy callback
0.176: before simple callback
0.708: before byte image callback
0.78: before full callback
1.947: end
(these are cumulative numbers, so the elapsed times are the differences between
successive lines)
Takaway: We can copy stuff out of the array, but fancy indexing is very expensive,
and also setting floats via JNA in a loop is expensive
*/

import org.scalatest._
import java.io._
import scala.util.Random
import com.sun.jna.Pointer
import com.sun.jna.Memory


import libs._
import loaders._
import preprocessing._

class CallbackBenchmarkSpec extends FlatSpec {
val trainBatchSize = 256
val testBatchSize = 50
val channels = 3
val fullWidth = 256
val fullHeight = 256
val croppedWidth = 227
val croppedHeight = 227

val intSize = 4
val dtypeSize = 4

val startTime = System.currentTimeMillis()

def log(message: String, i: Int = -1) {
val elapsedTime = 1F * (System.currentTimeMillis() - startTime) / 1000
if (i == -1) {
print(elapsedTime.toString + ": " + message + "\n")
} else {
print(elapsedTime.toString + ", i = " + i.toString + ": "+ message + "\n")
}
}

val minibatch = Array.range(0, trainBatchSize).map(
i => {
ByteNDArray(new Array[Byte](channels * fullWidth * fullHeight), Array(channels, fullWidth, fullHeight))
}
).toArray

val byteImageMinibatch = Array.range(0, trainBatchSize).map(
i => {
new ByteImage(fullWidth, fullHeight)
}
).toArray

val emptyCallback = makeEmptyImageCallback(minibatch)
val copyCallback = makeCopyImageCallback(minibatch)
val simpleCallback = makeSimpleImageCallback(minibatch)
val preprocessing = (image: ByteImage, buffer: Array[Float]) => {
image.cropInto(buffer, Array[Int](0, 0), Array[Int](227, 227))
var i = 0
while (i < 227 * 227 * 3) {
buffer(i) -= 100.0F
i += 1
}
()
}
val byteImageCallback = makeByteImageCallback(byteImageMinibatch, Some(preprocessing))
val fullCallback = makeImageCallback(minibatch)
val data = new Memory(channels * fullWidth * fullHeight * trainBatchSize * dtypeSize);

log("before empty callback")
emptyCallback.invoke(data, trainBatchSize, 3, new Pointer(0))
log("before copy callback")
copyCallback.invoke(data, trainBatchSize, 3, new Pointer(0))
log("before simple callback")
simpleCallback.invoke(data, trainBatchSize, 3, new Pointer(0))
log("before byte image callback")
byteImageCallback.invoke(data, trainBatchSize, 3, new Pointer(0))
log("before full callback")
fullCallback.invoke(data, trainBatchSize, 3, new Pointer(0))
log("end")

private def makeEmptyImageCallback(minibatch: Array[ByteNDArray], preprocessing: Option[ByteNDArray => NDArray] = None): CaffeLibrary.java_callback_t = {
return new CaffeLibrary.java_callback_t() {
def invoke(data: Pointer, batchSize: Int, numDims: Int, shape: Pointer) {
}
}
}

private def makeCopyImageCallback(minibatch: Array[ByteNDArray], preprocessing: Option[ByteNDArray => NDArray] = None): CaffeLibrary.java_callback_t = {
return new CaffeLibrary.java_callback_t() {
def invoke(data: Pointer, batchSize: Int, numDims: Int, shape: Pointer) {
val currentImageBatch = minibatch
var flatArray = new Array[Float](channels * fullWidth * fullHeight)
for (j <- 0 to batchSize - 1) {
val currentImage = currentImageBatch(j)
currentImage.copyBufferToFloatArray(flatArray)
}
}
}
}

private def makeSimpleImageCallback(minibatch: Array[ByteNDArray], preprocessing: Option[ByteNDArray => NDArray] = None): CaffeLibrary.java_callback_t = {
return new CaffeLibrary.java_callback_t() {
def invoke(data: Pointer, batchSize: Int, numDims: Int, shape: Pointer) {
for (j <- 0 to batchSize - 1) {
var i = 0
val flatSize = channels * fullWidth * fullHeight
while (i < flatSize) {
data.setFloat((j * flatSize + i) * dtypeSize, 0.0F)
i += 1
}
}
}
}
}

private def makeByteImageCallback(minibatch: Array[ByteImage], preprocessing: Option[(ByteImage, Array[Float]) => Unit] = None): CaffeLibrary.java_callback_t = {
return new CaffeLibrary.java_callback_t() {
def invoke(data: Pointer, batchSize: Int, numDims: Int, shape: Pointer) {
var buffer = new Array[Float](227 * 227 * 3)
for (j <- 0 to batchSize - 1) {
preprocessing.get(minibatch(j), buffer)
data.write(j * 227 * 227 * 3, buffer, 0, 227 * 227 * 3);
}
}
}
}

private def makeImageCallback(minibatch: Array[ByteNDArray], preprocessing: Option[ByteNDArray => NDArray] = None): CaffeLibrary.java_callback_t = {
return new CaffeLibrary.java_callback_t() {
def invoke(data: Pointer, batchSize: Int, numDims: Int, shape: Pointer) {
val currentImageBatch = minibatch
assert(currentImageBatch.length == batchSize)

for (j <- 0 to batchSize - 1) {
val currentImage = currentImageBatch(j)
val processedImage = {
if (preprocessing.isEmpty) {
currentImage.toFloatNDArray() // didn't test this code path
} else {
preprocessing.get(currentImage)
}
}

val flatImage = processedImage.toFlat() // this allocation could be avoided
val flatSize = flatImage.length
var i = 0
while (i < flatSize) {
data.setFloat((j * flatSize + i) * dtypeSize, flatImage(i))
i += 1
}
}
}
}
}
}
2 changes: 1 addition & 1 deletion src/test/scala/libs/WeightCollectionSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class WeightCollectionSpec extends FlatSpec {
val imgSize = 227

var netParameter = ProtoLoader.loadNetPrototxt(sparkNetHome + "/caffe/models/bvlc_reference_caffenet/train_val.prototxt")
netParameter = ProtoLoader.replaceDataLayers(netParameter, batchsize, channels, imgSize, imgSize)
netParameter = ProtoLoader.replaceDataLayers(netParameter, batchsize, batchsize, channels, imgSize, imgSize)
val solverParameter = ProtoLoader.loadSolverPrototxtWithNet(sparkNetHome + "/caffe/models/bvlc_reference_caffenet/solver.prototxt", netParameter, None)
val net = CaffeNet(solverParameter)
var netWeights = net.getWeights()
Expand Down

0 comments on commit 15e4f1d

Please sign in to comment.