Skip to content

Commit

Permalink
[SPARK-38993][PYTHON] Impl DataFrame.boxplot and DataFrame.plot.box
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Impl DataFrame.boxplot and DataFrame.plot.box

### Why are the changes needed?
to increase pandas API coverage in PySpark

### Does this PR introduce _any_ user-facing change?
yes

```
In [2]: df = ps.DataFrame([[5.1, 3.5, 0], [4.9, 3.0, 0], [7.0, 3.2, 1], [6.4, 3.2, 1], [5.9, 3.0, 2]], columns=['length', 'width', 'species'])

In [3]: df.boxplot()
Out[3]:
In [4]: df.plot.box()
Out[4]:
```

![image](https://user-images.githubusercontent.com/7322292/164674307-d7622e22-bbfb-45d0-9fd8-318a1a11258f.png)

### How was this patch tested?
added ut and manually tests

Closes apache#36317 from zhengruifeng/impl_box_plot.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
zhengruifeng authored and HyukjinKwon committed Apr 29, 2022
1 parent 973283c commit 9b16579
Show file tree
Hide file tree
Showing 7 changed files with 210 additions and 44 deletions.
2 changes: 2 additions & 0 deletions python/docs/source/reference/pyspark.pandas/frame.rst
Original file line number Diff line number Diff line change
Expand Up @@ -323,11 +323,13 @@ specific plotting methods of the form ``DataFrame.plot.<kind>``.
DataFrame.plot.barh
DataFrame.plot.bar
DataFrame.plot.hist
DataFrame.plot.box
DataFrame.plot.line
DataFrame.plot.pie
DataFrame.plot.scatter
DataFrame.plot.density
DataFrame.hist
DataFrame.boxplot
DataFrame.kde

Pandas-on-Spark specific
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ Supported DataFrame APIs
+--------------------------------------------+-------------+--------------------------------------+
| :func:`bool` | Y | |
+--------------------------------------------+-------------+--------------------------------------+
| boxplot | N | |
| boxplot | Y | |
+--------------------------------------------+-------------+--------------------------------------+
| :func:`clip` | P | ``axis``, ``inplace`` |
+--------------------------------------------+-------------+--------------------------------------+
Expand Down Expand Up @@ -315,7 +315,7 @@ Supported DataFrame APIs
+--------------------------------------------+-------------+--------------------------------------+
| :func:`plot.barh` | Y | |
+--------------------------------------------+-------------+--------------------------------------+
| :func:`plot.box` | N | |
| :func:`plot.box` | Y | |
+--------------------------------------------+-------------+--------------------------------------+
| :func:`plot.density` | Y | |
+--------------------------------------------+-------------+--------------------------------------+
Expand Down
6 changes: 6 additions & 0 deletions python/pyspark/pandas/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,12 @@ def hist(self, bins=10, **kwds):

hist.__doc__ = PandasOnSparkPlotAccessor.hist.__doc__

@no_type_check
def boxplot(self, **kwds):
return self.plot.box(**kwds)

boxplot.__doc__ = PandasOnSparkPlotAccessor.box.__doc__

@no_type_check
def kde(self, bw_method=None, ind=None, **kwds):
return self.plot.kde(bw_method, ind, **kwds)
Expand Down
1 change: 0 additions & 1 deletion python/pyspark/pandas/missing/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ class _MissingPandasLikeDataFrame:
# Functions
asfreq = _unsupported_function("asfreq")
asof = _unsupported_function("asof")
boxplot = _unsupported_function("boxplot")
combine = _unsupported_function("combine")
compare = _unsupported_function("compare")
convert_dtypes = _unsupported_function("convert_dtypes")
Expand Down
85 changes: 82 additions & 3 deletions python/pyspark/pandas/plot/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,45 @@ def compute_hist(psdf, bins):


class BoxPlotBase:
@staticmethod
def compute_multicol_stats(data, colnames, whis, precision):
# Computes mean, median, Q1 and Q3 with approx_percentile and precision
scol = []
for colname in colnames:
scol.append(
F.percentile_approx(
"`%s`" % colname, [0.25, 0.50, 0.75], int(1.0 / precision)
).alias("{}_percentiles%".format(colname))
)
scol.append(F.mean("`%s`" % colname).alias("{}_mean".format(colname)))

# a_percentiles a_mean b_percentiles b_mean
# 0 [3.0, 3.2, 3.2] 3.18 [5.1, 5.9, 6.4] 5.86
pdf = data._internal.resolved_copy.spark_frame.select(*scol).toPandas()

i = 0
multicol_stats = {}
for colname in colnames:
q1, med, q3 = pdf.iloc[0, i]
iqr = q3 - q1
lfence = q1 - whis * iqr
ufence = q3 + whis * iqr
i += 1

mean = pdf.iloc[0, i]
i += 1

multicol_stats[colname] = {
"mean": mean,
"med": med,
"q1": q1,
"q3": q3,
"lfence": lfence,
"ufence": ufence,
}

return multicol_stats

@staticmethod
def compute_stats(data, colname, whis, precision):
# Computes mean, median, Q1 and Q3 with approx_percentile and precision
Expand Down Expand Up @@ -307,6 +346,15 @@ def compute_stats(data, colname, whis, precision):

return stats, (lfence.values[0], ufence.values[0])

@staticmethod
def multicol_outliers(data, multicol_stats):
scols = {}
for colname, stats in multicol_stats.items():
scols["__{}_outlier".format(colname)] = ~F.col("`%s`" % colname).between(
stats["lfence"], stats["ufence"]
)
return data._internal.resolved_copy.spark_frame.withColumns(scols)

@staticmethod
def outliers(data, colname, lfence, ufence):
# Builds expression to identify outliers
Expand All @@ -316,6 +364,39 @@ def outliers(data, colname, lfence, ufence):
"__{}_outlier".format(colname), ~expression
)

@staticmethod
def calc_multicol_whiskers(colnames, multicol_outliers):
# Computes min and max values of non-outliers - the whiskers
scols = []
for colname in colnames:
outlier_colname = "__{}_outlier".format(colname)
scols.append(
F.min(
F.when(~F.col(outlier_colname), F.col(colname)).otherwise(SF.lit(None))
).alias("__{}_min".format(colname))
)
scols.append(
F.max(
F.when(~F.col(outlier_colname), F.col(colname)).otherwise(SF.lit(None))
).alias("__{}_max".format(colname))
)

pdf = multicol_outliers.select(*scols).toPandas()

i = 0
whiskers = {}
for colname in colnames:
min = pdf.iloc[0, i]
i += 1
max = pdf.iloc[0, i]
i += 1
whiskers[colname] = {
"min": min,
"max": max,
}

return whiskers

@staticmethod
def calc_whiskers(colname, outliers):
# Computes min and max values of non-outliers - the whiskers
Expand Down Expand Up @@ -815,10 +896,8 @@ def box(self, **kwds):
"""
from pyspark.pandas import DataFrame, Series

if isinstance(self.data, Series):
if isinstance(self.data, (Series, DataFrame)):
return self(kind="box", **kwds)
elif isinstance(self.data, DataFrame):
return unsupported_function(class_name="pd.DataFrame", method_name="box")()

def hist(self, bins=10, **kwds):
"""
Expand Down
110 changes: 74 additions & 36 deletions python/pyspark/pandas/plot/plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,7 @@ def plot_histogram(data: Union["ps.DataFrame", "ps.Series"], **kwargs):
def plot_box(data: Union["ps.DataFrame", "ps.Series"], **kwargs):
import plotly.graph_objs as go
import pyspark.pandas as ps

if isinstance(data, ps.DataFrame):
raise RuntimeError(
"plotly does not support a box plot with pandas-on-Spark DataFrame. Use Series instead."
)
from pyspark.sql.types import NumericType

# 'whis' isn't actually an argument in plotly (but in matplotlib). But seems like
# plotly doesn't expose the reach of the whiskers to the beyond the first and
Expand All @@ -150,40 +146,82 @@ def plot_box(data: Union["ps.DataFrame", "ps.Series"], **kwargs):
"Set to False." % notched
)

colname = name_like_string(data.name)
spark_column_name = data._internal.spark_column_name_for(data._column_label)

# Computes mean, median, Q1 and Q3 with approx_percentile and precision
col_stats, col_fences = BoxPlotBase.compute_stats(data, spark_column_name, whis, precision)

# Creates a column to flag rows as outliers or not
outliers = BoxPlotBase.outliers(data, spark_column_name, *col_fences)
fig = go.Figure()
if isinstance(data, ps.Series):
colname = name_like_string(data.name)
spark_column_name = data._internal.spark_column_name_for(data._column_label)

# Computes mean, median, Q1 and Q3 with approx_percentile and precision
col_stats, col_fences = BoxPlotBase.compute_stats(data, spark_column_name, whis, precision)

# Creates a column to flag rows as outliers or not
outliers = BoxPlotBase.outliers(data, spark_column_name, *col_fences)

# Computes min and max values of non-outliers - the whiskers
whiskers = BoxPlotBase.calc_whiskers(spark_column_name, outliers)

fliers = None
if boxpoints:
fliers = BoxPlotBase.get_fliers(spark_column_name, outliers, whiskers[0])
fliers = [fliers] if len(fliers) > 0 else None

fig.add_trace(
go.Box(
name=colname,
q1=[col_stats["q1"]],
median=[col_stats["med"]],
q3=[col_stats["q3"]],
mean=[col_stats["mean"]],
lowerfence=[whiskers[0]],
upperfence=[whiskers[1]],
y=fliers,
boxpoints=boxpoints,
notched=notched,
**kwargs, # this is for workarounds. Box takes different options from express.box.
)
)
fig["layout"]["xaxis"]["title"] = colname

# Computes min and max values of non-outliers - the whiskers
whiskers = BoxPlotBase.calc_whiskers(spark_column_name, outliers)
else:
numeric_column_names = []
for column_label in data._internal.column_labels:
if isinstance(data._internal.spark_type_for(column_label), NumericType):
numeric_column_names.append(name_like_string(column_label))

# Computes mean, median, Q1 and Q3 with approx_percentile and precision
multicol_stats = BoxPlotBase.compute_multicol_stats(
data, numeric_column_names, whis, precision
)

fliers = None
if boxpoints:
fliers = BoxPlotBase.get_fliers(spark_column_name, outliers, whiskers[0])
fliers = [fliers] if len(fliers) > 0 else None
# Creates a column to flag rows as outliers or not
outliers = BoxPlotBase.multicol_outliers(data, multicol_stats)

# Computes min and max values of non-outliers - the whiskers
whiskers = BoxPlotBase.calc_multicol_whiskers(numeric_column_names, outliers)

i = 0
for colname in numeric_column_names:
col_stats = multicol_stats[colname]
col_whiskers = whiskers[colname]

fig.add_trace(
go.Box(
x=[i],
name=colname,
q1=[col_stats["q1"]],
median=[col_stats["med"]],
q3=[col_stats["q3"]],
mean=[col_stats["mean"]],
lowerfence=[col_whiskers["min"]],
upperfence=[col_whiskers["max"]],
y=None, # todo: support y=fliers
boxpoints=boxpoints,
notched=notched,
**kwargs,
)
)
i += 1

fig = go.Figure()
fig.add_trace(
go.Box(
name=colname,
q1=[col_stats["q1"]],
median=[col_stats["med"]],
q3=[col_stats["q3"]],
mean=[col_stats["mean"]],
lowerfence=[whiskers[0]],
upperfence=[whiskers[1]],
y=fliers,
boxpoints=boxpoints,
notched=notched,
**kwargs, # this is for workarounds. Box takes different options from express.box.
)
)
fig["layout"]["xaxis"]["title"] = colname
fig["layout"]["yaxis"]["title"] = "value"
return fig

Expand Down
46 changes: 44 additions & 2 deletions python/pyspark/pandas/tests/plot/test_frame_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from pyspark import pandas as ps
from pyspark.pandas.config import set_option, reset_option, option_context
from pyspark.pandas.plot import TopNPlotBase, SampledPlotBase, HistogramPlotBase
from pyspark.pandas.plot import TopNPlotBase, SampledPlotBase, HistogramPlotBase, BoxPlotBase
from pyspark.pandas.exceptions import PandasNotImplementedError
from pyspark.testing.pandasutils import PandasOnSparkTestCase

Expand All @@ -41,7 +41,7 @@ def tearDownClass(cls):
def test_missing(self):
psdf = ps.DataFrame(np.random.rand(2500, 4), columns=["a", "b", "c", "d"])

unsupported_functions = ["box", "hexbin"]
unsupported_functions = ["hexbin"]

for name in unsupported_functions:
with self.assertRaisesRegex(
Expand Down Expand Up @@ -110,6 +110,48 @@ def test_compute_hist_multi_columns(self):
pd.Series(expected_histogram, name=expected_name), histogram, almost=True
)

def test_compute_box_multi_columns(self):
# compare compute_multicol_stats with compute_stats
def check_box_multi_columns(psdf):
k = 1.5
multicol_stats = BoxPlotBase.compute_multicol_stats(
psdf, ["a", "b", "c"], whis=k, precision=0.01
)
multicol_outliers = BoxPlotBase.multicol_outliers(psdf, multicol_stats)
multicol_whiskers = BoxPlotBase.calc_multicol_whiskers(
["a", "b", "c"], multicol_outliers
)

for col in ["a", "b", "c"]:
col_stats = multicol_stats[col]
col_whiskers = multicol_whiskers[col]

stats, fences = BoxPlotBase.compute_stats(psdf[col], col, whis=k, precision=0.01)
outliers = BoxPlotBase.outliers(psdf[col], col, *fences)
whiskers = BoxPlotBase.calc_whiskers(col, outliers)

self.assertEqual(stats["mean"], col_stats["mean"])
self.assertEqual(stats["med"], col_stats["med"])
self.assertEqual(stats["q1"], col_stats["q1"])
self.assertEqual(stats["q3"], col_stats["q3"])
self.assertEqual(fences[0], col_stats["lfence"])
self.assertEqual(fences[1], col_stats["ufence"])
self.assertEqual(whiskers[0], col_whiskers["min"])
self.assertEqual(whiskers[1], col_whiskers["max"])

pdf = pd.DataFrame(
{
"a": [1, 2, 3, 4, 5, 6, 7, 8, 9, 15, 50],
"b": [3, 2, 5, 4, 5, 6, 8, 8, 11, 60, 90],
"c": [-30, -2, 5, 4, 5, 6, -8, 8, 11, 12, 18],
},
index=[0, 1, 3, 5, 6, 8, 9, 9, 9, 10, 10],
)
psdf = ps.from_pandas(pdf)

check_box_multi_columns(psdf)
check_box_multi_columns(-psdf)


if __name__ == "__main__":
import unittest
Expand Down

0 comments on commit 9b16579

Please sign in to comment.