forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Scala] add CustomOp support (apache#4118)
* [Scala] add CustomOp support
- Loading branch information
Showing
8 changed files
with
1,225 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
219 changes: 219 additions & 0 deletions
219
scala-package/core/src/main/scala/ml/dmlc/mxnet/Operator.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,219 @@ | ||
package ml.dmlc.mxnet | ||
|
||
import ml.dmlc.mxnet.Base._ | ||
import scala.collection.mutable.ArrayBuffer | ||
|
||
/** | ||
* Base class for operators implemented in Scala | ||
* @author Depeng Liang | ||
*/ | ||
abstract class CustomOp { | ||
|
||
/** | ||
* forward interface. override to create new operators. | ||
* @param isTrain : Boolean | ||
* whether this is for training | ||
* @param req : array of String | ||
* how to assign to outData. can be 'null', 'write', or 'add'. | ||
* You can optionally use this.assign(dst, req, src) to handle this. | ||
* @param inData, outData, aux : array of NDArrays | ||
* input, output, and auxiliary states for forward. See document for | ||
* corresponding arguments of Operator::Forward | ||
*/ | ||
def forward(isTrain: Boolean, req: Array[String], | ||
inData: Array[NDArray], outData: Array[NDArray], aux: Array[NDArray]): Unit | ||
|
||
/** | ||
* backward interface. override to create new operators | ||
* @param req : array of String | ||
* how to assign to inGrad. can be 'null', 'write', or 'add'. | ||
* You can optionally use this.assign(dst, req, src) to handle this. | ||
* @param outGrad, inData, outData, inGrad, aux : array of NDArrays | ||
* input, output, and auxiliary states for backward. See document for | ||
* corresponding arguments of Operator::Backward | ||
*/ | ||
def backward(req: Array[String], outGrad: Array[NDArray], | ||
inData: Array[NDArray], outData: Array[NDArray], | ||
inGrad: Array[NDArray], aux: Array[NDArray]): Unit | ||
|
||
/** | ||
* Helper function for assigning into dst depending on requirements. | ||
*/ | ||
def assign(dst: NDArray, req: String, src: NDArray): Unit = req match { | ||
case "write" | "inplace" => dst.set(src) | ||
case "add" => dst += src | ||
case "null" => {} | ||
} | ||
|
||
/** | ||
* Scala Callback for CustomOp::Forward | ||
*/ | ||
private[mxnet] def forwardEntry(numNdarray: Int, ndarraies: Array[NDArrayHandle], | ||
tags: Array[Int], reqs: Array[Int], isTrain: Boolean): Boolean = { | ||
var success = true | ||
try { | ||
val tensors = (0 until 5).toArray.map( x => ArrayBuffer[NDArray]() ) | ||
for (i <- 0 until numNdarray) { | ||
if (tags(i) == 1 || tags(i) == 4) { | ||
tensors(tags(i)) += new NDArray(ndarraies(i), writable = true) | ||
} else { | ||
tensors(tags(i)) += new NDArray(ndarraies(i), writable = false) | ||
} | ||
} | ||
val reqEnum = Array("null", "write", "inplace", "add") | ||
val reqsArr = tensors(1).indices.map(i => reqEnum(reqs(i))).toArray | ||
this.forward(isTrain = isTrain, req = reqsArr, | ||
inData = tensors(0).toArray, outData = tensors(1).toArray, | ||
aux = tensors(4).toArray) | ||
} catch { | ||
case ex: Throwable => { | ||
success = false | ||
ex.printStackTrace() | ||
} | ||
} | ||
success | ||
} | ||
|
||
/** | ||
* Scala Callback for CustomOp::Backward | ||
*/ | ||
private[mxnet] def backwardEntry(numNdarray: Int, ndarraies: Array[NDArrayHandle], | ||
tags: Array[Int], reqs: Array[Int], isTrain: Boolean): Boolean = { | ||
var success = true | ||
try { | ||
val tensors = (0 until 5).toArray.map( x => ArrayBuffer[NDArray]() ) | ||
for (i <- 0 until numNdarray) { | ||
if (tags(i) == 2 || tags(i) == 4) { | ||
tensors(tags(i)) += new NDArray(ndarraies(i), writable = true) | ||
} else { | ||
tensors(tags(i)) += new NDArray(ndarraies(i), writable = false) | ||
} | ||
} | ||
val reqEnum = Array("null", "write", "inplace", "add") | ||
val reqsArr = tensors(2).indices.map(i => reqEnum(reqs(i))).toArray | ||
this.backward(req = reqsArr, | ||
inData = tensors(0).toArray, outData = tensors(1).toArray, | ||
inGrad = tensors(2).toArray, outGrad = tensors(3).toArray, | ||
aux = tensors(4).toArray) | ||
} catch { | ||
case ex: Throwable => { | ||
success = false | ||
ex.printStackTrace() | ||
} | ||
} | ||
success | ||
} | ||
} | ||
|
||
/** | ||
* Base class for operator property class implemented in Scala. | ||
* MXNET_CPU_WORKER_NTHREADS must be greater than 1 for custom op to work on CPU | ||
* @param needTopGrad : Boolean | ||
* The default declareBackwardDependency function use this value | ||
* to determine whether this operator needs gradient input for above. | ||
*/ | ||
abstract class CustomOpProp(needTopGrad: Boolean = false) { | ||
|
||
protected var kwargs: Map[String, String] = Map[String, String]() | ||
|
||
private[mxnet] def init(keys: Array[String], vals: Array[String]): Unit = { | ||
require(keys.length == vals.length) | ||
kwargs = keys.zip(vals).toMap | ||
} | ||
|
||
/** | ||
* inferShape interface. override to create new operators | ||
* @param inShape : array of array | ||
* list of argument shapes in the same order as declared in listArguments(). | ||
* @return | ||
* inShapes : array of array | ||
* array of argument shapes. Can be modified from inShape. | ||
* outShapes : array of array | ||
* array of output shapes calculated from inShape, | ||
* in the same order as declared in listOutputs(). | ||
* auxShapes : array of array | ||
* array of aux shapes calculated from in_shape, | ||
* in the same order as declared in listAuxiliaryStates(). | ||
*/ | ||
def inferShape(inShape: Array[Shape]): | ||
(Array[Shape], Array[Shape], Array[Shape]) | ||
|
||
/** | ||
* Scala Callback for CustomOp::InferShape | ||
*/ | ||
private[mxnet] def inferShapeEntry( | ||
numTensor: Int, intputShapes: Array[Array[Int]]): Array[Array[Int]] = { | ||
val nIn = this.listArguments().length | ||
val nOut = this.listOutputs().length | ||
val nAux = { | ||
val tmp = this.listAuxiliaryStates() | ||
if (tmp == null) 0 else tmp.length | ||
} | ||
require(numTensor == (nIn + nOut + nAux)) | ||
val (inShapes, outShapes, auxShapes) = | ||
inferShape(intputShapes.map(Shape(_))) | ||
require(inShapes != null && inShapes.length != 0) | ||
require(outShapes != null && outShapes.length != 0) | ||
if (auxShapes != null && auxShapes.length != 0) { | ||
inShapes.map(_.toArray) ++ outShapes.map(_.toArray) ++ auxShapes.map(_.toArray) | ||
} else inShapes.map(_.toArray) ++ outShapes.map(_.toArray) | ||
} | ||
|
||
/** | ||
* listOutputs interface. override to create new operators | ||
* @return | ||
* outputs : array of String | ||
* list of output blob names. | ||
*/ | ||
def listOutputs(): Array[String] | ||
|
||
/** | ||
* listArguments interface. override to create new operators | ||
* @return | ||
* arguments : array of String | ||
* list of argument blob names. | ||
*/ | ||
def listArguments(): Array[String] | ||
|
||
/** | ||
* listAuxiliaryStates interface. override to create new operators | ||
* @return | ||
* auxs : array of String | ||
* list of auxiliary state blob names. | ||
*/ | ||
def listAuxiliaryStates(): Array[String] = null | ||
|
||
/** | ||
* Declare dependencies of this operator for backward pass. | ||
* @param outGrad : array of Int | ||
* ids of outGrad blobs. | ||
* @param inData : array of Int | ||
* ids of inData blobs. | ||
* @param outData : array of Int | ||
* ids of outData blobs. | ||
* @return | ||
* deps : array of Int | ||
* ids of the needed blobs. | ||
*/ | ||
def declareBackwardDependency(outGrad: Array[Int], | ||
inData: Array[Int], outData: Array[Int]): Array[Int] = { | ||
val deps = ArrayBuffer[Array[Int]]() | ||
if (this.needTopGrad) deps += outGrad | ||
deps += inData | ||
deps += outData | ||
deps.toArray.flatten | ||
} | ||
|
||
/** | ||
* Create an operator that carries out the real computation | ||
* given the context, input shapes, and input data types. | ||
*/ | ||
def createOperator(ctx: String, inShapes: Array[Array[Int]], inDtypes: Array[Int]): CustomOp | ||
|
||
} | ||
|
||
object Operator { | ||
def register(regName: String, opProp: CustomOpProp): Unit = { | ||
checkCall(_LIB.mxCustomOpRegister(regName, opProp)) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
#!/bin/bash | ||
|
||
MXNET_ROOT=$(cd "$(dirname $0)/../../../.."; pwd) | ||
CLASS_PATH=$MXNET_ROOT/scala-package/assembly/linux-x86_64-gpu/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/* | ||
|
||
# MXNET_CPU_WORKER_NTHREADS must be greater than 1 for custom op to work on CPU | ||
export MXNET_CPU_WORKER_NTHREADS=2 | ||
|
||
# which gpu card to use, -1 means cpu | ||
GPU=$1 | ||
|
||
# the mnist data path | ||
# you can get the mnist data using the script core/scripts/get_mnist_data.sh | ||
DATA_PATH=$2 | ||
|
||
java -Xmx4G -cp $CLASS_PATH \ | ||
ml.dmlc.mxnet.examples.customop.ExampleCustomOp \ | ||
--data-path $DATA_PATH \ | ||
--gpu $GPU \ |
16 changes: 16 additions & 0 deletions
16
scala-package/examples/scripts/customop/run_customopwithrtc.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
#!/bin/bash | ||
|
||
MXNET_ROOT=$(cd "$(dirname $0)/../../../.."; pwd) | ||
CLASS_PATH=$MXNET_ROOT/scala-package/assembly/linux-x86_64-gpu/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/* | ||
|
||
# which gpu card to use | ||
GPU=0 | ||
|
||
# the mnist data path | ||
# you can get the mnist data using the script core/scripts/get_mnist_data.sh | ||
DATA_PATH=$1 | ||
|
||
java -Xmx4G -cp $CLASS_PATH \ | ||
ml.dmlc.mxnet.examples.customop.ExampleCustomOpWithRtc \ | ||
--data-path $DATA_PATH \ | ||
--gpu $GPU \ |
33 changes: 33 additions & 0 deletions
33
scala-package/examples/src/main/scala/ml/dmlc/mxnet/examples/customop/Data.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
package ml.dmlc.mxnet.examples.customop | ||
|
||
import ml.dmlc.mxnet.Shape | ||
import ml.dmlc.mxnet.IO | ||
import ml.dmlc.mxnet.DataIter | ||
|
||
/** | ||
* @author Depeng Liang | ||
*/ | ||
object Data { | ||
// return train and val iterators for mnist | ||
def mnistIterator(dataPath: String, batchSize: Int, inputShape: Shape): (DataIter, DataIter) = { | ||
val flat = if (inputShape.length == 3) "False" else "True" | ||
val trainParams = Map( | ||
"image" -> s"$dataPath/train-images-idx3-ubyte", | ||
"label" -> s"$dataPath/train-labels-idx1-ubyte", | ||
"input_shape" -> inputShape.toString(), | ||
"batch_size" -> s"$batchSize", | ||
"shuffle" -> "True", | ||
"flat" -> flat | ||
) | ||
val trainDataIter = IO.MNISTIter(trainParams) | ||
val testParams = Map( | ||
"image" -> s"$dataPath/t10k-images-idx3-ubyte", | ||
"label" -> s"$dataPath/t10k-labels-idx1-ubyte", | ||
"input_shape" -> inputShape.toString(), | ||
"batch_size" -> s"$batchSize", | ||
"flat" -> flat | ||
) | ||
val testDataIter = IO.MNISTIter(testParams) | ||
(trainDataIter, testDataIter) | ||
} | ||
} |
Oops, something went wrong.