forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Sampling] New sampling pipeline plus asynchronous prefetching (dmlc#…
…3665) * initial update * more * more * multi-gpu example * cluster gcn, finalize homogeneous * more explanation * fix * bunch of fixes * fix * RGAT example and more fixes * shadow-gnn sampler and some changes in unit test * fix * wth * more fixes * remove shadow+node/edge dataloader tests for possible ux changes * lints * add legacy dataloading import just in case * fix * update pylint for f-strings * fix * lint * lint * lint again * cherry-picking commit fa9f494 * oops * fix * add sample_neighbors in dist_graph * fix * lint * fix * fix * fix * fix tutorial * fix * fix * fix * fix warning * remove debug * add get_foo_storage apis * lint
- Loading branch information
Showing
62 changed files
with
4,011 additions
and
1,487 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torchmetrics.functional as MF | ||
import dgl | ||
import dgl.nn as dglnn | ||
import time | ||
import numpy as np | ||
from ogb.nodeproppred import DglNodePropPredDataset | ||
|
||
USE_WRAPPER = True | ||
|
||
class SAGE(nn.Module): | ||
def __init__(self, in_feats, n_hidden, n_classes): | ||
super().__init__() | ||
self.layers = nn.ModuleList() | ||
self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean')) | ||
self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean')) | ||
self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean')) | ||
self.dropout = nn.Dropout(0.5) | ||
|
||
def forward(self, sg, x): | ||
h = x | ||
for l, layer in enumerate(self.layers): | ||
h = layer(sg, h) | ||
if l != len(self.layers) - 1: | ||
h = F.relu(h) | ||
h = self.dropout(h) | ||
return h | ||
|
||
dataset = DglNodePropPredDataset('ogbn-products') | ||
graph, labels = dataset[0] | ||
graph.ndata['label'] = labels | ||
split_idx = dataset.get_idx_split() | ||
train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test'] | ||
graph.ndata['train_mask'] = torch.zeros(graph.num_nodes(), dtype=torch.bool).index_fill_(0, train_idx, True) | ||
graph.ndata['valid_mask'] = torch.zeros(graph.num_nodes(), dtype=torch.bool).index_fill_(0, valid_idx, True) | ||
graph.ndata['test_mask'] = torch.zeros(graph.num_nodes(), dtype=torch.bool).index_fill_(0, test_idx, True) | ||
|
||
model = SAGE(graph.ndata['feat'].shape[1], 256, dataset.num_classes).cuda() | ||
opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) | ||
|
||
if USE_WRAPPER: | ||
import dglnew | ||
graph.create_formats_() | ||
graph = dglnew.graph.wrapper.DGLGraphStorage(graph) | ||
|
||
num_partitions = 1000 | ||
sampler = dgl.dataloading.ClusterGCNSampler( | ||
graph, num_partitions, | ||
prefetch_node_feats=['feat', 'label', 'train_mask', 'valid_mask', 'test_mask']) | ||
# DataLoader for generic dataloading with a graph, a set of indices (any indices, like | ||
# partition IDs here), and a graph sampler. | ||
# NodeDataLoader and EdgeDataLoader are simply special cases of DataLoader where the | ||
# indices are guaranteed to be node and edge IDs. | ||
dataloader = dgl.dataloading.DataLoader( | ||
graph, | ||
torch.arange(num_partitions), | ||
sampler, | ||
device='cuda', | ||
batch_size=100, | ||
shuffle=True, | ||
drop_last=False, | ||
pin_memory=True, | ||
num_workers=8, | ||
persistent_workers=True, | ||
use_prefetch_thread=True) # TBD: could probably remove this argument | ||
|
||
durations = [] | ||
for _ in range(10): | ||
t0 = time.time() | ||
for it, sg in enumerate(dataloader): | ||
x = sg.ndata['feat'] | ||
y = sg.ndata['label'][:, 0] | ||
m = sg.ndata['train_mask'] | ||
y_hat = model(sg, x) | ||
loss = F.cross_entropy(y_hat[m], y[m]) | ||
opt.zero_grad() | ||
loss.backward() | ||
opt.step() | ||
if it % 20 == 0: | ||
acc = MF.accuracy(y_hat[m], y[m]) | ||
mem = torch.cuda.max_memory_allocated() / 1000000 | ||
print('Loss', loss.item(), 'Acc', acc.item(), 'GPU Mem', mem, 'MB') | ||
tt = time.time() | ||
print(tt - t0) | ||
durations.append(tt - t0) | ||
print(np.mean(durations[4:]), np.std(durations[4:])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from . import graph | ||
from . import storages |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .graph import * | ||
from .other_feature import * | ||
from .wrapper import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
class GraphStorage(object): | ||
def get_node_storage(self, key, ntype=None): | ||
pass | ||
|
||
def get_edge_storage(self, key, etype=None): | ||
pass | ||
|
||
# Required for checking whether a single dict is allowed for ndata and edata. | ||
@property | ||
def ntypes(self): | ||
pass | ||
|
||
@property | ||
def canonical_etypes(self): | ||
pass | ||
|
||
def etypes(self): | ||
return [etype[1] for etype in self.canonical_etypes] | ||
|
||
def sample_neighbors(self, seed_nodes, fanout, edge_dir='in', prob=None, | ||
exclude_edges=None, replace=False, output_device=None): | ||
"""Return a DGLGraph which is a subgraph induced by sampling neighboring edges of | ||
the given nodes. | ||
See ``dgl.sampling.sample_neighbors`` for detailed semantics. | ||
Parameters | ||
---------- | ||
seed_nodes : Tensor or dict[str, Tensor] | ||
Node IDs to sample neighbors from. | ||
This argument can take a single ID tensor or a dictionary of node types and ID tensors. | ||
If a single tensor is given, the graph must only have one type of nodes. | ||
fanout : int or dict[etype, int] | ||
The number of edges to be sampled for each node on each edge type. | ||
This argument can take a single int or a dictionary of edge types and ints. | ||
If a single int is given, DGL will sample this number of edges for each node for | ||
every edge type. | ||
If -1 is given for a single edge type, all the neighboring edges with that edge | ||
type will be selected. | ||
prob : str, optional | ||
Feature name used as the (unnormalized) probabilities associated with each | ||
neighboring edge of a node. The feature must have only one element for each | ||
edge. | ||
The features must be non-negative floats, and the sum of the features of | ||
inbound/outbound edges for every node must be positive (though they don't have | ||
to sum up to one). Otherwise, the result will be undefined. | ||
If :attr:`prob` is not None, GPU sampling is not supported. | ||
exclude_edges: tensor or dict | ||
Edge IDs to exclude during sampling neighbors for the seed nodes. | ||
This argument can take a single ID tensor or a dictionary of edge types and ID tensors. | ||
If a single tensor is given, the graph must only have one type of nodes. | ||
replace : bool, optional | ||
If True, sample with replacement. | ||
output_device : Framework-specific device context object, optional | ||
The output device. Default is the same as the input graph. | ||
Returns | ||
------- | ||
DGLGraph | ||
A sampled subgraph with the same nodes as the original graph, but only the sampled neighboring | ||
edges. The induced edge IDs will be in ``edata[dgl.EID]``. | ||
""" | ||
pass | ||
|
||
# Required in Cluster-GCN | ||
def subgraph(self, nodes, relabel_nodes=False, output_device=None): | ||
"""Return a subgraph induced on given nodes. | ||
This has the same semantics as ``dgl.node_subgraph``. | ||
Parameters | ||
---------- | ||
nodes : nodes or dict[str, nodes] | ||
The nodes to form the subgraph. The allowed nodes formats are: | ||
* Int Tensor: Each element is a node ID. The tensor must have the same device type | ||
and ID data type as the graph's. | ||
* iterable[int]: Each element is a node ID. | ||
* Bool Tensor: Each :math:`i^{th}` element is a bool flag indicating whether | ||
node :math:`i` is in the subgraph. | ||
If the graph is homogeneous, one can directly pass the above formats. | ||
Otherwise, the argument must be a dictionary with keys being node types | ||
and values being the node IDs in the above formats. | ||
relabel_nodes : bool, optional | ||
If True, the extracted subgraph will only have the nodes in the specified node set | ||
and it will relabel the nodes in order. | ||
output_device : Framework-specific device context object, optional | ||
The output device. Default is the same as the input graph. | ||
Returns | ||
------- | ||
DGLGraph | ||
The subgraph. | ||
""" | ||
pass | ||
|
||
# Required in Link Prediction | ||
def edge_subgraph(self, edges, relabel_nodes=False, output_device=None): | ||
"""Return a subgraph induced on given edges. | ||
This has the same semantics as ``dgl.edge_subgraph``. | ||
Parameters | ||
---------- | ||
edges : edges or dict[(str, str, str), edges] | ||
The edges to form the subgraph. The allowed edges formats are: | ||
* Int Tensor: Each element is an edge ID. The tensor must have the same device type | ||
and ID data type as the graph's. | ||
* iterable[int]: Each element is an edge ID. | ||
* Bool Tensor: Each :math:`i^{th}` element is a bool flag indicating whether | ||
edge :math:`i` is in the subgraph. | ||
If the graph is homogeneous, one can directly pass the above formats. | ||
Otherwise, the argument must be a dictionary with keys being edge types | ||
and values being the edge IDs in the above formats. | ||
relabel_nodes : bool, optional | ||
If True, the extracted subgraph will only have the nodes in the specified node set | ||
and it will relabel the nodes in order. | ||
output_device : Framework-specific device context object, optional | ||
The output device. Default is the same as the input graph. | ||
Returns | ||
------- | ||
DGLGraph | ||
The subgraph. | ||
""" | ||
pass | ||
|
||
# Required in Link Prediction negative sampler | ||
def find_edges(self, edges, etype=None, output_device=None): | ||
"""Return the source and destination node IDs given the edge IDs within the given edge type. | ||
""" | ||
pass | ||
|
||
# Required in Link Prediction negative sampler | ||
def num_nodes(self, ntype): | ||
"""Return the number of nodes for the given node type.""" | ||
pass | ||
|
||
def global_uniform_negative_sampling(self, num_samples, exclude_self_loops=True, | ||
replace=False, etype=None): | ||
"""Per source negative sampling as in ``dgl.dataloading.GlobalUniform``""" |
Oops, something went wrong.