Skip to content

Commit

Permalink
[Data] [strict-mode] Remove internal TableRow abstractions and instea…
Browse files Browse the repository at this point in the history
…d use Dict[str, Any] as the row format

Signed-off-by: Eric Liang <[email protected]>
  • Loading branch information
ericl authored Apr 20, 2023
1 parent bc78ac6 commit 0acef19
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 14 deletions.
11 changes: 8 additions & 3 deletions python/ray/data/_internal/table_block.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import collections
from typing import Dict, Iterator, List, Union, Any, TypeVar, TYPE_CHECKING
from typing import Dict, Iterator, List, Union, Any, TypeVar, Mapping, TYPE_CHECKING

import numpy as np

Expand Down Expand Up @@ -180,7 +180,8 @@ def is_tensor_wrapper(self) -> bool:
return False
return _is_tensor_schema(self.column_names())

def iter_rows(self) -> Iterator[Union[TableRow, np.ndarray]]:
def iter_rows(self) -> Iterator[Union[Mapping, np.ndarray]]:
ctx = ray.data.DataContext.get_current()
outer = self

class Iter:
Expand All @@ -193,7 +194,11 @@ def __iter__(self):
def __next__(self):
self._cur += 1
if self._cur < outer.num_rows():
return outer._get_row(self._cur)
row = outer._get_row(self._cur)
if ctx.strict_mode and isinstance(row, TableRow):
return row.as_pydict()
else:
return row
raise StopIteration

return Iter()
Expand Down
12 changes: 6 additions & 6 deletions python/ray/data/datastream.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Optional,
Tuple,
Union,
Mapping,
)
from uuid import uuid4

Expand Down Expand Up @@ -128,7 +129,6 @@
_wrap_arrow_serialization_workaround,
)
from ray.data.random_access_dataset import RandomAccessDataset
from ray.data.row import TableRow
from ray.types import ObjectRef
from ray.util.annotations import DeveloperAPI, PublicAPI, Deprecated
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
Expand Down Expand Up @@ -3027,12 +3027,12 @@ def iterator(self) -> DataIterator:
return DataIteratorImpl(self)

@ConsumptionAPI
def iter_rows(self, *, prefetch_blocks: int = 0) -> Iterator[Union[T, TableRow]]:
def iter_rows(self, *, prefetch_blocks: int = 0) -> Iterator[Union[T, Mapping]]:
"""Return a local row iterator over the datastream.
If the datastream is a tabular datastream (Arrow/Pandas blocks), dict-like
mappings :py:class:`~ray.data.row.TableRow` are yielded for each row by the
iterator. If the datastream is not tabular, the raw row is yielded.
If the datastream is a tabular datastream (Arrow/Pandas blocks), dicts
are yielded for each row by the iterator. If the datastream is not tabular,
the raw row is yielded.
Examples:
>>> import ray
Expand Down Expand Up @@ -4488,7 +4488,7 @@ def _build_multicolumn_aggs(
on = [on]
return [agg_cls(on_, *args, ignore_nulls=ignore_nulls, **kwargs) for on_ in on]

def _aggregate_result(self, result: Union[Tuple, TableRow]) -> U:
def _aggregate_result(self, result: Union[Tuple, Mapping]) -> U:
if result is not None and len(result) == 1:
if isinstance(result, tuple):
return result[0]
Expand Down
10 changes: 5 additions & 5 deletions python/ray/data/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
Tuple,
Union,
Iterator,
Mapping,
)

from ray.types import ObjectRef
from ray.data.block import BlockAccessor, Block, BlockMetadata, DataBatch, T
from ray.data.context import DataContext
from ray.data.row import TableRow
from ray.util.annotations import PublicAPI
from ray.data._internal.block_batching import batch_block_refs
from ray.data._internal.block_batching.iter_batches import iter_batches
Expand Down Expand Up @@ -200,12 +200,12 @@ def drop_metadata(block_iterator):
if stats:
stats.iter_total_s.add(time.perf_counter() - time_start)

def iter_rows(self, *, prefetch_blocks: int = 0) -> Iterator[Union[T, TableRow]]:
def iter_rows(self, *, prefetch_blocks: int = 0) -> Iterator[Union[T, Mapping]]:
"""Return a local row iterator over the datastream.
If the datastream is a tabular datastream (Arrow/Pandas blocks), dict-like
mappings :py:class:`~ray.data.row.TableRow` are yielded for each row by the
iterator. If the datastream is not tabular, the raw row is yielded.
If the datastream is a tabular datastream (Arrow/Pandas blocks), dicts
are yielded for each row by the iterator. If the datastream is not tabular,
the raw row is yielded.
Examples:
>>> import ray
Expand Down
11 changes: 11 additions & 0 deletions python/ray/data/tests/test_strict_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,17 @@ def test_strict_schema(ray_start_regular_shared):
assert isinstance(schema.base_schema, PandasBlockSchema)


def test_use_raw_dicts(ray_start_regular_shared):
assert type(ray.data.range(10).take(1)[0]) is dict
assert type(ray.data.from_items([1]).take(1)[0]) is dict

def checker(x):
assert type(x) is dict
return x

ray.data.range(10).map(checker).show()


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit 0acef19

Please sign in to comment.