Skip to content

Commit

Permalink
[Data] Add metadata and block_refs properties to RefBundle (ray…
Browse files Browse the repository at this point in the history
…-project#45567)

RefBundle stores data as a List[Tuple[ObjectRef, BlockMetadata]]. Often, we'll need to access either just the object references (List[ObjectRef]) or the metadata (List[BlockMetadata]). To avoid boilerplate code to access this data, this PR adds properties to separately access the object references and metadata.

Signed-off-by: Balaji Veeramani <[email protected]>
  • Loading branch information
bveeramani authored May 28, 2024
1 parent ce47bca commit c3e6eca
Show file tree
Hide file tree
Showing 14 changed files with 51 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -512,12 +512,10 @@ def on_task_finished(self, task_index: int, exception: Optional[Exception]):
input_size,
)

blocks = [input[0] for input in inputs.blocks]
metadata = [input[1] for input in inputs.blocks]
ctx = ray.data.context.DataContext.get_current()
if ctx.enable_get_object_locations_for_metrics:
locations = ray.experimental.get_object_locations(blocks)
for block, meta in zip(blocks, metadata):
locations = ray.experimental.get_object_locations(inputs.block_refs)
for block, meta in inputs.blocks:
if locations[block].get("did_spill", False):
assert meta.size_bytes is not None
self.obj_store_mem_spilled += meta.size_bytes
Expand Down
28 changes: 20 additions & 8 deletions python/ray/data/_internal/execution/interfaces/ref_bundle.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import List, Optional, Tuple

import ray
from .common import NodeIdStr
Expand Down Expand Up @@ -58,19 +58,29 @@ def __setattr__(self, key, value):
raise ValueError(f"The `{key}` field of RefBundle cannot be updated.")
object.__setattr__(self, key, value)

@property
def block_refs(self) -> List[BlockMetadata]:
"""List of block references in this bundle."""
return [block_ref for block_ref, _ in self.blocks]

@property
def metadata(self) -> List[BlockMetadata]:
"""List of block metadata in this bundle."""
return [metadata for _, metadata in self.blocks]

def num_rows(self) -> Optional[int]:
"""Number of rows present in this bundle, if known."""
total = 0
for b in self.blocks:
if b[1].num_rows is None:
for m in self.metadata:
if m.num_rows is None:
return None
else:
total += b[1].num_rows
total += m.num_rows
return total

def size_bytes(self) -> int:
"""Size of the blocks of this bundle in bytes."""
return sum(b[1].size_bytes for b in self.blocks)
return sum(m.size_bytes for m in self.metadata)

def destroy_if_owned(self) -> int:
"""Clears the object store memory for these blocks if owned.
Expand All @@ -79,8 +89,10 @@ def destroy_if_owned(self) -> int:
The number of bytes freed.
"""
should_free = self.owns_blocks and DataContext.get_current().eager_free
for b in self.blocks:
trace_deallocation(b[0], "RefBundle.destroy_if_owned", free=should_free)
for block_ref in self.block_refs:
trace_deallocation(
block_ref, "RefBundle.destroy_if_owned", free=should_free
)
return self.size_bytes() if should_free else 0

def get_cached_location(self) -> Optional[NodeIdStr]:
Expand All @@ -91,7 +103,7 @@ def get_cached_location(self) -> Optional[NodeIdStr]:
if self._cached_location is None:
# Only consider the first block in the bundle for now. TODO(ekl) consider
# taking into account other blocks.
ref = self.blocks[0][0]
ref = self.block_refs[0]
# This call is pretty fast for owned objects (~5k/s), so we don't need to
# batch it for now.
locs = ray.experimental.get_object_locations([ref])
Expand Down
5 changes: 2 additions & 3 deletions python/ray/data/_internal/execution/legacy_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,8 @@ def _bundles_to_block_list(bundles: Iterator[RefBundle]) -> BlockList:
for ref_bundle in bundles:
if not ref_bundle.owns_blocks:
owns_blocks = False
for block, meta in ref_bundle.blocks:
blocks.append(block)
metadata.append(meta)
blocks.extend(ref_bundle.block_refs)
metadata.extend(ref_bundle.metadata)
return BlockList(blocks, metadata, owned_by_consumer=owns_blocks)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _initialize_metadata(self):
self._num_output_bundles = len(self._input_data)
block_metadata = []
for bundle in self._input_data:
block_metadata.extend([m for (_, m) in bundle.blocks])
block_metadata.extend(bundle.metadata)
self._stats = {
"input": block_metadata,
}
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,7 @@ def _get_next_inner(self) -> RefBundle:
assert self._started
bundle = self._output_queue.get_next()
self._metrics.on_output_dequeued(bundle)
for _, meta in bundle.blocks:
self._output_metadata.append(meta)
self._output_metadata.extend(bundle.metadata)
return bundle

@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def __init__(
def _add_bundled_input(self, bundle: RefBundle):
# Submit the task as a normal Ray task.
map_task = cached_remote_fn(_map_task, num_returns="streaming")
input_blocks = [block for block, _ in bundle.blocks]
ctx = TaskContext(
task_idx=self._next_data_task_idx,
target_max_block_size=self.actual_target_max_block_size,
Expand All @@ -84,7 +83,7 @@ def _add_bundled_input(self, bundle: RefBundle):
self._map_transformer_ref,
data_context,
ctx,
*input_blocks,
*bundle.block_refs,
)
self._submit_data_task(gen, bundle)

Expand Down
5 changes: 2 additions & 3 deletions python/ray/data/_internal/planner/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,8 @@ def fn(
blocks = []
metadata = []
for ref_bundle in refs:
for block, block_metadata in ref_bundle.blocks:
blocks.append(block)
metadata.append(block_metadata)
blocks.extend(ref_bundle.block_refs)
metadata.extend(ref_bundle.metadata)
if len(blocks) == 0:
return (blocks, {})
unified_schema = unify_block_metadata_schema(metadata)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ def execute(
# eagerly release the blocks' memory.
input_blocks_list = []
for ref_bundle in refs:
for block, _ in ref_bundle.blocks:
input_blocks_list.append(block)
input_blocks_list.extend(ref_bundle.block_refs)
input_num_blocks = len(input_blocks_list)
input_owned = all(b.owns_blocks for b in refs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -457,8 +457,7 @@ def execute(
# processed concurrently.
input_blocks_list = []
for ref_bundle in refs:
for block, _ in ref_bundle.blocks:
input_blocks_list.append(block)
input_blocks_list.extend(ref_bundle.block_refs)
input_owned = all(b.owns_blocks for b in refs)

if map_ray_remote_args is None:
Expand Down
3 changes: 1 addition & 2 deletions python/ray/data/_internal/planner/randomize_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ def fn(
nonlocal op
blocks_with_metadata = []
for ref_bundle in refs:
for block, meta in ref_bundle.blocks:
blocks_with_metadata.append((block, meta))
blocks_with_metadata.extend(ref_bundle.blocks)

if len(blocks_with_metadata) == 0:
return refs, {op._name: []}
Expand Down
5 changes: 2 additions & 3 deletions python/ray/data/_internal/planner/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,8 @@ def fn(
blocks = []
metadata = []
for ref_bundle in refs:
for block, block_metadata in ref_bundle.blocks:
blocks.append(block)
metadata.append(block_metadata)
blocks.extend(ref_bundle.block_refs)
metadata.extend(ref_bundle.metadata)
if len(blocks) == 0:
return (blocks, {})
sort_key.validate_schema(unify_block_metadata_schema(metadata))
Expand Down
22 changes: 12 additions & 10 deletions python/ray/data/tests/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@


def _get_blocks(bundle: RefBundle, output_list: List[Block]):
for block, _ in bundle.blocks:
output_list.append(list(ray.get(block)["id"]))
for block_ref in bundle.block_refs:
output_list.append(list(ray.get(block_ref)["id"]))


def _mul2_transform(block_iter: Iterable[Block], ctx) -> Iterable[Block]:
Expand Down Expand Up @@ -196,9 +196,11 @@ def test_split_operator(ray_start_regular_shared, equal, chunk_size):
while op.has_next():
ref = op.get_next()
assert ref.owns_blocks, ref
for block, _ in ref.blocks:
for block_ref in ref.block_refs:
assert ref.output_split_idx is not None
output_splits[ref.output_split_idx].extend(list(ray.get(block)["id"]))
output_splits[ref.output_split_idx].extend(
list(ray.get(block_ref)["id"])
)
op.all_inputs_done()

expected_splits = [[] for _ in range(num_splits)]
Expand Down Expand Up @@ -234,8 +236,8 @@ def test_split_operator_random(ray_start_regular_shared, equal, random_seed):
while op.has_next():
ref = op.get_next()
assert ref.owns_blocks, ref
for block, _ in ref.blocks:
output_splits[ref.output_split_idx].extend(list(ray.get(block)["id"]))
for block_ref in ref.block_refs:
output_splits[ref.output_split_idx].extend(list(ray.get(block_ref)["id"]))
if equal:
actual = [len(output_splits[i]) for i in range(3)]
expected = [num_inputs // 3] * 3
Expand Down Expand Up @@ -271,8 +273,8 @@ def get_bundle_loc(bundle):
while op.has_next():
ref = op.get_next()
assert ref.owns_blocks, ref
for block, _ in ref.blocks:
output_splits[ref.output_split_idx].extend(list(ray.get(block)["id"]))
for block_ref in ref.block_refs:
output_splits[ref.output_split_idx].extend(list(ray.get(block_ref)["id"]))

total = 0
for i in range(2):
Expand Down Expand Up @@ -583,8 +585,8 @@ def test_limit_operator(ray_start_regular_shared):

def _get_bundles(bundle: RefBundle):
output = []
for block, _ in bundle.blocks:
output.extend(list(ray.get(block)["id"]))
for block_ref in bundle.block_refs:
output.extend(list(ray.get(block_ref)["id"]))
return output


Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/tests/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ def equalize_helper(input_block_lists: List[List[List[Any]]]):
result_block_lists = []
for bundle in result:
block_list = []
for block_ref, _ in bundle.blocks:
for block_ref in bundle.block_refs:
block = ray.get(block_ref)
block_accessor = BlockAccessor.for_block(block)
block_list.append(list(block_accessor.to_default()["id"]))
Expand Down
8 changes: 4 additions & 4 deletions python/ray/data/tests/test_streaming_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def map_fn(block_iter, _):
def ref_bundles_to_list(bundles: List[RefBundle]) -> List[List[Any]]:
output = []
for bundle in bundles:
for block, _ in bundle.blocks:
output.append(list(ray.get(block)["id"]))
for block_ref in bundle.block_refs:
output.append(list(ray.get(block_ref)["id"]))
return output


Expand Down Expand Up @@ -144,8 +144,8 @@ def run(self):
def get_outputs(out: List[RefBundle]):
outputs = []
for bundle in out:
for block, _ in bundle.blocks:
ids: pd.Series = ray.get(block)["id"]
for block_ref in bundle.block_refs:
ids: pd.Series = ray.get(block_ref)["id"]
outputs.extend(ids.values)
return outputs

Expand Down

0 comments on commit c3e6eca

Please sign in to comment.