Skip to content

Commit

Permalink
[Data] Fix write_results type (ray-project#33936)
Browse files Browse the repository at this point in the history
`Dataset.write_datasource` passes a `list[list[WriteResult]]` to `Datasource.on_write_complete` instead of a `list[WriteResult]`. This PR fixes the bug.
  • Loading branch information
bveeramani authored Apr 3, 2023
1 parent c6385fe commit a74e563
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 7 deletions.
8 changes: 5 additions & 3 deletions python/ray/data/_internal/planner/write.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Callable, Iterator
from typing import Callable, Iterator, List

from ray.data._internal.execution.interfaces import TaskContext
from ray.data.block import Block
from ray.data.datasource import Datasource
from ray.data.datasource import Datasource, WriteResult


def generate_write_fn(
Expand All @@ -13,6 +13,8 @@ def generate_write_fn(
# be raised. The Datasource can handle execution outcomes with the
# on_write_complete() and on_write_failed().
def fn(blocks: Iterator[Block], ctx) -> Iterator[Block]:
return [[datasource.write(blocks, ctx, **write_args)]]
# NOTE: `WriteResult` isn't a valid block type, so we need to wrap it in a list.
block: List[WriteResult] = [datasource.write(blocks, ctx, **write_args)]
return [block]

return fn
7 changes: 5 additions & 2 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2907,9 +2907,12 @@ def write_fn_wrapper(blocks: Iterator[Block], ctx, fn) -> Iterator[Block]:
self._write_ds = Dataset(
plan, self._epoch, self._lazy, logical_plan
).cache()
datasource.on_write_complete(
ray.get(self._write_ds._plan.execute().get_blocks())
blocks = ray.get(self._write_ds._plan.execute().get_blocks())
assert all(
isinstance(block, list) and len(block) == 1 for block in blocks
)
write_results = [block[0] for block in blocks]
datasource.on_write_complete(write_results)
except Exception as e:
datasource.on_write_failed([], e)
raise
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/datasource/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def write(
return "ok"

def on_write_complete(self, write_results: List[WriteResult]) -> None:
assert all(w == ["ok"] for w in write_results), write_results
assert all(w == "ok" for w in write_results), write_results
self.num_ok += 1

def on_write_failed(
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/tests/test_dataset_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def write(b):
return "ok"

def on_write_complete(self, write_results: List[WriteResult]) -> None:
assert all(w == ["ok"] for w in write_results), write_results
assert all(w == "ok" for w in write_results), write_results
self.num_ok += 1

def on_write_failed(
Expand Down

0 comments on commit a74e563

Please sign in to comment.