Skip to content

Commit

Permalink
[Feature][Dist] change TP::Receiver/TP::Sender for multiple connectio…
Browse files Browse the repository at this point in the history
…ns (dmlc#3574)

* [Feature] enable TP::Receiver wait for any numbers of senders

* fix random unit test failure

* avoid endless future wait

* fix unit test failure

* fix seg fault when finalize wait in receiver

* [Feature] refactor sender connect logic and remove unnecessary sleeps in unit tests

* fix lint

* release RPCContext resources before process exits

* [Debug] TPReceiver wait start log

* [Debug] add log in get port

* [Debug] add log

* [ReDebug] revert time sleep in unit tests

* [Debug] remove sleep for test_distri,test_mp

* [debug] add more log

* [debug] add listen_booted_ flag

* [debug] restore commented code for queue

* [debug] sleep more in rpc_client

* restore change in tests

* Revert "restore change in tests"

This reverts commit 41a1892.

* Revert "[debug] sleep more in rpc_client"

This reverts commit a908e75.

* Revert "[debug] restore commented code for queue"

This reverts commit d3f993b.

* Revert "[debug] add listen_booted_ flag"

This reverts commit 244b216.

* Revert "[debug] add more log"

This reverts commit 4b78447.

* Revert "[Debug] remove sleep for test_distri,test_mp"

This reverts commit e1df1aa.

* remove debug code

* revert unnecessary change

* revert unnecessary changes

* always reset RPCContext when get started and reset all data

* remove time.sleep in dist tests

* fix lint

* reset envs before each dist test

* reset env properly

* add time sleep when start each server

* sleep for a while when boot server

* replace wait_thread with callback

* fix lint

* add dglconnect handshake check

Co-authored-by: Jinjing Zhou <[email protected]>
  • Loading branch information
Rhett-Ying and VoVAllen authored Jan 11, 2022
1 parent 95c0ff6 commit 37467e2
Show file tree
Hide file tree
Showing 13 changed files with 201 additions and 155 deletions.
1 change: 1 addition & 0 deletions python/dgl/distributed/dist_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def initialize(ip_config, num_servers=1, num_workers=0,
'Please define DGL_CONF_PATH to run DistGraph server'
formats = os.environ.get('DGL_GRAPH_FORMAT', 'csc').split(',')
formats = [f.strip() for f in formats]
rpc.reset()
serv = DistGraphServer(int(os.environ.get('DGL_SERVER_ID')),
os.environ.get('DGL_IP_CONFIG'),
int(os.environ.get('DGL_NUM_SERVER')),
Expand Down
19 changes: 8 additions & 11 deletions python/dgl/distributed/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

__all__ = ['set_rank', 'get_rank', 'Request', 'Response', 'register_service', \
'create_sender', 'create_receiver', 'finalize_sender', 'finalize_receiver', \
'receiver_wait', 'add_receiver_addr', 'sender_connect', 'read_ip_config', \
'receiver_wait', 'connect_receiver', 'read_ip_config', \
'get_num_machines', 'set_num_machines', 'get_machine_id', 'set_machine_id', \
'send_request', 'recv_request', 'send_response', 'recv_response', 'remote_call', \
'send_request_to_machine', 'remote_call_to_machine', 'fast_pull', \
Expand Down Expand Up @@ -138,7 +138,7 @@ def finalize_receiver():
"""
_CAPI_DGLRPCFinalizeReceiver()

def receiver_wait(ip_addr, port, num_senders):
def receiver_wait(ip_addr, port, num_senders, blocking=True):
"""Wait all of the senders' connections.
This api will be blocked until all the senders connect to the receiver.
Expand All @@ -151,11 +151,13 @@ def receiver_wait(ip_addr, port, num_senders):
receiver's port
num_senders : int
total number of senders
blocking : bool
whether to wait blockingly
"""
_CAPI_DGLRPCReceiverWait(ip_addr, int(port), int(num_senders))
_CAPI_DGLRPCReceiverWait(ip_addr, int(port), int(num_senders), blocking)

def add_receiver_addr(ip_addr, port, recv_id):
"""Add Receiver's IP address to sender's namebook.
def connect_receiver(ip_addr, port, recv_id):
"""Connect to target receiver
Parameters
----------
Expand All @@ -166,12 +168,7 @@ def add_receiver_addr(ip_addr, port, recv_id):
recv_id : int
receiver's ID
"""
_CAPI_DGLRPCAddReceiver(ip_addr, int(port), int(recv_id))

def sender_connect():
"""Connect to all the receivers.
"""
_CAPI_DGLRPCSenderConnect()
return _CAPI_DGLRPCConnectReceiver(ip_addr, int(port), int(recv_id))

def set_rank(rank):
"""Set the rank of this process.
Expand Down
9 changes: 5 additions & 4 deletions python/dgl/distributed/rpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import socket
import atexit
import logging
import time

from . import rpc
from .constants import MAX_QUEUE_SIZE
Expand Down Expand Up @@ -161,17 +162,17 @@ def connect_to_server(ip_config, num_servers, max_queue_size=MAX_QUEUE_SIZE, net
for server_id, addr in server_namebook.items():
server_ip = addr[1]
server_port = addr[2]
rpc.add_receiver_addr(server_ip, server_port, server_id)
rpc.sender_connect()
while not rpc.connect_receiver(server_ip, server_port, server_id):
time.sleep(1)
# Get local usable IP address and port
ip_addr = get_local_usable_addr(server_ip)
client_ip, client_port = ip_addr.split(':')
# wait server connect back
rpc.receiver_wait(client_ip, client_port, num_servers, blocking=False)
# Register client on server
register_req = rpc.ClientRegisterRequest(ip_addr)
for server_id in range(num_servers):
rpc.send_request(server_id, register_req)
# wait server connect back
rpc.receiver_wait(client_ip, client_port, num_servers)
# recv client ID from server
res = rpc.recv_response()
rpc.set_rank(res.client_id)
Expand Down
13 changes: 5 additions & 8 deletions python/dgl/distributed/rpc_server.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Functions used by server."""

import time

from . import rpc
from .constants import MAX_QUEUE_SIZE

Expand Down Expand Up @@ -64,24 +62,23 @@ def start_server(server_id, ip_config, num_servers, num_clients, server_state, \
# wait all the senders connect to server.
# Once all the senders connect to server, server will not
# accept new sender's connection
print("Wait connections ...")
rpc.receiver_wait(ip_addr, port, num_clients)
print("%d clients connected!" % num_clients)
print("Wait connections non-blockingly...")
rpc.receiver_wait(ip_addr, port, num_clients, blocking=False)
rpc.set_num_client(num_clients)
# Recv all the client's IP and assign ID to clients
addr_list = []
client_namebook = {}
for _ in range(num_clients):
# blocked until request is received
req, _ = rpc.recv_request()
assert isinstance(req, rpc.ClientRegisterRequest)
addr_list.append(req.ip_addr)
addr_list.sort()
for client_id, addr in enumerate(addr_list):
client_namebook[client_id] = addr
for client_id, addr in client_namebook.items():
client_ip, client_port = addr.split(':')
rpc.add_receiver_addr(client_ip, client_port, client_id)
time.sleep(3) # wait client's socket ready. 3 sec is enough.
rpc.sender_connect()
assert rpc.connect_receiver(client_ip, client_port, client_id)
if rpc.get_rank() == 0: # server_0 send all the IDs
for client_id, _ in client_namebook.items():
register_res = rpc.ClientRegisterResponse(client_id)
Expand Down
14 changes: 4 additions & 10 deletions src/rpc/rpc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,28 +143,22 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCReceiverWait")
std::string ip = args[0];
int port = args[1];
int num_sender = args[2];
bool blocking = args[3];
std::string addr;
addr = StringPrintf("tcp://%s:%d", ip.c_str(), port);
if (RPCContext::getInstance()->receiver->Wait(addr, num_sender) == false) {
if (RPCContext::getInstance()->receiver->Wait(addr, num_sender, blocking) == false) {
LOG(FATAL) << "Wait sender socket failed.";
}
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCAddReceiver")
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCConnectReceiver")
.set_body([](DGLArgs args, DGLRetValue* rv) {
std::string ip = args[0];
int port = args[1];
int recv_id = args[2];
std::string addr;
addr = StringPrintf("tcp://%s:%d", ip.c_str(), port);
RPCContext::getInstance()->sender->AddReceiver(addr, recv_id);
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSenderConnect")
.set_body([](DGLArgs args, DGLRetValue* rv) {
if (RPCContext::getInstance()->sender->Connect() == false) {
LOG(FATAL) << "Sender connection failed.";
}
*rv = RPCContext::getInstance()->sender->ConnectReceiver(addr, recv_id);
});

DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetRank")
Expand Down
3 changes: 3 additions & 0 deletions src/rpc/rpc.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,15 @@ struct RPCContext {
t->rank = -1;
t->machine_id = -1;
t->num_machines = 0;
t->msg_seq = 0;
t->num_servers = 0;
t->num_clients = 0;
t->barrier_count = 0;
t->num_servers_per_machine = 0;
t->sender.reset();
t->receiver.reset();
t->ctx.reset();
t->server_state.reset();
}
};

Expand Down
Loading

0 comments on commit 37467e2

Please sign in to comment.