Skip to content

Commit

Permalink
[SPARK-49609][PYTHON][FOLLOWUP] Correct the typehint for filter and…
Browse files Browse the repository at this point in the history
… `where`

### What changes were proposed in this pull request?
Correct the typehint for `filter` and `where`

### Why are the changes needed?
the input `str` should not be treated as column name

### Does this PR introduce _any_ user-facing change?
doc change

### How was this patch tested?
ci

### Was this patch authored or co-authored using generative AI tooling?
no

Closes apache#48244 from zhengruifeng/py_filter_where.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed Sep 25, 2024
1 parent c362d50 commit 0ccf53a
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion python/pyspark/sql/classic/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1787,7 +1787,7 @@ def semanticHash(self) -> int:
def inputFiles(self) -> List[str]:
return list(self._jdf.inputFiles())

def where(self, condition: "ColumnOrName") -> ParentDataFrame:
def where(self, condition: Union[Column, str]) -> ParentDataFrame:
return self.filter(condition)

# Two aliases below were added for pandas compatibility many years ago.
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1260,7 +1260,7 @@ def intersectAll(self, other: ParentDataFrame) -> ParentDataFrame:
res._cached_schema = self._merge_cached_schema(other)
return res

def where(self, condition: "ColumnOrName") -> ParentDataFrame:
def where(self, condition: Union[Column, str]) -> ParentDataFrame:
if not isinstance(condition, (str, Column)):
raise PySparkTypeError(
errorClass="NOT_COLUMN_OR_STR",
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3351,7 +3351,7 @@ def selectExpr(self, *expr: Union[str, List[str]]) -> "DataFrame":
...

@dispatch_df_method
def filter(self, condition: "ColumnOrName") -> "DataFrame":
def filter(self, condition: Union[Column, str]) -> "DataFrame":
"""Filters rows using the given condition.
:func:`where` is an alias for :func:`filter`.
Expand Down Expand Up @@ -5902,7 +5902,7 @@ def inputFiles(self) -> List[str]:
...

@dispatch_df_method
def where(self, condition: "ColumnOrName") -> "DataFrame":
def where(self, condition: Union[Column, str]) -> "DataFrame":
"""
:func:`where` is an alias for :func:`filter`.
Expand Down

0 comments on commit 0ccf53a

Please sign in to comment.