Skip to content

Commit

Permalink
dist training tested with dmlc_local.py. version 0.1.2-SNAPSHOT
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhliu committed Apr 11, 2016
1 parent 391ddee commit b84d461
Show file tree
Hide file tree
Showing 19 changed files with 138 additions and 39 deletions.
24 changes: 18 additions & 6 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ endif
endif

ifndef DMLC_CORE
DMLC_CORE = dmlc-core
DMLC_CORE = $(ROOTDIR)/dmlc-core
endif

ifneq ($(USE_OPENMP), 1)
Expand Down Expand Up @@ -81,7 +81,7 @@ ifneq ($(USE_CUDA_PATH), NONE)
endif

# ps-lite
PS_PATH=./ps-lite
PS_PATH=$(ROOTDIR)/ps-lite
DEPS_PATH=$(shell pwd)/deps
include $(PS_PATH)/make/ps.mk
ifeq ($(USE_DIST_KVSTORE), 1)
Expand Down Expand Up @@ -235,16 +235,28 @@ rpkg: roxygen
R CMD build --no-build-vignettes R-package

scalapkg:
(cd $(ROOTDIR)/scala-package; mvn clean package -P$(SCALA_PKG_PROFILE) -Dcxx="$(CXX)" -Dcflags="$(CFLAGS)" -Dldflags="$(LDFLAGS)")
(cd $(ROOTDIR)/scala-package; \
mvn clean package -P$(SCALA_PKG_PROFILE) -Dcxx="$(CXX)" \
-Dcflags="$(CFLAGS)" -Dldflags="$(LDFLAGS)" \
-Dlddeps="$(LIB_DEP)")

scalatest:
(cd $(ROOTDIR)/scala-package; mvn verify -P$(SCALA_PKG_PROFILE) -Dcxx="$(CXX)" -Dcflags="$(CFLAGS)" -Dldflags="$(LDFLAGS)" $(SCALA_TEST_ARGS))
(cd $(ROOTDIR)/scala-package; \
mvn verify -P$(SCALA_PKG_PROFILE) -Dcxx="$(CXX)" \
-Dcflags="$(CFLAGS)" -Dldflags="$(LDFLAGS)" \
-Dlddeps="$(LIB_DEP)" $(SCALA_TEST_ARGS))

scalainstall:
(cd $(ROOTDIR)/scala-package; mvn install -P$(SCALA_PKG_PROFILE) -DskipTests -Dcxx="$(CXX)" -Dcflags="$(CFLAGS)" -Dldflags="$(LDFLAGS)")
(cd $(ROOTDIR)/scala-package; \
mvn install -P$(SCALA_PKG_PROFILE) -DskipTests -Dcxx="$(CXX)" \
-Dcflags="$(CFLAGS)" -Dldflags="$(LDFLAGS)" \
-Dlddeps="$(LIB_DEP)")

scaladeploy:
(cd $(ROOTDIR)/scala-package; mvn deploy -Prelease,$(SCALA_PKG_PROFILE) -DskipTests -Dcxx="$(CXX)" -Dcflags="$(CFLAGS)" -Dldflags="$(LDFLAGS)")
(cd $(ROOTDIR)/scala-package; \
mvn deploy -Prelease,$(SCALA_PKG_PROFILE) -DskipTests -Dcxx="$(CXX)" \
-Dcflags="$(CFLAGS)" -Dldflags="$(LDFLAGS)" \
-Dlddeps="$(LIB_DEP)")

jnilint:
python2 dmlc-core/scripts/lint.py mxnet-jnicpp cpp scala-package/native/src
Expand Down
6 changes: 3 additions & 3 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1153,18 +1153,18 @@ MXNET_DLL int MXKVStoreBarrier(KVStoreHandle handle);
* \brief the prototype of a server controller
* \param head the head of the command
* \param body the body of the command
* \param helper handle for implementing controller, e.g., a Java object for scala-package
* \param controller_handle helper handle for implementing controller
*/
typedef void (MXKVStoreServerController)(int head,
const char* body,
const char *body,
void *controller_handle);

/**
* \return Run as server (or scheduler)
*
* \param handle handle to the KVStore
* \param controller the user-defined server controller
* \param helper handle for implementing controller, e.g., a Java object for scala-package
* \param controller_handle helper handle for implementing controller
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXKVStoreRunServer(KVStoreHandle handle,
Expand Down
4 changes: 2 additions & 2 deletions scala-package/assembly/linux-x86_64-cpu/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
<parent>
<groupId>ml.dmlc.mxnet</groupId>
<artifactId>mxnet-full-parent_2.10</artifactId>
<version>0.1.1</version>
<version>0.1.2-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>

<groupId>ml.dmlc.mxnet</groupId>
<artifactId>mxnet-full_2.10-linux-x86_64-cpu</artifactId>
<version>0.1.1</version>
<version>0.1.2-SNAPSHOT</version>
<name>MXNet Scala Package - Full Linux-x86_64 CPU-only</name>
<packaging>jar</packaging>

Expand Down
4 changes: 2 additions & 2 deletions scala-package/assembly/linux-x86_64-gpu/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
<parent>
<groupId>ml.dmlc.mxnet</groupId>
<artifactId>mxnet-full-parent_2.10</artifactId>
<version>0.1.1</version>
<version>0.1.2-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>

<groupId>ml.dmlc.mxnet</groupId>
<artifactId>mxnet-full_2.10-linux-x86_64-gpu</artifactId>
<version>0.1.1</version>
<version>0.1.2-SNAPSHOT</version>
<name>MXNet Scala Package - Full Linux-x86_64 GPU</name>
<packaging>jar</packaging>

Expand Down
4 changes: 2 additions & 2 deletions scala-package/assembly/osx-x86_64-cpu/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
<parent>
<groupId>ml.dmlc.mxnet</groupId>
<artifactId>mxnet-full-parent_2.10</artifactId>
<version>0.1.1</version>
<version>0.1.2-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>

<groupId>ml.dmlc.mxnet</groupId>
<artifactId>mxnet-full_2.10-osx-x86_64-cpu</artifactId>
<version>0.1.1</version>
<version>0.1.2-SNAPSHOT</version>
<name>MXNet Scala Package - Full OSX-x86_64 CPU-only</name>
<packaging>jar</packaging>

Expand Down
4 changes: 2 additions & 2 deletions scala-package/assembly/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
<parent>
<groupId>ml.dmlc.mxnet</groupId>
<artifactId>mxnet-parent_2.10</artifactId>
<version>0.1.1</version>
<version>0.1.2-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>

<groupId>ml.dmlc.mxnet</groupId>
<artifactId>mxnet-full-parent_2.10</artifactId>
<version>0.1.1</version>
<version>0.1.2-SNAPSHOT</version>
<name>MXNet Scala Package - Full Parent</name>
<packaging>pom</packaging>

Expand Down
4 changes: 2 additions & 2 deletions scala-package/core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
<parent>
<groupId>ml.dmlc.mxnet</groupId>
<artifactId>mxnet-parent_2.10</artifactId>
<version>0.1.1</version>
<version>0.1.2-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>

<groupId>ml.dmlc.mxnet</groupId>
<artifactId>mxnet-core_2.10</artifactId>
<version>0.1.1</version>
<version>0.1.2-SNAPSHOT</version>
<name>MXNet Scala Package - Core</name>

<profiles>
Expand Down
8 changes: 6 additions & 2 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ml.dmlc.mxnet

import ml.dmlc.mxnet.Base._
import org.slf4j.{LoggerFactory, Logger}

/**
* Key value store interface of MXNet for parameter synchronization.
Expand Down Expand Up @@ -29,6 +30,7 @@ object KVStore {

// scalastyle:off finalize
class KVStore(private[mxnet] val handle: KVStoreHandle) {
private val logger: Logger = LoggerFactory.getLogger(classOf[KVStore])
private var updaterFunc: MXKVStoreUpdater = null
private var disposed = false

Expand Down Expand Up @@ -175,9 +177,11 @@ class KVStore(private[mxnet] val handle: KVStoreHandle) {
def setOptimizer(optimizer: Optimizer): Unit = {
val isWorker = new RefInt
checkCall(_LIB.mxKVStoreIsWorkerNode(isWorker))
if ("dist" == `type` && isWorker.value != 0) {
if (`type`.contains("dist") && isWorker.value != 0) {
val optSerialized = Serializer.getSerializer.serialize(optimizer)
sendCommandToServers(0, Serializer.encodeBase64String(optSerialized))
val cmd = Serializer.encodeBase64String(optSerialized)
logger.debug("Send optimizer to server: {}", cmd)
sendCommandToServers(0, cmd)
} else {
setUpdater(Optimizer.getUpdater(optimizer))
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package ml.dmlc.mxnet

import ml.dmlc.mxnet.Base._
import org.slf4j.{Logger, LoggerFactory}

/**
* Server node for the key value store
* @author Yizhi Liu
*/
class KVStoreServer(private val kvStore: KVStore) {
private val logger: Logger = LoggerFactory.getLogger(classOf[KVStoreServer])
private val handle: KVStoreHandle = kvStore.handle
private val controller = new KVServerControllerCallback {
override def invoke(cmdId: Int, cmdBody: String): Unit = {
logger.debug("Receive cmdId {}, cmdBody: {}", cmdId, cmdBody)
if (cmdId == 0) {
val optimizer = Serializer.getSerializer.deserialize[Optimizer](
Serializer.decodeBase64String(cmdBody))
kvStore.setOptimizer(optimizer)
} else {
logger.warn(s"Server ${kvStore.rank}, unknown command ($cmdId, $cmdBody)")
}
}
}

// run the server, whose behavior is like
// while receive(x):
// if is_command x: controller(x)
// else if is_key_value x: updater(x)
def run(): Unit = {
checkCall(_LIB.mxKVStoreRunServer(handle, controller))
}
}

object KVStoreServer {
// Start server/scheduler according to env variables
def start(): Unit = {
val isWorker = new RefInt
checkCall(_LIB.mxKVStoreIsWorkerNode(isWorker))
if (isWorker.value == 0) {
val kvStore = KVStore.create("dist")
val server = new KVStoreServer(kvStore)
server.run()
sys.exit()
}
}
}

trait KVServerControllerCallback {
def invoke(cmdId: Int, cmdBody: String): Unit
}
4 changes: 4 additions & 0 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ class LibInfo {
handles: Array[NDArrayHandle],
keys: Array[String]): Int
@native def mxNDArrayGetContext(handle: NDArrayHandle, devTypeId: RefInt, devId: RefInt): Int

// KVStore Server
@native def mxKVStoreRunServer(handle: KVStoreHandle, controller: KVServerControllerCallback): Int

// KVStore
@native def mxKVStoreCreate(name: String, handle: KVStoreHandleRef): Int
@native def mxKVStoreInit(handle: KVStoreHandle,
Expand Down
4 changes: 2 additions & 2 deletions scala-package/examples/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
<parent>
<groupId>ml.dmlc.mxnet</groupId>
<artifactId>mxnet-parent_2.10</artifactId>
<version>0.1.1</version>
<version>0.1.2-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>

<artifactId>mxnet-examples_2.10</artifactId>
<version>0.1.1</version>
<version>0.1.2-SNAPSHOT</version>
<name>MXNet Scala Package - Examples</name>
<packaging>pom</packaging>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ object TrainMnist {
else if (inst.cpus != null) inst.cpus.split(',').map(id => Context.cpu(id.trim.toInt))
else Array(Context.cpu(0))

logger.info("Start KVStoreServer for scheduler & servers")
KVStoreServer.start()

ModelTrain.fit(dataDir = inst.dataDir,
batchSize = inst.batchSize, numExamples = inst.numExamples, devs = devs,
network = net, dataLoader = getIterator(dataShape),
Expand Down
7 changes: 3 additions & 4 deletions scala-package/native/linux-x86_64-cpu/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
<parent>
<groupId>ml.dmlc.mxnet</groupId>
<artifactId>mxnet-scala-native-parent</artifactId>
<version>0.1.1</version>
<version>0.1.2-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>

<groupId>ml.dmlc.mxnet</groupId>
<artifactId>libmxnet-scala-linux-x86_64-cpu</artifactId>
<version>0.1.1</version>
<version>0.1.2-SNAPSHOT</version>
<name>MXNet Scala Package - Native Linux-x86_64 CPU-only</name>
<url>http://maven.apache.org</url>

Expand Down Expand Up @@ -65,9 +65,8 @@
</linkerStartOptions>
<linkerMiddleOptions>
<linkerMiddleOption>-Wl,--whole-archive</linkerMiddleOption>
<linkerMiddleOption>${lddeps}</linkerMiddleOption>
<linkerMiddleOption>../../../lib/libmxnet.a</linkerMiddleOption>
<linkerMiddleOption>../../../dmlc-core/libdmlc.a</linkerMiddleOption>
<linkerMiddleOption>../../../ps-lite/build/libps.a</linkerMiddleOption>
<linkerMiddleOption>-Wl,--no-whole-archive</linkerMiddleOption>
</linkerMiddleOptions>
<linkerEndOptions>
Expand Down
7 changes: 3 additions & 4 deletions scala-package/native/linux-x86_64-gpu/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
<parent>
<groupId>ml.dmlc.mxnet</groupId>
<artifactId>mxnet-scala-native-parent</artifactId>
<version>0.1.1</version>
<version>0.1.2-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>

<groupId>ml.dmlc.mxnet</groupId>
<artifactId>libmxnet-scala-linux-x86_64-gpu</artifactId>
<version>0.1.1</version>
<version>0.1.2-SNAPSHOT</version>
<name>MXNet Scala Package - Native Linux-x86_64 GPU</name>
<url>http://maven.apache.org</url>

Expand Down Expand Up @@ -65,9 +65,8 @@
</linkerStartOptions>
<linkerMiddleOptions>
<linkerMiddleOption>-Wl,--whole-archive</linkerMiddleOption>
<linkerMiddleOption>${lddeps}</linkerMiddleOption>
<linkerMiddleOption>../../../lib/libmxnet.a</linkerMiddleOption>
<linkerMiddleOption>../../../dmlc-core/libdmlc.a</linkerMiddleOption>
<linkerMiddleOption>../../../ps-lite/build/libps.a</linkerMiddleOption>
<linkerMiddleOption>-Wl,--no-whole-archive</linkerMiddleOption>
</linkerMiddleOptions>
<linkerEndOptions>
Expand Down
7 changes: 3 additions & 4 deletions scala-package/native/osx-x86_64-cpu/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
<parent>
<groupId>ml.dmlc.mxnet</groupId>
<artifactId>mxnet-scala-native-parent</artifactId>
<version>0.1.1</version>
<version>0.1.2-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>

<groupId>ml.dmlc.mxnet</groupId>
<artifactId>libmxnet-scala-osx-x86_64-cpu</artifactId>
<version>0.1.1</version>
<version>0.1.2-SNAPSHOT</version>
<name>MXNet Scala Package - Native OSX-x86_64 CPU-only</name>
<url>http://maven.apache.org</url>

Expand Down Expand Up @@ -67,8 +67,7 @@
<linkerMiddleOption>-framework JavaVM</linkerMiddleOption>
<linkerMiddleOption>-Wl,-exported_symbol,_Java_*</linkerMiddleOption>
<linkerMiddleOption>-Wl,-x</linkerMiddleOption>
<linkerMiddleOption>../../../dmlc-core/libdmlc.a</linkerMiddleOption>
<linkerMiddleOption>../../../ps-lite/build/libps.a</linkerMiddleOption>
<linkerMiddleOption>${lddeps}</linkerMiddleOption>
<linkerMiddleOption>-force_load ../../../lib/libmxnet.a</linkerMiddleOption>
</linkerMiddleOptions>
<linkerEndOptions>
Expand Down
4 changes: 2 additions & 2 deletions scala-package/native/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
<parent>
<groupId>ml.dmlc.mxnet</groupId>
<artifactId>mxnet-parent_2.10</artifactId>
<version>0.1.1</version>
<version>0.1.2-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>

<artifactId>mxnet-scala-native-parent</artifactId>
<version>0.1.1</version>
<version>0.1.2-SNAPSHOT</version>
<name>MXNet Scala Package - Native Parent</name>
<packaging>pom</packaging>

Expand Down
Loading

0 comments on commit b84d461

Please sign in to comment.