Skip to content

Commit

Permalink
[jvm-packages] allow training with missing values in xgboost-spark (d…
Browse files Browse the repository at this point in the history
…mlc#1525)

* allow training with missing values in xgboost-spark

* fix compilation error

* fix bug
  • Loading branch information
CodingCat authored Aug 30, 2016
1 parent 6014839 commit 3f198b9
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ package ml.dmlc.xgboost4j.scala.spark

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ListBuffer

import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix, Rabit, RabitTracker, XGBoostError}
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
import org.apache.commons.logging.LogFactory
import org.apache.hadoop.fs.Path
import org.apache.spark.mllib.linalg.SparseVector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkContext, TaskContext}
Expand All @@ -35,12 +37,37 @@ object XGBoost extends Serializable {
new XGBoostModel(booster)
}

private def fromDenseToSparseLabeledPoints(
denseLabeledPoints: Iterator[LabeledPoint],
missing: Float): Iterator[LabeledPoint] = {
if (!missing.isNaN) {
val sparseLabeledPoints = new ListBuffer[LabeledPoint]
for (labelPoint <- denseLabeledPoints) {
val dVector = labelPoint.features.toDense
val indices = new ListBuffer[Int]
val values = new ListBuffer[Double]
for (i <- dVector.values.indices) {
if (values(i) != missing) {
indices += i
values += dVector.values(i)
}
}
val sparseVector = new SparseVector(dVector.values.length, indices.toArray,
values.toArray)
sparseLabeledPoints += LabeledPoint(labelPoint.label, sparseVector)
}
sparseLabeledPoints.iterator
} else {
denseLabeledPoints
}
}

private[spark] def buildDistributedBoosters(
trainingData: RDD[LabeledPoint],
xgBoostConfMap: Map[String, Any],
rabitEnv: mutable.Map[String, String],
numWorkers: Int, round: Int, obj: ObjectiveTrait, eval: EvalTrait,
useExternalMemory: Boolean): RDD[Booster] = {
useExternalMemory: Boolean, missing: Float = Float.NaN): RDD[Booster] = {
import DataUtils._
val partitionedData = {
if (numWorkers > trainingData.partitions.length) {
Expand Down Expand Up @@ -71,7 +98,8 @@ object XGBoost extends Serializable {
null
}
}
val trainingSet = new DMatrix(new JDMatrix(trainingSamples, cacheFileName))
val partitionItr = fromDenseToSparseLabeledPoints(trainingSamples, missing)
val trainingSet = new DMatrix(new JDMatrix(partitionItr, cacheFileName))
booster = SXGBoost.train(trainingSet, xgBoostConfMap, round,
watches = new mutable.HashMap[String, DMatrix] {
put("train", trainingSet)
Expand All @@ -97,13 +125,14 @@ object XGBoost extends Serializable {
* @param eval the user-defined evaluation function, null by default
* @param useExternalMemory indicate whether to use external memory cache, by setting this flag as
* true, the user may save the RAM cost for running XGBoost within Spark
* @param missing the value represented the missing value in the dataset
* @throws ml.dmlc.xgboost4j.java.XGBoostError when the model training is failed
* @return XGBoostModel when successful training
*/
@throws(classOf[XGBoostError])
def train(trainingData: RDD[LabeledPoint], configMap: Map[String, Any], round: Int,
nWorkers: Int, obj: ObjectiveTrait = null, eval: EvalTrait = null,
useExternalMemory: Boolean = false): XGBoostModel = {
useExternalMemory: Boolean = false, missing: Float = Float.NaN): XGBoostModel = {
require(nWorkers > 0, "you must specify more than 0 workers")
val tracker = new RabitTracker(nWorkers)
implicit val sc = trainingData.sparkContext
Expand All @@ -119,7 +148,7 @@ object XGBoost extends Serializable {
}
require(tracker.start(), "FAULT: Failed to start tracker")
val boosters = buildDistributedBoosters(trainingData, overridedConfMap,
tracker.getWorkerEnvs.asScala, nWorkers, round, obj, eval, useExternalMemory)
tracker.getWorkerEnvs.asScala, nWorkers, round, obj, eval, useExternalMemory, missing)
val sparkJobThread = new Thread() {
override def run() {
// force the job
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class XGBoostSuite extends FunSuite with BeforeAndAfter {
List("eta" -> "1", "max_depth" -> "6", "silent" -> "0",
"objective" -> "binary:logistic").toMap,
new scala.collection.mutable.HashMap[String, String],
numWorkers = 2, round = 5, null, null, useExternalMemory = false)
numWorkers = 2, round = 5, eval = null, obj = null, useExternalMemory = false)
val boosterCount = boosterRDD.count()
assert(boosterCount === 2)
val boosters = boosterRDD.collect()
Expand Down

0 comments on commit 3f198b9

Please sign in to comment.