Skip to content

Commit

Permalink
[jvm-packages] fix bug doing rabit call after finalize (dmlc#2079)
Browse files Browse the repository at this point in the history
[jvm-packages]fix bug doing rabit call after finalize
  • Loading branch information
hlsc authored and CodingCat committed Mar 3, 2017
1 parent fd19b7a commit a920933
Showing 1 changed file with 7 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,13 @@ abstract class XGBoostModel(protected var _booster: Booster)
import DataUtils._
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
testSet.mapPartitions { testSamples =>
val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString)
Rabit.init(rabitEnv.asJava)
if (testSamples.hasNext) {
val dMatrix = new DMatrix(new JDMatrix(testSamples, null))
Iterator(broadcastBooster.value.predictLeaf(dMatrix))
val res = broadcastBooster.value.predictLeaf(dMatrix)
Rabit.shutdown()
Iterator(res)
} else {
Iterator()
}
Expand Down Expand Up @@ -145,8 +149,9 @@ abstract class XGBoostModel(protected var _booster: Booster)
flatSampleArray(i) = sampleArray(i / numColumns).values(i % numColumns).toFloat
}
val dMatrix = new DMatrix(flatSampleArray, numRows, numColumns, missingValue)
val res = broadcastBooster.value.predict(dMatrix)
Rabit.shutdown()
Iterator(broadcastBooster.value.predict(dMatrix))
Iterator(res)
}
}
}
Expand Down

0 comments on commit a920933

Please sign in to comment.