Skip to content

Commit

Permalink
register ownership (oap-project#349)
Browse files Browse the repository at this point in the history
Signed-off-by: Zhi Lin <[email protected]>
  • Loading branch information
kira-lin authored Jun 26, 2023
1 parent b9e1b99 commit 7571312
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
package org.apache.spark.raydp;

import io.ray.api.ObjectRef;
import io.ray.api.Ray;
import io.ray.api.id.ObjectId;
import io.ray.runtime.AbstractRayRuntime;
import io.ray.runtime.object.ObjectRefImpl;

public class RayDPUtils {
Expand All @@ -40,9 +42,13 @@ public static <T> ObjectRefImpl<T> convert(ObjectRef<T> obj) {
* Create ObjectRef from Array[Byte] and register ownership.
* We can't import the ObjectRefImpl in scala code, so we do the conversion at here.
*/
public static <T> ObjectRef<T> readBinary(byte[] obj, Class<T> clazz) {
public static <T> ObjectRef<T> readBinary(byte[] obj, Class<T> clazz, byte[] ownerAddress) {
ObjectId id = new ObjectId(obj);
ObjectRefImpl<T> ref = new ObjectRefImpl<>(id, clazz);
ObjectRefImpl<T> ref = new ObjectRefImpl<>(id, clazz, false);
AbstractRayRuntime runtime = (AbstractRayRuntime) Ray.internal();
runtime.getObjectStore().registerOwnershipInfoAndResolveFuture(
id, null, ownerAddress
);
return ref;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,11 @@ class RayDatasetRDD(

override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
val ref = split.asInstanceOf[RayDatasetRDDPartition].ref
ObjectStoreReader.getBatchesFromStream(ref)
ObjectStoreReader.getBatchesFromStream(ref, locations.get(split.index))
}

override def getPreferredLocations(split: Partition): Seq[String] = {
Seq(Address.parseFrom(locations.get(split.index)).getIpAddress())
val address = Address.parseFrom(locations.get(split.index))
Seq(address.getIpAddress())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ object ObjectStoreReader {
}

def getBatchesFromStream(
ref: Array[Byte]): Iterator[Array[Byte]] = {
val objectRef = RayDPUtils.readBinary(ref, classOf[Array[Byte]])
ref: Array[Byte],
ownerAddress: Array[Byte]): Iterator[Array[Byte]] = {
val objectRef = RayDPUtils.readBinary(ref, classOf[Array[Byte]], ownerAddress)
ArrowConverters.getBatchesFromStream(
Channels.newChannel(new ByteArrayInputStream(objectRef.get)))
}
Expand Down
20 changes: 20 additions & 0 deletions python/raydp/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,26 @@ def stop_all():
request.addfinalizer(stop_all)
return spark

@pytest.fixture(scope="function", params=["local", "ray://localhost:10001"])
def spark_on_ray_2_executors(request):
ray.shutdown()
if request.param == "local":
ray.init(address="local", num_cpus=6, include_dashboard=False)
else:
ray.init(address=request.param)
node_ip = ray.util.get_node_ip_address()
spark = raydp.init_spark("test", 2, 1, "500M", configs= {
"spark.driver.host": node_ip,
"spark.driver.bindAddress": node_ip
})

def stop_all():
raydp.stop_spark()
ray.shutdown()

request.addfinalizer(stop_all)
return spark


@pytest.fixture(scope="function")
def spark_on_ray_fraction_custom_resource(request):
Expand Down
8 changes: 4 additions & 4 deletions python/raydp/tests/test_spark_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,12 @@ def test_spark_driver_and_executor_hostname(spark_on_ray_small):
assert node_ip_address == driver_bind_address


def test_ray_dataset_roundtrip(spark_on_ray_small):
def test_ray_dataset_roundtrip(spark_on_ray_2_executors):
# skipping this to be compatible with ray 2.4.0
# see issue #343
if not ray.worker.global_worker.connected:
pytest.skip("Skip this test if using ray client")
spark = spark_on_ray_small
spark = spark_on_ray_2_executors
spark_df = spark.createDataFrame([(1, "a"), (2, "b"), (3, "c")], ["one", "two"])
rows = [(r.one, r.two) for r in spark_df.take(3)]
ds = ray.data.from_spark(spark_df)
Expand All @@ -110,12 +110,12 @@ def test_ray_dataset_roundtrip(spark_on_ray_small):
assert values == rows_2


def test_ray_dataset_to_spark(spark_on_ray_small):
def test_ray_dataset_to_spark(spark_on_ray_2_executors):
# skipping this to be compatible with ray 2.4.0
# see issue #343
if not ray.worker.global_worker.connected:
pytest.skip("Skip this test if using ray client")
spark = spark_on_ray_small
spark = spark_on_ray_2_executors
n = 5
data = {"value": list(range(n))}
ds = ray.data.from_arrow(pyarrow.Table.from_pydict(data))
Expand Down

0 comments on commit 7571312

Please sign in to comment.