From 701b4fccc2eed979ae3db801fabb6bf7bc03940c Mon Sep 17 00:00:00 2001 From: "Quan (Andy) Gan" Date: Sun, 30 Jan 2022 16:13:00 +0800 Subject: [PATCH] [Sampling] New sampling pipeline plus asynchronous prefetching (#3665) * initial update * more * more * multi-gpu example * cluster gcn, finalize homogeneous * more explanation * fix * bunch of fixes * fix * RGAT example and more fixes * shadow-gnn sampler and some changes in unit test * fix * wth * more fixes * remove shadow+node/edge dataloader tests for possible ux changes * lints * add legacy dataloading import just in case * fix * update pylint for f-strings * fix * lint * lint * lint again * cherry-picking commit fa9f494 * oops * fix * add sample_neighbors in dist_graph * fix * lint * fix * fix * fix * fix tutorial * fix * fix * fix * fix warning * remove debug * add get_foo_storage apis * lint --- docs/source/api/python/dgl.dataloading.rst | 2 + docs/source/guide/distributed-apis.rst | 26 +- docs/source/guide_cn/distributed-apis.rst | 20 +- docs/source/guide_ko/distributed-apis.rst | 10 +- .../__temporary__/cluster_gcn/cluster_gcn.py | 88 + .../pytorch/__temporary__/dglnew/__init__.py | 2 + .../__temporary__/dglnew/graph/__init__.py | 3 + .../__temporary__/dglnew/graph/graph.py | 150 ++ .../dglnew/graph/other_feature.py | 40 + .../pytorch/__temporary__/graphsage/ddp.py | 106 ++ .../__temporary__/graphsage/disk_storage.py | 86 + .../__temporary__/graphsage/link_pred.py | 94 + .../pytorch/__temporary__/graphsage/normal.py | 79 + examples/pytorch/__temporary__/rgat/rgat.py | 105 ++ .../experimental/train_dist_unsupervised.py | 2 +- .../rgcn/experimental/entity_classify_dist.py | 6 +- examples/pytorch/tgn/dataloading.py | 3 + examples/pytorch/tgn/train.py | 2 +- python/dgl/__init__.py | 2 + python/dgl/_dataloading/__init__.py | 28 + .../async_transferer.py | 2 +- python/dgl/_dataloading/cluster_gcn.py | 83 + python/dgl/_dataloading/dataloader.py | 982 +++++++++++ python/dgl/_dataloading/negative_sampler.py | 125 ++ .../{dataloading => _dataloading}/neighbor.py | 0 .../pytorch/__init__.py | 0 .../pytorch/dataloader.py | 192 +-- python/dgl/_dataloading/shadow.py | 98 ++ python/dgl/dataloading/__init__.py | 13 +- python/dgl/dataloading/base.py | 275 +++ python/dgl/dataloading/cluster_gcn.py | 115 +- python/dgl/dataloading/dataloader.py | 1525 +++++++---------- python/dgl/dataloading/dist_dataloader.py | 103 ++ python/dgl/dataloading/negative_sampler.py | 21 +- python/dgl/dataloading/neighbor_sampler.py | 124 ++ python/dgl/dataloading/shadow.py | 45 +- python/dgl/distributed/dist_graph.py | 15 + python/dgl/frame.py | 31 +- python/dgl/geometry/capi.py | 2 +- python/dgl/heterograph.py | 8 + python/dgl/sampling/__init__.py | 1 + python/dgl/sampling/negative.py | 4 + python/dgl/sampling/neighbor.py | 50 +- python/dgl/sampling/utils.py | 79 + python/dgl/storages/__init__.py | 9 + python/dgl/storages/base.py | 83 + python/dgl/storages/numpy.py | 18 + python/dgl/storages/pytorch_tensor.py | 32 + python/dgl/storages/tensor.py | 17 + python/dgl/subgraph.py | 60 +- python/dgl/utils/__init__.py | 1 + python/dgl/utils/exception.py | 57 + python/dgl/utils/internal.py | 29 +- python/dgl/view.py | 87 +- ...ransferer.py => _test_async_transferer.py} | 0 tests/compute/test_subgraph.py | 3 + tests/compute/test_transform.py | 1 + tests/distributed/test_mp_dataloader.py | 20 +- tests/pytorch/test_dataloader.py | 283 +-- tutorials/dist/1_node_classification.py | 9 +- tutorials/large/L2_large_link_prediction.py | 41 +- tutorials/multi/2_node_classification.py | 1 - 62 files changed, 4011 insertions(+), 1487 deletions(-) create mode 100644 examples/pytorch/__temporary__/cluster_gcn/cluster_gcn.py create mode 100644 examples/pytorch/__temporary__/dglnew/__init__.py create mode 100644 examples/pytorch/__temporary__/dglnew/graph/__init__.py create mode 100644 examples/pytorch/__temporary__/dglnew/graph/graph.py create mode 100644 examples/pytorch/__temporary__/dglnew/graph/other_feature.py create mode 100644 examples/pytorch/__temporary__/graphsage/ddp.py create mode 100644 examples/pytorch/__temporary__/graphsage/disk_storage.py create mode 100644 examples/pytorch/__temporary__/graphsage/link_pred.py create mode 100644 examples/pytorch/__temporary__/graphsage/normal.py create mode 100644 examples/pytorch/__temporary__/rgat/rgat.py create mode 100644 python/dgl/_dataloading/__init__.py rename python/dgl/{dataloading => _dataloading}/async_transferer.py (97%) create mode 100644 python/dgl/_dataloading/cluster_gcn.py create mode 100644 python/dgl/_dataloading/dataloader.py create mode 100644 python/dgl/_dataloading/negative_sampler.py rename python/dgl/{dataloading => _dataloading}/neighbor.py (100%) rename python/dgl/{dataloading => _dataloading}/pytorch/__init__.py (100%) rename python/dgl/{dataloading => _dataloading}/pytorch/dataloader.py (87%) create mode 100644 python/dgl/_dataloading/shadow.py create mode 100644 python/dgl/dataloading/base.py create mode 100644 python/dgl/dataloading/dist_dataloader.py create mode 100644 python/dgl/dataloading/neighbor_sampler.py create mode 100644 python/dgl/sampling/utils.py create mode 100644 python/dgl/storages/__init__.py create mode 100644 python/dgl/storages/base.py create mode 100644 python/dgl/storages/numpy.py create mode 100644 python/dgl/storages/pytorch_tensor.py create mode 100644 python/dgl/storages/tensor.py create mode 100644 python/dgl/utils/exception.py rename tests/compute/{test_async_transferer.py => _test_async_transferer.py} (100%) diff --git a/docs/source/api/python/dgl.dataloading.rst b/docs/source/api/python/dgl.dataloading.rst index 7429297a5ff0..e0502227bedb 100644 --- a/docs/source/api/python/dgl.dataloading.rst +++ b/docs/source/api/python/dgl.dataloading.rst @@ -17,6 +17,8 @@ and an ``EdgeDataLoader`` for edge/link prediction task. .. autoclass:: NodeDataLoader .. autoclass:: EdgeDataLoader .. autoclass:: GraphDataLoader +.. autoclass:: DistNodeDataLoader +.. autoclass:: DistEdgeDataLoader .. _api-dataloading-neighbor-sampling: diff --git a/docs/source/guide/distributed-apis.rst b/docs/source/guide/distributed-apis.rst index fdbd502a2385..3ccc0f19f49f 100644 --- a/docs/source/guide/distributed-apis.rst +++ b/docs/source/guide/distributed-apis.rst @@ -202,20 +202,20 @@ DGL provides two levels of APIs for sampling nodes and edges to generate mini-ba (see the section of mini-batch training). The low-level APIs require users to write code to explicitly define how a layer of nodes are sampled (e.g., using :func:`dgl.sampling.sample_neighbors` ). The high-level sampling APIs implement a few popular sampling algorithms for node classification -and link prediction tasks (e.g., :class:`~dgl.dataloading.pytorch.NodeDataloader` and -:class:`~dgl.dataloading.pytorch.EdgeDataloader` ). +and link prediction tasks (e.g., :class:`~dgl.dataloading.pytorch.NodeDataLoader` and +:class:`~dgl.dataloading.pytorch.EdgeDataLoader` ). The distributed sampling module follows the same design and provides two levels of sampling APIs. For the lower-level sampling API, it provides :func:`~dgl.distributed.sample_neighbors` for distributed neighborhood sampling on :class:`~dgl.distributed.DistGraph`. In addition, DGL provides -a distributed Dataloader (:class:`~dgl.distributed.DistDataLoader` ) for distributed sampling. -The distributed Dataloader has the same interface as Pytorch DataLoader except that users cannot +a distributed DataLoader (:class:`~dgl.distributed.DistDataLoader` ) for distributed sampling. +The distributed DataLoader has the same interface as Pytorch DataLoader except that users cannot specify the number of worker processes when creating a dataloader. The worker processes are created in :func:`dgl.distributed.initialize`. **Note**: When running :func:`dgl.distributed.sample_neighbors` on :class:`~dgl.distributed.DistGraph`, -the sampler cannot run in Pytorch Dataloader with multiple worker processes. The main reason is that -Pytorch Dataloader creates new sampling worker processes in every epoch, which leads to creating and +the sampler cannot run in Pytorch DataLoader with multiple worker processes. The main reason is that +Pytorch DataLoader creates new sampling worker processes in every epoch, which leads to creating and destroying :class:`~dgl.distributed.DistGraph` objects many times. When using the low-level API, the sampling code is similar to single-process sampling. The only @@ -240,17 +240,17 @@ difference is that users need to use :func:`dgl.distributed.sample_neighbors` an for batch in dataloader: ... -The same high-level sampling APIs (:class:`~dgl.dataloading.pytorch.NodeDataloader` and -:class:`~dgl.dataloading.pytorch.EdgeDataloader` ) work for both :class:`~dgl.DGLGraph` -and :class:`~dgl.distributed.DistGraph`. When using :class:`~dgl.dataloading.pytorch.NodeDataloader` -and :class:`~dgl.dataloading.pytorch.EdgeDataloader`, the distributed sampling code is exactly -the same as single-process sampling. +The high-level sampling APIs (:class:`~dgl.dataloading.pytorch.NodeDataLoader` and +:class:`~dgl.dataloading.pytorch.EdgeDataLoader` ) has distributed counterparts +(:class:`~dgl.dataloading.pytorch.DistNodeDataLoader` and +:class:`~dgl.dataloading.pytorch.DistEdgeDataLoader`). The code is exactly the +same as single-process sampling otherwise. .. code:: python sampler = dgl.sampling.MultiLayerNeighborSampler([10, 25]) - dataloader = dgl.sampling.NodeDataLoader(g, train_nid, sampler, - batch_size=batch_size, shuffle=True) + dataloader = dgl.sampling.DistNodeDataLoader(g, train_nid, sampler, + batch_size=batch_size, shuffle=True) for batch in dataloader: ... diff --git a/docs/source/guide_cn/distributed-apis.rst b/docs/source/guide_cn/distributed-apis.rst index 0510710353c8..b851e63f18eb 100644 --- a/docs/source/guide_cn/distributed-apis.rst +++ b/docs/source/guide_cn/distributed-apis.rst @@ -177,9 +177,9 @@ DGL提供了一个稀疏的Adagrad优化器 :class:`~dgl.distributed.SparseAdagr DGL提供了两个级别的API,用于对节点和边进行采样以生成小批次训练数据(请参阅小批次训练的章节)。 底层API要求用户编写代码以明确定义如何对节点层进行采样(例如,使用 :func:`dgl.sampling.sample_neighbors` )。 高层采样API为节点分类和链接预测任务实现了一些流行的采样算法(例如 -:class:`~dgl.dataloading.pytorch.NodeDataloader` +:class:`~dgl.dataloading.pytorch.NodeDataLoader` 和 -:class:`~dgl.dataloading.pytorch.EdgeDataloader` )。 +:class:`~dgl.dataloading.pytorch.EdgeDataLoader` )。 分布式采样模块遵循相同的设计,也提供两个级别的采样API。对于底层的采样API,它为 :class:`~dgl.distributed.DistGraph` 上的分布式邻居采样提供了 @@ -188,7 +188,7 @@ DGL提供了两个级别的API,用于对节点和边进行采样以生成小 分布式数据加载器具有与PyTorch DataLoader相同的接口。其中的工作进程(worker)在 :func:`dgl.distributed.initialize` 中创建。 **Note**: 在 :class:`~dgl.distributed.DistGraph` 上运行 :func:`dgl.distributed.sample_neighbors` 时, -采样器无法在具有多个工作进程的PyTorch Dataloader中运行。主要原因是PyTorch Dataloader在每个训练周期都会创建新的采样工作进程, +采样器无法在具有多个工作进程的PyTorch DataLoader中运行。主要原因是PyTorch DataLoader在每个训练周期都会创建新的采样工作进程, 从而导致多次创建和删除 :class:`~dgl.distributed.DistGraph` 对象。 使用底层API时,采样代码类似于单进程采样。唯一的区别是用户需要使用 @@ -214,19 +214,19 @@ DGL提供了两个级别的API,用于对节点和边进行采样以生成小 for batch in dataloader: ... -:class:`~dgl.DGLGraph` 和 :class:`~dgl.distributed.DistGraph` 都可以使用相同的高级采样API( -:class:`~dgl.dataloading.pytorch.NodeDataloader` +:class:`~dgl.dataloading.pytorch.NodeDataLoader` 和 -:class:`~dgl.dataloading.pytorch.EdgeDataloader`)。使用 -:class:`~dgl.dataloading.pytorch.NodeDataloader` +:class:`~dgl.dataloading.pytorch.EdgeDataLoader` 有分布式的版本 +:class:`~dgl.dataloading.pytorch.DistNodeDataLoader` 和 -:class:`~dgl.dataloading.pytorch.EdgeDataloader` 时,分布式采样代码与单进程采样完全相同。 +:class:`~dgl.dataloading.pytorch.DistEdgeDataLoader` 。使用 +时分布式采样代码与单进程采样几乎完全相同。 .. code:: python sampler = dgl.sampling.MultiLayerNeighborSampler([10, 25]) - dataloader = dgl.sampling.NodeDataLoader(g, train_nid, sampler, - batch_size=batch_size, shuffle=True) + dataloader = dgl.sampling.DistNodeDataLoader(g, train_nid, sampler, + batch_size=batch_size, shuffle=True) for batch in dataloader: ... diff --git a/docs/source/guide_ko/distributed-apis.rst b/docs/source/guide_ko/distributed-apis.rst index 63c6b31cbb85..60d3a67cdf01 100644 --- a/docs/source/guide_ko/distributed-apis.rst +++ b/docs/source/guide_ko/distributed-apis.rst @@ -132,8 +132,8 @@ DGL은 노드 임베딩들을 필요로 하는 변환 모델(transductive models 분산 샘플링 ~~~~~~~~ -DGL은 미니-배치를 생성하기 위해 노드 및 에지 샘플링을 하는 두 수준의 API를 제공한다 (미니-배치 학습 섹션 참조). Low-level API는 노드들의 레이어가 어떻게 샘플링될지를 명시적으로 정의하는 코드를 직접 작성해야한다 (예를 들면, :func:`dgl.sampling.sample_neighbors` 사용해서). High-level API는 노드 분류 및 링크 예측(예, :class:`~dgl.dataloading.pytorch.NodeDataloader` 와 -:class:`~dgl.dataloading.pytorch.EdgeDataloader`) 에 사용되는 몇 가지 유명한 샘플링 알고리즘을 구현하고 있다. +DGL은 미니-배치를 생성하기 위해 노드 및 에지 샘플링을 하는 두 수준의 API를 제공한다 (미니-배치 학습 섹션 참조). Low-level API는 노드들의 레이어가 어떻게 샘플링될지를 명시적으로 정의하는 코드를 직접 작성해야한다 (예를 들면, :func:`dgl.sampling.sample_neighbors` 사용해서). High-level API는 노드 분류 및 링크 예측(예, :class:`~dgl.dataloading.pytorch.NodeDataLoader` 와 +:class:`~dgl.dataloading.pytorch.EdgeDataLoader`) 에 사용되는 몇 가지 유명한 샘플링 알고리즘을 구현하고 있다. 분산 샘플링 모듈도 같은 디자인을 따르고 있고, 두 level의 샘플링 API를 제공한다. Low-level 샘플링 API의 경우, :class:`~dgl.distributed.DistGraph` 에 대한 분산 이웃 샘플링을 위해 :func:`~dgl.distributed.sample_neighbors` 가 있다. 또한, DGL은 분산 샘플링을 위해 분산 데이터 로더, :class:`~dgl.distributed.DistDataLoader` 를 제공한다. 분산 DataLoader는 PyTorch DataLoader와 같은 인터페이스를 갖는데, 다른 점은 사용자가 데이터 로더를 생성할 때 worker 프로세스의 개수를 지정할 수 없다는 점이다. Worker 프로세스들은 :func:`dgl.distributed.initialize` 에서 만들어진다. @@ -159,13 +159,13 @@ Low-level API를 사용할 때, 샘플링 코드는 단일 프로세스 샘플 for batch in dataloader: ... -동일한 high-level 샘플링 API들(:class:`~dgl.dataloading.pytorch.NodeDataloader` 와 :class:`~dgl.dataloading.pytorch.EdgeDataloader` )이 :class:`~dgl.DGLGraph` 와 :class:`~dgl.distributed.DistGraph` 에 대해서 동작한다. :class:`~dgl.dataloading.pytorch.NodeDataloader` 과 :class:`~dgl.dataloading.pytorch.EdgeDataloader` 를 사용할 때, 분산 샘플링 코드는 싱글-프로세스 샘플링 코드와 정확하게 같다. +동일한 high-level 샘플링 API들(:class:`~dgl.dataloading.pytorch.NodeDataLoader` 와 :class:`~dgl.dataloading.pytorch.EdgeDataLoader` )이 :class:`~dgl.DGLGraph` 와 :class:`~dgl.distributed.DistGraph` 에 대해서 동작한다. :class:`~dgl.dataloading.pytorch.NodeDataLoader` 과 :class:`~dgl.dataloading.pytorch.EdgeDataLoader` 를 사용할 때, 분산 샘플링 코드는 싱글-프로세스 샘플링 코드와 정확하게 같다. .. code:: python sampler = dgl.sampling.MultiLayerNeighborSampler([10, 25]) - dataloader = dgl.sampling.NodeDataLoader(g, train_nid, sampler, - batch_size=batch_size, shuffle=True) + dataloader = dgl.sampling.DistNodeDataLoader(g, train_nid, sampler, + batch_size=batch_size, shuffle=True) for batch in dataloader: ... diff --git a/examples/pytorch/__temporary__/cluster_gcn/cluster_gcn.py b/examples/pytorch/__temporary__/cluster_gcn/cluster_gcn.py new file mode 100644 index 000000000000..7f23626731a4 --- /dev/null +++ b/examples/pytorch/__temporary__/cluster_gcn/cluster_gcn.py @@ -0,0 +1,88 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchmetrics.functional as MF +import dgl +import dgl.nn as dglnn +import time +import numpy as np +from ogb.nodeproppred import DglNodePropPredDataset + +USE_WRAPPER = True + +class SAGE(nn.Module): + def __init__(self, in_feats, n_hidden, n_classes): + super().__init__() + self.layers = nn.ModuleList() + self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean')) + self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean')) + self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean')) + self.dropout = nn.Dropout(0.5) + + def forward(self, sg, x): + h = x + for l, layer in enumerate(self.layers): + h = layer(sg, h) + if l != len(self.layers) - 1: + h = F.relu(h) + h = self.dropout(h) + return h + +dataset = DglNodePropPredDataset('ogbn-products') +graph, labels = dataset[0] +graph.ndata['label'] = labels +split_idx = dataset.get_idx_split() +train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test'] +graph.ndata['train_mask'] = torch.zeros(graph.num_nodes(), dtype=torch.bool).index_fill_(0, train_idx, True) +graph.ndata['valid_mask'] = torch.zeros(graph.num_nodes(), dtype=torch.bool).index_fill_(0, valid_idx, True) +graph.ndata['test_mask'] = torch.zeros(graph.num_nodes(), dtype=torch.bool).index_fill_(0, test_idx, True) + +model = SAGE(graph.ndata['feat'].shape[1], 256, dataset.num_classes).cuda() +opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) + +if USE_WRAPPER: + import dglnew + graph.create_formats_() + graph = dglnew.graph.wrapper.DGLGraphStorage(graph) + +num_partitions = 1000 +sampler = dgl.dataloading.ClusterGCNSampler( + graph, num_partitions, + prefetch_node_feats=['feat', 'label', 'train_mask', 'valid_mask', 'test_mask']) +# DataLoader for generic dataloading with a graph, a set of indices (any indices, like +# partition IDs here), and a graph sampler. +# NodeDataLoader and EdgeDataLoader are simply special cases of DataLoader where the +# indices are guaranteed to be node and edge IDs. +dataloader = dgl.dataloading.DataLoader( + graph, + torch.arange(num_partitions), + sampler, + device='cuda', + batch_size=100, + shuffle=True, + drop_last=False, + pin_memory=True, + num_workers=8, + persistent_workers=True, + use_prefetch_thread=True) # TBD: could probably remove this argument + +durations = [] +for _ in range(10): + t0 = time.time() + for it, sg in enumerate(dataloader): + x = sg.ndata['feat'] + y = sg.ndata['label'][:, 0] + m = sg.ndata['train_mask'] + y_hat = model(sg, x) + loss = F.cross_entropy(y_hat[m], y[m]) + opt.zero_grad() + loss.backward() + opt.step() + if it % 20 == 0: + acc = MF.accuracy(y_hat[m], y[m]) + mem = torch.cuda.max_memory_allocated() / 1000000 + print('Loss', loss.item(), 'Acc', acc.item(), 'GPU Mem', mem, 'MB') + tt = time.time() + print(tt - t0) + durations.append(tt - t0) +print(np.mean(durations[4:]), np.std(durations[4:])) diff --git a/examples/pytorch/__temporary__/dglnew/__init__.py b/examples/pytorch/__temporary__/dglnew/__init__.py new file mode 100644 index 000000000000..29409e1c19c4 --- /dev/null +++ b/examples/pytorch/__temporary__/dglnew/__init__.py @@ -0,0 +1,2 @@ +from . import graph +from . import storages diff --git a/examples/pytorch/__temporary__/dglnew/graph/__init__.py b/examples/pytorch/__temporary__/dglnew/graph/__init__.py new file mode 100644 index 000000000000..6581b140ac70 --- /dev/null +++ b/examples/pytorch/__temporary__/dglnew/graph/__init__.py @@ -0,0 +1,3 @@ +from .graph import * +from .other_feature import * +from .wrapper import * diff --git a/examples/pytorch/__temporary__/dglnew/graph/graph.py b/examples/pytorch/__temporary__/dglnew/graph/graph.py new file mode 100644 index 000000000000..8f8ffaa94045 --- /dev/null +++ b/examples/pytorch/__temporary__/dglnew/graph/graph.py @@ -0,0 +1,150 @@ +class GraphStorage(object): + def get_node_storage(self, key, ntype=None): + pass + + def get_edge_storage(self, key, etype=None): + pass + + # Required for checking whether a single dict is allowed for ndata and edata. + @property + def ntypes(self): + pass + + @property + def canonical_etypes(self): + pass + + def etypes(self): + return [etype[1] for etype in self.canonical_etypes] + + def sample_neighbors(self, seed_nodes, fanout, edge_dir='in', prob=None, + exclude_edges=None, replace=False, output_device=None): + """Return a DGLGraph which is a subgraph induced by sampling neighboring edges of + the given nodes. + + See ``dgl.sampling.sample_neighbors`` for detailed semantics. + + Parameters + ---------- + seed_nodes : Tensor or dict[str, Tensor] + Node IDs to sample neighbors from. + + This argument can take a single ID tensor or a dictionary of node types and ID tensors. + If a single tensor is given, the graph must only have one type of nodes. + fanout : int or dict[etype, int] + The number of edges to be sampled for each node on each edge type. + + This argument can take a single int or a dictionary of edge types and ints. + If a single int is given, DGL will sample this number of edges for each node for + every edge type. + + If -1 is given for a single edge type, all the neighboring edges with that edge + type will be selected. + prob : str, optional + Feature name used as the (unnormalized) probabilities associated with each + neighboring edge of a node. The feature must have only one element for each + edge. + + The features must be non-negative floats, and the sum of the features of + inbound/outbound edges for every node must be positive (though they don't have + to sum up to one). Otherwise, the result will be undefined. + + If :attr:`prob` is not None, GPU sampling is not supported. + exclude_edges: tensor or dict + Edge IDs to exclude during sampling neighbors for the seed nodes. + + This argument can take a single ID tensor or a dictionary of edge types and ID tensors. + If a single tensor is given, the graph must only have one type of nodes. + replace : bool, optional + If True, sample with replacement. + output_device : Framework-specific device context object, optional + The output device. Default is the same as the input graph. + + Returns + ------- + DGLGraph + A sampled subgraph with the same nodes as the original graph, but only the sampled neighboring + edges. The induced edge IDs will be in ``edata[dgl.EID]``. + """ + pass + + # Required in Cluster-GCN + def subgraph(self, nodes, relabel_nodes=False, output_device=None): + """Return a subgraph induced on given nodes. + + This has the same semantics as ``dgl.node_subgraph``. + + Parameters + ---------- + nodes : nodes or dict[str, nodes] + The nodes to form the subgraph. The allowed nodes formats are: + + * Int Tensor: Each element is a node ID. The tensor must have the same device type + and ID data type as the graph's. + * iterable[int]: Each element is a node ID. + * Bool Tensor: Each :math:`i^{th}` element is a bool flag indicating whether + node :math:`i` is in the subgraph. + + If the graph is homogeneous, one can directly pass the above formats. + Otherwise, the argument must be a dictionary with keys being node types + and values being the node IDs in the above formats. + relabel_nodes : bool, optional + If True, the extracted subgraph will only have the nodes in the specified node set + and it will relabel the nodes in order. + output_device : Framework-specific device context object, optional + The output device. Default is the same as the input graph. + + Returns + ------- + DGLGraph + The subgraph. + """ + pass + + # Required in Link Prediction + def edge_subgraph(self, edges, relabel_nodes=False, output_device=None): + """Return a subgraph induced on given edges. + + This has the same semantics as ``dgl.edge_subgraph``. + + Parameters + ---------- + edges : edges or dict[(str, str, str), edges] + The edges to form the subgraph. The allowed edges formats are: + + * Int Tensor: Each element is an edge ID. The tensor must have the same device type + and ID data type as the graph's. + * iterable[int]: Each element is an edge ID. + * Bool Tensor: Each :math:`i^{th}` element is a bool flag indicating whether + edge :math:`i` is in the subgraph. + + If the graph is homogeneous, one can directly pass the above formats. + Otherwise, the argument must be a dictionary with keys being edge types + and values being the edge IDs in the above formats. + relabel_nodes : bool, optional + If True, the extracted subgraph will only have the nodes in the specified node set + and it will relabel the nodes in order. + output_device : Framework-specific device context object, optional + The output device. Default is the same as the input graph. + + Returns + ------- + DGLGraph + The subgraph. + """ + pass + + # Required in Link Prediction negative sampler + def find_edges(self, edges, etype=None, output_device=None): + """Return the source and destination node IDs given the edge IDs within the given edge type. + """ + pass + + # Required in Link Prediction negative sampler + def num_nodes(self, ntype): + """Return the number of nodes for the given node type.""" + pass + + def global_uniform_negative_sampling(self, num_samples, exclude_self_loops=True, + replace=False, etype=None): + """Per source negative sampling as in ``dgl.dataloading.GlobalUniform``""" diff --git a/examples/pytorch/__temporary__/dglnew/graph/other_feature.py b/examples/pytorch/__temporary__/dglnew/graph/other_feature.py new file mode 100644 index 000000000000..d1684e73b962 --- /dev/null +++ b/examples/pytorch/__temporary__/dglnew/graph/other_feature.py @@ -0,0 +1,40 @@ +from collections import Mapping +from dgl.storages import wrap_storage +from dgl.utils import recursive_apply + +# A GraphStorage class where ndata and edata can be any FeatureStorage but +# otherwise the same as the wrapped DGLGraph. +class OtherFeatureGraphStorage(object): + def __init__(self, g, ndata=None, edata=None): + self.g = g + self._ndata = recursive_apply(ndata, wrap_storage) if ndata is not None else {} + self._edata = recursive_apply(edata, wrap_storage) if edata is not None else {} + + for k, v in self._ndata.items(): + if not isinstance(v, Mapping): + assert len(self.g.ntypes) == 1 + self._ndata[k] = {self.g.ntypes[0]: v} + for k, v in self._edata.items(): + if not isinstance(v, Mapping): + assert len(self.g.canonical_etypes) == 1 + self._edata[k] = {self.g.canonical_etypes[0]: v} + + def get_node_storage(self, key, ntype=None): + if ntype is None: + ntype = self.g.ntypes[0] + return self._ndata[key][ntype] + + def get_edge_storage(self, key, etype=None): + if etype is None: + etype = self.g.canonical_etypes[0] + return self._edata[key][etype] + + def __getattr__(self, key): + # I wrote it in this way because I'm too lazy to write "def sample_neighbors" + # or stuff like that. + if key in ['ntypes', 'etypes', 'canonical_etypes', 'sample_neighbors', + 'subgraph', 'edge_subgraph', 'find_edges', 'num_nodes']: + # Delegate to the wrapped DGLGraph instance. + return getattr(self.g, key) + else: + return super().__getattr__(key) diff --git a/examples/pytorch/__temporary__/graphsage/ddp.py b/examples/pytorch/__temporary__/graphsage/ddp.py new file mode 100644 index 000000000000..d1366f1e1d70 --- /dev/null +++ b/examples/pytorch/__temporary__/graphsage/ddp.py @@ -0,0 +1,106 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +import torch.distributed.optim +import torchmetrics.functional as MF +import dgl +import dgl.nn as dglnn +import time +import numpy as np +from ogb.nodeproppred import DglNodePropPredDataset + +USE_WRAPPER = False + +class SAGE(nn.Module): + def __init__(self, in_feats, n_hidden, n_classes): + super().__init__() + self.layers = nn.ModuleList() + self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean')) + self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean')) + self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean')) + self.dropout = nn.Dropout(0.5) + + def forward(self, blocks, x): + h = x + for l, (layer, block) in enumerate(zip(self.layers, blocks)): + h = layer(block, h) + if l != len(self.layers) - 1: + h = F.relu(h) + h = self.dropout(h) + return h + + +def train(rank, world_size, graph, num_classes, split_idx): + torch.cuda.set_device(rank) + dist.init_process_group('nccl', 'tcp://127.0.0.1:12347', world_size=world_size, rank=rank) + + model = SAGE(graph.ndata['feat'].shape[1], 256, num_classes).cuda() + model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank) + opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) + + train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test'] + if USE_WRAPPER: + import dglnew + graph = dglnew.graph.wrapper.DGLGraphStorage(graph) + + sampler = dgl.dataloading.NeighborSampler( + [5, 5, 5], output_device='cpu', prefetch_node_feats=['feat'], + prefetch_labels=['label']) + dataloader = dgl.dataloading.NodeDataLoader( + graph, + train_idx, + sampler, + device='cuda', + batch_size=1000, + shuffle=True, + drop_last=False, + pin_memory=True, + num_workers=4, + persistent_workers=True, + use_ddp=True, + use_prefetch_thread=True) # TBD: could probably remove this argument + + durations = [] + for _ in range(10): + t0 = time.time() + for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader): + x = blocks[0].srcdata['feat'] + y = blocks[-1].dstdata['label'][:, 0] + y_hat = model(blocks, x) + loss = F.cross_entropy(y_hat, y) + opt.zero_grad() + loss.backward() + opt.step() + if it % 20 == 0: + acc = MF.accuracy(y_hat, y) + mem = torch.cuda.max_memory_allocated() / 1000000 + print('Loss', loss.item(), 'Acc', acc.item(), 'GPU Mem', mem, 'MB') + tt = time.time() + if rank == 0: + print(tt - t0) + durations.append(tt - t0) + if rank == 0: + print(np.mean(durations[4:]), np.std(durations[4:])) + +if __name__ == '__main__': + dataset = DglNodePropPredDataset('ogbn-products') + graph, labels = dataset[0] + graph.ndata['label'] = labels + graph.create_formats_() + split_idx = dataset.get_idx_split() + num_classes = dataset.num_classes + n_procs = 4 + + # Tested with mp.spawn and fork. Both worked and got 4s per epoch with 4 GPUs + # and 3.86s per epoch with 8 GPUs on p2.8x, compared to 5.2s from official examples. + #import torch.multiprocessing as mp + #mp.spawn(train, args=(n_procs, graph, num_classes, split_idx), nprocs=n_procs) + import dgl.multiprocessing as mp + procs = [] + for i in range(n_procs): + p = mp.Process(target=train, args=(i, n_procs, graph, num_classes, split_idx)) + p.start() + procs.append(p) + for p in procs: + p.join() diff --git a/examples/pytorch/__temporary__/graphsage/disk_storage.py b/examples/pytorch/__temporary__/graphsage/disk_storage.py new file mode 100644 index 000000000000..00f3eee0d608 --- /dev/null +++ b/examples/pytorch/__temporary__/graphsage/disk_storage.py @@ -0,0 +1,86 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchmetrics.functional as MF +import dgl +import dgl.nn as dglnn +import time +import numpy as np +# OGB must follow DGL if both DGL and PyG are installed. Otherwise DataLoader will hang. +# (This is a long-standing issue) +from ogb.nodeproppred import DglNodePropPredDataset + +import dglnew + +class SAGE(nn.Module): + def __init__(self, in_feats, n_hidden, n_classes): + super().__init__() + self.layers = nn.ModuleList() + self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean')) + self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean')) + self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean')) + self.dropout = nn.Dropout(0.5) + + def forward(self, blocks, x): + h = x + for l, (layer, block) in enumerate(zip(self.layers, blocks)): + h = layer(block, h) + if l != len(self.layers) - 1: + h = F.relu(h) + h = self.dropout(h) + return h + +dataset = DglNodePropPredDataset('ogbn-products') +graph, labels = dataset[0] +split_idx = dataset.get_idx_split() +train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test'] + +# This is an example of using feature storage other than tensors +feat_np = graph.ndata['feat'].numpy() +feat = np.memmap('feat.npy', mode='w+', shape=feat_np.shape, dtype='float32') +print(feat.shape) +feat[:] = feat_np + +model = SAGE(feat.shape[1], 256, dataset.num_classes).cuda() +opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) + +graph.create_formats_() +# Because NumpyStorage is registered with memmap, one can directly add numpy memmaps +graph = dglnew.graph.OtherFeatureGraphStorage(graph, ndata={'feat': feat, 'label': labels}) +#graph = dglnew.graph.OtherFeatureGraphStorage(graph, +# ndata={'feat': dgl.storages.NumpyStorage(feat), 'label': labels}) + +sampler = dgl.dataloading.NeighborSampler( + [5, 5, 5], output_device='cpu', prefetch_node_feats=['feat'], + prefetch_labels=['label']) +dataloader = dgl.dataloading.NodeDataLoader( + graph, + train_idx, + sampler, + device='cuda', + batch_size=1000, + shuffle=True, + drop_last=False, + pin_memory=True, + num_workers=4, + use_prefetch_thread=True) # TBD: could probably remove this argument + +durations = [] +for _ in range(10): + t0 = time.time() + for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader): + x = blocks[0].srcdata['feat'] + y = blocks[-1].dstdata['label'][:, 0] + y_hat = model(blocks, x) + loss = F.cross_entropy(y_hat, y) + opt.zero_grad() + loss.backward() + opt.step() + if it % 20 == 0: + acc = MF.accuracy(y_hat, y) + mem = torch.cuda.max_memory_allocated() / 1000000 + print('Loss', loss.item(), 'Acc', acc.item(), 'GPU Mem', mem, 'MB') + tt = time.time() + print(tt - t0) + durations.append(tt - t0) +print(np.mean(durations[4:]), np.std(durations[4:])) diff --git a/examples/pytorch/__temporary__/graphsage/link_pred.py b/examples/pytorch/__temporary__/graphsage/link_pred.py new file mode 100644 index 000000000000..601044a69f38 --- /dev/null +++ b/examples/pytorch/__temporary__/graphsage/link_pred.py @@ -0,0 +1,94 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchmetrics.functional as MF +import dgl +import dgl.nn as dglnn +import time +import numpy as np +# OGB must follow DGL if both DGL and PyG are installed. Otherwise DataLoader will hang. +# (This is a long-standing issue) +from ogb.nodeproppred import DglNodePropPredDataset + +USE_WRAPPER = True + +class SAGE(nn.Module): + def __init__(self, in_feats, n_hidden, n_classes): + super().__init__() + self.layers = nn.ModuleList() + self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean')) + self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean')) + self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean')) + self.dropout = nn.Dropout(0.5) + + def forward(self, pair_graph, neg_pair_graph, blocks, x): + h = x + for l, (layer, block) in enumerate(zip(self.layers, blocks)): + h = layer(block, h) + if l != len(self.layers) - 1: + h = F.relu(h) + h = self.dropout(h) + with pair_graph.local_scope(), neg_pair_graph.local_scope(): + pair_graph.ndata['h'] = neg_pair_graph.ndata['h'] = h + pair_graph.apply_edges(dgl.function.u_dot_v('h', 'h', 's')) + neg_pair_graph.apply_edges(dgl.function.u_dot_v('h', 'h', 's')) + return pair_graph.edata['s'], neg_pair_graph.edata['s'] + +dataset = DglNodePropPredDataset('ogbn-products') +graph, labels = dataset[0] +graph.ndata['label'] = labels +split_idx = dataset.get_idx_split() +train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test'] + +model = SAGE(graph.ndata['feat'].shape[1], 256, dataset.num_classes).cuda() +opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) + +num_edges = graph.num_edges() +train_eids = torch.arange(num_edges) +if USE_WRAPPER: + import dglnew + graph.create_formats_() + graph = dglnew.graph.wrapper.DGLGraphStorage(graph) + +sampler = dgl.dataloading.NeighborSampler( + [5, 5, 5], output_device='cpu', prefetch_node_feats=['feat'], + prefetch_labels=['label']) +dataloader = dgl.dataloading.EdgeDataLoader( + graph, + train_eids, + sampler, + device='cuda', + batch_size=1000, + shuffle=True, + drop_last=False, + pin_memory=True, + num_workers=8, + persistent_workers=True, + use_prefetch_thread=True, # TBD: could probably remove this argument + exclude='reverse_id', + reverse_eids=torch.arange(num_edges) ^ 1, + negative_sampler=dgl.dataloading.negative_sampler.Uniform(5)) + +durations = [] +for _ in range(10): + t0 = time.time() + for it, (input_nodes, pair_graph, neg_pair_graph, blocks) in enumerate(dataloader): + x = blocks[0].srcdata['feat'] + pos_score, neg_score = model(pair_graph, neg_pair_graph, blocks, x) + pos_label = torch.ones_like(pos_score) + neg_label = torch.zeros_like(neg_score) + score = torch.cat([pos_score, neg_score]) + labels = torch.cat([pos_label, neg_label]) + loss = F.binary_cross_entropy_with_logits(score, labels) + opt.zero_grad() + loss.backward() + opt.step() + if it % 20 == 0: + acc = MF.auroc(score, labels.long()) + mem = torch.cuda.max_memory_allocated() / 1000000 + print('Loss', loss.item(), 'Acc', acc.item(), 'GPU Mem', mem, 'MB') + tt = time.time() + print(tt - t0) + t0 = time.time() + durations.append(tt - t0) +print(np.mean(durations[4:]), np.std(durations[4:])) diff --git a/examples/pytorch/__temporary__/graphsage/normal.py b/examples/pytorch/__temporary__/graphsage/normal.py new file mode 100644 index 000000000000..6f82809b4d3a --- /dev/null +++ b/examples/pytorch/__temporary__/graphsage/normal.py @@ -0,0 +1,79 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchmetrics.functional as MF +import dgl +import dgl.nn as dglnn +import time +import numpy as np +from ogb.nodeproppred import DglNodePropPredDataset + +USE_WRAPPER = True + +class SAGE(nn.Module): + def __init__(self, in_feats, n_hidden, n_classes): + super().__init__() + self.layers = nn.ModuleList() + self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean')) + self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean')) + self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean')) + self.dropout = nn.Dropout(0.5) + + def forward(self, blocks, x): + h = x + for l, (layer, block) in enumerate(zip(self.layers, blocks)): + h = layer(block, h) + if l != len(self.layers) - 1: + h = F.relu(h) + h = self.dropout(h) + return h + +dataset = DglNodePropPredDataset('ogbn-products') +graph, labels = dataset[0] +graph.ndata['label'] = labels +split_idx = dataset.get_idx_split() +train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test'] + +model = SAGE(graph.ndata['feat'].shape[1], 256, dataset.num_classes).cuda() +opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) + +if USE_WRAPPER: + import dglnew + graph.create_formats_() + graph = dglnew.graph.wrapper.DGLGraphStorage(graph) + +sampler = dgl.dataloading.NeighborSampler( + [5, 5, 5], output_device='cpu', prefetch_node_feats=['feat'], + prefetch_labels=['label']) +dataloader = dgl.dataloading.NodeDataLoader( + graph, + train_idx, + sampler, + device='cuda', + batch_size=1000, + shuffle=True, + drop_last=False, + pin_memory=True, + num_workers=16, + persistent_workers=True, + use_prefetch_thread=True) # TBD: could probably remove this argument + +durations = [] +for _ in range(10): + t0 = time.time() + for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader): + x = blocks[0].srcdata['feat'] + y = blocks[-1].dstdata['label'][:, 0] + y_hat = model(blocks, x) + loss = F.cross_entropy(y_hat, y) + opt.zero_grad() + loss.backward() + opt.step() + if it % 20 == 0: + acc = MF.accuracy(y_hat, y) + mem = torch.cuda.max_memory_allocated() / 1000000 + print('Loss', loss.item(), 'Acc', acc.item(), 'GPU Mem', mem, 'MB') + tt = time.time() + print(tt - t0) + durations.append(tt - t0) +print(np.mean(durations[4:]), np.std(durations[4:])) diff --git a/examples/pytorch/__temporary__/rgat/rgat.py b/examples/pytorch/__temporary__/rgat/rgat.py new file mode 100644 index 000000000000..3a4364b2ccc7 --- /dev/null +++ b/examples/pytorch/__temporary__/rgat/rgat.py @@ -0,0 +1,105 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchmetrics.functional as MF +import dgl +import dgl.function as fn +import dgl.nn as dglnn +from dgl.utils import recursive_apply +import time +import numpy as np +from ogb.nodeproppred import DglNodePropPredDataset +import tqdm + +USE_WRAPPER = True + +class HeteroGAT(nn.Module): + def __init__(self, etypes, in_feats, n_hidden, n_classes, n_heads=4): + super().__init__() + self.layers = nn.ModuleList() + self.layers.append(dglnn.HeteroGraphConv({ + etype: dglnn.GATConv(in_feats, n_hidden // n_heads, n_heads) + for etype in etypes})) + self.layers.append(dglnn.HeteroGraphConv({ + etype: dglnn.GATConv(n_hidden, n_hidden // n_heads, n_heads) + for etype in etypes})) + self.layers.append(dglnn.HeteroGraphConv({ + etype: dglnn.GATConv(n_hidden, n_hidden // n_heads, n_heads) + for etype in etypes})) + self.dropout = nn.Dropout(0.5) + self.linear = nn.Linear(n_hidden, n_classes) # Should be HeteroLinear + + def forward(self, blocks, x): + h = x + for l, (layer, block) in enumerate(zip(self.layers, blocks)): + h = layer(block, h) + # One thing is that h might return tensors with zero rows if the number of dst nodes + # of one node type is 0. x.view(x.shape[0], -1) wouldn't work in this case. + h = recursive_apply(h, lambda x: x.view(x.shape[0], x.shape[1] * x.shape[2])) + if l != len(self.layers) - 1: + h = recursive_apply(h, F.relu) + h = recursive_apply(h, self.dropout) + return self.linear(h['paper']) + +dataset = DglNodePropPredDataset('ogbn-mag') + +graph, labels = dataset[0] +graph.ndata['label'] = labels +# Preprocess: add reverse edges in "cites" relation, and add reverse edge types for the +# rest. +graph = dgl.AddReverse()(graph) +# Preprocess: precompute the author, topic, and institution features +graph.update_all(fn.copy_u('feat', 'm'), fn.mean('m', 'feat'), etype='rev_writes') +graph.update_all(fn.copy_u('feat', 'm'), fn.mean('m', 'feat'), etype='has_topic') +graph.update_all(fn.copy_u('feat', 'm'), fn.mean('m', 'feat'), etype='affiliated_with') +graph.edges['cites'].data['weight'] = torch.ones(graph.num_edges('cites')) # dummy edge weights + +model = HeteroGAT(graph.etypes, graph.ndata['feat']['paper'].shape[1], 256, dataset.num_classes).cuda() +opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) + +if USE_WRAPPER: + import dglnew + graph.create_formats_() + graph = dglnew.graph.wrapper.DGLGraphStorage(graph) + +split_idx = dataset.get_idx_split() +train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test'] + +sampler = dgl.dataloading.NeighborSampler( + [5, 5, 5], output_device='cpu', + prefetch_node_feats={k: ['feat'] for k in graph.ntypes}, + prefetch_labels={'paper': ['label']}, + prefetch_edge_feats={'cites': ['weight']}) +dataloader = dgl.dataloading.NodeDataLoader( + graph, + train_idx, + sampler, + device='cuda', + batch_size=1000, + shuffle=True, + drop_last=False, + pin_memory=True, + num_workers=8, + persistent_workers=True, + use_prefetch_thread=True) # TBD: could probably remove this argument + +durations = [] +for _ in range(10): + t0 = time.time() + for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader): + x = blocks[0].srcdata['feat'] + y = blocks[-1].dstdata['label']['paper'][:, 0] + assert y.min() >= 0 and y.max() < dataset.num_classes + y_hat = model(blocks, x) + loss = F.cross_entropy(y_hat, y) + opt.zero_grad() + loss.backward() + opt.step() + if it % 20 == 0: + acc = MF.accuracy(y_hat, y) + mem = torch.cuda.max_memory_allocated() / 1000000 + print('Loss', loss.item(), 'Acc', acc.item(), 'GPU Mem', mem, 'MB') + tt = time.time() + print(tt - t0) + durations.append(tt - t0) +print(np.mean(durations[4:]), np.std(durations[4:])) diff --git a/examples/pytorch/graphsage/experimental/train_dist_unsupervised.py b/examples/pytorch/graphsage/experimental/train_dist_unsupervised.py index 0d173fbb22d6..f5e785adcade 100644 --- a/examples/pytorch/graphsage/experimental/train_dist_unsupervised.py +++ b/examples/pytorch/graphsage/experimental/train_dist_unsupervised.py @@ -69,7 +69,7 @@ def inference(self, g, x, batch_size, device): y = th.zeros(g.number_of_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes) sampler = dgl.dataloading.MultiLayerNeighborSampler([None]) - dataloader = dgl.dataloading.NodeDataLoader( + dataloader = dgl.dataloading.DistNodeDataLoader( g, th.arange(g.number_of_nodes()), sampler, diff --git a/examples/pytorch/rgcn/experimental/entity_classify_dist.py b/examples/pytorch/rgcn/experimental/entity_classify_dist.py index b184c11c9fb9..4dfe4f532dd1 100644 --- a/examples/pytorch/rgcn/experimental/entity_classify_dist.py +++ b/examples/pytorch/rgcn/experimental/entity_classify_dist.py @@ -366,7 +366,7 @@ def run(args, device, data): val_fanouts = [int(fanout) for fanout in args.validation_fanout.split(',')] sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts) - dataloader = dgl.dataloading.NodeDataLoader( + dataloader = dgl.dataloading.DistNodeDataLoader( g, {'paper': train_nid}, sampler, @@ -375,7 +375,7 @@ def run(args, device, data): drop_last=False) valid_sampler = dgl.dataloading.MultiLayerNeighborSampler(val_fanouts) - valid_dataloader = dgl.dataloading.NodeDataLoader( + valid_dataloader = dgl.dataloading.DistNodeDataLoader( g, {'paper': val_nid}, valid_sampler, @@ -384,7 +384,7 @@ def run(args, device, data): drop_last=False) test_sampler = dgl.dataloading.MultiLayerNeighborSampler(val_fanouts) - test_dataloader = dgl.dataloading.NodeDataLoader( + test_dataloader = dgl.dataloading.DistNodeDataLoader( g, {'paper': test_nid}, test_sampler, diff --git a/examples/pytorch/tgn/dataloading.py b/examples/pytorch/tgn/dataloading.py index ee37c5f60b6c..8e273407a697 100644 --- a/examples/pytorch/tgn/dataloading.py +++ b/examples/pytorch/tgn/dataloading.py @@ -287,6 +287,9 @@ def __init__(self, g, eids, graph_sampler, device='cpu', collator=TemporalEdgeCo if dataloader_kwargs.get('num_workers', 0) > 0: g.create_formats_() + def __iter__(self): + return iter(self.dataloader) + # ====== Fast Mode ====== # Part of code in reservoir sampling comes from PyG library diff --git a/examples/pytorch/tgn/train.py b/examples/pytorch/tgn/train.py index 456a7d94a019..10d766da6b7f 100644 --- a/examples/pytorch/tgn/train.py +++ b/examples/pytorch/tgn/train.py @@ -292,7 +292,7 @@ def test_val(model, dataloader, sampler, criterion, args): if i < args.epochs-1 and args.fast_mode: sampler.reset() print(log_content[0], log_content[1], log_content[2]) - except: + except KeyboardInterrupt: traceback.print_exc() error_content = "Training Interreputed!" f.writelines(error_content) diff --git a/python/dgl/__init__.py b/python/dgl/__init__.py index 848a540460ee..edc94a7a18cc 100644 --- a/python/dgl/__init__.py +++ b/python/dgl/__init__.py @@ -21,9 +21,11 @@ from . import distributed from . import random from . import sampling +from . import storages from . import dataloading from . import ops from . import cuda +from . import _dataloading # legacy dataloading modules from ._ffi.runtime_ctypes import TypeCode from ._ffi.function import register_func, get_global_func, list_global_func_names, extract_ext_funcs diff --git a/python/dgl/_dataloading/__init__.py b/python/dgl/_dataloading/__init__.py new file mode 100644 index 000000000000..c71c65dfc54a --- /dev/null +++ b/python/dgl/_dataloading/__init__.py @@ -0,0 +1,28 @@ +"""The ``dgl.dataloading`` package contains: + +* Data loader classes for iterating over a set of nodes or edges in a graph and generates + computation dependency via neighborhood sampling methods. + +* Various sampler classes that perform neighborhood sampling for multi-layer GNNs. + +* Negative samplers for link prediction. + +For a holistic explanation on how different components work together. +Read the user guide :ref:`guide-minibatch`. + +.. note:: + This package is experimental and the interfaces may be subject + to changes in future releases. It currently only has implementations in PyTorch. +""" +from .neighbor import * +from .dataloader import * +from .cluster_gcn import * +from .shadow import * + +from . import negative_sampler +from .async_transferer import AsyncTransferer + +from .. import backend as F + +if F.get_preferred_backend() == 'pytorch': + from .pytorch import * diff --git a/python/dgl/dataloading/async_transferer.py b/python/dgl/_dataloading/async_transferer.py similarity index 97% rename from python/dgl/dataloading/async_transferer.py rename to python/dgl/_dataloading/async_transferer.py index ac201992448b..5b4332e59d3d 100644 --- a/python/dgl/dataloading/async_transferer.py +++ b/python/dgl/_dataloading/async_transferer.py @@ -101,4 +101,4 @@ def async_copy(self, tensor, device): return Transfer(transfer_id, self._handle) -_init_api("dgl.dataloading.async_transferer") +_init_api("dataloading.async_transferer", "dgl._dataloading.async_transferer") diff --git a/python/dgl/_dataloading/cluster_gcn.py b/python/dgl/_dataloading/cluster_gcn.py new file mode 100644 index 000000000000..6df0772e2df3 --- /dev/null +++ b/python/dgl/_dataloading/cluster_gcn.py @@ -0,0 +1,83 @@ +"""Cluster-GCN subgraph iterators.""" +import os +import pickle +import numpy as np + +from ..transform import metis_partition_assignment +from .. import backend as F +from .dataloader import SubgraphIterator + +class ClusterGCNSubgraphIterator(SubgraphIterator): + """Subgraph sampler following that of ClusterGCN. + + This sampler first partitions the graph with METIS partitioning, then it caches the nodes of + each partition to a file within the given cache directory. + + This is used in conjunction with :class:`dgl.dataloading.pytorch.GraphDataLoader`. + + Notes + ----- + The graph must be homogeneous and on CPU. + + Parameters + ---------- + g : DGLGraph + The original graph. + num_partitions : int + The number of partitions. + cache_directory : str + The path to the cache directory for storing the partition result. + refresh : bool + If True, recompute the partition. + + Examples + -------- + Assuming that you have a graph ``g``: + + >>> sgiter = dgl.dataloading.ClusterGCNSubgraphIterator( + ... g, num_partitions=100, cache_directory='.', refresh=True) + >>> dataloader = dgl.dataloading.GraphDataLoader(sgiter, batch_size=4, num_workers=0) + >>> for subgraph_batch in dataloader: + ... train_on(subgraph_batch) + """ + def __init__(self, g, num_partitions, cache_directory, refresh=False): + if os.name == 'nt': + raise NotImplementedError("METIS partitioning is not supported on Windows yet.") + super().__init__(g) + + # First see if the cache is already there. If so, directly read from cache. + if not refresh and self._load_parts(cache_directory): + return + + # Otherwise, build the cache. + assignment = F.asnumpy(metis_partition_assignment(g, num_partitions)) + self._save_parts(assignment, cache_directory) + + def _cache_file_path(self, cache_directory): + return os.path.join(cache_directory, 'cluster_gcn_cache') + + def _load_parts(self, cache_directory): + path = self._cache_file_path(cache_directory) + if not os.path.exists(path): + return False + + with open(path, 'rb') as file_: + self.part_indptr, self.part_indices = pickle.load(file_) + return True + + def _save_parts(self, assignment, cache_directory): + os.makedirs(cache_directory, exist_ok=True) + + self.part_indices = np.argsort(assignment) + num_nodes_per_part = np.bincount(assignment) + self.part_indptr = np.insert(np.cumsum(num_nodes_per_part), 0, 0) + + with open(self._cache_file_path(cache_directory), 'wb') as file_: + pickle.dump((self.part_indptr, self.part_indices), file_) + + def __len__(self): + return self.part_indptr.shape[0] - 1 + + def __getitem__(self, i): + nodes = self.part_indices[self.part_indptr[i]:self.part_indptr[i+1]] + return self.g.subgraph(nodes) diff --git a/python/dgl/_dataloading/dataloader.py b/python/dgl/_dataloading/dataloader.py new file mode 100644 index 000000000000..acba119891d2 --- /dev/null +++ b/python/dgl/_dataloading/dataloader.py @@ -0,0 +1,982 @@ +"""Data loaders""" + +from collections.abc import Mapping, Sequence +from abc import ABC, abstractproperty, abstractmethod +import re +import numpy as np +from .. import transform +from ..base import NID, EID +from .. import backend as F +from .. import utils +from ..batch import batch +from ..convert import heterograph +from ..heterograph import DGLHeteroGraph as DGLGraph +from ..distributed.dist_graph import DistGraph +from ..utils import to_device + +def _tensor_or_dict_to_numpy(ids): + if isinstance(ids, Mapping): + return {k: F.zerocopy_to_numpy(v) for k, v in ids.items()} + else: + return F.zerocopy_to_numpy(ids) + +def _locate_eids_to_exclude(frontier_parent_eids, exclude_eids): + """Find the edges whose IDs in parent graph appeared in exclude_eids. + + Note that both arguments are numpy arrays or numpy dicts. + """ + if isinstance(frontier_parent_eids, Mapping): + result = { + k: np.isin(frontier_parent_eids[k], exclude_eids[k]).nonzero()[0] + for k in frontier_parent_eids.keys() if k in exclude_eids.keys()} + return {k: F.zerocopy_from_numpy(v) for k, v in result.items()} + else: + result = np.isin(frontier_parent_eids, exclude_eids).nonzero()[0] + return F.zerocopy_from_numpy(result) + +class _EidExcluder(): + def __init__(self, exclude_eids): + device = None + if isinstance(exclude_eids, Mapping): + for _, v in exclude_eids.items(): + if device is None: + device = F.context(v) + break + else: + device = F.context(exclude_eids) + self._exclude_eids = None + self._filter = None + + if device == F.cpu(): + # TODO(nv-dlasalle): Once Filter is implemented for the CPU, we + # should just use that irregardless of the device. + self._exclude_eids = ( + _tensor_or_dict_to_numpy(exclude_eids) if exclude_eids is not None else None) + else: + if isinstance(exclude_eids, Mapping): + self._filter = {k: utils.Filter(v) for k, v in exclude_eids.items()} + else: + self._filter = utils.Filter(exclude_eids) + + def _find_indices(self, parent_eids): + """ Find the set of edge indices to remove. + """ + if self._exclude_eids is not None: + parent_eids_np = _tensor_or_dict_to_numpy(parent_eids) + return _locate_eids_to_exclude(parent_eids_np, self._exclude_eids) + else: + assert self._filter is not None + if isinstance(parent_eids, Mapping): + located_eids = {k: self._filter[k].find_included_indices(parent_eids[k]) + for k, v in parent_eids.items() if k in self._filter} + else: + located_eids = self._filter.find_included_indices(parent_eids) + return located_eids + + def __call__(self, frontier): + parent_eids = frontier.edata[EID] + located_eids = self._find_indices(parent_eids) + + if not isinstance(located_eids, Mapping): + # (BarclayII) If frontier already has a EID field and located_eids is empty, + # the returned graph will keep EID intact. Otherwise, EID will change + # to the mapping from the new graph to the old frontier. + # So we need to test if located_eids is empty, and do the remapping ourselves. + if len(located_eids) > 0: + frontier = transform.remove_edges( + frontier, located_eids, store_ids=True) + frontier.edata[EID] = F.gather_row(parent_eids, frontier.edata[EID]) + else: + # (BarclayII) remove_edges only accepts removing one type of edges, + # so I need to keep track of the edge IDs left one by one. + new_eids = parent_eids.copy() + for k, v in located_eids.items(): + if len(v) > 0: + frontier = transform.remove_edges( + frontier, v, etype=k, store_ids=True) + new_eids[k] = F.gather_row(parent_eids[k], frontier.edges[k].data[EID]) + frontier.edata[EID] = new_eids + return frontier + + +def exclude_edges(subg, exclude_eids, device): + """Find and remove from the subgraph the edges whose IDs in the parent + graph are given. + + Parameters + ---------- + subg : DGLGraph + The subgraph. Must have ``dgl.EID`` field containing the original + edge IDs in the parent graph. + exclude_eids : Tensor or dict + The edge IDs to exclude. + device : device + The output device of the graph. + + Returns + ------- + DGLGraph + The new subgraph with edges removed. The ``dgl.EID`` field contains + the original edge IDs in the same parent graph. + """ + if exclude_eids is None: + return subg + + if device is not None: + if isinstance(exclude_eids, Mapping): + exclude_eids = {k: F.copy_to(v, device) \ + for k, v in exclude_eids.items()} + else: + exclude_eids = F.copy_to(exclude_eids, device) + + excluder = _EidExcluder(exclude_eids) + return subg if excluder is None else excluder(subg) + + +def _find_exclude_eids_with_reverse_id(g, eids, reverse_eid_map): + if isinstance(eids, Mapping): + eids = {g.to_canonical_etype(k): v for k, v in eids.items()} + exclude_eids = { + k: F.cat([v, F.gather_row(reverse_eid_map[k], v)], 0) + for k, v in eids.items()} + else: + exclude_eids = F.cat([eids, F.gather_row(reverse_eid_map, eids)], 0) + return exclude_eids + +def _find_exclude_eids_with_reverse_types(g, eids, reverse_etype_map): + exclude_eids = {g.to_canonical_etype(k): v for k, v in eids.items()} + reverse_etype_map = { + g.to_canonical_etype(k): g.to_canonical_etype(v) + for k, v in reverse_etype_map.items()} + exclude_eids.update({reverse_etype_map[k]: v for k, v in exclude_eids.items()}) + return exclude_eids + +def _find_exclude_eids(g, exclude_mode, eids, **kwargs): + """Find all edge IDs to exclude according to :attr:`exclude_mode`. + + Parameters + ---------- + g : DGLGraph + The graph. + exclude_mode : str, optional + Can be either of the following, + + None (default) + Does not exclude any edge. + + 'self' + Exclude the given edges themselves but nothing else. + + 'reverse_id' + Exclude all edges specified in ``eids``, as well as their reverse edges + of the same edge type. + + The mapping from each edge ID to its reverse edge ID is specified in + the keyword argument ``reverse_eid_map``. + + This mode assumes that the reverse of an edge with ID ``e`` and type + ``etype`` will have ID ``reverse_eid_map[e]`` and type ``etype``. + + 'reverse_types' + Exclude all edges specified in ``eids``, as well as their reverse + edges of the corresponding edge types. + + The mapping from each edge type to its reverse edge type is specified + in the keyword argument ``reverse_etype_map``. + + This mode assumes that the reverse of an edge with ID ``e`` and type ``etype`` + will have ID ``e`` and type ``reverse_etype_map[etype]``. + eids : Tensor or dict[etype, Tensor] + The edge IDs. + reverse_eid_map : Tensor or dict[etype, Tensor] + The mapping from edge ID to its reverse edge ID. + reverse_etype_map : dict[etype, etype] + The mapping from edge etype to its reverse edge type. + """ + if exclude_mode is None: + return None + elif exclude_mode == 'self': + return eids + elif exclude_mode == 'reverse_id': + return _find_exclude_eids_with_reverse_id(g, eids, kwargs['reverse_eid_map']) + elif exclude_mode == 'reverse_types': + return _find_exclude_eids_with_reverse_types(g, eids, kwargs['reverse_etype_map']) + else: + raise ValueError('unsupported mode {}'.format(exclude_mode)) + +class Sampler(object): + """An abstract class that takes in a graph and a set of seed nodes and returns a + structure representing a smaller portion of the graph for computation. It can + be either a list of bipartite graphs (i.e. :class:`BlockSampler`), or a single + subgraph. + """ + def __init__(self, output_ctx=None): + self.set_output_context(output_ctx) + + def sample(self, g, seed_nodes, exclude_eids=None): + """Sample a structure from the graph. + + Parameters + ---------- + g : DGLGraph + The original graph. + seed_nodes : Tensor or dict[ntype, Tensor] + The destination nodes by type. + + If the graph only has one node type, one can just specify a single tensor + of node IDs. + exclude_eids : Tensor or dict[etype, Tensor] + The edges to exclude from computation dependency. + + Returns + ------- + Tensor or dict[ntype, Tensor] + The nodes whose input features are required for computing the output + representation of :attr:`seed_nodes`. + any + Any data representing the structure. + """ + raise NotImplementedError + + def set_output_context(self, ctx): + """Set the device the generated block or subgraph will be output to. + This should only be set to a cuda device, when multi-processing is not + used in the dataloader (e.g., num_workers is 0). + + Parameters + ---------- + ctx : DGLContext, default None + The device context the sampled blocks will be stored on. This + should only be a CUDA context if multiprocessing is not used in + the dataloader (e.g., num_workers is 0). If this is None, the + sampled blocks will be stored on the same device as the input + graph. + """ + if ctx is not None: + self.output_device = F.to_backend_ctx(ctx) + else: + self.output_device = None + +class BlockSampler(Sampler): + """Abstract class specifying the neighborhood sampling strategy for DGL data loaders. + + The main method for BlockSampler is :meth:`sample`, + which generates a list of message flow graphs (MFGs) for a multi-layer GNN given a set of + seed nodes to have their outputs computed. + + The default implementation of :meth:`sample` is + to repeat :attr:`num_layers` times the following procedure from the last layer to the first + layer: + + * Obtain a frontier. The frontier is defined as a graph with the same nodes as the + original graph but only the edges involved in message passing on the current layer. + Customizable via :meth:`sample_frontier`. + + * Optionally, if the task is link prediction or edge classfication, remove edges + connecting training node pairs. If the graph is undirected, also remove the + reverse edges. This is controlled by the argument :attr:`exclude_eids` in + :meth:`sample` method. + + * Convert the frontier into a MFG. + + * Optionally assign the IDs of the edges in the original graph selected in the first step + to the MFG, controlled by the argument ``return_eids`` in + :meth:`sample` method. + + * Prepend the MFG to the MFG list to be returned. + + All subclasses should override :meth:`sample_frontier` + method while specifying the number of layers to sample in :attr:`num_layers` argument. + + Parameters + ---------- + num_layers : int + The number of layers to sample. + return_eids : bool, default False + Whether to return the edge IDs involved in message passing in the MFG. + If True, the edge IDs will be stored as an edge feature named ``dgl.EID``. + output_ctx : DGLContext, default None + The context the sampled blocks will be stored on. This should only be + a CUDA context if multiprocessing is not used in the dataloader (e.g., + num_workers is 0). If this is None, the sampled blocks will be stored + on the same device as the input graph. + exclude_edges_in_frontier : bool, default False + If True, the :func:`sample_frontier` method will receive an argument + :attr:`exclude_eids` containing the edge IDs from the original graph to exclude. + The :func:`sample_frontier` method must return a graph that does not contain + the edges corresponding to the excluded edges. No additional postprocessing + will be done. + + Otherwise, the edges will be removed *after* :func:`sample_frontier` returns. + + Notes + ----- + For the concept of frontiers and MFGs, please refer to + :ref:`User Guide Section 6 ` and + :doc:`Minibatch Training Tutorials `. + """ + def __init__(self, num_layers, return_eids=False, output_ctx=None): + super().__init__(output_ctx) + self.num_layers = num_layers + self.return_eids = return_eids + + # pylint: disable=unused-argument + @staticmethod + def assign_block_eids(block, frontier): + """Assigns edge IDs from the original graph to the message flow graph (MFG). + + See also + -------- + BlockSampler + """ + for etype in block.canonical_etypes: + block.edges[etype].data[EID] = frontier.edges[etype].data[EID][ + block.edges[etype].data[EID]] + return block + + # This is really a hack working around the lack of GPU-based neighbor sampling + # with edge exclusion. + @classmethod + def exclude_edges_in_frontier(cls, g): + """Returns whether the sampler will exclude edges in :func:`sample_frontier`. + + If this method returns True, the method :func:`sample_frontier` will receive an + argument :attr:`exclude_eids` from :func:`sample`. :func:`sample_frontier` + is then responsible for removing those edges. + + If this method returns False, :func:`sample` will be responsible for + removing the edges. + + When subclassing :class:`BlockSampler`, this method should return True when you + would like to remove the excluded edges in your :func:`sample_frontier` method. + + By default this method returns False. + + Parameters + ---------- + g : DGLGraph + The original graph + + Returns + ------- + bool + Whether :func:`sample_frontier` will receive an argument :attr:`exclude_eids`. + """ + return False + + def sample_frontier(self, block_id, g, seed_nodes, exclude_eids=None): + """Generate the frontier given the destination nodes. + + The subclasses should override this function. + + Parameters + ---------- + block_id : int + Represents which GNN layer the frontier is generated for. + g : DGLGraph + The original graph. + seed_nodes : Tensor or dict[ntype, Tensor] + The destination nodes by node type. + + If the graph only has one node type, one can just specify a single tensor + of node IDs. + exclude_eids: Tensor or dict + Edge IDs to exclude during sampling neighbors for the seed nodes. + + This argument can take a single ID tensor or a dictionary of edge types and ID tensors. + If a single tensor is given, the graph must only have one type of nodes. + + Returns + ------- + DGLGraph + The frontier generated for the current layer. + + Notes + ----- + For the concept of frontiers and MFGs, please refer to + :ref:`User Guide Section 6 ` and + :doc:`Minibatch Training Tutorials `. + """ + raise NotImplementedError + + def sample(self, g, seed_nodes, exclude_eids=None): + """Generate the a list of MFGs given the destination nodes. + + Parameters + ---------- + g : DGLGraph + The original graph. + seed_nodes : Tensor or dict[ntype, Tensor] + The destination nodes by node type. + + If the graph only has one node type, one can just specify a single tensor + of node IDs. + exclude_eids : Tensor or dict[etype, Tensor] + The edges to exclude from computation dependency. + + Returns + ------- + list[DGLGraph] + The MFGs generated for computing the multi-layer GNN output. + + Notes + ----- + For the concept of frontiers and MFGs, please refer to + :ref:`User Guide Section 6 ` and + :doc:`Minibatch Training Tutorials `. + """ + blocks = [] + + if isinstance(g, DistGraph): + # TODO:(nv-dlasalle) dist graphs may not have an associated graph, + # causing an error when trying to fetch the device, so for now, + # always assume the distributed graph's device is CPU. + graph_device = F.cpu() + else: + graph_device = g.device + + for block_id in reversed(range(self.num_layers)): + seed_nodes_in = to_device(seed_nodes, graph_device) + + if self.exclude_edges_in_frontier(g): + frontier = self.sample_frontier( + block_id, g, seed_nodes_in, exclude_eids=exclude_eids) + else: + frontier = self.sample_frontier(block_id, g, seed_nodes_in) + + if self.output_device is not None: + frontier = frontier.to(self.output_device) + seed_nodes_out = to_device(seed_nodes, self.output_device) + else: + seed_nodes_out = seed_nodes + + # Removing edges from the frontier for link prediction training falls + # into the category of frontier postprocessing + if not self.exclude_edges_in_frontier(g): + frontier = exclude_edges(frontier, exclude_eids, self.output_device) + + block = transform.to_block(frontier, seed_nodes_out) + if self.return_eids: + self.assign_block_eids(block, frontier) + + seed_nodes = {ntype: block.srcnodes[ntype].data[NID] for ntype in block.srctypes} + blocks.insert(0, block) + return blocks[0].srcdata[NID], blocks[-1].dstdata[NID], blocks + + def sample_blocks(self, g, seed_nodes, exclude_eids=None): + """Deprecated and identical to :meth:`sample`. + """ + return self.sample(g, seed_nodes, exclude_eids) + +class Collator(ABC): + """Abstract DGL collator for training GNNs on downstream tasks stochastically. + + Provides a :attr:`dataset` object containing the collection of all nodes or edges, + as well as a :attr:`collate` method that combines a set of items from + :attr:`dataset` and obtains the message flow graphs (MFGs). + + Notes + ----- + For the concept of MFGs, please refer to + :ref:`User Guide Section 6 ` and + :doc:`Minibatch Training Tutorials `. + """ + @abstractproperty + def dataset(self): + """Returns the dataset object of the collator.""" + raise NotImplementedError + + @abstractmethod + def collate(self, items): + """Combines the items from the dataset object and obtains the list of MFGs. + + Parameters + ---------- + items : list[str, int] + The list of node or edge IDs or type-ID pairs. + + Notes + ----- + For the concept of MFGs, please refer to + :ref:`User Guide Section 6 ` and + :doc:`Minibatch Training Tutorials `. + """ + raise NotImplementedError + +class NodeCollator(Collator): + """DGL collator to combine nodes and their computation dependencies within a minibatch for + training node classification or regression on a single graph with neighborhood sampling. + + Parameters + ---------- + g : DGLGraph + The graph. + nids : Tensor or dict[ntype, Tensor] + The node set to compute outputs. + graph_sampler : dgl.dataloading.BlockSampler + The neighborhood sampler. + + Examples + -------- + To train a 3-layer GNN for node classification on a set of nodes ``train_nid`` on + a homogeneous graph where each node takes messages from all neighbors (assume + the backend is PyTorch): + + >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5]) + >>> collator = dgl.dataloading.NodeCollator(g, train_nid, sampler) + >>> dataloader = torch.utils.data.DataLoader( + ... collator.dataset, collate_fn=collator.collate, + ... batch_size=1024, shuffle=True, drop_last=False, num_workers=4) + >>> for input_nodes, output_nodes, blocks in dataloader: + ... train_on(input_nodes, output_nodes, blocks) + + Notes + ----- + For the concept of MFGs, please refer to + :ref:`User Guide Section 6 ` and + :doc:`Minibatch Training Tutorials `. + """ + def __init__(self, g, nids, graph_sampler): + self.g = g + if not isinstance(nids, Mapping): + assert len(g.ntypes) == 1, \ + "nids should be a dict of node type and ids for graph with multiple node types" + self.graph_sampler = graph_sampler + + self.nids = utils.prepare_tensor_or_dict(g, nids, 'nids') + self._dataset = utils.maybe_flatten_dict(self.nids) + + @property + def dataset(self): + return self._dataset + + def collate(self, items): + """Find the list of MFGs necessary for computing the representation of given + nodes for a node classification/regression task. + + Parameters + ---------- + items : list[int] or list[tuple[str, int]] + Either a list of node IDs (for homogeneous graphs), or a list of node type-ID + pairs (for heterogeneous graphs). + + Returns + ------- + input_nodes : Tensor or dict[ntype, Tensor] + The input nodes necessary for computation in this minibatch. + + If the original graph has multiple node types, return a dictionary of + node type names and node ID tensors. Otherwise, return a single tensor. + output_nodes : Tensor or dict[ntype, Tensor] + The nodes whose representations are to be computed in this minibatch. + + If the original graph has multiple node types, return a dictionary of + node type names and node ID tensors. Otherwise, return a single tensor. + MFGs : list[DGLGraph] + The list of MFGs necessary for computing the representation. + """ + if isinstance(items[0], tuple): + # returns a list of pairs: group them by node types into a dict + items = utils.group_as_dict(items) + items = utils.prepare_tensor_or_dict(self.g, items, 'items') + + input_nodes, output_nodes, blocks = self.graph_sampler.sample_blocks(self.g, items) + + return input_nodes, output_nodes, blocks + +class EdgeCollator(Collator): + """DGL collator to combine edges and their computation dependencies within a minibatch for + training edge classification, edge regression, or link prediction on a single graph + with neighborhood sampling. + + Given a set of edges, the collate function will yield + + * A tensor of input nodes necessary for computing the representation on edges, or + a dictionary of node type names and such tensors. + + * A subgraph that contains only the edges in the minibatch and their incident nodes. + Note that the graph has an identical metagraph with the original graph. + + * If a negative sampler is given, another graph that contains the "negative edges", + connecting the source and destination nodes yielded from the given negative sampler. + + * A list of MFGs necessary for computing the representation of the incident nodes + of the edges in the minibatch. + + Parameters + ---------- + g : DGLGraph + The graph from which the edges are iterated in minibatches and the subgraphs + are generated. + eids : Tensor or dict[etype, Tensor] + The edge set in graph :attr:`g` to compute outputs. + graph_sampler : dgl.dataloading.BlockSampler + The neighborhood sampler. + g_sampling : DGLGraph, optional + The graph where neighborhood sampling and message passing is performed. + + Note that this is not necessarily the same as :attr:`g`. + + If None, assume to be the same as :attr:`g`. + exclude : str, optional + Whether and how to exclude dependencies related to the sampled edges in the + minibatch. Possible values are + + * None, which excludes nothing. + + * ``'self'``, which excludes the sampled edges themselves but nothing else. + + * ``'reverse_id'``, which excludes the reverse edges of the sampled edges. The said + reverse edges have the same edge type as the sampled edges. Only works + on edge types whose source node type is the same as its destination node type. + + * ``'reverse_types'``, which excludes the reverse edges of the sampled edges. The + said reverse edges have different edge types from the sampled edges. + + If ``g_sampling`` is given, ``exclude`` is ignored and will be always ``None``. + reverse_eids : Tensor or dict[etype, Tensor], optional + A tensor of reverse edge ID mapping. The i-th element indicates the ID of + the i-th edge's reverse edge. + + If the graph is heterogeneous, this argument requires a dictionary of edge + types and the reverse edge ID mapping tensors. + + Required and only used when ``exclude`` is set to ``reverse_id``. + + For heterogeneous graph this will be a dict of edge type and edge IDs. Note that + only the edge types whose source node type is the same as destination node type + are needed. + reverse_etypes : dict[etype, etype], optional + The mapping from the edge type to its reverse edge type. + + Required and only used when ``exclude`` is set to ``reverse_types``. + negative_sampler : callable, optional + The negative sampler. Can be omitted if no negative sampling is needed. + + The negative sampler must be a callable that takes in the following arguments: + + * The original (heterogeneous) graph. + + * The ID array of sampled edges in the minibatch, or the dictionary of edge + types and ID array of sampled edges in the minibatch if the graph is + heterogeneous. + + It should return + + * A pair of source and destination node ID arrays as negative samples, + or a dictionary of edge types and such pairs if the graph is heterogenenous. + + A set of builtin negative samplers are provided in + :ref:`the negative sampling module `. + + Examples + -------- + The following example shows how to train a 3-layer GNN for edge classification on a + set of edges ``train_eid`` on a homogeneous undirected graph. Each node takes + messages from all neighbors. + + Say that you have an array of source node IDs ``src`` and another array of destination + node IDs ``dst``. One can make it bidirectional by adding another set of edges + that connects from ``dst`` to ``src``: + + >>> g = dgl.graph((torch.cat([src, dst]), torch.cat([dst, src]))) + + One can then know that the ID difference of an edge and its reverse edge is ``|E|``, + where ``|E|`` is the length of your source/destination array. The reverse edge + mapping can be obtained by + + >>> E = len(src) + >>> reverse_eids = torch.cat([torch.arange(E, 2 * E), torch.arange(0, E)]) + + Note that the sampled edges as well as their reverse edges are removed from + computation dependencies of the incident nodes. This is a common trick to avoid + information leakage. + + >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5]) + >>> collator = dgl.dataloading.EdgeCollator( + ... g, train_eid, sampler, exclude='reverse_id', + ... reverse_eids=reverse_eids) + >>> dataloader = torch.utils.data.DataLoader( + ... collator.dataset, collate_fn=collator.collate, + ... batch_size=1024, shuffle=True, drop_last=False, num_workers=4) + >>> for input_nodes, pair_graph, blocks in dataloader: + ... train_on(input_nodes, pair_graph, blocks) + + To train a 3-layer GNN for link prediction on a set of edges ``train_eid`` on a + homogeneous graph where each node takes messages from all neighbors (assume the + backend is PyTorch), with 5 uniformly chosen negative samples per edge: + + >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5]) + >>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(5) + >>> collator = dgl.dataloading.EdgeCollator( + ... g, train_eid, sampler, exclude='reverse_id', + ... reverse_eids=reverse_eids, negative_sampler=neg_sampler) + >>> dataloader = torch.utils.data.DataLoader( + ... collator.dataset, collate_fn=collator.collate, + ... batch_size=1024, shuffle=True, drop_last=False, num_workers=4) + >>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader: + ... train_on(input_nodse, pair_graph, neg_pair_graph, blocks) + + For heterogeneous graphs, the reverse of an edge may have a different edge type + from the original edge. For instance, consider that you have an array of + user-item clicks, representated by a user array ``user`` and an item array ``item``. + You may want to build a heterogeneous graph with a user-click-item relation and an + item-clicked-by-user relation. + + >>> g = dgl.heterograph({ + ... ('user', 'click', 'item'): (user, item), + ... ('item', 'clicked-by', 'user'): (item, user)}) + + To train a 3-layer GNN for edge classification on a set of edges ``train_eid`` with + type ``click``, you can write + + >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5]) + >>> collator = dgl.dataloading.EdgeCollator( + ... g, {'click': train_eid}, sampler, exclude='reverse_types', + ... reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click'}) + >>> dataloader = torch.utils.data.DataLoader( + ... collator.dataset, collate_fn=collator.collate, + ... batch_size=1024, shuffle=True, drop_last=False, num_workers=4) + >>> for input_nodes, pair_graph, blocks in dataloader: + ... train_on(input_nodes, pair_graph, blocks) + + To train a 3-layer GNN for link prediction on a set of edges ``train_eid`` with type + ``click``, you can write + + >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5]) + >>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(5) + >>> collator = dgl.dataloading.EdgeCollator( + ... g, train_eid, sampler, exclude='reverse_types', + ... reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click'}, + ... negative_sampler=neg_sampler) + >>> dataloader = torch.utils.data.DataLoader( + ... collator.dataset, collate_fn=collator.collate, + ... batch_size=1024, shuffle=True, drop_last=False, num_workers=4) + >>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader: + ... train_on(input_nodes, pair_graph, neg_pair_graph, blocks) + + Notes + ----- + For the concept of MFGs, please refer to + :ref:`User Guide Section 6 ` and + :doc:`Minibatch Training Tutorials `. + """ + def __init__(self, g, eids, graph_sampler, g_sampling=None, exclude=None, + reverse_eids=None, reverse_etypes=None, negative_sampler=None): + self.g = g + if not isinstance(eids, Mapping): + assert len(g.etypes) == 1, \ + "eids should be a dict of etype and ids for graph with multiple etypes" + self.graph_sampler = graph_sampler + + # One may wish to iterate over the edges in one graph while perform sampling in + # another graph. This may be the case for iterating over validation and test + # edge set while perform neighborhood sampling on the graph formed by only + # the training edge set. + # See GCMC for an example usage. + if g_sampling is not None: + self.g_sampling = g_sampling + self.exclude = None + else: + self.g_sampling = self.g + self.exclude = exclude + + self.reverse_eids = reverse_eids + self.reverse_etypes = reverse_etypes + self.negative_sampler = negative_sampler + + self.eids = utils.prepare_tensor_or_dict(g, eids, 'eids') + self._dataset = utils.maybe_flatten_dict(self.eids) + + @property + def dataset(self): + return self._dataset + + def _collate(self, items): + if isinstance(items[0], tuple): + # returns a list of pairs: group them by node types into a dict + items = utils.group_as_dict(items) + items = utils.prepare_tensor_or_dict(self.g_sampling, items, 'items') + + pair_graph = self.g.edge_subgraph(items) + seed_nodes = pair_graph.ndata[NID] + + exclude_eids = _find_exclude_eids( + self.g_sampling, + self.exclude, + items, + reverse_eid_map=self.reverse_eids, + reverse_etype_map=self.reverse_etypes) + + input_nodes, _, blocks = self.graph_sampler.sample_blocks( + self.g_sampling, seed_nodes, exclude_eids=exclude_eids) + + return input_nodes, pair_graph, blocks + + def _collate_with_negative_sampling(self, items): + if isinstance(items[0], tuple): + # returns a list of pairs: group them by node types into a dict + items = utils.group_as_dict(items) + items = utils.prepare_tensor_or_dict(self.g_sampling, items, 'items') + + pair_graph = self.g.edge_subgraph(items, relabel_nodes=False) + induced_edges = pair_graph.edata[EID] + + neg_srcdst = self.negative_sampler(self.g, items) + if not isinstance(neg_srcdst, Mapping): + assert len(self.g.etypes) == 1, \ + 'graph has multiple or no edge types; '\ + 'please return a dict in negative sampler.' + neg_srcdst = {self.g.canonical_etypes[0]: neg_srcdst} + # Get dtype from a tuple of tensors + dtype = F.dtype(list(neg_srcdst.values())[0][0]) + ctx = F.context(pair_graph) + neg_edges = { + etype: neg_srcdst.get(etype, (F.copy_to(F.tensor([], dtype), ctx), + F.copy_to(F.tensor([], dtype), ctx))) + for etype in self.g.canonical_etypes} + neg_pair_graph = heterograph( + neg_edges, {ntype: self.g.number_of_nodes(ntype) for ntype in self.g.ntypes}) + + pair_graph, neg_pair_graph = transform.compact_graphs([pair_graph, neg_pair_graph]) + pair_graph.edata[EID] = induced_edges + + seed_nodes = pair_graph.ndata[NID] + + exclude_eids = _find_exclude_eids( + self.g_sampling, + self.exclude, + items, + reverse_eid_map=self.reverse_eids, + reverse_etype_map=self.reverse_etypes) + + input_nodes, _, blocks = self.graph_sampler.sample_blocks( + self.g_sampling, seed_nodes, exclude_eids=exclude_eids) + + return input_nodes, pair_graph, neg_pair_graph, blocks + + def collate(self, items): + """Combines the sampled edges into a minibatch for edge classification, edge + regression, and link prediction tasks. + + Parameters + ---------- + items : list[int] or list[tuple[str, int]] + Either a list of edge IDs (for homogeneous graphs), or a list of edge type-ID + pairs (for heterogeneous graphs). + + Returns + ------- + Either ``(input_nodes, pair_graph, blocks)``, or + ``(input_nodes, pair_graph, negative_pair_graph, blocks)`` if negative sampling is + enabled. + + input_nodes : Tensor or dict[ntype, Tensor] + The input nodes necessary for computation in this minibatch. + + If the original graph has multiple node types, return a dictionary of + node type names and node ID tensors. Otherwise, return a single tensor. + pair_graph : DGLGraph + The graph that contains only the edges in the minibatch as well as their incident + nodes. + + Note that the metagraph of this graph will be identical to that of the original + graph. + negative_pair_graph : DGLGraph + The graph that contains only the edges connecting the source and destination nodes + yielded from the given negative sampler, if negative sampling is enabled. + + Note that the metagraph of this graph will be identical to that of the original + graph. + blocks : list[DGLGraph] + The list of MFGs necessary for computing the representation of the edges. + """ + if self.negative_sampler is None: + return self._collate(items) + else: + return self._collate_with_negative_sampling(items) + +class GraphCollator(object): + """Given a set of graphs as well as their graph-level data, the collate function will batch the + graphs into a batched graph, and stack the tensors into a single bigger tensor. If the + example is a container (such as sequences or mapping), the collate function preserves + the structure and collates each of the elements recursively. + + If the set of graphs has no graph-level data, the collate function will yield a batched graph. + + Examples + -------- + To train a GNN for graph classification on a set of graphs in ``dataset`` (assume + the backend is PyTorch): + + >>> dataloader = dgl.dataloading.GraphDataLoader( + ... dataset, batch_size=1024, shuffle=True, drop_last=False, num_workers=4) + >>> for batched_graph, labels in dataloader: + ... train_on(batched_graph, labels) + """ + def __init__(self): + self.graph_collate_err_msg_format = ( + "graph_collate: batch must contain DGLGraph, tensors, numpy arrays, " + "numbers, dicts or lists; found {}") + self.np_str_obj_array_pattern = re.compile(r'[SaUO]') + + #This implementation is based on torch.utils.data._utils.collate.default_collate + def collate(self, items): + """This function is similar to ``torch.utils.data._utils.collate.default_collate``. + It combines the sampled graphs and corresponding graph-level data + into a batched graph and tensors. + + Parameters + ---------- + items : list of data points or tuples + Elements in the list are expected to have the same length. + Each sub-element will be batched as a batched graph, or a + batched tensor correspondingly. + + Returns + ------- + A tuple of the batching results. + """ + elem = items[0] + elem_type = type(elem) + if isinstance(elem, DGLGraph): + batched_graphs = batch(items) + return batched_graphs + elif F.is_tensor(elem): + return F.stack(items, 0) + elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ + and elem_type.__name__ != 'string_': + if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': + # array of string classes and object + if self.np_str_obj_array_pattern.search(elem.dtype.str) is not None: + raise TypeError(self.graph_collate_err_msg_format.format(elem.dtype)) + + return self.collate([F.tensor(b) for b in items]) + elif elem.shape == (): # scalars + return F.tensor(items) + elif isinstance(elem, float): + return F.tensor(items, dtype=F.float64) + elif isinstance(elem, int): + return F.tensor(items) + elif isinstance(elem, (str, bytes)): + return items + elif isinstance(elem, Mapping): + return {key: self.collate([d[key] for d in items]) for key in elem} + elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple + return elem_type(*(self.collate(samples) for samples in zip(*items))) + elif isinstance(elem, Sequence): + # check to make sure that the elements in batch have consistent size + item_iter = iter(items) + elem_size = len(next(item_iter)) + if not all(len(elem) == elem_size for elem in item_iter): + raise RuntimeError('each element in list of batch should be of equal size') + transposed = zip(*items) + return [self.collate(samples) for samples in transposed] + + raise TypeError(self.graph_collate_err_msg_format.format(elem_type)) + +class SubgraphIterator(object): + """Abstract class representing an iterator that yields a subgraph given a graph. + """ + def __init__(self, g): + self.g = g diff --git a/python/dgl/_dataloading/negative_sampler.py b/python/dgl/_dataloading/negative_sampler.py new file mode 100644 index 000000000000..d9a56e2e8df3 --- /dev/null +++ b/python/dgl/_dataloading/negative_sampler.py @@ -0,0 +1,125 @@ +"""Negative samplers""" +from collections.abc import Mapping +from .. import backend as F +from ..sampling import global_uniform_negative_sampling + +class _BaseNegativeSampler(object): + def _generate(self, g, eids, canonical_etype): + raise NotImplementedError + + def __call__(self, g, eids): + """Returns negative samples. + + Parameters + ---------- + g : DGLGraph + The graph. + eids : Tensor or dict[etype, Tensor] + The sampled edges in the minibatch. + + Returns + ------- + tuple[Tensor, Tensor] or dict[etype, tuple[Tensor, Tensor]] + The returned source-destination pairs as negative samples. + """ + if isinstance(eids, Mapping): + eids = {g.to_canonical_etype(k): v for k, v in eids.items()} + neg_pair = {k: self._generate(g, v, k) for k, v in eids.items()} + else: + assert len(g.etypes) == 1, \ + 'please specify a dict of etypes and ids for graphs with multiple edge types' + neg_pair = self._generate(g, eids, g.canonical_etypes[0]) + + return neg_pair + +class PerSourceUniform(_BaseNegativeSampler): + """Negative sampler that randomly chooses negative destination nodes + for each source node according to a uniform distribution. + + For each edge ``(u, v)`` of type ``(srctype, etype, dsttype)``, DGL generates + :attr:`k` pairs of negative edges ``(u, v')``, where ``v'`` is chosen + uniformly from all the nodes of type ``dsttype``. The resulting edges will + also have type ``(srctype, etype, dsttype)``. + + Parameters + ---------- + k : int + The number of negative samples per edge. + + Examples + -------- + >>> g = dgl.graph(([0, 1, 2], [1, 2, 3])) + >>> neg_sampler = dgl.dataloading.negative_sampler.PerSourceUniform(2) + >>> neg_sampler(g, torch.tensor([0, 1])) + (tensor([0, 0, 1, 1]), tensor([1, 0, 2, 3])) + """ + def __init__(self, k): + self.k = k + + def _generate(self, g, eids, canonical_etype): + _, _, vtype = canonical_etype + shape = F.shape(eids) + dtype = F.dtype(eids) + ctx = F.context(eids) + shape = (shape[0] * self.k,) + src, _ = g.find_edges(eids, etype=canonical_etype) + src = F.repeat(src, self.k, 0) + dst = F.randint(shape, dtype, ctx, 0, g.number_of_nodes(vtype)) + return src, dst + +# Alias +Uniform = PerSourceUniform + +class GlobalUniform(_BaseNegativeSampler): + """Negative sampler that randomly chooses negative source-destination pairs according + to a uniform distribution. + + For each edge ``(u, v)`` of type ``(srctype, etype, dsttype)``, DGL generates at most + :attr:`k` pairs of negative edges ``(u', v')``, where ``u'`` is chosen uniformly from + all the nodes of type ``srctype`` and ``v'`` is chosen uniformly from all the nodes + of type ``dsttype``. The resulting edges will also have type + ``(srctype, etype, dsttype)``. DGL guarantees that the sampled pairs will not have + edges in between. + + Parameters + ---------- + k : int + The desired number of negative samples to generate per edge. + exclude_self_loops : bool, optional + Whether to exclude self-loops from negative samples. (Default: True) + replace : bool, optional + Whether to sample with replacement. Setting it to True will make things + faster. (Default: True) + redundancy : float, optional + Indicates how much more negative samples to actually generate during rejection sampling + before finding the unique pairs. + + Increasing it will increase the likelihood of getting :attr:`k` negative samples + per edge, but will also take more time and memory. + + (Default: automatically determined by the density of graph) + + Notes + ----- + This negative sampler will try to generate as many negative samples as possible, but + it may rarely return less than :attr:`k` negative samples per edge. + This is more likely to happen if a graph is so small or dense that not many unique + negative samples exist. + + Examples + -------- + >>> g = dgl.graph(([0, 1, 2], [1, 2, 3])) + >>> neg_sampler = dgl.dataloading.negative_sampler.GlobalUniform(2, True) + >>> neg_sampler(g, torch.LongTensor([0, 1])) + (tensor([0, 1, 3, 2]), tensor([2, 0, 2, 1])) + """ + def __init__(self, k, exclude_self_loops=True, replace=False, redundancy=None): + self.k = k + self.exclude_self_loops = exclude_self_loops + self.replace = replace + self.redundancy = redundancy + + def _generate(self, g, eids, canonical_etype): + return global_uniform_negative_sampling( + g, len(eids) * self.k, self.exclude_self_loops, self.replace, + canonical_etype, self.redundancy) diff --git a/python/dgl/dataloading/neighbor.py b/python/dgl/_dataloading/neighbor.py similarity index 100% rename from python/dgl/dataloading/neighbor.py rename to python/dgl/_dataloading/neighbor.py diff --git a/python/dgl/dataloading/pytorch/__init__.py b/python/dgl/_dataloading/pytorch/__init__.py similarity index 100% rename from python/dgl/dataloading/pytorch/__init__.py rename to python/dgl/_dataloading/pytorch/__init__.py diff --git a/python/dgl/dataloading/pytorch/dataloader.py b/python/dgl/_dataloading/pytorch/dataloader.py similarity index 87% rename from python/dgl/dataloading/pytorch/dataloader.py rename to python/dgl/_dataloading/pytorch/dataloader.py index 36057891220d..a7a6ab0f5fd4 100644 --- a/python/dgl/dataloading/pytorch/dataloader.py +++ b/python/dgl/_dataloading/pytorch/dataloader.py @@ -10,7 +10,6 @@ import torch.distributed as dist from ..dataloader import NodeCollator, EdgeCollator, GraphCollator, SubgraphIterator from ...distributed import DistGraph -from ...distributed import DistDataLoader from ...ndarray import NDArray as DGLNDArray from ... import backend as F from ...base import DGLError @@ -26,6 +25,10 @@ PYTORCH_16 = PYTORCH_VER >= LooseVersion("1.6.0") PYTORCH_17 = PYTORCH_VER >= LooseVersion("1.7.0") +def _check_graph_type(g): + if isinstance(g, DistGraph): + raise TypeError("Please use DistNodeDataLoader or DistEdgeDataLoader for DistGraph") + def _create_dist_sampler(dataset, dataloader_kwargs, ddp_seed): # Note: will change the content of dataloader_kwargs dist_sampler_kwargs = {'shuffle': dataloader_kwargs['shuffle']} @@ -166,14 +169,6 @@ def set_epoch(self, epoch): """Set epoch number for distributed training.""" self.epoch = epoch -def _remove_kwargs_dist(kwargs): - if 'num_workers' in kwargs: - del kwargs['num_workers'] - if 'pin_memory' in kwargs: - del kwargs['pin_memory'] - print('Distributed DataLoader does not support pin_memory') - return kwargs - # The following code is a fix to the PyTorch-specific issue in # https://github.com/dmlc/dgl/issues/2137 # @@ -290,14 +285,14 @@ def _restore_storages(subgs, g): _restore_subgraph_storage(subg, g) class _NodeCollator(NodeCollator): - def collate(self, items): + def collate(self, items): # pylint: disable=missing-docstring # input_nodes, output_nodes, blocks result = super().collate(items) _pop_storages(result[-1], self.g) return result class _EdgeCollator(EdgeCollator): - def collate(self, items): + def collate(self, items): # pylint: disable=missing-docstring if self.negative_sampler is None: # input_nodes, pair_graph, blocks result = super().collate(items) @@ -381,10 +376,10 @@ def _background_node_dataloader(dl_iter, g, device, results, load_input, load_ou class _NodeDataLoaderIter: - def __init__(self, node_dataloader): + def __init__(self, node_dataloader, iter_): self.device = node_dataloader.device self.node_dataloader = node_dataloader - self.iter_ = iter(node_dataloader.dataloader) + self.iter_ = iter_ self.async_load = node_dataloader.async_load and ( F.device_type(self.device) == 'cuda') if self.async_load: @@ -418,10 +413,10 @@ def __next__(self): return input_nodes, output_nodes, blocks class _EdgeDataLoaderIter: - def __init__(self, edge_dataloader): + def __init__(self, edge_dataloader, iter_): self.device = edge_dataloader.device self.edge_dataloader = edge_dataloader - self.iter_ = iter(edge_dataloader.dataloader) + self.iter_ = iter_ # Make this an iterator for PyTorch Lightning compatibility def __iter__(self): @@ -441,9 +436,9 @@ def __next__(self): return result class _GraphDataLoaderIter: - def __init__(self, graph_dataloader): + def __init__(self, graph_dataloader, iter_): self.dataloader = graph_dataloader - self.iter_ = iter(graph_dataloader.dataloader) + self.iter_ = iter_ def __iter__(self): return self @@ -490,14 +485,9 @@ def _init_dataloader(collator, device, dataloader_kwargs, use_ddp, ddp_seed): else: dist_sampler = None - dataloader = DataLoader( - dataset, - collate_fn=collator.collate, - **dataloader_kwargs) + return use_scalar_batcher, scalar_batcher, dataset, collator, dist_sampler - return use_scalar_batcher, scalar_batcher, dataloader, dist_sampler - -class NodeDataLoader: +class NodeDataLoader(DataLoader): """PyTorch dataloader for batch-iterating over a set of nodes, generating the list of message flow graphs (MFGs) as computation dependency of the said minibatch. @@ -600,6 +590,7 @@ class NodeDataLoader: def __init__(self, g, nids, graph_sampler, device=None, use_ddp=False, ddp_seed=0, load_input=None, load_output=None, async_load=False, **kwargs): + _check_graph_type(g) collator_kwargs = {} dataloader_kwargs = {} for k, v in kwargs.items(): @@ -608,65 +599,42 @@ def __init__(self, g, nids, graph_sampler, device=None, use_ddp=False, ddp_seed= else: dataloader_kwargs[k] = v - if isinstance(g, DistGraph): - if device is None: - # for the distributed case default to the CPU - device = 'cpu' - assert device == 'cpu', 'Only cpu is supported in the case of a DistGraph.' - # Distributed DataLoader currently does not support heterogeneous graphs - # and does not copy features. Fallback to normal solution - self.collator = NodeCollator(g, nids, graph_sampler, **collator_kwargs) - _remove_kwargs_dist(dataloader_kwargs) - self.dataloader = DistDataLoader(self.collator.dataset, - collate_fn=self.collator.collate, - **dataloader_kwargs) - self.is_distributed = True - else: - if device is None: - # default to the same device the graph is on - device = th.device(g.device) - - if not g.is_homogeneous: - if load_input or load_output: - raise DGLError('load_input/load_output not supported for heterograph yet.') - self.load_input = {} if load_input is None else load_input - self.load_output = {} if load_output is None else load_output - self.async_load = async_load - - # if the sampler supports it, tell it to output to the specified device. - # But if async_load is enabled, set_output_context should be skipped as - # we'd like to avoid any graph/data transfer graphs across devices in - # sampler. Such transfer will be handled in dataloader. - num_workers = dataloader_kwargs.get('num_workers', 0) - if ((not async_load) and - callable(getattr(graph_sampler, "set_output_context", None)) and - num_workers == 0): - graph_sampler.set_output_context(to_dgl_context(device)) - - self.collator = _NodeCollator(g, nids, graph_sampler, **collator_kwargs) - self.use_scalar_batcher, self.scalar_batcher, self.dataloader, self.dist_sampler = \ - _init_dataloader(self.collator, device, dataloader_kwargs, use_ddp, ddp_seed) + if device is None: + # default to the same device the graph is on + device = th.device(g.device) + + if not g.is_homogeneous: + if load_input or load_output: + raise DGLError('load_input/load_output not supported for heterograph yet.') + self.load_input = {} if load_input is None else load_input + self.load_output = {} if load_output is None else load_output + self.async_load = async_load + + # if the sampler supports it, tell it to output to the specified device. + # But if async_load is enabled, set_output_context should be skipped as + # we'd like to avoid any graph/data transfer graphs across devices in + # sampler. Such transfer will be handled in dataloader. + num_workers = dataloader_kwargs.get('num_workers', 0) + if ((not async_load) and + callable(getattr(graph_sampler, "set_output_context", None)) and + num_workers == 0): + graph_sampler.set_output_context(to_dgl_context(device)) + + self.collator = _NodeCollator(g, nids, graph_sampler, **collator_kwargs) + self.use_scalar_batcher, self.scalar_batcher, self.dataloader, self.dist_sampler = \ + _init_dataloader(self.collator, device, dataloader_kwargs, use_ddp, ddp_seed) - self.use_ddp = use_ddp - self.is_distributed = False + self.use_ddp = use_ddp + self.is_distributed = False - # Precompute the CSR and CSC representations so each subprocess does not - # duplicate. - if num_workers > 0: - g.create_formats_() + # Precompute the CSR and CSC representations so each subprocess does not + # duplicate. + if num_workers > 0: + g.create_formats_() self.device = device def __iter__(self): - """Return the iterator of the data loader.""" - if self.is_distributed: - # Directly use the iterator of DistDataLoader, which doesn't copy features anyway. - return iter(self.dataloader) - else: - return _NodeDataLoaderIter(self) - - def __len__(self): - """Return the number of batches of the data loader.""" - return len(self.dataloader) + return _NodeDataLoaderIter(self, super().__iter__()) def set_epoch(self, epoch): """Sets the epoch number for the underlying sampler which ensures all replicas @@ -689,7 +657,7 @@ def set_epoch(self, epoch): else: raise DGLError('set_epoch is only available when use_ddp is True.') -class EdgeDataLoader: +class EdgeDataLoader(DataLoader): """PyTorch dataloader for batch-iterating over a set of edges, generating the list of message flow graphs (MFGs) as computation dependency of the said minibatch for edge classification, edge regression, and link prediction. @@ -897,8 +865,9 @@ class EdgeDataLoader: * Link prediction on heterogeneous graph: RGCN for link prediction. """ collator_arglist = inspect.getfullargspec(EdgeCollator).args - - def __init__(self, g, eids, graph_sampler, device='cpu', use_ddp=False, ddp_seed=0, **kwargs): + def __init__(self, g, eids, graph_sampler, device='cpu', use_ddp=False, ddp_seed=0, + **kwargs): + _check_graph_type(g) collator_kwargs = {} dataloader_kwargs = {} for k, v in kwargs.items(): @@ -907,53 +876,30 @@ def __init__(self, g, eids, graph_sampler, device='cpu', use_ddp=False, ddp_seed else: dataloader_kwargs[k] = v - if isinstance(g, DistGraph): - if device is None: - # for the distributed case default to the CPU - device = 'cpu' - assert device == 'cpu', 'Only cpu is supported in the case of a DistGraph.' - # Distributed DataLoader currently does not support heterogeneous graphs - # and does not copy features. Fallback to normal solution - self.collator = EdgeCollator(g, eids, graph_sampler, **collator_kwargs) - _remove_kwargs_dist(dataloader_kwargs) - self.dataloader = DistDataLoader(self.collator.dataset, - collate_fn=self.collator.collate, - **dataloader_kwargs) - self.is_distributed = True - else: if device is None: # default to the same device the graph is on device = th.device(g.device) - # if the sampler supports it, tell it to output to the - # specified device - num_workers = dataloader_kwargs.get('num_workers', 0) - if callable(getattr(graph_sampler, "set_output_context", None)) and num_workers == 0: - graph_sampler.set_output_context(to_dgl_context(device)) + # if the sampler supports it, tell it to output to the + # specified device + num_workers = dataloader_kwargs.get('num_workers', 0) + if callable(getattr(graph_sampler, "set_output_context", None)) and num_workers == 0: + graph_sampler.set_output_context(to_dgl_context(device)) - self.collator = _EdgeCollator(g, eids, graph_sampler, **collator_kwargs) - self.use_scalar_batcher, self.scalar_batcher, self.dataloader, self.dist_sampler = \ - _init_dataloader(self.collator, device, dataloader_kwargs, use_ddp, ddp_seed) - self.use_ddp = use_ddp - self.is_distributed = False + self.collator = EdgeCollator(g, eids, graph_sampler, **collator_kwargs) + self.use_scalar_batcher, self.scalar_batcher, dataset, collator, self.dist_sampler = \ + _init_dataloader(self.collator, device, dataloader_kwargs, use_ddp, ddp_seed) + self.use_ddp = use_ddp + super().__init__(dataset, collate_fn=collator.collate, **dataloader_kwargs) - # Precompute the CSR and CSC representations so each subprocess does not duplicate. - if num_workers > 0: - g.create_formats_() + # Precompute the CSR and CSC representations so each subprocess does not duplicate. + if num_workers > 0: + g.create_formats_() self.device = device def __iter__(self): - """Return the iterator of the data loader.""" - if self.is_distributed: - # Directly use the iterator of DistDataLoader, which doesn't copy features anyway. - return iter(self.dataloader) - else: - return _EdgeDataLoaderIter(self) - - def __len__(self): - """Return the number of batches of the data loader.""" - return len(self.dataloader) + return _EdgeDataLoaderIter(self, super().__iter__()) def set_epoch(self, epoch): """Sets the epoch number for the underlying sampler which ensures all replicas @@ -976,7 +922,7 @@ def set_epoch(self, epoch): else: raise DGLError('set_epoch is only available when use_ddp is True.') -class GraphDataLoader: +class GraphDataLoader(DataLoader): """PyTorch dataloader for batch-iterating over a set of graphs, generating the batched graph and corresponding label tensor (if provided) of the said minibatch. @@ -1023,7 +969,6 @@ class GraphDataLoader: ... train_on(batched_graph, labels) """ collator_arglist = inspect.getfullargspec(GraphCollator).args - def __init__(self, dataset, collate_fn=None, use_ddp=False, ddp_seed=0, **kwargs): collator_kwargs = {} dataloader_kwargs = {} @@ -1058,14 +1003,11 @@ def __iter__(self): if use_ddp: self.dist_sampler = _create_dist_sampler(dataset, dataloader_kwargs, ddp_seed) dataloader_kwargs['sampler'] = self.dist_sampler - - self.dataloader = DataLoader(dataset=dataset, - collate_fn=self.collate, - **dataloader_kwargs) + super().__init__(dataset, collate_fn=self.collate, **dataloader_kwargs) def __iter__(self): """Return the iterator of the data loader.""" - return _GraphDataLoaderIter(self) + return _GraphDataLoaderIter(self, super().__iter__()) def __len__(self): """Return the number of batches of the data loader.""" diff --git a/python/dgl/_dataloading/shadow.py b/python/dgl/_dataloading/shadow.py new file mode 100644 index 000000000000..4841a9dbb1c6 --- /dev/null +++ b/python/dgl/_dataloading/shadow.py @@ -0,0 +1,98 @@ +"""ShaDow-GNN subgraph samplers.""" +from ..utils import prepare_tensor_or_dict +from ..base import NID +from .. import transform +from ..sampling import sample_neighbors +from .neighbor import NeighborSamplingMixin +from .dataloader import exclude_edges, Sampler + +class ShaDowKHopSampler(NeighborSamplingMixin, Sampler): + """K-hop subgraph sampler used by + `ShaDow-GNN `__. + + It performs node-wise neighbor sampling but instead of returning a list of + MFGs, it returns a single subgraph induced by all the sampled nodes. The + seed nodes from which the neighbors are sampled will appear the first in the + induced nodes of the subgraph. + + This is used in conjunction with :class:`dgl.dataloading.pytorch.NodeDataLoader` + and :class:`dgl.dataloading.pytorch.EdgeDataLoader`. + + Parameters + ---------- + fanouts : list[int] or list[dict[etype, int]] + List of neighbors to sample per edge type for each GNN layer, with the i-th + element being the fanout for the i-th GNN layer. + + If only a single integer is provided, DGL assumes that every edge type + will have the same fanout. + + If -1 is provided for one edge type on one layer, then all inbound edges + of that edge type will be included. + replace : bool, default True + Whether to sample with replacement + prob : str, optional + If given, the probability of each neighbor being sampled is proportional + to the edge feature value with the given name in ``g.edata``. The feature must be + a scalar on each edge. + + Examples + -------- + To train a 3-layer GNN for node classification on a set of nodes ``train_nid`` on + a homogeneous graph where each node takes messages from 5, 10, 15 neighbors for + the first, second, and third layer respectively (assuming the backend is PyTorch): + + >>> g = dgl.data.CoraFullDataset()[0] + >>> sampler = dgl.dataloading.ShaDowKHopSampler([5, 10, 15]) + >>> dataloader = dgl.dataloading.NodeDataLoader( + ... g, torch.arange(g.num_nodes()), sampler, + ... batch_size=5, shuffle=True, drop_last=False, num_workers=4) + >>> for input_nodes, output_nodes, (subgraph,) in dataloader: + ... print(subgraph) + ... assert torch.equal(input_nodes, subgraph.ndata[dgl.NID]) + ... assert torch.equal(input_nodes[:output_nodes.shape[0]], output_nodes) + ... break + Graph(num_nodes=529, num_edges=3796, + ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), + 'feat': Scheme(shape=(8710,), dtype=torch.float32), + '_ID': Scheme(shape=(), dtype=torch.int64)} + edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)}) + + If training on a heterogeneous graph and you want different number of neighbors for each + edge type, one should instead provide a list of dicts. Each dict would specify the + number of neighbors to pick per edge type. + + >>> sampler = dgl.dataloading.ShaDowKHopSampler([ + ... {('user', 'follows', 'user'): 5, + ... ('user', 'plays', 'game'): 4, + ... ('game', 'played-by', 'user'): 3}] * 3) + + If you would like non-uniform neighbor sampling: + + >>> g.edata['p'] = torch.rand(g.num_edges()) # any non-negative 1D vector works + >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([5, 10, 15], prob='p') + """ + def __init__(self, fanouts, replace=False, prob=None, output_ctx=None): + super().__init__(output_ctx) + self.fanouts = fanouts + self.replace = replace + self.prob = prob + self.set_output_context(output_ctx) + + def sample(self, g, seed_nodes, exclude_eids=None): + self._build_fanout(len(self.fanouts), g) + self._build_prob_arrays(g) + seed_nodes = prepare_tensor_or_dict(g, seed_nodes, 'seed nodes') + output_nodes = seed_nodes + + for i in range(len(self.fanouts)): + fanout = self.fanouts[i] + frontier = sample_neighbors( + g, seed_nodes, fanout, replace=self.replace, prob=self.prob_arrays) + block = transform.to_block(frontier, seed_nodes) + seed_nodes = block.srcdata[NID] + + subg = g.subgraph(seed_nodes, relabel_nodes=True) + subg = exclude_edges(subg, exclude_eids, self.output_device) + + return seed_nodes, output_nodes, [subg] diff --git a/python/dgl/dataloading/__init__.py b/python/dgl/dataloading/__init__.py index c71c65dfc54a..69ce04f19f1f 100644 --- a/python/dgl/dataloading/__init__.py +++ b/python/dgl/dataloading/__init__.py @@ -14,15 +14,12 @@ This package is experimental and the interfaces may be subject to changes in future releases. It currently only has implementations in PyTorch. """ -from .neighbor import * -from .dataloader import * +from .. import backend as F +from .neighbor_sampler import * from .cluster_gcn import * from .shadow import * - +from .base import * from . import negative_sampler -from .async_transferer import AsyncTransferer - -from .. import backend as F - if F.get_preferred_backend() == 'pytorch': - from .pytorch import * + from .dataloader import * + from .dist_dataloader import * diff --git a/python/dgl/dataloading/base.py b/python/dgl/dataloading/base.py new file mode 100644 index 000000000000..b0ebe2c368b0 --- /dev/null +++ b/python/dgl/dataloading/base.py @@ -0,0 +1,275 @@ +"""Base classes and functionalities for dataloaders""" +from collections import Mapping +from ..base import NID, EID +from ..convert import heterograph +from .. import backend as F +from ..transform import compact_graphs +from ..frame import LazyFeature +from ..utils import recursive_apply + +def _set_lazy_features(x, xdata, feature_names): + if feature_names is None: + return + if not isinstance(feature_names, Mapping): + xdata.update({k: LazyFeature(k) for k in feature_names}) + else: + for type_, names in feature_names.items(): + x[type_].data.update({k: LazyFeature(k) for k in names}) + +def set_node_lazy_features(g, feature_names): + """Set lazy features for ``g.ndata`` if :attr:`feature_names` is a list of strings, + or ``g.nodes[ntype].data`` if :attr:`feature_names` is a dict of list of strings. + """ + return _set_lazy_features(g.nodes, g.ndata, feature_names) + +def set_edge_lazy_features(g, feature_names): + """Set lazy features for ``g.edata`` if :attr:`feature_names` is a list of strings, + or ``g.edges[etype].data`` if :attr:`feature_names` is a dict of list of strings. + """ + return _set_lazy_features(g.edges, g.edata, feature_names) + +def set_src_lazy_features(g, feature_names): + """Set lazy features for ``g.srcdata`` if :attr:`feature_names` is a list of strings, + or ``g.srcnodes[srctype].data`` if :attr:`feature_names` is a dict of list of strings. + """ + return _set_lazy_features(g.srcnodes, g.srcdata, feature_names) + +def set_dst_lazy_features(g, feature_names): + """Set lazy features for ``g.dstdata`` if :attr:`feature_names` is a list of strings, + or ``g.dstnodes[dsttype].data`` if :attr:`feature_names` is a dict of list of strings. + """ + return _set_lazy_features(g.dstnodes, g.dstdata, feature_names) + +class BlockSampler(object): + """BlockSampler is an abstract class assuming to take in a set of nodes whose + outputs are to compute, and return a list of blocks. + + Moreover, it assumes that the input node features will be put in the first block's + ``srcdata``, the output node labels will be put in the last block's ``dstdata``, and + the edge data will be put in all the blocks' ``edata``. + """ + def __init__(self, prefetch_node_feats=None, prefetch_labels=None, + prefetch_edge_feats=None, output_device=None): + self.prefetch_node_feats = prefetch_node_feats or [] + self.prefetch_labels = prefetch_labels or [] + self.prefetch_edge_feats = prefetch_edge_feats or [] + self.output_device = output_device + + def sample_blocks(self, g, seed_nodes, exclude_eids=None): + """Generates a list of blocks from the given seed nodes. + + This function must return a triplet where the first element is the input node IDs + for the first GNN layer (a tensor or a dict of tensors for heterogeneous graphs), + the second element is the output node IDs for the last GNN layer, and the third + element is the said list of blocks. + """ + raise NotImplementedError + + def assign_lazy_features(self, result): + """Assign lazy features for prefetching.""" + # A LazyFeature is a placeholder telling the dataloader where and which IDs + # to prefetch. It has the signature LazyFeature(name, id_). id_ can be None + # if the LazyFeature is set into one of the subgraph's ``xdata``, in which case the + # dataloader will infer from the subgraph's ``xdata[dgl.NID]`` (or ``xdata[dgl.EID]`` + # if the LazyFeature is set as edge features). + # + # If you want to prefetch things other than ndata and edata, you can also + # return a LazyFeature(name, id_). If a LazyFeature is returned in places other than + # in a graph's ndata/edata/srcdata/dstdata, the DataLoader will prefetch it + # from its dictionary ``other_data``. + # For instance, you can run + # + # return blocks, LazyFeature('other_feat', id_) + # + # To make it work with the sampler returning the stuff above, your dataloader + # needs to have the following + # + # dataloader.attach_data('other_feat', tensor) + # + # Then you can run + # + # for blocks, other_feat in dataloader: + # train_on(blocks, other_feat) + input_nodes, output_nodes, blocks = result + set_src_lazy_features(blocks[0], self.prefetch_node_feats) + set_dst_lazy_features(blocks[-1], self.prefetch_labels) + for block in blocks: + set_edge_lazy_features(block, self.prefetch_edge_feats) + return input_nodes, output_nodes, blocks + + def sample(self, g, seed_nodes): + """Sample a list of blocks from the given seed nodes.""" + result = self.sample_blocks(g, seed_nodes) + return self.assign_lazy_features(result) + + +def _find_exclude_eids_with_reverse_id(g, eids, reverse_eid_map): + if isinstance(eids, Mapping): + eids = {g.to_canonical_etype(k): v for k, v in eids.items()} + exclude_eids = { + k: F.cat([v, F.gather_row(reverse_eid_map[k], v)], 0) + for k, v in eids.items()} + else: + exclude_eids = F.cat([eids, F.gather_row(reverse_eid_map, eids)], 0) + return exclude_eids + +def _find_exclude_eids_with_reverse_types(g, eids, reverse_etype_map): + exclude_eids = {g.to_canonical_etype(k): v for k, v in eids.items()} + reverse_etype_map = { + g.to_canonical_etype(k): g.to_canonical_etype(v) + for k, v in reverse_etype_map.items()} + exclude_eids.update({reverse_etype_map[k]: v for k, v in exclude_eids.items()}) + return exclude_eids + +def _find_exclude_eids(g, exclude_mode, eids, **kwargs): + if exclude_mode is None: + return None + elif F.is_tensor(exclude_mode) or ( + isinstance(exclude_mode, Mapping) and + all(F.is_tensor(v) for v in exclude_mode.values())): + return exclude_mode + elif exclude_mode == 'self': + return eids + elif exclude_mode == 'reverse_id': + return _find_exclude_eids_with_reverse_id(g, eids, kwargs['reverse_eid_map']) + elif exclude_mode == 'reverse_types': + return _find_exclude_eids_with_reverse_types(g, eids, kwargs['reverse_etype_map']) + else: + raise ValueError('unsupported mode {}'.format(exclude_mode)) + +def find_exclude_eids(g, seed_edges, exclude, reverse_eids=None, reverse_etypes=None, + output_device=None): + """Find all edge IDs to exclude according to :attr:`exclude_mode`. + + Parameters + ---------- + g : DGLGraph + The graph. + exclude_mode : str, optional + Can be either of the following, + + None (default) + Does not exclude any edge. + + Tensor or dict[etype, Tensor] + Exclude the given edge IDs. + + 'self' + Exclude the given edges themselves but nothing else. + + 'reverse_id' + Exclude all edges specified in ``eids``, as well as their reverse edges + of the same edge type. + + The mapping from each edge ID to its reverse edge ID is specified in + the keyword argument ``reverse_eid_map``. + + This mode assumes that the reverse of an edge with ID ``e`` and type + ``etype`` will have ID ``reverse_eid_map[e]`` and type ``etype``. + + 'reverse_types' + Exclude all edges specified in ``eids``, as well as their reverse + edges of the corresponding edge types. + + The mapping from each edge type to its reverse edge type is specified + in the keyword argument ``reverse_etype_map``. + + This mode assumes that the reverse of an edge with ID ``e`` and type ``etype`` + will have ID ``e`` and type ``reverse_etype_map[etype]``. + eids : Tensor or dict[etype, Tensor] + The edge IDs. + reverse_eids : Tensor or dict[etype, Tensor] + The mapping from edge ID to its reverse edge ID. + reverse_etypes : dict[etype, etype] + The mapping from edge etype to its reverse edge type. + output_device : device + The device of the output edge IDs. + """ + exclude_eids = _find_exclude_eids( + g, + exclude, + seed_edges, + reverse_eid_map=reverse_eids, + reverse_etype_map=reverse_etypes) + if exclude_eids is not None: + exclude_eids = recursive_apply( + exclude_eids, lambda x: x.to(output_device)) + return exclude_eids + + +class EdgeBlockSampler(object): + """Adapts a :class:`BlockSampler` object's :attr:`sample` method for edge + classification and link prediction. + """ + def __init__(self, block_sampler, exclude=None, reverse_eids=None, + reverse_etypes=None, negative_sampler=None, prefetch_node_feats=None, + prefetch_labels=None, prefetch_edge_feats=None): + self.reverse_eids = reverse_eids + self.reverse_etypes = reverse_etypes + self.exclude = exclude + self.block_sampler = block_sampler + self.negative_sampler = negative_sampler + self.prefetch_node_feats = prefetch_node_feats or [] + self.prefetch_labels = prefetch_labels or [] + self.prefetch_edge_feats = prefetch_edge_feats or [] + self.output_device = block_sampler.output_device + + def _build_neg_graph(self, g, seed_edges): + neg_srcdst = self.negative_sampler(g, seed_edges) + if not isinstance(neg_srcdst, Mapping): + assert len(g.canonical_etypes) == 1, \ + 'graph has multiple or no edge types; '\ + 'please return a dict in negative sampler.' + neg_srcdst = {g.canonical_etypes[0]: neg_srcdst} + + dtype = F.dtype(list(neg_srcdst.values())[0][0]) + neg_edges = { + etype: neg_srcdst.get(etype, (F.tensor([], dtype), F.tensor([], dtype))) + for etype in g.canonical_etypes} + neg_pair_graph = heterograph( + neg_edges, {ntype: g.num_nodes(ntype) for ntype in g.ntypes}) + return neg_pair_graph + + def assign_lazy_features(self, result): + """Assign lazy features for prefetching.""" + pair_graph = result[1] + blocks = result[-1] + + set_src_lazy_features(blocks[0], self.prefetch_node_feats) + set_edge_lazy_features(pair_graph, self.prefetch_labels) + for block in blocks: + set_edge_lazy_features(block, self.prefetch_edge_feats) + # In-place updates + return result + + def sample(self, g, seed_edges): + """Samples a list of blocks, as well as a subgraph containing the sampled + edges from the original graph. + + If :attr:`negative_sampler` is given, also returns another graph containing the + negative pairs as edges. + """ + exclude = self.exclude + pair_graph = g.edge_subgraph( + seed_edges, relabel_nodes=False, output_device=self.output_device) + eids = pair_graph.edata[EID] + + if self.negative_sampler is not None: + neg_graph = self._build_neg_graph(g, seed_edges) + pair_graph, neg_graph = compact_graphs([pair_graph, neg_graph]) + else: + pair_graph = compact_graphs(pair_graph) + + pair_graph.edata[EID] = eids + seed_nodes = pair_graph.ndata[NID] + + exclude_eids = find_exclude_eids( + g, seed_edges, exclude, self.reverse_eids, self.reverse_etypes, + self.output_device) + + input_nodes, _, blocks = self.block_sampler.sample_blocks(g, seed_nodes, exclude_eids) + + if self.negative_sampler is None: + return self.assign_lazy_features((input_nodes, pair_graph, blocks)) + else: + return self.assign_lazy_features((input_nodes, pair_graph, neg_graph, blocks)) diff --git a/python/dgl/dataloading/cluster_gcn.py b/python/dgl/dataloading/cluster_gcn.py index 6df0772e2df3..1cb41ab8614d 100644 --- a/python/dgl/dataloading/cluster_gcn.py +++ b/python/dgl/dataloading/cluster_gcn.py @@ -1,19 +1,20 @@ -"""Cluster-GCN subgraph iterators.""" +"""Cluster-GCN samplers.""" import os import pickle import numpy as np -from ..transform import metis_partition_assignment from .. import backend as F -from .dataloader import SubgraphIterator +from ..base import DGLError +from ..partition import metis_partition_assignment +from .base import set_node_lazy_features, set_edge_lazy_features -class ClusterGCNSubgraphIterator(SubgraphIterator): - """Subgraph sampler following that of ClusterGCN. +class ClusterGCNSampler(object): + """Cluster-GCN sampler. This sampler first partitions the graph with METIS partitioning, then it caches the nodes of each partition to a file within the given cache directory. - This is used in conjunction with :class:`dgl.dataloading.pytorch.GraphDataLoader`. + This is used in conjunction with :class:`dgl.dataloading.DataLoader`. Notes ----- @@ -23,61 +24,53 @@ class ClusterGCNSubgraphIterator(SubgraphIterator): ---------- g : DGLGraph The original graph. - num_partitions : int + k : int The number of partitions. - cache_directory : str + cache_path : str The path to the cache directory for storing the partition result. - refresh : bool - If True, recompute the partition. - - Examples - -------- - Assuming that you have a graph ``g``: - - >>> sgiter = dgl.dataloading.ClusterGCNSubgraphIterator( - ... g, num_partitions=100, cache_directory='.', refresh=True) - >>> dataloader = dgl.dataloading.GraphDataLoader(sgiter, batch_size=4, num_workers=0) - >>> for subgraph_batch in dataloader: - ... train_on(subgraph_batch) """ - def __init__(self, g, num_partitions, cache_directory, refresh=False): - if os.name == 'nt': - raise NotImplementedError("METIS partitioning is not supported on Windows yet.") - super().__init__(g) - - # First see if the cache is already there. If so, directly read from cache. - if not refresh and self._load_parts(cache_directory): - return - - # Otherwise, build the cache. - assignment = F.asnumpy(metis_partition_assignment(g, num_partitions)) - self._save_parts(assignment, cache_directory) - - def _cache_file_path(self, cache_directory): - return os.path.join(cache_directory, 'cluster_gcn_cache') - - def _load_parts(self, cache_directory): - path = self._cache_file_path(cache_directory) - if not os.path.exists(path): - return False - - with open(path, 'rb') as file_: - self.part_indptr, self.part_indices = pickle.load(file_) - return True - - def _save_parts(self, assignment, cache_directory): - os.makedirs(cache_directory, exist_ok=True) - - self.part_indices = np.argsort(assignment) - num_nodes_per_part = np.bincount(assignment) - self.part_indptr = np.insert(np.cumsum(num_nodes_per_part), 0, 0) - - with open(self._cache_file_path(cache_directory), 'wb') as file_: - pickle.dump((self.part_indptr, self.part_indices), file_) - - def __len__(self): - return self.part_indptr.shape[0] - 1 - - def __getitem__(self, i): - nodes = self.part_indices[self.part_indptr[i]:self.part_indptr[i+1]] - return self.g.subgraph(nodes) + def __init__(self, g, k, balance_ntypes=None, balance_edges=False, mode='k-way', + prefetch_node_feats=None, prefetch_edge_feats=None, output_device=None, + cache_path='cluster_gcn.pkl'): + if os.path.exists(cache_path): + try: + with open(cache_path, 'rb') as f: + self.partition_offset, self.partition_node_ids = pickle.load(f) + except (EOFError, TypeError, ValueError): + raise DGLError( + f'The contents in the cache file {cache_path} is invalid. ' + f'Please remove the cache file {cache_path} or specify another path.') + if len(self.partition_offset) != k + 1: + raise DGLError( + f'Number of partitions in the cache does not match the value of k. ' + f'Please remove the cache file {cache_path} or specify another path.') + if len(self.partition_node_ids) != g.num_nodes(): + raise DGLError( + f'Number of nodes in the cache does not match the given graph. ' + f'Please remove the cache file {cache_path} or specify another path.') + else: + partition_ids = metis_partition_assignment( + g, k, balance_ntypes=balance_ntypes, balance_edges=balance_edges, mode=mode) + partition_ids = F.asnumpy(partition_ids) + partition_node_ids = np.argsort(partition_ids) + partition_size = F.zerocopy_from_numpy(np.bincount(partition_ids, minlength=k)) + partition_offset = F.zerocopy_from_numpy(np.insert(np.cumsum(partition_size), 0, 0)) + partition_node_ids = F.zerocopy_from_numpy(partition_ids) + with open(cache_path, 'wb') as f: + pickle.dump((partition_offset, partition_node_ids), f) + self.partition_offset = partition_offset + self.partition_node_ids = partition_node_ids + + self.prefetch_node_feats = prefetch_node_feats or [] + self.prefetch_edge_feats = prefetch_edge_feats or [] + self.output_device = output_device + + def sample(self, g, partition_ids): + """Samples a subgraph given a list of partition IDs.""" + node_ids = F.cat([ + self.partition_node_ids[self.partition_offset[i]:self.partition_offset[i+1]] + for i in F.asnumpy(partition_ids)], 0) + sg = g.subgraph(node_ids, relabel_nodes=True, output_device=self.output_device) + set_node_lazy_features(sg, self.prefetch_node_feats) + set_edge_lazy_features(sg, self.prefetch_edge_feats) + return sg diff --git a/python/dgl/dataloading/dataloader.py b/python/dgl/dataloading/dataloader.py index 01bc2d176e87..45c22125c441 100644 --- a/python/dgl/dataloading/dataloader.py +++ b/python/dgl/dataloading/dataloader.py @@ -1,900 +1,602 @@ -"""Data loaders""" - +"""DGL PyTorch DataLoaders""" from collections.abc import Mapping, Sequence -from abc import ABC, abstractproperty, abstractmethod +from queue import Queue +import itertools +import threading +from distutils.version import LooseVersion +import random +import math +import inspect import re -import numpy as np -from .. import transform -from ..base import NID, EID -from .. import backend as F -from .. import utils -from ..batch import batch -from ..convert import heterograph -from ..heterograph import DGLHeteroGraph as DGLGraph -from ..distributed.dist_graph import DistGraph -from ..utils import to_device - -def _tensor_or_dict_to_numpy(ids): - if isinstance(ids, Mapping): - return {k: F.zerocopy_to_numpy(v) for k, v in ids.items()} - else: - return F.zerocopy_to_numpy(ids) -def _locate_eids_to_exclude(frontier_parent_eids, exclude_eids): - """Find the edges whose IDs in parent graph appeared in exclude_eids. +import torch +import torch.distributed as dist +from torch.utils.data.distributed import DistributedSampler + +from ..base import NID, EID, dgl_warning +from ..batch import batch as batch_graphs +from ..heterograph import DGLHeteroGraph +from .. import ndarray as nd +from ..utils import ( + recursive_apply, ExceptionWrapper, recursive_apply_pair, set_num_threads, + create_shared_mem_array, get_shared_mem_array) +from ..frame import LazyFeature +from ..storages import wrap_storage +from .base import BlockSampler, EdgeBlockSampler +from .. import backend as F - Note that both arguments are numpy arrays or numpy dicts. +class _TensorizedDatasetIter(object): + def __init__(self, dataset, batch_size, drop_last, mapping_keys): + self.dataset = dataset + self.batch_size = batch_size + self.drop_last = drop_last + self.mapping_keys = mapping_keys + self.index = 0 + + # For PyTorch Lightning compatibility + def __iter__(self): + return self + + def _next_indices(self): + num_items = self.dataset.shape[0] + if self.index >= num_items: + raise StopIteration + end_idx = self.index + self.batch_size + if end_idx > num_items: + if self.drop_last: + raise StopIteration + end_idx = num_items + batch = self.dataset[self.index:end_idx] + self.index += self.batch_size + + return batch + + def __next__(self): + batch = self._next_indices() + if self.mapping_keys is None: + return batch + + # convert the type-ID pairs to dictionary + type_ids = batch[:, 0] + indices = batch[:, 1] + type_ids_sortidx = torch.argsort(type_ids) + type_ids = type_ids[type_ids_sortidx] + indices = indices[type_ids_sortidx] + type_id_uniq, type_id_count = torch.unique_consecutive(type_ids, return_counts=True) + type_id_uniq = type_id_uniq.tolist() + type_id_offset = type_id_count.cumsum(0).tolist() + type_id_offset.insert(0, 0) + id_dict = { + self.mapping_keys[type_id_uniq[i]]: indices[type_id_offset[i]:type_id_offset[i+1]] + for i in range(len(type_id_uniq))} + return id_dict + + +def _get_id_tensor_from_mapping(indices, device, keys): + lengths = torch.LongTensor([ + (indices[k].shape[0] if k in indices else 0) for k in keys], device=device) + type_ids = torch.arange(len(keys), device=device).repeat_interleave(lengths) + all_indices = torch.cat([indices[k] for k in keys if k in indices]) + return torch.stack([type_ids, all_indices], 1) + + +def _divide_by_worker(dataset): + num_samples = dataset.shape[0] + worker_info = torch.utils.data.get_worker_info() + if worker_info: + chunk_size = num_samples // worker_info.num_workers + left_over = num_samples % worker_info.num_workers + start = (chunk_size * worker_info.id) + min(left_over, worker_info.id) + end = start + chunk_size + (worker_info.id < left_over) + assert worker_info.id < worker_info.num_workers - 1 or end == num_samples + dataset = dataset[start:end] + return dataset + + +class TensorizedDataset(torch.utils.data.IterableDataset): + """Custom Dataset wrapper that returns a minibatch as tensors or dicts of tensors. + When the dataset is on the GPU, this significantly reduces the overhead. """ - if isinstance(frontier_parent_eids, Mapping): - result = { - k: np.isin(frontier_parent_eids[k], exclude_eids[k]).nonzero()[0] - for k in frontier_parent_eids.keys() if k in exclude_eids.keys()} - return {k: F.zerocopy_from_numpy(v) for k, v in result.items()} - else: - result = np.isin(frontier_parent_eids, exclude_eids).nonzero()[0] - return F.zerocopy_from_numpy(result) - -class _EidExcluder(): - def __init__(self, exclude_eids): - device = None - if isinstance(exclude_eids, Mapping): - for _, v in exclude_eids.items(): - if device is None: - device = F.context(v) - break + def __init__(self, indices, batch_size, drop_last): + if isinstance(indices, Mapping): + self._mapping_keys = list(indices.keys()) + self._device = next(iter(indices.values())).device + self._tensor_dataset = _get_id_tensor_from_mapping( + indices, self._device, self._mapping_keys) else: - device = F.context(exclude_eids) - self._exclude_eids = None - self._filter = None - - if device == F.cpu(): - # TODO(nv-dlasalle): Once Filter is implemented for the CPU, we - # should just use that irregardless of the device. - self._exclude_eids = ( - _tensor_or_dict_to_numpy(exclude_eids) if exclude_eids is not None else None) + self._tensor_dataset = indices + self._device = indices.device + self._mapping_keys = None + self.batch_size = batch_size + self.drop_last = drop_last + + def shuffle(self): + """Shuffle the dataset.""" + # TODO: may need an in-place shuffle kernel + perm = torch.randperm(self._tensor_dataset.shape[0], device=self._device) + self._tensor_dataset[:] = self._tensor_dataset[perm] + + def __iter__(self): + dataset = _divide_by_worker(self._tensor_dataset) + return _TensorizedDatasetIter( + dataset, self.batch_size, self.drop_last, self._mapping_keys) + + def __len__(self): + num_samples = self._tensor_dataset.shape[0] + return (num_samples + (0 if self.drop_last else (self.batch_size - 1))) // self.batch_size + +def _get_shared_mem_name(id_): + return f'ddp_{id_}' + +def _generate_shared_mem_name_id(): + for _ in range(3): # 3 trials + id_ = random.getrandbits(32) + name = _get_shared_mem_name(id_) + if not nd.exist_shared_mem_array(name): + return name, id_ + raise DGLError('Unable to generate a shared memory array') + +class DDPTensorizedDataset(torch.utils.data.IterableDataset): + """Custom Dataset wrapper that returns a minibatch as tensors or dicts of tensors. + When the dataset is on the GPU, this significantly reduces the overhead. + + This class additionally saves the index tensor in shared memory and therefore + avoids duplicating the same index tensor during shuffling. + """ + def __init__(self, indices, batch_size, drop_last, ddp_seed): + if isinstance(indices, Mapping): + self._mapping_keys = list(indices.keys()) else: - if isinstance(exclude_eids, Mapping): - self._filter = {k: utils.Filter(v) for k, v in exclude_eids.items()} - else: - self._filter = utils.Filter(exclude_eids) + self._mapping_keys = None - def _find_indices(self, parent_eids): - """ Find the set of edge indices to remove. - """ - if self._exclude_eids is not None: - parent_eids_np = _tensor_or_dict_to_numpy(parent_eids) - return _locate_eids_to_exclude(parent_eids_np, self._exclude_eids) + self.rank = dist.get_rank() + self.num_replicas = dist.get_world_size() + self.seed = ddp_seed + self.epoch = 0 + self.batch_size = batch_size + self.drop_last = drop_last + + if self.drop_last and len(indices) % self.num_replicas != 0: + self.num_samples = math.ceil((len(indices) - self.num_replicas) / self.num_replicas) else: - assert self._filter is not None - if isinstance(parent_eids, Mapping): - located_eids = {k: self._filter[k].find_included_indices(parent_eids[k]) - for k, v in parent_eids.items() if k in self._filter} + self.num_samples = math.ceil(len(indices) / self.num_replicas) + self.total_size = self.num_samples * self.num_replicas + # If drop_last is True, we create a shared memory array larger than the number + # of indices since we will need to pad it after shuffling to make it evenly + # divisible before every epoch. If drop_last is False, we create an array + # with the same size as the indices so we can trim it later. + self.shared_mem_size = self.total_size if not self.drop_last else len(indices) + self.num_indices = len(indices) + + if self.rank == 0: + name, id_ = _generate_shared_mem_name_id() + if isinstance(indices, Mapping): + device = next(iter(indices.values())).device + id_tensor = _get_id_tensor_from_mapping(indices, device, self._mapping_keys) + self._tensor_dataset = create_shared_mem_array( + name, (self.shared_mem_size, 2), torch.int64) + self._tensor_dataset[:id_tensor.shape[0], :] = id_tensor else: - located_eids = self._filter.find_included_indices(parent_eids) - return located_eids - - def __call__(self, frontier): - parent_eids = frontier.edata[EID] - located_eids = self._find_indices(parent_eids) - - if not isinstance(located_eids, Mapping): - # (BarclayII) If frontier already has a EID field and located_eids is empty, - # the returned graph will keep EID intact. Otherwise, EID will change - # to the mapping from the new graph to the old frontier. - # So we need to test if located_eids is empty, and do the remapping ourselves. - if len(located_eids) > 0: - frontier = transform.remove_edges( - frontier, located_eids, store_ids=True) - frontier.edata[EID] = F.gather_row(parent_eids, frontier.edata[EID]) - else: - # (BarclayII) remove_edges only accepts removing one type of edges, - # so I need to keep track of the edge IDs left one by one. - new_eids = parent_eids.copy() - for k, v in located_eids.items(): - if len(v) > 0: - frontier = transform.remove_edges( - frontier, v, etype=k, store_ids=True) - new_eids[k] = F.gather_row(parent_eids[k], frontier.edges[k].data[EID]) - frontier.edata[EID] = new_eids - return frontier - - -def exclude_edges(subg, exclude_eids, device): - """Find and remove from the subgraph the edges whose IDs in the parent - graph are given. - - Parameters - ---------- - subg : DGLGraph - The subgraph. Must have ``dgl.EID`` field containing the original - edge IDs in the parent graph. - exclude_eids : Tensor or dict - The edge IDs to exclude. - device : device - The output device of the graph. - - Returns - ------- - DGLGraph - The new subgraph with edges removed. The ``dgl.EID`` field contains - the original edge IDs in the same parent graph. - """ - if exclude_eids is None: - return subg - - if device is not None: - if isinstance(exclude_eids, Mapping): - exclude_eids = {k: F.copy_to(v, device) \ - for k, v in exclude_eids.items()} + self._tensor_dataset = create_shared_mem_array( + name, (self.shared_mem_size,), torch.int64) + self._tensor_dataset[:len(indices)] = indices + self._device = self._tensor_dataset.device + meta_info = torch.LongTensor([id_, self._tensor_dataset.shape[0]]) else: - exclude_eids = F.copy_to(exclude_eids, device) - - excluder = _EidExcluder(exclude_eids) - return subg if excluder is None else excluder(subg) - - -def _find_exclude_eids_with_reverse_id(g, eids, reverse_eid_map): - if isinstance(eids, Mapping): - eids = {g.to_canonical_etype(k): v for k, v in eids.items()} - exclude_eids = { - k: F.cat([v, F.gather_row(reverse_eid_map[k], v)], 0) - for k, v in eids.items()} + meta_info = torch.LongTensor([0, 0]) + + if dist.get_backend() == 'nccl': + # Use default CUDA device; PyTorch DDP required the users to set the CUDA + # device for each process themselves so calling .cuda() should be safe. + meta_info = meta_info.cuda() + dist.broadcast(meta_info, src=0) + + if self.rank != 0: + id_, num_samples = meta_info.tolist() + name = _get_shared_mem_name(id_) + if isinstance(indices, Mapping): + indices_shared = get_shared_mem_array(name, (num_samples, 2), torch.int64) + else: + indices_shared = get_shared_mem_array(name, (num_samples,), torch.int64) + self._tensor_dataset = indices_shared + self._device = indices_shared.device + + def shuffle(self): + """Shuffles the dataset.""" + # Only rank 0 does the actual shuffling. The other ranks wait for it. + if self.rank == 0: + self._tensor_dataset[:self.num_indices] = self._tensor_dataset[ + torch.randperm(self.num_indices, device=self._device)] + if not self.drop_last: + # pad extra + self._tensor_dataset[self.num_indices:] = \ + self._tensor_dataset[:self.total_size - self.num_indices] + dist.barrier() + + def __iter__(self): + start = self.num_samples * self.rank + end = self.num_samples * (self.rank + 1) + dataset = _divide_by_worker(self._tensor_dataset[start:end]) + return _TensorizedDatasetIter( + dataset, self.batch_size, self.drop_last, self._mapping_keys) + + def __len__(self): + return (self.num_samples + (0 if self.drop_last else (self.batch_size - 1))) // \ + self.batch_size + + +def _prefetch_update_feats(feats, frames, types, get_storage_func, id_name, device, pin_memory): + for tid, frame in enumerate(frames): + type_ = types[tid] + default_id = frame.get(id_name, None) + for key in frame.keys(): + column = frame[key] + if isinstance(column, LazyFeature): + parent_key = column.name or key + if column.id_ is None and default_id is None: + raise DGLError( + 'Found a LazyFeature with no ID specified, ' + 'and the graph does not have dgl.NID or dgl.EID columns') + feats[tid, key] = get_storage_func(parent_key, type_).fetch( + column.id_ or default_id, device, pin_memory) + + +# This class exists to avoid recursion into the feature dictionary returned by the +# prefetcher when calling recursive_apply(). +class _PrefetchedGraphFeatures(object): + __slots__ = ['node_feats', 'edge_feats'] + def __init__(self, node_feats, edge_feats): + self.node_feats = node_feats + self.edge_feats = edge_feats + + +def _prefetch_for_subgraph(subg, dataloader): + node_feats, edge_feats = {}, {} + _prefetch_update_feats( + node_feats, subg._node_frames, subg.ntypes, dataloader.graph.get_node_storage, + NID, dataloader.device, dataloader.pin_memory) + _prefetch_update_feats( + edge_feats, subg._edge_frames, subg.canonical_etypes, dataloader.graph.get_edge_storage, + EID, dataloader.device, dataloader.pin_memory) + return _PrefetchedGraphFeatures(node_feats, edge_feats) + + +def _prefetch_for(item, dataloader): + if isinstance(item, DGLHeteroGraph): + return _prefetch_for_subgraph(item, dataloader) + elif isinstance(item, LazyFeature): + return dataloader.other_storages[item.name].fetch( + item.id_, dataloader.device, dataloader.pin_memory) else: - exclude_eids = F.cat([eids, F.gather_row(reverse_eid_map, eids)], 0) - return exclude_eids - -def _find_exclude_eids_with_reverse_types(g, eids, reverse_etype_map): - exclude_eids = {g.to_canonical_etype(k): v for k, v in eids.items()} - reverse_etype_map = { - g.to_canonical_etype(k): g.to_canonical_etype(v) - for k, v in reverse_etype_map.items()} - exclude_eids.update({reverse_etype_map[k]: v for k, v in exclude_eids.items()}) - return exclude_eids + return None -def _find_exclude_eids(g, exclude_mode, eids, **kwargs): - """Find all edge IDs to exclude according to :attr:`exclude_mode`. - Parameters - ---------- - g : DGLGraph - The graph. - exclude_mode : str, optional - Can be either of the following, - - None (default) - Does not exclude any edge. - - 'self' - Exclude the given edges themselves but nothing else. - - 'reverse_id' - Exclude all edges specified in ``eids``, as well as their reverse edges - of the same edge type. - - The mapping from each edge ID to its reverse edge ID is specified in - the keyword argument ``reverse_eid_map``. - - This mode assumes that the reverse of an edge with ID ``e`` and type - ``etype`` will have ID ``reverse_eid_map[e]`` and type ``etype``. - - 'reverse_types' - Exclude all edges specified in ``eids``, as well as their reverse - edges of the corresponding edge types. - - The mapping from each edge type to its reverse edge type is specified - in the keyword argument ``reverse_etype_map``. - - This mode assumes that the reverse of an edge with ID ``e`` and type ``etype`` - will have ID ``e`` and type ``reverse_etype_map[etype]``. - eids : Tensor or dict[etype, Tensor] - The edge IDs. - reverse_eid_map : Tensor or dict[etype, Tensor] - The mapping from edge ID to its reverse edge ID. - reverse_etype_map : dict[etype, etype] - The mapping from edge etype to its reverse edge type. - """ - if exclude_mode is None: - return None - elif exclude_mode == 'self': - return eids - elif exclude_mode == 'reverse_id': - return _find_exclude_eids_with_reverse_id(g, eids, kwargs['reverse_eid_map']) - elif exclude_mode == 'reverse_types': - return _find_exclude_eids_with_reverse_types(g, eids, kwargs['reverse_etype_map']) +def _await_or_return(x): + if hasattr(x, 'wait'): + return x.wait() + elif isinstance(x, _PrefetchedGraphFeatures): + node_feats = recursive_apply(x.node_feats, _await_or_return) + edge_feats = recursive_apply(x.edge_feats, _await_or_return) + return _PrefetchedGraphFeatures(node_feats, edge_feats) else: - raise ValueError('unsupported mode {}'.format(exclude_mode)) - -class Sampler(object): - """An abstract class that takes in a graph and a set of seed nodes and returns a - structure representing a smaller portion of the graph for computation. It can - be either a list of bipartite graphs (i.e. :class:`BlockSampler`), or a single - subgraph. + return x + + +def _prefetch(batch, dataloader, stream): + # feats has the same nested structure of batch, except that + # (1) each subgraph is replaced with a pair of node features and edge features, both + # being dictionaries whose keys are (type_id, column_name) and values are either + # tensors or futures. + # (2) each LazyFeature object is replaced with a tensor or future. + # (3) everything else are replaced with None. + # + # Once the futures are fetched, this function waits for them to complete by + # calling its wait() method. + with torch.cuda.stream(stream): + feats = recursive_apply(batch, _prefetch_for, dataloader) + feats = recursive_apply(feats, _await_or_return) + return feats + + +def _assign_for(item, feat): + if isinstance(item, DGLHeteroGraph): + subg = item + for (tid, key), value in feat.node_feats.items(): + assert isinstance(subg._node_frames[tid][key], LazyFeature) + subg._node_frames[tid][key] = value + for (tid, key), value in feat.edge_feats.items(): + assert isinstance(subg._edge_frames[tid][key], LazyFeature) + subg._edge_frames[tid][key] = value + return subg + elif isinstance(item, LazyFeature): + return feat + else: + return item + + +def _prefetcher_entry(dataloader_it, dataloader, queue, num_threads, use_alternate_streams): + # PyTorch will set the number of threads to 1 which slows down pin_memory() calls + # in main process if a prefetching thread is created. + if num_threads is not None: + torch.set_num_threads(num_threads) + if use_alternate_streams: + stream = ( + torch.cuda.Stream(device=dataloader.device) + if dataloader.device.type == 'cuda' else None) + else: + stream = None + + try: + for batch in dataloader_it: + batch = recursive_apply(batch, restore_parent_storage_columns, dataloader.graph) + feats = _prefetch(batch, dataloader, stream) + + queue.put(( + # batch will be already in pinned memory as per the behavior of + # PyTorch DataLoader. + recursive_apply(batch, lambda x: x.to(dataloader.device, non_blocking=True)), + feats, + stream.record_event() if stream is not None else None, + None)) + queue.put((None, None, None, None)) + except: # pylint: disable=bare-except + queue.put((None, None, None, ExceptionWrapper(where='in prefetcher'))) + + +# DGLHeteroGraphs have the semantics of lazy feature slicing with subgraphs. Such behavior depends +# on that DGLHeteroGraph's ndata and edata are maintained by Frames. So to maintain compatibility +# with older code, DGLHeteroGraphs and other graph storages are handled separately: (1) +# DGLHeteroGraphs will preserve the lazy feature slicing for subgraphs. (2) Other graph storages +# will not have lazy feature slicing; all feature slicing will be eager. +def remove_parent_storage_columns(item, g): + """Removes the storage objects in the given graphs' Frames if it is a sub-frame of the + given parent graph, so that the storages are not serialized during IPC from PyTorch + DataLoader workers. """ - def __init__(self, output_ctx=None): - self.set_output_context(output_ctx) - - def sample(self, g, seed_nodes, exclude_eids=None): - """Sample a structure from the graph. - - Parameters - ---------- - g : DGLGraph - The original graph. - seed_nodes : Tensor or dict[ntype, Tensor] - The destination nodes by type. - - If the graph only has one node type, one can just specify a single tensor - of node IDs. - exclude_eids : Tensor or dict[etype, Tensor] - The edges to exclude from computation dependency. - - Returns - ------- - Tensor or dict[ntype, Tensor] - The nodes whose input features are required for computing the output - representation of :attr:`seed_nodes`. - any - Any data representing the structure. - """ - raise NotImplementedError - - def set_output_context(self, ctx): - """Set the device the generated block or subgraph will be output to. - This should only be set to a cuda device, when multi-processing is not - used in the dataloader (e.g., num_workers is 0). - - Parameters - ---------- - ctx : DGLContext, default None - The device context the sampled blocks will be stored on. This - should only be a CUDA context if multiprocessing is not used in - the dataloader (e.g., num_workers is 0). If this is None, the - sampled blocks will be stored on the same device as the input - graph. - """ - if ctx is not None: - self.output_device = F.to_backend_ctx(ctx) - else: - self.output_device = None - -class BlockSampler(Sampler): - """Abstract class specifying the neighborhood sampling strategy for DGL data loaders. - - The main method for BlockSampler is :meth:`sample`, - which generates a list of message flow graphs (MFGs) for a multi-layer GNN given a set of - seed nodes to have their outputs computed. - - The default implementation of :meth:`sample` is - to repeat :attr:`num_layers` times the following procedure from the last layer to the first - layer: - - * Obtain a frontier. The frontier is defined as a graph with the same nodes as the - original graph but only the edges involved in message passing on the current layer. - Customizable via :meth:`sample_frontier`. - - * Optionally, if the task is link prediction or edge classfication, remove edges - connecting training node pairs. If the graph is undirected, also remove the - reverse edges. This is controlled by the argument :attr:`exclude_eids` in - :meth:`sample` method. - - * Convert the frontier into a MFG. - - * Optionally assign the IDs of the edges in the original graph selected in the first step - to the MFG, controlled by the argument ``return_eids`` in - :meth:`sample` method. - - * Prepend the MFG to the MFG list to be returned. - - All subclasses should override :meth:`sample_frontier` - method while specifying the number of layers to sample in :attr:`num_layers` argument. - - Parameters - ---------- - num_layers : int - The number of layers to sample. - return_eids : bool, default False - Whether to return the edge IDs involved in message passing in the MFG. - If True, the edge IDs will be stored as an edge feature named ``dgl.EID``. - output_ctx : DGLContext, default None - The context the sampled blocks will be stored on. This should only be - a CUDA context if multiprocessing is not used in the dataloader (e.g., - num_workers is 0). If this is None, the sampled blocks will be stored - on the same device as the input graph. - exclude_edges_in_frontier : bool, default False - If True, the :func:`sample_frontier` method will receive an argument - :attr:`exclude_eids` containing the edge IDs from the original graph to exclude. - The :func:`sample_frontier` method must return a graph that does not contain - the edges corresponding to the excluded edges. No additional postprocessing - will be done. - - Otherwise, the edges will be removed *after* :func:`sample_frontier` returns. - - Notes - ----- - For the concept of frontiers and MFGs, please refer to - :ref:`User Guide Section 6 ` and - :doc:`Minibatch Training Tutorials `. + if not isinstance(item, DGLHeteroGraph) or not isinstance(g, DGLHeteroGraph): + return item + + for subframe, frame in zip( + itertools.chain(item._node_frames, item._edge_frames), + itertools.chain(g._node_frames, g._edge_frames)): + for key in list(subframe.keys()): + subcol = subframe._columns[key] # directly get the column object + if isinstance(subcol, LazyFeature): + continue + col = frame._columns.get(key, None) + if col is None: + continue + if col.storage is subcol.storage: + subcol.storage = None + return item + + +def restore_parent_storage_columns(item, g): + """Restores the storage objects in the given graphs' Frames if it is a sub-frame of the + given parent graph (i.e. when the storage object is None). """ - def __init__(self, num_layers, return_eids=False, output_ctx=None): - super().__init__(output_ctx) - self.num_layers = num_layers - self.return_eids = return_eids - - # pylint: disable=unused-argument - @staticmethod - def assign_block_eids(block, frontier): - """Assigns edge IDs from the original graph to the message flow graph (MFG). - - See also - -------- - BlockSampler - """ - for etype in block.canonical_etypes: - block.edges[etype].data[EID] = frontier.edges[etype].data[EID][ - block.edges[etype].data[EID]] - return block - - # This is really a hack working around the lack of GPU-based neighbor sampling - # with edge exclusion. - @classmethod - def exclude_edges_in_frontier(cls, g): - """Returns whether the sampler will exclude edges in :func:`sample_frontier`. - - If this method returns True, the method :func:`sample_frontier` will receive an - argument :attr:`exclude_eids` from :func:`sample`. :func:`sample_frontier` - is then responsible for removing those edges. - - If this method returns False, :func:`sample` will be responsible for - removing the edges. - - When subclassing :class:`BlockSampler`, this method should return True when you - would like to remove the excluded edges in your :func:`sample_frontier` method. - - By default this method returns False. - - Parameters - ---------- - g : DGLGraph - The original graph - - Returns - ------- - bool - Whether :func:`sample_frontier` will receive an argument :attr:`exclude_eids`. - """ - return False - - def sample_frontier(self, block_id, g, seed_nodes, exclude_eids=None): - """Generate the frontier given the destination nodes. - - The subclasses should override this function. - - Parameters - ---------- - block_id : int - Represents which GNN layer the frontier is generated for. - g : DGLGraph - The original graph. - seed_nodes : Tensor or dict[ntype, Tensor] - The destination nodes by node type. - - If the graph only has one node type, one can just specify a single tensor - of node IDs. - exclude_eids: Tensor or dict - Edge IDs to exclude during sampling neighbors for the seed nodes. - - This argument can take a single ID tensor or a dictionary of edge types and ID tensors. - If a single tensor is given, the graph must only have one type of nodes. - - Returns - ------- - DGLGraph - The frontier generated for the current layer. - - Notes - ----- - For the concept of frontiers and MFGs, please refer to - :ref:`User Guide Section 6 ` and - :doc:`Minibatch Training Tutorials `. - """ - raise NotImplementedError - - def sample(self, g, seed_nodes, exclude_eids=None): - """Generate the a list of MFGs given the destination nodes. - - Parameters - ---------- - g : DGLGraph - The original graph. - seed_nodes : Tensor or dict[ntype, Tensor] - The destination nodes by node type. - - If the graph only has one node type, one can just specify a single tensor - of node IDs. - exclude_eids : Tensor or dict[etype, Tensor] - The edges to exclude from computation dependency. - - Returns - ------- - list[DGLGraph] - The MFGs generated for computing the multi-layer GNN output. - - Notes - ----- - For the concept of frontiers and MFGs, please refer to - :ref:`User Guide Section 6 ` and - :doc:`Minibatch Training Tutorials `. - """ - blocks = [] - - if isinstance(g, DistGraph): - # TODO:(nv-dlasalle) dist graphs may not have an associated graph, - # causing an error when trying to fetch the device, so for now, - # always assume the distributed graph's device is CPU. - graph_device = F.cpu() + if not isinstance(item, DGLHeteroGraph) or not isinstance(g, DGLHeteroGraph): + return item + + for subframe, frame in zip( + itertools.chain(item._node_frames, item._edge_frames), + itertools.chain(g._node_frames, g._edge_frames)): + for key in subframe.keys(): + subcol = subframe._columns[key] + if isinstance(subcol, LazyFeature): + continue + col = frame._columns.get(key, None) + if col is None: + continue + if subcol.storage is None: + subcol.storage = col.storage + return item + + +class _PrefetchingIter(object): + def __init__(self, dataloader, dataloader_it, use_thread=False, use_alternate_streams=True, + num_threads=None): + self.queue = Queue(1) + self.dataloader_it = dataloader_it + self.dataloader = dataloader + self.graph_sampler = self.dataloader.graph_sampler + self.pin_memory = self.dataloader.pin_memory + self.num_threads = num_threads + + self.use_thread = use_thread + self.use_alternate_streams = use_alternate_streams + if use_thread: + thread = threading.Thread( + target=_prefetcher_entry, + args=(dataloader_it, dataloader, self.queue, num_threads, use_alternate_streams), + daemon=True) + thread.start() + self.thread = thread + + def __iter__(self): + return self + + def _next_non_threaded(self): + batch = next(self.dataloader_it) + batch = recursive_apply(batch, restore_parent_storage_columns, self.dataloader.graph) + device = self.dataloader.device + if self.use_alternate_streams: + stream = torch.cuda.Stream(device=device) if device.type == 'cuda' else None else: - graph_device = g.device - - for block_id in reversed(range(self.num_layers)): - seed_nodes_in = to_device(seed_nodes, graph_device) - - if self.exclude_edges_in_frontier(g): - frontier = self.sample_frontier( - block_id, g, seed_nodes_in, exclude_eids=exclude_eids) - else: - frontier = self.sample_frontier(block_id, g, seed_nodes_in) - - if self.output_device is not None: - frontier = frontier.to(self.output_device) - seed_nodes_out = to_device(seed_nodes, self.output_device) - else: - seed_nodes_out = seed_nodes - - # Removing edges from the frontier for link prediction training falls - # into the category of frontier postprocessing - if not self.exclude_edges_in_frontier(g): - frontier = exclude_edges(frontier, exclude_eids, self.output_device) - - block = transform.to_block(frontier, seed_nodes_out) - if self.return_eids: - self.assign_block_eids(block, frontier) - - seed_nodes = {ntype: block.srcnodes[ntype].data[NID] for ntype in block.srctypes} - blocks.insert(0, block) - return blocks[0].srcdata[NID], blocks[-1].dstdata[NID], blocks - - def sample_blocks(self, g, seed_nodes, exclude_eids=None): - """Deprecated and identical to :meth:`sample`. - """ - return self.sample(g, seed_nodes, exclude_eids) - -class Collator(ABC): - """Abstract DGL collator for training GNNs on downstream tasks stochastically. - - Provides a :attr:`dataset` object containing the collection of all nodes or edges, - as well as a :attr:`collate` method that combines a set of items from - :attr:`dataset` and obtains the message flow graphs (MFGs). - - Notes - ----- - For the concept of MFGs, please refer to - :ref:`User Guide Section 6 ` and - :doc:`Minibatch Training Tutorials `. - """ - @abstractproperty - def dataset(self): - """Returns the dataset object of the collator.""" - raise NotImplementedError - - @abstractmethod - def collate(self, items): - """Combines the items from the dataset object and obtains the list of MFGs. - - Parameters - ---------- - items : list[str, int] - The list of node or edge IDs or type-ID pairs. - - Notes - ----- - For the concept of MFGs, please refer to - :ref:`User Guide Section 6 ` and - :doc:`Minibatch Training Tutorials `. - """ - raise NotImplementedError - -class NodeCollator(Collator): - """DGL collator to combine nodes and their computation dependencies within a minibatch for - training node classification or regression on a single graph with neighborhood sampling. - - Parameters - ---------- - g : DGLGraph - The graph. - nids : Tensor or dict[ntype, Tensor] - The node set to compute outputs. - graph_sampler : dgl.dataloading.BlockSampler - The neighborhood sampler. - - Examples - -------- - To train a 3-layer GNN for node classification on a set of nodes ``train_nid`` on - a homogeneous graph where each node takes messages from all neighbors (assume - the backend is PyTorch): - - >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5]) - >>> collator = dgl.dataloading.NodeCollator(g, train_nid, sampler) - >>> dataloader = torch.utils.data.DataLoader( - ... collator.dataset, collate_fn=collator.collate, - ... batch_size=1024, shuffle=True, drop_last=False, num_workers=4) - >>> for input_nodes, output_nodes, blocks in dataloader: - ... train_on(input_nodes, output_nodes, blocks) - - Notes - ----- - For the concept of MFGs, please refer to - :ref:`User Guide Section 6 ` and - :doc:`Minibatch Training Tutorials `. + stream = None + feats = _prefetch(batch, self.dataloader, stream) + batch = recursive_apply(batch, lambda x: x.to(device, non_blocking=True)) + stream_event = stream.record_event() if stream is not None else None + return batch, feats, stream_event + + def _next_threaded(self): + batch, feats, stream_event, exception = self.queue.get() + if batch is None: + self.thread.join() + if exception is None: + raise StopIteration + exception.reraise() + return batch, feats, stream_event + + def __next__(self): + batch, feats, stream_event = \ + self._next_non_threaded() if not self.use_thread else self._next_threaded() + batch = recursive_apply_pair(batch, feats, _assign_for) + if stream_event is not None: + stream_event.wait() + return batch + + +# Make them classes to work with pickling in mp.spawn +class CollateWrapper(object): + """Wraps a collate function with :func:`remove_parent_storage_columns` for serializing + from PyTorch DataLoader workers. """ - def __init__(self, g, nids, graph_sampler): + def __init__(self, sample_func, g): + self.sample_func = sample_func self.g = g - if not isinstance(nids, Mapping): - assert len(g.ntypes) == 1, \ - "nids should be a dict of node type and ids for graph with multiple node types" - self.graph_sampler = graph_sampler - self.nids = utils.prepare_tensor_or_dict(g, nids, 'nids') - self._dataset = utils.maybe_flatten_dict(self.nids) + def __call__(self, items): + batch = self.sample_func(self.g, items) + return recursive_apply(batch, remove_parent_storage_columns, self.g) - @property - def dataset(self): - return self._dataset - def collate(self, items): - """Find the list of MFGs necessary for computing the representation of given - nodes for a node classification/regression task. - - Parameters - ---------- - items : list[int] or list[tuple[str, int]] - Either a list of node IDs (for homogeneous graphs), or a list of node type-ID - pairs (for heterogeneous graphs). - - Returns - ------- - input_nodes : Tensor or dict[ntype, Tensor] - The input nodes necessary for computation in this minibatch. - - If the original graph has multiple node types, return a dictionary of - node type names and node ID tensors. Otherwise, return a single tensor. - output_nodes : Tensor or dict[ntype, Tensor] - The nodes whose representations are to be computed in this minibatch. - - If the original graph has multiple node types, return a dictionary of - node type names and node ID tensors. Otherwise, return a single tensor. - MFGs : list[DGLGraph] - The list of MFGs necessary for computing the representation. - """ - if isinstance(items[0], tuple): - # returns a list of pairs: group them by node types into a dict - items = utils.group_as_dict(items) - items = utils.prepare_tensor_or_dict(self.g, items, 'items') - - input_nodes, output_nodes, blocks = self.graph_sampler.sample(self.g, items) - - return input_nodes, output_nodes, blocks - -class EdgeCollator(Collator): - """DGL collator to combine edges and their computation dependencies within a minibatch for - training edge classification, edge regression, or link prediction on a single graph - with neighborhood sampling. - - Given a set of edges, the collate function will yield - - * A tensor of input nodes necessary for computing the representation on edges, or - a dictionary of node type names and such tensors. - - * A subgraph that contains only the edges in the minibatch and their incident nodes. - Note that the graph has an identical metagraph with the original graph. - - * If a negative sampler is given, another graph that contains the "negative edges", - connecting the source and destination nodes yielded from the given negative sampler. - - * A list of MFGs necessary for computing the representation of the incident nodes - of the edges in the minibatch. - - Parameters - ---------- - g : DGLGraph - The graph from which the edges are iterated in minibatches and the subgraphs - are generated. - eids : Tensor or dict[etype, Tensor] - The edge set in graph :attr:`g` to compute outputs. - graph_sampler : dgl.dataloading.BlockSampler - The neighborhood sampler. - g_sampling : DGLGraph, optional - The graph where neighborhood sampling and message passing is performed. - - Note that this is not necessarily the same as :attr:`g`. - - If None, assume to be the same as :attr:`g`. - exclude : str, optional - Whether and how to exclude dependencies related to the sampled edges in the - minibatch. Possible values are - - * None, which excludes nothing. - - * ``'self'``, which excludes the sampled edges themselves but nothing else. - - * ``'reverse_id'``, which excludes the reverse edges of the sampled edges. The said - reverse edges have the same edge type as the sampled edges. Only works - on edge types whose source node type is the same as its destination node type. - - * ``'reverse_types'``, which excludes the reverse edges of the sampled edges. The - said reverse edges have different edge types from the sampled edges. - - If ``g_sampling`` is given, ``exclude`` is ignored and will be always ``None``. - reverse_eids : Tensor or dict[etype, Tensor], optional - A tensor of reverse edge ID mapping. The i-th element indicates the ID of - the i-th edge's reverse edge. - - If the graph is heterogeneous, this argument requires a dictionary of edge - types and the reverse edge ID mapping tensors. - - Required and only used when ``exclude`` is set to ``reverse_id``. - - For heterogeneous graph this will be a dict of edge type and edge IDs. Note that - only the edge types whose source node type is the same as destination node type - are needed. - reverse_etypes : dict[etype, etype], optional - The mapping from the edge type to its reverse edge type. - - Required and only used when ``exclude`` is set to ``reverse_types``. - negative_sampler : callable, optional - The negative sampler. Can be omitted if no negative sampling is needed. +class WorkerInitWrapper(object): + """Wraps the :attr:`worker_init_fn` argument of the DataLoader to set the number of DGL + OMP threads to 1 for PyTorch DataLoader workers. + """ + def __init__(self, func): + self.func = func - The negative sampler must be a callable that takes in the following arguments: + def __call__(self, worker_id): + set_num_threads(1) + if self.func is not None: + self.func(worker_id) - * The original (heterogeneous) graph. - * The ID array of sampled edges in the minibatch, or the dictionary of edge - types and ID array of sampled edges in the minibatch if the graph is - heterogeneous. +def create_tensorized_dataset(indices, batch_size, drop_last, use_ddp, ddp_seed): + """Converts a given indices tensor to a TensorizedDataset, an IterableDataset + that returns views of the original tensor, to reduce overhead from having + a list of scalar tensors in default PyTorch DataLoader implementation. + """ + if use_ddp: + return DDPTensorizedDataset(indices, batch_size, drop_last, ddp_seed) + else: + return TensorizedDataset(indices, batch_size, drop_last) - It should return - * A pair of source and destination node ID arrays as negative samples, - or a dictionary of edge types and such pairs if the graph is heterogenenous. +class DataLoader(torch.utils.data.DataLoader): + """DataLoader class.""" + def __init__(self, graph, indices, graph_sampler, device='cpu', use_ddp=False, + ddp_seed=0, batch_size=1, drop_last=False, shuffle=False, + use_prefetch_thread=False, use_alternate_streams=True, **kwargs): + self.graph = graph - A set of builtin negative samplers are provided in - :ref:`the negative sampling module `. + try: + if isinstance(indices, Mapping): + indices = {k: (torch.tensor(v) if not torch.is_tensor(v) else v) + for k, v in indices.items()} + else: + indices = torch.tensor(indices) if not torch.is_tensor(indices) else indices + except: # pylint: disable=bare-except + # ignore when it fails to convert to torch Tensors. + pass + + if (torch.is_tensor(indices) or ( + isinstance(indices, Mapping) and + all(torch.is_tensor(v) for v in indices.values()))): + self.dataset = create_tensorized_dataset( + indices, batch_size, drop_last, use_ddp, ddp_seed) + else: + self.dataset = indices - Examples - -------- - The following example shows how to train a 3-layer GNN for edge classification on a - set of edges ``train_eid`` on a homogeneous undirected graph. Each node takes - messages from all neighbors. - - Say that you have an array of source node IDs ``src`` and another array of destination - node IDs ``dst``. One can make it bidirectional by adding another set of edges - that connects from ``dst`` to ``src``: - - >>> g = dgl.graph((torch.cat([src, dst]), torch.cat([dst, src]))) - - One can then know that the ID difference of an edge and its reverse edge is ``|E|``, - where ``|E|`` is the length of your source/destination array. The reverse edge - mapping can be obtained by - - >>> E = len(src) - >>> reverse_eids = torch.cat([torch.arange(E, 2 * E), torch.arange(0, E)]) - - Note that the sampled edges as well as their reverse edges are removed from - computation dependencies of the incident nodes. This is a common trick to avoid - information leakage. - - >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5]) - >>> collator = dgl.dataloading.EdgeCollator( - ... g, train_eid, sampler, exclude='reverse_id', - ... reverse_eids=reverse_eids) - >>> dataloader = torch.utils.data.DataLoader( - ... collator.dataset, collate_fn=collator.collate, - ... batch_size=1024, shuffle=True, drop_last=False, num_workers=4) - >>> for input_nodes, pair_graph, blocks in dataloader: - ... train_on(input_nodes, pair_graph, blocks) - - To train a 3-layer GNN for link prediction on a set of edges ``train_eid`` on a - homogeneous graph where each node takes messages from all neighbors (assume the - backend is PyTorch), with 5 uniformly chosen negative samples per edge: - - >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5]) - >>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(5) - >>> collator = dgl.dataloading.EdgeCollator( - ... g, train_eid, sampler, exclude='reverse_id', - ... reverse_eids=reverse_eids, negative_sampler=neg_sampler) - >>> dataloader = torch.utils.data.DataLoader( - ... collator.dataset, collate_fn=collator.collate, - ... batch_size=1024, shuffle=True, drop_last=False, num_workers=4) - >>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader: - ... train_on(input_nodse, pair_graph, neg_pair_graph, blocks) - - For heterogeneous graphs, the reverse of an edge may have a different edge type - from the original edge. For instance, consider that you have an array of - user-item clicks, representated by a user array ``user`` and an item array ``item``. - You may want to build a heterogeneous graph with a user-click-item relation and an - item-clicked-by-user relation. - - >>> g = dgl.heterograph({ - ... ('user', 'click', 'item'): (user, item), - ... ('item', 'clicked-by', 'user'): (item, user)}) - - To train a 3-layer GNN for edge classification on a set of edges ``train_eid`` with - type ``click``, you can write - - >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5]) - >>> collator = dgl.dataloading.EdgeCollator( - ... g, {'click': train_eid}, sampler, exclude='reverse_types', - ... reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click'}) - >>> dataloader = torch.utils.data.DataLoader( - ... collator.dataset, collate_fn=collator.collate, - ... batch_size=1024, shuffle=True, drop_last=False, num_workers=4) - >>> for input_nodes, pair_graph, blocks in dataloader: - ... train_on(input_nodes, pair_graph, blocks) - - To train a 3-layer GNN for link prediction on a set of edges ``train_eid`` with type - ``click``, you can write - - >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5]) - >>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(5) - >>> collator = dgl.dataloading.EdgeCollator( - ... g, train_eid, sampler, exclude='reverse_types', - ... reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click'}, - ... negative_sampler=neg_sampler) - >>> dataloader = torch.utils.data.DataLoader( - ... collator.dataset, collate_fn=collator.collate, - ... batch_size=1024, shuffle=True, drop_last=False, num_workers=4) - >>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader: - ... train_on(input_nodes, pair_graph, neg_pair_graph, blocks) - - Notes - ----- - For the concept of MFGs, please refer to - :ref:`User Guide Section 6 ` and - :doc:`Minibatch Training Tutorials `. - """ - def __init__(self, g, eids, graph_sampler, g_sampling=None, exclude=None, - reverse_eids=None, reverse_etypes=None, negative_sampler=None): - self.g = g - if not isinstance(eids, Mapping): - assert len(g.etypes) == 1, \ - "eids should be a dict of etype and ids for graph with multiple etypes" + self.ddp_seed = ddp_seed + self._shuffle_dataset = shuffle self.graph_sampler = graph_sampler - - # One may wish to iterate over the edges in one graph while perform sampling in - # another graph. This may be the case for iterating over validation and test - # edge set while perform neighborhood sampling on the graph formed by only - # the training edge set. - # See GCMC for an example usage. + self.device = torch.device(device) + self.use_alternate_streams = use_alternate_streams + if self.device.type == 'cuda' and self.device.index is None: + self.device = torch.device('cuda', torch.cuda.current_device()) + self.use_prefetch_thread = use_prefetch_thread + worker_init_fn = WorkerInitWrapper(kwargs.get('worker_init_fn', None)) + + # Instantiate all the formats if the number of workers is greater than 0. + if kwargs.get('num_workers', 0) > 0 and hasattr(self.graph, 'create_formats_'): + self.graph.create_formats_() + + self.other_storages = {} + + super().__init__( + self.dataset, + collate_fn=CollateWrapper(self.graph_sampler.sample, graph), + batch_size=None, + worker_init_fn=worker_init_fn, + **kwargs) + + def __iter__(self): + if self._shuffle_dataset: + self.dataset.shuffle() + # When using multiprocessing PyTorch sometimes set the number of PyTorch threads to 1 + # when spawning new Python threads. This drastically slows down pinning features. + num_threads = torch.get_num_threads() if self.num_workers > 0 else None + return _PrefetchingIter( + self, super().__iter__(), use_thread=self.use_prefetch_thread, + use_alternate_streams=self.use_alternate_streams, num_threads=num_threads) + + # To allow data other than node/edge data to be prefetched. + def attach_data(self, name, data): + """Add a data other than node and edge features for prefetching.""" + self.other_storages[name] = wrap_storage(data) + + +# Alias +class NodeDataLoader(DataLoader): + """NodeDataLoader class.""" + + +class EdgeDataLoader(DataLoader): + """EdgeDataLoader class.""" + def __init__(self, graph, indices, graph_sampler, device='cpu', use_ddp=False, + ddp_seed=0, batch_size=1, drop_last=False, shuffle=False, + use_prefetch_thread=False, use_alternate_streams=True, + exclude=None, reverse_eids=None, reverse_etypes=None, negative_sampler=None, + g_sampling=None, **kwargs): if g_sampling is not None: - self.g_sampling = g_sampling - self.exclude = None - else: - self.g_sampling = self.g - self.exclude = exclude - - self.reverse_eids = reverse_eids - self.reverse_etypes = reverse_etypes - self.negative_sampler = negative_sampler - - self.eids = utils.prepare_tensor_or_dict(g, eids, 'eids') - self._dataset = utils.maybe_flatten_dict(self.eids) - - @property - def dataset(self): - return self._dataset - - def _collate(self, items): - if isinstance(items[0], tuple): - # returns a list of pairs: group them by node types into a dict - items = utils.group_as_dict(items) - items = utils.prepare_tensor_or_dict(self.g_sampling, items, 'items') - - pair_graph = self.g.edge_subgraph(items) - seed_nodes = pair_graph.ndata[NID] - - exclude_eids = _find_exclude_eids( - self.g_sampling, - self.exclude, - items, - reverse_eid_map=self.reverse_eids, - reverse_etype_map=self.reverse_etypes) - - input_nodes, _, blocks = self.graph_sampler.sample( - self.g_sampling, seed_nodes, exclude_eids=exclude_eids) - - return input_nodes, pair_graph, blocks - - def _collate_with_negative_sampling(self, items): - if isinstance(items[0], tuple): - # returns a list of pairs: group them by node types into a dict - items = utils.group_as_dict(items) - items = utils.prepare_tensor_or_dict(self.g_sampling, items, 'items') - - pair_graph = self.g.edge_subgraph(items, relabel_nodes=False) - induced_edges = pair_graph.edata[EID] - - neg_srcdst = self.negative_sampler(self.g, items) - if not isinstance(neg_srcdst, Mapping): - assert len(self.g.etypes) == 1, \ - 'graph has multiple or no edge types; '\ - 'please return a dict in negative sampler.' - neg_srcdst = {self.g.canonical_etypes[0]: neg_srcdst} - # Get dtype from a tuple of tensors - dtype = F.dtype(list(neg_srcdst.values())[0][0]) - ctx = F.context(pair_graph) - neg_edges = { - etype: neg_srcdst.get(etype, (F.copy_to(F.tensor([], dtype), ctx), - F.copy_to(F.tensor([], dtype), ctx))) - for etype in self.g.canonical_etypes} - neg_pair_graph = heterograph( - neg_edges, {ntype: self.g.number_of_nodes(ntype) for ntype in self.g.ntypes}) - - pair_graph, neg_pair_graph = transform.compact_graphs([pair_graph, neg_pair_graph]) - pair_graph.edata[EID] = induced_edges - - seed_nodes = pair_graph.ndata[NID] - - exclude_eids = _find_exclude_eids( - self.g_sampling, - self.exclude, - items, - reverse_eid_map=self.reverse_eids, - reverse_etype_map=self.reverse_etypes) - - input_nodes, _, blocks = self.graph_sampler.sample( - self.g_sampling, seed_nodes, exclude_eids=exclude_eids) - - return input_nodes, pair_graph, neg_pair_graph, blocks - - def collate(self, items): - """Combines the sampled edges into a minibatch for edge classification, edge - regression, and link prediction tasks. - - Parameters - ---------- - items : list[int] or list[tuple[str, int]] - Either a list of edge IDs (for homogeneous graphs), or a list of edge type-ID - pairs (for heterogeneous graphs). - - Returns - ------- - Either ``(input_nodes, pair_graph, blocks)``, or - ``(input_nodes, pair_graph, negative_pair_graph, blocks)`` if negative sampling is - enabled. - - input_nodes : Tensor or dict[ntype, Tensor] - The input nodes necessary for computation in this minibatch. - - If the original graph has multiple node types, return a dictionary of - node type names and node ID tensors. Otherwise, return a single tensor. - pair_graph : DGLGraph - The graph that contains only the edges in the minibatch as well as their incident - nodes. - - Note that the metagraph of this graph will be identical to that of the original - graph. - negative_pair_graph : DGLGraph - The graph that contains only the edges connecting the source and destination nodes - yielded from the given negative sampler, if negative sampling is enabled. - - Note that the metagraph of this graph will be identical to that of the original - graph. - blocks : list[DGLGraph] - The list of MFGs necessary for computing the representation of the edges. - """ - if self.negative_sampler is None: - return self._collate(items) - else: - return self._collate_with_negative_sampling(items) + dgl_warning( + "g_sampling is deprecated. " + "Please merge g_sampling and the original graph into one graph and use " + "the exclude argument to specify which edges you don't want to sample.") + if isinstance(graph_sampler, BlockSampler): + graph_sampler = EdgeBlockSampler( + graph_sampler, exclude=exclude, reverse_eids=reverse_eids, + reverse_etypes=reverse_etypes, negative_sampler=negative_sampler) + + super().__init__( + graph, indices, graph_sampler, device=device, use_ddp=use_ddp, ddp_seed=ddp_seed, + batch_size=batch_size, drop_last=drop_last, shuffle=shuffle, + use_prefetch_thread=use_prefetch_thread, use_alternate_streams=use_alternate_streams, + **kwargs) + + +######## Graph DataLoaders ######## +# GraphDataLoader loads a set of graphs so it's not relevant to the above. They are currently +# copied from the old DataLoader implementation. + +PYTORCH_VER = LooseVersion(torch.__version__) +PYTORCH_16 = PYTORCH_VER >= LooseVersion("1.6.0") +PYTORCH_17 = PYTORCH_VER >= LooseVersion("1.7.0") + +def _create_dist_sampler(dataset, dataloader_kwargs, ddp_seed): + # Note: will change the content of dataloader_kwargs + dist_sampler_kwargs = {'shuffle': dataloader_kwargs['shuffle']} + dataloader_kwargs['shuffle'] = False + if PYTORCH_16: + dist_sampler_kwargs['seed'] = ddp_seed + if PYTORCH_17: + dist_sampler_kwargs['drop_last'] = dataloader_kwargs['drop_last'] + dataloader_kwargs['drop_last'] = False + + return DistributedSampler(dataset, **dist_sampler_kwargs) class GraphCollator(object): """Given a set of graphs as well as their graph-level data, the collate function will batch the @@ -939,8 +641,8 @@ def collate(self, items): """ elem = items[0] elem_type = type(elem) - if isinstance(elem, DGLGraph): - batched_graphs = batch(items) + if isinstance(elem, DGLHeteroGraph): + batched_graphs = batch_graphs(items) return batched_graphs elif F.is_tensor(elem): return F.stack(items, 0) @@ -975,8 +677,89 @@ def collate(self, items): raise TypeError(self.graph_collate_err_msg_format.format(elem_type)) -class SubgraphIterator(object): - """Abstract class representing an iterator that yields a subgraph given a graph. +class GraphDataLoader(torch.utils.data.DataLoader): + """PyTorch dataloader for batch-iterating over a set of graphs, generating the batched + graph and corresponding label tensor (if provided) of the said minibatch. + + Parameters + ---------- + collate_fn : Function, default is None + The customized collate function. Will use the default collate + function if not given. + use_ddp : boolean, optional + If True, tells the DataLoader to split the training set for each + participating process appropriately using + :class:`torch.utils.data.distributed.DistributedSampler`. + + Overrides the :attr:`sampler` argument of :class:`torch.utils.data.DataLoader`. + ddp_seed : int, optional + The seed for shuffling the dataset in + :class:`torch.utils.data.distributed.DistributedSampler`. + + Only effective when :attr:`use_ddp` is True. + kwargs : dict + Arguments being passed to :py:class:`torch.utils.data.DataLoader`. + + Examples + -------- + To train a GNN for graph classification on a set of graphs in ``dataset`` (assume + the backend is PyTorch): + + >>> dataloader = dgl.dataloading.GraphDataLoader( + ... dataset, batch_size=1024, shuffle=True, drop_last=False, num_workers=4) + >>> for batched_graph, labels in dataloader: + ... train_on(batched_graph, labels) + + **Using with Distributed Data Parallel** + + If you are using PyTorch's distributed training (e.g. when using + :mod:`torch.nn.parallel.DistributedDataParallel`), you can train the model by + turning on the :attr:`use_ddp` option: + + >>> dataloader = dgl.dataloading.GraphDataLoader( + ... dataset, use_ddp=True, batch_size=1024, shuffle=True, drop_last=False, num_workers=4) + >>> for epoch in range(start_epoch, n_epochs): + ... dataloader.set_epoch(epoch) + ... for batched_graph, labels in dataloader: + ... train_on(batched_graph, labels) """ - def __init__(self, g): - self.g = g + collator_arglist = inspect.getfullargspec(GraphCollator).args + + def __init__(self, dataset, collate_fn=None, use_ddp=False, ddp_seed=0, **kwargs): + collator_kwargs = {} + dataloader_kwargs = {} + for k, v in kwargs.items(): + if k in self.collator_arglist: + collator_kwargs[k] = v + else: + dataloader_kwargs[k] = v + + if collate_fn is None: + self.collate = GraphCollator(**collator_kwargs).collate + else: + self.collate = collate_fn + + self.use_ddp = use_ddp + if use_ddp: + self.dist_sampler = _create_dist_sampler(dataset, dataloader_kwargs, ddp_seed) + dataloader_kwargs['sampler'] = self.dist_sampler + + super().__init__(dataset=dataset, collate_fn=self.collate, **dataloader_kwargs) + + def set_epoch(self, epoch): + """Sets the epoch number for the underlying sampler which ensures all replicas + to use a different ordering for each epoch. + + Only available when :attr:`use_ddp` is True. + + Calls :meth:`torch.utils.data.distributed.DistributedSampler.set_epoch`. + + Parameters + ---------- + epoch : int + The epoch number. + """ + if self.use_ddp: + self.dist_sampler.set_epoch(epoch) + else: + raise DGLError('set_epoch is only available when use_ddp is True.') diff --git a/python/dgl/dataloading/dist_dataloader.py b/python/dgl/dataloading/dist_dataloader.py new file mode 100644 index 000000000000..984bb93d0403 --- /dev/null +++ b/python/dgl/dataloading/dist_dataloader.py @@ -0,0 +1,103 @@ +"""Distributed dataloaders. +""" +import inspect +from ..distributed import DistDataLoader +# Still depends on the legacy NodeCollator... +from .._dataloading.dataloader import NodeCollator, EdgeCollator + +def _remove_kwargs_dist(kwargs): + if 'num_workers' in kwargs: + del kwargs['num_workers'] + if 'pin_memory' in kwargs: + del kwargs['pin_memory'] + print('Distributed DataLoaders do not support pin_memory.') + return kwargs + +class DistNodeDataLoader(DistDataLoader): + """PyTorch dataloader for batch-iterating over a set of nodes, generating the list + of message flow graphs (MFGs) as computation dependency of the said minibatch, on + a distributed graph. + + All the arguments have the same meaning as the single-machine counterpart + :class:`dgl.dataloading.pytorch.NodeDataLoader` except the first argument + :attr:`g` which must be a :class:`dgl.distributed.DistGraph`. + + Parameters + ---------- + g : DistGraph + The distributed graph. + + nids, graph_sampler, device, kwargs : + See :class:`dgl.dataloading.pytorch.NodeDataLoader`. + + See also + -------- + dgl.dataloading.pytorch.NodeDataLoader + """ + def __init__(self, g, nids, graph_sampler, device=None, **kwargs): + collator_kwargs = {} + dataloader_kwargs = {} + _collator_arglist = inspect.getfullargspec(NodeCollator).args + for k, v in kwargs.items(): + if k in _collator_arglist: + collator_kwargs[k] = v + else: + dataloader_kwargs[k] = v + if device is None: + # for the distributed case default to the CPU + device = 'cpu' + assert device == 'cpu', 'Only cpu is supported in the case of a DistGraph.' + # Distributed DataLoader currently does not support heterogeneous graphs + # and does not copy features. Fallback to normal solution + self.collator = NodeCollator(g, nids, graph_sampler, **collator_kwargs) + _remove_kwargs_dist(dataloader_kwargs) + super().__init__(self.collator.dataset, + collate_fn=self.collator.collate, + **dataloader_kwargs) + self.device = device + +class DistEdgeDataLoader(DistDataLoader): + """PyTorch dataloader for batch-iterating over a set of edges, generating the list + of message flow graphs (MFGs) as computation dependency of the said minibatch for + edge classification, edge regression, and link prediction, on a distributed + graph. + + All the arguments have the same meaning as the single-machine counterpart + :class:`dgl.dataloading.pytorch.EdgeDataLoader` except the first argument + :attr:`g` which must be a :class:`dgl.distributed.DistGraph`. + + Parameters + ---------- + g : DistGraph + The distributed graph. + + eids, graph_sampler, device, kwargs : + See :class:`dgl.dataloading.pytorch.EdgeDataLoader`. + + See also + -------- + dgl.dataloading.pytorch.EdgeDataLoader + """ + def __init__(self, g, eids, graph_sampler, device=None, **kwargs): + collator_kwargs = {} + dataloader_kwargs = {} + _collator_arglist = inspect.getfullargspec(EdgeCollator).args + for k, v in kwargs.items(): + if k in _collator_arglist: + collator_kwargs[k] = v + else: + dataloader_kwargs[k] = v + + if device is None: + # for the distributed case default to the CPU + device = 'cpu' + assert device == 'cpu', 'Only cpu is supported in the case of a DistGraph.' + # Distributed DataLoader currently does not support heterogeneous graphs + # and does not copy features. Fallback to normal solution + self.collator = EdgeCollator(g, eids, graph_sampler, **collator_kwargs) + _remove_kwargs_dist(dataloader_kwargs) + super().__init__(self.collator.dataset, + collate_fn=self.collator.collate, + **dataloader_kwargs) + + self.device = device diff --git a/python/dgl/dataloading/negative_sampler.py b/python/dgl/dataloading/negative_sampler.py index d9a56e2e8df3..54b6eca45cc0 100644 --- a/python/dgl/dataloading/negative_sampler.py +++ b/python/dgl/dataloading/negative_sampler.py @@ -1,7 +1,6 @@ """Negative samplers""" from collections.abc import Mapping from .. import backend as F -from ..sampling import global_uniform_negative_sampling class _BaseNegativeSampler(object): def _generate(self, g, eids, canonical_etype): @@ -26,7 +25,7 @@ def __call__(self, g, eids): eids = {g.to_canonical_etype(k): v for k, v in eids.items()} neg_pair = {k: self._generate(g, v, k) for k, v in eids.items()} else: - assert len(g.etypes) == 1, \ + assert len(g.canonical_etypes) == 1, \ 'please specify a dict of etypes and ids for graphs with multiple edge types' neg_pair = self._generate(g, eids, g.canonical_etypes[0]) @@ -64,7 +63,7 @@ def _generate(self, g, eids, canonical_etype): shape = (shape[0] * self.k,) src, _ = g.find_edges(eids, etype=canonical_etype) src = F.repeat(src, self.k, 0) - dst = F.randint(shape, dtype, ctx, 0, g.number_of_nodes(vtype)) + dst = F.randint(shape, dtype, ctx, 0, g.num_nodes(vtype)) return src, dst # Alias @@ -90,14 +89,6 @@ class GlobalUniform(_BaseNegativeSampler): replace : bool, optional Whether to sample with replacement. Setting it to True will make things faster. (Default: True) - redundancy : float, optional - Indicates how much more negative samples to actually generate during rejection sampling - before finding the unique pairs. - - Increasing it will increase the likelihood of getting :attr:`k` negative samples - per edge, but will also take more time and memory. - - (Default: automatically determined by the density of graph) Notes ----- @@ -113,13 +104,11 @@ class GlobalUniform(_BaseNegativeSampler): >>> neg_sampler(g, torch.LongTensor([0, 1])) (tensor([0, 1, 3, 2]), tensor([2, 0, 2, 1])) """ - def __init__(self, k, exclude_self_loops=True, replace=False, redundancy=None): + def __init__(self, k, exclude_self_loops=True, replace=False): self.k = k self.exclude_self_loops = exclude_self_loops self.replace = replace - self.redundancy = redundancy def _generate(self, g, eids, canonical_etype): - return global_uniform_negative_sampling( - g, len(eids) * self.k, self.exclude_self_loops, self.replace, - canonical_etype, self.redundancy) + return g.global_uniform_negative_sampling( + len(eids) * self.k, self.exclude_self_loops, self.replace, canonical_etype) diff --git a/python/dgl/dataloading/neighbor_sampler.py b/python/dgl/dataloading/neighbor_sampler.py new file mode 100644 index 000000000000..526d5223228d --- /dev/null +++ b/python/dgl/dataloading/neighbor_sampler.py @@ -0,0 +1,124 @@ +"""Data loading components for neighbor sampling""" +from ..base import NID, EID +from ..transform import to_block +from .base import BlockSampler + +class NeighborSampler(BlockSampler): + """Sampler that builds computational dependency of node representations via + neighbor sampling for multilayer GNN. + + This sampler will make every node gather messages from a fixed number of neighbors + per edge type. The neighbors are picked uniformly. + + Parameters + ---------- + fanouts : list[int] or list[dict[etype, int]] + List of neighbors to sample per edge type for each GNN layer, with the i-th + element being the fanout for the i-th GNN layer. + + If only a single integer is provided, DGL assumes that every edge type + will have the same fanout. + + If -1 is provided for one edge type on one layer, then all inbound edges + of that edge type will be included. + replace : bool, default False + Whether to sample with replacement + prob : str, optional + If given, the probability of each neighbor being sampled is proportional + to the edge feature value with the given name in ``g.edata``. The feature must be + a scalar on each edge. + + Examples + -------- + To train a 3-layer GNN for node classification on a set of nodes ``train_nid`` on + a homogeneous graph where each node takes messages from 5, 10, 15 neighbors for + the first, second, and third layer respectively (assuming the backend is PyTorch): + + >>> sampler = dgl.dataloading.NeighborSampler([5, 10, 15]) + >>> dataloader = dgl.dataloading.NodeDataLoader( + ... g, train_nid, sampler, + ... batch_size=1024, shuffle=True, drop_last=False, num_workers=4) + >>> for input_nodes, output_nodes, blocks in dataloader: + ... train_on(blocks) + + If training on a heterogeneous graph and you want different number of neighbors for each + edge type, one should instead provide a list of dicts. Each dict would specify the + number of neighbors to pick per edge type. + + >>> sampler = dgl.dataloading.NeighborSampler([ + ... {('user', 'follows', 'user'): 5, + ... ('user', 'plays', 'game'): 4, + ... ('game', 'played-by', 'user'): 3}] * 3) + + If you would like non-uniform neighbor sampling: + + >>> g.edata['p'] = torch.rand(g.num_edges()) # any non-negative 1D vector works + >>> sampler = dgl.dataloading.NeighborSampler([5, 10, 15], prob='p') + + Notes + ----- + For the concept of MFGs, please refer to + :ref:`User Guide Section 6 ` and + :doc:`Minibatch Training Tutorials `. + """ + def __init__(self, fanouts, edge_dir='in', prob=None, replace=False, **kwargs): + super().__init__(**kwargs) + self.fanouts = fanouts + self.edge_dir = edge_dir + self.prob = prob + self.replace = replace + + def sample_blocks(self, g, seed_nodes, exclude_eids=None): + output_nodes = seed_nodes + blocks = [] + for fanout in reversed(self.fanouts): + frontier = g.sample_neighbors( + seed_nodes, fanout, edge_dir=self.edge_dir, prob=self.prob, + replace=self.replace, output_device=self.output_device, + exclude_edges=exclude_eids) + eid = frontier.edata[EID] + block = to_block(frontier, seed_nodes) + block.edata[EID] = eid + seed_nodes = block.srcdata[NID] + blocks.insert(0, block) + + return seed_nodes, output_nodes, blocks + +MultiLayerNeighborSampler = NeighborSampler + +class MultiLayerFullNeighborSampler(NeighborSampler): + """Sampler that builds computational dependency of node representations by taking messages + from all neighbors for multilayer GNN. + + This sampler will make every node gather messages from every single neighbor per edge type. + + Parameters + ---------- + n_layers : int + The number of GNN layers to sample. + return_eids : bool, default False + Whether to return the edge IDs involved in message passing in the MFG. + If True, the edge IDs will be stored as an edge feature named ``dgl.EID``. + + Examples + -------- + To train a 3-layer GNN for node classification on a set of nodes ``train_nid`` on + a homogeneous graph where each node takes messages from all neighbors for the first, + second, and third layer respectively (assuming the backend is PyTorch): + + >>> sampler = dgl.dataloading.MultiLayerFullNeighborSampler(3) + >>> dataloader = dgl.dataloading.NodeDataLoader( + ... g, train_nid, sampler, + ... batch_size=1024, shuffle=True, drop_last=False, num_workers=4) + >>> for input_nodes, output_nodes, blocks in dataloader: + ... train_on(blocks) + + Notes + ----- + For the concept of MFGs, please refer to + :ref:`User Guide Section 6 ` and + :doc:`Minibatch Training Tutorials `. + """ + def __init__(self, num_layers, edge_dir='in', prob=None, replace=False, **kwargs): + super().__init__([-1] * num_layers, edge_dir=edge_dir, prob=prob, replace=replace, + **kwargs) diff --git a/python/dgl/dataloading/shadow.py b/python/dgl/dataloading/shadow.py index 4841a9dbb1c6..86675922e775 100644 --- a/python/dgl/dataloading/shadow.py +++ b/python/dgl/dataloading/shadow.py @@ -1,12 +1,10 @@ """ShaDow-GNN subgraph samplers.""" -from ..utils import prepare_tensor_or_dict -from ..base import NID +from ..sampling.utils import EidExcluder from .. import transform -from ..sampling import sample_neighbors -from .neighbor import NeighborSamplingMixin -from .dataloader import exclude_edges, Sampler +from ..base import NID +from .base import set_node_lazy_features, set_edge_lazy_features -class ShaDowKHopSampler(NeighborSamplingMixin, Sampler): +class ShaDowKHopSampler(object): """K-hop subgraph sampler used by `ShaDow-GNN `__. @@ -70,29 +68,32 @@ class ShaDowKHopSampler(NeighborSamplingMixin, Sampler): If you would like non-uniform neighbor sampling: >>> g.edata['p'] = torch.rand(g.num_edges()) # any non-negative 1D vector works - >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([5, 10, 15], prob='p') + >>> sampler = dgl.dataloading.ShaDowKHopSampler([5, 10, 15], prob='p') """ - def __init__(self, fanouts, replace=False, prob=None, output_ctx=None): - super().__init__(output_ctx) + def __init__(self, fanouts, replace=False, prob=None, prefetch_node_feats=None, + prefetch_edge_feats=None, output_device=None): self.fanouts = fanouts self.replace = replace self.prob = prob - self.set_output_context(output_ctx) + self.prefetch_node_feats = prefetch_node_feats + self.prefetch_edge_feats = prefetch_edge_feats + self.output_device = output_device - def sample(self, g, seed_nodes, exclude_eids=None): - self._build_fanout(len(self.fanouts), g) - self._build_prob_arrays(g) - seed_nodes = prepare_tensor_or_dict(g, seed_nodes, 'seed nodes') + def sample(self, g, seed_nodes, exclude_edges=None): + """Sample a subgraph given a tensor of seed nodes.""" output_nodes = seed_nodes - - for i in range(len(self.fanouts)): - fanout = self.fanouts[i] - frontier = sample_neighbors( - g, seed_nodes, fanout, replace=self.replace, prob=self.prob_arrays) + for fanout in reversed(self.fanouts): + frontier = g.sample_neighbors( + seed_nodes, fanout, output_device=self.output_device, + replace=self.replace, prob=self.prob, exclude_edges=exclude_edges) block = transform.to_block(frontier, seed_nodes) seed_nodes = block.srcdata[NID] - subg = g.subgraph(seed_nodes, relabel_nodes=True) - subg = exclude_edges(subg, exclude_eids, self.output_device) + subg = g.subgraph(seed_nodes, relabel_nodes=True, output_device=self.output_device) + if exclude_edges is not None: + subg = EidExcluder(exclude_edges)(subg) + + set_node_lazy_features(subg, self.prefetch_node_feats) + set_edge_lazy_features(subg, self.prefetch_edge_feats) - return seed_nodes, output_nodes, [subg] + return seed_nodes, output_nodes, subg diff --git a/python/dgl/distributed/dist_graph.py b/python/dgl/distributed/dist_graph.py index 81fa760bb2cd..d64269995912 100644 --- a/python/dgl/distributed/dist_graph.py +++ b/python/dgl/distributed/dist_graph.py @@ -26,6 +26,7 @@ from . import role from .server_state import ServerState from .rpc_server import start_server +from . import graph_services from .graph_services import find_edges as dist_find_edges from .graph_services import out_degrees as dist_out_degrees from .graph_services import in_degrees as dist_in_degrees @@ -1223,6 +1224,20 @@ def barrier(self): ''' self._client.barrier() + def sample_neighbors(self, seed_nodes, fanout, edge_dir='in', prob=None, + exclude_edges=None, replace=False, + output_device=None): + # pylint: disable=unused-argument + """Sample neighbors from a distributed graph.""" + # Currently prob, exclude_edges, output_device, and edge_dir are ignored. + if len(self.etypes) > 1: + frontier = graph_services.sample_etype_neighbors( + self, seed_nodes, ETYPE, fanout, replace=replace) + else: + frontier = graph_services.sample_neighbors( + self, seed_nodes, fanout, replace=replace) + return frontier + def _get_ndata_names(self, ntype=None): ''' Get the names of all node data. ''' diff --git a/python/dgl/frame.py b/python/dgl/frame.py index 74cde3d73dcd..820b2be2a36c 100644 --- a/python/dgl/frame.py +++ b/python/dgl/frame.py @@ -7,6 +7,7 @@ from . import backend as F from .base import DGLError, dgl_warning from .init import zero_initializer +from .storages import TensorStorage class _LazyIndex(object): def __init__(self, index): @@ -38,6 +39,23 @@ def flatten(self): flat_index = F.gather_row(flat_index, index) return flat_index +class LazyFeature(object): + """Placeholder for prefetching from DataLoader. + """ + __slots__ = ['name', 'id_'] + def __init__(self, name=None, id_=None): + self.name = name + self.id_ = id_ + + def to(self, *args, **kwargs): # pylint: disable=invalid-name, unused-argument + """No-op. For compatibility of :meth:`Frame.to` method.""" + return self + + @property + def data(self): + """No-op. For compatibility of :meth:`Frame.__repr__` method.""" + return self + class Scheme(namedtuple('Scheme', ['shape', 'dtype'])): """The column scheme. @@ -77,7 +95,7 @@ def infer_scheme(tensor): """ return Scheme(tuple(F.shape(tensor)[1:]), F.dtype(tensor)) -class Column(object): +class Column(TensorStorage): """A column is a compact store of features of multiple nodes/edges. It batches all the feature tensors together along the first dimension @@ -120,7 +138,7 @@ class Column(object): Index tensor """ def __init__(self, storage, scheme=None, index=None, device=None): - self.storage = storage + super().__init__(storage) self.scheme = scheme if scheme else infer_scheme(storage) self.index = index self.device = device @@ -336,10 +354,13 @@ def __init__(self, data=None, num_rows=None): assert not isinstance(data, Frame) # sanity check for code refactor # Note that we always create a new column for the given data. # This avoids two frames accidentally sharing the same column. - self._columns = {k : Column.create(v) for k, v in data.items()} + self._columns = {k : v if isinstance(v, LazyFeature) else Column.create(v) + for k, v in data.items()} self._num_rows = num_rows # infer num_rows & sanity check for name, col in self._columns.items(): + if isinstance(col, LazyFeature): + continue if self._num_rows is None: self._num_rows = len(col) elif len(col) != self._num_rows: @@ -504,6 +525,10 @@ def update_column(self, name, data): data : Column or data convertible to Column The column data. """ + if isinstance(data, LazyFeature): + self._columns[name] = data + return + col = Column.create(data) if len(col) != self.num_rows: raise DGLError('Expected data to have %d rows, got %d.' % diff --git a/python/dgl/geometry/capi.py b/python/dgl/geometry/capi.py index 0ad9fee81c50..b75adb6c1b97 100644 --- a/python/dgl/geometry/capi.py +++ b/python/dgl/geometry/capi.py @@ -1,6 +1,6 @@ """Python interfaces to DGL farthest point sampler.""" -from dgl._ffi.base import DGLError import numpy as np +from .._ffi.base import DGLError from .._ffi.function import _init_api from .. import backend as F from .. import ndarray as nd diff --git a/python/dgl/heterograph.py b/python/dgl/heterograph.py index 2735470bdad1..f8fefc9fc805 100644 --- a/python/dgl/heterograph.py +++ b/python/dgl/heterograph.py @@ -1600,6 +1600,14 @@ def set_batch_num_edges(self, val): # View ################################################################# + def get_node_storage(self, key, ntype=None): + """Get storage object of node feature of type :attr:`ntype` and name :attr:`key`.""" + return self._node_frames[self.get_ntype_id(ntype)]._columns[key] + + def get_edge_storage(self, key, etype=None): + """Get storage object of edge feature of type :attr:`etype` and name :attr:`key`.""" + return self._edge_frames[self.get_etype_id(etype)]._columns[key] + @property def nodes(self): """Return a node view diff --git a/python/dgl/sampling/__init__.py b/python/dgl/sampling/__init__.py index 7ae2515549a4..b48a5e1a2d4f 100644 --- a/python/dgl/sampling/__init__.py +++ b/python/dgl/sampling/__init__.py @@ -10,3 +10,4 @@ from .neighbor import * from .node2vec_randomwalk import * from .negative import * +from . import utils diff --git a/python/dgl/sampling/negative.py b/python/dgl/sampling/negative.py index 3b6d4dec4fae..ca34c3c60431 100644 --- a/python/dgl/sampling/negative.py +++ b/python/dgl/sampling/negative.py @@ -3,6 +3,8 @@ from numpy.polynomial import polynomial from .._ffi.function import _init_api from .. import backend as F +from .. import utils +from ..heterograph import DGLHeteroGraph __all__ = [ 'global_uniform_negative_sampling'] @@ -99,5 +101,7 @@ def global_uniform_negative_sampling( src, dst = _CAPI_DGLGlobalUniformNegativeSampling( g._graph, etype_id, num_samples, 3, exclude_self_loops, replace, redundancy) return F.from_dgl_nd(src), F.from_dgl_nd(dst) +DGLHeteroGraph.global_uniform_negative_sampling = utils.alias_func( + global_uniform_negative_sampling) _init_api('dgl.sampling.negative', __name__) diff --git a/python/dgl/sampling/neighbor.py b/python/dgl/sampling/neighbor.py index 8c10f2ded11e..186dc3f7a669 100644 --- a/python/dgl/sampling/neighbor.py +++ b/python/dgl/sampling/neighbor.py @@ -6,6 +6,7 @@ from ..heterograph import DGLHeteroGraph from .. import ndarray as nd from .. import utils +from .utils import EidExcluder __all__ = [ 'sample_etype_neighbors', @@ -15,7 +16,7 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=None, replace=False, copy_ndata=True, copy_edata=True, etype_sorted=False, - _dist_training=False): + _dist_training=False, output_device=None): """Sample neighboring edges of the given nodes and return the induced subgraph. For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges @@ -77,6 +78,8 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No A hint telling whether the etypes are already sorted. (Default: False) + output_device : Framework-specific device context object, optional + The output device. Default is the same as the input graph. Returns ------- @@ -142,10 +145,13 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No for i, etype in enumerate(ret.canonical_etypes): ret.edges[etype].data[EID] = induced_edges[i] - return ret + return ret if output_device is None else ret.to(output_device) + +DGLHeteroGraph.sample_etype_neighbors = utils.alias_func(sample_etype_neighbors) def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False, - copy_ndata=True, copy_edata=True, _dist_training=False, exclude_edges=None): + copy_ndata=True, copy_edata=True, _dist_training=False, + exclude_edges=None, output_device=None): """Sample neighboring edges of the given nodes and return the induced subgraph. For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges @@ -210,12 +216,13 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False, Internal argument. Do not use. (Default: False) + output_device : Framework-specific device context object, optional + The output device. Default is the same as the input graph. Returns ------- DGLGraph - A sampled subgraph containing only the sampled neighboring edges, with the - same device as the input graph. + A sampled subgraph containing only the sampled neighboring edges. Notes ----- @@ -280,6 +287,22 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False, tensor([False, False, False]) """ + if g.device == F.cpu(): + frontier = _sample_neighbors( + g, nodes, fanout, edge_dir=edge_dir, prob=prob, replace=replace, + copy_ndata=copy_ndata, copy_edata=copy_edata, exclude_edges=exclude_edges) + else: + frontier = _sample_neighbors( + g, nodes, fanout, edge_dir=edge_dir, prob=prob, replace=replace, + copy_ndata=copy_ndata, copy_edata=copy_edata) + if exclude_edges is not None: + eid_excluder = EidExcluder(exclude_edges) + frontier = eid_excluder(frontier) + return frontier if output_device is None else frontier.to(output_device) + +def _sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False, + copy_ndata=True, copy_edata=True, _dist_training=False, + exclude_edges=None): if not isinstance(nodes, dict): if len(g.ntypes) > 1: raise DGLError("Must specify node type when the graph is not homogeneous.") @@ -357,9 +380,11 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False, return ret +DGLHeteroGraph.sample_neighbors = utils.alias_func(sample_neighbors) + def sample_neighbors_biased(g, nodes, fanout, bias, edge_dir='in', tag_offset_name='_TAG_OFFSET', replace=False, - copy_ndata=True, copy_edata=True): + copy_ndata=True, copy_edata=True, output_device=None): r"""Sample neighboring edges of the given nodes and return the induced subgraph, where each neighbor's probability to be picked is determined by its tag. @@ -439,6 +464,8 @@ def sample_neighbors_biased(g, nodes, fanout, bias, edge_dir='in', edge features. (Default: True) + output_device : Framework-specific device context object, optional + The output device. Default is the same as the input graph. Returns ------- @@ -523,11 +550,12 @@ def sample_neighbors_biased(g, nodes, fanout, bias, edge_dir='in', utils.set_new_frames(ret, edge_frames=edge_frames) ret.edata[EID] = induced_edges[0] - return ret + return ret if output_device is None else ret.to(output_device) +DGLHeteroGraph.sample_neighbors_biased = utils.alias_func(sample_neighbors_biased) def select_topk(g, k, weight, nodes=None, edge_dir='in', ascending=False, - copy_ndata=True, copy_edata=True): + copy_ndata=True, copy_edata=True, output_device=None): """Select the neighboring edges with k-largest (or k-smallest) weights of the given nodes and return the induced subgraph. @@ -581,6 +609,8 @@ def select_topk(g, k, weight, nodes=None, edge_dir='in', ascending=False, edge features. (Default: True) + output_device : Framework-specific device context object, optional + The output device. Default is the same as the input graph. Returns ------- @@ -655,6 +685,8 @@ def select_topk(g, k, weight, nodes=None, edge_dir='in', ascending=False, if copy_edata: edge_frames = utils.extract_edge_subframes(g, induced_edges) utils.set_new_frames(ret, edge_frames=edge_frames) - return ret + return ret if output_device is None else ret.to(output_device) + +DGLHeteroGraph.select_topk = utils.alias_func(select_topk) _init_api('dgl.sampling.neighbor', __name__) diff --git a/python/dgl/sampling/utils.py b/python/dgl/sampling/utils.py new file mode 100644 index 000000000000..f25b84075dfb --- /dev/null +++ b/python/dgl/sampling/utils.py @@ -0,0 +1,79 @@ +"""Sampling utilities""" +from collections.abc import Mapping +import numpy as np + +from ..utils import recursive_apply, recursive_apply_pair +from ..base import EID +from .. import backend as F +from .. import transform, utils + +def _locate_eids_to_exclude(frontier_parent_eids, exclude_eids): + """Find the edges whose IDs in parent graph appeared in exclude_eids. + + Note that both arguments are numpy arrays or numpy dicts. + """ + func = lambda x, y: np.isin(x, y).nonzero()[0] + result = recursive_apply_pair(frontier_parent_eids, exclude_eids, func) + return recursive_apply(result, F.zerocopy_from_numpy) + +class EidExcluder(object): + """Class that finds the edges whose IDs in parent graph appeared in exclude_eids. + + The edge IDs can be both CPU and GPU tensors. + """ + def __init__(self, exclude_eids): + device = None + if isinstance(exclude_eids, Mapping): + for _, v in exclude_eids.items(): + if device is None: + device = F.context(v) + break + else: + device = F.context(exclude_eids) + self._exclude_eids = None + self._filter = None + + if device == F.cpu(): + # TODO(nv-dlasalle): Once Filter is implemented for the CPU, we + # should just use that irregardless of the device. + self._exclude_eids = ( + recursive_apply(exclude_eids, F.zerocopy_to_numpy) + if exclude_eids is not None else None) + else: + self._filter = recursive_apply(exclude_eids, utils.Filter) + + def _find_indices(self, parent_eids): + """ Find the set of edge indices to remove. + """ + if self._exclude_eids is not None: + parent_eids_np = recursive_apply(parent_eids, F.zerocopy_to_numpy) + return _locate_eids_to_exclude(parent_eids_np, self._exclude_eids) + else: + assert self._filter is not None + func = lambda x, y: x.find_included_indices(y) + return recursive_apply_pair(self._filter, parent_eids, func) + + def __call__(self, frontier): + parent_eids = frontier.edata[EID] + located_eids = self._find_indices(parent_eids) + + if not isinstance(located_eids, Mapping): + # (BarclayII) If frontier already has a EID field and located_eids is empty, + # the returned graph will keep EID intact. Otherwise, EID will change + # to the mapping from the new graph to the old frontier. + # So we need to test if located_eids is empty, and do the remapping ourselves. + if len(located_eids) > 0: + frontier = transform.remove_edges( + frontier, located_eids, store_ids=True) + frontier.edata[EID] = F.gather_row(parent_eids, frontier.edata[EID]) + else: + # (BarclayII) remove_edges only accepts removing one type of edges, + # so I need to keep track of the edge IDs left one by one. + new_eids = parent_eids.copy() + for k, v in located_eids.items(): + if len(v) > 0: + frontier = transform.remove_edges( + frontier, v, etype=k, store_ids=True) + new_eids[k] = F.gather_row(parent_eids[k], frontier.edges[k].data[EID]) + frontier.edata[EID] = new_eids + return frontier diff --git a/python/dgl/storages/__init__.py b/python/dgl/storages/__init__.py new file mode 100644 index 000000000000..bfa19e496e85 --- /dev/null +++ b/python/dgl/storages/__init__.py @@ -0,0 +1,9 @@ +"""Feature storage classes for DataLoading""" +from .. import backend as F + +from .base import * +from .numpy import * +if F.get_preferred_backend() == 'pytorch': + from .pytorch_tensor import * +else: + from .tensor import * diff --git a/python/dgl/storages/base.py b/python/dgl/storages/base.py new file mode 100644 index 000000000000..bd2ecc5d1183 --- /dev/null +++ b/python/dgl/storages/base.py @@ -0,0 +1,83 @@ +"""Base classes and functionalities for feature storages.""" + +import threading + + +STORAGE_WRAPPERS = {} +def register_storage_wrapper(type_): + """Decorator that associates a type to a ``FeatureStorage`` object. + """ + def deco(cls): + STORAGE_WRAPPERS[type_] = cls + return cls + return deco + +def wrap_storage(storage): + """Wrap an object into a FeatureStorage as specified by the ``register_storage_wrapper`` + decorators. + """ + for type_, storage_cls in STORAGE_WRAPPERS.items(): + if isinstance(storage, type_): + return storage_cls(storage) + + assert isinstance(storage, FeatureStorage), ( + "The frame column must be a tensor or a FeatureStorage object, got {}" + .format(type(storage))) + return storage + +class _FuncWrapper(object): + def __init__(self, func): + self.func = func + + def __call__(self, buf, *args): + buf[0] = self.func(*args) + +class ThreadedFuture(object): + """Wraps a function into a future asynchronously executed by a Python + ``threading.Thread`. The function is being executed upon instantiation of + this object. + """ + def __init__(self, target, args): + self.buf = [None] + + thread = threading.Thread( + target=_FuncWrapper(target), + args=[self.buf] + list(args), + daemon=True) + thread.start() + self.thread = thread + + def wait(self): + """Blocks the current thread until the result becomes available and returns it.""" + self.thread.join() + return self.buf[0] + +class FeatureStorage(object): + """Feature storage object which should support a fetch() operation. It is the + counterpart of a tensor for homogeneous graphs, or a dict of tensor for heterogeneous + graphs where the keys are node/edge types. + """ + def requires_ddp(self): + """Whether the FeatureStorage requires the DataLoader to set use_ddp. + """ + return False + + def fetch(self, indices, device, pin_memory=False): + """Retrieve the features at the given indices. + + If :attr:`indices` is a tensor, this is equivalent to + + .. code:: + + storage[indices] + + If :attr:`indices` is a dict of tensor, this is equivalent to + + .. code:: + + {k: storage[k][indices[k]] for k in indices.keys()} + + The subclasses can choose to utilize or ignore the flag :attr:`pin_memory` + depending on the underlying framework. + """ + raise NotImplementedError diff --git a/python/dgl/storages/numpy.py b/python/dgl/storages/numpy.py new file mode 100644 index 000000000000..02de9a22f8a9 --- /dev/null +++ b/python/dgl/storages/numpy.py @@ -0,0 +1,18 @@ +"""Feature storage for ``numpy.memmap`` object.""" +import numpy as np +from .base import FeatureStorage, ThreadedFuture, register_storage_wrapper +from .. import backend as F + +@register_storage_wrapper(np.memmap) +class NumpyStorage(FeatureStorage): + """FeatureStorage that asynchronously reads features from a ``numpy.memmap`` object.""" + def __init__(self, arr): + self.arr = arr + + def _fetch(self, indices, device, pin_memory=False): # pylint: disable=unused-argument + result = F.zerocopy_from_numpy(self.arr[indices]) + result = F.copy_to(result, device) + return result + + def fetch(self, indices, device, pin_memory=False): + return ThreadedFuture(target=self._fetch, args=(indices, device, pin_memory)) diff --git a/python/dgl/storages/pytorch_tensor.py b/python/dgl/storages/pytorch_tensor.py new file mode 100644 index 000000000000..8fdb30a03b8f --- /dev/null +++ b/python/dgl/storages/pytorch_tensor.py @@ -0,0 +1,32 @@ +"""Feature storages for PyTorch tensors.""" + +import torch +from .base import FeatureStorage, register_storage_wrapper + +def _fetch_cpu(indices, tensor, feature_shape, device, pin_memory): + result = torch.empty( + indices.shape[0], *feature_shape, dtype=tensor.dtype, + pin_memory=pin_memory) + torch.index_select(tensor, 0, indices, out=result) + result = result.to(device, non_blocking=True) + return result + +def _fetch_cuda(indices, tensor, device): + return torch.index_select(tensor, 0, indices).to(device) + +@register_storage_wrapper(torch.Tensor) +class TensorStorage(FeatureStorage): + """Feature storages for slicing a PyTorch tensor.""" + def __init__(self, tensor): + self.storage = tensor + self.feature_shape = tensor.shape[1:] + self.is_cuda = (tensor.device.type == 'cuda') + + def fetch(self, indices, device, pin_memory=False): + device = torch.device(device) + if not self.is_cuda: + # CPU to CPU or CUDA - use pin_memory and async transfer if possible + return _fetch_cpu(indices, self.storage, self.feature_shape, device, pin_memory) + else: + # CUDA to CUDA or CPU + return _fetch_cuda(indices, self.storage, device) diff --git a/python/dgl/storages/tensor.py b/python/dgl/storages/tensor.py new file mode 100644 index 000000000000..d454119ec3c3 --- /dev/null +++ b/python/dgl/storages/tensor.py @@ -0,0 +1,17 @@ +"""Feature storages for tensors across different frameworks.""" +from .base import FeatureStorage +from .. import backend as F +from ..utils import recursive_apply_pair + +def _fetch(indices, tensor, device): + return F.copy_to(F.gather_row(tensor, indices), device) + +class TensorStorage(FeatureStorage): + """FeatureStorage that synchronously slices features from a tensor and transfers + it to the given device. + """ + def __init__(self, tensor): + self.storage = tensor + + def fetch(self, indices, device, pin_memory=False): # pylint: disable=unused-argument + return recursive_apply_pair(indices, self.storage, _fetch, device) diff --git a/python/dgl/subgraph.py b/python/dgl/subgraph.py index 93b4a449e235..ee07ba127e6d 100644 --- a/python/dgl/subgraph.py +++ b/python/dgl/subgraph.py @@ -13,11 +13,12 @@ from . import ndarray as nd from .heterograph import DGLHeteroGraph from . import utils +from .utils import recursive_apply __all__ = ['node_subgraph', 'edge_subgraph', 'node_type_subgraph', 'edge_type_subgraph', 'in_subgraph', 'out_subgraph', 'khop_in_subgraph', 'khop_out_subgraph'] -def node_subgraph(graph, nodes, *, relabel_nodes=True, store_ids=True): +def node_subgraph(graph, nodes, *, relabel_nodes=True, store_ids=True, output_device=None): """Return a subgraph induced on the given nodes. A node-induced subgraph is a graph with edges whose endpoints are both in the @@ -53,6 +54,8 @@ def node_subgraph(graph, nodes, *, relabel_nodes=True, store_ids=True): resulting graph under name ``dgl.EID``; if ``relabel_nodes`` is ``True``, it will also store the raw IDs of the specified nodes in the ``ndata`` of the resulting graph under name ``dgl.NID``. + output_device : Framework-specific device context object, optional + The output device. Default is the same as the input graph. Returns ------- @@ -150,11 +153,13 @@ def _process_nodes(ntype, v): # bug in #1453. if not relabel_nodes: induced_nodes = None - return _create_hetero_subgraph(graph, sgi, induced_nodes, induced_edges, store_ids=store_ids) + subg = _create_hetero_subgraph(graph, sgi, induced_nodes, induced_edges, store_ids=store_ids) + return subg if output_device is None else subg.to(output_device) DGLHeteroGraph.subgraph = utils.alias_func(node_subgraph) -def edge_subgraph(graph, edges, *, relabel_nodes=True, store_ids=True, **deprecated_kwargs): +def edge_subgraph(graph, edges, *, relabel_nodes=True, store_ids=True, output_device=None, + **deprecated_kwargs): """Return a subgraph induced on the given edges. An edge-induced subgraph is equivalent to creating a new graph using the given @@ -190,6 +195,8 @@ def edge_subgraph(graph, edges, *, relabel_nodes=True, store_ids=True, **depreca resulting graph under name ``dgl.EID``; if ``relabel_nodes`` is ``True``, it will also store the raw IDs of the incident nodes in the ``ndata`` of the resulting graph under name ``dgl.NID``. + output_device : Framework-specific device context object, optional + The output device. Default is the same as the input graph. Returns ------- @@ -301,11 +308,12 @@ def _process_edges(etype, e): induced_edges.append(_process_edges(cetype, eids)) sgi = graph._graph.edge_subgraph(induced_edges, not relabel_nodes) induced_nodes = sgi.induced_nodes if relabel_nodes else None - return _create_hetero_subgraph(graph, sgi, induced_nodes, induced_edges, store_ids=store_ids) + subg = _create_hetero_subgraph(graph, sgi, induced_nodes, induced_edges, store_ids=store_ids) + return subg if output_device is None else subg.to(output_device) DGLHeteroGraph.edge_subgraph = utils.alias_func(edge_subgraph) -def in_subgraph(graph, nodes, *, relabel_nodes=False, store_ids=True): +def in_subgraph(graph, nodes, *, relabel_nodes=False, store_ids=True, output_device=None): """Return the subgraph induced on the inbound edges of all the edge types of the given nodes. @@ -340,6 +348,8 @@ def in_subgraph(graph, nodes, *, relabel_nodes=False, store_ids=True): resulting graph under name ``dgl.EID``; if ``relabel_nodes`` is ``True``, it will also store the raw IDs of the extracted nodes in the ``ndata`` of the resulting graph under name ``dgl.NID``. + output_device : Framework-specific device context object, optional + The output device. Default is the same as the input graph. Returns ------- @@ -426,11 +436,12 @@ def in_subgraph(graph, nodes, *, relabel_nodes=False, store_ids=True): sgi = _CAPI_DGLInSubgraph(graph._graph, nodes_all_types, relabel_nodes) induced_nodes = sgi.induced_nodes if relabel_nodes else None induced_edges = sgi.induced_edges - return _create_hetero_subgraph(graph, sgi, induced_nodes, induced_edges, store_ids=store_ids) + subg = _create_hetero_subgraph(graph, sgi, induced_nodes, induced_edges, store_ids=store_ids) + return subg if output_device is None else subg.to(output_device) DGLHeteroGraph.in_subgraph = utils.alias_func(in_subgraph) -def out_subgraph(graph, nodes, *, relabel_nodes=False, store_ids=True): +def out_subgraph(graph, nodes, *, relabel_nodes=False, store_ids=True, output_device=None): """Return the subgraph induced on the outbound edges of all the edge types of the given nodes. @@ -465,6 +476,8 @@ def out_subgraph(graph, nodes, *, relabel_nodes=False, store_ids=True): resulting graph under name ``dgl.EID``; if ``relabel_nodes`` is ``True``, it will also store the raw IDs of the extracted nodes in the ``ndata`` of the resulting graph under name ``dgl.NID``. + output_device : Framework-specific device context object, optional + The output device. Default is the same as the input graph. Returns ------- @@ -551,11 +564,12 @@ def out_subgraph(graph, nodes, *, relabel_nodes=False, store_ids=True): sgi = _CAPI_DGLOutSubgraph(graph._graph, nodes_all_types, relabel_nodes) induced_nodes = sgi.induced_nodes if relabel_nodes else None induced_edges = sgi.induced_edges - return _create_hetero_subgraph(graph, sgi, induced_nodes, induced_edges, store_ids=store_ids) + subg = _create_hetero_subgraph(graph, sgi, induced_nodes, induced_edges, store_ids=store_ids) + return subg if output_device is None else subg.to(output_device) DGLHeteroGraph.out_subgraph = utils.alias_func(out_subgraph) -def khop_in_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True): +def khop_in_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True, output_device=None): """Return the subgraph induced by k-hop in-neighborhood of the specified node(s). We can expand a set of nodes by including the predecessors of them. From a @@ -594,6 +608,8 @@ def khop_in_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True): resulting graph under name ``dgl.EID``; if ``relabel_nodes`` is ``True``, it will also store the raw IDs of the extracted nodes in the ``ndata`` of the resulting graph under name ``dgl.NID``. + output_device : Framework-specific device context object, optional + The output device. Default is the same as the input graph. Returns ------- @@ -693,6 +709,8 @@ def khop_in_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True): for hop_nodes in k_hop_nodes_], dim=0), return_inverse=True) sub_g = node_subgraph(graph, k_hop_nodes, relabel_nodes=relabel_nodes, store_ids=store_ids) + if output_device is not None: + sub_g = sub_g.to(output_device) if relabel_nodes: if is_mapping: seed_inverse_indices = dict() @@ -702,13 +720,16 @@ def khop_in_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True): else: seed_inverse_indices = F.slice_axis( inverse_indices[nty], axis=0, begin=0, end=len(nodes[nty])) + if output_device is not None: + seed_inverse_indices = recursive_apply( + seed_inverse_indices, lambda x: F.copy_to(x, output_device)) return sub_g, seed_inverse_indices else: return sub_g DGLHeteroGraph.khop_in_subgraph = utils.alias_func(khop_in_subgraph) -def khop_out_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True): +def khop_out_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True, output_device=None): """Return the subgraph induced by k-hop out-neighborhood of the specified node(s). We can expand a set of nodes by including the successors of them. From a @@ -747,6 +768,8 @@ def khop_out_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True): resulting graph under name ``dgl.EID``; if ``relabel_nodes`` is ``True``, it will also store the raw IDs of the extracted nodes in the ``ndata`` of the resulting graph under name ``dgl.NID``. + output_device : Framework-specific device context object, optional + The output device. Default is the same as the input graph. Returns ------- @@ -847,6 +870,8 @@ def khop_out_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True): for hop_nodes in k_hop_nodes_], dim=0), return_inverse=True) sub_g = node_subgraph(graph, k_hop_nodes, relabel_nodes=relabel_nodes, store_ids=store_ids) + if output_device is not None: + sub_g = sub_g.to(output_device) if relabel_nodes: if is_mapping: seed_inverse_indices = dict() @@ -856,13 +881,16 @@ def khop_out_subgraph(graph, nodes, k, *, relabel_nodes=True, store_ids=True): else: seed_inverse_indices = F.slice_axis( inverse_indices[nty], axis=0, begin=0, end=len(nodes[nty])) + if output_device is not None: + seed_inverse_indices = recursive_apply( + seed_inverse_indices, lambda x: F.copy_to(x, output_device)) return sub_g, seed_inverse_indices else: return sub_g DGLHeteroGraph.khop_out_subgraph = utils.alias_func(khop_out_subgraph) -def node_type_subgraph(graph, ntypes): +def node_type_subgraph(graph, ntypes, output_device=None): """Return the subgraph induced on given node types. A node-type-induced subgraph contains all the nodes of the given subset of @@ -877,6 +905,8 @@ def node_type_subgraph(graph, ntypes): The graph to extract subgraphs from. ntypes : list[str] The type names of the nodes in the subgraph. + output_device : Framework-specific device context object, optional + The output device. Default is the same as the input graph. Returns ------- @@ -935,11 +965,11 @@ def node_type_subgraph(graph, ntypes): etypes.append(graph.canonical_etypes[etid]) if len(etypes) == 0: raise DGLError('There are no edges among nodes of the specified types.') - return edge_type_subgraph(graph, etypes) + return edge_type_subgraph(graph, etypes, output_device=output_device) DGLHeteroGraph.node_type_subgraph = utils.alias_func(node_type_subgraph) -def edge_type_subgraph(graph, etypes): +def edge_type_subgraph(graph, etypes, output_device=None): """Return the subgraph induced on given edge types. An edge-type-induced subgraph contains all the edges of the given subset of @@ -960,6 +990,8 @@ def edge_type_subgraph(graph, etypes): * ``(str, str, str)`` for source node type, edge type and destination node type. * or one ``str`` for the edge type name if the name can uniquely identify a triplet format in the graph. + output_device : Framework-specific device context object, optional + The output device. Default is the same as the input graph. Returns ------- @@ -1029,7 +1061,7 @@ def edge_type_subgraph(graph, etypes): hgidx = heterograph_index.create_heterograph_from_relations( metagraph, rel_graphs, utils.toindex(num_nodes_per_induced_type, "int64")) hg = DGLHeteroGraph(hgidx, induced_ntypes, induced_etypes, node_frames, edge_frames) - return hg + return hg if output_device is None else hg.to(output_device) DGLHeteroGraph.edge_type_subgraph = utils.alias_func(edge_type_subgraph) diff --git a/python/dgl/utils/__init__.py b/python/dgl/utils/__init__.py index 788c5f40a800..6e15327b9824 100644 --- a/python/dgl/utils/__init__.py +++ b/python/dgl/utils/__init__.py @@ -4,3 +4,4 @@ from .checks import * from .shared_mem import * from .filter import * +from .exception import * diff --git a/python/dgl/utils/exception.py b/python/dgl/utils/exception.py new file mode 100644 index 000000000000..b56c2e2a29df --- /dev/null +++ b/python/dgl/utils/exception.py @@ -0,0 +1,57 @@ +"""Exception wrapper classes to properly display exceptions under multithreading or +multiprocessing. +""" +import sys +import traceback + +# The following code is borrowed from PyTorch. Basically when a subprocess or thread +# throws an exception, you will need to wrap the exception with ExceptionWrapper class +# and put it in the queue you are normally retrieving from. + +# NOTE [ Python Traceback Reference Cycle Problem ] +# +# When using sys.exc_info(), it is important to **not** store the exc_info[2], +# which is the traceback, because otherwise you will run into the traceback +# reference cycle problem, i.e., the traceback holding reference to the frame, +# and the frame (which holds reference to all the object in its temporary scope) +# holding reference the traceback. + +class KeyErrorMessage(str): + r"""str subclass that returns itself in repr""" + def __repr__(self): # pylint: disable=invalid-repr-returned + return self + + +class ExceptionWrapper(object): + r"""Wraps an exception plus traceback to communicate across threads""" + def __init__(self, exc_info=None, where="in background"): + # It is important that we don't store exc_info, see + # NOTE [ Python Traceback Reference Cycle Problem ] + if exc_info is None: + exc_info = sys.exc_info() + self.exc_type = exc_info[0] + self.exc_msg = "".join(traceback.format_exception(*exc_info)) + self.where = where + + def reraise(self): + r"""Reraises the wrapped exception in the current thread""" + # Format a message such as: "Caught ValueError in DataLoader worker + # process 2. Original Traceback:", followed by the traceback. + msg = "Caught {} {}.\nOriginal {}".format( + self.exc_type.__name__, self.where, self.exc_msg) + if self.exc_type == KeyError: + # KeyError calls repr() on its argument (usually a dict key). This + # makes stack traces unreadable. It will not be changed in Python + # (https://bugs.python.org/issue2651), so we work around it. + msg = KeyErrorMessage(msg) + elif getattr(self.exc_type, "message", None): + # Some exceptions have first argument as non-str but explicitly + # have message field + raise self.exc_type(message=msg) + try: + exception = self.exc_type(msg) + except TypeError: + # If the exception takes multiple arguments, don't try to + # instantiate since we don't know how to + raise RuntimeError(msg) from None + raise exception diff --git a/python/dgl/utils/internal.py b/python/dgl/utils/internal.py index 30cdcd685a2d..b232e7299717 100644 --- a/python/dgl/utils/internal.py +++ b/python/dgl/utils/internal.py @@ -1,7 +1,7 @@ """Internal utilities.""" from __future__ import absolute_import, division -from collections.abc import Mapping, Iterable +from collections.abc import Mapping, Iterable, Sequence from collections import defaultdict from functools import wraps import numpy as np @@ -910,4 +910,31 @@ def _fn(*args, **kwargs): _fn.__doc__ = """Alias of :func:`dgl.{}`.""".format(func.__name__) return _fn +def recursive_apply(data, fn, *args, **kwargs): + """Recursively apply a function to every element in a container. + """ + if isinstance(data, str): # str is a Sequence + return fn(data, *args, **kwargs) + elif isinstance(data, Mapping): + return {k: recursive_apply(v, fn, *args, **kwargs) for k, v in data.items()} + elif isinstance(data, Sequence): + return [recursive_apply(v, fn, *args, **kwargs) for v in data] + else: + return fn(data, *args, **kwargs) + +def recursive_apply_pair(data1, data2, fn, *args, **kwargs): + """Recursively apply a function to every pair of elements in two containers with the + same nested structure. + """ + if isinstance(data1, str) or isinstance(data2, str): + return fn(data1, data2, *args, **kwargs) + elif isinstance(data1, Mapping) and isinstance(data2, Mapping): + return { + k: recursive_apply_pair(data1[k], data2[k], fn, *args, **kwargs) + for k in data1.keys()} + elif isinstance(data1, Sequence) and isinstance(data2, Sequence): + return [recursive_apply_pair(x, y, fn, *args, **kwargs) for x, y in zip(data1, data2)] + else: + return fn(data1, data2, *args, **kwargs) + _init_api("dgl.utils.internal") diff --git a/python/dgl/view.py b/python/dgl/view.py index 2eeaa7318515..7f7d873815ed 100644 --- a/python/dgl/view.py +++ b/python/dgl/view.py @@ -6,6 +6,7 @@ from .base import ALL, DGLError from . import backend as F +from .frame import LazyFeature NodeSpace = namedtuple('NodeSpace', ['data']) EdgeSpace = namedtuple('EdgeSpace', ['data']) @@ -66,7 +67,9 @@ def __getitem__(self, key): return self._graph._get_n_repr(self._ntid, self._nodes)[key] def __setitem__(self, key, val): - if isinstance(self._ntype, list): + if isinstance(val, LazyFeature): + self._graph._node_frames[self._ntid][key] = val + elif isinstance(self._ntype, list): assert isinstance(val, dict), \ 'Current HeteroNodeDataView has multiple node types, ' \ 'please passing the node type and the corresponding data through a dict.' @@ -89,36 +92,33 @@ def __delitem__(self, key): else: self._graph._pop_n_repr(self._ntid, key) + def _transpose(self, as_dict=False): + if isinstance(self._ntype, list): + ret = defaultdict(dict) + for (i, ntype) in enumerate(self._ntype): + data = self._graph._get_n_repr(self._ntid[i], self._nodes) + for key in self._graph._node_frames[self._ntid[i]]: + ret[key][ntype] = data[key] + else: + ret = self._graph._get_n_repr(self._ntid, self._nodes) + if as_dict: + ret = {key: ret[key] for key in self._graph._node_frames[self._ntid]} + return ret + def __len__(self): - assert isinstance(self._ntype, list) is False, \ - 'Current HeteroNodeDataView has multiple node types, ' \ - 'can not support len().' - return len(self._graph._node_frames[self._ntid]) + return len(self._transpose()) def __iter__(self): - assert isinstance(self._ntype, list) is False, \ - 'Current HeteroNodeDataView has multiple node types, ' \ - 'can not be iterated.' - return iter(self._graph._node_frames[self._ntid]) + return iter(self._transpose()) def keys(self): - return self._graph._node_frames[self._ntid].keys() + return self._transpose().keys() def values(self): - return self._graph._node_frames[self._ntid].values() + return self._transpose().values() def __repr__(self): - if isinstance(self._ntype, list): - ret = defaultdict(dict) - for (i, ntype) in enumerate(self._ntype): - data = self._graph._get_n_repr(self._ntid[i], self._nodes) - for key in self._graph._node_frames[self._ntid[i]]: - ret[key][ntype] = data[key] - return repr(ret) - else: - data = self._graph._get_n_repr(self._ntid, self._nodes) - return repr({key : data[key] - for key in self._graph._node_frames[self._ntid]}) + return repr(self._transpose(as_dict=True)) class HeteroEdgeView(object): """A EdgeView class to act as G.edges for a DGLHeteroGraph.""" @@ -181,7 +181,9 @@ def __getitem__(self, key): return self._graph._get_e_repr(self._etid, self._edges)[key] def __setitem__(self, key, val): - if isinstance(self._etype, list): + if isinstance(val, LazyFeature): + self._graph._edge_frames[self._etid][key] = val + elif isinstance(self._etype, list): assert isinstance(val, dict), \ 'Current HeteroEdgeDataView has multiple edge types, ' \ 'please pass the edge type and the corresponding data through a dict.' @@ -204,33 +206,30 @@ def __delitem__(self, key): else: self._graph._pop_e_repr(self._etid, key) + def _transpose(self, as_dict=False): + if isinstance(self._etype, list): + ret = defaultdict(dict) + for (i, etype) in enumerate(self._etype): + data = self._graph._get_e_repr(self._etid[i], self._edges) + for key in self._graph._edge_frames[self._etid[i]]: + ret[key][etype] = data[key] + else: + ret = self._graph._get_e_repr(self._etid, self._edges) + if as_dict: + ret = {key: ret[key] for key in self._graph._edge_frames[self._etid]} + return ret + def __len__(self): - assert isinstance(self._etype, list) is False, \ - 'Current HeteroEdgeDataView has multiple edge types, ' \ - 'can not support len().' - return len(self._graph._edge_frames[self._etid]) + return len(self._transpose()) def __iter__(self): - assert isinstance(self._etype, list) is False, \ - 'Current HeteroEdgeDataView has multiple edge types, ' \ - 'can not be iterated.' - return iter(self._graph._edge_frames[self._etid]) + return iter(self._transpose()) def keys(self): - return self._graph._edge_frames[self._etid].keys() + return self._transpose().keys() def values(self): - return self._graph._edge_frames[self._etid].values() + return self._transpose().values() def __repr__(self): - if isinstance(self._etype, list): - ret = defaultdict(dict) - for (i, etype) in enumerate(self._etype): - data = self._graph._get_e_repr(self._etid[i], self._edges) - for key in self._graph._edge_frames[self._etid[i]]: - ret[key][etype] = data[key] - return repr(ret) - else: - data = self._graph._get_e_repr(self._etid, self._edges) - return repr({key : data[key] - for key in self._graph._edge_frames[self._etid]}) + return repr(self._transpose(as_dict=True)) diff --git a/tests/compute/test_async_transferer.py b/tests/compute/_test_async_transferer.py similarity index 100% rename from tests/compute/test_async_transferer.py rename to tests/compute/_test_async_transferer.py diff --git a/tests/compute/test_subgraph.py b/tests/compute/test_subgraph.py index 0acc81d7e605..511654fab38b 100644 --- a/tests/compute/test_subgraph.py +++ b/tests/compute/test_subgraph.py @@ -594,3 +594,6 @@ def test_khop_out_subgraph(idtype): assert edge_set == {(0, 1)} assert F.array_equal(F.astype(inv['user'], idtype), F.tensor([0], idtype)) assert F.array_equal(F.astype(inv['game'], idtype), F.tensor([0], idtype)) + +if __name__ == '__main__': + test_khop_out_subgraph(F.int64) diff --git a/tests/compute/test_transform.py b/tests/compute/test_transform.py index 4076113946ee..9f91da5da532 100644 --- a/tests/compute/test_transform.py +++ b/tests/compute/test_transform.py @@ -2308,3 +2308,4 @@ def test_module_add_edge(idtype): if __name__ == '__main__': test_partition_with_halo() + test_module_heat_kernel(F.int32) diff --git a/tests/distributed/test_mp_dataloader.py b/tests/distributed/test_mp_dataloader.py index 918640a2bbc7..bd0f00a3367c 100644 --- a/tests/distributed/test_mp_dataloader.py +++ b/tests/distributed/test_mp_dataloader.py @@ -146,14 +146,14 @@ def start_dist_neg_dataloader(rank, tmpdir, num_server, num_workers, orig_nid, g num_negs = 5 sampler = dgl.dataloading.MultiLayerNeighborSampler([5,10]) negative_sampler=dgl.dataloading.negative_sampler.Uniform(num_negs) - dataloader = dgl.dataloading.EdgeDataLoader(dist_graph, - train_eid, - sampler, - batch_size=batch_size, - negative_sampler=negative_sampler, - shuffle=True, - drop_last=False, - num_workers=num_workers) + dataloader = dgl.dataloading.DistEdgeDataLoader(dist_graph, + train_eid, + sampler, + batch_size=batch_size, + negative_sampler=negative_sampler, + shuffle=True, + drop_last=False, + num_workers=num_workers) for _ in range(2): for _, (_, pos_graph, neg_graph, blocks) in zip(range(0, num_edges_to_sample, batch_size), dataloader): block = blocks[-1] @@ -288,7 +288,7 @@ def start_node_dataloader(rank, tmpdir, num_server, num_workers, orig_nid, orig_ # We need to test creating DistDataLoader multiple times. for i in range(2): # Create DataLoader for constructing blocks - dataloader = dgl.dataloading.NodeDataLoader( + dataloader = dgl.dataloading.DistNodeDataLoader( dist_graph, train_nid, sampler, @@ -339,7 +339,7 @@ def start_edge_dataloader(rank, tmpdir, num_server, num_workers, orig_nid, orig_ # We need to test creating DistDataLoader multiple times. for i in range(2): # Create DataLoader for constructing blocks - dataloader = dgl.dataloading.EdgeDataLoader( + dataloader = dgl.dataloading.DistEdgeDataLoader( dist_graph, train_eid, sampler, diff --git a/tests/pytorch/test_dataloader.py b/tests/pytorch/test_dataloader.py index 84e87b00f773..7c44fef86b99 100644 --- a/tests/pytorch/test_dataloader.py +++ b/tests/pytorch/test_dataloader.py @@ -10,193 +10,6 @@ from itertools import product import pytest -def _check_neighbor_sampling_dataloader(g, nids, dl, mode, collator): - seeds = defaultdict(list) - - for item in dl: - if mode == 'node': - input_nodes, output_nodes, blocks = item - elif mode == 'edge': - input_nodes, pair_graph, blocks = item - output_nodes = pair_graph.ndata[dgl.NID] - elif mode == 'link': - input_nodes, pair_graph, neg_graph, blocks = item - output_nodes = pair_graph.ndata[dgl.NID] - for ntype in pair_graph.ntypes: - assert F.array_equal(pair_graph.nodes[ntype].data[dgl.NID], neg_graph.nodes[ntype].data[dgl.NID]) - - if len(g.ntypes) > 1: - for ntype in g.ntypes: - assert F.array_equal(input_nodes[ntype], blocks[0].srcnodes[ntype].data[dgl.NID]) - assert F.array_equal(output_nodes[ntype], blocks[-1].dstnodes[ntype].data[dgl.NID]) - else: - assert F.array_equal(input_nodes, blocks[0].srcdata[dgl.NID]) - assert F.array_equal(output_nodes, blocks[-1].dstdata[dgl.NID]) - - prev_dst = {ntype: None for ntype in g.ntypes} - for block in blocks: - for canonical_etype in block.canonical_etypes: - utype, etype, vtype = canonical_etype - uu, vv = block.all_edges(order='eid', etype=canonical_etype) - src = block.srcnodes[utype].data[dgl.NID] - dst = block.dstnodes[vtype].data[dgl.NID] - assert F.array_equal( - block.srcnodes[utype].data['feat'], g.nodes[utype].data['feat'][src]) - assert F.array_equal( - block.dstnodes[vtype].data['feat'], g.nodes[vtype].data['feat'][dst]) - if prev_dst[utype] is not None: - assert F.array_equal(src, prev_dst[utype]) - u = src[uu] - v = dst[vv] - assert F.asnumpy(g.has_edges_between(u, v, etype=canonical_etype)).all() - eid = block.edges[canonical_etype].data[dgl.EID] - assert F.array_equal( - block.edges[canonical_etype].data['feat'], - g.edges[canonical_etype].data['feat'][eid]) - ufound, vfound = g.find_edges(eid, etype=canonical_etype) - assert F.array_equal(ufound, u) - assert F.array_equal(vfound, v) - for ntype in block.dsttypes: - src = block.srcnodes[ntype].data[dgl.NID] - dst = block.dstnodes[ntype].data[dgl.NID] - assert F.array_equal(src[:block.number_of_dst_nodes(ntype)], dst) - prev_dst[ntype] = dst - - if mode == 'node': - for ntype in blocks[-1].dsttypes: - seeds[ntype].append(blocks[-1].dstnodes[ntype].data[dgl.NID]) - elif mode == 'edge' or mode == 'link': - for etype in pair_graph.canonical_etypes: - seeds[etype].append(pair_graph.edges[etype].data[dgl.EID]) - - # Check if all nodes/edges are iterated - seeds = {k: F.cat(v, 0) for k, v in seeds.items()} - for k, v in seeds.items(): - if k in nids: - seed_set = set(F.asnumpy(nids[k])) - elif isinstance(k, tuple) and k[1] in nids: - seed_set = set(F.asnumpy(nids[k[1]])) - else: - continue - - v_set = set(F.asnumpy(v)) - assert v_set == seed_set - -def test_neighbor_sampler_dataloader(): - g = dgl.heterograph({('user', 'follow', 'user'): ([0, 0, 0, 1, 1], [1, 2, 3, 3, 4])}, - {'user': 6}).long() - g = dgl.to_bidirected(g).to(F.ctx()) - g.ndata['feat'] = F.randn((6, 8)) - g.edata['feat'] = F.randn((10, 4)) - reverse_eids = F.tensor([5, 6, 7, 8, 9, 0, 1, 2, 3, 4], dtype=F.int64) - g_sampler1 = dgl.dataloading.MultiLayerNeighborSampler([2, 2], return_eids=True) - g_sampler2 = dgl.dataloading.MultiLayerFullNeighborSampler(2, return_eids=True) - - hg = dgl.heterograph({ - ('user', 'follow', 'user'): ([0, 0, 0, 1, 1, 1, 2], [1, 2, 3, 0, 2, 3, 0]), - ('user', 'followed-by', 'user'): ([1, 2, 3, 0, 2, 3, 0], [0, 0, 0, 1, 1, 1, 2]), - ('user', 'play', 'game'): ([0, 1, 1, 3, 5], [0, 1, 2, 0, 2]), - ('game', 'played-by', 'user'): ([0, 1, 2, 0, 2], [0, 1, 1, 3, 5]) - }).long().to(F.ctx()) - for ntype in hg.ntypes: - hg.nodes[ntype].data['feat'] = F.randn((hg.number_of_nodes(ntype), 8)) - for etype in hg.canonical_etypes: - hg.edges[etype].data['feat'] = F.randn((hg.number_of_edges(etype), 4)) - hg_sampler1 = dgl.dataloading.MultiLayerNeighborSampler( - [{'play': 1, 'played-by': 1, 'follow': 2, 'followed-by': 1}] * 2, return_eids=True) - hg_sampler2 = dgl.dataloading.MultiLayerFullNeighborSampler(2, return_eids=True) - reverse_etypes = {'follow': 'followed-by', 'followed-by': 'follow', 'play': 'played-by', 'played-by': 'play'} - - collators = [] - graphs = [] - nids = [] - modes = [] - for seeds, sampler in product( - [F.tensor([0, 1, 2, 3, 5], dtype=F.int64), F.tensor([4, 5], dtype=F.int64)], - [g_sampler1, g_sampler2]): - collators.append(dgl.dataloading.NodeCollator(g, seeds, sampler)) - graphs.append(g) - nids.append({'user': seeds}) - modes.append('node') - - collators.append(dgl.dataloading.EdgeCollator(g, seeds, sampler)) - graphs.append(g) - nids.append({'follow': seeds}) - modes.append('edge') - - collators.append(dgl.dataloading.EdgeCollator( - g, seeds, sampler, exclude='self')) - graphs.append(g) - nids.append({'follow': seeds}) - modes.append('edge') - - collators.append(dgl.dataloading.EdgeCollator( - g, seeds, sampler, exclude='reverse_id', reverse_eids=reverse_eids)) - graphs.append(g) - nids.append({'follow': seeds}) - modes.append('edge') - - collators.append(dgl.dataloading.EdgeCollator( - g, seeds, sampler, negative_sampler=dgl.dataloading.negative_sampler.Uniform(2))) - graphs.append(g) - nids.append({'follow': seeds}) - modes.append('link') - - collators.append(dgl.dataloading.EdgeCollator( - g, seeds, sampler, exclude='self', negative_sampler=dgl.dataloading.negative_sampler.Uniform(2))) - graphs.append(g) - nids.append({'follow': seeds}) - modes.append('link') - - collators.append(dgl.dataloading.EdgeCollator( - g, seeds, sampler, exclude='reverse_id', reverse_eids=reverse_eids, - negative_sampler=dgl.dataloading.negative_sampler.Uniform(2))) - graphs.append(g) - nids.append({'follow': seeds}) - modes.append('link') - - for seeds, sampler in product( - [{'user': F.tensor([0, 1, 3, 5], dtype=F.int64), 'game': F.tensor([0, 1, 2], dtype=F.int64)}, - {'user': F.tensor([4, 5], dtype=F.int64), 'game': F.tensor([0, 1, 2], dtype=F.int64)}], - [hg_sampler1, hg_sampler2]): - collators.append(dgl.dataloading.NodeCollator(hg, seeds, sampler)) - graphs.append(hg) - nids.append(seeds) - modes.append('node') - - for seeds, sampler in product( - [{'follow': F.tensor([0, 1, 3, 5], dtype=F.int64), 'play': F.tensor([1, 3], dtype=F.int64)}, - {'follow': F.tensor([4, 5], dtype=F.int64), 'play': F.tensor([1, 3], dtype=F.int64)}], - [hg_sampler1, hg_sampler2]): - collators.append(dgl.dataloading.EdgeCollator(hg, seeds, sampler)) - graphs.append(hg) - nids.append(seeds) - modes.append('edge') - - collators.append(dgl.dataloading.EdgeCollator( - hg, seeds, sampler, exclude='reverse_types', reverse_etypes=reverse_etypes)) - graphs.append(hg) - nids.append(seeds) - modes.append('edge') - - collators.append(dgl.dataloading.EdgeCollator( - hg, seeds, sampler, negative_sampler=dgl.dataloading.negative_sampler.Uniform(2))) - graphs.append(hg) - nids.append(seeds) - modes.append('link') - - collators.append(dgl.dataloading.EdgeCollator( - hg, seeds, sampler, exclude='reverse_types', reverse_etypes=reverse_etypes, - negative_sampler=dgl.dataloading.negative_sampler.Uniform(2))) - graphs.append(hg) - nids.append(seeds) - modes.append('link') - - for _g, nid, collator, mode in zip(graphs, nids, collators, modes): - dl = DataLoader( - collator.dataset, collate_fn=collator.collate, batch_size=2, shuffle=True, drop_last=False) - assert isinstance(iter(dl), Iterator) - _check_neighbor_sampling_dataloader(_g, nid, dl, mode, collator) def test_graph_dataloader(): batch_size = 16 @@ -213,15 +26,12 @@ def test_graph_dataloader(): def test_cluster_gcn(num_workers): dataset = dgl.data.CoraFullDataset() g = dataset[0] - sgiter = dgl.dataloading.ClusterGCNSubgraphIterator(g, 100, '.', refresh=True) - dataloader = dgl.dataloading.GraphDataLoader(sgiter, batch_size=4, num_workers=num_workers) - for sg in dataloader: - assert sg.batch_size == 4 - - sgiter = dgl.dataloading.ClusterGCNSubgraphIterator(g, 100, '.', refresh=False) # use cache - dataloader = dgl.dataloading.GraphDataLoader(sgiter, batch_size=4, num_workers=num_workers) - for sg in dataloader: - assert sg.batch_size == 4 + sampler = dgl.dataloading.ClusterGCNSampler(g, 100) + dataloader = dgl.dataloading.DataLoader( + g, torch.arange(100), sampler, batch_size=4, num_workers=num_workers) + assert len(dataloader) == 25 + for i, sg in enumerate(dataloader): + pass @pytest.mark.parametrize('num_workers', [0, 4]) def test_shadow(num_workers): @@ -230,7 +40,7 @@ def test_shadow(num_workers): dataloader = dgl.dataloading.NodeDataLoader( g, torch.arange(g.num_nodes()), sampler, batch_size=5, shuffle=True, drop_last=False, num_workers=num_workers) - for i, (input_nodes, output_nodes, (subgraph,)) in enumerate(dataloader): + for i, (input_nodes, output_nodes, subgraph) in enumerate(dataloader): assert torch.equal(input_nodes, subgraph.ndata[dgl.NID]) assert torch.equal(input_nodes[:output_nodes.shape[0]], output_nodes) assert torch.equal(subgraph.ndata['label'], g.ndata['label'][input_nodes]) @@ -288,37 +98,25 @@ def _check_device(data): else: assert data.device == F.ctx() -@pytest.mark.parametrize('sampler_name', ['full', 'neighbor', 'neighbor2', 'shadow']) +@pytest.mark.parametrize('sampler_name', ['full', 'neighbor', 'neighbor2']) def test_node_dataloader(sampler_name): g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4])) g1.ndata['feat'] = F.copy_to(F.randn((5, 8)), F.cpu()) g1.ndata['label'] = F.copy_to(F.randn((g1.num_nodes(),)), F.cpu()) - for load_input, load_output in [(None, None), ({'feat': g1.ndata['feat']}, {'label': g1.ndata['label']})]: - for async_load in [False, True]: - for num_workers in [0, 1, 2]: - sampler = { - 'full': dgl.dataloading.MultiLayerFullNeighborSampler(2), - 'neighbor': dgl.dataloading.MultiLayerNeighborSampler([3, 3]), - 'neighbor2': dgl.dataloading.MultiLayerNeighborSampler([3, 3]), - 'shadow': dgl.dataloading.ShaDowKHopSampler([3, 3])}[sampler_name] - dataloader = dgl.dataloading.NodeDataLoader( - g1, g1.nodes(), sampler, device=F.ctx(), - load_input=load_input, - load_output=load_output, - async_load=async_load, - batch_size=g1.num_nodes(), - num_workers=num_workers) - for input_nodes, output_nodes, blocks in dataloader: - _check_device(input_nodes) - _check_device(output_nodes) - _check_device(blocks) - if load_input: - _check_device(blocks[0].srcdata['feat']) - OPS.copy_u_sum(blocks[0], blocks[0].srcdata['feat']) - if load_output: - _check_device(blocks[-1].dstdata['label']) - OPS.copy_u_sum(blocks[-1], blocks[-1].dstdata['label']) + for num_workers in [0, 1, 2]: + sampler = { + 'full': dgl.dataloading.MultiLayerFullNeighborSampler(2), + 'neighbor': dgl.dataloading.MultiLayerNeighborSampler([3, 3]), + 'neighbor2': dgl.dataloading.MultiLayerNeighborSampler([3, 3])}[sampler_name] + dataloader = dgl.dataloading.NodeDataLoader( + g1, g1.nodes(), sampler, device=F.ctx(), + batch_size=g1.num_nodes(), + num_workers=num_workers) + for input_nodes, output_nodes, blocks in dataloader: + _check_device(input_nodes) + _check_device(output_nodes) + _check_device(blocks) g2 = dgl.heterograph({ ('user', 'follow', 'user'): ([0, 0, 0, 1, 1, 1, 2], [1, 2, 3, 0, 2, 3, 0]), @@ -332,30 +130,19 @@ def test_node_dataloader(sampler_name): sampler = { 'full': dgl.dataloading.MultiLayerFullNeighborSampler(2), 'neighbor': dgl.dataloading.MultiLayerNeighborSampler([{etype: 3 for etype in g2.etypes}] * 2), - 'neighbor2': dgl.dataloading.MultiLayerNeighborSampler([3, 3]), - 'shadow': dgl.dataloading.ShaDowKHopSampler([{etype: 3 for etype in g2.etypes}] * 2)}[sampler_name] + 'neighbor2': dgl.dataloading.MultiLayerNeighborSampler([3, 3])}[sampler_name] - for async_load in [False, True]: - dataloader = dgl.dataloading.NodeDataLoader( - g2, {nty: g2.nodes(nty) for nty in g2.ntypes}, - sampler, device=F.ctx(), async_load=async_load, batch_size=batch_size) - assert isinstance(iter(dataloader), Iterator) - for input_nodes, output_nodes, blocks in dataloader: - _check_device(input_nodes) - _check_device(output_nodes) - _check_device(blocks) - - status = False - try: - dgl.dataloading.NodeDataLoader( - g2, {nty: g2.nodes(nty) for nty in g2.ntypes}, - sampler, device=F.ctx(), load_input={'feat': g1.ndata['feat']}, batch_size=batch_size) - except dgl.DGLError: - status = True - assert status + dataloader = dgl.dataloading.NodeDataLoader( + g2, {nty: g2.nodes(nty) for nty in g2.ntypes}, + sampler, device=F.ctx(), batch_size=batch_size) + assert isinstance(iter(dataloader), Iterator) + for input_nodes, output_nodes, blocks in dataloader: + _check_device(input_nodes) + _check_device(output_nodes) + _check_device(blocks) -@pytest.mark.parametrize('sampler_name', ['full', 'neighbor', 'shadow']) +@pytest.mark.parametrize('sampler_name', ['full', 'neighbor']) @pytest.mark.parametrize('neg_sampler', [ dgl.dataloading.negative_sampler.Uniform(2), dgl.dataloading.negative_sampler.GlobalUniform(15, False, 3), @@ -366,8 +153,7 @@ def test_edge_dataloader(sampler_name, neg_sampler): sampler = { 'full': dgl.dataloading.MultiLayerFullNeighborSampler(2), - 'neighbor': dgl.dataloading.MultiLayerNeighborSampler([3, 3]), - 'shadow': dgl.dataloading.ShaDowKHopSampler([3, 3])}[sampler_name] + 'neighbor': dgl.dataloading.MultiLayerNeighborSampler([3, 3])}[sampler_name] # no negative sampler dataloader = dgl.dataloading.EdgeDataLoader( @@ -399,7 +185,7 @@ def test_edge_dataloader(sampler_name, neg_sampler): sampler = { 'full': dgl.dataloading.MultiLayerFullNeighborSampler(2), 'neighbor': dgl.dataloading.MultiLayerNeighborSampler([{etype: 3 for etype in g2.etypes}] * 2), - 'shadow': dgl.dataloading.ShaDowKHopSampler([{etype: 3 for etype in g2.etypes}] * 2)}[sampler_name] + }[sampler_name] # no negative sampler dataloader = dgl.dataloading.EdgeDataLoader( @@ -424,11 +210,10 @@ def test_edge_dataloader(sampler_name, neg_sampler): _check_device(blocks) if __name__ == '__main__': - test_neighbor_sampler_dataloader() test_graph_dataloader() test_cluster_gcn(0) test_neighbor_nonuniform(0) - for sampler in ['full', 'neighbor', 'shadow']: + for sampler in ['full', 'neighbor']: test_node_dataloader(sampler) for neg_sampler in [ dgl.dataloading.negative_sampler.Uniform(2), diff --git a/tutorials/dist/1_node_classification.py b/tutorials/dist/1_node_classification.py index ec8a059f2bc0..f77a5da53d3f 100644 --- a/tutorials/dist/1_node_classification.py +++ b/tutorials/dist/1_node_classification.py @@ -265,7 +265,8 @@ def forward(self, blocks, x): Distributed mini-batch sampler ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -We can use the same `NodeDataLoader` to create a distributed mini-batch sampler for +We can use the same :class:`~dgl.dataloading.pytorch.DistNodeDataLoader`, the distributed counterpart +of :class:`~dgl.dataloading.pytorch.NodeDataLoader`, to create a distributed mini-batch sampler for node classification. @@ -274,10 +275,10 @@ def forward(self, blocks, x): .. code-block:: python sampler = dgl.dataloading.MultiLayerNeighborSampler([25,10]) - train_dataloader = dgl.dataloading.NodeDataLoader( + train_dataloader = dgl.dataloading.DistNodeDataLoader( g, train_nid, sampler, batch_size=1024, shuffle=True, drop_last=False) - valid_dataloader = dgl.dataloading.NodeDataLoader( + valid_dataloader = dgl.dataloading.DistNodeDataLoader( g, valid_nid, sampler, batch_size=1024, shuffle=False, drop_last=False) @@ -432,4 +433,4 @@ def forward(self, blocks, x): ip_addr3 ip_addr4 -''' \ No newline at end of file +''' diff --git a/tutorials/large/L2_large_link_prediction.py b/tutorials/large/L2_large_link_prediction.py index eb7642dba3db..1398a47bc81e 100644 --- a/tutorials/large/L2_large_link_prediction.py +++ b/tutorials/large/L2_large_link_prediction.py @@ -369,12 +369,6 @@ def closure(): # Ultimately, they require the model to predict one scalar score given # a node pair among a set of node pairs. # -# ``dgl.dataloading.EdgeDataLoader`` allows you to iterate over -# the edges of a new graph with the same nodes, while performing -# neighbor sampling on the original graph with ``g_sampling`` argument. -# This functionality enables convenient evaluation of a link prediction -# model. -# # Assuming that you have the following test set with labels, where # ``test_pos_src`` and ``test_pos_dst`` are ground truth node pairs # with edges in between (or *positive* pairs), and ``test_neg_src`` @@ -383,10 +377,16 @@ def closure(): # # Positive pairs -test_pos_src, test_pos_dst = graph.edges() -# Negative pairs +# These are randomly generated as an example. You will need to +# replace them with your own ground truth. +n_test_pos = 1000 +test_pos_src, test_pos_dst = ( + torch.randint(0, graph.num_nodes(), (n_test_pos,)), + torch.randint(0, graph.num_nodes(), (n_test_pos,))) +# Negative pairs. Likewise, you will need to replace them with your +# own ground truth. test_neg_src = test_pos_src -test_neg_dst = torch.randint(0, graph.num_nodes(), (graph.num_edges(),)) +test_neg_dst = torch.randint(0, graph.num_nodes(), (n_test_pos,)) ###################################################################### @@ -398,10 +398,20 @@ def closure(): test_src = torch.cat([test_pos_src, test_pos_dst]) test_dst = torch.cat([test_neg_src, test_neg_dst]) test_graph = dgl.graph((test_src, test_dst), num_nodes=graph.num_nodes()) -test_graph.edata['label'] = torch.cat( +test_ground_truth = torch.cat( [torch.ones_like(test_pos_src), torch.zeros_like(test_neg_src)]) +###################################################################### +# You will need to merge the test graph with the original graph. The +# testing edges' ID will be starting from ``graph.num_edges()``. +# + +new_graph = dgl.merge([graph, test_graph]) +test_edge_ids = torch.arange(graph.num_edges(), new_graph.num_edges()) + + + ###################################################################### # Then you could create a new ``EdgeDataLoader`` instance that # iterates on the new ``test_graph``, but uses the original ``graph`` @@ -412,11 +422,11 @@ def closure(): # test_dataloader = dgl.dataloading.EdgeDataLoader( # The following arguments are specific to EdgeDataLoader. - test_graph, # The graph to iterate edges over - torch.arange(test_graph.number_of_edges()), # The edges to iterate over + new_graph, # The graph to iterate edges over + test_edge_ids, # The edges to iterate over sampler, # The neighbor sampler device=device, # Put the MFGs on CPU or GPU - g_sampling=graph, # Graph to sample neighbors + exclude=test_edge_ids, # Do not sample test edges as neighbors # The following arguments are inherited from PyTorch DataLoader. batch_size=1024, # Batch size shuffle=True, # Whether to shuffle the nodes for every epoch @@ -447,7 +457,10 @@ def closure(): outputs = model(mfgs, inputs) test_preds.append(predictor(pair_graph, outputs)) - test_labels.append(pair_graph.edata['label']) + test_labels.append( + # Need to map the IDs of test edges in the merged graph back + # to that of test_ground_truth. + test_ground_truth[pair_graph.edata[dgl.EID] - graph.num_edges()]) test_preds = torch.cat(test_preds).cpu().numpy() test_labels = torch.cat(test_labels).cpu().numpy() diff --git a/tutorials/multi/2_node_classification.py b/tutorials/multi/2_node_classification.py index aa67f4e93829..2815c7f6c590 100644 --- a/tutorials/multi/2_node_classification.py +++ b/tutorials/multi/2_node_classification.py @@ -158,7 +158,6 @@ def run(proc_id, devices): # Copied from previous tutorial with changes highlighted. for epoch in range(10): - train_dataloader.set_epoch(epoch) # <--- necessary for dataloader with DDP. model.train() with tqdm.tqdm(train_dataloader) as tq: