Skip to content

Commit

Permalink
[Datasets] Allow MultiHotEncoder to encode arrays (ray-project#31365)
Browse files Browse the repository at this point in the history
MultiHotEncoder doesn't work with Arrow datasets. This PR updates the MultiHotEncoder implementation to fix the issue.

Signed-off-by: Balaji Veeramani <[email protected]>
  • Loading branch information
bveeramani authored Jan 4, 2023
1 parent 11e92e2 commit a34fa71
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
8 changes: 6 additions & 2 deletions python/ray/data/preprocessors/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List, Dict, Optional

from collections import Counter, OrderedDict
import numpy as np
import pandas as pd
import pandas.api.types

Expand Down Expand Up @@ -324,7 +325,9 @@ def _transform_pandas(self, df: pd.DataFrame):
_validate_df(df, *self.columns)

def encode_list(element: list, *, name: str):
if not isinstance(element, list):
if isinstance(element, np.ndarray):
element = element.tolist()
elif not isinstance(element, list):
element = [element]
stats = self.stats_[f"unique_values({name})"]
counter = Counter(element)
Expand Down Expand Up @@ -509,6 +512,7 @@ def _get_unique_value_indices(
encode_lists: bool = True,
) -> Dict[str, Dict[str, int]]:
"""If drop_na_values is True, will silently drop NA values."""

if max_categories is None:
max_categories = {}
for column in max_categories:
Expand Down Expand Up @@ -601,5 +605,5 @@ def _is_series_composed_of_lists(series: pd.Series) -> bool:
(element for element in series if element is not None), None
)
return pandas.api.types.is_object_dtype(series.dtype) and isinstance(
first_not_none_element, list
first_not_none_element, (list, np.ndarray)
)
8 changes: 8 additions & 0 deletions python/ray/data/tests/preprocessors/test_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,14 @@ def test_multi_hot_encoder():
null_encoder.transform_batch(null_df)
null_encoder.transform_batch(nonnull_df)

# Verify that `fit` and `transform` work with ndarrays.
df = pd.DataFrame({"column": [np.array(["A"]), np.array(["A", "B"])]})
ds = ray.data.from_pandas(df)
encoder = MultiHotEncoder(["column"])
transformed = encoder.fit_transform(ds)
encodings = [record["column"] for record in transformed.take_all()]
assert encodings == [[1, 0], [1, 1]]


def test_multi_hot_encoder_with_max_categories():
"""Tests basic MultiHotEncoder functionality with limit."""
Expand Down

0 comments on commit a34fa71

Please sign in to comment.