Skip to content

Commit

Permalink
[SPARK-49764][PYTHON][CONNECT] Support area plots
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Support area plots with plotly backend on both Spark Connect and Spark classic.

### Why are the changes needed?
While Pandas on Spark supports plotting, PySpark currently lacks this feature. The proposed API will enable users to generate visualizations. This will provide users with an intuitive, interactive way to explore and understand large datasets directly from PySpark DataFrames, streamlining the data analysis workflow in distributed environments.

See more at [PySpark Plotting API Specification](https://docs.google.com/document/d/1IjOEzC8zcetG86WDvqkereQPj_NGLNW7Bdu910g30Dg/edit?usp=sharing) in progress.

Part of https://issues.apache.org/jira/browse/SPARK-49530.

### Does this PR introduce _any_ user-facing change?
Yes. Area plots are supported as shown below.

```py
>>> from datetime import datetime
>>> data = [
...     (3, 5, 20, datetime(2018, 1, 31)),
...     (2, 5, 42, datetime(2018, 2, 28)),
...     (3, 6, 28, datetime(2018, 3, 31)),
...     (9, 12, 62, datetime(2018, 4, 30))]
>>> columns = ["sales", "signups", "visits", "date"]
>>> df = spark.createDataFrame(data, columns)
>>> fig = df.plot.area(x="date", y=["sales", "signups", "visits"])  # df.plot(kind="area", x="date", y=["sales", "signups", "visits"])
>>> fig.show()
```
![newplot (7)](https://github.com/user-attachments/assets/e603cd99-ce8b-4448-8e1f-cbc093097c45)

### How was this patch tested?
Unit tests.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#48236 from xinrong-meng/plot_area.

Authored-by: Xinrong Meng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
xinrong-meng authored and HyukjinKwon committed Sep 25, 2024
1 parent 46c5acc commit 7f0ecd4
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 0 deletions.
35 changes: 35 additions & 0 deletions python/pyspark/sql/plot/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def get_sampled(self, sdf: "DataFrame") -> "pd.DataFrame":

class PySparkPlotAccessor:
plot_data_map = {
"area": PySparkSampledPlotBase().get_sampled,
"bar": PySparkTopNPlotBase().get_top_n,
"barh": PySparkTopNPlotBase().get_top_n,
"line": PySparkSampledPlotBase().get_sampled,
Expand Down Expand Up @@ -264,3 +265,37 @@ def scatter(self, x: str, y: str, **kwargs: Any) -> "Figure":
>>> df.plot.scatter(x='length', y='width') # doctest: +SKIP
"""
return self(kind="scatter", x=x, y=y, **kwargs)

def area(self, x: str, y: str, **kwargs: Any) -> "Figure":
"""
Draw a stacked area plot.
An area plot displays quantitative data visually.
Parameters
----------
x : str
Name of column to use for the horizontal axis.
y : str or list of str
Name(s) of the column(s) to plot.
**kwargs: Optional
Additional keyword arguments.
Returns
-------
:class:`plotly.graph_objs.Figure`
Examples
--------
>>> from datetime import datetime
>>> data = [
... (3, 5, 20, datetime(2018, 1, 31)),
... (2, 5, 42, datetime(2018, 2, 28)),
... (3, 6, 28, datetime(2018, 3, 31)),
... (9, 12, 62, datetime(2018, 4, 30))
... ]
>>> columns = ["sales", "signups", "visits", "date"]
>>> df = spark.createDataFrame(data, columns)
>>> df.plot.area(x='date', y=['sales', 'signups', 'visits']) # doctest: +SKIP
"""
return self(kind="area", x=x, y=y, **kwargs)
35 changes: 35 additions & 0 deletions python/pyspark/sql/tests/plot/test_frame_plot_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#

import unittest
from datetime import datetime

import pyspark.sql.plot # noqa: F401
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_plotly, plotly_requirement_message

Expand All @@ -34,6 +36,17 @@ def sdf2(self):
columns = ["length", "width", "species"]
return self.spark.createDataFrame(data, columns)

@property
def sdf3(self):
data = [
(3, 5, 20, datetime(2018, 1, 31)),
(2, 5, 42, datetime(2018, 2, 28)),
(3, 6, 28, datetime(2018, 3, 31)),
(9, 12, 62, datetime(2018, 4, 30)),
]
columns = ["sales", "signups", "visits", "date"]
return self.spark.createDataFrame(data, columns)

def _check_fig_data(self, kind, fig_data, expected_x, expected_y, expected_name=""):
if kind == "line":
self.assertEqual(fig_data["mode"], "lines")
Expand All @@ -46,6 +59,11 @@ def _check_fig_data(self, kind, fig_data, expected_x, expected_y, expected_name=
elif kind == "scatter":
self.assertEqual(fig_data["type"], "scatter")
self.assertEqual(fig_data["orientation"], "v")
self.assertEqual(fig_data["mode"], "markers")
elif kind == "area":
self.assertEqual(fig_data["type"], "scatter")
self.assertEqual(fig_data["orientation"], "v")
self.assertEqual(fig_data["mode"], "lines")

self.assertEqual(fig_data["xaxis"], "x")
self.assertEqual(list(fig_data["x"]), expected_x)
Expand Down Expand Up @@ -98,6 +116,23 @@ def test_scatter_plot(self):
"scatter", fig["data"][0], [3.5, 3.0, 3.2, 3.2, 3.0], [5.1, 4.9, 7.0, 6.4, 5.9]
)

def test_area_plot(self):
# single column as vertical axis
fig = self.sdf3.plot(kind="area", x="date", y="sales")
expected_x = [
datetime(2018, 1, 31, 0, 0),
datetime(2018, 2, 28, 0, 0),
datetime(2018, 3, 31, 0, 0),
datetime(2018, 4, 30, 0, 0),
]
self._check_fig_data("area", fig["data"][0], expected_x, [3, 2, 3, 9])

# multiple columns as vertical axis
fig = self.sdf3.plot.area(x="date", y=["sales", "signups", "visits"])
self._check_fig_data("area", fig["data"][0], expected_x, [3, 2, 3, 9], "sales")
self._check_fig_data("area", fig["data"][1], expected_x, [5, 5, 6, 12], "signups")
self._check_fig_data("area", fig["data"][2], expected_x, [20, 42, 28, 62], "visits")


class DataFramePlotPlotlyTests(DataFramePlotPlotlyTestsMixin, ReusedSQLTestCase):
pass
Expand Down

0 comments on commit 7f0ecd4

Please sign in to comment.