Skip to content

Commit

Permalink
[SPARK-20040][ML][PYTHON] pyspark wrapper for ChiSquareTest
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

A pyspark wrapper for spark.ml.stat.ChiSquareTest

## How was this patch tested?

unit tests
doctests

Author: Bago Amirbekian <[email protected]>

Closes apache#17421 from MrBago/chiSquareTestWrapper.
  • Loading branch information
MrBago authored and jkbradley committed Mar 29, 2017
1 parent 7d432af commit a5c8770
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 12 deletions.
1 change: 1 addition & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ def __hash__(self):
"pyspark.ml.linalg.__init__",
"pyspark.ml.recommendation",
"pyspark.ml.regression",
"pyspark.ml.stat",
"pyspark.ml.tuning",
"pyspark.ml.tests",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ object ChiSquareTest {
statistics: Vector)

/**
* Conduct Pearson's independence test for every feature against the label across the input RDD.
* For each feature, the (feature, label) pairs are converted into a contingency matrix for which
* the Chi-squared statistic is computed. All label and feature values must be categorical.
* Conduct Pearson's independence test for every feature against the label. For each feature, the
* (feature, label) pairs are converted into a contingency matrix for which the Chi-squared
* statistic is computed. All label and feature values must be categorical.
*
* The null hypothesis is that the occurrence of the outcomes is statistically independent.
*
Expand Down
8 changes: 8 additions & 0 deletions python/docs/pyspark.ml.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@ pyspark.ml.regression module
:undoc-members:
:inherited-members:

pyspark.ml.stat module
----------------------

.. automodule:: pyspark.ml.stat
:members:
:undoc-members:
:inherited-members:

pyspark.ml.tuning module
------------------------

Expand Down
93 changes: 93 additions & 0 deletions python/pyspark/ml/stat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from pyspark import since, SparkContext
from pyspark.ml.common import _java2py, _py2java
from pyspark.ml.wrapper import _jvm


class ChiSquareTest(object):
"""
.. note:: Experimental
Conduct Pearson's independence test for every feature against the label. For each feature,
the (feature, label) pairs are converted into a contingency matrix for which the Chi-squared
statistic is computed. All label and feature values must be categorical.
The null hypothesis is that the occurrence of the outcomes is statistically independent.
:param dataset:
DataFrame of categorical labels and categorical features.
Real-valued features will be treated as categorical for each distinct value.
:param featuresCol:
Name of features column in dataset, of type `Vector` (`VectorUDT`).
:param labelCol:
Name of label column in dataset, of any numerical type.
:return:
DataFrame containing the test result for every feature against the label.
This DataFrame will contain a single Row with the following fields:
- `pValues: Vector`
- `degreesOfFreedom: Array[Int]`
- `statistics: Vector`
Each of these fields has one value per feature.
>>> from pyspark.ml.linalg import Vectors
>>> from pyspark.ml.stat import ChiSquareTest
>>> dataset = [[0, Vectors.dense([0, 0, 1])],
... [0, Vectors.dense([1, 0, 1])],
... [1, Vectors.dense([2, 1, 1])],
... [1, Vectors.dense([3, 1, 1])]]
>>> dataset = spark.createDataFrame(dataset, ["label", "features"])
>>> chiSqResult = ChiSquareTest.test(dataset, 'features', 'label')
>>> chiSqResult.select("degreesOfFreedom").collect()[0]
Row(degreesOfFreedom=[3, 1, 0])
.. versionadded:: 2.2.0
"""
@staticmethod
@since("2.2.0")
def test(dataset, featuresCol, labelCol):
"""
Perform a Pearson's independence test using dataset.
"""
sc = SparkContext._active_spark_context
javaTestObj = _jvm().org.apache.spark.ml.stat.ChiSquareTest
args = [_py2java(sc, arg) for arg in (dataset, featuresCol, labelCol)]
return _java2py(sc, javaTestObj.test(*args))


if __name__ == "__main__":
import doctest
import pyspark.ml.stat
from pyspark.sql import SparkSession

globs = pyspark.ml.stat.__dict__.copy()
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
spark = SparkSession.builder \
.master("local[2]") \
.appName("ml.stat tests") \
.getOrCreate()
sc = spark.sparkContext
globs['sc'] = sc
globs['spark'] = spark

failure_count, test_count = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
spark.stop()
if failure_count:
exit(-1)
31 changes: 22 additions & 9 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@
import tempfile
import array as pyarray
import numpy as np
from numpy import (
abs, all, arange, array, array_equal, dot, exp, inf, mean, ones, random, tile, zeros)
from numpy import sum as array_sum
from numpy import abs, all, arange, array, array_equal, inf, ones, tile, zeros
import inspect

from pyspark import keyword_only, SparkContext
Expand All @@ -54,20 +52,19 @@
from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvaluator
from pyspark.ml.feature import *
from pyspark.ml.fpm import FPGrowth, FPGrowthModel
from pyspark.ml.linalg import (
DenseMatrix, DenseMatrix, DenseVector, Matrices, MatrixUDT,
SparseMatrix, SparseVector, Vector, VectorUDT, Vectors, _convert_to_vector)
from pyspark.ml.linalg import DenseMatrix, DenseMatrix, DenseVector, Matrices, MatrixUDT, \
SparseMatrix, SparseVector, Vector, VectorUDT, Vectors
from pyspark.ml.param import Param, Params, TypeConverters
from pyspark.ml.param.shared import HasInputCol, HasMaxIter, HasSeed
from pyspark.ml.recommendation import ALS
from pyspark.ml.regression import (
DecisionTreeRegressor, GeneralizedLinearRegression, LinearRegression)
from pyspark.ml.regression import DecisionTreeRegressor, GeneralizedLinearRegression, \
LinearRegression
from pyspark.ml.stat import ChiSquareTest
from pyspark.ml.tuning import *
from pyspark.ml.wrapper import JavaParams, JavaWrapper
from pyspark.serializers import PickleSerializer
from pyspark.sql import DataFrame, Row, SparkSession
from pyspark.sql.functions import rand
from pyspark.sql.utils import IllegalArgumentException
from pyspark.storagelevel import *
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase

Expand Down Expand Up @@ -1741,6 +1738,22 @@ def test_new_java_array(self):
self.assertEqual(_java2py(self.sc, java_array), [])


class ChiSquareTestTests(SparkSessionTestCase):

def test_chisquaretest(self):
data = [[0, Vectors.dense([0, 1, 2])],
[1, Vectors.dense([1, 1, 1])],
[2, Vectors.dense([2, 1, 0])]]
df = self.spark.createDataFrame(data, ['label', 'feat'])
res = ChiSquareTest.test(df, 'feat', 'label')
# This line is hitting the collect bug described in #17218, commented for now.
# pValues = res.select("degreesOfFreedom").collect())
self.assertIsInstance(res, DataFrame)
fieldNames = set(field.name for field in res.schema.fields)
expectedFields = ["pValues", "degreesOfFreedom", "statistics"]
self.assertTrue(all(field in fieldNames for field in expectedFields))


if __name__ == "__main__":
from pyspark.ml.tests import *
if xmlrunner:
Expand Down

0 comments on commit a5c8770

Please sign in to comment.