diff --git a/python/dgl/backend/mxnet/immutable_graph_index.py b/python/dgl/backend/mxnet/immutable_graph_index.py index d4b93fc9c4ad..b857faf971fd 100644 --- a/python/dgl/backend/mxnet/immutable_graph_index.py +++ b/python/dgl/backend/mxnet/immutable_graph_index.py @@ -271,7 +271,6 @@ def node_subgraphs(self, vs_arr): def neighbor_sampling(self, seed_ids, expand_factor, num_hops, neighbor_type, node_prob, max_subgraph_size): - assert node_prob is None if neighbor_type == 'in': g = self._in_csr elif neighbor_type == 'out': @@ -280,9 +279,14 @@ def neighbor_sampling(self, seed_ids, expand_factor, num_hops, neighbor_type, raise NotImplementedError num_nodes = [] num_subgs = len(seed_ids) - res = mx.nd.contrib.dgl_csr_neighbor_uniform_sample(g, *seed_ids, num_hops=num_hops, - num_neighbor=expand_factor, - max_num_vertices=max_subgraph_size) + if node_prob is None: + res = mx.nd.contrib.dgl_csr_neighbor_uniform_sample(g, *seed_ids, num_hops=num_hops, + num_neighbor=expand_factor, + max_num_vertices=max_subgraph_size) + else: + res = mx.nd.contrib.dgl_csr_neighbor_non_uniform_sample(g, node_prob, *seed_ids, num_hops=num_hops, + num_neighbor=expand_factor, + max_num_vertices=max_subgraph_size) vertices, subgraphs = res[0:num_subgs], res[num_subgs:(2*num_subgs)] num_nodes = [subg_v[-1].asnumpy()[0] for subg_v in vertices]