Skip to content

Commit

Permalink
[SPARK-11898][MLLIB] Use broadcast for the global tables in Word2Vec
Browse files Browse the repository at this point in the history
jira: https://issues.apache.org/jira/browse/SPARK-11898
syn0Global and sync1Global in word2vec are quite large objects with size (vocab * vectorSize * 8), yet they are passed to worker using basic task serialization.

Use broadcast can greatly improve the performance. My benchmark shows that, for 1M vocabulary and default vectorSize 100, changing to broadcast can help,

1. decrease the worker memory consumption by 45%.
2. decrease running time by 40%.

This will also help extend the upper limit for Word2Vec.

Author: Yuhao Yang <[email protected]>

Closes apache#9878 from hhbyyh/w2vBC.
  • Loading branch information
hhbyyh authored and srowen committed Dec 1, 2015
1 parent 9693b0d commit a0af0e3
Showing 1 changed file with 6 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -316,12 +316,15 @@ class Word2Vec extends Serializable with Logging {
Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)
val syn1Global = new Array[Float](vocabSize * vectorSize)
var alpha = learningRate

for (k <- 1 to numIterations) {
val bcSyn0Global = sc.broadcast(syn0Global)
val bcSyn1Global = sc.broadcast(syn1Global)
val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))
val syn0Modify = new Array[Int](vocabSize)
val syn1Modify = new Array[Int](vocabSize)
val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) {
val model = iter.foldLeft((bcSyn0Global.value, bcSyn1Global.value, 0, 0)) {
case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
var lwc = lastWordCount
var wc = wordCount
Expand Down Expand Up @@ -405,6 +408,8 @@ class Word2Vec extends Serializable with Logging {
}
i += 1
}
bcSyn0Global.unpersist(false)
bcSyn1Global.unpersist(false)
}
newSentences.unpersist()

Expand Down

0 comments on commit a0af0e3

Please sign in to comment.