Skip to content

Commit

Permalink
load from definition (intel#1587)
Browse files Browse the repository at this point in the history
* load from definition

* remove dup
  • Loading branch information
wzhongyuan authored Sep 25, 2017
1 parent a5c7465 commit b00fe5d
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,26 @@ import java.io._

import scala.collection.JavaConverters._
import com.google.protobuf.CodedInputStream
import com.intel.analytics.bigdl.nn.Container
import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity}
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.File
import com.intel.analytics.bigdl.utils.{File, Table}
import serialization.Bigdl.BigDLModule

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

object ModuleLoader {

/**
* load module from `modelPath`
* @param modelPath path where protobuf formatted module is stored
* @param ev numeric ops
* @tparam T data type
* @return loaded BigDL module
*/
def loadFromFile[T: ClassTag](modelPath : String)
(implicit ev: TensorNumeric[T]) : AbstractModule[Activity, Activity, T] = {
val modelBuilder = BigDLModule.newBuilder
Expand All @@ -39,10 +49,97 @@ object ModuleLoader {
val bigDLModel = modelBuilder.build()
ModuleSerializer.load(bigDLModel).module
}

/**
* Load weights from `modulePath` and copy to pre-defined module
* for `layers` layers, copy all if not specified
* @param definition pre-defined module
* @param modelPath path where protobuf formatted module is stored
* @param layers name list of layers weight & bias of which to be copied
* @param ev numeric ops
* @tparam T data type
*/

def loadFromDefinition[T : ClassTag](definition : AbstractModule[Activity, Activity, T],
modelPath : String, layers : mutable.HashSet[String] = null)(implicit ev: TensorNumeric[T])
: Unit = {
val loadedModule = loadFromFile(modelPath)
val layersToCopy = if (layers == null) {
val allLayers = new mutable.HashSet[String]()
getAllLayers(definition, allLayers)
allLayers
} else {
layers
}
copyParams(definition, loadedModule, layersToCopy)
}

private def getAllLayers[T : ClassTag](module : AbstractModule[Activity, Activity, T],
layers : mutable.HashSet[String]) : Unit
= {
layers.add(module.getName)
if (module.isInstanceOf[Container[_, _, _]]) {
module.asInstanceOf[Container[_, _, _]].modules.foreach(subModule => {
getAllLayers(subModule, layers)
})
}
}

private def copyParams[T : ClassTag](definition : AbstractModule[Activity, Activity, T],
mirror : AbstractModule[Activity, Activity, T],
layers : mutable.HashSet[String]) : Unit = {
val parameterTable = definition.getParametersTable()
val copiedParameterTable = mirror.getParametersTable()
layers.foreach(name => {
if (parameterTable.contains(name)) {
require(copiedParameterTable.contains(name), s"$name does not exist in loaded module")
copyParams(parameterTable.get(name).get.asInstanceOf[Table],
copiedParameterTable.get(name).get.asInstanceOf[Table])
}
})
}

private def copyParams[T : ClassTag](params : Table, copyParams : Table) : Unit = {
copyParam(params, copyParams, "weight")
copyParam(params, copyParams, "bias")
}

private def copyParam[T : ClassTag](params : Table,
copyParams : Table, paraName : String) : Unit = {
if (params.contains(paraName)) {
// this is for quantization tensors where the weight might be an array
if (copyParams.get(paraName).get
.isInstanceOf[Array[Tensor[T]]]) {
require(params.get(paraName).get
.isInstanceOf[Array[Tensor[T]]], "param type mismatch!")
val copies = params.get(paraName).get
.asInstanceOf[Array[Tensor[T]]]
val origins = params.get(paraName).get
.asInstanceOf[Array[Tensor[T]]]
var i = 0
while (i < copies.length) {
origins(i).copy(copies(i))
i += 1
}
} else {
// For normal layers, their params are just tensors
params.get(paraName).get.asInstanceOf[Tensor[T]].copy(
copyParams.get(paraName).get.asInstanceOf[Tensor[T]])
}
}
}
}

object ModulePersister {

/**
* Persist module to specified path
* @param modelPath path to persist module to
* @param module module to be persisted
* @param overwrite if overwrite module file if exists
* @param ev numeric ops
* @tparam T data type
*/
def saveToFile[T: ClassTag](modelPath: String, module: AbstractModule[Activity, Activity, T],
overwrite: Boolean = false)(implicit ev: TensorNumeric[T]): Unit = {

Expand All @@ -52,6 +149,14 @@ object ModulePersister {
File.saveBytes(bigDLModel.toByteArray, modelPath, overwrite)
}

/**
* Save module definition to given path
* @param definitionPath the path to persist definition path to
* @param module module to be persisted
* @param overwrite if overwrite module file if exists
* @param ev numeric ops
* @tparam T data type
*/
def saveModelDefinitionToFile[T: ClassTag](definitionPath : String,
module : AbstractModule[Activity, Activity, T],
overwrite : Boolean = false)(implicit ev: TensorNumeric[T]) : Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1879,6 +1879,21 @@ class ModuleSerializerSpec extends FlatSpec with Matchers {

res1 should be (res2)
}

"Load by definition " should " work properly" in {
val linear1 = Linear(2, 2).setName("linear")
val sequential = Sequential().setName("sequential").add(linear1)
ModulePersister.saveToFile("/tmp/loadDef.bigdl", sequential, true)
val linear2 = Linear(2, 2).setName("linear")
val definition = Sequential().setName("sequential").add(linear2)
ModuleLoader.loadFromDefinition(definition, "/tmp/loadDef.bigdl")

val weight1 = linear1.weight

val weight2 = linear2.weight

weight1 should be (weight2)
}
}

class TestModule[T: ClassTag](val custom: CustomData)
Expand Down

0 comments on commit b00fe5d

Please sign in to comment.