From a3ce780d29fb7ab233d9e5e031f99ef848f9e38f Mon Sep 17 00:00:00 2001 From: Jinjing Zhou Date: Mon, 6 Dec 2021 20:03:33 +0800 Subject: [PATCH] [RPC] Use tensorpipe for rpc communication (#3335) * doesn't know whether works * add change * fix * fix * fix * remove * revert * lint * lint * fix * revert * lint * fix * only build rpc on linux * lint * lint * fix build on windows * fix windows * remove old test * fix cmake * Revert "remove old test" This reverts commit f1ea75c777c34cdc1f08c0589676ba6aee1feb29. * fix windows * fix * fix * fix indent * fix indent * address comment * fix * fix * fix * fix * fix * lint * fix indent * fix lint * add introduction * fix * lint * lint * add more logs * fix * update xbyak for C++14 with gcc5 * Remove channels * fix * add test script * fix * remove unused file * fix lint * add timeout --- .gitmodules | 3 + CMakeLists.txt | 21 +- Jenkinsfile | 2 +- python/dgl/distributed/rpc_server.py | 4 +- src/graph/serialize/zerocopy_serializer.cc | 2 +- src/rpc/rpc.cc | 322 +++++++++++---------- src/rpc/rpc.h | 88 ++---- src/rpc/rpc_msg.h | 68 +++++ src/rpc/tensorpipe/README.md | 104 +++++++ src/rpc/tensorpipe/queue.h | 53 ++++ src/rpc/tensorpipe/tp_communicator.cc | 168 +++++++++++ src/rpc/tensorpipe/tp_communicator.h | 186 ++++++++++++ tests/distributed/test_rpc.py | 47 ++- third_party/tensorpipe | 1 + third_party/xbyak | 2 +- 15 files changed, 850 insertions(+), 221 deletions(-) create mode 100644 src/rpc/rpc_msg.h create mode 100644 src/rpc/tensorpipe/README.md create mode 100644 src/rpc/tensorpipe/queue.h create mode 100644 src/rpc/tensorpipe/tp_communicator.cc create mode 100644 src/rpc/tensorpipe/tp_communicator.h create mode 160000 third_party/tensorpipe diff --git a/.gitmodules b/.gitmodules index 3bf2c4ef6465..515d34ef31a0 100644 --- a/.gitmodules +++ b/.gitmodules @@ -32,6 +32,9 @@ [submodule "third_party/libxsmm"] path = third_party/libxsmm url = https://github.com/hfp/libxsmm.git +[submodule "third_party/tensorpipe"] + path = third_party/tensorpipe + url = https://github.com/pytorch/tensorpipe [submodule "third_party/thrust"] path = third_party/thrust url = https://github.com/NVIDIA/thrust.git diff --git a/CMakeLists.txt b/CMakeLists.txt index af0998deb7ff..f32f20d6362e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -166,11 +166,17 @@ file(GLOB_RECURSE DGL_SRC_1 src/api/*.cc src/graph/*.cc src/scheduler/*.cc - src/rpc/*.cc ) list(APPEND DGL_SRC ${DGL_SRC_1}) +if (NOT MSVC) + file(GLOB_RECURSE DGL_RPC_SRC src/rpc/*.cc) +else() + file(GLOB_RECURSE DGL_RPC_SRC src/rpc/network/*.cc) +endif() +list(APPEND DGL_SRC ${DGL_RPC_SRC}) + # Configure cuda if(USE_CUDA) dgl_config_cuda(DGL_CUDA_SRC) @@ -198,6 +204,8 @@ else(USE_CUDA) add_library(dgl SHARED ${DGL_SRC}) endif(USE_CUDA) +set_property(TARGET dgl PROPERTY CXX_STANDARD 14) + # include directories target_include_directories(dgl PRIVATE "include") target_include_directories(dgl PRIVATE "third_party/dlpack/include") @@ -209,6 +217,7 @@ target_include_directories(dgl PRIVATE "tensoradapter/include") target_include_directories(dgl PRIVATE "third_party/nanoflann/include") target_include_directories(dgl PRIVATE "third_party/libxsmm/include") + # For serialization if (USE_HDFS) option(DMLC_HDFS_SHARED "dgl has to build with dynamic hdfs library" ON) @@ -242,6 +251,16 @@ if((NOT MSVC) AND USE_LIBXSMM) list(APPEND DGL_LINKER_LIBS -L${CMAKE_SOURCE_DIR}/third_party/libxsmm/lib/ xsmm) endif((NOT MSVC) AND USE_LIBXSMM) +if(NOT MSVC) + # Only build tensorpipe on linux + string(REPLACE "-pedantic" "" CMAKE_C_FLAGS ${CMAKE_C_FLAGS}) + set(TP_BUILD_LIBUV ON) + set(TP_STATIC_OR_SHARED STATIC) + add_subdirectory(third_party/tensorpipe) + list(APPEND DGL_LINKER_LIBS tensorpipe) + target_include_directories(dgl PRIVATE third_party/tensorpipe) +endif(NOT MSVC) + # Compile TVM Runtime and Featgraph # (NOTE) We compile a dynamic library called featgraph_runtime, which the DGL library links to. # Kernels are packed in a separate dynamic library called featgraph_kernels, which DGL diff --git a/Jenkinsfile b/Jenkinsfile index bddd9ea4a2c0..f75261c0450e 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -65,7 +65,7 @@ def unit_test_linux(backend, dev) { def unit_test_win64(backend, dev) { init_git_win64() unpack_lib("dgl-${dev}-win64", dgl_win64_libs) - timeout(time: 10, unit: 'MINUTES') { + timeout(time: 20, unit: 'MINUTES') { bat "CALL tests\\scripts\\task_unit_test.bat ${backend}" } } diff --git a/python/dgl/distributed/rpc_server.py b/python/dgl/distributed/rpc_server.py index ec53e99aad55..ee19eb7d63ec 100644 --- a/python/dgl/distributed/rpc_server.py +++ b/python/dgl/distributed/rpc_server.py @@ -35,8 +35,8 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \ """ assert server_id >= 0, 'server_id (%d) cannot be a negative number.' % server_id assert num_servers > 0, 'num_servers (%d) must be a positive number.' % num_servers - assert num_clients >= 0, 'num_client (%d) cannot be a negative number.' % num_client - assert max_queue_size > 0, 'queue_size (%d) cannot be a negative number.' % queue_size + assert num_clients >= 0, 'num_client (%d) cannot be a negative number.' % num_clients + assert max_queue_size > 0, 'queue_size (%d) cannot be a negative number.' % max_queue_size assert net_type in ('socket'), 'net_type (%s) can only be \'socket\'' % net_type # Register signal handler. rpc.register_sig_handler() diff --git a/src/graph/serialize/zerocopy_serializer.cc b/src/graph/serialize/zerocopy_serializer.cc index 6cf2cd7080ab..cc3f59684ae9 100644 --- a/src/graph/serialize/zerocopy_serializer.cc +++ b/src/graph/serialize/zerocopy_serializer.cc @@ -21,7 +21,7 @@ struct RawDataTensorCtx { void RawDataTensoDLPackDeleter(DLManagedTensor* tensor) { auto ctx = static_cast(tensor->manager_ctx); - free(ctx->tensor.dl_tensor.data); + delete[] ctx->tensor.dl_tensor.data; delete ctx; } diff --git a/src/rpc/rpc.cc b/src/rpc/rpc.cc index e21421e0dc58..9fafeca816ac 100644 --- a/src/rpc/rpc.cc +++ b/src/rpc/rpc.cc @@ -3,21 +3,23 @@ * \file rpc/rpc.cc * \brief Implementation of RPC utilities used by both server and client sides. */ -#include "./rpc.h" - -#include #if defined(__linux__) -#include -#endif +#include "./rpc.h" -#include -#include -#include #include +#include #include +#include +#include #include -#include "../runtime/resource_manager.h" +#include +#include + +#include +#include + #include "../c_api_common.h" +#include "../runtime/resource_manager.h" using dgl::network::StringPrintf; using namespace dgl::runtime; @@ -25,239 +27,248 @@ using namespace dgl::runtime; namespace dgl { namespace rpc { -RPCStatus SendRPCMessage(const RPCMessage& msg, const int32_t target_id) { - std::shared_ptr zerocopy_blob(new std::string()); - StreamWithBuffer zc_write_strm(zerocopy_blob.get(), true); - zc_write_strm.Write(msg); - int32_t nonempty_ndarray_count = zc_write_strm.buffer_list().size(); - zerocopy_blob->append(reinterpret_cast(&nonempty_ndarray_count), - sizeof(int32_t)); - network::Message rpc_meta_msg; - rpc_meta_msg.data = const_cast(zerocopy_blob->data()); - rpc_meta_msg.size = zerocopy_blob->size(); - rpc_meta_msg.deallocator = [zerocopy_blob](network::Message*) {}; - CHECK_EQ(RPCContext::ThreadLocal()->sender->Send( - rpc_meta_msg, target_id), ADD_SUCCESS); - // send real ndarray data - for (auto ptr : zc_write_strm.buffer_list()) { - network::Message ndarray_data_msg; - ndarray_data_msg.data = reinterpret_cast(ptr.data); - if (ptr.size == 0) { - LOG(FATAL) << "Cannot send a empty NDArray."; +using namespace tensorpipe; + +// Borrow from PyTorch + +const char kSocketIfnameEnvVar[] = "TP_SOCKET_IFNAME"; +const char kDefaultUvAddress[] = "127.0.0.1"; + +const std::string& guessAddress() { + static const std::string uvAddress = []() { + tensorpipe::Error error; + std::string result; + char* ifnameEnv = std::getenv(kSocketIfnameEnvVar); + if (ifnameEnv != nullptr) { + std::tie(error, result) = + tensorpipe::transport::uv::lookupAddrForIface(ifnameEnv); + if (error) { + LOG(WARNING) << "Failed to look up the IP address for interface " + << ifnameEnv << " (" << error.what() << "), defaulting to " + << kDefaultUvAddress; + return std::string(kDefaultUvAddress); + } + } else { + std::tie(error, result) = + tensorpipe::transport::uv::lookupAddrForHostname(); + if (error) { + LOG(WARNING) << "Failed to look up the IP address for the hostname (" + << error.what() << "), defaulting to " + << kDefaultUvAddress; + return std::string(kDefaultUvAddress); + } } - ndarray_data_msg.size = ptr.size; - NDArray tensor = ptr.tensor; - ndarray_data_msg.deallocator = [tensor](network::Message*) {}; - CHECK_EQ(RPCContext::ThreadLocal()->sender->Send( - ndarray_data_msg, target_id), ADD_SUCCESS); - } + return result; + }(); + return uvAddress; +} + +RPCStatus SendRPCMessage(const RPCMessage& msg, const int32_t target_id) { + RPCContext::getInstance()->sender->Send(msg, target_id); return kRPCSuccess; } RPCStatus RecvRPCMessage(RPCMessage* msg, int32_t timeout) { // ignore timeout now CHECK_EQ(timeout, 0) << "rpc cannot support timeout now."; - network::Message rpc_meta_msg; - int send_id; - CHECK_EQ(RPCContext::ThreadLocal()->receiver->Recv( - &rpc_meta_msg, &send_id), REMOVE_SUCCESS); - char* count_ptr = rpc_meta_msg.data+rpc_meta_msg.size-sizeof(int32_t); - int32_t nonempty_ndarray_count = *(reinterpret_cast(count_ptr)); - // Recv real ndarray data - std::vector buffer_list(nonempty_ndarray_count); - for (int i = 0; i < nonempty_ndarray_count; ++i) { - network::Message ndarray_data_msg; - CHECK_EQ(RPCContext::ThreadLocal()->receiver->RecvFrom( - &ndarray_data_msg, send_id), REMOVE_SUCCESS); - buffer_list[i] = ndarray_data_msg.data; - } - StreamWithBuffer zc_read_strm(rpc_meta_msg.data, rpc_meta_msg.size-sizeof(int32_t), buffer_list); - zc_read_strm.Read(msg); - rpc_meta_msg.deallocator(&rpc_meta_msg); + RPCContext::getInstance()->receiver->Recv(msg); return kRPCSuccess; } +void InitGlobalTpContext() { + if (!RPCContext::getInstance()->ctx) { + RPCContext::getInstance()->ctx = std::make_shared(); + auto context = RPCContext::getInstance()->ctx; + auto transportContext = tensorpipe::transport::uv::create(); + auto shmtransport = tensorpipe::transport::shm::create(); + context->registerTransport(0 /* priority */, "tcp", transportContext); + // Register basic uv channel + auto basicChannel = tensorpipe::channel::basic::create(); + context->registerChannel(0 /* low priority */, "basic", basicChannel); + + char* numUvThreads_str = std::getenv("DGL_SOCKET_NTHREADS"); + if (numUvThreads_str) { + int numUvThreads = std::atoi(numUvThreads_str); + CHECK(numUvThreads > 0) << "DGL_SOCKET_NTHREADS should be positive integer if set"; + // Register multiplex uv channel + std::vector> contexts; + std::vector> listeners; + for (int i = 0; i < numUvThreads; i++) { + auto context = tensorpipe::transport::uv::create(); + std::string address = guessAddress(); + contexts.push_back(std::move(context)); + listeners.push_back(contexts.back()->listen(address)); + } + auto mptChannel = tensorpipe::channel::mpt::create(std::move(contexts), + std::move(listeners)); + context->registerChannel(20 /* high priority */, "mpt", mptChannel); + } + } +} + //////////////////////////// C APIs //////////////////////////// DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCReset") -.set_body([] (DGLArgs args, DGLRetValue* rv) { - RPCContext::Reset(); -}); +.set_body([](DGLArgs args, DGLRetValue* rv) { RPCContext::Reset(); }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateSender") -.set_body([] (DGLArgs args, DGLRetValue* rv) { +.set_body([](DGLArgs args, DGLRetValue* rv) { int64_t msg_queue_size = args[0]; std::string type = args[1]; - int max_thread_count = args[2]; - if (type.compare("socket") == 0) { - RPCContext::ThreadLocal()->sender = - std::make_shared(msg_queue_size, max_thread_count); - } else { - LOG(FATAL) << "Unknown communicator type for rpc receiver: " << type; - } + InitGlobalTpContext(); + RPCContext::getInstance()->sender = + std::make_shared(RPCContext::getInstance()->ctx); }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateReceiver") -.set_body([] (DGLArgs args, DGLRetValue* rv) { +.set_body([](DGLArgs args, DGLRetValue* rv) { int64_t msg_queue_size = args[0]; std::string type = args[1]; - int max_thread_count = args[2]; - if (type.compare("socket") == 0) { - RPCContext::ThreadLocal()->receiver = - std::make_shared(msg_queue_size, max_thread_count); - } else { - LOG(FATAL) << "Unknown communicator type for rpc sender: " << type; - } + InitGlobalTpContext(); + RPCContext::getInstance()->receiver = + std::make_shared(RPCContext::getInstance()->ctx); }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeSender") -.set_body([] (DGLArgs args, DGLRetValue* rv) { - RPCContext::ThreadLocal()->sender->Finalize(); +.set_body([](DGLArgs args, DGLRetValue* rv) { + RPCContext::getInstance()->sender->Finalize(); }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeReceiver") -.set_body([] (DGLArgs args, DGLRetValue* rv) { - RPCContext::ThreadLocal()->receiver->Finalize(); +.set_body([](DGLArgs args, DGLRetValue* rv) { + RPCContext::getInstance()->receiver->Finalize(); }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCReceiverWait") -.set_body([] (DGLArgs args, DGLRetValue* rv) { +.set_body([](DGLArgs args, DGLRetValue* rv) { std::string ip = args[0]; int port = args[1]; int num_sender = args[2]; std::string addr; - if (RPCContext::ThreadLocal()->receiver->Type() == "socket") { - addr = StringPrintf("socket://%s:%d", ip.c_str(), port); - } else { - LOG(FATAL) << "Unknown communicator type: " << RPCContext::ThreadLocal()->receiver->Type(); - } - if (RPCContext::ThreadLocal()->receiver->Wait(addr.c_str(), num_sender) == false) { + addr = StringPrintf("tcp://%s:%d", ip.c_str(), port); + if (RPCContext::getInstance()->receiver->Wait(addr, num_sender) == false) { LOG(FATAL) << "Wait sender socket failed."; } }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCAddReceiver") -.set_body([] (DGLArgs args, DGLRetValue* rv) { +.set_body([](DGLArgs args, DGLRetValue* rv) { std::string ip = args[0]; int port = args[1]; int recv_id = args[2]; std::string addr; - if (RPCContext::ThreadLocal()->sender->Type() == "socket") { - addr = StringPrintf("socket://%s:%d", ip.c_str(), port); - } else { - LOG(FATAL) << "Unknown communicator type: " << RPCContext::ThreadLocal()->sender->Type(); - } - RPCContext::ThreadLocal()->sender->AddReceiver(addr.c_str(), recv_id); + addr = StringPrintf("tcp://%s:%d", ip.c_str(), port); + RPCContext::getInstance()->sender->AddReceiver(addr, recv_id); }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSenderConnect") -.set_body([] (DGLArgs args, DGLRetValue* rv) { - if (RPCContext::ThreadLocal()->sender->Connect() == false) { +.set_body([](DGLArgs args, DGLRetValue* rv) { + if (RPCContext::getInstance()->sender->Connect() == false) { LOG(FATAL) << "Sender connection failed."; } }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetRank") -.set_body([] (DGLArgs args, DGLRetValue* rv) { +.set_body([](DGLArgs args, DGLRetValue* rv) { const int32_t rank = args[0]; - RPCContext::ThreadLocal()->rank = rank; + RPCContext::getInstance()->rank = rank; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetRank") -.set_body([] (DGLArgs args, DGLRetValue* rv) { - *rv = RPCContext::ThreadLocal()->rank; +.set_body([](DGLArgs args, DGLRetValue* rv) { + *rv = RPCContext::getInstance()->rank; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumServer") -.set_body([] (DGLArgs args, DGLRetValue* rv) { +.set_body([](DGLArgs args, DGLRetValue* rv) { const int32_t num_servers = args[0]; - *rv = RPCContext::ThreadLocal()->num_servers = num_servers; + *rv = RPCContext::getInstance()->num_servers = num_servers; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServer") -.set_body([] (DGLArgs args, DGLRetValue* rv) { - *rv = RPCContext::ThreadLocal()->num_servers; +.set_body([](DGLArgs args, DGLRetValue* rv) { + *rv = RPCContext::getInstance()->num_servers; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumClient") -.set_body([] (DGLArgs args, DGLRetValue* rv) { +.set_body([](DGLArgs args, DGLRetValue* rv) { const int32_t num_clients = args[0]; - *rv = RPCContext::ThreadLocal()->num_clients = num_clients; + *rv = RPCContext::getInstance()->num_clients = num_clients; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumClient") -.set_body([] (DGLArgs args, DGLRetValue* rv) { - *rv = RPCContext::ThreadLocal()->num_clients; +.set_body([](DGLArgs args, DGLRetValue* rv) { + *rv = RPCContext::getInstance()->num_clients; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumServerPerMachine") -.set_body([] (DGLArgs args, DGLRetValue* rv) { +.set_body([](DGLArgs args, DGLRetValue* rv) { const int32_t num_servers = args[0]; - *rv = RPCContext::ThreadLocal()->num_servers_per_machine = num_servers; + *rv = RPCContext::getInstance()->num_servers_per_machine = num_servers; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServerPerMachine") -.set_body([] (DGLArgs args, DGLRetValue* rv) { - *rv = RPCContext::ThreadLocal()->num_servers_per_machine; +.set_body([](DGLArgs args, DGLRetValue* rv) { + *rv = RPCContext::getInstance()->num_servers_per_machine; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCIncrMsgSeq") -.set_body([] (DGLArgs args, DGLRetValue* rv) { - *rv = (RPCContext::ThreadLocal()->msg_seq)++; +.set_body([](DGLArgs args, DGLRetValue* rv) { + *rv = (RPCContext::getInstance()->msg_seq)++; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMsgSeq") -.set_body([] (DGLArgs args, DGLRetValue* rv) { - *rv = RPCContext::ThreadLocal()->msg_seq; +.set_body([](DGLArgs args, DGLRetValue* rv) { + *rv = RPCContext::getInstance()->msg_seq; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMsgSeq") -.set_body([] (DGLArgs args, DGLRetValue* rv) { +.set_body([](DGLArgs args, DGLRetValue* rv) { const int64_t msg_seq = args[0]; - RPCContext::ThreadLocal()->msg_seq = msg_seq; + RPCContext::getInstance()->msg_seq = msg_seq; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetBarrierCount") -.set_body([] (DGLArgs args, DGLRetValue* rv) { - *rv = RPCContext::ThreadLocal()->barrier_count; +.set_body([](DGLArgs args, DGLRetValue* rv) { + *rv = RPCContext::getInstance()->barrier_count; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetBarrierCount") -.set_body([] (DGLArgs args, DGLRetValue* rv) { +.set_body([](DGLArgs args, DGLRetValue* rv) { const int32_t count = args[0]; - RPCContext::ThreadLocal()->barrier_count = count; + RPCContext::getInstance()->barrier_count = count; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMachineID") -.set_body([] (DGLArgs args, DGLRetValue* rv) { - *rv = RPCContext::ThreadLocal()->machine_id; +.set_body([](DGLArgs args, DGLRetValue* rv) { + *rv = RPCContext::getInstance()->machine_id; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMachineID") -.set_body([] (DGLArgs args, DGLRetValue* rv) { +.set_body([](DGLArgs args, DGLRetValue* rv) { const int32_t machine_id = args[0]; - RPCContext::ThreadLocal()->machine_id = machine_id; + RPCContext::getInstance()->machine_id = machine_id; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumMachines") -.set_body([] (DGLArgs args, DGLRetValue* rv) { - *rv = RPCContext::ThreadLocal()->num_machines; +.set_body([](DGLArgs args, DGLRetValue* rv) { + *rv = RPCContext::getInstance()->num_machines; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumMachines") -.set_body([] (DGLArgs args, DGLRetValue* rv) { +.set_body([](DGLArgs args, DGLRetValue* rv) { const int32_t num_machines = args[0]; - RPCContext::ThreadLocal()->num_machines = num_machines; + RPCContext::getInstance()->num_machines = num_machines; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSendRPCMessage") -.set_body([] (DGLArgs args, DGLRetValue* rv) { +.set_body([](DGLArgs args, DGLRetValue* rv) { RPCMessageRef msg = args[0]; const int32_t target_id = args[1]; *rv = SendRPCMessage(*(msg.sptr()), target_id); }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCRecvRPCMessage") -.set_body([] (DGLArgs args, DGLRetValue* rv) { +.set_body([](DGLArgs args, DGLRetValue* rv) { int32_t timeout = args[0]; RPCMessageRef msg = args[1]; *rv = RecvRPCMessage(msg.sptr().get(), timeout); @@ -266,57 +277,67 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCRecvRPCMessage") //////////////////////////// RPCMessage //////////////////////////// DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateEmptyRPCMessage") -.set_body([] (DGLArgs args, DGLRetValue* rv) { +.set_body([](DGLArgs args, DGLRetValue* rv) { + std::shared_ptr rst(new RPCMessage); + *rv = rst; +}); + +DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateEmptyRPCMessageWithSize") +.set_body([](DGLArgs args, DGLRetValue* rv) { + int64_t message_size = args[0]; + std::shared_ptr rst(new RPCMessage); *rv = rst; }); + DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateRPCMessage") -.set_body([] (DGLArgs args, DGLRetValue* rv) { +.set_body([](DGLArgs args, DGLRetValue* rv) { std::shared_ptr rst(new RPCMessage); rst->service_id = args[0]; rst->msg_seq = args[1]; rst->client_id = args[2]; rst->server_id = args[3]; - const std::string data = args[4]; // directly assigning string value raises errors :( + const std::string data = + args[4]; // directly assigning string value raises errors :( rst->data = data; rst->tensors = ListValueToVector(args[5]); *rv = rst; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetServiceId") -.set_body([] (DGLArgs args, DGLRetValue* rv) { +.set_body([](DGLArgs args, DGLRetValue* rv) { const RPCMessageRef msg = args[0]; *rv = msg->service_id; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetMsgSeq") -.set_body([] (DGLArgs args, DGLRetValue* rv) { +.set_body([](DGLArgs args, DGLRetValue* rv) { const RPCMessageRef msg = args[0]; *rv = msg->msg_seq; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetClientId") -.set_body([] (DGLArgs args, DGLRetValue* rv) { +.set_body([](DGLArgs args, DGLRetValue* rv) { const RPCMessageRef msg = args[0]; *rv = msg->client_id; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetServerId") -.set_body([] (DGLArgs args, DGLRetValue* rv) { +.set_body([](DGLArgs args, DGLRetValue* rv) { const RPCMessageRef msg = args[0]; *rv = msg->server_id; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetData") -.set_body([] (DGLArgs args, DGLRetValue* rv) { +.set_body([](DGLArgs args, DGLRetValue* rv) { const RPCMessageRef msg = args[0]; DGLByteArray barr{msg->data.c_str(), msg->data.size()}; *rv = barr; }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetTensors") -.set_body([] (DGLArgs args, DGLRetValue* rv) { +.set_body([](DGLArgs args, DGLRetValue* rv) { const RPCMessageRef msg = args[0]; List ret; for (size_t i = 0; i < msg->tensors.size(); ++i) { @@ -337,7 +358,7 @@ void SigHandler(int s) { } DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCHandleSignal") -.set_body([] (DGLArgs args, DGLRetValue* rv) { +.set_body([](DGLArgs args, DGLRetValue* rv) { // Ctrl+C handler struct sigaction sigHandler; sigHandler.sa_handler = SigHandler; @@ -351,10 +372,10 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCHandleSignal") //////////////////////////// ServerState //////////////////////////// DGL_REGISTER_GLOBAL("distributed.server_state._CAPI_DGLRPCGetServerState") -.set_body([] (DGLArgs args, DGLRetValue* rv) { - auto st = RPCContext::ThreadLocal()->server_state; +.set_body([](DGLArgs args, DGLRetValue* rv) { + auto st = RPCContext::getInstance()->server_state; if (st.get() == nullptr) { - RPCContext::ThreadLocal()->server_state = std::make_shared(); + RPCContext::getInstance()->server_state = std::make_shared(); } *rv = st; }); @@ -362,7 +383,7 @@ DGL_REGISTER_GLOBAL("distributed.server_state._CAPI_DGLRPCGetServerState") //////////////////////////// KVStore //////////////////////////// DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetGlobalIDFromLocalPartition") -.set_body([] (DGLArgs args, DGLRetValue* rv) { +.set_body([](DGLArgs args, DGLRetValue* rv) { NDArray ID = args[0]; NDArray part_id = args[1]; int local_machine_id = args[2]; @@ -380,7 +401,7 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetGlobalIDFromLocalPartition") }); DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull") -.set_body([] (DGLArgs args, DGLRetValue* rv) { +.set_body([](DGLArgs args, DGLRetValue* rv) { // Input std::string name = args[0]; int local_machine_id = args[1]; @@ -403,8 +424,8 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull") std::vector local_ids; std::vector local_ids_orginal; std::vector local_data_shape; - std::vector > remote_ids(machine_count); - std::vector > remote_ids_original(machine_count); + std::vector> remote_ids(machine_count); + std::vector> remote_ids_original(machine_count); // Get row size (in bytes) int row_size = 1; for (int i = 0; i < local_data->ndim; ++i) { @@ -443,8 +464,8 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull") msg.service_id = service_id; msg.msg_seq = msg_seq; msg.client_id = client_id; - int lower = i*group_count; - int upper = (i+1)*group_count; + int lower = i * group_count; + int upper = (i + 1) * group_count; msg.server_id = dgl::RandomEngine::ThreadLocal()->RandInt(lower, upper); msg.data = pickle_data; NDArray tensor = dgl::aten::VecToIdArray(remote_ids[i]); @@ -461,12 +482,12 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull") // Copy local data parallel_for(0, local_ids.size(), [&](size_t b, size_t e) { for (auto i = b; i < e; ++i) { - CHECK_GE(ID_size*row_size, local_ids_orginal[i]*row_size+row_size); + CHECK_GE(ID_size * row_size, + local_ids_orginal[i] * row_size + row_size); CHECK_GE(data_size, local_ids[i] * row_size + row_size); CHECK_GE(local_ids[i], 0); memcpy(return_data + local_ids_orginal[i] * row_size, - local_data_char + local_ids[i] * row_size, - row_size); + local_data_char + local_ids[i] * row_size, row_size); } }); // Recv remote message @@ -478,8 +499,7 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull") dgl_id_t id_size = remote_ids[part_id].size(); for (size_t n = 0; n < id_size; ++n) { memcpy(return_data + remote_ids_original[part_id][n] * row_size, - data_char + n * row_size, - row_size); + data_char + n * row_size, row_size); } } *rv = res_tensor; @@ -487,3 +507,5 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull") } // namespace rpc } // namespace dgl + +#endif diff --git a/src/rpc/rpc.h b/src/rpc/rpc.h index aa2f99289400..050272f7a5fb 100644 --- a/src/rpc/rpc.h +++ b/src/rpc/rpc.h @@ -12,17 +12,21 @@ #include #include #include +#include #include #include -#include "./network/communicator.h" -#include "./network/socket_communicator.h" -#include "./network/msg_queue.h" +#include + +#include "./rpc_msg.h" +#include "./tensorpipe/tp_communicator.h" #include "./network/common.h" #include "./server_state.h" namespace dgl { namespace rpc { +struct RPCContext; + // Communicator handler type typedef void* CommunicatorHandle; @@ -31,8 +35,8 @@ struct RPCContext { /*! * \brief Rank of this process. * - * If the process is a client, this is equal to client ID. Otherwise, the process - * is a server and this is equal to server ID. + * If the process is a client, this is equal to client ID. Otherwise, the + * process is a server and this is equal to server ID. */ int32_t rank = -1; @@ -49,7 +53,7 @@ struct RPCContext { /*! * \brief Message sequence number. */ - int64_t msg_seq = 0; + std::atomic msg_seq{0}; /*! * \brief Total number of server. @@ -74,12 +78,17 @@ struct RPCContext { /*! * \brief Sender communicator. */ - std::shared_ptr sender; + std::shared_ptr sender; /*! * \brief Receiver communicator. */ - std::shared_ptr receiver; + std::shared_ptr receiver; + + /*! + * \brief Tensorpipe global context + */ + std::shared_ptr ctx; /*! * \brief Server state data. @@ -92,74 +101,27 @@ struct RPCContext { */ std::shared_ptr server_state; - /*! \brief Get the thread-local RPC context structure */ - static RPCContext *ThreadLocal() { - return dmlc::ThreadLocalStore::Get(); + /*! \brief Get the RPC context singleton */ + static RPCContext* getInstance() { + static RPCContext ctx; + return &ctx; } /*! \brief Reset the RPC context */ static void Reset() { - auto* t = ThreadLocal(); + auto* t = getInstance(); t->rank = -1; t->machine_id = -1; t->num_machines = 0; t->num_clients = 0; t->barrier_count = 0; t->num_servers_per_machine = 0; - t->sender = std::shared_ptr(); - t->receiver = std::shared_ptr(); - } -}; - -/*! \brief RPC message data structure - * - * This structure is exposed to Python and can be used as argument or return value - * in C API. - */ -struct RPCMessage : public runtime::Object { - /*! \brief Service ID */ - int32_t service_id; - - /*! \brief Sequence number of this message. */ - int64_t msg_seq; - - /*! \brief Client ID. */ - int32_t client_id; - - /*! \brief Server ID. */ - int32_t server_id; - - /*! \brief Payload buffer carried by this request.*/ - std::string data; - - /*! \brief Extra payloads in the form of tensors.*/ - std::vector tensors; - - bool Load(dmlc::Stream* stream) { - stream->Read(&service_id); - stream->Read(&msg_seq); - stream->Read(&client_id); - stream->Read(&server_id); - stream->Read(&data); - stream->Read(&tensors); - return true; - } - - void Save(dmlc::Stream* stream) const { - stream->Write(service_id); - stream->Write(msg_seq); - stream->Write(client_id); - stream->Write(server_id); - stream->Write(data); - stream->Write(tensors); + t->sender.reset(); + t->receiver.reset(); + t->ctx.reset(); } - - static constexpr const char* _type_key = "rpc.RPCMessage"; - DGL_DECLARE_OBJECT_TYPE_INFO(RPCMessage, runtime::Object); }; -DGL_DEFINE_OBJECT_REF(RPCMessageRef, RPCMessage); - /*! \brief RPC status flag */ enum RPCStatus { kRPCSuccess = 0, diff --git a/src/rpc/rpc_msg.h b/src/rpc/rpc_msg.h new file mode 100644 index 000000000000..6031eb047f7f --- /dev/null +++ b/src/rpc/rpc_msg.h @@ -0,0 +1,68 @@ +/*! + * Copyright (c) 2020 by Contributors + * \file rpc/rpc_msg.h + * \brief Common headers for remote process call (RPC). + */ +#ifndef DGL_RPC_RPC_MSG_H_ +#define DGL_RPC_RPC_MSG_H_ + +#include + +#include +#include + +namespace dgl { +namespace rpc { + +/*! \brief RPC message data structure + * + * This structure is exposed to Python and can be used as argument or return + * value in C API. + */ +struct RPCMessage : public runtime::Object { + /*! \brief Service ID */ + int32_t service_id; + + /*! \brief Sequence number of this message. */ + int64_t msg_seq; + + /*! \brief Client ID. */ + int32_t client_id; + + /*! \brief Server ID. */ + int32_t server_id; + + /*! \brief Payload buffer carried by this request.*/ + std::string data; + + /*! \brief Extra payloads in the form of tensors.*/ + std::vector tensors; + + bool Load(dmlc::Stream* stream) { + stream->Read(&service_id); + stream->Read(&msg_seq); + stream->Read(&client_id); + stream->Read(&server_id); + stream->Read(&data); + stream->Read(&tensors); + return true; + } + + void Save(dmlc::Stream* stream) const { + stream->Write(service_id); + stream->Write(msg_seq); + stream->Write(client_id); + stream->Write(server_id); + stream->Write(data); + stream->Write(tensors); + } + + static constexpr const char* _type_key = "rpc.RPCMessage"; + DGL_DECLARE_OBJECT_TYPE_INFO(RPCMessage, runtime::Object); +}; + +DGL_DEFINE_OBJECT_REF(RPCMessageRef, RPCMessage); + +} // namespace rpc +} // namespace dgl +#endif // DGL_RPC_RPC_MSG_H_ diff --git a/src/rpc/tensorpipe/README.md b/src/rpc/tensorpipe/README.md new file mode 100644 index 000000000000..6b356d1f03e6 --- /dev/null +++ b/src/rpc/tensorpipe/README.md @@ -0,0 +1,104 @@ +# Introduction to tensorpipe + +## Process of setup communication: +```cpp +context = std::make_shared(); +// For Receiver +// Create listener to accept join request +listener = context->listen({addr}); +// Accept join request and generate pipe +std::promise> pipeProm; +listener->accept([&](const Error& error, std::shared_ptr pipe) { + if (error) { + LOG(WARNING) << error.what(); + } + pipeProm.set_value(std::move(pipe)); +}); +std::shared_ptr pipe = pipeProm.get_future().get(); + +// For Sender +pipe = context->connect(addr); +// Note that the pipe may not be really available at this point +// For example if no listener listening the address, there won't be error raised +// The error will happen at the write/read operation. Thus we need to manually check this +std::promise done; +tensorpipe::Message tpmsg; +tpmsg.metadata = "dglconnect"; +pipe->write(tpmsg, [&done](const tensorpipe::Error& error) { + if (error) { + done.set_value(false); + } else { + done.set_value(true); + } +}); +if (done.get_future().get()) { + break; +} else { + sleep(5); + LOG(INFO) << "Cannot connect to remove server. Wait to retry"; +} +``` + +## Read and Write + +Message structure: https://github.com/pytorch/tensorpipe/blob/master/tensorpipe/core/message.h + +There are three concepts, Message, Descriptor and Allocation. +Message is the core struct for communication. Message contains three major field, metadata(string), payload(cpu memory buffers), tensor(cpu/gpu memory buffer, with device as attribute). + +Descriptor and Allocation are for the read scenario. A typical read operation as follows + +```cpp +pipe->readDescriptor( + [](const Error& error, Descriptor descriptor) { + // Descriptor contains metadata of the message, the data size of each payload, the device information of tensors and other metadatas other than the real buffer + // User should allocate the proper memory based on the descriptor, and set back the allocated memory to Allocation object + Allocation allocation; + // Then call pipe->read to ask pipe to receive the real buffer into allocations + pipe->read(allocation, [](const Error& error) {}); + }); +``` + +To send the message is much simpler +```cpp +// Resource cleaning should be handled in the callback +pipe->write(message, callback_fn) +``` + +## Register the underlying communication channel +There are two concept, transport and channel. +Transport is the basic component for communication like sockets, which only supports cpu buffers. +Channel is higher abstraction over transport, which can support gpu buffers, or utilize multiple transport method to acceelerate communication + +Tensorpipe will try to setup the channel based on priority. + +```cpp +// Register transport +auto context = std::make_shared(); +// uv is short for libuv, using epoll with sockets to communicate +auto transportContext = tensorpipe::transport::uv::create(); +context->registerTransport(0 /* priority */, "tcp", transportContext);/ +// basic channel just use the bare transport to communicate +auto basicChannel = tensorpipe::channel::basic::create(); +context->registerChannel(0, "basic", basicChannel); +// Below is the mpt(multiplex transport) channel, which can use multiple uv transport to increase throughput +std::vector> contexts = { + tensorpipe::transport::uv::create(), tensorpipe::transport::uv::create(), + tensorpipe::transport::uv::create()}; +std::vector> listeners = { + contexts[0]->listen("127.0.0.1"), contexts[1]->listen("127.0.0.1"), + contexts[2]->listen("127.0.0.1")}; +auto mptChannel = tensorpipe::channel::mpt::create( + std::move(contexts), std::move(listeners)); +context->registerChannel(10, "mpt", mptChannel); +``` + +There are more channels supported by tensorpipe, such as CUDA IPC (for cuda communication on the same machine), CMA(using shared memory on the same machine), CUDA GDR(using infiniband with CUDA GPUDirect for gpu buffer), CUDA Basic(using socket+seperate thread to copy buffer to CUDA memory. + +Quote from tensorpipe: + +Backends come in two flavors: + +Transports are the connections used by the pipes to transfer control messages, and the (smallish) core payloads. They are meant to be lightweight and low-latency. The most basic transport is a simple TCP one, which should work in all scenarios. A more optimized one, for example, is based on a ring buffer allocated in shared memory, which two processes on the same machine can use to communicate by performing just a memory copy, without passing through the kernel. + +Channels are where the heavy lifting takes place, as they take care of copying the (larger) tensor data. High bandwidths are a requirement. Examples include multiplexing chunks of data across multiple TCP sockets and processes, so to saturate the NIC's bandwidth. Or using a CUDA memcpy call to transfer memory from one GPU to another using NVLink. \ No newline at end of file diff --git a/src/rpc/tensorpipe/queue.h b/src/rpc/tensorpipe/queue.h new file mode 100644 index 000000000000..3c265379cac3 --- /dev/null +++ b/src/rpc/tensorpipe/queue.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#ifndef DGL_RPC_TENSORPIPE_QUEUE_H_ +#define DGL_RPC_TENSORPIPE_QUEUE_H_ + +#include +#include +#include + +namespace dgl { +namespace rpc { + +template +class Queue { + public: + // Capacity isn't used actually + explicit Queue(int capacity = 1) : capacity_(capacity) {} + + void push(T t) { + std::unique_lock lock(mutex_); + // while (items_.size() >= capacity_) { + // cv_.wait(lock); + // } + items_.push_back(std::move(t)); + cv_.notify_all(); + } + + T pop() { + std::unique_lock lock(mutex_); + while (items_.size() == 0) { + cv_.wait(lock); + } + T t(std::move(items_.front())); + items_.pop_front(); + cv_.notify_all(); + return t; + } + + private: + std::mutex mutex_; + std::condition_variable cv_; + const int capacity_; + std::deque items_; +}; +} // namespace rpc +} // namespace dgl + +#endif // DGL_RPC_TENSORPIPE_QUEUE_H_ diff --git a/src/rpc/tensorpipe/tp_communicator.cc b/src/rpc/tensorpipe/tp_communicator.cc new file mode 100644 index 000000000000..b78d990a57e3 --- /dev/null +++ b/src/rpc/tensorpipe/tp_communicator.cc @@ -0,0 +1,168 @@ +/*! + * Copyright (c) 2019 by Contributors + * \file tp_communicator.cc + * \brief Tensorpipe Communicator for DGL distributed training. + */ + +#include "tp_communicator.h" + +#include +#include + +#include +#include +#include + +#include "../rpc.h" + +namespace dgl { +namespace rpc { + +using namespace tensorpipe; + +void TPSender::AddReceiver(const std::string& addr, int recv_id) { + receiver_addrs_[recv_id] = addr; +} + +bool TPSender::Connect() { + for (const auto& kv : receiver_addrs_) { + std::shared_ptr pipe; + for (;;) { + pipe = context->connect(kv.second); + std::promise done; + tensorpipe::Message tpmsg; + tpmsg.metadata = "dglconnect"; + pipe->write(tpmsg, [&done](const tensorpipe::Error& error) { + if (error) { + done.set_value(false); + } else { + done.set_value(true); + } + }); + if (done.get_future().get()) { + break; + } else { + sleep(5); + LOG(INFO) << "Cannot connect to remove server " << kv.second + << ". Wait to retry"; + } + } + pipes_[kv.first] = pipe; + } + return true; +} + +void TPSender::Send(const RPCMessage& msg, int recv_id) { + auto pipe = pipes_[recv_id]; + tensorpipe::Message tp_msg; + std::string* zerocopy_blob_ptr = &tp_msg.metadata; + StreamWithBuffer zc_write_strm(zerocopy_blob_ptr, true); + zc_write_strm.Write(msg); + int32_t nonempty_ndarray_count = zc_write_strm.buffer_list().size(); + zerocopy_blob_ptr->append(reinterpret_cast(&nonempty_ndarray_count), + sizeof(int32_t)); + tp_msg.tensors.resize(nonempty_ndarray_count); + // Hold the NDArray that ensure it's valid until write operation completes + auto ndarray_holder = std::make_shared>(); + ndarray_holder->resize(nonempty_ndarray_count); + auto& buffer_list = zc_write_strm.buffer_list(); + for (int i = 0; i < buffer_list.size(); i++) { + auto& ptr = buffer_list[i]; + (*ndarray_holder.get())[i] = ptr.tensor; + tensorpipe::CpuBuffer cpu_buffer; + cpu_buffer.ptr = ptr.data; + tp_msg.tensors[i].buffer = cpu_buffer; + tp_msg.tensors[i].length = ptr.size; + if (ptr.size == 0) { + LOG(FATAL) << "Cannot send a empty NDArray."; + } + } + pipe->write(tp_msg, + [ndarray_holder, recv_id](const tensorpipe::Error& error) { + if (error) { + LOG(FATAL) << "Failed to send message to " << recv_id + << ". Details: " << error.what(); + } + }); +} + +void TPSender::Finalize() {} +void TPReceiver::Finalize() {} + +bool TPReceiver::Wait(const std::string& addr, int num_sender) { + listener = context->listen({addr}); + for (int i = 0; i < num_sender; i++) { + std::promise> pipeProm; + listener->accept([&](const Error& error, std::shared_ptr pipe) { + if (error) { + LOG(WARNING) << error.what(); + } + pipeProm.set_value(std::move(pipe)); + }); + std::shared_ptr pipe = pipeProm.get_future().get(); + std::promise checkConnect; + pipe->readDescriptor( + [pipe, &checkConnect](const Error& error, Descriptor descriptor) { + Allocation allocation; + checkConnect.set_value(descriptor.metadata == "dglconnect"); + pipe->read(allocation, [](const Error& error) {}); + }); + CHECK(checkConnect.get_future().get()) << "Invalid connect message."; + pipes_[i] = pipe; + ReceiveFromPipe(pipe, queue_); + } + return true; +} + +void TPReceiver::ReceiveFromPipe(std::shared_ptr pipe, + std::shared_ptr queue) { + pipe->readDescriptor([pipe, queue = std::move(queue)](const Error& error, + Descriptor descriptor) { + if (error) { + // Error may happen when the pipe is closed + return; + } + Allocation allocation; + CHECK_EQ(descriptor.payloads.size(), 0) << "Invalid DGL RPC Message"; + + int tensorsize = descriptor.tensors.size(); + if (tensorsize > 0) { + allocation.tensors.resize(tensorsize); + for (int i = 0; i < descriptor.tensors.size(); i++) { + tensorpipe::CpuBuffer cpu_buffer; + cpu_buffer.ptr = new char[descriptor.tensors[i].length]; + allocation.tensors[i].buffer = cpu_buffer; + } + } + pipe->read( + allocation, [allocation, descriptor = std::move(descriptor), + queue = std::move(queue), pipe](const Error& error) { + if (error) { + // Because we always have a read event posted to the epoll, + // Therefore when pipe is closed, error will be raised. + // But this error is expected. + // Other error is not expected. But we cannot identify the error with each + // Other for now. Thus here we skip handling for all errors + return; + } + + char* meta_msg_begin = const_cast(&descriptor.metadata[0]); + std::vector buffer_list(descriptor.tensors.size()); + for (int i = 0; i < descriptor.tensors.size(); i++) { + buffer_list[i] = allocation.tensors[i].buffer.unwrap().ptr; + } + StreamWithBuffer zc_read_strm( + meta_msg_begin, descriptor.metadata.size() - sizeof(int32_t), + buffer_list); + RPCMessage msg; + zc_read_strm.Read(&msg); + queue->push(msg); + TPReceiver::ReceiveFromPipe(pipe, queue); + }); + }); +} + +void TPReceiver::Recv(RPCMessage* msg) { *msg = std::move(queue_->pop()); } + +} // namespace rpc +} // namespace dgl diff --git a/src/rpc/tensorpipe/tp_communicator.h b/src/rpc/tensorpipe/tp_communicator.h new file mode 100644 index 000000000000..3650f0d44145 --- /dev/null +++ b/src/rpc/tensorpipe/tp_communicator.h @@ -0,0 +1,186 @@ +/*! + * Copyright (c) 2019 by Contributors + * \file tp_communicator.h + * \brief Tensorpipe Communicator for DGL distributed training. + */ +#ifndef DGL_RPC_TENSORPIPE_TP_COMMUNICATOR_H_ +#define DGL_RPC_TENSORPIPE_TP_COMMUNICATOR_H_ + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "./queue.h" + +namespace dgl { +namespace rpc { + +class RPCMessage; + +typedef Queue RPCMessageQueue; + +/*! + * \brief TPSender for DGL distributed training. + * + * TPSender is the communicator implemented by tcp socket. + */ +class TPSender { + public: + /*! + * \brief Sender constructor + * \param queue_size size of message queue + */ + explicit TPSender(std::shared_ptr ctx) { + CHECK(ctx) << "Context is not initialized"; + this->context = ctx; + } + + /*! + * \brief Add receiver's address and ID to the sender's namebook + * \param addr Networking address, e.g., 'tcp://127.0.0.1:50091' + * \param id receiver's ID + * + * AddReceiver() is not thread-safe and only one thread can invoke this API. + */ + void AddReceiver(const std::string& addr, int recv_id); + + /*! + * \brief Connect with all the Receivers + * \return True for success and False for fail + * + * Connect() is not thread-safe and only one thread can invoke this API. + */ + bool Connect(); + + /*! + * \brief Send RPCMessage to specified Receiver. + * \param msg data message \param recv_id receiver's ID + */ + void Send(const RPCMessage& msg, int recv_id); + + /*! + * \brief Finalize TPSender + */ + void Finalize(); + + /*! + * \brief Communicator type: 'tp' + */ + inline std::string Type() const { return std::string("tp"); } + + private: + /*! + * \brief global context of tensorpipe + */ + std::shared_ptr context; + + /*! + * \brief pipe for each connection of receiver + */ + std::unordered_map> + pipes_; + + /*! + * \brief receivers' listening address + */ + std::unordered_map receiver_addrs_; +}; + +/*! + * \brief TPReceiver for DGL distributed training. + * + * Tensorpipe Receiver is the communicator implemented by tcp socket. + */ +class TPReceiver { + public: + /*! + * \brief Receiver constructor + * \param queue_size size of message queue. + */ + explicit TPReceiver(std::shared_ptr ctx) { + CHECK(ctx) << "Context is not initialized"; + this->context = ctx; + queue_ = std::make_shared(); + } + + /*! + * \brief Wait for all the Senders to connect + * \param addr Networking address, e.g., 'tcp://127.0.0.1:50051' + * \param num_sender total number of Senders + * \return True for success and False for fail + * + * Wait() is not thread-safe and only one thread can invoke this API. + */ + bool Wait(const std::string& addr, int num_sender); + + /*! + * \brief Recv RPCMessage from Sender. Actually removing data from queue. + * \param msg pointer of RPCmessage + * \param send_id which sender current msg comes from + * \return Status code + * + * (1) The Recv() API is blocking, which will not + * return until getting data from message queue. + * (2) The Recv() API is thread-safe. + * (3) Memory allocated by communicator but will not own it after the function + * returns. + */ + void Recv(RPCMessage* msg); + + /*! + * \brief Finalize SocketReceiver + * + * Finalize() is not thread-safe and only one thread can invoke this API. + */ + void Finalize(); + + /*! + * \brief Communicator type: 'tp' (tensorpipe) + */ + inline std::string Type() const { return std::string("tp"); } + + /*! + * \brief Issue a receive request on pipe, and push the result into queue + */ + static void ReceiveFromPipe(std::shared_ptr pipe, + std::shared_ptr queue); + + private: + /*! + * \brief number of sender + */ + int num_sender_; + + /*! + * \brief listener to build pipe + */ + std::shared_ptr listener; + + /*! + * \brief global context of tensorpipe + */ + std::shared_ptr context; + + /*! + * \brief pipe for each client connections + */ + std::unordered_map> + pipes_; + + /*! + * \brief RPCMessage queue + */ + std::shared_ptr queue_; +}; + +} // namespace rpc +} // namespace dgl + +#endif // DGL_RPC_TENSORPIPE_TP_COMMUNICATOR_H_ diff --git a/tests/distributed/test_rpc.py b/tests/distributed/test_rpc.py index f68905a36094..b53d80c501bf 100644 --- a/tests/distributed/test_rpc.py +++ b/tests/distributed/test_rpc.py @@ -107,12 +107,13 @@ def process_request(self, server_state): res = HelloResponse(self.hello_str, self.integer, new_tensor) return res -def start_server(num_clients, ip_config): +def start_server(num_clients, ip_config, server_id=0): print("Sleep 5 seconds to test client re-connect.") time.sleep(5) server_state = dgl.distributed.ServerState(None, local_g=None, partition_book=None) dgl.distributed.register_service(HELLO_SERVICE_ID, HelloRequest, HelloResponse) - dgl.distributed.start_server(server_id=0, + print("Start server {}".format(server_id)) + dgl.distributed.start_server(server_id=server_id, ip_config=ip_config, num_servers=1, num_clients=num_clients, @@ -224,8 +225,50 @@ def test_multi_client(): pserver.join() +@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') +def test_multi_thread_rpc(): + os.environ['DGL_DIST_MODE'] = 'distributed' + ip_config = open("rpc_ip_config_multithread.txt", "w") + num_servers = 2 + for _ in range(num_servers): # 3 servers + ip_config.write('{}\n'.format(get_local_usable_addr())) + ip_config.close() + ctx = mp.get_context('spawn') + pserver_list = [] + for i in range(num_servers): + pserver = ctx.Process(target=start_server, args=(1, "rpc_ip_config_multithread.txt", i)) + pserver.start() + pserver_list.append(pserver) + def start_client_multithread(ip_config): + import threading + dgl.distributed.connect_to_server(ip_config=ip_config, num_servers=1) + dgl.distributed.register_service(HELLO_SERVICE_ID, HelloRequest, HelloResponse) + + req = HelloRequest(STR, INTEGER, TENSOR, simple_func) + dgl.distributed.send_request(0, req) + + def subthread_call(server_id): + req = HelloRequest(STR, INTEGER, TENSOR+ server_id, simple_func) + dgl.distributed.send_request(server_id, req) + + + subthread = threading.Thread(target=subthread_call, args=(1,)) + subthread.start() + subthread.join() + + res0 = dgl.distributed.recv_response() + res1 = dgl.distributed.recv_response() + assert_array_equal(F.asnumpy(res0.tensor), F.asnumpy(TENSOR)) + assert_array_equal(F.asnumpy(res1.tensor), F.asnumpy(TENSOR+1)) + dgl.distributed.exit_client() + + start_client_multithread("rpc_ip_config_multithread.txt") + pserver.join() + + if __name__ == '__main__': test_serialize() test_rpc_msg() test_rpc() test_multi_client() + test_multi_thread_rpc() \ No newline at end of file diff --git a/third_party/tensorpipe b/third_party/tensorpipe new file mode 160000 index 000000000000..6042f1a4cbce --- /dev/null +++ b/third_party/tensorpipe @@ -0,0 +1 @@ +Subproject commit 6042f1a4cbce8eef997f11ed0012de137b317361 diff --git a/third_party/xbyak b/third_party/xbyak index 0140eeff1fff..757e4063f646 160000 --- a/third_party/xbyak +++ b/third_party/xbyak @@ -1 +1 @@ -Subproject commit 0140eeff1fffcf5069dea3abb57095695320971c +Subproject commit 757e4063f6464740b8ff4a2cae9136d2f8458020