SparkGBM is an implementation of Gradient Boosting Machine atop Apache Spark. It is designed to be scalable and efficient with the following advantages:
1, Compatible with current ML/MLlib pipeline
2, Purely writen in Scala/Spark, no other dependancy
3, Faster training speed compared with ml.GBT
Thanks to XGBoost and LightGBM, SparkGBM draws on the valuable experience of them to aim to be an efficient framework:
From XGBoost we introduced:
1, Second order approximation of objective function
2, L1/L2 regularization of weights to prevent overfitting
3, Column subsampling by tree and by node
4, Sparsity-awareness
From LightGBM we introduced:
1, Histogram subtraction to halve communication overhead
2, Feature binning to reduce memory footprint
GBMClassifier for binary classification: Source Example
val gbmc = new GBMClassifier
gbmc.setBoostType("gbtree") // "dart" -> DART, "gbtree" -> gradient boosting
.setObjectiveFunc("logistic") // "logistic" -> logloss
.setEvaluateFunc(Array("auc", "logloss")) // "auc", "logloss", "error"
.setMaxIter(10) // maximum number of iterations
.setModelCheckpointInterval(4) // model checkpoint interval
.setModelCheckpointPath(path) // model checkpoint directory
/** training without validation */
val model =
/** load the snapshots saved during the training */
val modelSnapshot4 = GBMClassificationModel.load(s"$path/model-4")
val modelSnapshot8 = GBMClassificationModel.load(s"$path/model-8")
/** model save and load */
val model2 = GBMClassificationModel.load(savePath)
GBMRegressor for regression: Source Example
val gbmr = new GBMRegressor
gbmr.setBoostType("dart") // "dart" -> DART, "gbtree" -> gradient boosting
.setObjectiveFunc("square") // "square" -> MSE, "huber" -> Pseudo-Huber loss
.setEvaluateFunc(Array("rmse", "mae")) // "rmse", "mse", "mae"
.setMaxIter(10) // maximum number of iterations
.setMaxDepth(7) // maximum depth
.setMaxBins(32) // maximum number of bins
.setNumericalBinType("width") // "width" -> by interval-equal bins, "depth" -> by quantiles
.setMaxLeaves(100) // maximum number of leaves
.setMinNodeHess(0.001) // minimum hessian needed in a node
.setRegAlpha(0.1) // L1 regularization
.setRegLambda(0.5) // L2 regularization
.setDropRate(0.1) // dropout rate
.setDropSkip(0.5) // probability of skipping drop
.setInitialModelPath(path) // path of initial model
.setEarlyStopIters(10) // early stopping
/** training without validation, early stopping is ignored */
val model1 =
/** training with validation */
val model2 =, test)
/** using only 5 tree for the following feature importance computation, prediction and leaf transformation */
/** feature importance */
/** prediction */
/** enable one-hot leaf transform */
Besides all the functions in DataFrame-based APIs, RDD-based APIs also support user-defined objective, evaluation and callback. Source Example
/** User defined objective function */
val obj = new ScalarObjFunc {
override def compute(label: Double, score: Double): (Double, Double) = (score - label, 1.0)
override def name: String = "Another Square"
/** User defined evaluation function for R2 */
val r2Eval = new ScalarEvalFunc {
override def isLargerBetter: Boolean = true
override def name: String = "R2 (no weight)"
// (weight, label, raw, score)
override def computeImpl(data: RDD[(Double, Double, Double, Double)]): Double = {
/** ignore weight */
new RegressionMetrics( => (t._4, t._2))).r2
/** User defined evaluation function for MAE */
val maeEval = new SimpleEvalFunc {
override def compute(label: Double,
score: Double): Double = (label - score).abs
override def isLargerBetter: Boolean = false
override def name: String = "Another MAE"
/** User defined callback function */
val lrUpdater = new CallbackFunc {
override def compute(boostConfig: BoostConfig,
model: GBMModel,
iteration: Int,
trainMetrics: Array[Map[String, Double]],
testMetrics: Array[Map[String, Double]]): Boolean = {
/** learning rate decay */
if (boostConfig.getStepSize > 0.01) {
boostConfig.updateStepSize(boostConfig.getStepSize * 0.95)
println(s"Round ${model.numTrees}: train metrics: ${trainMetrics.last}")
if (testMetrics.nonEmpty) {
println(s"Round ${model.numTrees}: test metrics: ${testMetrics.last}")
override def name: String = "Learning Rate Updater"
val recoder = new MetricRecoder
val gbm = new GBM
.setEvalFunc(Array(r2Eval, maeEval, new R2Eval))
.setCallbackFunc(Array(lrUpdater, recoder))
/** train with validation */
val model =, test)
mvn clean package
Current master branch work for Spark-2.3.0