Skip to content

Commit

Permalink
[SPARK-2511][MLLIB] add HashingTF and IDF
Browse files Browse the repository at this point in the history
This is roughly the TF-IDF implementation used in the Databricks Cloud Demo: http://databricks.com/cloud/ .

Both `HashingTF` and `IDF` are implemented as transformers, similar to scikit-learn.

Author: Xiangrui Meng <[email protected]>

Closes apache#1671 from mengxr/tfidf and squashes the following commits:

7d65888 [Xiangrui Meng] use JavaConverters._
5fe9ec4 [Xiangrui Meng] fix unit test
6e214ec [Xiangrui Meng] add apache header
cfd9aed [Xiangrui Meng] add Java-friendly methods move classes to mllib.feature
3814440 [Xiangrui Meng] add HashingTF and IDF
  • Loading branch information
mengxr committed Jul 31, 2014
1 parent e5749a1 commit dc0865b
Show file tree
Hide file tree
Showing 5 changed files with 454 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.feature

import java.lang.{Iterable => JavaIterable}

import scala.collection.JavaConverters._
import scala.collection.mutable

import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils

/**
* :: Experimental ::
* Maps a sequence of terms to their term frequencies using the hashing trick.
*
* @param numFeatures number of features (default: 1000000)
*/
@Experimental
class HashingTF(val numFeatures: Int) extends Serializable {

def this() = this(1000000)

/**
* Returns the index of the input term.
*/
def indexOf(term: Any): Int = Utils.nonNegativeMod(term.##, numFeatures)

/**
* Transforms the input document into a sparse term frequency vector.
*/
def transform(document: Iterable[_]): Vector = {
val termFrequencies = mutable.HashMap.empty[Int, Double]
document.foreach { term =>
val i = indexOf(term)
termFrequencies.put(i, termFrequencies.getOrElse(i, 0.0) + 1.0)
}
Vectors.sparse(numFeatures, termFrequencies.toSeq)
}

/**
* Transforms the input document into a sparse term frequency vector (Java version).
*/
def transform(document: JavaIterable[_]): Vector = {
transform(document.asScala)
}

/**
* Transforms the input document to term frequency vectors.
*/
def transform[D <: Iterable[_]](dataset: RDD[D]): RDD[Vector] = {
dataset.map(this.transform)
}

/**
* Transforms the input document to term frequency vectors (Java version).
*/
def transform[D <: JavaIterable[_]](dataset: JavaRDD[D]): JavaRDD[Vector] = {
dataset.rdd.map(this.transform).toJavaRDD()
}
}
194 changes: 194 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.feature

import breeze.linalg.{DenseVector => BDV}

import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.rdd.RDDFunctions._
import org.apache.spark.rdd.RDD

/**
* :: Experimental ::
* Inverse document frequency (IDF).
* The standard formulation is used: `idf = log((m + 1) / (d(t) + 1))`, where `m` is the total
* number of documents and `d(t)` is the number of documents that contain term `t`.
*/
@Experimental
class IDF {

// TODO: Allow different IDF formulations.

private var brzIdf: BDV[Double] = _

/**
* Computes the inverse document frequency.
* @param dataset an RDD of term frequency vectors
*/
def fit(dataset: RDD[Vector]): this.type = {
brzIdf = dataset.treeAggregate(new IDF.DocumentFrequencyAggregator)(
seqOp = (df, v) => df.add(v),
combOp = (df1, df2) => df1.merge(df2)
).idf()
this
}

/**
* Computes the inverse document frequency.
* @param dataset a JavaRDD of term frequency vectors
*/
def fit(dataset: JavaRDD[Vector]): this.type = {
fit(dataset.rdd)
}

/**
* Transforms term frequency (TF) vectors to TF-IDF vectors.
* @param dataset an RDD of term frequency vectors
* @return an RDD of TF-IDF vectors
*/
def transform(dataset: RDD[Vector]): RDD[Vector] = {
if (!initialized) {
throw new IllegalStateException("Haven't learned IDF yet. Call fit first.")
}
val theIdf = brzIdf
val bcIdf = dataset.context.broadcast(theIdf)
dataset.mapPartitions { iter =>
val thisIdf = bcIdf.value
iter.map { v =>
val n = v.size
v match {
case sv: SparseVector =>
val nnz = sv.indices.size
val newValues = new Array[Double](nnz)
var k = 0
while (k < nnz) {
newValues(k) = sv.values(k) * thisIdf(sv.indices(k))
k += 1
}
Vectors.sparse(n, sv.indices, newValues)
case dv: DenseVector =>
val newValues = new Array[Double](n)
var j = 0
while (j < n) {
newValues(j) = dv.values(j) * thisIdf(j)
j += 1
}
Vectors.dense(newValues)
case other =>
throw new UnsupportedOperationException(
s"Only sparse and dense vectors are supported but got ${other.getClass}.")
}
}
}
}

/**
* Transforms term frequency (TF) vectors to TF-IDF vectors (Java version).
* @param dataset a JavaRDD of term frequency vectors
* @return a JavaRDD of TF-IDF vectors
*/
def transform(dataset: JavaRDD[Vector]): JavaRDD[Vector] = {
transform(dataset.rdd).toJavaRDD()
}

/** Returns the IDF vector. */
def idf(): Vector = {
if (!initialized) {
throw new IllegalStateException("Haven't learned IDF yet. Call fit first.")
}
Vectors.fromBreeze(brzIdf)
}

private def initialized: Boolean = brzIdf != null
}

private object IDF {

/** Document frequency aggregator. */
class DocumentFrequencyAggregator extends Serializable {

/** number of documents */
private var m = 0L
/** document frequency vector */
private var df: BDV[Long] = _

/** Adds a new document. */
def add(doc: Vector): this.type = {
if (isEmpty) {
df = BDV.zeros(doc.size)
}
doc match {
case sv: SparseVector =>
val nnz = sv.indices.size
var k = 0
while (k < nnz) {
if (sv.values(k) > 0) {
df(sv.indices(k)) += 1L
}
k += 1
}
case dv: DenseVector =>
val n = dv.size
var j = 0
while (j < n) {
if (dv.values(j) > 0.0) {
df(j) += 1L
}
j += 1
}
case other =>
throw new UnsupportedOperationException(
s"Only sparse and dense vectors are supported but got ${other.getClass}.")
}
m += 1L
this
}

/** Merges another. */
def merge(other: DocumentFrequencyAggregator): this.type = {
if (!other.isEmpty) {
m += other.m
if (df == null) {
df = other.df.copy
} else {
df += other.df
}
}
this
}

private def isEmpty: Boolean = m == 0L

/** Returns the current IDF vector. */
def idf(): BDV[Double] = {
if (isEmpty) {
throw new IllegalStateException("Haven't seen any document yet.")
}
val n = df.length
val inv = BDV.zeros[Double](n)
var j = 0
while (j < n) {
inv(j) = math.log((m + 1.0)/ (df(j) + 1.0))
j += 1
}
inv
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.feature;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;

import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import com.google.common.collect.Lists;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;

public class JavaTfIdfSuite implements Serializable {
private transient JavaSparkContext sc;

@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaTfIdfSuite");
}

@After
public void tearDown() {
sc.stop();
sc = null;
}

@Test
public void tfIdf() {
// The tests are to check Java compatibility.
HashingTF tf = new HashingTF();
JavaRDD<ArrayList<String>> documents = sc.parallelize(Lists.newArrayList(
Lists.newArrayList("this is a sentence".split(" ")),
Lists.newArrayList("this is another sentence".split(" ")),
Lists.newArrayList("this is still a sentence".split(" "))), 2);
JavaRDD<Vector> termFreqs = tf.transform(documents);
termFreqs.collect();
IDF idf = new IDF();
JavaRDD<Vector> tfIdfs = idf.fit(termFreqs).transform(termFreqs);
List<Vector> localTfIdfs = tfIdfs.collect();
int indexOfThis = tf.indexOf("this");
for (Vector v: localTfIdfs) {
Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.feature

import org.scalatest.FunSuite

import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.LocalSparkContext

class HashingTFSuite extends FunSuite with LocalSparkContext {

test("hashing tf on a single doc") {
val hashingTF = new HashingTF(1000)
val doc = "a a b b c d".split(" ")
val n = hashingTF.numFeatures
val termFreqs = Seq(
(hashingTF.indexOf("a"), 2.0),
(hashingTF.indexOf("b"), 2.0),
(hashingTF.indexOf("c"), 1.0),
(hashingTF.indexOf("d"), 1.0))
assert(termFreqs.map(_._1).forall(i => i >= 0 && i < n),
"index must be in range [0, #features)")
assert(termFreqs.map(_._1).toSet.size === 4, "expecting perfect hashing")
val expected = Vectors.sparse(n, termFreqs)
assert(hashingTF.transform(doc) === expected)
}

test("hashing tf on an RDD") {
val hashingTF = new HashingTF
val localDocs: Seq[Seq[String]] = Seq(
"a a b b b c d".split(" "),
"a b c d a b c".split(" "),
"c b a c b a a".split(" "))
val docs = sc.parallelize(localDocs, 2)
assert(hashingTF.transform(docs).collect().toSet === localDocs.map(hashingTF.transform).toSet)
}
}
Loading

0 comments on commit dc0865b

Please sign in to comment.