From 96179b0c96a5dfed18f961cdd0635ae6a8ce84b8 Mon Sep 17 00:00:00 2001 From: Lingfan Yu Date: Thu, 16 Aug 2018 14:05:50 -0400 Subject: [PATCH] Deep Generative Models of Graphs (#14) * model code for generative graphs * batched version for dynamic graph generation using padding * renaming function train back to forward * remove old util function for padding DGMG * override networkx clear to reset state, add dgl.nn * Dynamic graph without batching * use relative import path * load dataset, pad batch * bug fix * experimental batch and unbatch * dgmg batched version * minor tweak * move preprocessing padding into data loading * batch graph test code * minor * batched graph class and test cases * make dgl.nn.gcn a simple layer plus minor fix * update dgmg model * test forward using attribute field * use frame append, minor changes * moving networkx operations out of forward * revert some changes * remove structural immutability check --- .gitignore | 1 + examples/pytorch/generative_graph/model.py | 255 +++++++++++++++++++++ examples/pytorch/generative_graph/util.py | 156 +++++++++++++ python/dgl/__init__.py | 1 + python/dgl/backend/numpy.py | 9 +- python/dgl/backend/pytorch.py | 4 +- python/dgl/batch.py | 133 +++++++++++ python/dgl/graph.py | 24 +- python/dgl/nn/__init__.py | 9 + python/dgl/nn/pytorch/__init__.py | 1 + python/dgl/nn/pytorch/gcn.py | 49 ++++ tests/test_graph_batch.py | 127 ++++++++++ 12 files changed, 762 insertions(+), 7 deletions(-) create mode 100644 examples/pytorch/generative_graph/model.py create mode 100644 examples/pytorch/generative_graph/util.py create mode 100644 python/dgl/batch.py create mode 100644 python/dgl/nn/__init__.py create mode 100644 python/dgl/nn/pytorch/__init__.py create mode 100644 python/dgl/nn/pytorch/gcn.py create mode 100644 tests/test_graph_batch.py diff --git a/.gitignore b/.gitignore index 88821af7d54d..dd0950dc3fc6 100644 --- a/.gitignore +++ b/.gitignore @@ -131,6 +131,7 @@ examples/pytorch/data/ind.citeseer.ally examples/pytorch/data/ind.citeseer.allx examples/pytorch/.DS_Store examples/.DS_Store +examples/pytorch/generative_graph/*.p .DS_Store # data directory diff --git a/examples/pytorch/generative_graph/model.py b/examples/pytorch/generative_graph/model.py new file mode 100644 index 000000000000..c02674142d4d --- /dev/null +++ b/examples/pytorch/generative_graph/model.py @@ -0,0 +1,255 @@ +import dgl +from dgl.graph import DGLGraph +from dgl.nn import GCN +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import argparse +from util import DataLoader, elapsed +import time + +class MLP(nn.Module): + def __init__(self, num_hidden, num_classes, num_layers): + super(MLP, self).__init__() + layers = [] + # hidden layers + for _ in range(num_layers): + layers.append(nn.Linear(num_hidden, num_hidden)) + layers.append(nn.Sigmoid()) + # output projection + layers.append(nn.Linear(num_hidden, num_classes)) + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) + + +def move2cuda(x): + # recursively move a object to cuda + if isinstance(x, torch.Tensor): + # if Tensor, move directly + return x.cuda() + else: + try: + # iterable, recursively move each element + x = [move2cuda(i) for i in x] + return x + except: + # don't do anything for other types like basic types + return x + + +class DGMG(nn.Module): + def __init__(self, node_num_hidden, graph_num_hidden, T, num_MLP_layers=1, loss_func=None, dropout=0.0, use_cuda=False): + super(DGMG, self).__init__() + # hidden size of node and graph + self.node_num_hidden = node_num_hidden + self.graph_num_hidden = graph_num_hidden + # use GCN as a simple propagation model + self.gcn = nn.ModuleList([GCN(node_num_hidden, node_num_hidden, F.relu, dropout) for _ in range(T)]) + # project node repr to graph repr (higher dimension) + self.graph_project = nn.Linear(node_num_hidden, graph_num_hidden) + # add node + self.fan = MLP(graph_num_hidden, 2, num_MLP_layers) + # add edge + self.fae = MLP(graph_num_hidden + node_num_hidden, 1, num_MLP_layers) + # select node to add edge + self.fs = MLP(node_num_hidden * 2, 1, num_MLP_layers) + # init node state + self.finit = MLP(graph_num_hidden, node_num_hidden, num_MLP_layers) + # loss function + self.loss_func = loss_func + # use gpu + self.use_cuda = use_cuda + + def decide_add_node(self, hGs): + h = self.fan(hGs) + p = F.softmax(h, dim=1) + # calc loss + self.loss += self.loss_func(p, self.labels[self.step], self.masks[self.step]) + + def decide_add_edge(self, batched_graph, hGs): + hvs = batched_graph.get_n_repr((self.sample_node_curr_idx - 1).tolist())['h'] + h = self.fae(torch.cat((hGs, hvs), dim=1)) + p = torch.sigmoid(h) + p = torch.cat([1 - p, p], dim=1) + self.loss += self.loss_func(p, self.labels[self.step], self.masks[self.step]) + + def select_node_to_add_edge(self, batched_graph, indices): + node_indices = self.sample_node_curr_idx[indices].tolist() + node_start = self.sample_node_start_idx[indices].tolist() + node_repr = batched_graph.get_n_repr()['h'] + for i, j, idx in zip(node_start, node_indices, indices): + hu = node_repr.narrow(0, i, j-i) + hv = node_repr.narrow(0, j-1, 1) + huv = torch.cat((hu, hv.expand(j-i, -1)), dim=1) + s = F.softmax(self.fs(huv), dim=0).view(1, -1) + dst = self.node_select[self.step][idx].view(-1) + self.loss += self.loss_func(s, dst) + + def update_graph_repr(self, batched_graph, hGs, indices, indices_tensor): + start = self.sample_node_start_idx[indices].tolist() + stop = self.sample_node_curr_idx[indices].tolist() + node_repr = batched_graph.get_n_repr()['h'] + graph_repr = self.graph_project(node_repr) + new_hGs = [] + for i, j in zip(start, stop): + h = graph_repr.narrow(0, i, j-i) + hG = torch.sum(h, 0, keepdim=True) + new_hGs.append(hG) + new_hGs = torch.cat(new_hGs, dim=0) + return hGs.index_copy(0, indices_tensor, new_hGs) + + def propagate(self, batched_graph, indices): + edge_src = [self.sample_edge_src[idx][0: self.sample_edge_count[idx]] for idx in indices] + edge_dst = [self.sample_edge_dst[idx][0: self.sample_edge_count[idx]] for idx in indices] + u = np.concatenate(edge_src).tolist() + v = np.concatenate(edge_dst).tolist() + for gcn in self.gcn: + gcn.forward(batched_graph, u, v, attribute='h') + + def forward(self, training=False, ground_truth=None): + if not training: + raise NotImplementedError("inference is not implemented yet") + + assert(ground_truth is not None) + signals, (batched_graph, self.sample_edge_src, self.sample_edge_dst) = ground_truth + nsteps, self.labels, self.node_select, self.masks, active_step, label1_set, label1_set_tensor = signals + # init loss + self.loss = 0 + + batch_size = len(self.sample_edge_src) + # initial node repr for each sample + hVs = torch.zeros(len(batched_graph), self.node_num_hidden) + # FIXME: what's the initial grpah repr for empty graph? + hGs = torch.zeros(batch_size, self.graph_num_hidden) + + if self.use_cuda: + hVs = hVs.cuda() + hGs = hGs.cuda() + batched_graph.set_n_repr({'h': hVs}) + + self.sample_node_start_idx = batched_graph.query_node_start_offset() + self.sample_node_curr_idx = self.sample_node_start_idx.copy() + self.sample_edge_count = np.zeros(batch_size, dtype=int) + + self.step = 0 + while self.step < nsteps: + if self.step % 2 == 0: # add node step + if active_step[self.step]: + # decide whether to add node + self.decide_add_node(hGs) + + # calculate initial state for new node + hvs = self.finit(hGs) + + # add node + update = label1_set[self.step] + if len(update) > 0: + hvs = torch.index_select(hvs, 0, label1_set_tensor[self.step]) + scatter_indices = self.sample_node_curr_idx[update] + batched_graph.set_n_repr({'h': hvs}, scatter_indices.tolist()) + self.sample_node_curr_idx[update] += 1 + + # get new graph repr + hGs = self.update_graph_repr(batched_graph, hGs, update, label1_set_tensor[self.step]) + else: + # all samples are masked + pass + + else: # add edge step + + # decide whether to add edge, which edge to add + # and also add edge + self.decide_add_edge(batched_graph, hGs) + + # propagate + to_add_edge = label1_set[self.step] + if len(to_add_edge) > 0: + # at least one graph needs update + self.select_node_to_add_edge(batched_graph, to_add_edge) + # update edge count for each sample + self.sample_edge_count[to_add_edge] += 2 # undirected graph + + # perform gcn propagation + self.propagate(batched_graph, to_add_edge) + + # get new graph repr + hGs = self.update_graph_repr(batched_graph, hGs, label1_set[self.step], label1_set_tensor[self.step]) + + self.step += 1 + + +def main(args): + + if torch.cuda.is_available() and args.gpu >= 0: + torch.cuda.set_device(args.gpu) + use_cuda = True + else: + use_cuda = False + + + def masked_cross_entropy(x, label, mask=None): + # x: propability tensor, i.e. after softmax + x = torch.log(x) + if mask is not None: + x = x[mask] + label = label[mask] + return F.nll_loss(x, label) + + model = DGMG(args.n_hidden_node, args.n_hidden_graph, args.n_layers, + loss_func=masked_cross_entropy, dropout=args.dropout, use_cuda=use_cuda) + if use_cuda: + model.cuda() + + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + + # training loop + for ep in range(args.n_epochs): + print("epoch: {}".format(ep)) + for idx, ground_truth in enumerate(DataLoader(args.dataset, args.batch_size)): + if use_cuda: + count, label, node_list, mask, active, label1, label1_tensor = ground_truth[0] + label, node_list, mask, label1_tensor = move2cuda((label, node_list, mask, label1_tensor)) + ground_truth[0] = (count, label, node_list, mask, active, label1, label1_tensor) + ground_truth[1][0].set_device(dgl.gpu(args.gpu)) + + optimizer.zero_grad() + # create new empty graphs + start = time.time() + model.forward(True, ground_truth) + end = time.time() + elapsed("model forward", start, end) + start = time.time() + model.loss.backward() + optimizer.step() + end = time.time() + elapsed("model backward", start, end) + print("iter {}: loss {}".format(idx, model.loss.item())) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='DGMG') + parser.add_argument("--dropout", type=float, default=0, + help="dropout probability") + parser.add_argument("--gpu", type=int, default=-1, + help="gpu") + parser.add_argument("--lr", type=float, default=1e-2, + help="learning rate") + parser.add_argument("--n-epochs", type=int, default=20, + help="number of training epochs") + parser.add_argument("--n-hidden-node", type=int, default=16, + help="number of hidden DGMG node units") + parser.add_argument("--n-hidden-graph", type=int, default=32, + help="number of hidden DGMG graph units") + parser.add_argument("--n-layers", type=int, default=2, + help="number of hidden gcn layers") + parser.add_argument("--dataset", type=str, default='samples.p', + help="dataset pickle file") + parser.add_argument("--batch-size", type=int, default=32, + help="batch size") + args = parser.parse_args() + print(args) + + main(args) diff --git a/examples/pytorch/generative_graph/util.py b/examples/pytorch/generative_graph/util.py new file mode 100644 index 000000000000..41fe4f4d9bde --- /dev/null +++ b/examples/pytorch/generative_graph/util.py @@ -0,0 +1,156 @@ +import networkx as nx +import pickle +import random +import dgl +import numpy as np +import torch + +def convert_graph_to_ordering(g): + ordering = [] + h = nx.DiGraph() + h.add_edges_from(g.edges) + for n in h.nodes(): + ordering.append(n) + for m in h.predecessors(n): + ordering.append((m, n)) + return ordering + +def generate_dataset(): + n = 15 + m = 2 + n_samples = 1024 + samples = [] + for _ in range(n_samples): + g = nx.barabasi_albert_graph(n, m) + samples.append(convert_graph_to_ordering(g)) + + with open('samples.p', 'wb') as f: + pickle.dump(samples, f) + +class DataLoader(object): + def __init__(self, fname, batch_size, shuffle=True): + with open(fname, 'rb') as f: + datasets = pickle.load(f) + if shuffle: + random.shuffle(datasets) + num = len(datasets) // batch_size + + # pre-process dataset + self.ground_truth = [] + for i in range(num): + batch = datasets[i*batch_size: (i+1)*batch_size] + padded_signals = pad_ground_truth(batch) + merged_graph = generate_merged_graph(batch) + self.ground_truth.append([padded_signals, merged_graph]) + + def __iter__(self): + return iter(self.ground_truth) + + +def generate_merged_graph(batch): + n_graphs = len(batch) + graph_list = [] + # build each sample graph + new_edges = [] + for ordering in batch: + g = dgl.DGLGraph() + node_count = 0 + edge_list = [] + for step in ordering: + if isinstance(step, int): + node_count += 1 + else: + assert isinstance(step, tuple) + edge_list.append(step) + edge_list.append(tuple(reversed(step))) + g.add_nodes_from(range(node_count)) + g.add_edges_from(edge_list) + new_edges.append(zip(*edge_list)) + graph_list.append(g) + # batch + bg = dgl.batch(graph_list) + # get new edges + new_edges = [bg.query_new_edge(g, *edges) for g, edges in zip(graph_list, new_edges)] + new_src, new_dst = zip(*new_edges) + return bg, new_src, new_dst + +def expand_ground_truth(ordering): + node_list = [] + action = [] + label = [] + first_step = True + for i in ordering: + if isinstance(i, int): + if not first_step: + # add not to add edge + action.append(1) + label.append(0) + node_list.append(-1) + else: + first_step = False + action.append(0) # add node + label.append(1) + node_list.append(i) + else: + assert(isinstance(i, tuple)) + action.append(1) + label.append(1) + node_list.append(i[0]) # select src node to add + # add not to add node + action.append(0) + label.append(0) + node_list.append(-1) + return len(action), action, label, node_list + +def pad_ground_truth(batch): + a = [] + bz = len(batch) + for sample in batch: + a.append(expand_ground_truth(sample)) + length, action, label, node_list = zip(*a) + step = [0] * bz + new_label = [] + new_node_list = [] + mask_for_batch = [] + next_action = 0 + count = 0 + active_step = [] # steps at least some graphs are not masked + label1_set = [] # graphs who decide to add node or edge + label1_set_tensor = [] + while any([step[i] < length[i] for i in range(bz)]): + node_select = [] + label_select = [] + mask = [] + label1 = [] + not_all_masked = False + for sample_idx in range(bz): + if step[sample_idx] < length[sample_idx] and \ + action[sample_idx][step[sample_idx]] == next_action: + mask.append(1) + node_select.append(node_list[sample_idx][step[sample_idx]]) + label_select.append(label[sample_idx][step[sample_idx]]) + # if decide to add node or add edge, record sample_idx + if label_select[-1] == 1: + label1.append(sample_idx) + step[sample_idx] += 1 + not_all_masked = True + else: + mask.append(0) + node_select.append(-1) + label_select.append(0) + next_action = 1 - next_action + new_node_list.append(torch.LongTensor(node_select)) + mask_for_batch.append(torch.ByteTensor(mask)) + new_label.append(torch.LongTensor(label_select)) + active_step.append(not_all_masked) + label1_set.append(np.array(label1)) + label1_set_tensor.append(torch.LongTensor(label1)) + count += 1 + + return count, new_label, new_node_list, mask_for_batch, active_step, label1_set, label1_set_tensor + +def elapsed(msg, start, end): + print("{}: {} ms".format(msg, int((end-start)*1000))) + +if __name__ == '__main__': + generate_dataset() diff --git a/python/dgl/__init__.py b/python/dgl/__init__.py index 10ac3c13e5d0..9e935183861c 100644 --- a/python/dgl/__init__.py +++ b/python/dgl/__init__.py @@ -2,3 +2,4 @@ from .graph import DGLGraph from .graph import __MSG__, __REPR__ from .context import cpu, gpu +from .batch import batch, unbatch diff --git a/python/dgl/backend/numpy.py b/python/dgl/backend/numpy.py index 8bc39f05e96e..ed2ad58e865a 100644 --- a/python/dgl/backend/numpy.py +++ b/python/dgl/backend/numpy.py @@ -12,8 +12,13 @@ def asnumpy(a): def pack(arrays): return np.concatenate(arrays, axis=0) -def unpack(a): - return np.split(a, a.shape[0], axis=0) +def unpack(a, split_size_or_sections=None): + if split_size_or_sections is None: + indices_or_sections = a.shape[0] + else: + # convert split size to split indices by cumsum + indices_or_sections = np.cumsum(split_size_or_sections)[:-1] + return np.split(a, indices_or_sections, axis=0) def shape(a): return a.shape diff --git a/python/dgl/backend/pytorch.py b/python/dgl/backend/pytorch.py index e65b2913e8f2..d0a13e2a0884 100644 --- a/python/dgl/backend/pytorch.py +++ b/python/dgl/backend/pytorch.py @@ -32,8 +32,8 @@ def asnumpy(a): def pack(tensors): return th.cat(tensors) -def unpack(x): - return th.split(x, 1) +def unpack(x, indices_or_sections=1): + return th.split(x, indices_or_sections) def shape(x): return x.shape diff --git a/python/dgl/batch.py b/python/dgl/batch.py new file mode 100644 index 000000000000..522f1a6c1744 --- /dev/null +++ b/python/dgl/batch.py @@ -0,0 +1,133 @@ +from dgl.graph import DGLGraph +import dgl.backend as F +import dgl +import numpy as np + +class BatchedDGLGraph(DGLGraph): + def __init__(self, graph_list, node_attrs=None, edge_attrs=None, **attr): + super(BatchedDGLGraph, self).__init__(**attr) + self.graph_list = graph_list + self.graph_idx = {} + for idx, g in enumerate(self.graph_list): + self.graph_idx[g] = idx + + self.num_nodes = [len(g) for g in self.graph_list] + self.num_edges = [g.size() for g in self.graph_list] + + # calc index offset + self.node_offset = np.cumsum([0] + self.num_nodes) + self.edge_offset = np.cumsum([0] + self.num_edges) + + # in-order add relabeled nodes + self.add_nodes_from(range(self.node_offset[-1])) + + # in-order add relabeled edges + self.new_edge_list = [np.array(g.edges) + offset + for g, offset in zip(self.graph_list, self.node_offset[:-1])] + self.new_edges = np.concatenate(self.new_edge_list) + self.add_edges_from(self.new_edges) + + assert self.size() == self.edge_offset[-1] + + # set new node attr + if node_attrs: + attrs = {} + for key in node_attrs: + vals = [g.pop_n_repr(key) for g in self.graph_list] + attrs[key] = F.pack(vals) + self.set_n_repr(attrs) + else: + for g in self.graph_list: + self._node_frame.append(g._node_frame) + + # set new edge attr + if edge_attrs: + attrs = {} + for key in edge_attrs: + vals = [g.pop_e_repr(key) for g in self.graph_list] + attrs[key] = F.pack(vals) + self.set_e_repr(attrs) + else: + for g in self.graph_list: + self._edge_frame.append(g._edge_frame) + + def query_new_node(self, g, u): + idx = self.graph_idx[g] + offset = self.node_offset[idx] + if isinstance(u, (int, np.array, F.Tensor)): + return u + offset + else: + return np.array(u) + offset + + def query_new_edge(self, g, src, dst): + idx = self.graph_idx[g] + offset = self.node_offset[idx] + if isinstance(src, (int, np.ndarray, F.Tensor)) and \ + isinstance(dst, (int, np.ndarray, F.Tensor)): + return src + offset, dst + offset + else: + return np.array(src) + offset, np.array(dst) + offset + + def query_node_start_offset(self): + return self.node_offset[:-1].copy() + + def query_edge_start_offset(self): + return self.edge_offset[:-1].copy() + + +def unbatch(graph_batch): + """Unbatch the graph and return a list of subgraphs. + + Parameters + ---------- + graph_batch : DGLGraph + The batched graph. + """ + graph_list = graph_batch.graph_list + num_graphs = len(graph_list) + # split and set node attrs + attrs = [{} for _ in range(num_graphs)] # node attr dict for each graph + for key in graph_batch.get_n_attr_list(): + vals = F.unpack(graph_batch.pop_n_repr(key), graph_batch.num_nodes) + for attr, val in zip(attrs, vals): + attr[key] = val + for attr, g in zip(attrs, graph_list): + g.set_n_repr(attr) + + # split and set edge attrs + attrs = [{} for _ in range(num_graphs)] # edge attr dict for each graph + for key in graph_batch.get_e_attr_list(): + vals = F.unpack(graph_batch.pop_e_repr(key), graph_batch.num_edges) + for attr, val in zip(attrs, vals): + attr[key] = val + for attr, g in zip(attrs, graph_list): + g.set_e_repr(attr) + + return graph_list + + +# FIXME (lingfan): Do we really need the batch API? +# Can't we let user call BatchedDGLGraph(graph_list) directly +# and make unbatch a member function of BatchedDGLGraph +def batch(graph_list, node_attrs=None, edge_attrs=None): + """Batch a list of DGLGraphs into one single graph. + Once batch is called, the structure of both merged graph and graphs in graph_list + must not bbe mutated, or unbatch's behavior will be undefined. + + Parameters + ---------- + graph_list : iterable + A list of DGLGraphs to be batched. + node_attrs : str or iterable + A list of node attributes needed for merged graph + It's user's resposiblity to make sure node_attrs exists + edge_attrs : str or iterable + A list of edge attributes needed for merged graph + It's user's resposiblity to make sure edge_attrs exists + + Return + ------ + newgrh: DGLGraph + one single merged graph + """ + return BatchedDGLGraph(graph_list, node_attrs, edge_attrs) diff --git a/python/dgl/graph.py b/python/dgl/graph.py index e4bd7ff8f54f..b80a3d226b02 100644 --- a/python/dgl/graph.py +++ b/python/dgl/graph.py @@ -33,6 +33,7 @@ def __setitem__(self, key, val): def __getitem__(self, key): return self._dict[key] def __delitem__(self, key): + # FIXME: add callback del self._dict[key] def __len__(self): return len(self._dict) @@ -51,6 +52,7 @@ def __setitem__(self, key, val): def __getitem__(self, key): return self._dict[key] def __delitem__(self, key): + # FIXME: add callback del self._dict[key] def __len__(self): return len(self._dict) @@ -78,6 +80,12 @@ def __init__(self, graph_data=None, **attr): self.adjlist_outer_dict_factory = None self.adjlist_inner_dict_factory = lambda : _AdjInnerDict(self._add_edge_callback) self.edge_attr_dict_factory = dict + self._context = context.cpu() + # call base class init + super(DGLGraph, self).__init__(graph_data, **attr) + self._init_state() + + def _init_state(self): # cached graph and storage self._cached_graph = None self._node_frame = Frame() @@ -91,9 +99,16 @@ def __init__(self, graph_data=None, **attr): self._edge_func = None self._edge_cb_state = True self._edge_list = [] - self._context = context.cpu() - # call base class init - super(DGLGraph, self).__init__(graph_data, **attr) + + def clear(self): + super(DGLGraph, self).clear() + self._init_state() + + def get_n_attr_list(self): + return self._node_frame.schemes + + def get_e_attr_list(self): + return self._edge_frame.schemes def set_n_repr(self, hu, u=ALL): """Set node(s) representation. @@ -764,6 +779,8 @@ def _batch_update_by_edge( new_node_repr = update_func(node_repr, reduced_msgs) self.set_n_repr(new_node_repr, new2old) else: + u = utils.convert_to_id_tensor(u, self.context) + v = utils.convert_to_id_tensor(v, self.context) self._batch_sendto(u, v, message_func) unique_v = F.unique(v) self._batch_recv(unique_v, reduce_func, update_func) @@ -990,6 +1007,7 @@ def edge_list(self): """Return edges in the addition order.""" return self._edge_list + def _get_repr(attr_dict): if len(attr_dict) == 1 and __REPR__ in attr_dict: return attr_dict[__REPR__] diff --git a/python/dgl/nn/__init__.py b/python/dgl/nn/__init__.py new file mode 100644 index 000000000000..4c0b6ab938f4 --- /dev/null +++ b/python/dgl/nn/__init__.py @@ -0,0 +1,9 @@ +import os +__backend__ = os.environ.get('DGLBACKEND', 'pytorch').lower() + +if __backend__ == 'numpy': + pass +elif __backend__ == 'pytorch': + from .pytorch import * +else: + raise Exception("Unsupported backend %s" % __backend__) diff --git a/python/dgl/nn/pytorch/__init__.py b/python/dgl/nn/pytorch/__init__.py new file mode 100644 index 000000000000..e490b75d0dc6 --- /dev/null +++ b/python/dgl/nn/pytorch/__init__.py @@ -0,0 +1 @@ +from .gcn import GCN diff --git a/python/dgl/nn/pytorch/gcn.py b/python/dgl/nn/pytorch/gcn.py new file mode 100644 index 000000000000..f4e02da21b33 --- /dev/null +++ b/python/dgl/nn/pytorch/gcn.py @@ -0,0 +1,49 @@ +""" +Semi-Supervised Classification with Graph Convolutional Networks +Paper: https://arxiv.org/abs/1609.02907 +Code: https://github.com/tkipf/gcn + +GCN with SPMV specialization. +""" +import torch.nn as nn +from dgl.base import ALL, is_all + +class NodeUpdateModule(nn.Module): + def __init__(self, in_feats, out_feats, activation=None): + super(NodeUpdateModule, self).__init__() + self.linear = nn.Linear(in_feats, out_feats) + self.activation = activation + self.attribute = None + + def set_attribute_to_update(self, attribute): + self.attribute = attribute + + def forward(self, node, accum, attribute=None): + if self.attribute: + accum = accum[self.attribute] + h = self.linear(accum) + if self.activation: + h = self.activation(h) + if self.attribute: + return {self.attribute: h} + else: + return h + +class GCN(nn.Module): + def __init__(self, + in_feats, + out_feats, + activation, + dropout=0): + super(GCN, self).__init__() + self.dropout = dropout + # input layer + self.update_func = NodeUpdateModule(in_feats, out_feats, activation) + + def forward(self, g, u=ALL, v=ALL, attribute=None): + self.update_func.set_attribute_to_update(attribute) + if is_all(u) and is_all(v): + g.update_all('from_src', 'sum', self.update_func, batchable=True) + else: + g.update_by_edge(u, v, 'from_src', 'sum', self.update_func, batchable=True) + return g diff --git a/tests/test_graph_batch.py b/tests/test_graph_batch.py new file mode 100644 index 000000000000..705b9cb86178 --- /dev/null +++ b/tests/test_graph_batch.py @@ -0,0 +1,127 @@ +import networkx as nx +import dgl +import torch +import numpy as np + +def tree1(): + """Generate a tree + 0 + / \ + 1 2 + / \ + 3 4 + Edges are from leaves to root. + """ + g = dgl.DGLGraph() + g.add_node(0) + g.add_node(1) + g.add_node(2) + g.add_node(3) + g.add_node(4) + g.add_edge(3, 1) + g.add_edge(4, 1) + g.add_edge(1, 0) + g.add_edge(2, 0) + g.set_n_repr(torch.Tensor([0, 1, 2, 3, 4])) + return g + +def tree2(): + """Generate a tree + 1 + / \ + 4 3 + / \ + 2 0 + Edges are from leaves to root. + """ + g = dgl.DGLGraph() + g.add_node(0) + g.add_node(1) + g.add_node(2) + g.add_node(3) + g.add_node(4) + g.add_edge(2, 4) + g.add_edge(0, 4) + g.add_edge(4, 1) + g.add_edge(3, 1) + g.set_n_repr(torch.Tensor([0, 1, 2, 3, 4])) + return g + +def test_batch_unbatch(): + t1 = tree1() + t2 = tree2() + f1 = t1.get_n_repr() + f2 = t2.get_n_repr() + + bg = dgl.batch([t1, t2]) + dgl.unbatch(bg) + + assert(f1.equal(t1.get_n_repr())) + assert(f2.equal(t2.get_n_repr())) + + +def test_batch_sendrecv(): + t1 = tree1() + t2 = tree2() + + bg = dgl.batch([t1, t2]) + bg.register_message_func(lambda src, edge: src, batchable=True) + bg.register_reduce_func(lambda node, msgs: torch.sum(msgs, 1), batchable=True) + bg.register_update_func(lambda node, accum: accum, batchable=True) + e1 = [(3, 1), (4, 1)] + e2 = [(2, 4), (0, 4)] + + u1, v1 = bg.query_new_edge(t1, *zip(*e1)) + u2, v2 = bg.query_new_edge(t2, *zip(*e2)) + u = np.concatenate((u1, u2)).tolist() + v = np.concatenate((v1, v2)).tolist() + + bg.sendto(u, v) + bg.recv(v) + + dgl.unbatch(bg) + assert t1.get_n_repr()[1] == 7 + assert t2.get_n_repr()[4] == 2 + + +def test_batch_propagate(): + t1 = tree1() + t2 = tree2() + + bg = dgl.batch([t1, t2]) + bg.register_message_func(lambda src, edge: src, batchable=True) + bg.register_reduce_func(lambda node, msgs: torch.sum(msgs, 1), batchable=True) + bg.register_update_func(lambda node, accum: accum, batchable=True) + # get leaves. + + order = [] + + # step 1 + e1 = [(3, 1), (4, 1)] + e2 = [(2, 4), (0, 4)] + u1, v1 = bg.query_new_edge(t1, *zip(*e1)) + u2, v2 = bg.query_new_edge(t2, *zip(*e2)) + u = np.concatenate((u1, u2)).tolist() + v = np.concatenate((v1, v2)).tolist() + order.append((u, v)) + + # step 2 + e1 = [(1, 0), (2, 0)] + e2 = [(4, 1), (3, 1)] + u1, v1 = bg.query_new_edge(t1, *zip(*e1)) + u2, v2 = bg.query_new_edge(t2, *zip(*e2)) + u = np.concatenate((u1, u2)).tolist() + v = np.concatenate((v1, v2)).tolist() + order.append((u, v)) + + bg.propagate(iterator=order) + dgl.unbatch(bg) + + assert t1.get_n_repr()[0] == 9 + assert t2.get_n_repr()[1] == 5 + + +if __name__ == '__main__': + test_batch_unbatch() + test_batch_sendrecv() + test_batch_propagate()