diff --git a/python/dgl/contrib/sampling/sampler.py b/python/dgl/contrib/sampling/sampler.py index aaf441e51aef..b616780acca6 100644 --- a/python/dgl/contrib/sampling/sampler.py +++ b/python/dgl/contrib/sampling/sampler.py @@ -554,10 +554,11 @@ def fetch(self, current_index): return [subgraph.DGLSubGraph(self.g, subg) for subg in subgs] else: rets = [] - assert self._num_workers * 2 == len(subgs) - for i in range(self._num_workers): + assert len(subgs) % 2 == 0 + num_pos = int(len(subgs) / 2) + for i in range(num_pos): pos_subg = subgraph.DGLSubGraph(self.g, subgs[i]) - neg_subg = subgraph.DGLSubGraph(self.g, subgs[i + self._num_workers]) + neg_subg = subgraph.DGLSubGraph(self.g, subgs[i + num_pos]) rets.append((pos_subg, neg_subg)) return rets