Skip to content

Commit

Permalink
[SPARK-37412][PYTHON][ML] Inline typehints for pyspark.ml.stat
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR migrates type `pyspark.ml.stat` annotations from stub file to inline type hints.

(second take, after issue resulting in reversion of apache#35401)

### Why are the changes needed?

Part of ongoing migration of type hints.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Existing tests.

Closes apache#35437 from zero323/SPARK-37412.

Authored-by: zero323 <[email protected]>
Signed-off-by: zero323 <[email protected]>
  • Loading branch information
zero323 committed Feb 8, 2022
1 parent 6b62c30 commit 3d736d9
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 94 deletions.
62 changes: 41 additions & 21 deletions python/pyspark/ml/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,20 @@

import sys

from typing import Optional, Tuple, TYPE_CHECKING


from pyspark import since, SparkContext
from pyspark.ml.common import _java2py, _py2java
from pyspark.ml.wrapper import JavaWrapper, _jvm
from pyspark.ml.linalg import Matrix, Vector
from pyspark.ml.wrapper import JavaWrapper, _jvm # type: ignore[attr-defined]
from pyspark.sql.column import Column, _to_seq
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.functions import lit

if TYPE_CHECKING:
from py4j.java_gateway import JavaObject # type: ignore[import]


class ChiSquareTest:
"""
Expand All @@ -37,7 +45,9 @@ class ChiSquareTest:
"""

@staticmethod
def test(dataset, featuresCol, labelCol, flatten=False):
def test(
dataset: DataFrame, featuresCol: str, labelCol: str, flatten: bool = False
) -> DataFrame:
"""
Perform a Pearson's independence test using dataset.
Expand Down Expand Up @@ -95,6 +105,8 @@ def test(dataset, featuresCol, labelCol, flatten=False):
4.0
"""
sc = SparkContext._active_spark_context
assert sc is not None

javaTestObj = _jvm().org.apache.spark.ml.stat.ChiSquareTest
args = [_py2java(sc, arg) for arg in (dataset, featuresCol, labelCol, flatten)]
return _java2py(sc, javaTestObj.test(*args))
Expand All @@ -116,7 +128,7 @@ class Correlation:
"""

@staticmethod
def corr(dataset, column, method="pearson"):
def corr(dataset: DataFrame, column: str, method: str = "pearson") -> DataFrame:
"""
Compute the correlation matrix with specified method using dataset.
Expand Down Expand Up @@ -162,6 +174,8 @@ def corr(dataset, column, method="pearson"):
[ 0.4 , 0.9486... , NaN, 1. ]])
"""
sc = SparkContext._active_spark_context
assert sc is not None

javaCorrObj = _jvm().org.apache.spark.ml.stat.Correlation
args = [_py2java(sc, arg) for arg in (dataset, column, method)]
return _java2py(sc, javaCorrObj.corr(*args))
Expand All @@ -181,7 +195,7 @@ class KolmogorovSmirnovTest:
"""

@staticmethod
def test(dataset, sampleCol, distName, *params):
def test(dataset: DataFrame, sampleCol: str, distName: str, *params: float) -> DataFrame:
"""
Conduct a one-sample, two-sided Kolmogorov-Smirnov test for probability distribution
equality. Currently supports the normal distribution, taking as parameters the mean and
Expand Down Expand Up @@ -228,9 +242,11 @@ def test(dataset, sampleCol, distName, *params):
0.175
"""
sc = SparkContext._active_spark_context
assert sc is not None

javaTestObj = _jvm().org.apache.spark.ml.stat.KolmogorovSmirnovTest
dataset = _py2java(sc, dataset)
params = [float(param) for param in params]
params = [float(param) for param in params] # type: ignore[assignment]
return _java2py(
sc, javaTestObj.test(dataset, sampleCol, distName, _jvm().PythonUtils.toSeq(params))
)
Expand Down Expand Up @@ -284,94 +300,94 @@ class Summarizer:

@staticmethod
@since("2.4.0")
def mean(col, weightCol=None):
def mean(col: Column, weightCol: Optional[Column] = None) -> Column:
"""
return a column of mean summary
"""
return Summarizer._get_single_metric(col, weightCol, "mean")

@staticmethod
@since("3.0.0")
def sum(col, weightCol=None):
def sum(col: Column, weightCol: Optional[Column] = None) -> Column:
"""
return a column of sum summary
"""
return Summarizer._get_single_metric(col, weightCol, "sum")

@staticmethod
@since("2.4.0")
def variance(col, weightCol=None):
def variance(col: Column, weightCol: Optional[Column] = None) -> Column:
"""
return a column of variance summary
"""
return Summarizer._get_single_metric(col, weightCol, "variance")

@staticmethod
@since("3.0.0")
def std(col, weightCol=None):
def std(col: Column, weightCol: Optional[Column] = None) -> Column:
"""
return a column of std summary
"""
return Summarizer._get_single_metric(col, weightCol, "std")

@staticmethod
@since("2.4.0")
def count(col, weightCol=None):
def count(col: Column, weightCol: Optional[Column] = None) -> Column:
"""
return a column of count summary
"""
return Summarizer._get_single_metric(col, weightCol, "count")

@staticmethod
@since("2.4.0")
def numNonZeros(col, weightCol=None):
def numNonZeros(col: Column, weightCol: Optional[Column] = None) -> Column:
"""
return a column of numNonZero summary
"""
return Summarizer._get_single_metric(col, weightCol, "numNonZeros")

@staticmethod
@since("2.4.0")
def max(col, weightCol=None):
def max(col: Column, weightCol: Optional[Column] = None) -> Column:
"""
return a column of max summary
"""
return Summarizer._get_single_metric(col, weightCol, "max")

@staticmethod
@since("2.4.0")
def min(col, weightCol=None):
def min(col: Column, weightCol: Optional[Column] = None) -> Column:
"""
return a column of min summary
"""
return Summarizer._get_single_metric(col, weightCol, "min")

@staticmethod
@since("2.4.0")
def normL1(col, weightCol=None):
def normL1(col: Column, weightCol: Optional[Column] = None) -> Column:
"""
return a column of normL1 summary
"""
return Summarizer._get_single_metric(col, weightCol, "normL1")

@staticmethod
@since("2.4.0")
def normL2(col, weightCol=None):
def normL2(col: Column, weightCol: Optional[Column] = None) -> Column:
"""
return a column of normL2 summary
"""
return Summarizer._get_single_metric(col, weightCol, "normL2")

@staticmethod
def _check_param(featuresCol, weightCol):
def _check_param(featuresCol: Column, weightCol: Optional[Column]) -> Tuple[Column, Column]:
if weightCol is None:
weightCol = lit(1.0)
if not isinstance(featuresCol, Column) or not isinstance(weightCol, Column):
raise TypeError("featureCol and weightCol should be a Column")
return featuresCol, weightCol

@staticmethod
def _get_single_metric(col, weightCol, metric):
def _get_single_metric(col: Column, weightCol: Optional[Column], metric: str) -> Column:
col, weightCol = Summarizer._check_param(col, weightCol)
return Column(
JavaWrapper._new_java_obj(
Expand All @@ -380,7 +396,7 @@ def _get_single_metric(col, weightCol, metric):
)

@staticmethod
def metrics(*metrics):
def metrics(*metrics: str) -> "SummaryBuilder":
"""
Given a list of metrics, provides a builder that it turns computes metrics from a column.
Expand Down Expand Up @@ -415,6 +431,8 @@ def metrics(*metrics):
:py:class:`pyspark.ml.stat.SummaryBuilder`
"""
sc = SparkContext._active_spark_context
assert sc is not None

js = JavaWrapper._new_java_obj(
"org.apache.spark.ml.stat.Summarizer.metrics", _to_seq(sc, metrics)
)
Expand All @@ -432,10 +450,10 @@ class SummaryBuilder(JavaWrapper):
"""

def __init__(self, jSummaryBuilder):
def __init__(self, jSummaryBuilder: "JavaObject"):
super(SummaryBuilder, self).__init__(jSummaryBuilder)

def summary(self, featuresCol, weightCol=None):
def summary(self, featuresCol: Column, weightCol: Optional[Column] = None) -> Column:
"""
Returns an aggregate object that contains the summary of the column with the requested
metrics.
Expand All @@ -456,6 +474,8 @@ def summary(self, featuresCol, weightCol=None):
structure is determined during the creation of the builder.
"""
featuresCol, weightCol = Summarizer._check_param(featuresCol, weightCol)
assert self._java_obj is not None

return Column(self._java_obj.summary(featuresCol._jc, weightCol._jc))


Expand All @@ -474,7 +494,7 @@ class MultivariateGaussian:
[ 3., 2.]]))
"""

def __init__(self, mean, cov):
def __init__(self, mean: Vector, cov: Matrix):
self.mean = mean
self.cov = cov

Expand Down
73 changes: 0 additions & 73 deletions python/pyspark/ml/stat.pyi

This file was deleted.

0 comments on commit 3d736d9

Please sign in to comment.