Skip to content

Commit

Permalink
[Sampler] Change Distributed Sampler API (dmlc#499)
Browse files Browse the repository at this point in the history
* Change Distributed Sampler API

* fix lint

* fix lint

* update demo

* update

* update

* update

* update demo

* update demo
  • Loading branch information
aksnzhy authored Apr 22, 2019
1 parent 3f46459 commit fe7d5e9
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 61 deletions.
4 changes: 2 additions & 2 deletions examples/mxnet/sampling/dis_sampling/gcn_ns_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ def worker(self, args):
# create GCN model
g = DGLGraph(data.graph, readonly=True)

for epoch in range(args.n_epochs):
# Here we onlt send nodeflow for training
while True:
idx = 0
for nf in dgl.contrib.sampling.NeighborSampler(g, args.batch_size,
args.num_neighbors,
Expand All @@ -44,6 +43,7 @@ def worker(self, args):
print("send train nodeflow: %d" %(idx))
sender.send(nf, 0)
idx += 1
sender.signal(0)

def main(args):
pool = MySamplerPool()
Expand Down
14 changes: 7 additions & 7 deletions examples/mxnet/sampling/dis_sampling/gcn_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,6 @@ def main(args):
if args.self_loop and not args.dataset.startswith('reddit'):
data.graph.add_edges_from([(i,i) for i in range(len(data.graph))])

# Create sampler receiver
receiver = dgl.contrib.sampling.SamplerReceiver(addr=args.ip, num_sender=args.num_sender)

train_nid = mx.nd.array(np.nonzero(data.train_mask)[0]).astype(np.int64).as_in_context(ctx)
test_nid = mx.nd.array(np.nonzero(data.test_mask)[0]).astype(np.int64).as_in_context(ctx)

Expand Down Expand Up @@ -163,6 +160,9 @@ def main(args):
norm = mx.nd.expand_dims(1./degs, 1)
g.ndata['norm'] = norm

# Create sampler receiver
sampler = dgl.contrib.sampling.SamplerReceiver(graph=g, addr=args.ip, num_sender=args.num_sender)

model = GCNSampling(in_feats,
args.n_hidden,
n_classes,
Expand Down Expand Up @@ -191,11 +191,11 @@ def main(args):

# initialize graph
dur = []
total_count = 153
for epoch in range(args.n_epochs):
for subg_count in range(total_count):
print(subg_count)
nf = receiver.recv(g)
idx = 0
for nf in sampler:
print("epoch: %d, subgraph: %d" %(epoch, idx))
idx += 1
nf.copy_from_parent()
# forward
with mx.autograd.record():
Expand Down
47 changes: 33 additions & 14 deletions python/dgl/contrib/sampling/dis_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from ...network import _send_nodeflow, _recv_nodeflow
from ...network import _create_sender, _create_receiver
from ...network import _finalize_sender, _finalize_receiver
from ...network import _add_receiver_addr, _sender_connect, _receiver_wait
from ...network import _add_receiver_addr, _sender_connect
from ...network import _receiver_wait, _send_end_signal

from multiprocessing import Pool
from abc import ABCMeta, abstractmethod
Expand Down Expand Up @@ -103,6 +104,17 @@ def send(self, nodeflow, recv_id):
"""
_send_nodeflow(self._sender, nodeflow, recv_id)

def signal(self, recv_id):
"""Whene samplling of each epoch is finished, users can
invoke this API to tell SamplerReceiver it has finished its job.
Parameters
----------
recv_id : int
receiver ID
"""
_send_end_signal(self._sender, recv_id)

class SamplerReceiver(object):
"""SamplerReceiver for DGL distributed training.
Expand All @@ -114,14 +126,18 @@ class SamplerReceiver(object):
Parameters
----------
graph : DGLGraph
The parent graph
addr : str
address of SamplerReceiver, e.g., '127.0.0.1:50051'
num_sender : int
total number of SamplerSender
"""
def __init__(self, addr, num_sender):
def __init__(self, graph, addr, num_sender):
self._graph = graph
self._addr = addr
self._num_sender = num_sender
self._tmp_count = 0
self._receiver = _create_receiver()
vec = self._addr.split(':')
_receiver_wait(self._receiver, vec[0], int(vec[1]), self._num_sender);
Expand All @@ -131,17 +147,20 @@ def __del__(self):
"""
_finalize_receiver(self._receiver)

def recv(self, graph):
"""Receive a NodeFlow object from remote sampler.
Parameters
----------
graph : DGLGraph
The parent graph
def __iter__(self):
"""Iterator
"""
return self

Returns
-------
NodeFlow
received NodeFlow object
def __next__(self):
"""Return sampled NodeFlow object
"""
return _recv_nodeflow(self._receiver, graph)
while True:
res = _recv_nodeflow(self._receiver, self._graph)
if isinstance(res, int):
self._tmp_count += 1
if self._tmp_count == self._num_sender:
self._tmp_count = 0
raise StopIteration
else:
return res
27 changes: 24 additions & 3 deletions python/dgl/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

_init_api("dgl.network")

_CONTROL_NODEFLOW = 0
_CONTROL_END_SIGNAL = 1

def _create_sender():
"""Create a Sender communicator via C api
"""
Expand Down Expand Up @@ -74,6 +77,18 @@ def _send_nodeflow(sender, nodeflow, recv_id):
layers_offsets,
flows_offsets)

def _send_end_signal(sender, recv_id):
"""Send an epoch-end signal to remote Receiver.
Parameters
----------
sender : ctypes.c_void_p
C sender handle
recv_id : int
Receiver ID
"""
_CAPI_SenderSendEndSignal(sender, recv_id)

def _create_receiver():
"""Create a Receiver communicator via C api
"""
Expand Down Expand Up @@ -115,6 +130,12 @@ def _recv_nodeflow(receiver, graph):
NodeFlow
Sampled NodeFlow object
"""
# hdl is a list of ptr
hdl = unwrap_to_ptr_list(_CAPI_ReceiverRecvSubgraph(receiver))
return NodeFlow(graph, hdl[0])
res = _CAPI_ReceiverRecvSubgraph(receiver)
if isinstance(res, int):
if res == _CONTROL_END_SIGNAL:
return _CONTROL_END_SIGNAL
else:
raise RuntimeError('Got unexpected control code {}'.format(res))
else:
hdl = unwrap_to_ptr_list(res)
return NodeFlow(graph, hdl[0])
77 changes: 56 additions & 21 deletions src/graph/network.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,27 @@ namespace network {
static char* SEND_BUFFER = nullptr;
static char* RECV_BUFFER = nullptr;

// Wrapper for Send api
static void SendData(network::Sender* sender,
const char* data,
int64_t size,
int recv_id) {
int64_t send_size = sender->Send(data, size, recv_id);
if (send_size <= 0) {
LOG(FATAL) << "Send error (size: " << send_size << ")";
}
}

// Wrapper for Recv api
static void RecvData(network::Receiver* receiver,
char* dest,
int64_t max_size) {
int64_t recv_size = receiver->Recv(dest, max_size);
if (recv_size <= 0) {
LOG(FATAL) << "Receive error (size: " << recv_size << ")";
}
}

DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderCreate")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
try {
Expand Down Expand Up @@ -74,20 +95,30 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendSubgraph")
ImmutableGraph *ptr = static_cast<ImmutableGraph*>(ghandle);
network::Sender* sender = static_cast<network::Sender*>(chandle);
auto csr = ptr->GetInCSR();
// Write control message
*SEND_BUFFER = CONTROL_NODEFLOW;
// Serialize nodeflow to data buffer
int64_t data_size = network::SerializeSampledSubgraph(
SEND_BUFFER,
SEND_BUFFER+sizeof(CONTROL_NODEFLOW),
csr,
node_mapping,
edge_mapping,
layer_offsets,
flow_offsets);
CHECK_GT(data_size, 0);
data_size += sizeof(CONTROL_NODEFLOW);
// Send msg via network
int64_t size = sender->Send(SEND_BUFFER, data_size, recv_id);
if (size <= 0) {
LOG(FATAL) << "Send message error (size: " << size << ")";
}
SendData(sender, SEND_BUFFER, data_size, recv_id);
});

DGL_REGISTER_GLOBAL("network._CAPI_SenderSendEndSignal")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0];
int recv_id = args[1];
network::Sender* sender = static_cast<network::Sender*>(chandle);
*SEND_BUFFER = CONTROL_END_SIGNAL;
// Send msg via network
SendData(sender, SEND_BUFFER, sizeof(CONTROL_END_SIGNAL), recv_id);
});

DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverCreate")
Expand Down Expand Up @@ -125,23 +156,27 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvSubgraph")
CommunicatorHandle chandle = args[0];
network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle);
// Recv data from network
int64_t size = receiver->Recv(RECV_BUFFER, kMaxBufferSize);
if (size <= 0) {
LOG(FATAL) << "Receive error: (size: " << size << ")";
RecvData(receiver, RECV_BUFFER, kMaxBufferSize);
int control = *RECV_BUFFER;
if (control == CONTROL_NODEFLOW) {
NodeFlow* nf = new NodeFlow();
ImmutableGraph::CSR::Ptr csr;
// Deserialize nodeflow from recv_data_buffer
network::DeserializeSampledSubgraph(RECV_BUFFER+sizeof(CONTROL_NODEFLOW),
&(csr),
&(nf->node_mapping),
&(nf->edge_mapping),
&(nf->layer_offsets),
&(nf->flow_offsets));
nf->graph = GraphPtr(new ImmutableGraph(csr, nullptr, false));
std::vector<NodeFlow*> subgs(1);
subgs[0] = nf;
*rv = WrapVectorReturn(subgs);
} else if (control == CONTROL_END_SIGNAL) {
*rv = CONTROL_END_SIGNAL;
} else {
LOG(FATAL) << "Unknow control number: " << control;
}
NodeFlow* nf = new NodeFlow();
ImmutableGraph::CSR::Ptr csr;
// Deserialize nodeflow from recv_data_buffer
network::DeserializeSampledSubgraph(RECV_BUFFER,
&(csr),
&(nf->node_mapping),
&(nf->edge_mapping),
&(nf->layer_offsets),
&(nf->flow_offsets));
nf->graph = GraphPtr(new ImmutableGraph(csr, nullptr, false));
std::vector<NodeFlow*> subgs(1);
subgs[0] = nf;
*rv = WrapVectorReturn(subgs);
});

} // namespace network
Expand Down
4 changes: 4 additions & 0 deletions src/graph/network.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ const int64_t kQueueSize = 1024 * 1024 * 1024;
// Maximal try count of connection
const int kMaxTryCount = 500;

// Control number
const int CONTROL_NODEFLOW = 0;
const int CONTROL_END_SIGNAL = 1;

} // namespace network
} // namespace dgl

Expand Down
28 changes: 14 additions & 14 deletions tests/compute/test_dis_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,20 @@ def generate_rand_graph(n):

def start_trainer():
g = generate_rand_graph(100)
recv = dgl.contrib.sampling.SamplerReceiver(addr='127.0.0.1:50051', num_sender=1)
subg = recv.recv(g)
seed_ids = subg.layer_parent_nid(-1)
assert len(seed_ids) == 1
src, dst, eid = g.in_edges(seed_ids, form='all')
assert subg.number_of_nodes() == len(src) + 1
assert subg.number_of_edges() == len(src)
sampler = dgl.contrib.sampling.SamplerReceiver(graph=g, addr='127.0.0.1:50051', num_sender=1)
for subg in sampler:
seed_ids = subg.layer_parent_nid(-1)
assert len(seed_ids) == 1
src, dst, eid = g.in_edges(seed_ids, form='all')
assert subg.number_of_nodes() == len(src) + 1
assert subg.number_of_edges() == len(src)

assert seed_ids == subg.layer_parent_nid(-1)
child_src, child_dst, child_eid = subg.in_edges(subg.layer_nid(-1), form='all')
assert F.array_equal(child_src, subg.layer_nid(0))
assert seed_ids == subg.layer_parent_nid(-1)
child_src, child_dst, child_eid = subg.in_edges(subg.layer_nid(-1), form='all')
assert F.array_equal(child_src, subg.layer_nid(0))

src1 = subg.map_to_parent_nid(child_src)
assert F.array_equal(src1, src)
src1 = subg.map_to_parent_nid(child_src)
assert F.array_equal(src1, src)

def start_sampler():
g = generate_rand_graph(100)
Expand All @@ -35,12 +35,12 @@ def start_sampler():
for i, subg in enumerate(dgl.contrib.sampling.NeighborSampler(
g, 1, 100, neighbor_type='in', num_workers=4)):
sender.send(subg, 0)
break
sender.signal(0)

if __name__ == '__main__':
pid = os.fork()
if pid == 0:
start_trainer()
else:
time.sleep(1)
time.sleep(2) # wait trainer start
start_sampler()

0 comments on commit fe7d5e9

Please sign in to comment.