forked from apache/spark
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SPARK-2511][MLLIB] add HashingTF and IDF
This is roughly the TF-IDF implementation used in the Databricks Cloud Demo: http://databricks.com/cloud/ . Both `HashingTF` and `IDF` are implemented as transformers, similar to scikit-learn. Author: Xiangrui Meng <[email protected]> Closes apache#1671 from mengxr/tfidf and squashes the following commits: 7d65888 [Xiangrui Meng] use JavaConverters._ 5fe9ec4 [Xiangrui Meng] fix unit test 6e214ec [Xiangrui Meng] add apache header cfd9aed [Xiangrui Meng] add Java-friendly methods move classes to mllib.feature 3814440 [Xiangrui Meng] add HashingTF and IDF
- Loading branch information
Showing
5 changed files
with
454 additions
and
0 deletions.
There are no files selected for viewing
79 changes: 79 additions & 0 deletions
79
mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.feature | ||
|
||
import java.lang.{Iterable => JavaIterable} | ||
|
||
import scala.collection.JavaConverters._ | ||
import scala.collection.mutable | ||
|
||
import org.apache.spark.annotation.Experimental | ||
import org.apache.spark.api.java.JavaRDD | ||
import org.apache.spark.mllib.linalg.{Vector, Vectors} | ||
import org.apache.spark.rdd.RDD | ||
import org.apache.spark.util.Utils | ||
|
||
/** | ||
* :: Experimental :: | ||
* Maps a sequence of terms to their term frequencies using the hashing trick. | ||
* | ||
* @param numFeatures number of features (default: 1000000) | ||
*/ | ||
@Experimental | ||
class HashingTF(val numFeatures: Int) extends Serializable { | ||
|
||
def this() = this(1000000) | ||
|
||
/** | ||
* Returns the index of the input term. | ||
*/ | ||
def indexOf(term: Any): Int = Utils.nonNegativeMod(term.##, numFeatures) | ||
|
||
/** | ||
* Transforms the input document into a sparse term frequency vector. | ||
*/ | ||
def transform(document: Iterable[_]): Vector = { | ||
val termFrequencies = mutable.HashMap.empty[Int, Double] | ||
document.foreach { term => | ||
val i = indexOf(term) | ||
termFrequencies.put(i, termFrequencies.getOrElse(i, 0.0) + 1.0) | ||
} | ||
Vectors.sparse(numFeatures, termFrequencies.toSeq) | ||
} | ||
|
||
/** | ||
* Transforms the input document into a sparse term frequency vector (Java version). | ||
*/ | ||
def transform(document: JavaIterable[_]): Vector = { | ||
transform(document.asScala) | ||
} | ||
|
||
/** | ||
* Transforms the input document to term frequency vectors. | ||
*/ | ||
def transform[D <: Iterable[_]](dataset: RDD[D]): RDD[Vector] = { | ||
dataset.map(this.transform) | ||
} | ||
|
||
/** | ||
* Transforms the input document to term frequency vectors (Java version). | ||
*/ | ||
def transform[D <: JavaIterable[_]](dataset: JavaRDD[D]): JavaRDD[Vector] = { | ||
dataset.rdd.map(this.transform).toJavaRDD() | ||
} | ||
} |
194 changes: 194 additions & 0 deletions
194
mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.feature | ||
|
||
import breeze.linalg.{DenseVector => BDV} | ||
|
||
import org.apache.spark.annotation.Experimental | ||
import org.apache.spark.api.java.JavaRDD | ||
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} | ||
import org.apache.spark.mllib.rdd.RDDFunctions._ | ||
import org.apache.spark.rdd.RDD | ||
|
||
/** | ||
* :: Experimental :: | ||
* Inverse document frequency (IDF). | ||
* The standard formulation is used: `idf = log((m + 1) / (d(t) + 1))`, where `m` is the total | ||
* number of documents and `d(t)` is the number of documents that contain term `t`. | ||
*/ | ||
@Experimental | ||
class IDF { | ||
|
||
// TODO: Allow different IDF formulations. | ||
|
||
private var brzIdf: BDV[Double] = _ | ||
|
||
/** | ||
* Computes the inverse document frequency. | ||
* @param dataset an RDD of term frequency vectors | ||
*/ | ||
def fit(dataset: RDD[Vector]): this.type = { | ||
brzIdf = dataset.treeAggregate(new IDF.DocumentFrequencyAggregator)( | ||
seqOp = (df, v) => df.add(v), | ||
combOp = (df1, df2) => df1.merge(df2) | ||
).idf() | ||
this | ||
} | ||
|
||
/** | ||
* Computes the inverse document frequency. | ||
* @param dataset a JavaRDD of term frequency vectors | ||
*/ | ||
def fit(dataset: JavaRDD[Vector]): this.type = { | ||
fit(dataset.rdd) | ||
} | ||
|
||
/** | ||
* Transforms term frequency (TF) vectors to TF-IDF vectors. | ||
* @param dataset an RDD of term frequency vectors | ||
* @return an RDD of TF-IDF vectors | ||
*/ | ||
def transform(dataset: RDD[Vector]): RDD[Vector] = { | ||
if (!initialized) { | ||
throw new IllegalStateException("Haven't learned IDF yet. Call fit first.") | ||
} | ||
val theIdf = brzIdf | ||
val bcIdf = dataset.context.broadcast(theIdf) | ||
dataset.mapPartitions { iter => | ||
val thisIdf = bcIdf.value | ||
iter.map { v => | ||
val n = v.size | ||
v match { | ||
case sv: SparseVector => | ||
val nnz = sv.indices.size | ||
val newValues = new Array[Double](nnz) | ||
var k = 0 | ||
while (k < nnz) { | ||
newValues(k) = sv.values(k) * thisIdf(sv.indices(k)) | ||
k += 1 | ||
} | ||
Vectors.sparse(n, sv.indices, newValues) | ||
case dv: DenseVector => | ||
val newValues = new Array[Double](n) | ||
var j = 0 | ||
while (j < n) { | ||
newValues(j) = dv.values(j) * thisIdf(j) | ||
j += 1 | ||
} | ||
Vectors.dense(newValues) | ||
case other => | ||
throw new UnsupportedOperationException( | ||
s"Only sparse and dense vectors are supported but got ${other.getClass}.") | ||
} | ||
} | ||
} | ||
} | ||
|
||
/** | ||
* Transforms term frequency (TF) vectors to TF-IDF vectors (Java version). | ||
* @param dataset a JavaRDD of term frequency vectors | ||
* @return a JavaRDD of TF-IDF vectors | ||
*/ | ||
def transform(dataset: JavaRDD[Vector]): JavaRDD[Vector] = { | ||
transform(dataset.rdd).toJavaRDD() | ||
} | ||
|
||
/** Returns the IDF vector. */ | ||
def idf(): Vector = { | ||
if (!initialized) { | ||
throw new IllegalStateException("Haven't learned IDF yet. Call fit first.") | ||
} | ||
Vectors.fromBreeze(brzIdf) | ||
} | ||
|
||
private def initialized: Boolean = brzIdf != null | ||
} | ||
|
||
private object IDF { | ||
|
||
/** Document frequency aggregator. */ | ||
class DocumentFrequencyAggregator extends Serializable { | ||
|
||
/** number of documents */ | ||
private var m = 0L | ||
/** document frequency vector */ | ||
private var df: BDV[Long] = _ | ||
|
||
/** Adds a new document. */ | ||
def add(doc: Vector): this.type = { | ||
if (isEmpty) { | ||
df = BDV.zeros(doc.size) | ||
} | ||
doc match { | ||
case sv: SparseVector => | ||
val nnz = sv.indices.size | ||
var k = 0 | ||
while (k < nnz) { | ||
if (sv.values(k) > 0) { | ||
df(sv.indices(k)) += 1L | ||
} | ||
k += 1 | ||
} | ||
case dv: DenseVector => | ||
val n = dv.size | ||
var j = 0 | ||
while (j < n) { | ||
if (dv.values(j) > 0.0) { | ||
df(j) += 1L | ||
} | ||
j += 1 | ||
} | ||
case other => | ||
throw new UnsupportedOperationException( | ||
s"Only sparse and dense vectors are supported but got ${other.getClass}.") | ||
} | ||
m += 1L | ||
this | ||
} | ||
|
||
/** Merges another. */ | ||
def merge(other: DocumentFrequencyAggregator): this.type = { | ||
if (!other.isEmpty) { | ||
m += other.m | ||
if (df == null) { | ||
df = other.df.copy | ||
} else { | ||
df += other.df | ||
} | ||
} | ||
this | ||
} | ||
|
||
private def isEmpty: Boolean = m == 0L | ||
|
||
/** Returns the current IDF vector. */ | ||
def idf(): BDV[Double] = { | ||
if (isEmpty) { | ||
throw new IllegalStateException("Haven't seen any document yet.") | ||
} | ||
val n = df.length | ||
val inv = BDV.zeros[Double](n) | ||
var j = 0 | ||
while (j < n) { | ||
inv(j) = math.log((m + 1.0)/ (df(j) + 1.0)) | ||
j += 1 | ||
} | ||
inv | ||
} | ||
} | ||
} |
66 changes: 66 additions & 0 deletions
66
mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.feature; | ||
|
||
import java.io.Serializable; | ||
import java.util.ArrayList; | ||
import java.util.List; | ||
|
||
import org.junit.After; | ||
import org.junit.Assert; | ||
import org.junit.Before; | ||
import org.junit.Test; | ||
import com.google.common.collect.Lists; | ||
|
||
import org.apache.spark.api.java.JavaRDD; | ||
import org.apache.spark.api.java.JavaSparkContext; | ||
import org.apache.spark.mllib.linalg.Vector; | ||
|
||
public class JavaTfIdfSuite implements Serializable { | ||
private transient JavaSparkContext sc; | ||
|
||
@Before | ||
public void setUp() { | ||
sc = new JavaSparkContext("local", "JavaTfIdfSuite"); | ||
} | ||
|
||
@After | ||
public void tearDown() { | ||
sc.stop(); | ||
sc = null; | ||
} | ||
|
||
@Test | ||
public void tfIdf() { | ||
// The tests are to check Java compatibility. | ||
HashingTF tf = new HashingTF(); | ||
JavaRDD<ArrayList<String>> documents = sc.parallelize(Lists.newArrayList( | ||
Lists.newArrayList("this is a sentence".split(" ")), | ||
Lists.newArrayList("this is another sentence".split(" ")), | ||
Lists.newArrayList("this is still a sentence".split(" "))), 2); | ||
JavaRDD<Vector> termFreqs = tf.transform(documents); | ||
termFreqs.collect(); | ||
IDF idf = new IDF(); | ||
JavaRDD<Vector> tfIdfs = idf.fit(termFreqs).transform(termFreqs); | ||
List<Vector> localTfIdfs = tfIdfs.collect(); | ||
int indexOfThis = tf.indexOf("this"); | ||
for (Vector v: localTfIdfs) { | ||
Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15); | ||
} | ||
} | ||
} |
52 changes: 52 additions & 0 deletions
52
mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.feature | ||
|
||
import org.scalatest.FunSuite | ||
|
||
import org.apache.spark.mllib.linalg.Vectors | ||
import org.apache.spark.mllib.util.LocalSparkContext | ||
|
||
class HashingTFSuite extends FunSuite with LocalSparkContext { | ||
|
||
test("hashing tf on a single doc") { | ||
val hashingTF = new HashingTF(1000) | ||
val doc = "a a b b c d".split(" ") | ||
val n = hashingTF.numFeatures | ||
val termFreqs = Seq( | ||
(hashingTF.indexOf("a"), 2.0), | ||
(hashingTF.indexOf("b"), 2.0), | ||
(hashingTF.indexOf("c"), 1.0), | ||
(hashingTF.indexOf("d"), 1.0)) | ||
assert(termFreqs.map(_._1).forall(i => i >= 0 && i < n), | ||
"index must be in range [0, #features)") | ||
assert(termFreqs.map(_._1).toSet.size === 4, "expecting perfect hashing") | ||
val expected = Vectors.sparse(n, termFreqs) | ||
assert(hashingTF.transform(doc) === expected) | ||
} | ||
|
||
test("hashing tf on an RDD") { | ||
val hashingTF = new HashingTF | ||
val localDocs: Seq[Seq[String]] = Seq( | ||
"a a b b b c d".split(" "), | ||
"a b c d a b c".split(" "), | ||
"c b a c b a a".split(" ")) | ||
val docs = sc.parallelize(localDocs, 2) | ||
assert(hashingTF.transform(docs).collect().toSet === localDocs.map(hashingTF.transform).toSet) | ||
} | ||
} |
Oops, something went wrong.