Skip to content

Commit

Permalink
Word2vecInplace merge bug修复
Browse files Browse the repository at this point in the history
  • Loading branch information
zml1206 committed Aug 1, 2018
1 parent 18530ab commit dfd112b
Showing 1 changed file with 21 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,19 @@ class SQLWord2VecInPlace extends SQLAlg with Functions {
}
if (resultFeature.equals("merge")) {
val flatFeatureUdf = F.udf((a: Seq[Seq[Double]]) => {
val r = new Array[Double](vectorSize)
for (a1 <- a) {
val b = a1.toList
for (i <- 0 until b.size) {
r(i) = b(i) + r(i)
if (a.size == 0) {
Seq[Double]()
}
else {
val r = new Array[Double](vectorSize)
for (a1 <- a) {
val b = a1.toList
for (i <- 0 until b.size) {
r(i) = b(i) + r(i)
}
}
r.toSeq
}
r.toSeq
})
newDF = newDF.withColumn(inputCol, flatFeatureUdf(F.col(inputCol)))
}
Expand Down Expand Up @@ -146,14 +151,18 @@ class SQLWord2VecInPlace extends SQLAlg with Functions {
case "merge" => {
val f2 = (a: String) => {
val seq = func(a)
val r = new Array[Double](vectorSize)
for (a1 <- seq) {
val b = a1.toList
for (i <- 0 until b.size) {
r(i) = b(i) + r(i)
if (seq.size == 0) {
Seq[Double]()
} else {
val r = new Array[Double](vectorSize)
for (a1 <- seq) {
val b = a1.toList
for (i <- 0 until b.size) {
r(i) = b(i) + r(i)
}
}
r.toSeq
}
r.toSeq
}
UserDefinedFunction(f2, ArrayType(DoubleType), Some(Seq(StringType)))
}
Expand Down

0 comments on commit dfd112b

Please sign in to comment.