Skip to content

Commit

Permalink
Merge pull request karlhigley#38 from karlhigley/feature/point-ids
Browse files Browse the repository at this point in the history
Convert point ids from Int to Long
  • Loading branch information
karlhigley committed May 22, 2016
2 parents f8ffc4c + 139c10f commit 8d28ccf
Show file tree
Hide file tree
Showing 11 changed files with 25 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ class ANN private (
* @return ANNModel containing computed hash tables
*/
def train(
points: RDD[(Int, SparseVector)],
points: RDD[(Long, SparseVector)],
persistenceLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK
): ANNModel = {
var hashFunctions: Array[LSHFunction[_]] = Array()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ class ANNModel private[neighbors] (
private[neighbors] val numPoints: Int
) extends Serializable {

type Point = (Int, SparseVector)
type Point = (Long, SparseVector)
type CandidateGroup = Iterable[Point]

/**
* Identify pairs of nearest neighbors by applying a
* collision strategy to the hash tables and then computing
* the actual distance between candidate pairs.
*/
def neighbors(quantity: Int): RDD[(Int, Array[(Int, Double)])] = {
def neighbors(quantity: Int): RDD[(Long, Array[(Long, Double)])] = {
val candidates = collisionStrategy.apply(hashTables).groupByKey(hashTables.getNumPartitions).values
val neighbors = computeDistances(candidates)
neighbors.topByKey(quantity)(ANNModel.ordering)
Expand All @@ -41,7 +41,7 @@ class ANNModel private[neighbors] (
* only potential matches, cogrouping the two RDDs, and
* computing candidate distances in the "normal" fashion.
*/
def neighbors(queryPoints: RDD[Point], quantity: Int): RDD[(Int, Array[(Int, Double)])] = {
def neighbors(queryPoints: RDD[Point], quantity: Int): RDD[(Long, Array[(Long, Double)])] = {
val modelEntries = collisionStrategy.apply(hashTables)

val queryHashTables = ANNModel.generateHashTable(queryPoints, hashFunctions)
Expand Down Expand Up @@ -80,7 +80,7 @@ class ANNModel private[neighbors] (
* Compute the actual distance between candidate pairs
* using the supplied distance measure.
*/
private def computeDistances(candidates: RDD[CandidateGroup]): RDD[(Int, (Int, Double))] = {
private def computeDistances(candidates: RDD[CandidateGroup]): RDD[(Long, (Long, Double))] = {
candidates
.flatMap {
case group => {
Expand All @@ -101,7 +101,7 @@ class ANNModel private[neighbors] (
* Compute the actual distance between candidate pairs
* using the supplied distance measure.
*/
private def computeBipartiteDistances(candidates: RDD[(CandidateGroup, CandidateGroup)]): RDD[(Int, (Int, Double))] = {
private def computeBipartiteDistances(candidates: RDD[(CandidateGroup, CandidateGroup)]): RDD[(Long, (Long, Double))] = {
candidates
.flatMap {
case (groupA, groupB) => {
Expand All @@ -119,14 +119,14 @@ class ANNModel private[neighbors] (
}

object ANNModel {
private val ordering = Ordering[Double].on[(Int, Double)](_._2).reverse
private val ordering = Ordering[Double].on[(Long, Double)](_._2).reverse

/**
* Train a model by computing signatures for the supplied
* points
*/
def train(
points: RDD[(Int, SparseVector)],
points: RDD[(Long, SparseVector)],
hashFunctions: Array[_ <: LSHFunction[_]],
collisionStrategy: CollisionStrategy,
measure: DistanceMeasure,
Expand All @@ -144,7 +144,7 @@ object ANNModel {
}

def generateHashTable(
points: RDD[(Int, SparseVector)],
points: RDD[(Long, SparseVector)],
hashFunctions: Array[_ <: LSHFunction[_]]
): RDD[_ <: HashTableEntry[_]] = {
val indHashFunctions: Array[(_ <: LSHFunction[_], Int)] = hashFunctions.zipWithIndex
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import com.github.karlhigley.spark.neighbors.lsh.HashTableEntry
* and banding (for minhash LSH).
*/
private[neighbors] abstract class CollisionStrategy {
type Point = (Int, SparseVector)
type Point = (Long, SparseVector)

def apply(hashTables: RDD[_ <: HashTableEntry[_]]): RDD[(Product, Point)]
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ private[neighbors] class BitSamplingFunction(
/**
* Build a hash table entry for the supplied vector
*/
def hashTableEntry(id: Int, table: Int, v: SparseVector): BitHashTableEntry = {
def hashTableEntry(id: Long, table: Int, v: SparseVector): BitHashTableEntry = {
BitHashTableEntry(id, table, signature(v), v)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ private[neighbors] abstract class LSHFunction[+T <: Signature[_]] {
/**
* Build a hash table entry for the supplied vector
*/
def hashTableEntry(id: Int, table: Int, v: SparseVector): HashTableEntry[T]
def hashTableEntry(id: Long, table: Int, v: SparseVector): HashTableEntry[T]
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ private[neighbors] class MinhashFunction(
/**
* Build a hash table entry for the supplied vector
*/
def hashTableEntry(id: Int, table: Int, v: SparseVector): IntHashTableEntry = {
def hashTableEntry(id: Long, table: Int, v: SparseVector): IntHashTableEntry = {
IntHashTableEntry(id, table, signature(v), v)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ private[neighbors] class ScalarRandomProjectionFunction(
/**
* Build a hash table entry for the supplied vector
*/
def hashTableEntry(id: Int, table: Int, v: SparseVector): IntHashTableEntry = {
def hashTableEntry(id: Long, table: Int, v: SparseVector): IntHashTableEntry = {
IntHashTableEntry(id, table, signature(v), v)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ private[neighbors] class SignRandomProjectionFunction(
/**
* Build a hash table entry for the supplied vector
*/
def hashTableEntry(id: Int, table: Int, v: SparseVector): BitHashTableEntry = {
def hashTableEntry(id: Long, table: Int, v: SparseVector): BitHashTableEntry = {
BitHashTableEntry(id, table, signature(v), v)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ private[neighbors] final case class IntSignature(
* in a single RDD.
*/
private[neighbors] sealed abstract class HashTableEntry[+S <: Signature[_]] {
val id: Int
val id: Long
val table: Int
val signature: S
val point: SparseVector
Expand All @@ -41,7 +41,7 @@ private[neighbors] sealed abstract class HashTableEntry[+S <: Signature[_]] {
}

private[neighbors] final case class BitHashTableEntry(
id: Int,
id: Long,
table: Int,
signature: BitSignature,
point: SparseVector
Expand All @@ -52,7 +52,7 @@ private[neighbors] final case class BitHashTableEntry(
}

private[neighbors] final case class IntHashTableEntry(
id: Int,
id: Long,
table: Int,
signature: IntSignature,
point: SparseVector
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ class ANNModelSuite extends FunSuite with TestSparkContext {
val dimensions = 100
val density = 0.5

var points: RDD[(Int, SparseVector)] = _
var points: RDD[(Long, SparseVector)] = _

override def beforeAll() {
super.beforeAll()
val localPoints = TestHelpers.generateRandomPoints(numPoints, dimensions, density)
points = sc.parallelize(localPoints.zipWithIndex.map(_.swap))
points = sc.parallelize(localPoints).zipWithIndex.map(_.swap)
}

test("average selectivity is between zero and one") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ class ANNSuite extends FunSuite with TestSparkContext {
val dimensions = 100
val density = 0.5

var points: RDD[(Int, SparseVector)] = _
var points: RDD[(Long, SparseVector)] = _

override def beforeAll() {
super.beforeAll()
val localPoints = TestHelpers.generateRandomPoints(numPoints, dimensions, density)
points = sc.parallelize(localPoints.zipWithIndex.map(_.swap))
points = sc.parallelize(localPoints).zipWithIndex.map(_.swap)
}

test("compute hamming nearest neighbors as a batch") {
Expand Down Expand Up @@ -104,7 +104,7 @@ class ANNSuite extends FunSuite with TestSparkContext {
test("with multiple hash tables neighbors don't contain duplicates") {
val localPoints = TestHelpers.generateRandomPoints(numPoints, dimensions, density)
val withDuplicates = localPoints ++ localPoints
val points = sc.parallelize(withDuplicates.zipWithIndex.map(_.swap))
val points = sc.parallelize(withDuplicates).zipWithIndex.map(_.swap)

val ann =
new ANN(dimensions, "hamming")
Expand All @@ -129,7 +129,7 @@ class ANNSuite extends FunSuite with TestSparkContext {

test("find neighbors for a set of query points") {
val localPoints = TestHelpers.generateRandomPoints(100, dimensions, density)
val testPoints = sc.parallelize(localPoints.zipWithIndex.map(_.swap))
val testPoints = sc.parallelize(localPoints).zipWithIndex.map(_.swap)

val ann =
new ANN(dimensions, "hamming")
Expand All @@ -145,7 +145,7 @@ class ANNSuite extends FunSuite with TestSparkContext {
object ANNSuite {
def runAssertions(
hashTables: Array[_ <: HashTableEntry[_]],
neighbors: Array[(Int, Array[(Int, Double)])]
neighbors: Array[(Long, Array[(Long, Double)])]
) = {

// At least some neighbors are found
Expand Down

0 comments on commit 8d28ccf

Please sign in to comment.