Skip to content

Commit

Permalink
Merge pull request byzer-org#396 from zml1206/master
Browse files Browse the repository at this point in the history
Word2vecInplace merge预测bug,word2vec predict 空指针bug 修复
  • Loading branch information
allwefantasy authored Jul 30, 2018
2 parents 843cca0 + 95670a0 commit 18530ab
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import streaming.dsl.mmlib.SQLAlg
import scala.collection.JavaConversions._

/**
* Created by allwefantasy on 13/1/2018.
*/
* Created by allwefantasy on 13/1/2018.
*/
class SQLWord2Vec extends SQLAlg with Functions {

def train(df: DataFrame, path: String, params: Map[String, String]) = {
Expand All @@ -34,13 +34,13 @@ class SQLWord2Vec extends SQLAlg with Functions {
val f = (co: String) => {
model.value.get(co) match {
case Some(vec) => vec.toSeq
case None => null
case None => Seq[Double]()
}

}

val f2 = (co: Seq[String]) => {
co.map(f(_))
co.map(f(_)).filter(x => x.size > 0)
}

Map((name + "_array") -> f2, name -> f)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
package streaming.dsl.mmlib.algs

import MetaConst._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, functions => F}
import streaming.core.shared.SharedObjManager
import streaming.dsl.mmlib.SQLAlg
import streaming.dsl.mmlib.algs.MetaConst._
import streaming.dsl.mmlib.algs.feature.StringFeature
import streaming.dsl.mmlib.algs.feature.StringFeature.loadWordvecs
import streaming.dsl.mmlib.algs.meta.Word2VecMeta
Expand All @@ -27,7 +26,7 @@ class SQLWord2VecInPlace extends SQLAlg with Functions {
val length = params.getOrElse("length", "100").toInt
val stopWordPath = params.getOrElse("stopWordPath", "")
val resultFeature = params.getOrElse("resultFeature", "") //flat,merge,index
val minCount = params.getOrElse("minCount","1").toInt
val minCount = params.getOrElse("minCount", "1").toInt
val split = params.getOrElse("split", null)
require(!inputCol.isEmpty, "inputCol is required when use SQLWord2VecInPlace")
val metaPath = getMetaPath(params("path"))
Expand Down Expand Up @@ -88,7 +87,6 @@ class SQLWord2VecInPlace extends SQLAlg with Functions {
val resultFeature = trainParams.getOrElse("resultFeature", "")
val vectorSize = trainParams.getOrElse("vectorSize", "100").toInt
val length = trainParams.getOrElse("length", "100").toInt
val minCount = trainParams.getOrElse("minCount", "1").toInt
val wordIndexBr = spark.sparkContext.broadcast(word2vecMeta.wordIndex)
val split = trainParams.getOrElse("split", null)

Expand All @@ -114,53 +112,59 @@ class SQLWord2VecInPlace extends SQLAlg with Functions {
}
val func = (content: String) => {
val wordArray = wordArrayFunc(content)
val wordIntArray = wordArray.filter(f => wordIndexBr.value.contains(f)).map(f => wordIndexBr.value(f).toInt)
wordIntArray.map(f => f.toString).toSeq
}

val func2 = (content: String) => {
val wordArray = wordArrayFunc(content)
val r = new Array[Array[Double]](length)
val wordvecsMap = wordVecsBr.value
val wSize = wordArray.size
for (i <- 0 until length) {
if (i < wSize && wordvecsMap.contains(wordArray(i))) {
r(i) = wordvecsMap(wordArray(i))
} else
r(i) = new Array[Double](vectorSize)
if (wordVecsBr.value.size > 0) {
val r = new Array[Seq[Double]](length)
val wordvecsMap = wordVecsBr.value
val wSize = wordArray.size
for (i <- 0 until length) {
if (i < wSize && wordvecsMap.contains(wordArray(i))) {
r(i) = wordvecsMap(wordArray(i))
} else
r(i) = new Array[Double](vectorSize)
}
r.toSeq
}
else {
val wordIntArray = wordArray.filter(f => wordIndexBr.value.contains(f)).map(f => wordIndexBr.value(f).toInt)
word2vecMeta.predictFunc(wordIntArray.map(f => f.toString).toSeq)
}
r
}

val func3 = (content: String) => {

val funcIndex = (content: String) => {
val wordArray = wordArrayFunc(content)
wordArray.filter(f => wordIndexBr.value.contains(f)).map(f => wordIndexBr.value(f).toInt)
}

def resultFeaturematch(x: Int): String = x match {
case 1 => "one"
case 2 => "two"
case _ => "many"
}

if (wordVecsBr.value.size > 0) {
UserDefinedFunction(func2, ArrayType(ArrayType(DoubleType)), Some(Seq(StringType)))
} else {
resultFeature match {
case "index" => UserDefinedFunction(func3, ArrayType(IntegerType), Some(Seq(StringType)))
case "flag" => {
val f2 = (a: String) => {
word2vecMeta.predictFunc(func(a)).flatten
}
UserDefinedFunction(f2, ArrayType(DoubleType), Some(Seq(StringType)))
resultFeature match {
case "flat" => {
val f2 = (a: String) => {
func(a).flatten
}
case _ => {
val f2 = (a: String) => {
word2vecMeta.predictFunc(func(a))
UserDefinedFunction(f2, ArrayType(DoubleType), Some(Seq(StringType)))
}
case "merge" => {
val f2 = (a: String) => {
val seq = func(a)
val r = new Array[Double](vectorSize)
for (a1 <- seq) {
val b = a1.toList
for (i <- 0 until b.size) {
r(i) = b(i) + r(i)
}
}
UserDefinedFunction(f2, ArrayType(ArrayType(DoubleType)), Some(Seq(StringType)))
r.toSeq
}
UserDefinedFunction(f2, ArrayType(DoubleType), Some(Seq(StringType)))
}
case _ => {
if (wordVecsBr.value.size == 0 && resultFeature.equals("index"))
UserDefinedFunction(funcIndex, ArrayType(IntegerType), Some(Seq(StringType)))
else
UserDefinedFunction(func, ArrayType(ArrayType(DoubleType)), Some(Seq(StringType)))
}
}
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ object StringFeature extends BaseFeatureFunctions {
val word2vec = new SQLWord2Vec()
word2vec.train(replaceColumn(newDF, inputCol, F.udf((a: Seq[Int]) => {
a.map(f => f.toString)
})), WORD2VEC_PATH(metaPath, inputCol), Map("inputCol" -> inputCol, "minCount" -> "0", "vectorSize" -> (vectorSize + ""), "minCount" -> minCount.toString))
})), WORD2VEC_PATH(metaPath, inputCol), Map("inputCol" -> inputCol, "vectorSize" -> (vectorSize + ""), "minCount" -> minCount.toString))
val model = word2vec.load(df.sparkSession, WORD2VEC_PATH(metaPath, inputCol), Map())
if (!resultFeature.equals("index")) {
val predictFunc = word2vec.internal_predict(df.sparkSession, model, "wow")("wow_array").asInstanceOf[(Seq[String]) => Seq[Seq[Double]]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ class AutoMLSpec extends BasicSparkOperation with SpecFunctions with BasicMLSQLC
//执行sql

val wordvecPaths = Seq("/Users/zml/mlsql/word2vec", "")
val resultFeatures = Seq("index", "")
val resultFeatures = Seq("index", "merge", "flat", "")
for (wordvecPath <- wordvecPaths) {
for (resultFeature <- resultFeatures) {
implicit val spark = runtime.sparkSession
Expand All @@ -334,6 +334,8 @@ class AutoMLSpec extends BasicSparkOperation with SpecFunctions with BasicMLSQLC
// we should make sure train vector and predict vector the same
val trainVector = spark.sql("select * from parquet.`/tmp/william/tmp/word2vecinplace/data`").toJSON.collect()
val predictVector = spark.sql("select jack(content) as content from orginal_text_corpus").toJSON.collect()
println("wordvecPath:" + wordvecPath)
println("resultFeature:" + resultFeature)
predictVector.foreach { f =>
assume(trainVector.contains(f))
}
Expand All @@ -344,7 +346,7 @@ class AutoMLSpec extends BasicSparkOperation with SpecFunctions with BasicMLSQLC

"SQLWord2VecInPlaceMinCount" should "work fine" taggedAs (NotToRunTag) in {
withBatchContext(setupBatchContext(batchParams, "classpath:///test/empty.json")) { runtime: SparkRuntime =>
val resultFeatures = Seq("index", "")
val resultFeatures = Seq("flat", "merge", "index", "")
val minCounts = Seq("1", "2")
for (resultFeature <- resultFeatures) {
val wordCount = Array(0, 0)
Expand Down Expand Up @@ -377,6 +379,7 @@ class AutoMLSpec extends BasicSparkOperation with SpecFunctions with BasicMLSQLC
wordCount(i) = JSONObject.fromObject(countVector(0)).get("a").toString.toInt
i = i + 1
}
println("resultFeature:" + resultFeature)
assume(wordCount(0) > wordCount(1))
}
}
Expand Down

0 comments on commit 18530ab

Please sign in to comment.