Skip to content

Commit

Permalink
reduce heap memory consumption by avoiding the creation of a single l…
Browse files Browse the repository at this point in the history
…arge array for input data
  • Loading branch information
kiszk committed Dec 10, 2015
1 parent bb72343 commit 594dccd
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,14 @@ object SparkGPULR {
def ddotvv(x: Array[Double], y: Array[Double]) : Double =
(x zip y).foldLeft(0.0)((a, b) => a + (b._1 * b._2))

def generateData(N: Int, D: Int, R: Double): Array[DataPoint] = {
def generateData(seed: Int, N: Int, D: Int, R: Double): DataPoint = {
val r = new Random(seed)
def generatePoint(i: Int): DataPoint = {
val y = if (i % 2 == 0) -1 else 1
val x = Array.fill(D){rand.nextGaussian + y * R}
val x = Array.fill(D){r.nextGaussian + y * R}
DataPoint(x, y)
}
Array.tabulate(N)(generatePoint)
generatePoint(seed)
}

def showWarning() {
Expand Down Expand Up @@ -100,9 +101,10 @@ object SparkGPULR {
Some((size: Long) => 1),
Some(dimensions)))

val points = sc.parallelize(generateData(N, D, R), numSlices)
points.cacheGpu()
val pointsColumnCached = points.convert(ColumnFormat).cache()
val skelton = sc.parallelize((1 to N), numSlices)
val points = skelton.map(i => generateData(i, N, D, R))
val pointsColumnCached = points.convert(ColumnFormat).cache().cacheGpu()
pointsColumnCached.count()

// Initialize w to a random value
var w = Array.fill(D){2 * rand.nextDouble - 1}
Expand Down
13 changes: 8 additions & 5 deletions examples/src/main/scala/org/apache/spark/examples/SparkLR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,14 @@ object SparkLR {

case class DataPoint(x: Vector[Double], y: Double)

def generateData(N: Int, D: Int, R: Double): Array[DataPoint] = {
def generateData(seed: Int, N: Int, D: Int, R: Double): DataPoint = {
val r = new Random(seed)
def generatePoint(i: Int): DataPoint = {
val y = if (i % 2 == 0) -1 else 1
val x = DenseVector.fill(D){rand.nextGaussian + y * R}
val x = DenseVector.fill(D){r.nextGaussian + y * R}
DataPoint(x, y)
}
Array.tabulate(N)(generatePoint)
generatePoint(seed)
}

def showWarning() {
Expand All @@ -70,8 +71,10 @@ object SparkLR {
val R = 0.7 // Scaling factor
val ITERATIONS = if (args.length > 3) args(3).toInt else 5

val points = sc.parallelize(generateData(N, D, R), numSlices).cache()

val skelton = sc.parallelize((1 to N), numSlices)
val points = skelton.map(i => generateData(i, N, D, R)).cache()
points.count()

// Initialize w to a random value
var w = DenseVector.fill(D){2 * rand.nextDouble - 1}
printf("numSlices=%d, N=%d, D=%d, ITERATIONS=%d\n", numSlices, N, D, ITERATIONS)
Expand Down

0 comments on commit 594dccd

Please sign in to comment.