Skip to content

Commit

Permalink
adding table cache support
Browse files Browse the repository at this point in the history
  • Loading branch information
allwefantasy committed Nov 6, 2018
1 parent c663066 commit ca432b3
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class TrainAdaptor(scriptSQLExecListener: ScriptSQLExecListener) extends DslAdap

val isTrain = ctx.getChild(0).getText match {
case "predict" => false
case "run" => false
case "run" => true
case "train" => true
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package streaming.dsl.mmlib.algs

import org.apache.spark.ml.param.Param
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.expressions.UserDefinedFunction
import streaming.dsl.mmlib.SQLAlg
import streaming.dsl.mmlib.algs.param.{BaseParams, WowParams}
import streaming.session.MLSQLException


class SQLCacheExt(override val uid: String) extends SQLAlg with WowParams {

override def train(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {

val exe = params.get(execute.name).getOrElse {
"cache"
}

if (!execute.isValid(exe)) {
throw new MLSQLException(s"${execute.name} should be cache or uncache")
}

if (exe == "cache") {
df.persist()
} else {
df.unpersist()
}
df
}

override def load(sparkSession: SparkSession, path: String, params: Map[String, String]): Any = {
throw new RuntimeException("train is not support")
}

override def predict(sparkSession: SparkSession, _model: Any, name: String, params: Map[String, String]): UserDefinedFunction = {
null
}

final val execute: Param[String] = new Param[String](this, "execute", "cache|uncache", isValid = (m: String) => {
m == "cache" || m == "uncache"
})

def this() = this(BaseParams.randomUID())
}

0 comments on commit ca432b3

Please sign in to comment.