diff --git a/python/ray/data/_internal/table_block.py b/python/ray/data/_internal/table_block.py index 809137f25fda..0971ecea6e1a 100644 --- a/python/ray/data/_internal/table_block.py +++ b/python/ray/data/_internal/table_block.py @@ -1,5 +1,5 @@ import collections -from typing import Dict, Iterator, List, Union, Any, TypeVar, TYPE_CHECKING +from typing import Dict, Iterator, List, Union, Any, TypeVar, Mapping, TYPE_CHECKING import numpy as np @@ -180,7 +180,8 @@ def is_tensor_wrapper(self) -> bool: return False return _is_tensor_schema(self.column_names()) - def iter_rows(self) -> Iterator[Union[TableRow, np.ndarray]]: + def iter_rows(self) -> Iterator[Union[Mapping, np.ndarray]]: + ctx = ray.data.DataContext.get_current() outer = self class Iter: @@ -193,7 +194,11 @@ def __iter__(self): def __next__(self): self._cur += 1 if self._cur < outer.num_rows(): - return outer._get_row(self._cur) + row = outer._get_row(self._cur) + if ctx.strict_mode and isinstance(row, TableRow): + return row.as_pydict() + else: + return row raise StopIteration return Iter() diff --git a/python/ray/data/datastream.py b/python/ray/data/datastream.py index f01ca22c0d0a..3240b29e3f1a 100644 --- a/python/ray/data/datastream.py +++ b/python/ray/data/datastream.py @@ -17,6 +17,7 @@ Optional, Tuple, Union, + Mapping, ) from uuid import uuid4 @@ -128,7 +129,6 @@ _wrap_arrow_serialization_workaround, ) from ray.data.random_access_dataset import RandomAccessDataset -from ray.data.row import TableRow from ray.types import ObjectRef from ray.util.annotations import DeveloperAPI, PublicAPI, Deprecated from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy @@ -3027,12 +3027,12 @@ def iterator(self) -> DataIterator: return DataIteratorImpl(self) @ConsumptionAPI - def iter_rows(self, *, prefetch_blocks: int = 0) -> Iterator[Union[T, TableRow]]: + def iter_rows(self, *, prefetch_blocks: int = 0) -> Iterator[Union[T, Mapping]]: """Return a local row iterator over the datastream. - If the datastream is a tabular datastream (Arrow/Pandas blocks), dict-like - mappings :py:class:`~ray.data.row.TableRow` are yielded for each row by the - iterator. If the datastream is not tabular, the raw row is yielded. + If the datastream is a tabular datastream (Arrow/Pandas blocks), dicts + are yielded for each row by the iterator. If the datastream is not tabular, + the raw row is yielded. Examples: >>> import ray @@ -4488,7 +4488,7 @@ def _build_multicolumn_aggs( on = [on] return [agg_cls(on_, *args, ignore_nulls=ignore_nulls, **kwargs) for on_ in on] - def _aggregate_result(self, result: Union[Tuple, TableRow]) -> U: + def _aggregate_result(self, result: Union[Tuple, Mapping]) -> U: if result is not None and len(result) == 1: if isinstance(result, tuple): return result[0] diff --git a/python/ray/data/iterator.py b/python/ray/data/iterator.py index f7841673cf03..4957512da212 100644 --- a/python/ray/data/iterator.py +++ b/python/ray/data/iterator.py @@ -11,12 +11,12 @@ Tuple, Union, Iterator, + Mapping, ) from ray.types import ObjectRef from ray.data.block import BlockAccessor, Block, BlockMetadata, DataBatch, T from ray.data.context import DataContext -from ray.data.row import TableRow from ray.util.annotations import PublicAPI from ray.data._internal.block_batching import batch_block_refs from ray.data._internal.block_batching.iter_batches import iter_batches @@ -200,12 +200,12 @@ def drop_metadata(block_iterator): if stats: stats.iter_total_s.add(time.perf_counter() - time_start) - def iter_rows(self, *, prefetch_blocks: int = 0) -> Iterator[Union[T, TableRow]]: + def iter_rows(self, *, prefetch_blocks: int = 0) -> Iterator[Union[T, Mapping]]: """Return a local row iterator over the datastream. - If the datastream is a tabular datastream (Arrow/Pandas blocks), dict-like - mappings :py:class:`~ray.data.row.TableRow` are yielded for each row by the - iterator. If the datastream is not tabular, the raw row is yielded. + If the datastream is a tabular datastream (Arrow/Pandas blocks), dicts + are yielded for each row by the iterator. If the datastream is not tabular, + the raw row is yielded. Examples: >>> import ray diff --git a/python/ray/data/tests/test_strict_mode.py b/python/ray/data/tests/test_strict_mode.py index 5d7920e41b01..100097c91f6b 100644 --- a/python/ray/data/tests/test_strict_mode.py +++ b/python/ray/data/tests/test_strict_mode.py @@ -182,6 +182,17 @@ def test_strict_schema(ray_start_regular_shared): assert isinstance(schema.base_schema, PandasBlockSchema) +def test_use_raw_dicts(ray_start_regular_shared): + assert type(ray.data.range(10).take(1)[0]) is dict + assert type(ray.data.from_items([1]).take(1)[0]) is dict + + def checker(x): + assert type(x) is dict + return x + + ray.data.range(10).map(checker).show() + + if __name__ == "__main__": import sys