Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rezazadeh committed Jan 4, 2014
1 parent 35adc72 commit 8bfcce1
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 20 deletions.
28 changes: 19 additions & 9 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/SVD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.jblas.{DoubleMatrix, Singular, MatrixFunctions}
* @param m number of rows
* @param n number of columns
*/
class GradientDescent(var data: RDD[MatrixEntry], var m: Int, var n: Int) {
class SVD(var data: RDD[MatrixEntry], var m: Int, var n: Int) {
private var k: Int = 1

/**
Expand Down Expand Up @@ -65,6 +65,13 @@ class GradientDescent(var data: RDD[MatrixEntry], var m: Int, var n: Int) {
this.n = n
this
}

/**
* Compute SVD using the current set parameters
*/
def computeSVD() : SVDecomposedMatrix = {
SVD.sparseSVD(data, m, n, k)
}
}


Expand Down Expand Up @@ -169,33 +176,36 @@ object SVD {
SVDecomposedMatrix(retU, retS, retV)
}

/*

def main(args: Array[String]) {
if (args.length < 8) {
println("Usage: SVD <master> <matrix_file> <m> <n>
<minimum_singular_value> <output_U_file> <output_S_file> <output_V_file>")
println("Usage: SVD <master> <matrix_file> <m> <n>" +
"<k> <output_U_file> <output_S_file> <output_V_file>")
System.exit(1)
}
val (master, inputFile, m, n, min_svalue, output_u, output_s, output_v) =
val (master, inputFile, m, n, k, output_u, output_s, output_v) =
(args(0), args(1), args(2).toInt, args(3).toInt,
args(4).toDouble, args(5), args(6), args(7))
args(4).toInt, args(5), args(6), args(7))

val sc = new SparkContext(master, "SVD")

val rawdata = sc.textFile(inputFile)
val data = rawdata.map { line =>
val parts = line.split(',')
((parts(0).toInt, parts(1).toInt), parts(2).toDouble)
MatrixEntry(parts(0).toInt, parts(1).toInt, parts(2).toDouble)
}

val (u, s, v) = SVD.sparseSVD(data, m, n, min_svalue)
val decomposed = SVD.sparseSVD(data, m, n, k)
val u = decomposed.U
val s = decomposed.S
val v = decomposed.V

println("Computed " + s.toArray.length + " singular values and vectors")
u.saveAsTextFile(output_u)
s.saveAsTextFile(output_s)
v.saveAsTextFile(output_v)
System.exit(0)
}
*/
}


28 changes: 17 additions & 11 deletions mllib/src/test/scala/org/apache/spark/mllib/linalg/SVDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,18 @@ class SVDSuite extends FunSuite with BeforeAndAfterAll {
// check multiplication guarantee
assertMatrixEquals(retu.mmul(rets).mmul(retv.transpose), densea)
}
/*

test("rank one matrix svd") {
val m = 10
val n = 3
val data = sc.makeRDD(Array.tabulate(m,n){ (a,b)=>
((a+1,b+1), 1.0) }.flatten )
val min_svalue = 1.0e-4
val data = sc.makeRDD(Array.tabulate(m, n){ (a,b) =>
MatrixEntry(a + 1, b + 1, 1.0) }.flatten )
val k = 1

val (u, s, v) = SVD.sparseSVD(data, m, n, min_svalue)
val decomposed = SVD.sparseSVD(data, m, n, k)
val u = decomposed.U
val s = decomposed.S
val v = decomposed.V
val retrank = s.toArray.length

assert(retrank == 1, "rank returned not one")
Expand All @@ -116,15 +119,18 @@ class SVDSuite extends FunSuite with BeforeAndAfterAll {
assertMatrixEquals(retu.mmul(rets).mmul(retv.transpose), densea)
}

test("truncated with min singular value") {
test("truncated with k") {
val m = 10
val n = 3
val data = sc.makeRDD(Array.tabulate(m,n){ (a,b)=>
((a+1,b+1), (a+2).toDouble*(b+1)/(1+a+b)) }.flatten )
val data = sc.makeRDD(Array.tabulate(m,n){ (a, b) =>
MatrixEntry(a + 1, b + 1, (a + 2).toDouble*(b + 1)/(1 + a + b)) }.flatten )

val min_svalue = 5.0 // only one svalue above this
val k = 1 // only one svalue above this

val (u, s, v) = SVD.sparseSVD(data, m, n, min_svalue)
val decomposed = SVD.sparseSVD(data, m, n, k)
val u = decomposed.U
val s = decomposed.S
val v = decomposed.V
val retrank = s.toArray.length

val densea = getDenseMatrix(data, m, n)
Expand All @@ -140,5 +146,5 @@ class SVDSuite extends FunSuite with BeforeAndAfterAll {
assertMatrixEquals(retu, svd(0).getColumn(0))
assertMatrixEquals(rets, DoubleMatrix.diag(svd(1).getRow(0)))
assertMatrixEquals(retv, svd(2).getColumn(0))
}*/
}
}

0 comments on commit 8bfcce1

Please sign in to comment.