Skip to content

Commit

Permalink
[Core][C++ Worker] Cross language call support bytes[] type (ray-proj…
Browse files Browse the repository at this point in the history
  • Loading branch information
larrylian authored Feb 16, 2023
1 parent 8264285 commit cd3a7a7
Show file tree
Hide file tree
Showing 9 changed files with 110 additions and 15 deletions.
26 changes: 21 additions & 5 deletions cpp/src/ray/runtime/object/native_object_store.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,31 @@ std::vector<std::shared_ptr<msgpack::sbuffer>> NativeObjectStore::GetRaw(
for (size_t i = 0; i < results.size(); i++) {
const auto &meta = results[i]->GetMetadata();
const auto &data_buffer = results[i]->GetData();
std::string meta_str = "";
if (meta != nullptr) {
std::string meta_str((char *)meta->Data(), meta->Size());
meta_str = std::string((char *)meta->Data(), meta->Size());
CheckException(meta_str, data_buffer);
}

auto sbuffer = std::make_shared<msgpack::sbuffer>(data_buffer->Size());
sbuffer->write(reinterpret_cast<const char *>(data_buffer->Data()),
data_buffer->Size());
result_sbuffers.push_back(sbuffer);
const char *data = nullptr;
size_t data_size = 0;
if (data_buffer) {
data = reinterpret_cast<const char *>(data_buffer->Data());
data_size = data_buffer->Size();
}
if (meta_str == METADATA_STR_RAW) {
// TODO(LarryLian) In order to minimize the modification,
// there is an extra serialization here, but the performance will be a little worse.
// This code can be optimized later to improve performance
auto raw_buffer = Serializer::Serialize(data, data_size);
auto sbuffer = std::make_shared<msgpack::sbuffer>(raw_buffer.size());
sbuffer->write(raw_buffer.data(), raw_buffer.size());
result_sbuffers.push_back(sbuffer);
} else {
auto sbuffer = std::make_shared<msgpack::sbuffer>(data_size);
sbuffer->write(data, data_size);
result_sbuffers.push_back(sbuffer);
}
}
return result_sbuffers;
}
Expand Down
26 changes: 22 additions & 4 deletions cpp/src/ray/runtime/task/task_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,30 @@ Status TaskExecutor::ExecuteTask(
ArgsBufferList ray_args_buffer;
for (size_t i = 0; i < args_buffer.size(); i++) {
auto &arg = args_buffer.at(i);
std::string meta_str = "";
if (arg->GetMetadata() != nullptr) {
meta_str = std::string((const char *)arg->GetMetadata()->Data(),
arg->GetMetadata()->Size());
}
msgpack::sbuffer sbuf;
if (cross_lang) {
sbuf.write((const char *)(arg->GetData()->Data()) + XLANG_HEADER_LEN,
arg->GetData()->Size() - XLANG_HEADER_LEN);
const char *arg_data = nullptr;
size_t arg_data_size = 0;
if (arg->GetData()) {
arg_data = reinterpret_cast<const char *>(arg->GetData()->Data());
arg_data_size = arg->GetData()->Size();
}
if (meta_str == METADATA_STR_RAW) {
// TODO(LarryLian) In order to minimize the modification,
// there is an extra serialization here, but the performance will be a little worse.
// This code can be optimized later to improve performance
const auto &raw_buffer = Serializer::Serialize(arg_data, arg_data_size);
sbuf.write(raw_buffer.data(), raw_buffer.size());
} else if (cross_lang) {
RAY_CHECK(arg_data != nullptr)
<< "Task " << task_name << " no." << i << " arg data is null.";
sbuf.write(arg_data + XLANG_HEADER_LEN, arg_data_size - XLANG_HEADER_LEN);
} else {
sbuf.write((const char *)(arg->GetData()->Data()), arg->GetData()->Size());
sbuf.write(arg_data, arg_data_size);
}

ray_args_buffer.push_back(std::move(sbuf));
Expand Down
19 changes: 19 additions & 0 deletions cpp/src/ray/test/cluster/cluster_mode_xlang_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,17 @@ TEST(RayClusterModeXLangTest, JavaInvocationTest) {
named_actor_handle.Task(ray::JavaActorMethod<int>{"getValue"}).Remote();
EXPECT_EQ(0, *named_actor_obj1.Get());

std::vector<std::byte> bytes = {std::byte{1}, std::byte{2}, std::byte{3}};
auto ref_bytes = java_class_actor_handle
.Task(ray::JavaActorMethod<std::vector<std::byte>>{"echoBytes"})
.Remote(bytes);
EXPECT_EQ(*ref_bytes.Get(), bytes);

auto ref_bytes2 = java_class_actor_handle
.Task(ray::JavaActorMethod<std::vector<std::byte>>{"echoBytes"})
.Remote(std::vector<std::byte>());
EXPECT_EQ(*ref_bytes2.Get(), std::vector<std::byte>());

// Test get other java actor by actor name.
auto ref_1 =
java_class_actor_handle.Task(ray::JavaActorMethod<std::string>{"createChildActor"})
Expand All @@ -63,6 +74,14 @@ TEST(RayClusterModeXLangTest, JavaInvocationTest) {
ray::ActorHandleXlang &child_actor = *child_actor_optional;
auto ref_2 = child_actor.Task(ray::JavaActorMethod<int>{"getValue"}).Remote();
EXPECT_EQ(0, *ref_2.Get());

auto ref_3 =
child_actor.Task(ray::JavaActorMethod<std::string>{"echo"}).Remote("C++ worker");
EXPECT_EQ("C++ worker", *ref_3.Get());

auto ref_4 = child_actor.Task(ray::JavaActorMethod<std::vector<std::byte>>{"echoBytes"})
.Remote(bytes);
EXPECT_EQ(*ref_4.Get(), bytes);
}

TEST(RayClusterModeXLangTest, GetXLangActorByNameTest) {
Expand Down
4 changes: 3 additions & 1 deletion cpp/src/ray/test/cluster/counter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,9 @@ RAY_REMOTE(RAY_FUNC(Counter::FactoryCreate),
&Counter::Plus1ForActor,
&Counter::GetCount,
&Counter::CreateNestedChildActor,
&Counter::GetBytes);
&Counter::GetBytes,
&Counter::echoBytes,
&Counter::echoString);

RAY_REMOTE(ActorConcurrentCall::FactoryCreate, &ActorConcurrentCall::CountDown);

Expand Down
7 changes: 4 additions & 3 deletions cpp/src/ray/test/cluster/counter.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ class Counter {
return bytes;
}

std::vector<std::byte> echoBytes(const std::vector<std::byte> &bytes) { return bytes; }

std::string echoString(const std::string &str) { return str; }

int GetIntVal(ray::ObjectRef<ray::ObjectRef<int>> obj) {
auto val = *obj.Get();
return *val.Get();
Expand All @@ -77,9 +81,6 @@ class Counter {

std::string GetEnvVar(std::string key);

inline Counter *CreateCounter() { return new Counter(0); }
RAY_REMOTE(CreateCounter);

class CountDownLatch {
public:
explicit CountDownLatch(size_t count) : m_count(count) {}
Expand Down
20 changes: 20 additions & 0 deletions cpp/src/ray/test/serialization_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,24 @@ TEST(SerializationTest, TypeHybridTest) {

EXPECT_EQ(in_arg1, out_arg1);
EXPECT_EQ(in_arg2, out_arg2);
}

TEST(SerializationTest, BoundaryValueTest) {
std::string in_arg1 = "", out_arg1;
msgpack::sbuffer buffer1 = ray::internal::Serializer::Serialize(in_arg1);
out_arg1 =
ray::internal::Serializer::Deserialize<std::string>(buffer1.data(), buffer1.size());
EXPECT_EQ(in_arg1, out_arg1);

std::vector<std::byte> in_arg2, out_arg2;
msgpack::sbuffer buffer2 = ray::internal::Serializer::Serialize(in_arg2);
out_arg2 = ray::internal::Serializer::Deserialize<std::vector<std::byte>>(
buffer1.data(), buffer1.size());
EXPECT_EQ(in_arg2, out_arg2);

char *in_arg3 = nullptr;
msgpack::sbuffer buffer3 = ray::internal::Serializer::Serialize(in_arg3, 0);
auto out_arg3 = ray::internal::Serializer::Deserialize<std::vector<std::byte>>(
buffer1.data(), buffer1.size());
EXPECT_EQ(std::vector<std::byte>(), out_arg3);
}
4 changes: 3 additions & 1 deletion cpp/test_python_call_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def test_cross_language_cpp():


def test_cross_language_cpp_actor():
actor = ray.cross_language.cpp_actor_class("CreateCounter", "Counter").remote()
actor = ray.cross_language.cpp_actor_class(
"RAY_FUNC(Counter::FactoryCreate)", "Counter"
).remote()
obj = actor.Plus1.remote()
assert 1 == ray.get(obj)

Expand Down
4 changes: 4 additions & 0 deletions java/test/src/main/java/io/ray/test/Counter.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ public String echo(String str) {
return str;
}

public byte[] echoBytes(byte[] bytes) {
return bytes;
}

public String createChildActor(String actorName) {
childActor = Ray.actor(Counter::new, 0).setName(actorName).remote();
Assert.assertEquals(Integer.valueOf(0), childActor.task(Counter::getValue).remote().get());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,13 +231,26 @@ public void testCallingCppFunction() {
public void testCallingCppActor() {
String actorName = "actor_name";
CppActorHandle actor =
Ray.actor(CppActorClass.of("CreateCounter", "Counter")).setName(actorName).remote();
Ray.actor(CppActorClass.of("RAY_FUNC(Counter::FactoryCreate)", "Counter"))
.setName(actorName)
.remote();
ObjectRef<Integer> res = actor.task(CppActorMethod.of("Plus1", Integer.class)).remote();
Assert.assertEquals(res.get(), Integer.valueOf(1));
ObjectRef<byte[]> b =
actor.task(CppActorMethod.of("GetBytes", byte[].class), "C++ Worker").remote();
Assert.assertEquals(b.get(), "C++ Worker".getBytes());

ObjectRef<byte[]> b2 =
actor.task(CppActorMethod.of("echoBytes", byte[].class), "C++ Worker".getBytes()).remote();
Assert.assertEquals(b2.get(), "C++ Worker".getBytes());

ObjectRef<byte[]> b3 =
actor.task(CppActorMethod.of("echoBytes", byte[].class), new byte[0]).remote();
Assert.assertEquals(b3.get(), new byte[0]);

ObjectRef<byte[]> b4 = actor.task(CppActorMethod.of("echoBytes", byte[].class), null).remote();
Assert.assertThrows(CrossLanguageException.class, () -> b4.get());

// Test get cpp actor by actor name.
Optional<CppActorHandle> optional = Ray.getActor(actorName);
Assert.assertTrue(optional.isPresent());
Expand Down

0 comments on commit cd3a7a7

Please sign in to comment.