diff --git a/docs/source/api/python/nn.rst b/docs/source/api/python/nn.rst index 6fd3adae6a76..37d9ec2e338c 100644 --- a/docs/source/api/python/nn.rst +++ b/docs/source/api/python/nn.rst @@ -9,3 +9,4 @@ NN Modules nn.pytorch nn.mxnet + nn.tensorflow diff --git a/docs/source/api/python/nn.tensorflow.rst b/docs/source/api/python/nn.tensorflow.rst new file mode 100644 index 000000000000..8218cbcd743a --- /dev/null +++ b/docs/source/api/python/nn.tensorflow.rst @@ -0,0 +1,118 @@ +.. _apinn-tensorflow: + +NN Modules (Tensorflow) +==================== + +.. contents:: Contents + :local: + +We welcome your contribution! If you want a model to be implemented in DGL as a NN module, +please `create an issue `_ started with "[Feature Request] NN Module XXXModel". + +If you want to contribute a NN module, please `create a pull request `_ started +with "[NN] XXXModel in tensorflow NN Modules" and our team member would review this PR. + +Conv Layers +---------------------------------------- + +.. automodule:: dgl.nn.tensorflow.conv + +GraphConv +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: dgl.nn.tensorflow.conv.GraphConv + :members: weight, bias, forward, reset_parameters + :show-inheritance: + +RelGraphConv +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: dgl.nn.tensorflow.conv.RelGraphConv + :members: forward + :show-inheritance: + +GATConv +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: dgl.nn.tensorflow.conv.GATConv + :members: forward + :show-inheritance: + +SAGEConv +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: dgl.nn.tensorflow.conv.SAGEConv + :members: forward + :show-inheritance: + +SGConv +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: dgl.nn.tensorflow.conv.SGConv + :members: forward + :show-inheritance: + +APPNPConv +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: dgl.nn.tensorflow.conv.APPNPConv + :members: forward + :show-inheritance: + +GINConv +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: dgl.nn.tensorflow.conv.GINConv + :members: forward + :show-inheritance: + + +Global Pooling Layers +---------------------------------------- + +.. automodule:: dgl.nn.tensorflow.glob + +SumPooling +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: dgl.nn.tensorflow.glob.SumPooling + :members: + :show-inheritance: + +AvgPooling +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: dgl.nn.tensorflow.glob.AvgPooling + :members: + :show-inheritance: + +MaxPooling +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: dgl.nn.tensorflow.glob.MaxPooling + :members: + :show-inheritance: + +SortPooling +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: dgl.nn.tensorflow.glob.SortPooling + :members: + :show-inheritance: + +GlobalAttentionPooling +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: dgl.nn.tensorflow.glob.GlobalAttentionPooling + :members: + :show-inheritance: + + +Utility Modules +---------------------------------------- + +Edge Softmax +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automodule:: dgl.nn.tensorflow.softmax + :members: edge_softmax diff --git a/examples/pytorch/dgi/gcn.py b/examples/pytorch/dgi/gcn.py index 1b494e656e00..420ae839be16 100644 --- a/examples/pytorch/dgi/gcn.py +++ b/examples/pytorch/dgi/gcn.py @@ -31,5 +31,5 @@ def forward(self, features): for i, layer in enumerate(self.layers): if i != 0: h = self.dropout(h) - h = layer(h, self.g) + h = layer(self.g, h) return h diff --git a/examples/pytorch/gat/README.md b/examples/pytorch/gat/README.md index bcb1bea812b2..dd9e9e904e1a 100644 --- a/examples/pytorch/gat/README.md +++ b/examples/pytorch/gat/README.md @@ -41,11 +41,11 @@ python3 train_ppi.py --gpu=0 Results ------- -| Dataset | Test Accuracy | Time(s) | Baseline#1 times(s) | Baseline#2 times(s) | -| ------- | ------------- | ------- | ------------------- | ------------------- | -| Cora | 84.02(0.40) | 0.0113 | 0.0982 (**8.7x**) | 0.0424 (**3.8x**) | -| Citeseer | 70.91(0.79) | 0.0111 | n/a | n/a | -| Pubmed | 78.57(0.75) | 0.0115 | n/a | n/a | +| Dataset | Test Accuracy | Time(s) | Baseline#1 times(s) | Baseline#2 times(s) | +| -------- | ------------- | ------- | ------------------- | ------------------- | +| Cora | 84.02(0.40) | 0.0113 | 0.0982 (**8.7x**) | 0.0424 (**3.8x**) | +| Citeseer | 70.91(0.79) | 0.0111 | n/a | n/a | +| Pubmed | 78.57(0.75) | 0.0115 | n/a | n/a | * All the accuracy numbers are obtained after 300 epochs. * The time measures how long it takes to train one epoch. diff --git a/examples/tensorflow/dgi/README.md b/examples/tensorflow/dgi/README.md new file mode 100644 index 000000000000..b49693f471e9 --- /dev/null +++ b/examples/tensorflow/dgi/README.md @@ -0,0 +1,38 @@ +Deep Graph Infomax (DGI) +======================== + +- Paper link: [https://arxiv.org/abs/1809.10341](https://arxiv.org/abs/1809.10341) +- Author's code repo (in Pytorch): + [https://github.com/PetarV-/DGI](https://github.com/PetarV-/DGI) + +Dependencies +------------ +- tensorflow 2.1+ +- requests + +```bash +pip install tensorflow requests +``` + +How to run +---------- + +Run with following: + +```bash +python3 train.py --dataset=cora --gpu=0 --self-loop +``` + +```bash +python3 train.py --dataset=citeseer --gpu=0 +``` + +```bash +python3 train.py --dataset=pubmed --gpu=0 +``` + +Results +------- +* cora: ~81.6 (80.9-82.9) (paper: 82.3) +* citeseer: ~70.2 (paper: 71.8) +* pubmed: ~77.2 (paper: 76.8) diff --git a/examples/tensorflow/dgi/dgi.py b/examples/tensorflow/dgi/dgi.py new file mode 100644 index 000000000000..b941b4e94d3c --- /dev/null +++ b/examples/tensorflow/dgi/dgi.py @@ -0,0 +1,75 @@ +""" +Deep Graph Infomax in DGL + +References +---------- +Papers: https://arxiv.org/abs/1809.10341 +Author's code: https://github.com/PetarV-/DGI +""" + +import tensorflow as tf +from tensorflow.keras import layers +import numpy as np +import math +from gcn import GCN + + +class Encoder(layers.Layer): + def __init__(self, g, in_feats, n_hidden, n_layers, activation, dropout): + super(Encoder, self).__init__() + self.g = g + self.conv = GCN(g, in_feats, n_hidden, n_hidden, + n_layers, activation, dropout) + + def call(self, features, corrupt=False): + if corrupt: + perm = np.random.permutation(self.g.number_of_nodes()) + features = tf.gather(features, perm) + features = self.conv(features) + return features + + +class Discriminator(layers.Layer): + def __init__(self, n_hidden): + super(Discriminator, self).__init__() + uinit = tf.keras.initializers.RandomUniform( + -1.0/math.sqrt(n_hidden), 1.0/math.sqrt(n_hidden)) + self.weight = tf.Variable(initial_value=uinit( + shape=(n_hidden, n_hidden), dtype='float32'), trainable=True) + + def call(self, features, summary): + features = tf.matmul(features, tf.matmul( + self.weight, tf.expand_dims(summary, -1))) + return features + + +class DGI(tf.keras.Model): + def __init__(self, g, in_feats, n_hidden, n_layers, activation, dropout): + super(DGI, self).__init__() + self.encoder = Encoder(g, in_feats, n_hidden, + n_layers, activation, dropout) + self.discriminator = Discriminator(n_hidden) + self.loss = tf.nn.sigmoid_cross_entropy_with_logits + + def call(self, features): + positive = self.encoder(features, corrupt=False) + negative = self.encoder(features, corrupt=True) + summary = tf.nn.sigmoid(tf.reduce_mean(positive, axis=0)) + + positive = self.discriminator(positive, summary) + negative = self.discriminator(negative, summary) + + l1 = self.loss(tf.ones(positive.shape),positive) + l2 = self.loss(tf.zeros(negative.shape), negative) + + return tf.reduce_mean(l1) + tf.reduce_mean(l2) + + +class Classifier(layers.Layer): + def __init__(self, n_hidden, n_classes): + super(Classifier, self).__init__() + self.fc = layers.Dense(n_classes) + + def call(self, features): + features = self.fc(features) + return features diff --git a/examples/tensorflow/dgi/gcn.py b/examples/tensorflow/dgi/gcn.py new file mode 100644 index 000000000000..e896b8a0e30e --- /dev/null +++ b/examples/tensorflow/dgi/gcn.py @@ -0,0 +1,36 @@ +""" +This code was copied from the GCN implementation in DGL examples. +""" +import tensorflow as tf +from tensorflow.keras import layers + +from dgl.nn.tensorflow import GraphConv + +class GCN(layers.Layer): + def __init__(self, + g, + in_feats, + n_hidden, + n_classes, + n_layers, + activation, + dropout): + super(GCN, self).__init__() + self.g = g + self.layers =[] + # input layer + self.layers.append(GraphConv(in_feats, n_hidden, activation=activation)) + # hidden layers + for i in range(n_layers - 1): + self.layers.append(GraphConv(n_hidden, n_hidden, activation=activation)) + # output layer + self.layers.append(GraphConv(n_hidden, n_classes)) + self.dropout = layers.Dropout(dropout) + + def call(self, features): + h = features + for i, layer in enumerate(self.layers): + if i != 0: + h = self.dropout(h) + h = layer(self.g, h) + return h diff --git a/examples/tensorflow/dgi/train.py b/examples/tensorflow/dgi/train.py new file mode 100644 index 000000000000..33357cc9fcfd --- /dev/null +++ b/examples/tensorflow/dgi/train.py @@ -0,0 +1,170 @@ +import argparse +import time +import numpy as np +import networkx as nx +import tensorflow as tf +from tensorflow.keras import layers +from dgl import DGLGraph +from dgl.data import register_data_args, load_data +from dgi import DGI, Classifier + + +def evaluate(model, features, labels, mask): + logits = model(features, training=False) + logits = logits[mask] + labels = labels[mask] + indices = tf.math.argmax(logits, axis=1) + acc = tf.reduce_mean(tf.cast(indices == labels, dtype=tf.float32)) + return acc.numpy().item() + + +def main(args): + # load and preprocess dataset + data = load_data(args) + if args.gpu < 0: + device = "/cpu:0" + else: + device = "/gpu:{}".format(args.gpu) + with tf.device(device): + features = tf.convert_to_tensor(data.features, dtype=tf.float32) + labels = tf.convert_to_tensor(data.labels, dtype=tf.int64) + train_mask = tf.convert_to_tensor(data.train_mask, dtype=tf.bool) + val_mask = tf.convert_to_tensor(data.val_mask, dtype=tf.bool) + test_mask = tf.convert_to_tensor(data.test_mask, dtype=tf.bool) + in_feats = features.shape[1] + n_classes = data.num_labels + n_edges = data.graph.number_of_edges() + + # graph preprocess + g = data.graph + # add self loop + if args.self_loop: + g.remove_edges_from(nx.selfloop_edges(g)) + g.add_edges_from(zip(g.nodes(), g.nodes())) + g = DGLGraph(g) + n_edges = g.number_of_edges() + + # create DGI model + dgi = DGI(g, + in_feats, + args.n_hidden, + args.n_layers, + tf.keras.layers.PReLU(alpha_initializer=tf.constant_initializer(0.25)), + args.dropout) + + dgi_optimizer = tf.keras.optimizers.Adam( + learning_rate=args.dgi_lr) + + # train deep graph infomax + cnt_wait = 0 + best = 1e9 + best_t = 0 + dur = [] + for epoch in range(args.n_dgi_epochs): + if epoch >= 3: + t0 = time.time() + + with tf.GradientTape() as tape: + loss = dgi(features) + # Manually Weight Decay + # We found Tensorflow has a different implementation on weight decay + # of Adam(W) optimizer with PyTorch. And this results in worse results. + # Manually adding weights to the loss to do weight decay solves this problem. + for weight in dgi.trainable_weights: + loss = loss + \ + args.weight_decay * tf.nn.l2_loss(weight) + grads = tape.gradient(loss, dgi.trainable_weights) + dgi_optimizer.apply_gradients(zip(grads, dgi.trainable_weights)) + + if loss < best: + best = loss + best_t = epoch + cnt_wait = 0 + dgi.save_weights('best_dgi.pkl') + else: + cnt_wait += 1 + + if cnt_wait == args.patience: + print('Early stopping!') + break + + if epoch >= 3: + dur.append(time.time() - t0) + + print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | " + "ETputs(KTEPS) {:.2f}".format(epoch, np.mean(dur), loss.numpy().item(), + n_edges / np.mean(dur) / 1000)) + + # create classifier model + classifier = Classifier(args.n_hidden, n_classes) + + classifier_optimizer = tf.keras.optimizers.Adam(learning_rate=args.classifier_lr) + + # train classifier + print('Loading {}th epoch'.format(best_t)) + dgi.load_weights('best_dgi.pkl') + embeds = dgi.encoder(features, corrupt=False) + embeds = tf.stop_gradient(embeds) + dur = [] + loss_fcn = tf.keras.losses.SparseCategoricalCrossentropy( + from_logits=True) + for epoch in range(args.n_classifier_epochs): + if epoch >= 3: + t0 = time.time() + with tf.GradientTape() as tape: + preds = classifier(embeds) + loss = loss_fcn(labels[train_mask], preds[train_mask]) + # Manually Weight Decay + # We found Tensorflow has a different implementation on weight decay + # of Adam(W) optimizer with PyTorch. And this results in worse results. + # Manually adding weights to the loss to do weight decay solves this problem. + # In original code, there's no weight decay applied in this part + # link: https://github.com/PetarV-/DGI/blob/master/execute.py#L121 + # for weight in classifier.trainable_weights: + # loss = loss + \ + # args.weight_decay * tf.nn.l2_loss(weight) + grads = tape.gradient(loss, classifier.trainable_weights) + classifier_optimizer.apply_gradients(zip(grads, classifier.trainable_weights)) + if epoch >= 3: + dur.append(time.time() - t0) + + acc = evaluate(classifier, embeds, labels, val_mask) + print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | " + "ETputs(KTEPS) {:.2f}".format(epoch, np.mean(dur), loss.numpy().item(), + acc, n_edges / np.mean(dur) / 1000)) + + print() + acc = evaluate(classifier, embeds, labels, test_mask) + print("Test Accuracy {:.4f}".format(acc)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='DGI') + register_data_args(parser) + parser.add_argument("--dropout", type=float, default=0., + help="dropout probability") + parser.add_argument("--gpu", type=int, default=-1, + help="gpu") + parser.add_argument("--dgi-lr", type=float, default=1e-3, + help="dgi learning rate") + parser.add_argument("--classifier-lr", type=float, default=1e-2, + help="classifier learning rate") + parser.add_argument("--n-dgi-epochs", type=int, default=300, + help="number of training epochs") + parser.add_argument("--n-classifier-epochs", type=int, default=300, + help="number of training epochs") + parser.add_argument("--n-hidden", type=int, default=512, + help="number of hidden gcn units") + parser.add_argument("--n-layers", type=int, default=1, + help="number of hidden gcn layers") + parser.add_argument("--weight-decay", type=float, default=0., + help="Weight for L2 loss") + parser.add_argument("--patience", type=int, default=20, + help="early stop patience condition") + parser.add_argument("--self-loop", action='store_true', + help="graph self-loop (default=False)") + parser.set_defaults(self_loop=False) + args = parser.parse_args() + print(args) + + main(args) diff --git a/examples/tensorflow/gat/README.md b/examples/tensorflow/gat/README.md new file mode 100644 index 000000000000..f29603f6538c --- /dev/null +++ b/examples/tensorflow/gat/README.md @@ -0,0 +1,47 @@ +Graph Attention Networks (GAT) +============ + +- Paper link: [https://arxiv.org/abs/1710.10903](https://arxiv.org/abs/1710.10903) +- Author's code repo (in Tensorflow): + [https://github.com/PetarV-/GAT](https://github.com/PetarV-/GAT). +- Popular pytorch implementation: + [https://github.com/Diego999/pyGAT](https://github.com/Diego999/pyGAT). + +Dependencies +------------ +- tensorflow 2.1.0+ +- requests + +```bash +pip install tensorflow requests +``` + +How to run +---------- + +Run with following: + +```bash +python3 train.py --dataset=cora --gpu=0 +``` + +```bash +python3 train.py --dataset=citeseer --gpu=0 --early-stop +``` + +```bash +python3 train.py --dataset=pubmed --gpu=0 --num-out-heads=8 --weight-decay=0.001 --early-stop +``` + + +Results +------- + +| Dataset | Test Accuracy | Baseline (paper) | +| -------- | ------------- | ---------------- | +| Cora | 84.2 | 83.0(+-0.7) | +| Citeseer | 70.9 | 72.5(+-0.7) | +| Pubmed | 78.5 | 79.0(+-0.3) | + +* All the accuracy numbers are obtained after 200 epochs. +* All time is measured on EC2 p3.2xlarge instance w/ V100 GPU. diff --git a/examples/tensorflow/gat/gat.py b/examples/tensorflow/gat/gat.py new file mode 100644 index 000000000000..a98040e2aa58 --- /dev/null +++ b/examples/tensorflow/gat/gat.py @@ -0,0 +1,56 @@ +""" +Graph Attention Networks in DGL using SPMV optimization. +References +---------- +Paper: https://arxiv.org/abs/1710.10903 +Author's code: https://github.com/PetarV-/GAT +Pytorch implementation: https://github.com/Diego999/pyGAT +""" + +import tensorflow as tf +from tensorflow.keras import layers +import dgl.function as fn +from dgl.nn.tensorflow import edge_softmax, GATConv + + +class GAT(tf.keras.Model): + def __init__(self, + g, + num_layers, + in_dim, + num_hidden, + num_classes, + heads, + activation, + feat_drop, + attn_drop, + negative_slope, + residual): + super(GAT, self).__init__() + self.g = g + self.num_layers = num_layers + self.gat_layers = [] + self.activation = activation + # input projection (no residual) + self.gat_layers.append(GATConv( + in_dim, num_hidden, heads[0], + feat_drop, attn_drop, negative_slope, False, self.activation)) + # hidden layers + for l in range(1, num_layers): + # due to multi-head, the in_dim = num_hidden * num_heads + self.gat_layers.append(GATConv( + num_hidden * heads[l-1], num_hidden, heads[l], + feat_drop, attn_drop, negative_slope, residual, self.activation)) + # output projection + self.gat_layers.append(GATConv( + num_hidden * heads[-2], num_classes, heads[-1], + feat_drop, attn_drop, negative_slope, residual, None)) + + def call(self, inputs): + h = inputs + for l in range(self.num_layers): + h = self.gat_layers[l](self.g, h) + h = tf.reshape(h, (h.shape[0], -1)) + # output projection + logits = tf.reduce_mean(self.gat_layers[-1](self.g, h), axis=1) + return logits diff --git a/examples/tensorflow/gat/train.py b/examples/tensorflow/gat/train.py new file mode 100644 index 000000000000..2b82e45c40b0 --- /dev/null +++ b/examples/tensorflow/gat/train.py @@ -0,0 +1,179 @@ +""" +Graph Attention Networks in DGL using SPMV optimization. +Multiple heads are also batched together for faster training. +Compared with the original paper, this code does not implement +early stopping. +References +---------- +Paper: https://arxiv.org/abs/1710.10903 +Author's code: https://github.com/PetarV-/GAT +Pytorch implementation: https://github.com/Diego999/pyGAT +""" + +import argparse +import numpy as np +import networkx as nx +import time +import tensorflow as tf +from dgl import DGLGraph +from dgl.data import register_data_args, load_data +from gat import GAT +from utils import EarlyStopping + +def accuracy(logits, labels): + indices = tf.math.argmax(logits, axis=1) + acc = tf.reduce_mean(tf.cast(indices == labels, dtype=tf.float32)) + return acc.numpy().item() + + +def evaluate(model, features, labels, mask): + logits = model(features, training=False) + logits = logits[mask] + labels = labels[mask] + return accuracy(logits, labels) + + +def main(args): + # load and preprocess dataset + data = load_data(args) + + if args.gpu < 0: + device = "/cpu:0" + else: + device = "/gpu:{}".format(args.gpu) + + with tf.device(device): + + features = tf.convert_to_tensor(data.features, dtype=tf.float32) + labels = tf.convert_to_tensor(data.labels, dtype=tf.int64) + train_mask = tf.convert_to_tensor(data.train_mask, dtype=tf.bool) + val_mask = tf.convert_to_tensor(data.val_mask, dtype=tf.bool) + test_mask = tf.convert_to_tensor(data.test_mask, dtype=tf.bool) + num_feats = features.shape[1] + n_classes = data.num_labels + n_edges = data.graph.number_of_edges() + print("""----Data statistics------' + #Edges %d + #Classes %d + #Train samples %d + #Val samples %d + #Test samples %d""" % + (n_edges, n_classes, + train_mask.numpy().sum(), + val_mask.numpy().sum(), + test_mask.numpy().sum())) + + g = data.graph + # add self loop + g.remove_edges_from(nx.selfloop_edges(g)) + g = DGLGraph(g) + g.add_edges(g.nodes(), g.nodes()) + n_edges = g.number_of_edges() + # create model + heads = ([args.num_heads] * args.num_layers) + [args.num_out_heads] + model = GAT(g, + args.num_layers, + num_feats, + args.num_hidden, + n_classes, + heads, + tf.nn.elu, + args.in_drop, + args.attn_drop, + args.negative_slope, + args.residual) + print(model) + if args.early_stop: + stopper = EarlyStopping(patience=100) + + # loss_fcn = tf.keras.losses.SparseCategoricalCrossentropy( + # from_logits=False) + loss_fcn = tf.nn.sparse_softmax_cross_entropy_with_logits + + # use optimizer + optimizer = tf.keras.optimizers.Adam( + learning_rate=args.lr, epsilon=1e-8) + + # initialize graph + dur = [] + for epoch in range(args.epochs): + if epoch >= 3: + t0 = time.time() + # forward + with tf.GradientTape() as tape: + tape.watch(model.trainable_weights) + logits = model(features, training=True) + loss_value = tf.reduce_mean(loss_fcn( + labels=labels[train_mask], logits=logits[train_mask])) + # Manually Weight Decay + # We found Tensorflow has a different implementation on weight decay + # of Adam(W) optimizer with PyTorch. And this results in worse results. + # Manually adding weights to the loss to do weight decay solves this problem. + for weight in model.trainable_weights: + loss_value = loss_value + \ + args.weight_decay*tf.nn.l2_loss(weight) + + grads = tape.gradient(loss_value, model.trainable_weights) + optimizer.apply_gradients(zip(grads, model.trainable_weights)) + + if epoch >= 3: + dur.append(time.time() - t0) + + train_acc = accuracy(logits[train_mask], labels[train_mask]) + + if args.fastmode: + val_acc = accuracy(logits[val_mask], labels[val_mask]) + else: + val_acc = evaluate(model, features, labels, val_mask) + if args.early_stop: + if stopper.step(val_acc, model): + break + + print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | TrainAcc {:.4f} |" + " ValAcc {:.4f} | ETputs(KTEPS) {:.2f}". + format(epoch, np.mean(dur), loss_value.numpy().item(), train_acc, + val_acc, n_edges / np.mean(dur) / 1000)) + + print() + if args.early_stop: + model.load_weights('es_checkpoint.pb') + acc = evaluate(model, features, labels, test_mask) + print("Test Accuracy {:.4f}".format(acc)) + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser(description='GAT') + register_data_args(parser) + parser.add_argument("--gpu", type=int, default=-1, + help="which GPU to use. Set -1 to use CPU.") + parser.add_argument("--epochs", type=int, default=200, + help="number of training epochs") + parser.add_argument("--num-heads", type=int, default=8, + help="number of hidden attention heads") + parser.add_argument("--num-out-heads", type=int, default=1, + help="number of output attention heads") + parser.add_argument("--num-layers", type=int, default=1, + help="number of hidden layers") + parser.add_argument("--num-hidden", type=int, default=8, + help="number of hidden units") + parser.add_argument("--residual", action="store_true", default=False, + help="use residual connection") + parser.add_argument("--in-drop", type=float, default=.6, + help="input feature dropout") + parser.add_argument("--attn-drop", type=float, default=.6, + help="attention dropout") + parser.add_argument("--lr", type=float, default=0.005, + help="learning rate") + parser.add_argument('--weight-decay', type=float, default=5e-4, + help="weight decay") + parser.add_argument('--negative-slope', type=float, default=0.2, + help="the negative slope of leaky relu") + parser.add_argument('--early-stop', action='store_true', default=False, + help="indicates whether to use early stop or not") + parser.add_argument('--fastmode', action="store_true", default=False, + help="skip re-evaluate the validation set") + args = parser.parse_args() + print(args) + + main(args) diff --git a/examples/tensorflow/gat/utils.py b/examples/tensorflow/gat/utils.py new file mode 100644 index 000000000000..220c3427bda7 --- /dev/null +++ b/examples/tensorflow/gat/utils.py @@ -0,0 +1,28 @@ +import numpy as np + +class EarlyStopping: + def __init__(self, patience=10): + self.patience = patience + self.counter = 0 + self.best_score = None + self.early_stop = False + + def step(self, acc, model): + score = acc + if self.best_score is None: + self.best_score = score + self.save_checkpoint(model) + elif score < self.best_score: + self.counter += 1 + print(f'EarlyStopping counter: {self.counter} out of {self.patience}') + if self.counter >= self.patience: + self.early_stop = True + else: + self.best_score = score + self.save_checkpoint(model) + self.counter = 0 + return self.early_stop + + def save_checkpoint(self, model): + '''Saves model when validation loss decrease.''' + model.save_weights('es_checkpoint.pb') diff --git a/examples/tensorflow/gcn/README.md b/examples/tensorflow/gcn/README.md new file mode 100644 index 000000000000..3eafc4cd4d69 --- /dev/null +++ b/examples/tensorflow/gcn/README.md @@ -0,0 +1,35 @@ +Graph Convolutional Networks (GCN) +============ + +- Paper link: [https://arxiv.org/abs/1609.02907](https://arxiv.org/abs/1609.02907) +- Author's code repo: [https://github.com/tkipf/gcn](https://github.com/tkipf/gcn). Note that the original code is +implemented with Tensorflow for the paper. + +Dependencies +------------ +- Tensorflow 2.1+ +- requests + +``bash +pip install tensorflow requests +`` + +Codes +----- +The folder contains three implementations of GCN: +- `gcn.py` uses DGL's predefined graph convolution module. +- `gcn_mp.py` uses user-defined message and reduce functions. +- `gcn_builtin.py` improves from `gcn_mp.py` by using DGL's builtin functions + so SPMV optimization could be applied. + +Results +------- + +Run with following (available dataset: "cora", "citeseer", "pubmed") +```bash +python3 train.py --dataset cora --gpu 0 --self-loop +``` + +* cora: ~0.810 (0.79-0.83) (paper: 0.815) +* citeseer: 0.707 (paper: 0.703) +* pubmed: 0.792 (paper: 0.790) diff --git a/examples/tensorflow/gcn/gcn.py b/examples/tensorflow/gcn/gcn.py new file mode 100644 index 000000000000..3bb8e38d7ae6 --- /dev/null +++ b/examples/tensorflow/gcn/gcn.py @@ -0,0 +1,39 @@ +"""GCN using DGL nn package + +References: +- Semi-Supervised Classification with Graph Convolutional Networks +- Paper: https://arxiv.org/abs/1609.02907 +- Code: https://github.com/tkipf/gcn +""" +import tensorflow as tf +from tensorflow.keras import layers +from dgl.nn.tensorflow import GraphConv + +class GCN(tf.keras.Model): + def __init__(self, + g, + in_feats, + n_hidden, + n_classes, + n_layers, + activation, + dropout): + super(GCN, self).__init__() + self.g = g + self.layer_list = [] + # input layer + self.layer_list.append(GraphConv(in_feats, n_hidden, activation=activation)) + # hidden layers + for i in range(n_layers - 1): + self.layer_list.append(GraphConv(n_hidden, n_hidden, activation=activation)) + # output layer + self.layer_list.append(GraphConv(n_hidden, n_classes)) + self.dropout = layers.Dropout(dropout) + + def call(self, features): + h = features + for i, layer in enumerate(self.layer_list): + if i != 0: + h = self.dropout(h) + h = layer(self.g, h) + return h diff --git a/examples/tensorflow/gcn/gcn_builtin.py b/examples/tensorflow/gcn/gcn_builtin.py index 99b9897c20a2..d5d6354833ca 100644 --- a/examples/tensorflow/gcn/gcn_builtin.py +++ b/examples/tensorflow/gcn/gcn_builtin.py @@ -21,7 +21,8 @@ def __init__(self, super(GCNLayer, self).__init__() self.g = g - w_init = tf.random_normal_initializer() + w_init = tf.keras.initializers.VarianceScaling( + scale=1.0, mode="fan_out", distribution="uniform") self.weight = tf.Variable(initial_value=w_init(shape=(in_feats, out_feats), dtype='float32'), trainable=True) @@ -144,7 +145,7 @@ def main(args): args.dropout) optimizer = tf.keras.optimizers.Adam( - learning_rate=args.lr, decay=args.weight_decay) + learning_rate=args.lr) loss_fcn = tf.keras.losses.SparseCategoricalCrossentropy( from_logits=True) @@ -157,6 +158,13 @@ def main(args): with tf.GradientTape() as tape: logits = model(features) loss_value = loss_fcn(labels[train_mask], logits[train_mask]) + # Manually Weight Decay + # We found Tensorflow has a different implementation on weight decay + # of Adam(W) optimizer with PyTorch. And this results in worse results. + # Manually adding weights to the loss to do weight decay solves this problem. + for weight in model.trainable_weights: + loss_value = loss_value + \ + args.weight_decay*tf.nn.l2_loss(weight) grads = tape.gradient(loss_value, model.trainable_weights) optimizer.apply_gradients(zip(grads, model.trainable_weights)) diff --git a/examples/tensorflow/gcn/gcn_mp.py b/examples/tensorflow/gcn/gcn_mp.py index 72e7b2483598..d75cbfed63af 100644 --- a/examples/tensorflow/gcn/gcn_mp.py +++ b/examples/tensorflow/gcn/gcn_mp.py @@ -151,7 +151,7 @@ def main(args): args.dropout) optimizer = tf.keras.optimizers.Adam( - learning_rate=args.lr, decay=args.weight_decay) + learning_rate=args.lr) loss_fcn = tf.keras.losses.SparseCategoricalCrossentropy( from_logits=True) @@ -164,7 +164,13 @@ def main(args): with tf.GradientTape() as tape: logits = model(features) loss_value = loss_fcn(labels[train_mask], logits[train_mask]) - + # Manually Weight Decay + # We found Tensorflow has a different implementation on weight decay + # of Adam(W) optimizer with PyTorch. And this results in worse results. + # Manually adding weights to the loss to do weight decay solves this problem. + for weight in model.trainable_weights: + loss_value = loss_value + \ + args.weight_decay*tf.nn.l2_loss(weight) grads = tape.gradient(loss_value, model.trainable_weights) optimizer.apply_gradients(zip(grads, model.trainable_weights)) diff --git a/examples/tensorflow/gcn/train.py b/examples/tensorflow/gcn/train.py new file mode 100644 index 000000000000..e779578da3a5 --- /dev/null +++ b/examples/tensorflow/gcn/train.py @@ -0,0 +1,132 @@ +import argparse +import time +import numpy as np +import networkx as nx +import tensorflow as tf +from dgl import DGLGraph +from dgl.data import register_data_args, load_data +from gcn import GCN + + +def evaluate(model, features, labels, mask): + logits = model(features, training=False) + logits = logits[mask] + labels = labels[mask] + indices = tf.math.argmax(logits, axis=1) + acc = tf.reduce_mean(tf.cast(indices == labels, dtype=tf.float32)) + return acc.numpy().item() + + +def main(args): + # load and preprocess dataset + data = load_data(args) + + if args.gpu < 0: + device = "/cpu:0" + else: + device = "/gpu:{}".format(args.gpu) + + with tf.device(device): + features = tf.convert_to_tensor(data.features, dtype=tf.float32) + labels = tf.convert_to_tensor(data.labels, dtype=tf.int64) + train_mask = tf.convert_to_tensor(data.train_mask, dtype=tf.bool) + val_mask = tf.convert_to_tensor(data.val_mask, dtype=tf.bool) + test_mask = tf.convert_to_tensor(data.test_mask, dtype=tf.bool) + in_feats = features.shape[1] + n_classes = data.num_labels + n_edges = data.graph.number_of_edges() + print("""----Data statistics------' + #Edges %d + #Classes %d + #Train samples %d + #Val samples %d + #Test samples %d""" % + (n_edges, n_classes, + train_mask.numpy().sum(), + val_mask.numpy().sum(), + test_mask.numpy().sum())) + + # graph preprocess and calculate normalization factor + g = data.graph + if args.self_loop: + g.remove_edges_from(nx.selfloop_edges(g)) + g.add_edges_from(zip(g.nodes(), g.nodes())) + g = DGLGraph(g) + n_edges = g.number_of_edges() + # normalization + degs = tf.cast(tf.identity(g.in_degrees()), dtype=tf.float32) + norm = tf.math.pow(degs, -0.5) + norm = tf.where(tf.math.is_inf(norm), tf.zeros_like(norm), norm) + + g.ndata['norm'] = tf.expand_dims(norm, -1) + + # create GCN model + model = GCN(g, + in_feats, + args.n_hidden, + n_classes, + args.n_layers, + tf.nn.relu, + args.dropout) + + loss_fcn = tf.keras.losses.SparseCategoricalCrossentropy( + from_logits=True) + # use optimizer + optimizer = tf.keras.optimizers.Adam( + learning_rate=args.lr, epsilon=1e-8) + + # initialize graph + dur = [] + for epoch in range(args.n_epochs): + if epoch >= 3: + t0 = time.time() + # forward + with tf.GradientTape() as tape: + logits = model(features) + loss_value = loss_fcn(labels[train_mask], logits[train_mask]) + # Manually Weight Decay + # We found Tensorflow has a different implementation on weight decay + # of Adam(W) optimizer with PyTorch. And this results in worse results. + # Manually adding weights to the loss to do weight decay solves this problem. + for weight in model.trainable_weights: + loss_value = loss_value + \ + args.weight_decay*tf.nn.l2_loss(weight) + + grads = tape.gradient(loss_value, model.trainable_weights) + optimizer.apply_gradients(zip(grads, model.trainable_weights)) + if epoch >= 3: + dur.append(time.time() - t0) + + acc = evaluate(model, features, labels, val_mask) + print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | " + "ETputs(KTEPS) {:.2f}". format(epoch, np.mean(dur), loss_value.numpy().item(), + acc, n_edges / np.mean(dur) / 1000)) + + acc = evaluate(model, features, labels, test_mask) + print("Test Accuracy {:.4f}".format(acc)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='GCN') + register_data_args(parser) + parser.add_argument("--dropout", type=float, default=0.5, + help="dropout probability") + parser.add_argument("--gpu", type=int, default=-1, + help="gpu") + parser.add_argument("--lr", type=float, default=1e-2, + help="learning rate") + parser.add_argument("--n-epochs", type=int, default=200, + help="number of training epochs") + parser.add_argument("--n-hidden", type=int, default=16, + help="number of hidden gcn units") + parser.add_argument("--n-layers", type=int, default=1, + help="number of hidden gcn layers") + parser.add_argument("--weight-decay", type=float, default=5e-4, + help="Weight for L2 loss") + parser.add_argument("--self-loop", action='store_true', + help="graph self-loop (default=False)") + parser.set_defaults(self_loop=False) + args = parser.parse_args() + print(args) + + main(args) diff --git a/examples/tensorflow/rgcn/README.md b/examples/tensorflow/rgcn/README.md new file mode 100644 index 000000000000..4ced18682f68 --- /dev/null +++ b/examples/tensorflow/rgcn/README.md @@ -0,0 +1,33 @@ +# Relational-GCN + +* Paper: [https://arxiv.org/abs/1703.06103](https://arxiv.org/abs/1703.06103) +* Author's code for entity classification: [https://github.com/tkipf/relational-gcn](https://github.com/tkipf/relational-gcn) +* Author's code for link prediction: [https://github.com/MichSchli/RelationPrediction](https://github.com/MichSchli/RelationPrediction) + +### Dependencies +* Tensorflow 2.1+ +* requests +* rdflib +* pandas + +``` +pip install requests torch rdflib pandas +``` + +Example code was tested with rdflib 4.2.2 and pandas 0.23.4 + +### Entity Classification +AIFB: accuracy 97.22% (DGL), 95.83% (paper) +``` +python3 entity_classify.py -d aifb --testing --gpu 0 +``` + +MUTAG: accuracy 75% (DGL), 73.23% (paper) +``` +python3 entity_classify.py -d mutag --l2norm 5e-4 --n-bases 30 --testing --gpu 0 +``` + +BGS: accuracy 79.3% (DGL n-base=25), 83.10% (paper n-base=40) +``` +python3 entity_classify.py -d bgs --l2norm 5e-4 --n-bases 25 --testing --gpu 0 --relabel +``` diff --git a/examples/tensorflow/rgcn/entity_classify.py b/examples/tensorflow/rgcn/entity_classify.py new file mode 100644 index 000000000000..a8d53ce1a876 --- /dev/null +++ b/examples/tensorflow/rgcn/entity_classify.py @@ -0,0 +1,180 @@ +""" +Modeling Relational Data with Graph Convolutional Networks +Paper: https://arxiv.org/abs/1703.06103 +Code: https://github.com/tkipf/relational-gcn + +Difference compared to tkipf/relation-gcn +* l2norm applied to all weights +* remove nodes that won't be touched +""" + +import argparse +import numpy as np +import time +import tensorflow as tf +from tensorflow.keras import layers +from dgl import DGLGraph +from dgl.nn.tensorflow import RelGraphConv +from dgl.contrib.data import load_data +from functools import partial + +from model import BaseRGCN + +class EntityClassify(BaseRGCN): + def create_features(self): + features = tf.range(self.num_nodes) + return features + + def build_input_layer(self): + return RelGraphConv(self.num_nodes, self.h_dim, self.num_rels, "basis", + self.num_bases, activation=tf.nn.relu, self_loop=self.use_self_loop, + dropout=self.dropout) + + def build_hidden_layer(self, idx): + return RelGraphConv(self.h_dim, self.h_dim, self.num_rels, "basis", + self.num_bases, activation=tf.nn.relu, self_loop=self.use_self_loop, + dropout=self.dropout) + + def build_output_layer(self): + return RelGraphConv(self.h_dim, self.out_dim, self.num_rels, "basis", + self.num_bases, activation=partial(tf.nn.softmax, axis=1), + self_loop=self.use_self_loop) + +def acc(logits, labels, mask): + logits = tf.gather(logits, mask) + labels = tf.gather(labels, mask) + indices = tf.math.argmax(logits, axis=1) + acc = tf.reduce_mean(tf.cast(indices == labels, dtype=tf.float32)) + return acc + +def main(args): + # load graph data + data = load_data(args.dataset, bfs_level=args.bfs_level, relabel=args.relabel) + num_nodes = data.num_nodes + num_rels = data.num_rels + num_classes = data.num_classes + labels = data.labels + train_idx = data.train_idx + test_idx = data.test_idx + + # split dataset into train, validate, test + if args.validation: + val_idx = train_idx[:len(train_idx) // 5] + train_idx = train_idx[len(train_idx) // 5:] + else: + val_idx = train_idx + + # since the nodes are featureless, the input feature is then the node id. + feats = tf.range(num_nodes, dtype=tf.int64) + + # edge type and normalization factor + edge_type = tf.convert_to_tensor(data.edge_type) + edge_norm = tf.expand_dims(tf.convert_to_tensor(data.edge_norm), 1) + labels = tf.reshape(tf.convert_to_tensor(labels), (-1, )) + + # check cuda + if args.gpu < 0: + device = "/cpu:0" + use_cuda = False + else: + device = "/gpu:{}".format(args.gpu) + use_cuda = True + + with tf.device(device): + + # create graph + g = DGLGraph() + g.add_nodes(num_nodes) + g.add_edges(data.edge_src, data.edge_dst) + + # create model + model = EntityClassify(len(g), + args.n_hidden, + num_classes, + num_rels, + num_bases=args.n_bases, + num_hidden_layers=args.n_layers - 2, + dropout=args.dropout, + use_self_loop=args.use_self_loop, + use_cuda=use_cuda) + + # optimizer + optimizer = tf.keras.optimizers.Adam( + learning_rate=args.lr) + # training loop + print("start training...") + forward_time = [] + backward_time = [] + loss_fcn = tf.keras.losses.SparseCategoricalCrossentropy( + from_logits=False) + for epoch in range(args.n_epochs): + t0 = time.time() + with tf.GradientTape() as tape: + logits = model(g, feats, edge_type, edge_norm) + loss = loss_fcn(tf.gather(labels, train_idx), tf.gather(logits, train_idx)) + # Manually Weight Decay + # We found Tensorflow has a different implementation on weight decay + # of Adam(W) optimizer with PyTorch. And this results in worse results. + # Manually adding weights to the loss to do weight decay solves this problem. + for weight in model.trainable_weights: + loss = loss + \ + args.l2norm * tf.nn.l2_loss(weight) + t1 = time.time() + grads = tape.gradient(loss, model.trainable_weights) + optimizer.apply_gradients(zip(grads, model.trainable_weights)) + t2 = time.time() + + forward_time.append(t1 - t0) + backward_time.append(t2 - t1) + print("Epoch {:05d} | Train Forward Time(s) {:.4f} | Backward Time(s) {:.4f}". + format(epoch, forward_time[-1], backward_time[-1])) + train_acc = acc(logits, labels, train_idx) + val_loss = loss_fcn(tf.gather(labels, val_idx), tf.gather(logits, val_idx)) + val_acc = acc(logits, labels, val_idx) + print("Train Accuracy: {:.4f} | Train Loss: {:.4f} | Validation Accuracy: {:.4f} | Validation loss: {:.4f}". + format(train_acc, loss.numpy().item(), val_acc, val_loss.numpy().item())) + print() + + logits = model(g, feats, edge_type, edge_norm) + test_loss = loss_fcn(tf.gather(labels, test_idx), tf.gather(logits, test_idx)) + test_acc = acc(logits, labels, test_idx) + print("Test Accuracy: {:.4f} | Test loss: {:.4f}".format(test_acc, test_loss.numpy().item())) + print() + + print("Mean forward time: {:4f}".format(np.mean(forward_time[len(forward_time) // 4:]))) + print("Mean backward time: {:4f}".format(np.mean(backward_time[len(backward_time) // 4:]))) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='RGCN') + parser.add_argument("--dropout", type=float, default=0, + help="dropout probability") + parser.add_argument("--n-hidden", type=int, default=16, + help="number of hidden units") + parser.add_argument("--gpu", type=int, default=-1, + help="gpu") + parser.add_argument("--lr", type=float, default=1e-2, + help="learning rate") + parser.add_argument("--n-bases", type=int, default=-1, + help="number of filter weight matrices, default: -1 [use all]") + parser.add_argument("--n-layers", type=int, default=2, + help="number of propagation rounds") + parser.add_argument("-e", "--n-epochs", type=int, default=50, + help="number of training epochs") + parser.add_argument("-d", "--dataset", type=str, required=True, + help="dataset to use") + parser.add_argument("--l2norm", type=float, default=0, + help="l2 norm coef") + parser.add_argument("--relabel", default=False, action='store_true', + help="remove untouched nodes and relabel") + parser.add_argument("--use-self-loop", default=False, action='store_true', + help="include self feature as a special relation") + fp = parser.add_mutually_exclusive_group(required=False) + fp.add_argument('--validation', dest='validation', action='store_true') + fp.add_argument('--testing', dest='validation', action='store_false') + parser.set_defaults(validation=True) + + args = parser.parse_args() + print(args) + args.bfs_level = args.n_layers + 1 # pruning used nodes for memory + main(args) diff --git a/examples/tensorflow/rgcn/model.py b/examples/tensorflow/rgcn/model.py new file mode 100644 index 000000000000..7b539d3b2414 --- /dev/null +++ b/examples/tensorflow/rgcn/model.py @@ -0,0 +1,49 @@ +import tensorflow as tf +from tensorflow.keras import layers + +class BaseRGCN(layers.Layer): + def __init__(self, num_nodes, h_dim, out_dim, num_rels, num_bases, + num_hidden_layers=1, dropout=0, + use_self_loop=False, use_cuda=False): + super(BaseRGCN, self).__init__() + self.num_nodes = num_nodes + self.h_dim = h_dim + self.out_dim = out_dim + self.num_rels = num_rels + self.num_bases = None if num_bases < 0 else num_bases + self.num_hidden_layers = num_hidden_layers + self.dropout = dropout + self.use_self_loop = use_self_loop + self.use_cuda = use_cuda + + # create rgcn layers + self.build_model() + + def build_model(self): + self.layers = [] + # i2h + i2h = self.build_input_layer() + if i2h is not None: + self.layers.append(i2h) + # h2h + for idx in range(self.num_hidden_layers): + h2h = self.build_hidden_layer(idx) + self.layers.append(h2h) + # h2o + h2o = self.build_output_layer() + if h2o is not None: + self.layers.append(h2o) + + def build_input_layer(self): + return None + + def build_hidden_layer(self, idx): + raise NotImplementedError + + def build_output_layer(self): + return None + + def call(self, g, h, r, norm): + for layer in self.layers: + h = layer(g, h, r, norm) + return h diff --git a/examples/tensorflow/rgcn/utils.py b/examples/tensorflow/rgcn/utils.py new file mode 100644 index 000000000000..ec57800ee455 --- /dev/null +++ b/examples/tensorflow/rgcn/utils.py @@ -0,0 +1,165 @@ +""" +Utility functions for link prediction +Most code is adapted from authors' implementation of RGCN link prediction: +https://github.com/MichSchli/RelationPrediction + +""" + +import numpy as np +import tensorflow as tf +import dgl + +####################################################################### +# +# Utility function for building training and testing graphs +# +####################################################################### + +def get_adj_and_degrees(num_nodes, triplets): + """ Get adjacency list and degrees of the graph + """ + adj_list = [[] for _ in range(num_nodes)] + for i,triplet in enumerate(triplets): + adj_list[triplet[0]].append([i, triplet[2]]) + adj_list[triplet[2]].append([i, triplet[0]]) + + degrees = np.array([len(a) for a in adj_list]) + adj_list = [np.array(a) for a in adj_list] + return adj_list, degrees + +def sample_edge_neighborhood(adj_list, degrees, n_triplets, sample_size): + """Sample edges by neighborhool expansion. + + This guarantees that the sampled edges form a connected graph, which + may help deeper GNNs that require information from more than one hop. + """ + edges = np.zeros((sample_size), dtype=np.int32) + + #initialize + sample_counts = np.array([d for d in degrees]) + picked = np.array([False for _ in range(n_triplets)]) + seen = np.array([False for _ in degrees]) + + for i in range(0, sample_size): + weights = sample_counts * seen + + if np.sum(weights) == 0: + weights = np.ones_like(weights) + weights[np.where(sample_counts == 0)] = 0 + + probabilities = (weights) / np.sum(weights) + chosen_vertex = np.random.choice(np.arange(degrees.shape[0]), + p=probabilities) + chosen_adj_list = adj_list[chosen_vertex] + seen[chosen_vertex] = True + + chosen_edge = np.random.choice(np.arange(chosen_adj_list.shape[0])) + chosen_edge = chosen_adj_list[chosen_edge] + edge_number = chosen_edge[0] + + while picked[edge_number]: + chosen_edge = np.random.choice(np.arange(chosen_adj_list.shape[0])) + chosen_edge = chosen_adj_list[chosen_edge] + edge_number = chosen_edge[0] + + edges[i] = edge_number + other_vertex = chosen_edge[1] + picked[edge_number] = True + sample_counts[chosen_vertex] -= 1 + sample_counts[other_vertex] -= 1 + seen[other_vertex] = True + + return edges + +def sample_edge_uniform(adj_list, degrees, n_triplets, sample_size): + """Sample edges uniformly from all the edges.""" + all_edges = np.arange(n_triplets) + return np.random.choice(all_edges, sample_size, replace=False) + +def generate_sampled_graph_and_labels(triplets, sample_size, split_size, + num_rels, adj_list, degrees, + negative_rate, sampler="uniform"): + """Get training graph and signals + First perform edge neighborhood sampling on graph, then perform negative + sampling to generate negative samples + """ + # perform edge neighbor sampling + if sampler == "uniform": + edges = sample_edge_uniform(adj_list, degrees, len(triplets), sample_size) + elif sampler == "neighbor": + edges = sample_edge_neighborhood(adj_list, degrees, len(triplets), sample_size) + else: + raise ValueError("Sampler type must be either 'uniform' or 'neighbor'.") + + # relabel nodes to have consecutive node ids + edges = triplets[edges] + src, rel, dst = edges.transpose() + uniq_v, edges = np.unique((src, dst), return_inverse=True) + src, dst = np.reshape(edges, (2, -1)) + relabeled_edges = np.stack((src, rel, dst)).transpose() + + # negative sampling + samples, labels = negative_sampling(relabeled_edges, len(uniq_v), + negative_rate) + + # further split graph, only half of the edges will be used as graph + # structure, while the rest half is used as unseen positive samples + split_size = int(sample_size * split_size) + graph_split_ids = np.random.choice(np.arange(sample_size), + size=split_size, replace=False) + src = src[graph_split_ids] + dst = dst[graph_split_ids] + rel = rel[graph_split_ids] + + # build DGL graph + print("# sampled nodes: {}".format(len(uniq_v))) + print("# sampled edges: {}".format(len(src) * 2)) + g, rel, norm = build_graph_from_triplets(len(uniq_v), num_rels, + (src, rel, dst)) + return g, uniq_v, rel, norm, samples, labels + +def comp_deg_norm(g): + g = g.local_var() + in_deg = g.in_degrees(range(g.number_of_nodes())).float().numpy() + norm = 1.0 / in_deg + norm[np.isinf(norm)] = 0 + return norm + +def build_graph_from_triplets(num_nodes, num_rels, triplets): + """ Create a DGL graph. The graph is bidirectional because RGCN authors + use reversed relations. + This function also generates edge type and normalization factor + (reciprocal of node incoming degree) + """ + g = dgl.DGLGraph() + g.add_nodes(num_nodes) + src, rel, dst = triplets + src, dst = np.concatenate((src, dst)), np.concatenate((dst, src)) + rel = np.concatenate((rel, rel + num_rels)) + edges = sorted(zip(dst, src, rel)) + dst, src, rel = np.array(edges).transpose() + g.add_edges(src, dst) + norm = comp_deg_norm(g) + print("# nodes: {}, # edges: {}".format(num_nodes, len(src))) + return g, rel, norm + +def build_test_graph(num_nodes, num_rels, edges): + src, rel, dst = edges.transpose() + print("Test graph:") + return build_graph_from_triplets(num_nodes, num_rels, (src, rel, dst)) + +def negative_sampling(pos_samples, num_entity, negative_rate): + size_of_batch = len(pos_samples) + num_to_generate = size_of_batch * negative_rate + neg_samples = np.tile(pos_samples, (negative_rate, 1)) + labels = np.zeros(size_of_batch * (negative_rate + 1), dtype=np.float32) + labels[: size_of_batch] = 1 + values = np.random.randint(num_entity, size=num_to_generate) + choices = np.random.uniform(size=num_to_generate) + subj = choices > 0.5 + obj = choices <= 0.5 + neg_samples[subj, 0] = values[subj] + neg_samples[obj, 2] = values[obj] + + return np.concatenate((pos_samples, neg_samples)), labels + diff --git a/python/dgl/nn/tensorflow/__init__.py b/python/dgl/nn/tensorflow/__init__.py new file mode 100644 index 000000000000..8d6e05feaa25 --- /dev/null +++ b/python/dgl/nn/tensorflow/__init__.py @@ -0,0 +1,5 @@ +"""Package for Tensorflow-specific NN modules.""" +from .conv import * +from .softmax import * +from .utils import * +from .glob import * diff --git a/python/dgl/nn/tensorflow/conv/__init__.py b/python/dgl/nn/tensorflow/conv/__init__.py new file mode 100644 index 000000000000..73f21c87679d --- /dev/null +++ b/python/dgl/nn/tensorflow/conv/__init__.py @@ -0,0 +1,8 @@ +"""TF NN conv module""" +from .gatconv import GATConv +from .relgraphconv import RelGraphConv +from .graphconv import GraphConv +from .ginconv import GINConv +from .sageconv import SAGEConv +from .sgconv import SGConv +from .appnpconv import APPNPConv diff --git a/python/dgl/nn/tensorflow/conv/appnpconv.py b/python/dgl/nn/tensorflow/conv/appnpconv.py new file mode 100644 index 000000000000..300c6d8e98a8 --- /dev/null +++ b/python/dgl/nn/tensorflow/conv/appnpconv.py @@ -0,0 +1,77 @@ +"""TF Module for APPNPConv""" +# pylint: disable= no-member, arguments-differ, invalid-name +import tensorflow as tf +from tensorflow.keras import layers +import numpy as np + +from .... import function as fn + + +class APPNPConv(layers.Layer): + r"""Approximate Personalized Propagation of Neural Predictions + layer from paper `Predict then Propagate: Graph Neural Networks + meet Personalized PageRank `__. + + .. math:: + H^{0} & = X + + H^{t+1} & = (1-\alpha)\left(\hat{D}^{-1/2} + \hat{A} \hat{D}^{-1/2} H^{t}\right) + \alpha H^{0} + + Parameters + ---------- + k : int + Number of iterations :math:`K`. + alpha : float + The teleport probability :math:`\alpha`. + edge_drop : float, optional + Dropout rate on edges that controls the + messages received by each node. Default: ``0``. + """ + + def __init__(self, + k, + alpha, + edge_drop=0.): + super(APPNPConv, self).__init__() + self._k = k + self._alpha = alpha + self.edge_drop = layers.Dropout(edge_drop) + + def call(self, graph, feat): + r"""Compute APPNP layer. + + Parameters + ---------- + graph : DGLGraph + The graph. + feat : tf.Tensor + The input feature of shape :math:`(N, *)` :math:`N` is the + number of nodes, and :math:`*` could be of any shape. + + Returns + ------- + tf.Tensor + The output feature of shape :math:`(N, *)` where :math:`*` + should be the same as input shape. + """ + graph = graph.local_var() + degs = tf.clip_by_value(tf.cast(graph.in_degrees(), tf.float32), + clip_value_min=1, clip_value_max=np.inf) + norm = tf.pow(degs, -0.5) + shp = norm.shape + (1,) * (feat.ndim - 1) + norm = tf.reshape(norm, shp) + feat_0 = feat + for _ in range(self._k): + # normalization by src node + feat = feat * norm + graph.ndata['h'] = feat + graph.edata['w'] = self.edge_drop( + tf.ones(graph.number_of_edges(), 1)) + graph.update_all(fn.u_mul_e('h', 'w', 'm'), + fn.sum('m', 'h')) + feat = graph.ndata.pop('h') + # normalization by dst node + feat = feat * norm + feat = (1 - self._alpha) * feat + self._alpha * feat_0 + return feat diff --git a/python/dgl/nn/tensorflow/conv/gatconv.py b/python/dgl/nn/tensorflow/conv/gatconv.py new file mode 100644 index 000000000000..72cf4573ef6d --- /dev/null +++ b/python/dgl/nn/tensorflow/conv/gatconv.py @@ -0,0 +1,126 @@ +"""Tensorflow modules for graph attention networks(GAT).""" +# pylint: disable= no-member, arguments-differ, invalid-name +import tensorflow as tf +from tensorflow.keras import layers +import numpy as np + +from .... import function as fn +from ..softmax import edge_softmax +from ..utils import Identity + +# pylint: enable=W0235 + + +class GATConv(layers.Layer): + r"""Apply `Graph Attention Network `__ + over an input signal. + + .. math:: + h_i^{(l+1)} = \sum_{j\in \mathcal{N}(i)} \alpha_{i,j} W^{(l)} h_j^{(l)} + + where :math:`\alpha_{ij}` is the attention score bewteen node :math:`i` and + node :math:`j`: + + .. math:: + \alpha_{ij}^{l} & = \mathrm{softmax_i} (e_{ij}^{l}) + + e_{ij}^{l} & = \mathrm{LeakyReLU}\left(\vec{a}^T [W h_{i} \| W h_{j}]\right) + + Parameters + ---------- + in_feats : int + Input feature size. + out_feats : int + Output feature size. + num_heads : int + Number of heads in Multi-Head Attention. + feat_drop : float, optional + Dropout rate on feature, defaults: ``0``. + attn_drop : float, optional + Dropout rate on attention weight, defaults: ``0``. + negative_slope : float, optional + LeakyReLU angle of negative slope. + residual : bool, optional + If True, use residual connection. + activation : callable activation function/layer or None, optional. + If not None, applies an activation function to the updated node features. + Default: ``None``. + """ + + def __init__(self, + in_feats, + out_feats, + num_heads, + feat_drop=0., + attn_drop=0., + negative_slope=0.2, + residual=False, + activation=None): + super(GATConv, self).__init__() + self._num_heads = num_heads + self._in_feats = in_feats + self._out_feats = out_feats + xinit = tf.keras.initializers.VarianceScaling(scale=np.sqrt( + 2), mode="fan_avg", distribution="untruncated_normal") + self.fc = layers.Dense( + out_feats * num_heads, use_bias=False, kernel_initializer=xinit) + self.attn_l = tf.Variable(initial_value=xinit( + shape=(1, num_heads, out_feats), dtype='float32'), trainable=True) + + self.attn_r = tf.Variable(initial_value=xinit( + shape=(1, num_heads, out_feats), dtype='float32'), trainable=True) + self.feat_drop = layers.Dropout(rate=feat_drop) + self.attn_drop = layers.Dropout(rate=attn_drop) + self.leaky_relu = layers.LeakyReLU(alpha=negative_slope) + if residual: + if in_feats != out_feats: + self.res_fc = layers.Dense( + num_heads * out_feats, use_bias=False, kernel_initializer=xinit) + else: + self.res_fc = Identity() + else: + self.res_fc = None + # self.register_buffer('res_fc', None) + self.activation = activation + + def call(self, graph, feat): + r"""Compute graph attention network layer. + + Parameters + ---------- + graph : DGLGraph + The graph. + feat : tf.Tensor + The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` + is size of input feature, :math:`N` is the number of nodes. + + Returns + ------- + tf.Tensor + The output feature of shape :math:`(N, H, D_{out})` where :math:`H` + is the number of heads, and :math:`D_{out}` is size of output feature. + """ + graph = graph.local_var() + h = self.feat_drop(feat) + feat = tf.reshape(self.fc(h), (-1, self._num_heads, self._out_feats)) + el = tf.reduce_sum(feat * self.attn_l, axis=-1, keepdims=True) + er = tf.reduce_sum(feat * self.attn_r, axis=-1, keepdims=True) + graph.ndata.update({'ft': feat, 'el': el, 'er': er}) + # compute edge attention + graph.apply_edges(fn.u_add_v('el', 'er', 'e')) + e = self.leaky_relu(graph.edata.pop('e')) + # compute softmax + graph.edata['a'] = self.attn_drop(edge_softmax(graph, e)) + # message passing + graph.update_all(fn.u_mul_e('ft', 'a', 'm'), + fn.sum('m', 'ft')) + rst = graph.ndata['ft'] + # residual + if self.res_fc is not None: + resval = tf.reshape(self.res_fc( + h), (h.shape[0], -1, self._out_feats)) + rst = rst + resval + # activation + if self.activation: + rst = self.activation(rst) + return rst diff --git a/python/dgl/nn/tensorflow/conv/ginconv.py b/python/dgl/nn/tensorflow/conv/ginconv.py new file mode 100644 index 000000000000..496117144804 --- /dev/null +++ b/python/dgl/nn/tensorflow/conv/ginconv.py @@ -0,0 +1,75 @@ +"""Tensorflow Module for Graph Isomorphism Network layer""" +# pylint: disable= no-member, arguments-differ, invalid-name +import tensorflow as tf +from tensorflow.keras import layers + +from .... import function as fn + + +class GINConv(layers.Layer): + r"""Graph Isomorphism Network layer from paper `How Powerful are Graph + Neural Networks? `__. + + .. math:: + h_i^{(l+1)} = f_\Theta \left((1 + \epsilon) h_i^{l} + + \mathrm{aggregate}\left(\left\{h_j^{l}, j\in\mathcal{N}(i) + \right\}\right)\right) + + Parameters + ---------- + apply_func : callable activation function/layer or None + If not None, apply this function to the updated node feature, + the :math:`f_\Theta` in the formula. + aggregator_type : str + Aggregator type to use (``sum``, ``max`` or ``mean``). + init_eps : float, optional + Initial :math:`\epsilon` value, default: ``0``. + learn_eps : bool, optional + If True, :math:`\epsilon` will be a learnable parameter. + """ + def __init__(self, + apply_func, + aggregator_type, + init_eps=0, + learn_eps=False): + super(GINConv, self).__init__() + self.apply_func = apply_func + if aggregator_type == 'sum': + self._reducer = fn.sum + elif aggregator_type == 'max': + self._reducer = fn.max + elif aggregator_type == 'mean': + self._reducer = fn.mean + else: + raise KeyError('Aggregator type {} not recognized.'.format(aggregator_type)) + # to specify whether eps is trainable or not. + self.eps = tf.Variable(initial_value=[init_eps], dtype=tf.float32, trainable=learn_eps) + + def call(self, graph, feat): + r"""Compute Graph Isomorphism Network layer. + + Parameters + ---------- + graph : DGLGraph + The graph. + feat : tf.Tensor + The input feature of shape :math:`(N, D)` where :math:`D` + could be any positive integer, :math:`N` is the number + of nodes. If ``apply_func`` is not None, :math:`D` should + fit the input dimensionality requirement of ``apply_func``. + + Returns + ------- + tf.Tensor + The output feature of shape :math:`(N, D_{out})` where + :math:`D_{out}` is the output dimensionality of ``apply_func``. + If ``apply_func`` is None, :math:`D_{out}` should be the same + as input dimensionality. + """ + graph = graph.local_var() + graph.ndata['h'] = feat + graph.update_all(fn.copy_u('h', 'm'), self._reducer('m', 'neigh')) + rst = (1 + self.eps) * feat + graph.ndata['neigh'] + if self.apply_func is not None: + rst = self.apply_func(rst) + return rst diff --git a/python/dgl/nn/tensorflow/conv/graphconv.py b/python/dgl/nn/tensorflow/conv/graphconv.py new file mode 100644 index 000000000000..42db8acc8b8c --- /dev/null +++ b/python/dgl/nn/tensorflow/conv/graphconv.py @@ -0,0 +1,150 @@ +"""Tensorflow modules for graph convolutions(GCN).""" +# pylint: disable= no-member, arguments-differ, invalid-name +import tensorflow as tf +from tensorflow.keras import layers +import numpy as np + +from .... import function as fn + +# pylint: disable=W0235 + + +class GraphConv(layers.Layer): + r"""Apply graph convolution over an input signal. + + Graph convolution is introduced in `GCN `__ + and can be described as below: + + .. math:: + h_i^{(l+1)} = \sigma(b^{(l)} + \sum_{j\in\mathcal{N}(i)}\frac{1}{c_{ij}}h_j^{(l)}W^{(l)}) + + where :math:`\mathcal{N}(i)` is the neighbor set of node :math:`i`. :math:`c_{ij}` is equal + to the product of the square root of node degrees: + :math:`\sqrt{|\mathcal{N}(i)|}\sqrt{|\mathcal{N}(j)|}`. :math:`\sigma` is an activation + function. + + The model parameters are initialized as in the + `original implementation `__ where + the weight :math:`W^{(l)}` is initialized using Glorot uniform initialization + and the bias is initialized to be zero. + + Notes + ----- + Zero in degree nodes could lead to invalid normalizer. A common practice + to avoid this is to add a self-loop for each node in the graph, which + can be achieved by: + + >>> g = ... # some DGLGraph + >>> g.add_edges(g.nodes(), g.nodes()) + + + Parameters + ---------- + in_feats : int + Input feature size. + out_feats : int + Output feature size. + norm : bool, optional + If True, the normalizer :math:`c_{ij}` is applied. Default: ``True``. + bias : bool, optional + If True, adds a learnable bias to the output. Default: ``True``. + activation: callable activation function/layer or None, optional + If not None, applies an activation function to the updated node features. + Default: ``None``. + + Attributes + ---------- + weight : tf.Tensor + The learnable weight tensor. + bias : tf.Tensor + The learnable bias tensor. + """ + + def __init__(self, + in_feats, + out_feats, + norm=True, + bias=True, + activation=None): + super(GraphConv, self).__init__() + self._in_feats = in_feats + self._out_feats = out_feats + self._norm = norm + + xinit = tf.keras.initializers.glorot_uniform() + self.weight = tf.Variable(initial_value=xinit( + shape=(in_feats, out_feats), dtype='float32'), trainable=True) + + if bias: + zeroinit = tf.keras.initializers.zeros() + self.bias = tf.Variable(initial_value=zeroinit( + shape=(out_feats), dtype='float32'), trainable=True) + + self._activation = activation + + def call(self, graph, feat): + r"""Compute graph convolution. + + Notes + ----- + * Input shape: :math:`(N, *, \text{in_feats})` where * means any number of additional + dimensions, :math:`N` is the number of nodes. + * Output shape: :math:`(N, *, \text{out_feats})` where all but the last dimension are + the same shape as the input. + + Parameters + ---------- + graph : DGLGraph + The graph. + feat : tf.Tensor + The input feature + + Returns + ------- + tf.Tensor + The output feature + """ + graph = graph.local_var() + if self._norm: + in_degree = tf.clip_by_value(tf.cast(graph.in_degrees(), tf.float32), clip_value_min=1, + clip_value_max=np.inf) + norm = tf.pow(in_degree, -0.5) + shp = norm.shape + (1,) * (feat.ndim - 1) + norm = tf.reshape(norm, shp) + feat = feat * norm + + if self._in_feats > self._out_feats: + # mult W first to reduce the feature size for aggregation. + feat = tf.matmul(feat, self.weight) + graph.ndata['h'] = feat + graph.update_all(fn.copy_src(src='h', out='m'), + fn.sum(msg='m', out='h')) + rst = graph.ndata['h'] + else: + # aggregate first then mult W + graph.ndata['h'] = feat + graph.update_all(fn.copy_src(src='h', out='m'), + fn.sum(msg='m', out='h')) + rst = graph.ndata['h'] + rst = tf.matmul(rst, self.weight) + + if self._norm: + rst = rst * norm + + if self.bias is not None: + rst = rst + self.bias + + if self._activation is not None: + rst = self._activation(rst) + + return rst + + def extra_repr(self): + """Set the extra representation of the module, + which will come into effect when printing the model. + """ + summary = 'in={_in_feats}, out={_out_feats}' + summary += ', normalization={_norm}' + if '_activation' in self.__dict__: + summary += ', activation={_activation}' + return summary.format(**self.__dict__) diff --git a/python/dgl/nn/tensorflow/conv/relgraphconv.py b/python/dgl/nn/tensorflow/conv/relgraphconv.py new file mode 100644 index 000000000000..4f82483b2491 --- /dev/null +++ b/python/dgl/nn/tensorflow/conv/relgraphconv.py @@ -0,0 +1,197 @@ +"""Tensorflow Module for Relational graph convolution layer""" +# pylint: disable= no-member, arguments-differ, invalid-name +import tensorflow as tf +from tensorflow.keras import layers + +from .... import function as fn +from .. import utils + + +class RelGraphConv(layers.Layer): + r"""Relational graph convolution layer. + + Relational graph convolution is introduced in "`Modeling Relational Data with Graph + Convolutional Networks `__" + and can be described as below: + + .. math:: + + h_i^{(l+1)} = \sigma(\sum_{r\in\mathcal{R}} + \sum_{j\in\mathcal{N}^r(i)}\frac{1}{c_{i,r}}W_r^{(l)}h_j^{(l)}+W_0^{(l)}h_i^{(l)}) + + where :math:`\mathcal{N}^r(i)` is the neighbor set of node :math:`i` w.r.t. relation + :math:`r`. :math:`c_{i,r}` is the normalizer equal + to :math:`|\mathcal{N}^r(i)|`. :math:`\sigma` is an activation function. :math:`W_0` + is the self-loop weight. + + The basis regularization decomposes :math:`W_r` by: + + .. math:: + + W_r^{(l)} = \sum_{b=1}^B a_{rb}^{(l)}V_b^{(l)} + + where :math:`B` is the number of bases. + + The block-diagonal-decomposition regularization decomposes :math:`W_r` into :math:`B` + number of block diagonal matrices. We refer :math:`B` as the number of bases. + + Parameters + ---------- + in_feat : int + Input feature size. + out_feat : int + Output feature size. + num_rels : int + Number of relations. + regularizer : str + Which weight regularizer to use "basis" or "bdd" + num_bases : int, optional + Number of bases. If is none, use number of relations. Default: None. + bias : bool, optional + True if bias is added. Default: True + activation : callable, optional + Activation function. Default: None + self_loop : bool, optional + True to include self loop message. Default: False + dropout : float, optional + Dropout rate. Default: 0.0 + """ + + def __init__(self, + in_feat, + out_feat, + num_rels, + regularizer="basis", + num_bases=None, + bias=True, + activation=None, + self_loop=False, + dropout=0.0): + super(RelGraphConv, self).__init__() + self.in_feat = in_feat + self.out_feat = out_feat + self.num_rels = num_rels + self.regularizer = regularizer + self.num_bases = num_bases + if self.num_bases is None or self.num_bases > self.num_rels or self.num_bases < 0: + self.num_bases = self.num_rels + self.bias = bias + self.activation = activation + self.self_loop = self_loop + + xinit = tf.keras.initializers.glorot_uniform() + zeroinit = tf.keras.initializers.zeros() + + if regularizer == "basis": + # add basis weights + self.weight = tf.Variable(initial_value=xinit( + shape=(self.num_bases, self.in_feat, self.out_feat), + dtype='float32'), trainable=True) + if self.num_bases < self.num_rels: + # linear combination coefficients + self.w_comp = tf.Variable(initial_value=xinit( + shape=(self.num_rels, self.num_bases), dtype='float32'), trainable=True) + # message func + self.message_func = self.basis_message_func + elif regularizer == "bdd": + if in_feat % num_bases != 0 or out_feat % num_bases != 0: + raise ValueError( + 'Feature size must be a multiplier of num_bases.') + # add block diagonal weights + self.submat_in = in_feat // self.num_bases + self.submat_out = out_feat // self.num_bases + + # assuming in_feat and out_feat are both divisible by num_bases + self.weight = tf.Variable(initial_value=xinit( + shape=(self.num_rels, self.num_bases * + self.submat_in * self.submat_out), + dtype='float32'), trainable=True) + # message func + self.message_func = self.bdd_message_func + else: + raise ValueError("Regularizer must be either 'basis' or 'bdd'") + + # bias + if self.bias: + self.h_bias = tf.Variable(initial_value=zeroinit( + shape=(out_feat), dtype='float32'), trainable=True) + + # weight for self loop + if self.self_loop: + self.loop_weight = tf.Variable(initial_value=xinit( + shape=(in_feat, out_feat), dtype='float32'), trainable=True) + + self.dropout = layers.Dropout(rate=dropout) + + def basis_message_func(self, edges): + """Message function for basis regularizer""" + if self.num_bases < self.num_rels: + # generate all weights from bases + weight = tf.reshape(self.weight, (self.num_bases, + self.in_feat * self.out_feat)) + weight = tf.reshape(tf.matmul(self.w_comp, weight), ( + self.num_rels, self.in_feat, self.out_feat)) + else: + weight = self.weight + + msg = utils.bmm_maybe_select( + edges.src['h'], weight, edges.data['type']) + if 'norm' in edges.data: + msg = msg * edges.data['norm'] + return {'msg': msg} + + def bdd_message_func(self, edges): + """Message function for block-diagonal-decomposition regularizer""" + if ((edges.src['h'].dtype == tf.int64) and + len(edges.src['h'].shape) == 1): + raise TypeError( + 'Block decomposition does not allow integer ID feature.') + weight = tf.reshape(tf.gather( + self.weight, edges.data['type']), (-1, self.submat_in, self.submat_out)) + node = tf.reshape(edges.src['h'], (-1, 1, self.submat_in)) + msg = tf.reshape(tf.matmul(node, weight), (-1, self.out_feat)) + if 'norm' in edges.data: + msg = msg * edges.data['norm'] + return {'msg': msg} + + def call(self, g, x, etypes, norm=None): + """ Forward computation + + Parameters + ---------- + g : DGLGraph + The graph. + x : tf.Tensor + Input node features. Could be either + * :math:`(|V|, D)` dense tensor + * :math:`(|V|,)` int64 vector, representing the categorical values of each + node. We then treat the input feature as an one-hot encoding feature. + etypes : tf.Tensor + Edge type tensor. Shape: :math:`(|E|,)` + norm : tf.Tensor + Optional edge normalizer tensor. Shape: :math:`(|E|, 1)` + + Returns + ------- + tf.Tensor + New node features. + """ + g = g.local_var() + g.ndata['h'] = x + g.edata['type'] = tf.cast(etypes, tf.int64) + if norm is not None: + g.edata['norm'] = norm + if self.self_loop: + loop_message = utils.matmul_maybe_select(x, self.loop_weight) + # message passing + g.update_all(self.message_func, fn.sum(msg='msg', out='h')) + # apply bias and activation + node_repr = g.ndata['h'] + if self.bias: + node_repr = node_repr + self.h_bias + if self.self_loop: + node_repr = node_repr + loop_message + if self.activation: + node_repr = self.activation(node_repr) + node_repr = self.dropout(node_repr) + return node_repr diff --git a/python/dgl/nn/tensorflow/conv/sageconv.py b/python/dgl/nn/tensorflow/conv/sageconv.py new file mode 100644 index 000000000000..cda26e33e2fa --- /dev/null +++ b/python/dgl/nn/tensorflow/conv/sageconv.py @@ -0,0 +1,127 @@ +"""Tensorflow Module for GraphSAGE layer""" +# pylint: disable= no-member, arguments-differ, invalid-name +import tensorflow as tf +from tensorflow.keras import layers + +from .... import function as fn + + +class SAGEConv(layers.Layer): + r"""GraphSAGE layer from paper `Inductive Representation Learning on + Large Graphs `__. + + .. math:: + h_{\mathcal{N}(i)}^{(l+1)} & = \mathrm{aggregate} + \left(\{h_{j}^{l}, \forall j \in \mathcal{N}(i) \}\right) + + h_{i}^{(l+1)} & = \sigma \left(W \cdot \mathrm{concat} + (h_{i}^{l}, h_{\mathcal{N}(i)}^{l+1} + b) \right) + + h_{i}^{(l+1)} & = \mathrm{norm}(h_{i}^{l}) + + Parameters + ---------- + in_feats : int + Input feature size. + out_feats : int + Output feature size. + feat_drop : float + Dropout rate on features, default: ``0``. + aggregator_type : str + Aggregator type to use (``mean``, ``gcn``, ``pool``, ``lstm``). + bias : bool + If True, adds a learnable bias to the output. Default: ``True``. + norm : callable activation function/layer or None, optional + If not None, applies normalization to the updated node features. + activation : callable activation function/layer or None, optional + If not None, applies an activation function to the updated node features. + Default: ``None``. + """ + + def __init__(self, + in_feats, + out_feats, + aggregator_type, + feat_drop=0., + bias=True, + norm=None, + activation=None): + super(SAGEConv, self).__init__() + self._in_feats = in_feats + self._out_feats = out_feats + self._aggre_type = aggregator_type + self.norm = norm + self.feat_drop = layers.Dropout(feat_drop) + self.activation = activation + # aggregator type: mean/pool/lstm/gcn + if aggregator_type == 'pool': + self.fc_pool = layers.Dense(in_feats) + if aggregator_type == 'lstm': + self.lstm = layers.LSTM(units=in_feats) + if aggregator_type != 'gcn': + self.fc_self = layers.Dense(out_feats, use_bias=bias) + self.fc_neigh = layers.Dense(out_feats, use_bias=bias) + + def _lstm_reducer(self, nodes): + """LSTM reducer + NOTE(zihao): lstm reducer with default schedule (degree bucketing) + is slow, we could accelerate this with degree padding in the future. + """ + m = nodes.mailbox['m'] # (B, L, D) + rst = self.lstm(m) + return {'neigh': rst} + + def call(self, graph, feat): + r"""Compute GraphSAGE layer. + + Parameters + ---------- + graph : DGLGraph + The graph. + feat : tf.Tensor + The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` + is size of input feature, :math:`N` is the number of nodes. + + Returns + ------- + tf.Tensor + The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` + is size of output feature. + """ + graph = graph.local_var() + feat = self.feat_drop(feat) + h_self = feat + if self._aggre_type == 'mean': + graph.ndata['h'] = feat + graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh')) + h_neigh = graph.ndata['neigh'] + elif self._aggre_type == 'gcn': + graph.ndata['h'] = feat + graph.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'neigh')) + # divide in_degrees + degs = tf.cast(graph.in_degrees(), tf.float32) + h_neigh = (graph.ndata['neigh'] + graph.ndata['h'] + ) / (tf.expand_dims(degs, -1) + 1) + elif self._aggre_type == 'pool': + graph.ndata['h'] = tf.nn.relu(self.fc_pool(feat)) + graph.update_all(fn.copy_src('h', 'm'), fn.max('m', 'neigh')) + h_neigh = graph.ndata['neigh'] + elif self._aggre_type == 'lstm': + graph.ndata['h'] = feat + graph.update_all(fn.copy_src('h', 'm'), self._lstm_reducer) + h_neigh = graph.ndata['neigh'] + else: + raise KeyError( + 'Aggregator type {} not recognized.'.format(self._aggre_type)) + # GraphSAGE GCN does not require fc_self. + if self._aggre_type == 'gcn': + rst = self.fc_neigh(h_neigh) + else: + rst = self.fc_self(h_self) + self.fc_neigh(h_neigh) + # activation + if self.activation is not None: + rst = self.activation(rst) + # normalization + if self.norm is not None: + rst = self.norm(rst) + return rst diff --git a/python/dgl/nn/tensorflow/conv/sgconv.py b/python/dgl/nn/tensorflow/conv/sgconv.py new file mode 100644 index 000000000000..46f5d176fea5 --- /dev/null +++ b/python/dgl/nn/tensorflow/conv/sgconv.py @@ -0,0 +1,99 @@ +"""tf Module for Simplifying Graph Convolution layer""" +# pylint: disable= no-member, arguments-differ, invalid-name, W0613 +import tensorflow as tf +from tensorflow.keras import layers +import numpy as np + +from .... import function as fn + + +class SGConv(layers.Layer): + r"""Simplifying Graph Convolution layer from paper `Simplifying Graph + Convolutional Networks `__. + + .. math:: + H^{l+1} = (\hat{D}^{-1/2} \hat{A} \hat{D}^{-1/2})^K H^{l} \Theta^{l} + + Parameters + ---------- + in_feats : int + Number of input features. + out_feats : int + Number of output features. + k : int + Number of hops :math:`K`. Defaults:``1``. + cached : bool + If True, the module would cache + + .. math:: + (\hat{D}^{-\frac{1}{2}}\hat{A}\hat{D}^{-\frac{1}{2}})^K X\Theta + + at the first forward call. This parameter should only be set to + ``True`` in Transductive Learning setting. + bias : bool + If True, adds a learnable bias to the output. Default: ``True``. + norm : callable activation function/layer or None, optional + If not None, applies normalization to the updated node features. + """ + + def __init__(self, + in_feats, + out_feats, + k=1, + cached=False, + bias=True, + norm=None): + super(SGConv, self).__init__() + self.fc = layers.Dense(out_feats, use_bias=bias) + self._cached = cached + self._cached_h = None + self._k = k + self.norm = norm + + def call(self, graph, feat): + r"""Compute Simplifying Graph Convolution layer. + + Parameters + ---------- + graph : DGLGraph + The graph. + feat : tf.Tensor + The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` + is size of input feature, :math:`N` is the number of nodes. + + Returns + ------- + tf.Tensor + The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` + is size of output feature. + + Notes + ----- + If ``cache`` is se to True, ``feat`` and ``graph`` should not change during + training, or you will get wrong results. + """ + graph = graph.local_var() + if self._cached_h is not None: + feat = self._cached_h + else: + # compute normalization + degs = tf.clip_by_value(tf.cast( + graph.in_degrees(), tf.float32), clip_value_min=1, clip_value_max=np.inf) + norm = tf.pow(degs, -0.5) + norm = tf.expand_dims(norm, 1) + # compute (D^-1 A^k D)^k X + for _ in range(self._k): + feat = feat * norm + graph.ndata['h'] = feat + graph.update_all(fn.copy_u('h', 'm'), + fn.sum('m', 'h')) + feat = graph.ndata.pop('h') + feat = feat * norm + + if self.norm is not None: + feat = self.norm(feat) + + # cache feature + if self._cached: + self._cached_h = feat + return self.fc(feat) diff --git a/python/dgl/nn/tensorflow/glob.py b/python/dgl/nn/tensorflow/glob.py new file mode 100644 index 000000000000..fe801d845828 --- /dev/null +++ b/python/dgl/nn/tensorflow/glob.py @@ -0,0 +1,258 @@ +"""Tensorflow modules for graph global pooling.""" +# pylint: disable= no-member, arguments-differ, invalid-name, W0235 +import tensorflow as tf +from tensorflow.keras import layers + + +from ... import BatchedDGLGraph +from ...batched_graph import sum_nodes, mean_nodes, max_nodes, \ + softmax_nodes, topk_nodes + + +__all__ = ['SumPooling', 'AvgPooling', + 'MaxPooling', 'SortPooling', 'WeightAndSum', 'GlobalAttentionPooling'] + + +class SumPooling(layers.Layer): + r"""Apply sum pooling over the nodes in the graph. + + .. math:: + r^{(i)} = \sum_{k=1}^{N_i} x^{(i)}_k + """ + + def __init__(self): + super(SumPooling, self).__init__() + + def call(self, graph, feat): + r"""Compute sum pooling. + + + Parameters + ---------- + graph : DGLGraph or BatchedDGLGraph + The graph. + feat : tf.Tensor + The input feature with shape :math:`(N, *)` where + :math:`N` is the number of nodes in the graph. + + Returns + ------- + tf.Tensor + The output feature with shape :math:`(*)` (if + input graph is a BatchedDGLGraph, the result shape + would be :math:`(B, *)`. + """ + with graph.local_scope(): + graph.ndata['h'] = feat + readout = sum_nodes(graph, 'h') + return readout + + +class AvgPooling(layers.Layer): + r"""Apply average pooling over the nodes in the graph. + + .. math:: + r^{(i)} = \frac{1}{N_i}\sum_{k=1}^{N_i} x^{(i)}_k + """ + + def __init__(self): + super(AvgPooling, self).__init__() + + def call(self, graph, feat): + r"""Compute average pooling. + + Parameters + ---------- + graph : DGLGraph or BatchedDGLGraph + The graph. + feat : tf.Tensor + The input feature with shape :math:`(N, *)` where + :math:`N` is the number of nodes in the graph. + + Returns + ------- + tf.Tensor + The output feature with shape :math:`(*)` (if + input graph is a BatchedDGLGraph, the result shape + would be :math:`(B, *)`. + """ + with graph.local_scope(): + graph.ndata['h'] = feat + readout = mean_nodes(graph, 'h') + return readout + + +class MaxPooling(layers.Layer): + r"""Apply max pooling over the nodes in the graph. + + .. math:: + r^{(i)} = \max_{k=1}^{N_i}\left( x^{(i)}_k \right) + """ + + def __init__(self): + super(MaxPooling, self).__init__() + + def call(self, graph, feat): + r"""Compute max pooling. + + Parameters + ---------- + graph : DGLGraph or BatchedDGLGraph + The graph. + feat : tf.Tensor + The input feature with shape :math:`(N, *)` where + :math:`N` is the number of nodes in the graph. + + Returns + ------- + tf.Tensor + The output feature with shape :math:`(*)` (if + input graph is a BatchedDGLGraph, the result shape + would be :math:`(B, *)`. + """ + with graph.local_scope(): + graph.ndata['h'] = feat + readout = max_nodes(graph, 'h') + return readout + + +class SortPooling(layers.Layer): + r"""Apply Sort Pooling (`An End-to-End Deep Learning Architecture for Graph Classification + `__) over the nodes in the graph. + + Parameters + ---------- + k : int + The number of nodes to hold for each graph. + """ + + def __init__(self, k): + super(SortPooling, self).__init__() + self.k = k + + def call(self, graph, feat): + r"""Compute sort pooling. + + Parameters + ---------- + graph : DGLGraph or BatchedDGLGraph + The graph. + feat : tf.Tensor + The input feature with shape :math:`(N, D)` where + :math:`N` is the number of nodes in the graph. + + Returns + ------- + tf.Tensor + The output feature with shape :math:`(k * D)` (if + input graph is a BatchedDGLGraph, the result shape + would be :math:`(B, k * D)`. + """ + with graph.local_scope(): + # Sort the feature of each node in ascending order. + feat = tf.sort(feat, -1) + graph.ndata['h'] = feat + # Sort nodes according to their last features. + ret = tf.reshape(topk_nodes(graph, 'h', self.k, idx=-1)[0], ( + -1, self.k * feat.shape[-1])) + if isinstance(graph, BatchedDGLGraph): + return ret + else: + return tf.squeeze(ret, 0) + + +class GlobalAttentionPooling(layers.Layer): + r"""Apply Global Attention Pooling (`Gated Graph Sequence Neural Networks + `__) over the nodes in the graph. + + .. math:: + r^{(i)} = \sum_{k=1}^{N_i}\mathrm{softmax}\left(f_{gate} + \left(x^{(i)}_k\right)\right) f_{feat}\left(x^{(i)}_k\right) + + Parameters + ---------- + gate_nn : tf.layers.Layer + A neural network that computes attention scores for each feature. + feat_nn : tf.layers.Layer, optional + A neural network applied to each feature before combining them + with attention scores. + """ + + def __init__(self, gate_nn, feat_nn=None): + super(GlobalAttentionPooling, self).__init__() + self.gate_nn = gate_nn + self.feat_nn = feat_nn + + def call(self, graph, feat): + r"""Compute global attention pooling. + + Parameters + ---------- + graph : DGLGraph + The graph. + feat : tf.Tensor + The input feature with shape :math:`(N, D)` where + :math:`N` is the number of nodes in the graph. + + Returns + ------- + tf.Tensor + The output feature with shape :math:`(D)` (if + input graph is a BatchedDGLGraph, the result shape + would be :math:`(B, D)`. + """ + with graph.local_scope(): + gate = self.gate_nn(feat) + assert gate.shape[-1] == 1, "The output of gate_nn should have size 1 at the last axis." + feat = self.feat_nn(feat) if self.feat_nn else feat + + graph.ndata['gate'] = gate + gate = softmax_nodes(graph, 'gate') + graph.ndata.pop('gate') + + graph.ndata['r'] = feat * gate + readout = sum_nodes(graph, 'r') + graph.ndata.pop('r') + + return readout + + +class WeightAndSum(layers.Layer): + """Compute importance weights for atoms and perform a weighted sum. + + Parameters + ---------- + in_feats : int + Input atom feature size + """ + + def __init__(self, in_feats): + super(WeightAndSum, self).__init__() + self.in_feats = in_feats + self.atom_weighting = tf.keras.Sequential( + layers.Dense(1), + layers.Activation(tf.nn.sigmoid) + ) + + def call(self, bg, feats): + """Compute molecule representations out of atom representations + + Parameters + ---------- + bg : BatchedDGLGraph + B Batched DGLGraphs for processing multiple molecules in parallel + feats : FloatTensor of shape (N, self.in_feats) + Representations for all atoms in the molecules + * N is the total number of atoms in all molecules + + Returns + ------- + FloatTensor of shape (B, self.in_feats) + Representations for B molecules + """ + with bg.local_scope(): + bg.ndata['h'] = feats + bg.ndata['w'] = self.atom_weighting(bg.ndata['h']) + h_g_sum = sum_nodes(bg, 'h', 'w') + + return h_g_sum diff --git a/python/dgl/nn/tensorflow/softmax.py b/python/dgl/nn/tensorflow/softmax.py new file mode 100644 index 000000000000..236ff8481e1e --- /dev/null +++ b/python/dgl/nn/tensorflow/softmax.py @@ -0,0 +1,44 @@ +"""tf modules for graph related softmax.""" +# pylint: disable= no-member, arguments-differ +import tensorflow as tf + +from ... import function as fn +from ...base import ALL, is_all + +__all__ = ['edge_softmax'] + + +def edge_softmax_real(graph, score, eids=ALL): + """Edge Softmax function""" + if not is_all(eids): + graph = graph.edge_subgraph(tf.cast(eids, tf.int64)) + g = graph.local_var() + g.edata['s'] = score + g.update_all(fn.copy_e('s', 'm'), fn.max('m', 'smax')) + g.apply_edges(fn.e_sub_v('s', 'smax', 'out')) + g.edata['out'] = tf.math.exp(g.edata['out']) + g.update_all(fn.copy_e('out', 'm'), fn.sum('m', 'out_sum')) + g.apply_edges(fn.e_div_v('out', 'out_sum', 'out')) + out = g.edata['out'] + + def edge_softmax_backward(grad_out): + g = graph.local_var() + # clear backward cache explicitly + g.edata['out'] = out + g.edata['grad_s'] = out * grad_out + g.update_all(fn.copy_e('grad_s', 'm'), fn.sum('m', 'accum')) + g.apply_edges(fn.e_mul_v('out', 'accum', 'out')) + grad_score = g.edata['grad_s'] - g.edata['out'] + return grad_score + + return out, edge_softmax_backward + + +def edge_softmax(graph, logits, eids=ALL): + """Closure for tf.custom_gradient""" + + @tf.custom_gradient + def _lambda(logits): + return edge_softmax_real(graph, logits, eids=eids) + + return _lambda(logits) diff --git a/python/dgl/nn/tensorflow/utils.py b/python/dgl/nn/tensorflow/utils.py new file mode 100644 index 000000000000..264ead09c880 --- /dev/null +++ b/python/dgl/nn/tensorflow/utils.py @@ -0,0 +1,99 @@ +"""Utilities for tf NN package""" +# pylint: disable=no-member, invalid-name +from tensorflow.keras import layers # pylint: disable=W0235 +import tensorflow as tf + + +def matmul_maybe_select(A, B): + """Perform Matrix multiplication C = A * B but A could be an integer id vector. + + If A is an integer vector, we treat it as multiplying a one-hot encoded tensor. + In this case, the expensive dense matrix multiply can be replaced by a much + cheaper index lookup. + + For example, + :: + + A = [2, 0, 1], + B = [[0.1, 0.2], + [0.3, 0.4], + [0.5, 0.6]] + + then matmul_maybe_select(A, B) is equivalent to + :: + + [[0, 0, 1], [[0.1, 0.2], + [1, 0, 0], * [0.3, 0.4], + [0, 1, 0]] [0.5, 0.6]] + + In all other cases, perform a normal matmul. + + Parameters + ---------- + A : tf.Tensor + lhs tensor + B : tf.Tensor + rhs tensor + + Returns + ------- + C : tf.Tensor + result tensor + """ + if A.dtype == tf.int64 and len(A.shape) == 1: + return tf.gather(B, A) + else: + return tf.matmul(A, B) + + +def bmm_maybe_select(A, B, index): + """Slice submatrices of A by the given index and perform bmm. + + B is a 3D tensor of shape (N, D1, D2), which can be viewed as a stack of + N matrices of shape (D1, D2). The input index is an integer vector of length M. + A could be either: + (1) a dense tensor of shape (M, D1), + (2) an integer vector of length M. + The result C is a 2D matrix of shape (M, D2) + + For case (1), C is computed by bmm: + :: + + C[i, :] = matmul(A[i, :], B[index[i], :, :]) + + For case (2), C is computed by index select: + :: + + C[i, :] = B[index[i], A[i], :] + + Parameters + ---------- + A : tf.Tensor + lhs tensor + B : tf.Tensor + rhs tensor + index : tf.Tensor + index tensor + + Returns + ------- + C : tf.Tensor + return tensor + """ + if A.dtype == tf.int64 and len(A.shape) == 1: + # following is a faster version of B[index, A, :] + B = tf.reshape(B, (-1, B.shape[2])) + flatidx = index * B.shape[1] + A + return tf.gather(B, flatidx) + else: + BB = tf.gather(B, index) + return tf.squeeze(tf.matmul(tf.expand_dims(A, 1), BB)) + + +class Identity(layers.Layer): + """A placeholder identity operator that is argument-insensitive. + """ + + def call(self, x): + """Return input""" + return x diff --git a/tests/backend/tensorflow/__init__.py b/tests/backend/tensorflow/__init__.py index 669f73e5e201..51956371e15d 100644 --- a/tests/backend/tensorflow/__init__.py +++ b/tests/backend/tensorflow/__init__.py @@ -18,8 +18,8 @@ def array_equal(a, b): def allclose(a, b, rtol=1e-4, atol=1e-4): - return np.allclose(a.numpy(), - b.numpy(), rtol=rtol, atol=atol) + return np.allclose(tf.convert_to_tensor(a).numpy(), + tf.convert_to_tensor(b).numpy(), rtol=rtol, atol=atol) def randn(shape): diff --git a/tests/scripts/task_unit_test.sh b/tests/scripts/task_unit_test.sh index 76003f7016b3..f89135d32ffe 100644 --- a/tests/scripts/task_unit_test.sh +++ b/tests/scripts/task_unit_test.sh @@ -23,6 +23,13 @@ export PYTHONPATH=tests:${PWD}/python:$PYTHONPATH export DGL_DOWNLOAD_DIR=${PWD} export TF_FORCE_GPU_ALLOW_GROWTH=true +if [ $2 == "gpu" ] +then + export CUDA_VISIBLE_DEVICES=0 +else + export CUDA_VISIBLE_DEVICES=-1 +fi + conda activate ${DGLBACKEND}-ci python3 -m pytest -v --junitxml=pytest_compute.xml tests/compute || fail "compute" diff --git a/tests/tensorflow/test_nn.py b/tests/tensorflow/test_nn.py new file mode 100644 index 000000000000..7af23e0236f0 --- /dev/null +++ b/tests/tensorflow/test_nn.py @@ -0,0 +1,379 @@ +import tensorflow as tf +from tensorflow.keras import layers +import networkx as nx +import dgl +import dgl.nn.tensorflow as nn +import dgl.function as fn +import backend as F +from copy import deepcopy + +import numpy as np +import scipy as sp + +def _AXWb(A, X, W, b): + X = tf.matmul(X, W) + Y = tf.reshape(tf.matmul(A, tf.reshape(X, (X.shape[0], -1))), X.shape) + return Y + b + +def test_graph_conv(): + g = dgl.DGLGraph(nx.path_graph(3)) + ctx = F.ctx() + adj = tf.sparse.to_dense(tf.sparse.reorder(g.adjacency_matrix(ctx=ctx))) + + conv = nn.GraphConv(5, 2, norm=False, bias=True) + # conv = conv + print(conv) + # test#1: basic + h0 = F.ones((3, 5)) + h1 = conv(g, h0) + assert len(g.ndata) == 0 + assert len(g.edata) == 0 + assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias)) + # test#2: more-dim + h0 = F.ones((3, 5, 5)) + h1 = conv(g, h0) + assert len(g.ndata) == 0 + assert len(g.edata) == 0 + assert F.allclose(h1, _AXWb(adj, h0, conv.weight, conv.bias)) + + conv = nn.GraphConv(5, 2) + # conv = conv + # test#3: basic + h0 = F.ones((3, 5)) + h1 = conv(g, h0) + assert len(g.ndata) == 0 + assert len(g.edata) == 0 + # test#4: basic + h0 = F.ones((3, 5, 5)) + h1 = conv(g, h0) + assert len(g.ndata) == 0 + assert len(g.edata) == 0 + + conv = nn.GraphConv(5, 2) + # conv = conv + # test#3: basic + h0 = F.ones((3, 5)) + h1 = conv(g, h0) + assert len(g.ndata) == 0 + assert len(g.edata) == 0 + # test#4: basic + h0 = F.ones((3, 5, 5)) + h1 = conv(g, h0) + assert len(g.ndata) == 0 + assert len(g.edata) == 0 + + # test rest_parameters + # old_weight = deepcopy(conv.weight.data) + # conv.reset_parameters() + # new_weight = conv.weight.data + # assert not F.allclose(old_weight, new_weight) + +def _S2AXWb(A, N, X, W, b): + X1 = X * N + X1 = th.matmul(A, X1.view(X1.shape[0], -1)) + X1 = X1 * N + X2 = X1 * N + X2 = th.matmul(A, X2.view(X2.shape[0], -1)) + X2 = X2 * N + X = th.cat([X, X1, X2], dim=-1) + Y = th.matmul(X, W.rot90()) + + return Y + b + +def test_simple_pool(): + ctx = F.ctx() + g = dgl.DGLGraph(nx.path_graph(15)) + + sum_pool = nn.SumPooling() + avg_pool = nn.AvgPooling() + max_pool = nn.MaxPooling() + sort_pool = nn.SortPooling(10) # k = 10 + print(sum_pool, avg_pool, max_pool, sort_pool) + + # test#1: basic + h0 = F.randn((g.number_of_nodes(), 5)) + h1 = sum_pool(g, h0) + assert F.allclose(h1, F.sum(h0, 0)) + h1 = avg_pool(g, h0) + assert F.allclose(h1, F.mean(h0, 0)) + h1 = max_pool(g, h0) + assert F.allclose(h1, F.max(h0, 0)) + h1 = sort_pool(g, h0) + assert h1.shape[0] == 10 * 5 and h1.ndim == 1 + + # test#2: batched graph + g_ = dgl.DGLGraph(nx.path_graph(5)) + bg = dgl.batch([g, g_, g, g_, g]) + h0 = F.randn((bg.number_of_nodes(), 5)) + h1 = sum_pool(bg, h0) + truth = tf.stack([F.sum(h0[:15], 0), + F.sum(h0[15:20], 0), + F.sum(h0[20:35], 0), + F.sum(h0[35:40], 0), + F.sum(h0[40:55], 0)], 0) + assert F.allclose(h1, truth) + + h1 = avg_pool(bg, h0) + truth = tf.stack([F.mean(h0[:15], 0), + F.mean(h0[15:20], 0), + F.mean(h0[20:35], 0), + F.mean(h0[35:40], 0), + F.mean(h0[40:55], 0)], 0) + assert F.allclose(h1, truth) + + h1 = max_pool(bg, h0) + truth = tf.stack([F.max(h0[:15], 0), + F.max(h0[15:20], 0), + F.max(h0[20:35], 0), + F.max(h0[35:40], 0), + F.max(h0[40:55], 0)], 0) + assert F.allclose(h1, truth) + + h1 = sort_pool(bg, h0) + assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.ndim == 2 + +def uniform_attention(g, shape): + a = F.ones(shape) + target_shape = (g.number_of_edges(),) + (1,) * (len(shape) - 1) + return a / tf.cast(tf.reshape(g.in_degrees(g.edges()[1]), target_shape), tf.float32) + +def test_edge_softmax(): + # Basic + g = dgl.DGLGraph(nx.path_graph(3)) + edata = F.ones((g.number_of_edges(), 1)) + a = nn.edge_softmax(g, edata) + assert len(g.ndata) == 0 + assert len(g.edata) == 0 + assert F.allclose(a, uniform_attention(g, a.shape)) + + # Test higher dimension case + edata = F.ones((g.number_of_edges(), 3, 1)) + a = nn.edge_softmax(g, edata) + assert len(g.ndata) == 0 + assert len(g.edata) == 0 + assert F.allclose(a, uniform_attention(g, a.shape)) + + # Test both forward and backward with Tensorflow built-in softmax. + g = dgl.DGLGraph() + g.add_nodes(30) + # build a complete graph + for i in range(30): + for j in range(30): + g.add_edge(i, j) + + + score = F.randn((900, 1)) + with tf.GradientTape() as tape: + tape.watch(score) + grad = F.randn((900, 1)) + y = tf.reshape(F.softmax(tf.reshape(score,(30, 30)), dim=0), (-1, 1)) + grads = tape.gradient(y, [score]) + grad_score = grads[0] + + with tf.GradientTape() as tape: + tape.watch(score) + y_dgl = nn.edge_softmax(g, score) + assert len(g.ndata) == 0 + assert len(g.edata) == 0 + # check forward + assert F.allclose(y_dgl, y) + grads = tape.gradient(y_dgl, [score]) + # checkout gradient + assert F.allclose(grads[0], grad_score) + print(grads[0][:10], grad_score[:10]) + + # Test 2 + def generate_rand_graph(n): + arr = (sp.sparse.random(n, n, density=0.1, format='coo') != 0).astype(np.int64) + return dgl.DGLGraph(arr, readonly=True) + + g = generate_rand_graph(50) + a1 = F.randn((g.number_of_edges(), 1)) + a2 = tf.identity(a1) + with tf.GradientTape() as tape: + tape.watch(a1) + g.edata['s'] = a1 + g.group_apply_edges('dst', lambda edges: {'ss':F.softmax(edges.data['s'], 1)}) + loss = tf.reduce_sum(g.edata['ss']) + a1_grad = tape.gradient(loss, [a1])[0] + + with tf.GradientTape() as tape: + tape.watch(a2) + builtin_sm = nn.edge_softmax(g, a2) + loss = tf.reduce_sum(builtin_sm) + a2_grad = tape.gradient(loss, [a2])[0] + print(a1_grad - a2_grad) + assert len(g.ndata) == 0 + assert len(g.edata) == 2 + assert F.allclose(a1_grad, a2_grad, rtol=1e-4, atol=1e-4) # Follow tolerance in unittest backend + +def test_partial_edge_softmax(): + g = dgl.DGLGraph() + g.add_nodes(30) + # build a complete graph + for i in range(30): + for j in range(30): + g.add_edge(i, j) + + score = F.randn((300, 1)) + grad = F.randn((300, 1)) + import numpy as np + eids = np.random.choice(900, 300, replace=False).astype('int64') + eids = F.zerocopy_from_numpy(eids) + # compute partial edge softmax + with tf.GradientTape() as tape: + tape.watch(score) + y_1 = nn.edge_softmax(g, score, eids) + grads = tape.gradient(y_1, [score]) + grad_1 = grads[0] + # compute edge softmax on edge subgraph + subg = g.edge_subgraph(eids) + with tf.GradientTape() as tape: + tape.watch(score) + y_2 = nn.edge_softmax(subg, score) + grads = tape.gradient(y_2, [score]) + grad_2 = grads[0] + + assert F.allclose(y_1, y_2) + assert F.allclose(grad_1, grad_2) + +def test_glob_att_pool(): + g = dgl.DGLGraph(nx.path_graph(10)) + + gap = nn.GlobalAttentionPooling(layers.Dense(1), layers.Dense(10)) + print(gap) + + # test#1: basic + h0 = F.randn((g.number_of_nodes(), 5)) + h1 = gap(g, h0) + assert h1.shape[0] == 10 and h1.ndim == 1 + + # test#2: batched graph + bg = dgl.batch([g, g, g, g]) + h0 = F.randn((bg.number_of_nodes(), 5)) + h1 = gap(bg, h0) + assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.ndim == 2 + + +def test_rgcn(): + etype = [] + g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) + # 5 etypes + R = 5 + for i in range(g.number_of_edges()): + etype.append(i % 5) + B = 2 + I = 10 + O = 8 + + rgc_basis = nn.RelGraphConv(I, O, R, "basis", B) + h = tf.random.normal((100, I)) + r = tf.constant(etype) + h_new = rgc_basis(g, h, r) + assert list(h_new.shape) == [100, O] + + rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B) + h = tf.random.normal((100, I)) + r = tf.constant(etype) + h_new = rgc_bdd(g, h, r) + assert list(h_new.shape) == [100, O] + + # with norm + norm = tf.zeros((g.number_of_edges(), 1)) + + rgc_basis = nn.RelGraphConv(I, O, R, "basis", B) + h = tf.random.normal((100, I)) + r = tf.constant(etype) + h_new = rgc_basis(g, h, r, norm) + assert list(h_new.shape) == [100, O] + + rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B) + h = tf.random.normal((100, I)) + r = tf.constant(etype) + h_new = rgc_bdd(g, h, r, norm) + assert list(h_new.shape) == [100, O] + + # id input + rgc_basis = nn.RelGraphConv(I, O, R, "basis", B) + h = tf.constant(np.random.randint(0, I, (100,))) + r = tf.constant(etype) + h_new = rgc_basis(g, h, r) + assert list(h_new.shape) == [100, O] + +def test_gat_conv(): + g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) + gat = nn.GATConv(5, 2, 4) + feat = F.randn((100, 5)) + h = gat(g, feat) + assert h.shape[-1] == 2 and h.shape[-2] == 4 + +def test_sage_conv(): + for aggre_type in ['mean', 'pool', 'gcn', 'lstm']: + g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) + sage = nn.SAGEConv(5, 10, aggre_type) + feat = F.randn((100, 5)) + h = sage(g, feat) + assert h.shape[-1] == 10 + +def test_sgc_conv(): + ctx = F.ctx() + g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) + # not cached + sgc = nn.SGConv(5, 10, 3) + feat = F.randn((100, 5)) + + h = sgc(g, feat) + assert h.shape[-1] == 10 + + # cached + sgc = nn.SGConv(5, 10, 3, True) + h_0 = sgc(g, feat) + h_1 = sgc(g, feat + 1) + assert F.allclose(h_0, h_1) + assert h_0.shape[-1] == 10 + +def test_appnp_conv(): + g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) + appnp = nn.APPNPConv(10, 0.1) + feat = F.randn((100, 5)) + + h = appnp(g, feat) + assert h.shape[-1] == 5 + +def test_gin_conv(): + for aggregator_type in ['mean', 'max', 'sum']: + g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) + gin = nn.GINConv( + tf.keras.layers.Dense(12), + aggregator_type + ) + feat = F.randn((100, 5)) + gin = gin + h = gin(g, feat) + assert h.shape[-1] == 12 + + +if __name__ == '__main__': + test_graph_conv() + test_edge_softmax() + test_partial_edge_softmax() + # test_set2set() + test_glob_att_pool() + test_simple_pool() + # test_set_trans() + test_rgcn() + # test_tagconv() + test_gat_conv() + test_sage_conv() + test_sgc_conv() + test_appnp_conv() + test_gin_conv() + # test_agnn_conv() + # test_gated_graph_conv() + # test_nn_conv() + # test_gmm_conv() + # test_dense_graph_conv() + # test_dense_sage_conv() + # test_dense_cheb_conv() + # test_sequential() +