Skip to content

Commit

Permalink
[Dist] Enable maximum try times for socket backend via DGL_DIST_MAX_T… (
Browse files Browse the repository at this point in the history
dmlc#3977)

* [Dist] Enable maximum try times for socket backend via DGL_DIST_MAX_TRY_TIMES

* reset env before/after test

* print log for info when trying to connect

* fix

* print log in python instead of cpp
  • Loading branch information
Rhett-Ying authored May 11, 2022
1 parent 74f0140 commit 22e218d
Show file tree
Hide file tree
Showing 12 changed files with 99 additions and 26 deletions.
28 changes: 25 additions & 3 deletions python/dgl/distributed/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
'wait_for_senders', 'connect_receiver', 'read_ip_config', 'get_group_id', \
'get_num_machines', 'set_num_machines', 'get_machine_id', 'set_machine_id', \
'send_request', 'recv_request', 'send_response', 'recv_response', 'remote_call', \
'send_request_to_machine', 'remote_call_to_machine', 'fast_pull', \
'send_request_to_machine', 'remote_call_to_machine', 'fast_pull', 'DistConnectError', \
'get_num_client', 'set_num_client', 'client_barrier', 'copy_data_to_shared_memory']

REQUEST_CLASS_TO_SERVICE_ID = {}
Expand Down Expand Up @@ -175,14 +175,19 @@ def connect_receiver(ip_addr, port, recv_id, group_id=-1):
raise DGLError("Invalid target id: {}".format(target_id))
return _CAPI_DGLRPCConnectReceiver(ip_addr, int(port), int(target_id))

def connect_receiver_finalize():
def connect_receiver_finalize(max_try_times):
"""Finalize the action to connect to receivers. Make sure that either all connections are
successfully established or connection fails.
When "socket" network backend is in use, the function issues actual requests to receiver
sockets to establish connections.
Parameters
----------
max_try_times : int
maximum try times
"""
_CAPI_DGLRPCConnectReceiverFinalize()
return _CAPI_DGLRPCConnectReceiverFinalize(max_try_times)

def set_rank(rank):
"""Set the rank of this process.
Expand Down Expand Up @@ -1231,4 +1236,21 @@ def get_client(client_id, group_id):
"""
return _CAPI_DGLRPCGetClient(int(client_id), int(group_id))

class DistConnectError(DGLError):
"""Exception raised for errors if fail to connect peer.
Attributes
----------
kv_store : KVServer
reference for KVServer
"""

def __init__(self, max_try_times, ip='', port=''):
peer_str = "peer[{}:{}]".format(ip, port) if ip != '' else "peer"
self.message = "Failed to build conncetion with {} after {} retries. " \
"Please check network availability or increase max try " \
"times via 'DGL_DIST_MAX_TRY_TIMES'.".format(
peer_str, max_try_times)
super().__init__(self.message)

_init_api("dgl.distributed.rpc")
11 changes: 10 additions & 1 deletion python/dgl/distributed/rpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,21 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE,
rpc.create_sender(max_queue_size, net_type)
rpc.create_receiver(max_queue_size, net_type)
# Get connected with all server nodes
max_try_times = int(os.environ.get('DGL_DIST_MAX_TRY_TIMES', 1024))
for server_id, addr in server_namebook.items():
server_ip = addr[1]
server_port = addr[2]
try_times = 0
while not rpc.connect_receiver(server_ip, server_port, server_id):
try_times += 1
if try_times % 200 == 0:
print("Client is trying to connect server receiver: {}:{}".format(
server_ip, server_port))
if try_times >= max_try_times:
raise rpc.DistConnectError(max_try_times, server_ip, server_port)
time.sleep(3)
rpc.connect_receiver_finalize()
if not rpc.connect_receiver_finalize(max_try_times):
raise rpc.DistConnectError(max_try_times)
# Get local usable IP address and port
ip_addr = get_local_usable_addr(server_ip)
client_ip, client_port = ip_addr.split(':')
Expand Down
12 changes: 6 additions & 6 deletions python/dgl/distributed/rpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,14 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \
try_times = 0
while not rpc.connect_receiver(client_ip, client_port, client_id, group_id):
try_times += 1
if try_times % 200 == 0:
print("Server~{} is trying to connect client receiver: {}:{}".format(
server_id, client_ip, client_port))
if try_times >= max_try_times:
raise DGLError("Failed to connect to receiver [{}:{}] after {} "
"retries. Please check availability of this target "
"receiver or change the max retry times via "
"'DGL_DIST_MAX_TRY_TIMES'.".format(
client_ip, client_port, max_try_times))
raise rpc.DistConnectError(max_try_times, client_ip, client_port)
time.sleep(1)
rpc.connect_receiver_finalize()
if not rpc.connect_receiver_finalize(max_try_times):
raise rpc.DistConnectError(max_try_times)
if rpc.get_rank() == 0: # server_0 send all the IDs
for client_id, _ in client_namebook.items():
register_res = rpc.ClientRegisterResponse(client_id)
Expand Down
3 changes: 2 additions & 1 deletion src/graph/network.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderConnect")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0];
network::Sender* sender = static_cast<network::Sender*>(chandle);
if (sender->ConnectReceiverFinalize() == false) {
const int max_try_times = 1024;
if (sender->ConnectReceiverFinalize(max_try_times) == false) {
LOG(FATAL) << "Sender connection failed.";
}
});
Expand Down
4 changes: 3 additions & 1 deletion src/rpc/net_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ struct RPCSender : RPCBase {
*
* The function is *not* thread-safe; only one thread can invoke this API.
*/
virtual bool ConnectReceiverFinalize() { return true; }
virtual bool ConnectReceiverFinalize(const int max_try_times) {
return true;
}

/*!
* \brief Send RPCMessage to specified Receiver.
Expand Down
10 changes: 5 additions & 5 deletions src/rpc/network/socket_communicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ bool SocketSender::ConnectReceiver(const std::string& addr, int recv_id) {
return true;
}

bool SocketSender::ConnectReceiverFinalize() {
bool SocketSender::ConnectReceiverFinalize(const int max_try_times) {
// Create N sockets for Receiver
int receiver_count = static_cast<int>(receiver_addrs_.size());
if (max_thread_count_ == 0 || max_thread_count_ > receiver_count) {
Expand All @@ -71,16 +71,16 @@ bool SocketSender::ConnectReceiverFinalize() {
int try_count = 0;
const char* ip = r.second.ip.c_str();
int port = r.second.port;
while (bo == false && try_count < kMaxTryCount) {
while (bo == false && try_count < max_try_times) {
if (client_socket->Connect(ip, port)) {
bo = true;
} else {
if (try_count % 200 == 0 && try_count != 0) {
// every 1000 seconds show this message
LOG(INFO) << "Try to connect to: " << ip << ":" << port;
// every 600 seconds show this message
LOG(INFO) << "Trying to connect receiver: " << ip << ":" << port;
}
try_count++;
std::this_thread::sleep_for(std::chrono::seconds(5));
std::this_thread::sleep_for(std::chrono::seconds(3));
}
}
if (bo == false) {
Expand Down
3 changes: 1 addition & 2 deletions src/rpc/network/socket_communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
namespace dgl {
namespace network {

static constexpr int kMaxTryCount = 1024; // maximal connection: 1024
static constexpr int kTimeOut = 10 * 60; // 10 minutes (in seconds) for socket timeout
static constexpr int kMaxConnection = 1024; // maximal connection: 1024

Expand Down Expand Up @@ -70,7 +69,7 @@ class SocketSender : public Sender {
*
* The function is *not* thread-safe; only one thread can invoke this API.
*/
bool ConnectReceiverFinalize() override;
bool ConnectReceiverFinalize(const int max_try_times) override;

/*!
* \brief Send RPCMessage to specified Receiver.
Expand Down
3 changes: 2 additions & 1 deletion src/rpc/rpc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCConnectReceiver")

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCConnectReceiverFinalize")
.set_body([](DGLArgs args, DGLRetValue* rv) {
RPCContext::getInstance()->sender->ConnectReceiverFinalize();
const int max_try_times = args[0];
*rv = RPCContext::getInstance()->sender->ConnectReceiverFinalize(max_try_times);
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetRank")
Expand Down
12 changes: 9 additions & 3 deletions src/rpc/tensorpipe/tp_communicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,21 @@ void TPSender::Send(const RPCMessage &msg, int recv_id) {

void TPSender::Finalize() {
for (auto &&p : pipes_) {
p.second->close();
if (p.second) {
p.second->close();
}
}
pipes_.clear();
}

void TPReceiver::Finalize() {
listener_->close();
if (listener_) {
listener_->close();
}
for (auto &&p : pipes_) {
p.second->close();
if (p.second) {
p.second->close();
}
}
pipes_.clear();
}
Expand Down
5 changes: 3 additions & 2 deletions tests/cpp/socket_communicator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ using dgl::network::DefaultMessageDeleter;

const int64_t kQueueSize = 500 * 1024;
const int kThreadNum = 2;
const int kMaxTryTimes = 1024;

#ifndef WIN32

Expand Down Expand Up @@ -66,7 +67,7 @@ void start_client() {
for (int i = 0; i < kNumReceiver; ++i) {
sender.ConnectReceiver(ip_addr[i], i);
}
sender.ConnectReceiverFinalize();
sender.ConnectReceiverFinalize(kMaxTryTimes);
for (int i = 0; i < kNumMessage; ++i) {
for (int n = 0; n < kNumReceiver; ++n) {
char* str_data = new char[9];
Expand Down Expand Up @@ -171,7 +172,7 @@ static void start_client() {
t.close();
SocketSender sender(kQueueSize, kThreadNum);
sender.ConnectReceiver(ip_addr.c_str(), 0);
sender.ConnectReceiverFinalize();
sender.ConnectReceiverFinalize(kMaxTryTimes);
char* str_data = new char[9];
memcpy(str_data, "123456789", 9);
Message msg = {str_data, 9};
Expand Down
31 changes: 31 additions & 0 deletions tests/distributed/test_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,35 @@ def test_multi_client_groups():
for p in pserver_list:
p.join()

@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@pytest.mark.parametrize("net_type", ['socket', 'tensorpipe'])
def test_multi_client_connect(net_type):
reset_envs()
os.environ['DGL_DIST_MODE'] = 'distributed'
ip_config = "rpc_ip_config_mul_client.txt"
generate_ip_config(ip_config, 1, 1)
ctx = mp.get_context('spawn')
num_clients = 1
pserver = ctx.Process(target=start_server, args=(num_clients, ip_config, 0, False, 1, net_type))

# small max try times
os.environ['DGL_DIST_MAX_TRY_TIMES'] = '1'
expect_except = False
try:
start_client(ip_config, 0, 1, net_type)
except dgl.distributed.DistConnectError as err:
print("Expected error: {}".format(err))
expect_except = True
assert expect_except

# large max try times
os.environ['DGL_DIST_MAX_TRY_TIMES'] = '1024'
pclient = ctx.Process(target=start_client, args=(ip_config, 0, 1, net_type))
pclient.start()
pserver.start()
pclient.join()
pserver.join()
reset_envs()

if __name__ == '__main__':
test_serialize()
Expand All @@ -286,3 +315,5 @@ def test_multi_client_groups():
test_multi_client('socket')
test_multi_client('tesnsorpipe')
test_multi_thread_rpc()
test_multi_client_connect('socket')
test_multi_client_connect('tensorpipe')
3 changes: 2 additions & 1 deletion tests/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def generate_ip_config(file_name, num_machines, num_servers):

def reset_envs():
"""Reset common environment variable which are set in tests. """
for key in ['DGL_ROLE', 'DGL_NUM_SAMPLER', 'DGL_NUM_SERVER', 'DGL_DIST_MODE', 'DGL_NUM_CLIENT']:
for key in ['DGL_ROLE', 'DGL_NUM_SAMPLER', 'DGL_NUM_SERVER', \
'DGL_DIST_MODE', 'DGL_NUM_CLIENT', 'DGL_DIST_MAX_TRY_TIMES']:
if key in os.environ:
os.environ.pop(key)

0 comments on commit 22e218d

Please sign in to comment.