Skip to content

Commit

Permalink
[data] register arrow extension types in __init__ (ray-project#41689)
Browse files Browse the repository at this point in the history
Signed-off-by: Andrew Xue <[email protected]>
  • Loading branch information
Zandew authored Dec 7, 2023
1 parent 5453965 commit 6a30439
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 2 deletions.
10 changes: 8 additions & 2 deletions python/ray/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,14 @@
# PyArrow is mocked in documentation builds. In this case, we don't need to do
# anything.
pass
elif parse_version(pyarrow_version) >= parse_version("14.0.1"):
pa.PyExtensionType.set_auto_load(True)
else:
if parse_version(pyarrow_version) >= parse_version("14.0.1"):
pa.PyExtensionType.set_auto_load(True)
# Import these arrow extension types to ensure that they are registered.
from ray.air.util.tensor_extensions.arrow import ( # noqa
ArrowTensorType,
ArrowVariableShapedTensorType,
)
except ModuleNotFoundError:
pass

Expand Down
26 changes: 26 additions & 0 deletions python/ray/data/tests/test_arrow_block.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import numpy as np
import pyarrow as pa
import pytest

import ray
from ray._private.test_utils import run_string_as_driver
from ray.data._internal.arrow_block import ArrowBlockAccessor


Expand All @@ -16,6 +19,29 @@ def test_append_column(ray_start_regular_shared):
assert actual_block.equals(expected_block)


def test_register_arrow_types(tmp_path):
# Test that our custom arrow extension types are registered on initialization.
ds = ray.data.from_items(np.zeros((8, 8, 8), dtype=np.int64))
tmp_file = f"{tmp_path}/test.parquet"
ds.write_parquet(tmp_file)

ds = ray.data.read_parquet(tmp_file)
schema = (
"Column Type\n------ ----\nitem numpy.ndarray(shape=(8, 8), dtype=int64)"
)
assert str(ds.schema()) == schema

# Also run in driver script to eliminate existing imports.
driver_script = """import ray
ds = ray.data.read_parquet("{0}")
schema = ds.schema()
assert str(schema) == \"\"\"{1}\"\"\"
""".format(
tmp_file, schema
)
run_string_as_driver(driver_script)


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit 6a30439

Please sign in to comment.