Skip to content

Commit

Permalink
Revert "[dataset] Use polars for sorting (ray-project#24523)" (ray-pr…
Browse files Browse the repository at this point in the history
…oject#24781)

This reverts commit c62e00e.

See if reverts this resolve linux://python/ray/tests:test_actor_advanced failure.
  • Loading branch information
scv119 authored May 13, 2022
1 parent cc21979 commit 2be45fe
Show file tree
Hide file tree
Showing 9 changed files with 37 additions and 138 deletions.
7 changes: 0 additions & 7 deletions python/ray/data/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import ray
from ray.util.annotations import DeveloperAPI


# The context singleton on this process.
_default_context: "Optional[DatasetContext]" = None
_context_lock = threading.Lock()
Expand Down Expand Up @@ -38,9 +37,6 @@
os.environ.get("RAY_DATASET_PUSH_BASED_SHUFFLE", None)
)

# Whether to use Polars for tabular dataset sorts, groupbys, and aggregations.
DEFAULT_USE_POLARS = False


@DeveloperAPI
class DatasetContext:
Expand All @@ -61,7 +57,6 @@ def __init__(
optimize_fuse_shuffle_stages: bool,
actor_prefetcher_enabled: bool,
use_push_based_shuffle: bool,
use_polars: bool,
):
"""Private constructor (use get_current() instead)."""
self.block_owner = block_owner
Expand All @@ -73,7 +68,6 @@ def __init__(
self.optimize_fuse_shuffle_stages = optimize_fuse_shuffle_stages
self.actor_prefetcher_enabled = actor_prefetcher_enabled
self.use_push_based_shuffle = use_push_based_shuffle
self.use_polars = use_polars

@staticmethod
def get_current() -> "DatasetContext":
Expand All @@ -97,7 +91,6 @@ def get_current() -> "DatasetContext":
optimize_fuse_shuffle_stages=DEFAULT_OPTIMIZE_FUSE_SHUFFLE_STAGES,
actor_prefetcher_enabled=DEFAULT_ACTOR_PREFETCHER_ENABLED,
use_push_based_shuffle=DEFAULT_USE_PUSH_BASED_SHUFFLE,
use_polars=DEFAULT_USE_POLARS,
)

if (
Expand Down
70 changes: 31 additions & 39 deletions python/ray/data/impl/arrow_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@
from ray.data.row import TableRow
from ray.data.impl.table_block import TableBlockAccessor, TableBlockBuilder
from ray.data.aggregate import AggregateFn
from ray.data.context import DatasetContext
from ray.data.impl.arrow_ops import transform_polars, transform_pyarrow

if TYPE_CHECKING:
import pandas
Expand All @@ -42,21 +40,6 @@
T = TypeVar("T")


# We offload some transformations to polars for performance.
def get_sort_transform(context: DatasetContext) -> Callable:
if context.use_polars:
return transform_polars.sort
else:
return transform_pyarrow.sort


def get_concat_and_sort_transform(context: DatasetContext) -> Callable:
if context.use_polars:
return transform_polars.concat_and_sort
else:
return transform_pyarrow.concat_and_sort


class ArrowRow(TableRow):
"""
Row of a tabular Dataset backed by a Arrow Table block.
Expand Down Expand Up @@ -282,35 +265,45 @@ def sort_and_partition(
# so calling sort_indices() will raise an error.
return [self._empty_table() for _ in range(len(boundaries) + 1)]

context = DatasetContext.get_current()
sort = get_sort_transform(context)
col, _ = key[0]
table = sort(self._table, key, descending)
import pyarrow.compute as pac

indices = pac.sort_indices(self._table, sort_keys=key)
table = self._table.take(indices)
if len(boundaries) == 0:
return [table]

partitions = []
# For each boundary value, count the number of items that are less
# than it. Since the block is sorted, these counts partition the items
# such that boundaries[i] <= x < boundaries[i + 1] for each x in
# partition[i]. If `descending` is true, `boundaries` would also be
# in descending order and we only need to count the number of items
# *greater than* the boundary value instead.
if descending:
num_rows = len(table[col])
bounds = num_rows - np.searchsorted(
table[col], boundaries, sorter=np.arange(num_rows - 1, -1, -1)
)
else:
bounds = np.searchsorted(table[col], boundaries)
last_idx = 0
for idx in bounds:
col, _ = key[0]
comp_fn = pac.greater if descending else pac.less

# TODO(ekl) this is O(n^2) but in practice it's much faster than the
# O(n) algorithm, could be optimized.
boundary_indices = [pac.sum(comp_fn(table[col], b)).as_py() for b in boundaries]
### Compute the boundary indices in O(n) time via scan. # noqa
# boundary_indices = []
# remaining = boundaries.copy()
# values = table[col]
# for i, x in enumerate(values):
# while remaining and not comp_fn(x, remaining[0]).as_py():
# remaining.pop(0)
# boundary_indices.append(i)
# for _ in remaining:
# boundary_indices.append(len(values))

ret = []
prev_i = 0
for i in boundary_indices:
# Slices need to be copied to avoid including the base table
# during serialization.
partitions.append(_copy_table(table.slice(last_idx, idx - last_idx)))
last_idx = idx
partitions.append(_copy_table(table.slice(last_idx)))
return partitions
ret.append(_copy_table(table.slice(prev_i, i - prev_i)))
prev_i = i
ret.append(_copy_table(table.slice(prev_i)))
return ret

def combine(self, key: KeyFn, aggs: Tuple[AggregateFn]) -> Block[ArrowRow]:
"""Combine rows with the same key into an accumulator.
Expand Down Expand Up @@ -398,10 +391,9 @@ def merge_sorted_blocks(
if len(blocks) == 0:
ret = ArrowBlockAccessor._empty_table()
else:
concat_and_sort = get_concat_and_sort_transform(
DatasetContext.get_current()
)
ret = concat_and_sort(blocks, key, _descending)
ret = pyarrow.concat_tables(blocks, promote=True)
indices = pyarrow.compute.sort_indices(ret, sort_keys=key)
ret = ret.take(indices)
return ret, ArrowBlockAccessor(ret).get_metadata(None, exec_stats=stats.build())

@staticmethod
Expand Down
Empty file.
40 changes: 0 additions & 40 deletions python/ray/data/impl/arrow_ops/transform_polars.py

This file was deleted.

24 changes: 0 additions & 24 deletions python/ray/data/impl/arrow_ops/transform_pyarrow.py

This file was deleted.

1 change: 0 additions & 1 deletion python/ray/data/tests/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def prepare_read(
return read_tasks


@pytest.mark.skip(reason="failing after #24523")
@pytest.mark.parametrize("lazy_input", [True, False])
def test_memory_release_pipeline(shutdown_only, lazy_input):
context = DatasetContext.get_current()
Expand Down
24 changes: 6 additions & 18 deletions python/ray/data/tests/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,12 @@ def test_sort_partition_same_key_to_same_block(

@pytest.mark.parametrize("num_items,parallelism", [(100, 1), (1000, 4)])
@pytest.mark.parametrize("use_push_based_shuffle", [False, True])
@pytest.mark.parametrize("use_polars", [False, True])
def test_sort_arrow(
ray_start_regular, num_items, parallelism, use_push_based_shuffle, use_polars
):
def test_sort_arrow(ray_start_regular, num_items, parallelism, use_push_based_shuffle):
ctx = ray.data.context.DatasetContext.get_current()

try:
original_push_based_shuffle = ctx.use_push_based_shuffle
original = ctx.use_push_based_shuffle
ctx.use_push_based_shuffle = use_push_based_shuffle
original_use_polars = ctx.use_polars
ctx.use_polars = use_polars

a = list(reversed(range(num_items)))
b = [f"{x:03}" for x in range(num_items)]
Expand Down Expand Up @@ -117,22 +112,16 @@ def assert_sorted(sorted_ds, expected_rows):
assert_sorted(ds.sort(key="b"), zip(a, b))
assert_sorted(ds.sort(key="a", descending=True), zip(a, b))
finally:
ctx.use_push_based_shuffle = original_push_based_shuffle
ctx.use_polars = original_use_polars
ctx.use_push_based_shuffle = original


@pytest.mark.parametrize("use_push_based_shuffle", [False, True])
@pytest.mark.parametrize("use_polars", [False, True])
def test_sort_arrow_with_empty_blocks(
ray_start_regular, use_push_based_shuffle, use_polars
):
def test_sort_arrow_with_empty_blocks(ray_start_regular, use_push_based_shuffle):
ctx = ray.data.context.DatasetContext.get_current()

try:
original_push_based_shuffle = ctx.use_push_based_shuffle
original = ctx.use_push_based_shuffle
ctx.use_push_based_shuffle = use_push_based_shuffle
original_use_polars = ctx.use_polars
ctx.use_polars = use_polars

assert (
BlockAccessor.for_block(pa.Table.from_pydict({})).sample(10, "A").num_rows
Expand Down Expand Up @@ -173,8 +162,7 @@ def test_sort_arrow_with_empty_blocks(
)
assert ds.sort("value").count() == 0
finally:
ctx.use_push_based_shuffle = original_push_based_shuffle
ctx.use_polars = original_use_polars
ctx.use_push_based_shuffle = original


def test_push_based_shuffle_schedule():
Expand Down
2 changes: 0 additions & 2 deletions python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ aiorwlock

# Requirements for running tests
pyarrow >= 6.0.1, < 7.0.0
# Used for Dataset tests.
polars
azure-cli-core==2.29.1
azure-identity==1.7.0
azure-mgmt-compute==23.1.0
Expand Down
7 changes: 0 additions & 7 deletions release/nightly_tests/dataset/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from ray.data.impl.arrow_block import ArrowRow
from ray.data.impl.util import _check_pyarrow_version
from ray.data.block import Block, BlockMetadata
from ray.data.context import DatasetContext

from ray.data.datasource import Datasource, ReadTask
from ray.internal.internal_api import memory_summary
Expand Down Expand Up @@ -86,15 +85,9 @@ def make_block(count: int, num_columns: int) -> Block:
parser.add_argument(
"--shuffle", help="shuffle instead of sort", action="store_true"
)
parser.add_argument("--use-polars", action="store_true")

args = parser.parse_args()

if args.use_polars and not args.shuffle:
print("Using polars for sort")
ctx = DatasetContext.get_current()
ctx.use_polars = True

num_partitions = int(args.num_partitions)
partition_size = int(float(args.partition_size))
print(
Expand Down

0 comments on commit 2be45fe

Please sign in to comment.