Skip to content

Commit

Permalink
Clean up the test set-up using beforeAll
Browse files Browse the repository at this point in the history
`sc.parallelize` can only be used inside the test blocks. `beforeAll` counts.
  • Loading branch information
karlhigley committed May 22, 2016
1 parent 21bbe26 commit 76f4201
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,25 @@ package com.github.karlhigley.spark.neighbors

import org.scalatest.FunSuite

import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.SparseVector

import com.github.karlhigley.spark.neighbors.lsh.HashTableEntry

class ANNModelSuite extends FunSuite with TestSparkContext {
val numPoints = 1000
val dimensions = 100
val density = 0.5

val localPoints = TestHelpers.generateRandomPoints(numPoints, dimensions, density)
var points: RDD[(Int, SparseVector)] = _

test("average selectivity is between zero and one") {
val points = sc.parallelize(localPoints.zipWithIndex.map(_.swap))
override def beforeAll() {
super.beforeAll()
val localPoints = TestHelpers.generateRandomPoints(numPoints, dimensions, density)
points = sc.parallelize(localPoints.zipWithIndex.map(_.swap))
}

test("average selectivity is between zero and one") {
val ann =
new ANN(dimensions, "cosine")
.setTables(1)
Expand All @@ -27,8 +34,6 @@ class ANNModelSuite extends FunSuite with TestSparkContext {
}

test("average selectivity increases with more tables") {
val points = sc.parallelize(localPoints.zipWithIndex.map(_.swap))

val ann =
new ANN(dimensions, "cosine")
.setTables(1)
Expand All @@ -43,8 +48,6 @@ class ANNModelSuite extends FunSuite with TestSparkContext {
}

test("average selectivity decreases with signature length") {
val points = sc.parallelize(localPoints.zipWithIndex.map(_.swap))

val ann =
new ANN(dimensions, "cosine")
.setTables(1)
Expand Down
28 changes: 13 additions & 15 deletions src/test/scala/com/github/karlhigley/spark/neighbors/ANNSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package com.github.karlhigley.spark.neighbors

import org.scalatest.FunSuite

import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.SparseVector

import com.github.karlhigley.spark.neighbors.lsh.HashTableEntry

class ANNSuite extends FunSuite with TestSparkContext {
Expand All @@ -11,11 +14,15 @@ class ANNSuite extends FunSuite with TestSparkContext {
val dimensions = 100
val density = 0.5

val localPoints = TestHelpers.generateRandomPoints(numPoints, dimensions, density)
var points: RDD[(Int, SparseVector)] = _

test("compute hamming nearest neighbors as a batch") {
val points = sc.parallelize(localPoints.zipWithIndex.map(_.swap))
override def beforeAll() {
super.beforeAll()
val localPoints = TestHelpers.generateRandomPoints(numPoints, dimensions, density)
points = sc.parallelize(localPoints.zipWithIndex.map(_.swap))
}

test("compute hamming nearest neighbors as a batch") {
val ann =
new ANN(dimensions, "hamming")
.setTables(1)
Expand All @@ -31,8 +38,6 @@ class ANNSuite extends FunSuite with TestSparkContext {
}

test("compute cosine nearest neighbors as a batch") {
val points = sc.parallelize(localPoints.zipWithIndex.map(_.swap))

val ann =
new ANN(dimensions, "cosine")
.setTables(1)
Expand All @@ -48,8 +53,6 @@ class ANNSuite extends FunSuite with TestSparkContext {
}

test("compute euclidean nearest neighbors as a batch") {
val points = sc.parallelize(localPoints.zipWithIndex.map(_.swap))

val ann =
new ANN(dimensions, "euclidean")
.setTables(1)
Expand All @@ -66,8 +69,6 @@ class ANNSuite extends FunSuite with TestSparkContext {
}

test("compute manhattan nearest neighbors as a batch") {
val points = sc.parallelize(localPoints.zipWithIndex.map(_.swap))

val ann =
new ANN(dimensions, "manhattan")
.setTables(1)
Expand All @@ -84,8 +85,6 @@ class ANNSuite extends FunSuite with TestSparkContext {
}

test("compute jaccard nearest neighbors as a batch") {
val points = sc.parallelize(localPoints.zipWithIndex.map(_.swap))

val ann =
new ANN(dimensions, "jaccard")
.setTables(1)
Expand All @@ -103,6 +102,7 @@ class ANNSuite extends FunSuite with TestSparkContext {
}

test("with multiple hash tables neighbors don't contain duplicates") {
val localPoints = TestHelpers.generateRandomPoints(numPoints, dimensions, density)
val withDuplicates = localPoints ++ localPoints
val points = sc.parallelize(withDuplicates.zipWithIndex.map(_.swap))

Expand All @@ -128,10 +128,8 @@ class ANNSuite extends FunSuite with TestSparkContext {
}

test("find neighbors for a set of query points") {
val points = sc.parallelize(localPoints.zipWithIndex.map(_.swap))

val localTestPoints = TestHelpers.generateRandomPoints(100, dimensions, density)
val testPoints = sc.parallelize(localTestPoints.zipWithIndex.map(_.swap))
val localPoints = TestHelpers.generateRandomPoints(100, dimensions, density)
val testPoints = sc.parallelize(localPoints.zipWithIndex.map(_.swap))

val ann =
new ANN(dimensions, "hamming")
Expand Down

0 comments on commit 76f4201

Please sign in to comment.