Skip to content

Commit

Permalink
[data] [strict mode] Allow returning lists instead of arrays for nump…
Browse files Browse the repository at this point in the history
…y batches (ray-project#34734)

Allow map_batches UDFs to return {"foo": [1, 2, 3]} in addition to {"foo": np.array([1, 2, 3])} by implicitly casting lists to arrays.
  • Loading branch information
ericl authored Apr 25, 2023
1 parent 0abda20 commit 74ddaaa
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 5 deletions.
20 changes: 16 additions & 4 deletions python/ray/data/_internal/planner/map_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,29 @@ def validate_batch(batch: Block) -> None:
)

if isinstance(batch, collections.abc.Mapping):
for key, value in batch.items():
if not isinstance(value, np.ndarray):
for key, value in list(batch.items()):
if not isinstance(value, (np.ndarray, list)):
raise ValueError(
f"Error validating {_truncated_repr(batch)}: "
"The `fn` you passed to `map_batches` returned a "
f"`dict`. `map_batches` expects all `dict` values "
f"to be of type `numpy.ndarray`, but the value "
f"to be `list` or `np.ndarray` type, but the value "
f"corresponding to key {key!r} is of type "
f"{type(value)}. To fix this issue, convert "
f"the {type(value)} to a `numpy.ndarray`."
f"the {type(value)} to a `np.ndarray`."
)
if isinstance(value, list):
# Try to convert list values into an numpy array via
# np.array(), so users don't need to manually cast.
# NOTE: we don't cast generic iterables, since types like
# `str` are also Iterable.
try:
batch[key] = np.array(value)
except Exception:
raise ValueError(
"Failed to convert column values to numpy array: "
f"({_truncated_repr(value)})."
)

def process_next_batch(batch: DataBatch) -> Iterator[Block]:
# Apply UDF.
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def batch_to_block(batch: DataBatch) -> Block:
return ArrowBlockAccessor.numpy_to_block(
batch, passthrough_arrow_not_implemented_errors=True
)
except pa.ArrowNotImplementedError:
except (pa.ArrowNotImplementedError, pa.ArrowInvalid):
import pandas as pd

# TODO(ekl) once we support Python objects within Arrow blocks, we
Expand Down
20 changes: 20 additions & 0 deletions python/ray/data/tests/test_strict_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,26 @@ def test_strict_map_output(ray_start_regular_shared, enable_strict_mode):
ds.map(lambda x: UserDict({"x": object()})).materialize()


def test_strict_convert_map_output(ray_start_regular_shared, enable_strict_mode):
ds = ray.data.range(1).map_batches(lambda x: {"id": [0, 1, 2, 3]}).materialize()
assert ds.take_batch()["id"].tolist() == [0, 1, 2, 3]

with pytest.raises(ValueError):
# Strings not converted into array.
ray.data.range(1).map_batches(lambda x: {"id": "string"}).materialize()

class UserObj:
def __eq__(self, other):
return isinstance(other, UserObj)

ds = (
ray.data.range(1)
.map_batches(lambda x: {"id": [0, 1, 2, UserObj()]})
.materialize()
)
assert ds.take_batch()["id"].tolist() == [0, 1, 2, UserObj()]


def test_strict_default_batch_format(ray_start_regular_shared, enable_strict_mode):
ds = ray.data.range(1)

Expand Down

0 comments on commit 74ddaaa

Please sign in to comment.