Skip to content

Commit

Permalink
[Sampler] extend sampler (dmlc#238)
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-da authored and jermainewang committed Dec 4, 2018
1 parent 455ea48 commit c99f423
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions python/dgl/backend/mxnet/immutable_graph_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,6 @@ def node_subgraphs(self, vs_arr):

def neighbor_sampling(self, seed_ids, expand_factor, num_hops, neighbor_type,
node_prob, max_subgraph_size):
assert node_prob is None
if neighbor_type == 'in':
g = self._in_csr
elif neighbor_type == 'out':
Expand All @@ -280,9 +279,14 @@ def neighbor_sampling(self, seed_ids, expand_factor, num_hops, neighbor_type,
raise NotImplementedError
num_nodes = []
num_subgs = len(seed_ids)
res = mx.nd.contrib.dgl_csr_neighbor_uniform_sample(g, *seed_ids, num_hops=num_hops,
num_neighbor=expand_factor,
max_num_vertices=max_subgraph_size)
if node_prob is None:
res = mx.nd.contrib.dgl_csr_neighbor_uniform_sample(g, *seed_ids, num_hops=num_hops,
num_neighbor=expand_factor,
max_num_vertices=max_subgraph_size)
else:
res = mx.nd.contrib.dgl_csr_neighbor_non_uniform_sample(g, node_prob, *seed_ids, num_hops=num_hops,
num_neighbor=expand_factor,
max_num_vertices=max_subgraph_size)

vertices, subgraphs = res[0:num_subgs], res[num_subgs:(2*num_subgs)]
num_nodes = [subg_v[-1].asnumpy()[0] for subg_v in vertices]
Expand Down

0 comments on commit c99f423

Please sign in to comment.