Skip to content

Commit

Permalink
Merge pull request byzer-org#1555 from hellozepp/extractEtMethod
Browse files Browse the repository at this point in the history
Extract a general function of an algorithm
  • Loading branch information
chncaesar authored Sep 20, 2021
2 parents e493203 + a6e31c4 commit fc993c8
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -117,20 +117,9 @@ class ModelList(format: String, path: String, option: Map[String, String])(spark

override def explain: DataFrame = {

def getAlgName(fullName: String) = {
if (fullName.contains(".") && fullName.startsWith("SQL")) {
fullName
} else {
fullName.split("\\.").last.replace("SQL", "")
}
}

val items = ClassPath.from(getClass.getClassLoader).getTopLevelClasses("streaming.dsl.mmlib.algs").map { f =>
getAlgName(f.getName)
}.toSet ++ MLMapping.mapping.keys.toSet

val items = MLMapping.getAllETNames
val rows = sparkSession.sparkContext.parallelize(items.toSeq.sorted, 1)
sparkSession.createDataFrame(rows.filter(f => ModelSelfExplain.findAlg(f).isDefined).map { algName =>
sparkSession.createDataFrame(rows.map { algName =>
val sqlAlg = ModelSelfExplain.findAlg(algName).get
Row.fromSeq(Seq(algName, sqlAlg.modelType.humanFriendlyName,
sqlAlg.coreCompatibility.map(f => f.coreVersion).mkString(","),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
package tech.mlsql.dsl.adaptor

import java.util.UUID

import org.apache.spark.SparkCoreVersion
import streaming.dsl.ScriptSQLExecListener
import streaming.dsl.mmlib.SQLAlg
import streaming.dsl.parser.DSLSQLParser._
import streaming.dsl.template.TemplateMerge
import streaming.log.WowLog
import tech.mlsql.common.utils.log.Logging
import tech.mlsql.dsl.auth.ETAuth
import tech.mlsql.dsl.auth.dsl.mmlib.ETMethod
import tech.mlsql.ets.register.ETRegister
Expand Down Expand Up @@ -116,8 +116,8 @@ class TrainAdaptor(scriptSQLExecListener: ScriptSQLExecListener) extends DslAdap
}
}

object MLMapping {
val mapping = Map[String, String](
object MLMapping extends Logging with WowLog {
private val mapping = Map[String, String](
"Word2vec" -> "streaming.dsl.mmlib.algs.SQLWord2Vec",
"NaiveBayes" -> "streaming.dsl.mmlib.algs.SQLNaiveBayes",
"RandomForest" -> "streaming.dsl.mmlib.algs.SQLRandomForest",
Expand All @@ -137,7 +137,6 @@ object MLMapping {
"StandardScaler" -> "streaming.dsl.mmlib.algs.SQLStandardScaler",
"DicOrTableToArray" -> "streaming.dsl.mmlib.algs.SQLDicOrTableToArray",
"TableToMap" -> "streaming.dsl.mmlib.algs.SQLTableToMap",
"DL4J" -> "streaming.dsl.mmlib.algs.SQLDL4J",
"TokenExtract" -> "streaming.dsl.mmlib.algs.SQLTokenExtract",
"TokenAnalysis" -> "streaming.dsl.mmlib.algs.SQLTokenAnalysis",
"TfIdfInPlace" -> "streaming.dsl.mmlib.algs.SQLTfIdfInPlace",
Expand All @@ -147,8 +146,6 @@ object MLMapping {
"NormalizeInPlace" -> "streaming.dsl.mmlib.algs.SQLNormalizeInPlace",
"PythonAlg" -> "streaming.dsl.mmlib.algs.SQLPythonAlg",
"ConfusionMatrix" -> "streaming.dsl.mmlib.algs.SQLConfusionMatrix",
"OpenCVImage" -> "streaming.dsl.mmlib.algs.processing.SQLOpenCVImage",
"JavaImage" -> "streaming.dsl.mmlib.algs.processing.SQLJavaImage",
"Discretizer" -> "streaming.dsl.mmlib.algs.SQLDiscretizer",
"SendMessage" -> "streaming.dsl.mmlib.algs.SQLSendMessage",
"JDBC" -> "streaming.dsl.mmlib.algs.SQLJDBC",
Expand All @@ -159,16 +156,81 @@ object MLMapping {
"ScriptUDF" -> "streaming.dsl.mmlib.algs.ScriptUDF",
"MapValues" -> "streaming.dsl.mmlib.algs.SQLMapValues",
"ExternalPythonAlg" -> "streaming.dsl.mmlib.algs.SQLExternalPythonAlg",
"Kill" -> "streaming.dsl.mmlib.algs.SQLMLSQLJobExt"

"Kill" -> "streaming.dsl.mmlib.algs.SQLMLSQLJobExt",
"ALSInPlace" -> "streaming.dsl.mmlib.algs.SQLALSInPlace",
"AutoIncrementKeyExt" -> "streaming.dsl.mmlib.algs.SQLAutoIncrementKeyExt",
"CacheExt" -> "streaming.dsl.mmlib.algs.SQLCacheExt",
"CommunityBasedSimilarityInPlace" -> "streaming.dsl.mmlib.algs.SQLCommunityBasedSimilarityInPlace",
"CorpusExplainInPlace" -> "streaming.dsl.mmlib.algs.SQLCorpusExplainInPlace",
"DataSourceExt" -> "streaming.dsl.mmlib.algs.SQLDataSourceExt",
"DownloadExt" -> "streaming.dsl.mmlib.algs.SQLDownloadExt",
"FeatureExtractInPlace" -> "streaming.dsl.mmlib.algs.SQLFeatureExtractInPlace",
"JDBCUpdatExt" -> "streaming.dsl.mmlib.algs.SQLJDBCUpdatExt",
"ModelExplainInPlace" -> "streaming.dsl.mmlib.algs.SQLModelExplainInPlace",
"PythonEnvExt" -> "streaming.dsl.mmlib.algs.SQLPythonEnvExt",
"PythonParallelExt" -> "streaming.dsl.mmlib.algs.SQLPythonParallelExt",
"RawSimilarInPlace" -> "streaming.dsl.mmlib.algs.SQLRawSimilarInPlace",
"ReduceFeaturesInPlace" -> "streaming.dsl.mmlib.algs.SQLReduceFeaturesInPlace",
"RepartitionExt" -> "streaming.dsl.mmlib.algs.SQLRepartitionExt",
"ShowFunctionsExt" -> "streaming.dsl.mmlib.algs.SQLShowFunctionsExt",
"TreeBuildExt" -> "streaming.dsl.mmlib.algs.SQLTreeBuildExt",
"UploadFileToServerExt" -> "streaming.dsl.mmlib.algs.SQLUploadFileToServerExt",
"WaterMarkInPlace" -> "streaming.dsl.mmlib.algs.SQLWaterMarkInPlace",
"Word2ArrayInPlace" -> "streaming.dsl.mmlib.algs.SQLWord2ArrayInPlace"
)

/**
* Get all ET names. Including locally loaded classes and code registered.
*
* @return Collection of all ET names. for example: [[scala.collection.immutable.Set("Word2vec", "NaiveBayes")]]
*/
def getAllETNames: Set[String] = {
getETMapping.keys.toSet
}

/**
* Get all ETs. Including locally loaded classes and code registered.
*
* @return The Map of ET, for example:
* [[scala.collection.immutable.Map("Word2vec" -> "streaming.dsl.mmlib.algs.SQLWord2Vec", "Kill" -> "streaming.dsl.mmlib.algs.SQLMLSQLJobExt")]]
*/
def getETMapping: Map[String, String] = {
getRegisteredMapping.filter(f =>
try {
Some(findET(f._1)).isDefined
} catch {
case e: Exception =>
logError(format("load ET class failed!" + format_throwable(e)))
false
case e1: NoClassDefFoundError =>
logError(format("load ET class failed!" + format_throwable(e1)))
false
case _ => false
})
}

def findAlg(name: String) = {
(ETRegister.getMapping ++ mapping).get(name.capitalize) match {
/**
* In this method, we combine the `mappings` registered in two different ways, MLMapping and ETRegister.
* Consistent with [[tech.mlsql.dsl.adaptor.MLMapping#findET(java.lang.String)]] usage.
*
* @param name name of algorithm
* @return algorithm instance
*/
def findAlg(name: String): SQLAlg = {
findET(name)
}

/**
* @param name name of ET
* @return ET instance
*/
def findET(name: String): SQLAlg = {
getRegisteredMapping.get(name.capitalize) match {
case Some(clzz) =>
Class.forName(clzz).newInstance().asInstanceOf[SQLAlg]
case None =>
logWarning(format("Do not calling unregistered ET! If you are using a custom ET, " +
"please register it in `ETRegister`."))
if (!name.contains(".") && (name.endsWith("InPlace") || name.endsWith("Ext"))) {
Class.forName(s"streaming.dsl.mmlib.algs.SQL${name}").newInstance().asInstanceOf[SQLAlg]
} else {
Expand All @@ -179,11 +241,13 @@ object MLMapping {
case e: Exception =>
throw new RuntimeException(s"${name} is not found")
}


}
}
}

private def getRegisteredMapping: Map[String, String] = {
MLMapping.mapping ++ ETRegister.getMapping
}
}

case class TrainStatement(raw: String, inputTableName: String, etName: String, path: String, option: Map[String, String], outputTableName: String)
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package tech.mlsql.ets.register
import tech.mlsql.runtime.AppRuntimeStore

import scala.collection.JavaConverters._
import scala.collection.concurrent


/**
Expand All @@ -17,7 +18,12 @@ object ETRegister {

def remove(name: String) = mapping.remove(name)

def getMapping = {
/**
* @return et mapping
* @see If you need to get all the ET, you should use the function [[tech.mlsql.dsl.adaptor.MLMapping.getETMapping]],
* because there are still registrations using MLMapping.
*/
def getMapping: concurrent.Map[String, String] = {
mapping.asScala
}

Expand Down

0 comments on commit fc993c8

Please sign in to comment.