Skip to content

Commit

Permalink
Merge pull request karlhigley#50 from karlhigley/fix/query
Browse files Browse the repository at this point in the history
Fix the # of neighbor sets returned by the query version of `neighbors`
  • Loading branch information
karlhigley authored Jul 4, 2016
2 parents 666eb07 + 6e4f322 commit c83b583
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class ANNModel private[neighbors] (
val queryHashTables = ANNModel.generateHashTable(queryPoints, hashFunctions)
val queryEntries = collisionStrategy.apply(queryHashTables)

val candidateGroups = modelEntries.cogroup(queryEntries).values
val candidateGroups = queryEntries.cogroup(modelEntries).values
val neighbors = computeBipartiteDistances(candidateGroups)
neighbors.topByKey(quantity)(ANNModel.ordering)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,17 +128,17 @@ class ANNSuite extends FunSuite with TestSparkContext {
}

test("find neighbors for a set of query points") {
val localPoints = TestHelpers.generateRandomPoints(100, dimensions, density)
val testPoints = sc.parallelize(localPoints).zipWithIndex.map(_.swap)

val ann =
new ANN(dimensions, "hamming")
new ANN(dimensions, "cosine")
.setTables(1)
.setSignatureLength(16)
.setSignatureLength(4)

val model = ann.train(points)
val neighbors = model.neighbors(testPoints, 10)
val neighborIds = neighbors.map(_._1)

val queryPoints = points.sample(withReplacement = false, fraction = 0.01)
val approxNeighbors = model.neighbors(queryPoints, 10)

assert(approxNeighbors.count() == queryPoints.count())
}
}

Expand Down

0 comments on commit c83b583

Please sign in to comment.