Skip to content

Commit

Permalink
[SPARK-38937][PYTHON] interpolate support param limit_direction
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
interpolate support param `limit_direction`

### Why are the changes needed?
 `limit_direction` is supported in the pandas side

### Does this PR introduce _any_ user-facing change?
yes, a new param is supported

### How was this patch tested?
added ut

Closes apache#36246 from zhengruifeng/linear_interpolate_support_limit_direction.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
zhengruifeng authored and HyukjinKwon committed Apr 25, 2022
1 parent 9440590 commit 5046b8c
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,7 @@ Supported DataFrame APIs
| :func:`insert` | Y | |
+--------------------------------------------+-------------+--------------------------------------+
| :func:`interpolate` | P | ``axis``, ``inplace``, |
| | | ``limit_direction``, ``limit_area``, |
| | | ``downcast`` |
| | | |
| | | ``limit_area``, ``downcast`` |
+--------------------------------------------+-------------+--------------------------------------+
| :func:`isin` | Y | |
+--------------------------------------------+-------------+--------------------------------------+
Expand Down Expand Up @@ -877,8 +875,7 @@ Supported Series APIs
| infer_objects | N | |
+---------------------------------+-------------------+-------------------------------------------+
| :func:`interpolate` | P | ``axis``, ``inplace``, |
| | | ``limit_direction``, ``limit_area``, |
| | | ``downcast`` |
| | | ``limit_area``, ``downcast`` |
+---------------------------------+-------------------+-------------------------------------------+
| :func:`is_monotonic` | Y | |
+---------------------------------+-------------------+-------------------------------------------+
Expand Down
18 changes: 15 additions & 3 deletions python/pyspark/pandas/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -5500,11 +5500,20 @@ def op(psser: ps.Series) -> ps.Series:
else:
return psdf

def interpolate(self, method: Optional[str] = None, limit: Optional[int] = None) -> "DataFrame":
if (method is not None) and (method not in ["linear"]):
def interpolate(
self,
method: str = "linear",
limit: Optional[int] = None,
limit_direction: Optional[str] = None,
) -> "DataFrame":
if method not in ["linear"]:
raise NotImplementedError("interpolate currently works only for method='linear'")
if (limit is not None) and (not limit > 0):
raise ValueError("limit must be > 0.")
if (limit_direction is not None) and (
limit_direction not in ["forward", "backward", "both"]
):
raise ValueError("invalid limit_direction: '{}'".format(limit_direction))

numeric_col_names = []
for label in self._internal.column_labels:
Expand All @@ -5514,7 +5523,10 @@ def interpolate(self, method: Optional[str] = None, limit: Optional[int] = None)

psdf = self[numeric_col_names]
return psdf._apply_series_op(
lambda psser: psser._interpolate(method=method, limit=limit), should_resolve=True
lambda psser: psser._interpolate(
method=method, limit=limit, limit_direction=limit_direction
),
should_resolve=True,
)

def replace(
Expand Down
13 changes: 9 additions & 4 deletions python/pyspark/pandas/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3253,16 +3253,17 @@ def ffill(

pad = ffill

# TODO: add 'axis', 'inplace', 'limit_direction', 'limit_area', 'downcast'
# TODO: add 'axis', 'inplace', 'limit_area', 'downcast'
def interpolate(
self: FrameLike,
method: Optional[str] = None,
method: str = "linear",
limit: Optional[int] = None,
limit_direction: Optional[str] = None,
) -> FrameLike:
"""
Fill NaN values using an interpolation method.
.. note:: the current implementation of rank uses Spark's Window without
.. note:: the current implementation of interpolate uses Spark's Window without
specifying partition specification. This leads to move all data into
single partition in single machine and could cause serious
performance degradation. Avoid this method against very large dataset.
Expand All @@ -3281,6 +3282,10 @@ def interpolate(
Maximum number of consecutive NaNs to fill. Must be greater than
0.
limit_direction : str, default None
Consecutive NaNs will be filled in this direction.
One of {{'forward', 'backward', 'both'}}.
Returns
-------
Series or DataFrame or None
Expand Down Expand Up @@ -3335,7 +3340,7 @@ def interpolate(
2 2.0 3.0 -3.0 9.0
3 2.0 4.0 -4.0 16.0
"""
return self.interpolate(method=method, limit=limit)
return self.interpolate(method=method, limit=limit, limit_direction=limit_direction)

@property
def at(self) -> AtIndexer:
Expand Down
63 changes: 54 additions & 9 deletions python/pyspark/pandas/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2169,18 +2169,28 @@ def _fillna(
)
)._psser_for(self._column_label)

def interpolate(self, method: Optional[str] = None, limit: Optional[int] = None) -> "Series":
return self._interpolate(method=method, limit=limit)
def interpolate(
self,
method: str = "linear",
limit: Optional[int] = None,
limit_direction: Optional[str] = None,
) -> "Series":
return self._interpolate(method=method, limit=limit, limit_direction=limit_direction)

def _interpolate(
self,
method: Optional[str] = None,
method: str = "linear",
limit: Optional[int] = None,
limit_direction: Optional[str] = None,
) -> "Series":
if (method is not None) and (method not in ["linear"]):
if method not in ["linear"]:
raise NotImplementedError("interpolate currently works only for method='linear'")
if (limit is not None) and (not limit > 0):
raise ValueError("limit must be > 0.")
if (limit_direction is not None) and (
limit_direction not in ["forward", "backward", "both"]
):
raise ValueError("invalid limit_direction: '{}'".format(limit_direction))

if not self.spark.nullable and not isinstance(
self.spark.data_type, (FloatType, DoubleType)
Expand Down Expand Up @@ -2209,15 +2219,50 @@ def _interpolate(
) * null_index_forward + last_non_null_forward

fill_cond = ~F.isnull(last_non_null_backward) & ~F.isnull(last_non_null_forward)
pad_cond = F.isnull(last_non_null_backward) & ~F.isnull(last_non_null_forward)
if limit is not None:
fill_cond = fill_cond & (null_index_forward <= F.lit(limit))
pad_cond = pad_cond & (null_index_forward <= F.lit(limit))

pad_head = SF.lit(None)
pad_head_cond = SF.lit(False)
pad_tail = SF.lit(None)
pad_tail_cond = SF.lit(False)

# inputs -> NaN, NaN, 1.0, NaN, NaN, NaN, 5.0, NaN, NaN
if limit_direction is None or limit_direction == "forward":
# outputs -> NaN, NaN, 1.0, 2.0, 3.0, 4.0, 5.0, 5.0, 5.0
pad_tail = last_non_null_forward
pad_tail_cond = F.isnull(last_non_null_backward) & ~F.isnull(last_non_null_forward)
if limit is not None:
# outputs (limit=1) -> NaN, NaN, 1.0, 2.0, NaN, NaN, 5.0, 5.0, NaN
fill_cond = fill_cond & (null_index_forward <= F.lit(limit))
pad_tail_cond = pad_tail_cond & (null_index_forward <= F.lit(limit))

elif limit_direction == "backward":
# outputs -> 1.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, NaN, NaN
pad_head = last_non_null_backward
pad_head_cond = ~F.isnull(last_non_null_backward) & F.isnull(last_non_null_forward)
if limit is not None:
# outputs (limit=1) -> NaN, 1.0, 1.0, NaN, NaN, 4.0, 5.0, NaN, NaN
fill_cond = fill_cond & (null_index_backward <= F.lit(limit))
pad_head_cond = pad_head_cond & (null_index_backward <= F.lit(limit))

else:
# outputs -> 1.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 5.0, 5.0
pad_head = last_non_null_backward
pad_head_cond = ~F.isnull(last_non_null_backward) & F.isnull(last_non_null_forward)
pad_tail = last_non_null_forward
pad_tail_cond = F.isnull(last_non_null_backward) & ~F.isnull(last_non_null_forward)
if limit is not None:
# outputs (limit=1) -> NaN, 1.0, 1.0, 2.0, NaN, 4.0, 5.0, 5.0, NaN
fill_cond = fill_cond & (
(null_index_forward <= F.lit(limit)) | (null_index_backward <= F.lit(limit))
)
pad_head_cond = pad_head_cond & (null_index_backward <= F.lit(limit))
pad_tail_cond = pad_tail_cond & (null_index_forward <= F.lit(limit))

cond = self.isnull().spark.column
scol = (
F.when(cond & fill_cond, fill)
.when(cond & pad_cond, last_non_null_forward)
.when(cond & pad_head_cond, pad_head)
.when(cond & pad_tail_cond, pad_tail)
.otherwise(scol)
)

Expand Down
33 changes: 17 additions & 16 deletions python/pyspark/pandas/tests/test_generic_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,18 @@ def test_interpolate_error(self):
with self.assertRaisesRegex(ValueError, "limit must be > 0"):
psdf.interpolate(limit=0)

def _test_series_interpolate(self, pser):
psser = ps.from_pandas(pser)
self.assert_eq(psser.interpolate(), pser.interpolate())
for l1 in range(1, 5):
self.assert_eq(psser.interpolate(limit=l1), pser.interpolate(limit=l1))

def _test_dataframe_interpolate(self, pdf):
psdf = ps.from_pandas(pdf)
self.assert_eq(psdf.interpolate(), pdf.interpolate())
for l2 in range(1, 5):
self.assert_eq(psdf.interpolate(limit=l2), pdf.interpolate(limit=l2))
with self.assertRaisesRegex(ValueError, "invalid limit_direction"):
psdf.interpolate(limit_direction="jump")

def _test_interpolate(self, pobj):
psobj = ps.from_pandas(pobj)
self.assert_eq(psobj.interpolate(), pobj.interpolate())
for limit in range(1, 5):
for limit_direction in [None, "forward", "backward", "both"]:
self.assert_eq(
psobj.interpolate(limit=limit, limit_direction=limit_direction),
pobj.interpolate(limit=limit, limit_direction=limit_direction),
)

def test_interpolate(self):
pser = pd.Series(
Expand All @@ -54,7 +55,7 @@ def test_interpolate(self):
],
name="a",
)
self._test_series_interpolate(pser)
self._test_interpolate(pser)

pser = pd.Series(
[
Expand All @@ -64,7 +65,7 @@ def test_interpolate(self):
],
name="a",
)
self._test_series_interpolate(pser)
self._test_interpolate(pser)

pser = pd.Series(
[
Expand All @@ -84,7 +85,7 @@ def test_interpolate(self):
],
name="a",
)
self._test_series_interpolate(pser)
self._test_interpolate(pser)

pdf = pd.DataFrame(
[
Expand All @@ -96,7 +97,7 @@ def test_interpolate(self):
],
columns=list("abc"),
)
self._test_dataframe_interpolate(pdf)
self._test_interpolate(pdf)

pdf = pd.DataFrame(
[
Expand All @@ -108,7 +109,7 @@ def test_interpolate(self):
],
columns=list("abcde"),
)
self._test_dataframe_interpolate(pdf)
self._test_interpolate(pdf)


if __name__ == "__main__":
Expand Down

0 comments on commit 5046b8c

Please sign in to comment.