Skip to content

Commit

Permalink
[Doc] Update NodeDataLoader and EdgeDataLoader for GPU-based neighbor…
Browse files Browse the repository at this point in the history
… sampling (dmlc#3046)

* update docstrings and tidy code

* add docs

* address comments

* Update __init__.py

* address comments
  • Loading branch information
BarclayII authored Jun 25, 2021
1 parent acd21a6 commit 427a5a9
Show file tree
Hide file tree
Showing 6 changed files with 938 additions and 895 deletions.
8 changes: 4 additions & 4 deletions examples/pytorch/graphsage/train_sampling_unsupervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def run(proc_id, n_gpus, args, devices, data):

# Create PyTorch DataLoader for constructing blocks
n_edges = g.num_edges()
train_seeds = np.arange(n_edges)
train_seeds = th.arange(n_edges)

# Create sampler
sampler = dgl.dataloading.MultiLayerNeighborSampler(
Expand All @@ -85,13 +85,13 @@ def run(proc_id, n_gpus, args, devices, data):
# For each edge with ID e in Reddit dataset, the reverse edge is e ± |E|/2.
reverse_eids=th.cat([
th.arange(n_edges // 2, n_edges),
th.arange(0, n_edges // 2)]),
th.arange(0, n_edges // 2)]).to(train_seeds),
negative_sampler=NegativeSampler(g, args.num_negs, args.neg_share),
device=device,
use_ddp=n_gpus > 1,
batch_size=args.batch_size,
shuffle=True,
drop_last=False,
pin_memory=True,
num_workers=args.num_workers)

# Define model and optimizer
Expand Down Expand Up @@ -174,7 +174,7 @@ def main(args, devices):
test_mask = g.ndata['test_mask']

# Create csr/coo/csc formats before launching training processes with multi-gpu.
# This avoids creating certain formats in each sub-process, which saves momory and CPU.
# This avoids creating certain formats in each sub-process, which saves memory and CPU.
g.create_formats_()
# Pack data
data = train_mask, val_mask, test_mask, n_classes, g
Expand Down
2 changes: 1 addition & 1 deletion python/dgl/dataloading/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ class EdgeCollator(Collator):
>>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(5)
>>> collator = dgl.dataloading.EdgeCollator(
... g, train_eid, sampler, exclude='reverse_id',
... reverse_eids=reverse_eids, negative_sampler=neg_sampler,
... reverse_eids=reverse_eids, negative_sampler=neg_sampler)
>>> dataloader = torch.utils.data.DataLoader(
... collator.dataset, collate_fn=collator.collate,
... batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
Expand Down
Loading

0 comments on commit 427a5a9

Please sign in to comment.