Skip to content

Commit

Permalink
[Tensorflow] Several nn & example (dmlc#1191)
Browse files Browse the repository at this point in the history
* several nn example

* appnp

* fix lint

* lint

* add dgi

* fix

* fix

* fix

* fff

* docs

* 111

* fix

* change init

* change result

* tiaocan+1

* fix

* fix lint

* fix

* fix
  • Loading branch information
VoVAllen authored Jan 19, 2020
1 parent 31a7d50 commit a00636a
Show file tree
Hide file tree
Showing 36 changed files with 3,058 additions and 12 deletions.
1 change: 1 addition & 0 deletions docs/source/api/python/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ NN Modules

nn.pytorch
nn.mxnet
nn.tensorflow
118 changes: 118 additions & 0 deletions docs/source/api/python/nn.tensorflow.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
.. _apinn-tensorflow:

NN Modules (Tensorflow)
====================

.. contents:: Contents
:local:

We welcome your contribution! If you want a model to be implemented in DGL as a NN module,
please `create an issue <https://github.com/dmlc/dgl/issues>`_ started with "[Feature Request] NN Module XXXModel".

If you want to contribute a NN module, please `create a pull request <https://github.com/dmlc/dgl/pulls>`_ started
with "[NN] XXXModel in tensorflow NN Modules" and our team member would review this PR.

Conv Layers
----------------------------------------

.. automodule:: dgl.nn.tensorflow.conv

GraphConv
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: dgl.nn.tensorflow.conv.GraphConv
:members: weight, bias, forward, reset_parameters
:show-inheritance:

RelGraphConv
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: dgl.nn.tensorflow.conv.RelGraphConv
:members: forward
:show-inheritance:

GATConv
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: dgl.nn.tensorflow.conv.GATConv
:members: forward
:show-inheritance:

SAGEConv
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: dgl.nn.tensorflow.conv.SAGEConv
:members: forward
:show-inheritance:

SGConv
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: dgl.nn.tensorflow.conv.SGConv
:members: forward
:show-inheritance:

APPNPConv
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: dgl.nn.tensorflow.conv.APPNPConv
:members: forward
:show-inheritance:

GINConv
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: dgl.nn.tensorflow.conv.GINConv
:members: forward
:show-inheritance:


Global Pooling Layers
----------------------------------------

.. automodule:: dgl.nn.tensorflow.glob

SumPooling
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: dgl.nn.tensorflow.glob.SumPooling
:members:
:show-inheritance:

AvgPooling
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: dgl.nn.tensorflow.glob.AvgPooling
:members:
:show-inheritance:

MaxPooling
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: dgl.nn.tensorflow.glob.MaxPooling
:members:
:show-inheritance:

SortPooling
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: dgl.nn.tensorflow.glob.SortPooling
:members:
:show-inheritance:

GlobalAttentionPooling
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: dgl.nn.tensorflow.glob.GlobalAttentionPooling
:members:
:show-inheritance:


Utility Modules
----------------------------------------

Edge Softmax
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. automodule:: dgl.nn.tensorflow.softmax
:members: edge_softmax
2 changes: 1 addition & 1 deletion examples/pytorch/dgi/gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@ def forward(self, features):
for i, layer in enumerate(self.layers):
if i != 0:
h = self.dropout(h)
h = layer(h, self.g)
h = layer(self.g, h)
return h
10 changes: 5 additions & 5 deletions examples/pytorch/gat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ python3 train_ppi.py --gpu=0
Results
-------

| Dataset | Test Accuracy | Time(s) | Baseline#1 times(s) | Baseline#2 times(s) |
| ------- | ------------- | ------- | ------------------- | ------------------- |
| Cora | 84.02(0.40) | 0.0113 | 0.0982 (**8.7x**) | 0.0424 (**3.8x**) |
| Citeseer | 70.91(0.79) | 0.0111 | n/a | n/a |
| Pubmed | 78.57(0.75) | 0.0115 | n/a | n/a |
| Dataset | Test Accuracy | Time(s) | Baseline#1 times(s) | Baseline#2 times(s) |
| -------- | ------------- | ------- | ------------------- | ------------------- |
| Cora | 84.02(0.40) | 0.0113 | 0.0982 (**8.7x**) | 0.0424 (**3.8x**) |
| Citeseer | 70.91(0.79) | 0.0111 | n/a | n/a |
| Pubmed | 78.57(0.75) | 0.0115 | n/a | n/a |

* All the accuracy numbers are obtained after 300 epochs.
* The time measures how long it takes to train one epoch.
Expand Down
38 changes: 38 additions & 0 deletions examples/tensorflow/dgi/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
Deep Graph Infomax (DGI)
========================

- Paper link: [https://arxiv.org/abs/1809.10341](https://arxiv.org/abs/1809.10341)
- Author's code repo (in Pytorch):
[https://github.com/PetarV-/DGI](https://github.com/PetarV-/DGI)

Dependencies
------------
- tensorflow 2.1+
- requests

```bash
pip install tensorflow requests
```

How to run
----------

Run with following:

```bash
python3 train.py --dataset=cora --gpu=0 --self-loop
```

```bash
python3 train.py --dataset=citeseer --gpu=0
```

```bash
python3 train.py --dataset=pubmed --gpu=0
```

Results
-------
* cora: ~81.6 (80.9-82.9) (paper: 82.3)
* citeseer: ~70.2 (paper: 71.8)
* pubmed: ~77.2 (paper: 76.8)
75 changes: 75 additions & 0 deletions examples/tensorflow/dgi/dgi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""
Deep Graph Infomax in DGL
References
----------
Papers: https://arxiv.org/abs/1809.10341
Author's code: https://github.com/PetarV-/DGI
"""

import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
import math
from gcn import GCN


class Encoder(layers.Layer):
def __init__(self, g, in_feats, n_hidden, n_layers, activation, dropout):
super(Encoder, self).__init__()
self.g = g
self.conv = GCN(g, in_feats, n_hidden, n_hidden,
n_layers, activation, dropout)

def call(self, features, corrupt=False):
if corrupt:
perm = np.random.permutation(self.g.number_of_nodes())
features = tf.gather(features, perm)
features = self.conv(features)
return features


class Discriminator(layers.Layer):
def __init__(self, n_hidden):
super(Discriminator, self).__init__()
uinit = tf.keras.initializers.RandomUniform(
-1.0/math.sqrt(n_hidden), 1.0/math.sqrt(n_hidden))
self.weight = tf.Variable(initial_value=uinit(
shape=(n_hidden, n_hidden), dtype='float32'), trainable=True)

def call(self, features, summary):
features = tf.matmul(features, tf.matmul(
self.weight, tf.expand_dims(summary, -1)))
return features


class DGI(tf.keras.Model):
def __init__(self, g, in_feats, n_hidden, n_layers, activation, dropout):
super(DGI, self).__init__()
self.encoder = Encoder(g, in_feats, n_hidden,
n_layers, activation, dropout)
self.discriminator = Discriminator(n_hidden)
self.loss = tf.nn.sigmoid_cross_entropy_with_logits

def call(self, features):
positive = self.encoder(features, corrupt=False)
negative = self.encoder(features, corrupt=True)
summary = tf.nn.sigmoid(tf.reduce_mean(positive, axis=0))

positive = self.discriminator(positive, summary)
negative = self.discriminator(negative, summary)

l1 = self.loss(tf.ones(positive.shape),positive)
l2 = self.loss(tf.zeros(negative.shape), negative)

return tf.reduce_mean(l1) + tf.reduce_mean(l2)


class Classifier(layers.Layer):
def __init__(self, n_hidden, n_classes):
super(Classifier, self).__init__()
self.fc = layers.Dense(n_classes)

def call(self, features):
features = self.fc(features)
return features
36 changes: 36 additions & 0 deletions examples/tensorflow/dgi/gcn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""
This code was copied from the GCN implementation in DGL examples.
"""
import tensorflow as tf
from tensorflow.keras import layers

from dgl.nn.tensorflow import GraphConv

class GCN(layers.Layer):
def __init__(self,
g,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
dropout):
super(GCN, self).__init__()
self.g = g
self.layers =[]
# input layer
self.layers.append(GraphConv(in_feats, n_hidden, activation=activation))
# hidden layers
for i in range(n_layers - 1):
self.layers.append(GraphConv(n_hidden, n_hidden, activation=activation))
# output layer
self.layers.append(GraphConv(n_hidden, n_classes))
self.dropout = layers.Dropout(dropout)

def call(self, features):
h = features
for i, layer in enumerate(self.layers):
if i != 0:
h = self.dropout(h)
h = layer(self.g, h)
return h
Loading

0 comments on commit a00636a

Please sign in to comment.