Skip to content

Commit

Permalink
fix some bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
allwefantasy committed May 9, 2019
1 parent 1c0c899 commit bf09978
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
package streaming.dsl.auth

import streaming.core.datasource.{DataAuthConfig, DataSourceRegistry, SourceInfo}
import streaming.dsl.DslTool
import streaming.dsl.parser.DSLSQLParser._
import streaming.dsl.template.TemplateMerge
import streaming.dsl.{DslTool, ScriptSQLExec}
import streaming.log.{Logging, WowLog}
import tech.mlsql.dsl.processor.AuthProcessListener

Expand Down Expand Up @@ -51,7 +51,7 @@ class SaveAuth(authProcessListener: AuthProcessListener) extends MLSQLAuth with
ctx.getChild(tokenIndex) match {
case s: FormatContext =>
format = s.getText

case s: PathContext =>
path = TemplateMerge.merge(cleanStr(s.getText), env)

Expand Down Expand Up @@ -82,10 +82,14 @@ class SaveAuth(authProcessListener: AuthProcessListener) extends MLSQLAuth with
} getOrElse {
format match {
case "hive" =>
val Array(db, table) = final_path.split("\\.")
val Array(db, table) = final_path.split("\\.") match {
case Array(db, table) => Array(db, table)
case Array(table) => Array("default", table)
}
MLSQLTable(Some(db), Some(table), OperateType.SAVE, Some(format), TableType.HIVE)
case _ =>
MLSQLTable(None, Some(final_path), OperateType.SAVE, Some(format), TableType.from(format).get)
val context = ScriptSQLExec.contextGetOrForTest()
MLSQLTable(None, Some(resourceRealPath(context.execListener, owner, path)), OperateType.SAVE, Some(format), TableType.from(format).get)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ package streaming.dsl.mmlib.algs

import java.net.URLEncoder

import net.csdn.common.reflect.ReflectHelper
import org.apache.commons.compress.archivers.tar.TarArchiveInputStream
import org.apache.http.HttpResponse
import org.apache.http.client.fluent.Request
import org.apache.spark.ml.param.Param
import org.apache.spark.sql.expressions.UserDefinedFunction
Expand Down Expand Up @@ -99,12 +101,18 @@ class SQLDownloadExt(override val uid: String) extends SQLAlg with Logging with
URLEncoder.encode(name, "utf-8")
}

logInfo(format(s"download file from src:${$(from)} to dst:${$(to)}"))

val getUrl = fromUrl + s"?userName=${urlencode(context.owner)}&fileName=${urlencode($(from))}&auth_secret=${urlencode(auth_secret)}"
val stream = Request.Get(getUrl)

val response = Request.Get(getUrl)
.connectTimeout(60 * 1000)
.socketTimeout(10 * 60 * 1000)
.execute().returnContent().asStream()
.execute()
// Since response always consume the inputstream and return new stream, this will cost too much memory.
val stream = ReflectHelper.field(response, "response").asInstanceOf[HttpResponse].getEntity.getContent
val tarIS = new TarArchiveInputStream(stream)

var downloadResultRes = ArrayBuffer[DownloadResult]()
try {
var entry = tarIS.getNextEntry
Expand All @@ -113,6 +121,7 @@ class SQLDownloadExt(override val uid: String) extends SQLAlg with Logging with
if (!entry.isDirectory) {
val dir = entry.getName.split("/").filterNot(f => f.isEmpty).dropRight(1).mkString("/")
downloadResultRes += DownloadResult(PathFun(originalTo).add(dir).add(entry.getName.split("/").last).toPath)
logInfo(format(s"extracting ${downloadResultRes.last.hdfsPath}"))
HDFSOperator.saveStream($(to) + "/" + dir, entry.getName.split("/").last, tarIS)
}
entry = tarIS.getNextEntry
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
package streaming.dsl.mmlib.algs

import org.apache.spark.Partitioner
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, functions => F}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.mlsql.session.MLSQLException
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SparkSession, functions => F}
import streaming.dsl.mmlib.SQLAlg

/**
Expand All @@ -38,6 +39,7 @@ class SQLRateSampler extends SQLAlg with Functions {
case a: Double => a.asInstanceOf[Double].toInt
case a: Float => a.asInstanceOf[Float].toInt
case a: Long => a.asInstanceOf[Long].toInt
case _ => throw new MLSQLException("The type of labelCol should be int/double/float/long")
}
}

Expand Down Expand Up @@ -133,8 +135,7 @@ class SQLRateSampler extends SQLAlg with Functions {

override def train(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {
val newDF = internal_train(df, params)
newDF.write.mode(SaveMode.Overwrite).parquet(path)
emptyDataFrame()(df)
newDF
}

override def load(sparkSession: SparkSession, path: String, params: Map[String, String]): Any = ???
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class SQLTfIdf extends SQLAlg with Functions {
configureModel(rfc, params)
rfc.setOutputCol("__SQLTfIdf__")
val featurizedData = rfc.transform(df)
rfc.getBinary

val idf = new IDF()
configureModel(idf, params)
idf.setInputCol("__SQLTfIdf__")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.ml.linalg.SQLDataTypes._
import org.apache.spark.ml.param.{DoubleParam, Param}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import streaming.core.shared.SharedObjManager
import streaming.dsl.mmlib.algs.MetaConst._
import streaming.dsl.mmlib.algs.classfication.BaseClassification
Expand All @@ -41,7 +41,6 @@ class SQLTfIdfInPlace(override val uid: String) extends SQLAlg with MllibFunctio

override def train(df: DataFrame, path: String, params: Map[String, String]): DataFrame = {
val newDF = interval_train(df, params + ("path" -> path))
newDF.write.mode(SaveMode.Overwrite).parquet(getDataPath(path))
newDF
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ object CommandCollection {
context.addEnv("hdfs",""" run command as HDFSCommand.`` where parameters='''{:all}''' """)
context.addEnv("fs",""" run command as HDFSCommand.`` where parameters='''{:all}''' """)

context.addEnv("split",""" run {0} as RateSampler.`` where labelCol="{2}" and sampleRate="{4}" as {6} """)

context.addEnv("saveUploadFileToHome",""" run command as DownloadExt.`` where from="{}" and to="{}" """)

context.addEnv("show",
"""
|run command as ShowCommand.`{}/{}/{}/{}/{}/{}/{}/{}/{}/{}/{}/{}`
Expand Down

0 comments on commit bf09978

Please sign in to comment.