diff --git a/python/ray/data/_internal/execution/interfaces/op_runtime_metrics.py b/python/ray/data/_internal/execution/interfaces/op_runtime_metrics.py index 5ea4f4cdc71e..51b7ab9f30e3 100644 --- a/python/ray/data/_internal/execution/interfaces/op_runtime_metrics.py +++ b/python/ray/data/_internal/execution/interfaces/op_runtime_metrics.py @@ -512,12 +512,10 @@ def on_task_finished(self, task_index: int, exception: Optional[Exception]): input_size, ) - blocks = [input[0] for input in inputs.blocks] - metadata = [input[1] for input in inputs.blocks] ctx = ray.data.context.DataContext.get_current() if ctx.enable_get_object_locations_for_metrics: - locations = ray.experimental.get_object_locations(blocks) - for block, meta in zip(blocks, metadata): + locations = ray.experimental.get_object_locations(inputs.block_refs) + for block, meta in inputs.blocks: if locations[block].get("did_spill", False): assert meta.size_bytes is not None self.obj_store_mem_spilled += meta.size_bytes diff --git a/python/ray/data/_internal/execution/interfaces/ref_bundle.py b/python/ray/data/_internal/execution/interfaces/ref_bundle.py index 6304f48432a8..c8996a1c422b 100644 --- a/python/ray/data/_internal/execution/interfaces/ref_bundle.py +++ b/python/ray/data/_internal/execution/interfaces/ref_bundle.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Optional, Tuple +from typing import List, Optional, Tuple import ray from .common import NodeIdStr @@ -58,19 +58,29 @@ def __setattr__(self, key, value): raise ValueError(f"The `{key}` field of RefBundle cannot be updated.") object.__setattr__(self, key, value) + @property + def block_refs(self) -> List[BlockMetadata]: + """List of block references in this bundle.""" + return [block_ref for block_ref, _ in self.blocks] + + @property + def metadata(self) -> List[BlockMetadata]: + """List of block metadata in this bundle.""" + return [metadata for _, metadata in self.blocks] + def num_rows(self) -> Optional[int]: """Number of rows present in this bundle, if known.""" total = 0 - for b in self.blocks: - if b[1].num_rows is None: + for m in self.metadata: + if m.num_rows is None: return None else: - total += b[1].num_rows + total += m.num_rows return total def size_bytes(self) -> int: """Size of the blocks of this bundle in bytes.""" - return sum(b[1].size_bytes for b in self.blocks) + return sum(m.size_bytes for m in self.metadata) def destroy_if_owned(self) -> int: """Clears the object store memory for these blocks if owned. @@ -79,8 +89,10 @@ def destroy_if_owned(self) -> int: The number of bytes freed. """ should_free = self.owns_blocks and DataContext.get_current().eager_free - for b in self.blocks: - trace_deallocation(b[0], "RefBundle.destroy_if_owned", free=should_free) + for block_ref in self.block_refs: + trace_deallocation( + block_ref, "RefBundle.destroy_if_owned", free=should_free + ) return self.size_bytes() if should_free else 0 def get_cached_location(self) -> Optional[NodeIdStr]: @@ -91,7 +103,7 @@ def get_cached_location(self) -> Optional[NodeIdStr]: if self._cached_location is None: # Only consider the first block in the bundle for now. TODO(ekl) consider # taking into account other blocks. - ref = self.blocks[0][0] + ref = self.block_refs[0] # This call is pretty fast for owned objects (~5k/s), so we don't need to # batch it for now. locs = ray.experimental.get_object_locations([ref]) diff --git a/python/ray/data/_internal/execution/legacy_compat.py b/python/ray/data/_internal/execution/legacy_compat.py index c216617eeb0f..aa4e7ba64ee5 100644 --- a/python/ray/data/_internal/execution/legacy_compat.py +++ b/python/ray/data/_internal/execution/legacy_compat.py @@ -204,9 +204,8 @@ def _bundles_to_block_list(bundles: Iterator[RefBundle]) -> BlockList: for ref_bundle in bundles: if not ref_bundle.owns_blocks: owns_blocks = False - for block, meta in ref_bundle.blocks: - blocks.append(block) - metadata.append(meta) + blocks.extend(ref_bundle.block_refs) + metadata.extend(ref_bundle.metadata) return BlockList(blocks, metadata, owned_by_consumer=owns_blocks) diff --git a/python/ray/data/_internal/execution/operators/input_data_buffer.py b/python/ray/data/_internal/execution/operators/input_data_buffer.py index ba781b660713..cc0447d8d26e 100644 --- a/python/ray/data/_internal/execution/operators/input_data_buffer.py +++ b/python/ray/data/_internal/execution/operators/input_data_buffer.py @@ -80,7 +80,7 @@ def _initialize_metadata(self): self._num_output_bundles = len(self._input_data) block_metadata = [] for bundle in self._input_data: - block_metadata.extend([m for (_, m) in bundle.blocks]) + block_metadata.extend(bundle.metadata) self._stats = { "input": block_metadata, } diff --git a/python/ray/data/_internal/execution/operators/map_operator.py b/python/ray/data/_internal/execution/operators/map_operator.py index 075e28bc6c1c..931dbd4aeb7c 100644 --- a/python/ray/data/_internal/execution/operators/map_operator.py +++ b/python/ray/data/_internal/execution/operators/map_operator.py @@ -370,8 +370,7 @@ def _get_next_inner(self) -> RefBundle: assert self._started bundle = self._output_queue.get_next() self._metrics.on_output_dequeued(bundle) - for _, meta in bundle.blocks: - self._output_metadata.append(meta) + self._output_metadata.extend(bundle.metadata) return bundle @abstractmethod diff --git a/python/ray/data/_internal/execution/operators/task_pool_map_operator.py b/python/ray/data/_internal/execution/operators/task_pool_map_operator.py index 8f4fc7fee2b8..42fa68030f48 100644 --- a/python/ray/data/_internal/execution/operators/task_pool_map_operator.py +++ b/python/ray/data/_internal/execution/operators/task_pool_map_operator.py @@ -63,7 +63,6 @@ def __init__( def _add_bundled_input(self, bundle: RefBundle): # Submit the task as a normal Ray task. map_task = cached_remote_fn(_map_task, num_returns="streaming") - input_blocks = [block for block, _ in bundle.blocks] ctx = TaskContext( task_idx=self._next_data_task_idx, target_max_block_size=self.actual_target_max_block_size, @@ -84,7 +83,7 @@ def _add_bundled_input(self, bundle: RefBundle): self._map_transformer_ref, data_context, ctx, - *input_blocks, + *bundle.block_refs, ) self._submit_data_task(gen, bundle) diff --git a/python/ray/data/_internal/planner/aggregate.py b/python/ray/data/_internal/planner/aggregate.py index 6cd86ad3938d..d31afc5a7c00 100644 --- a/python/ray/data/_internal/planner/aggregate.py +++ b/python/ray/data/_internal/planner/aggregate.py @@ -39,9 +39,8 @@ def fn( blocks = [] metadata = [] for ref_bundle in refs: - for block, block_metadata in ref_bundle.blocks: - blocks.append(block) - metadata.append(block_metadata) + blocks.extend(ref_bundle.block_refs) + metadata.extend(ref_bundle.metadata) if len(blocks) == 0: return (blocks, {}) unified_schema = unify_block_metadata_schema(metadata) diff --git a/python/ray/data/_internal/planner/exchange/pull_based_shuffle_task_scheduler.py b/python/ray/data/_internal/planner/exchange/pull_based_shuffle_task_scheduler.py index 1176202b9758..b2cf448f030d 100644 --- a/python/ray/data/_internal/planner/exchange/pull_based_shuffle_task_scheduler.py +++ b/python/ray/data/_internal/planner/exchange/pull_based_shuffle_task_scheduler.py @@ -41,8 +41,7 @@ def execute( # eagerly release the blocks' memory. input_blocks_list = [] for ref_bundle in refs: - for block, _ in ref_bundle.blocks: - input_blocks_list.append(block) + input_blocks_list.extend(ref_bundle.block_refs) input_num_blocks = len(input_blocks_list) input_owned = all(b.owns_blocks for b in refs) diff --git a/python/ray/data/_internal/planner/exchange/push_based_shuffle_task_scheduler.py b/python/ray/data/_internal/planner/exchange/push_based_shuffle_task_scheduler.py index 2c5130fe8933..cafd5f31f4da 100644 --- a/python/ray/data/_internal/planner/exchange/push_based_shuffle_task_scheduler.py +++ b/python/ray/data/_internal/planner/exchange/push_based_shuffle_task_scheduler.py @@ -457,8 +457,7 @@ def execute( # processed concurrently. input_blocks_list = [] for ref_bundle in refs: - for block, _ in ref_bundle.blocks: - input_blocks_list.append(block) + input_blocks_list.extend(ref_bundle.block_refs) input_owned = all(b.owns_blocks for b in refs) if map_ray_remote_args is None: diff --git a/python/ray/data/_internal/planner/randomize_blocks.py b/python/ray/data/_internal/planner/randomize_blocks.py index b2fc9c950c0e..835017f2cafd 100644 --- a/python/ray/data/_internal/planner/randomize_blocks.py +++ b/python/ray/data/_internal/planner/randomize_blocks.py @@ -22,8 +22,7 @@ def fn( nonlocal op blocks_with_metadata = [] for ref_bundle in refs: - for block, meta in ref_bundle.blocks: - blocks_with_metadata.append((block, meta)) + blocks_with_metadata.extend(ref_bundle.blocks) if len(blocks_with_metadata) == 0: return refs, {op._name: []} diff --git a/python/ray/data/_internal/planner/sort.py b/python/ray/data/_internal/planner/sort.py index fddb80a72dde..3ca5af144006 100644 --- a/python/ray/data/_internal/planner/sort.py +++ b/python/ray/data/_internal/planner/sort.py @@ -32,9 +32,8 @@ def fn( blocks = [] metadata = [] for ref_bundle in refs: - for block, block_metadata in ref_bundle.blocks: - blocks.append(block) - metadata.append(block_metadata) + blocks.extend(ref_bundle.block_refs) + metadata.extend(ref_bundle.metadata) if len(blocks) == 0: return (blocks, {}) sort_key.validate_schema(unify_block_metadata_schema(metadata)) diff --git a/python/ray/data/tests/test_operators.py b/python/ray/data/tests/test_operators.py index b48d94c79fb7..54ea30c17234 100644 --- a/python/ray/data/tests/test_operators.py +++ b/python/ray/data/tests/test_operators.py @@ -44,8 +44,8 @@ def _get_blocks(bundle: RefBundle, output_list: List[Block]): - for block, _ in bundle.blocks: - output_list.append(list(ray.get(block)["id"])) + for block_ref in bundle.block_refs: + output_list.append(list(ray.get(block_ref)["id"])) def _mul2_transform(block_iter: Iterable[Block], ctx) -> Iterable[Block]: @@ -196,9 +196,11 @@ def test_split_operator(ray_start_regular_shared, equal, chunk_size): while op.has_next(): ref = op.get_next() assert ref.owns_blocks, ref - for block, _ in ref.blocks: + for block_ref in ref.block_refs: assert ref.output_split_idx is not None - output_splits[ref.output_split_idx].extend(list(ray.get(block)["id"])) + output_splits[ref.output_split_idx].extend( + list(ray.get(block_ref)["id"]) + ) op.all_inputs_done() expected_splits = [[] for _ in range(num_splits)] @@ -234,8 +236,8 @@ def test_split_operator_random(ray_start_regular_shared, equal, random_seed): while op.has_next(): ref = op.get_next() assert ref.owns_blocks, ref - for block, _ in ref.blocks: - output_splits[ref.output_split_idx].extend(list(ray.get(block)["id"])) + for block_ref in ref.block_refs: + output_splits[ref.output_split_idx].extend(list(ray.get(block_ref)["id"])) if equal: actual = [len(output_splits[i]) for i in range(3)] expected = [num_inputs // 3] * 3 @@ -271,8 +273,8 @@ def get_bundle_loc(bundle): while op.has_next(): ref = op.get_next() assert ref.owns_blocks, ref - for block, _ in ref.blocks: - output_splits[ref.output_split_idx].extend(list(ray.get(block)["id"])) + for block_ref in ref.block_refs: + output_splits[ref.output_split_idx].extend(list(ray.get(block_ref)["id"])) total = 0 for i in range(2): @@ -583,8 +585,8 @@ def test_limit_operator(ray_start_regular_shared): def _get_bundles(bundle: RefBundle): output = [] - for block, _ in bundle.blocks: - output.extend(list(ray.get(block)["id"])) + for block_ref in bundle.block_refs: + output.extend(list(ray.get(block_ref)["id"])) return output diff --git a/python/ray/data/tests/test_split.py b/python/ray/data/tests/test_split.py index 8fe481338baf..db729ba93f24 100644 --- a/python/ray/data/tests/test_split.py +++ b/python/ray/data/tests/test_split.py @@ -667,7 +667,7 @@ def equalize_helper(input_block_lists: List[List[List[Any]]]): result_block_lists = [] for bundle in result: block_list = [] - for block_ref, _ in bundle.blocks: + for block_ref in bundle.block_refs: block = ray.get(block_ref) block_accessor = BlockAccessor.for_block(block) block_list.append(list(block_accessor.to_default()["id"])) diff --git a/python/ray/data/tests/test_streaming_integration.py b/python/ray/data/tests/test_streaming_integration.py index a317af32ef83..9f2672123c06 100644 --- a/python/ray/data/tests/test_streaming_integration.py +++ b/python/ray/data/tests/test_streaming_integration.py @@ -43,8 +43,8 @@ def map_fn(block_iter, _): def ref_bundles_to_list(bundles: List[RefBundle]) -> List[List[Any]]: output = [] for bundle in bundles: - for block, _ in bundle.blocks: - output.append(list(ray.get(block)["id"])) + for block_ref in bundle.block_refs: + output.append(list(ray.get(block_ref)["id"])) return output @@ -144,8 +144,8 @@ def run(self): def get_outputs(out: List[RefBundle]): outputs = [] for bundle in out: - for block, _ in bundle.blocks: - ids: pd.Series = ray.get(block)["id"] + for block_ref in bundle.block_refs: + ids: pd.Series = ray.get(block_ref)["id"] outputs.extend(ids.values) return outputs