Skip to content

Commit

Permalink
adding xgboost support
Browse files Browse the repository at this point in the history
  • Loading branch information
allwefantasy committed Oct 8, 2018
1 parent e2ebe34 commit b1542a5
Show file tree
Hide file tree
Showing 8 changed files with 515 additions and 124 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ object LocalSparkServiceApp {
"-streaming.ps.enable", "true",
"-spark.sql.hive.thriftServer.singleSession", "true",
"-streaming.rest.intercept.clzz", "streaming.rest.ExampleRestInterceptor",
"-streaming.deploy.rest.api", "true",
"-streaming.deploy.rest.api", "false",
"-spark.driver.maxResultSize", "2g",
"-spark.serializer", "org.apache.spark.serializer.KryoSerializer",
"-spark.sql.codegen.wholeStage", "true",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import org.apache.spark.sql.{DataFrame, Dataset, Row, SaveMode, SparkSession, fu
import org.apache.spark.util.{ExternalCommandRunner, ObjPickle, WowMD5, WowXORShiftRandom}
import streaming.common.HDFSOperator
import MetaConst._
import net.csdn.common.reflect.ReflectHelper
import org.apache.spark.ps.cluster.Message
import streaming.core.strategy.platform.{PlatformManager, SparkRuntime}
import streaming.log.{Logging, WowLog}
Expand Down Expand Up @@ -109,8 +110,9 @@ trait Functions extends SQlBaseFunc with Logging with WowLog with Serializable {
params.filter(f => f._1.startsWith(name + ".")).map { f =>
val Array(name, group, keys@_*) = f._1.split("\\.")
(group, keys.mkString("."), f._2)
}.groupBy(f => f._1).map { f => f._2.map(k =>
(k._2, k._3)).toMap
}.groupBy(f => f._1).map { f =>
f._2.map(k =>
(k._2, k._3)).toMap
}.toArray
}

Expand Down Expand Up @@ -154,13 +156,81 @@ trait Functions extends SQlBaseFunc with Logging with WowLog with Serializable {
val model = alg.asInstanceOf[Estimator[T]].fit(trainData)
model.asInstanceOf[MLWritable].write.overwrite().save(path + "/" + modelIndex)
}

params.getOrElse("multiModels", "false").toBoolean match {
case true => sampleUnbalanceWithMultiModel(df, path, params, f)
case false =>
f(df, 0)
}
}


def trainModelsWithMultiParamGroup2(df: DataFrame, path: String, params: Map[String, String],
modelType: () => Params,
evaluate: (Params, Map[String, String]) => List[MetricValue]
) = {

val keepVersion = params.getOrElse("keepVersion", "true").toBoolean

val mf = (trainData: DataFrame, fitParam: Map[String, String], modelIndex: Int) => {
val alg = modelType()
configureModel(alg, fitParam)

logInfo(format(s"[training] [alg=${alg.getClass.getName}] [keepVersion=${keepVersion}]"))

var status = "success"
val modelTrainStartTime = System.currentTimeMillis()
val modelPath = SQLPythonFunc.getAlgModelPath(path, keepVersion) + "/" + modelIndex
var scores: List[MetricValue] = List()
try {
val model = ReflectHelper.method(alg, "fit", trainData)
model.asInstanceOf[MLWritable].write.overwrite().save(modelPath)
scores = evaluate(model.asInstanceOf[Params], fitParam)
logInfo(format(s"[trained] [alg=${alg.getClass.getName}] [metrics=${scores}] [model hyperparameters=${
model.asInstanceOf[Params].explainParams().replaceAll("\n", "\t")
}]"))
} catch {
case e: Exception =>
logInfo(format_exception(e))
status = "fail"
}
val modelTrainEndTime = System.currentTimeMillis()
// if (status == "fail") {
// throw new RuntimeException(s"Fail to train als model: ${modelIndex}; All will fails")
// }
val metrics = scores.map(score => Row.fromSeq(Seq(score.name, score.value))).toArray
Row.fromSeq(Seq(modelPath, modelIndex, alg.getClass.getName, metrics, status, modelTrainStartTime, modelTrainEndTime, fitParam))
}
var fitParam = arrayParamsWithIndex("fitParam", params)
if (fitParam.size == 0) {
fitParam = Array((0, Map[String, String]()))
}

val wowRes = fitParam.map { fp =>
mf(df, fp._2, fp._1)
}

val wowRDD = df.sparkSession.sparkContext.parallelize(wowRes, 1)

df.sparkSession.createDataFrame(wowRDD, StructType(Seq(
StructField("modelPath", StringType),
StructField("algIndex", IntegerType),
StructField("alg", StringType),
StructField("metrics", ArrayType(StructType(Seq(
StructField(name = "name", dataType = StringType),
StructField(name = "value", dataType = DoubleType)
)))),

StructField("status", StringType),
StructField("startTime", LongType),
StructField("endTime", LongType),
StructField("trainParams", MapType(StringType, StringType))
))).
write.
mode(SaveMode.Overwrite).
parquet(SQLPythonFunc.getAlgMetalPath(path, keepVersion) + "/0")
}

def trainModelsWithMultiParamGroup[T <: Model[T]](df: DataFrame, path: String, params: Map[String, String],
modelType: () => Params,
evaluate: (Params, Map[String, String]) => List[MetricValue]
Expand Down Expand Up @@ -197,9 +267,11 @@ trait Functions extends SQlBaseFunc with Logging with WowLog with Serializable {
val metrics = scores.map(score => Row.fromSeq(Seq(score.name, score.value))).toArray
Row.fromSeq(Seq(modelPath, modelIndex, alg.getClass.getName, metrics, status, modelTrainStartTime, modelTrainEndTime, fitParam))
}
val fitParam = arrayParamsWithIndex("fitParam", params)
var fitParam = arrayParamsWithIndex("fitParam", params)

require(fitParam.size > 0, "fitParam.[group].[parameter] should be configured at least once")
if (fitParam.size == 0) {
fitParam = Array((0, Map[String, String]()))
}

val wowRes = fitParam.map { fp =>
mf(df, fp._2, fp._1)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package streaming.dsl.mmlib.algs

import net.csdn.common.reflect.ReflectHelper
import org.apache.spark.ml.classification.RandomForestClassificationModel
import org.apache.spark.ml.{Model, Transformer}
import org.apache.spark.ml.param.Params
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.expressions.UserDefinedFunction
import streaming.dsl.mmlib.SQLAlg
import streaming.dsl.mmlib.algs.classfication.BaseClassification
import streaming.dsl.mmlib.algs.param.BaseParams

import scala.collection.mutable.ArrayBuffer


/**
* Created by allwefantasy on 12/9/2018.
*/
class SQLXGBoostExt(override val uid: String) extends SQLAlg with MllibFunctions with Functions with BaseClassification {
def this() = this(BaseParams.randomUID())

override def train(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {


val keepVersion = params.getOrElse("keepVersion", "true").toBoolean
setKeepVersion(keepVersion)

val evaluateTable = params.get("evaluateTable")
setEvaluateTable(evaluateTable.getOrElse("None"))

SQLPythonFunc.incrementVersion(path, keepVersion)
val spark = df.sparkSession

trainModelsWithMultiParamGroup2(df, path, params, () => {
val obj = Class.forName("streaming.dsl.mmlib.algs.XGBoostExt").newInstance()
ReflectHelper.method(obj, "WowXGBoostClassifier").asInstanceOf[Params]
}, (_model, fitParam) => {
evaluateTable match {
case Some(etable) =>
val model = _model.asInstanceOf[Transformer]
val evaluateTableDF = spark.table(etable)
val predictions = model.transform(evaluateTableDF)
multiclassClassificationEvaluate(predictions, (evaluator) => {
evaluator.setLabelCol(fitParam.getOrElse("labelCol", "label"))
evaluator.setPredictionCol("prediction")
})

case None => List()
}
}
)

formatOutput(getModelMetaData(spark, path))

}

override def load(sparkSession: SparkSession, path: String, params: Map[String, String]): Any = {
val (bestModelPath, baseModelPath, metaPath) = mllibModelAndMetaPath(path, params, sparkSession)
val obj = Class.forName("streaming.dsl.mmlib.algs.XGBoostExt").newInstance()
val model = ReflectHelper.method(obj, "load", bestModelPath(0))
ArrayBuffer(model)
}


override def batchPredict(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {
val spark = df.sparkSession
val models = load(spark, path, params).asInstanceOf[ArrayBuffer[Transformer]]
models.head.transform(df)
}

override def predict(sparkSession: _root_.org.apache.spark.sql.SparkSession, _model: Any, name: String, params: Map[String, String]): UserDefinedFunction = {
predict_classification(sparkSession, _model, name)
}


}
Loading

0 comments on commit b1542a5

Please sign in to comment.