Skip to content

Commit

Permalink
[SPARK-33518][ML][FOLLOWUP] MatrixFactorizationModel use GEMV
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
1, update related doc;
2, MatrixFactorizationModel use GEMV;

### Why are the changes needed?
see performance gain in apache#30468

### Does this PR introduce _any_ user-facing change?
NO

### How was this patch tested?
existing testsuites

Closes apache#31279 from zhengruifeng/als_follow_up.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed Jan 27, 2021
1 parent abf7e81 commit 56e9426
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,8 @@ class ALSModel private[ml] (
* relatively efficient, the approach implemented here is significantly more efficient.
*
* This approach groups factors into blocks and computes the top-k elements per block,
* using dot product and an efficient [[BoundedPriorityQueue]] (instead of gemm).
* using GEMV (it use less memory compared with GEMM, and is much faster than DOT) and
* an efficient selection based on [[GuavaOrdering]] (instead of [[BoundedPriorityQueue]]).
* It then computes the global top-k by aggregating the per block top-k elements with
* a [[TopByKeyAggregator]]. This significantly reduces the size of intermediate and shuffle data.
* This is the DataFrame equivalent to the approach used in
Expand Down Expand Up @@ -481,7 +482,7 @@ class ALSModel private[ml] (
}

Iterator.range(0, m).flatMap { i =>
// buffer = i-th vec in srcMat * dstMat
// scores = i-th vec in srcMat * dstMat
BLAS.f2jBLAS.sgemv("T", rank, n, 1.0F, dstMat, 0, rank,
srcMat, i * rank, 1, 0.0F, scores, 0, 1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.lang.{Integer => JavaInteger}

import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus
import com.github.fommil.netlib.BLAS.{getInstance => blas}
import com.google.common.collect.{Ordering => GuavaOrdering}
import org.apache.hadoop.fs.Path
import org.json4s._
import org.json4s.JsonDSL._
Expand All @@ -37,7 +38,6 @@ import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.BoundedPriorityQueue

/**
* Model representing the result of matrix factorization.
Expand Down Expand Up @@ -277,11 +277,12 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
* arrays required for intermediate result storage, as well as a high sensitivity to the
* block size used.
*
* The following approach still groups factors into blocks, but instead computes the
* top-k elements per block, using dot product and an efficient [[BoundedPriorityQueue]]
* (instead of gemm). This avoids any large intermediate data structures and results
* in significantly reduced GC pressure as well as shuffle data, which far outweighs
* any cost incurred from not using Level 3 BLAS operations.
* This approach groups factors into blocks and computes the top-k elements per block,
* using GEMV (it use less memory compared with GEMM, and is much faster than DOT) and
* an efficient selection based on [[GuavaOrdering]] (instead of [[BoundedPriorityQueue]]).
* It then computes the global top-k by aggregating the per block top-k elements with
* a [[BoundedPriorityQueue]]. This significantly reduces the size of intermediate and
* shuffle data.
*
* @param rank rank
* @param srcFeatures src features to receive recommendations
Expand All @@ -295,28 +296,40 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
srcFeatures: RDD[(Int, Array[Double])],
dstFeatures: RDD[(Int, Array[Double])],
num: Int): RDD[(Int, Array[(Int, Double)])] = {
import scala.collection.JavaConverters._
val srcBlocks = blockify(srcFeatures)
val dstBlocks = blockify(dstFeatures)
val ratings = srcBlocks.cartesian(dstBlocks).flatMap { case (srcIter, dstIter) =>
val m = srcIter.size
val n = math.min(dstIter.size, num)
val output = new Array[(Int, (Int, Double))](m * n)
var i = 0
val pq = new BoundedPriorityQueue[(Int, Double)](n)(Ordering.by(_._2))
srcIter.foreach { case (srcId, srcFactor) =>
dstIter.foreach { case (dstId, dstFactor) =>
// We use F2jBLAS which is faster than a call to native BLAS for vector dot product
val score = BLAS.f2jBLAS.ddot(rank, srcFactor, 1, dstFactor, 1)
pq += dstId -> score
}
pq.foreach { case (dstId, score) =>
output(i) = (srcId, (dstId, score))
i += 1

val ratings = srcBlocks.cartesian(dstBlocks)
.mapPartitions { iter =>
var scores: Array[Double] = null
var idxOrd: GuavaOrdering[Int] = null
iter.flatMap { case ((srcIds, srcMat), (dstIds, dstMat)) =>
require(srcMat.length == srcIds.length * rank)
require(dstMat.length == dstIds.length * rank)
val m = srcIds.length
val n = dstIds.length
if (scores == null || scores.length < n) {
scores = Array.ofDim[Double](n)
idxOrd = new GuavaOrdering[Int] {
override def compare(left: Int, right: Int): Int = {
Ordering[Double].compare(scores(left), scores(right))
}
}
}

Iterator.range(0, m).flatMap { i =>
// scores = i-th vec in srcMat * dstMat
BLAS.f2jBLAS.dgemv("T", rank, n, 1.0F, dstMat, 0, rank,
srcMat, i * rank, 1, 0.0F, scores, 0, 1)

val srcId = srcIds(i)
idxOrd.greatestOf(Iterator.range(0, n).asJava, num).asScala
.iterator.map { j => (srcId, (dstIds(j), scores(j))) }
}
}
pq.clear()
}
output.toSeq
}

ratings.topByKey(num)(Ordering.by(_._2))
}

Expand All @@ -326,9 +339,10 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
*/
private def blockify(
features: RDD[(Int, Array[Double])],
blockSize: Int = 4096): RDD[Seq[(Int, Array[Double])]] = {
blockSize: Int = 4096): RDD[(Array[Int], Array[Double])] = {
features.mapPartitions { iter =>
iter.grouped(blockSize)
.map(block => (block.map(_._1).toArray, block.flatMap(_._2).toArray))
}
}

Expand Down

0 comments on commit 56e9426

Please sign in to comment.