diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala index 4f3d50f9f06a..85f9b3eaba26 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala @@ -22,6 +22,7 @@ import scala.collection.mutable import ml.dmlc.xgboost4j.java.Rabit import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost} +import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait} import ml.dmlc.xgboost4j.scala.spark.params._ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} @@ -134,6 +135,10 @@ class XGBoostClassifier ( def setNumEarlyStoppingRounds(value: Int): this.type = set(numEarlyStoppingRounds, value) + def setCustomObj(value: ObjectiveTrait): this.type = set(customObj, value) + + def setCustomEval(value: EvalTrait): this.type = set(customEval, value) + // called at the start of fit/train when 'eval_metric' is not defined private def setupDefaultEvalMetric(): String = { require(isDefined(objective), "Users must set \'objective\' via xgboostParams.") diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala index 29f289102372..6b6c635bdabf 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala @@ -23,6 +23,7 @@ import ml.dmlc.xgboost4j.java.Rabit import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} import ml.dmlc.xgboost4j.scala.spark.params.{DefaultXGBoostParamsReader, _} import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost} +import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait} import org.apache.hadoop.fs.Path import org.apache.spark.TaskContext @@ -136,6 +137,10 @@ class XGBoostRegressor ( def setNumEarlyStoppingRounds(value: Int): this.type = set(numEarlyStoppingRounds, value) + def setCustomObj(value: ObjectiveTrait): this.type = set(customObj, value) + + def setCustomEval(value: EvalTrait): this.type = set(customEval, value) + // called at the start of fit/train when 'eval_metric' is not defined private def setupDefaultEvalMetric(): String = { require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")