Skip to content

Commit

Permalink
[DGL-LifeSci] Pre-trained GIN (dmlc#1558)
Browse files Browse the repository at this point in the history
* Update

* Fix

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update
  • Loading branch information
mufeili authored May 24, 2020
1 parent 3e69692 commit 165c67c
Show file tree
Hide file tree
Showing 12 changed files with 666 additions and 2 deletions.
5 changes: 5 additions & 0 deletions apps/life_sci/docs/source/api/model.gnn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ Weave
.. automodule:: dgllife.model.gnn.weave
:members:

GIN
---
.. automodule:: dgllife.model.gnn.gin
:members:

WLN
---
.. automodule:: dgllife.model.gnn.wln
Expand Down
5 changes: 5 additions & 0 deletions apps/life_sci/docs/source/api/model.zoo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ Weave Predictor
.. automodule:: dgllife.model.model_zoo.weave_predictor
:members:

GIN Predictor
`````````````
.. automodule:: dgllife.model.model_zoo.gin_predictor
:members:

Generative Models
-----------------

Expand Down
3 changes: 3 additions & 0 deletions apps/life_sci/docs/source/api/utils.mols.rst
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ For using featurization methods like above in creating node features:
dgllife.utils.BaseAtomFeaturizer.feat_size
dgllife.utils.CanonicalAtomFeaturizer
dgllife.utils.CanonicalAtomFeaturizer.feat_size
dgllife.utils.PretrainAtomFeaturizer

Featurization for Edges
```````````````````````
Expand All @@ -134,6 +135,7 @@ We consider the following bond descriptors:
dgllife.utils.bond_is_in_ring_one_hot
dgllife.utils.bond_is_in_ring
dgllife.utils.bond_stereo_one_hot
dgllife.utils.bond_direction_one_hot

For using featurization methods like above in creating edge features:

Expand All @@ -144,3 +146,4 @@ For using featurization methods like above in creating edge features:
dgllife.utils.BaseBondFeaturizer.feat_size
dgllife.utils.CanonicalBondFeaturizer
dgllife.utils.CanonicalBondFeaturizer.feat_size
dgllife.utils.PretrainBondFeaturizer
1 change: 1 addition & 0 deletions apps/life_sci/python/dgllife/model/gnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from .schnet import *
from .wln import *
from .weave import *
from .gin import *
200 changes: 200 additions & 0 deletions apps/life_sci/python/dgllife/model/gnn/gin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
"""Graph Isomorphism Networks."""
# pylint: disable= no-member, arguments-differ, invalid-name
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F

__all__ = ['GIN']

# pylint: disable=W0221, C0103
class GINLayer(nn.Module):
r"""Single Layer GIN from `Strategies for
Pre-training Graph Neural Networks <https://arxiv.org/abs/1905.12265>`__
Parameters
----------
num_edge_emb_list : list of int
num_edge_emb_list[i] gives the number of items to embed for the
i-th categorical edge feature variables. E.g. num_edge_emb_list[0] can be
the number of bond types and num_edge_emb_list[1] can be the number of
bond direction types.
emb_dim : int
The size of each embedding vector.
batch_norm : bool
Whether to apply batch normalization to the output of message passing.
Default to True.
activation : None or callable
Activation function to apply to the output node representations.
Default to None.
"""
def __init__(self, num_edge_emb_list, emb_dim, batch_norm=True, activation=None):
super(GINLayer, self).__init__()

self.mlp = nn.Sequential(
nn.Linear(emb_dim, 2 * emb_dim),
nn.ReLU(),
nn.Linear(2 * emb_dim, emb_dim)
)
self.edge_embeddings = nn.ModuleList()
for num_emb in num_edge_emb_list:
emb_module = nn.Embedding(num_emb, emb_dim)
nn.init.xavier_uniform_(emb_module.weight.data)
self.edge_embeddings.append(emb_module)

if batch_norm:
self.bn = nn.BatchNorm1d(emb_dim)
else:
self.bn = None
self.activation = activation

def forward(self, g, node_feats, categorical_edge_feats):
"""Update node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
node_feats : FloatTensor of shape (N, emb_dim)
* Input node features
* N is the total number of nodes in the batch of graphs
* emb_dim is the input node feature size, which must match emb_dim in initialization
categorical_edge_feats : list of LongTensor of shape (E)
* Input categorical edge features
* len(categorical_edge_feats) should be the same as len(self.edge_embeddings)
* E is the total number of edges in the batch of graphs
Returns
-------
node_feats : float32 tensor of shape (N, emb_dim)
Output node representations
"""
edge_embeds = []
for i, feats in enumerate(categorical_edge_feats):
edge_embeds.append(self.edge_embeddings[i](feats))
edge_embeds = torch.stack(edge_embeds, dim=0).sum(0)
g = g.local_var()
g.ndata['feat'] = node_feats
g.edata['feat'] = edge_embeds
g.update_all(fn.u_add_e('feat', 'feat', 'm'), fn.sum('m', 'feat'))

node_feats = self.mlp(g.ndata.pop('feat'))
if self.bn is not None:
node_feats = self.bn(node_feats)
if self.activation is not None:
node_feats = self.activation(node_feats)

return node_feats

class GIN(nn.Module):
r"""Graph Isomorphism Network from `Strategies for
Pre-training Graph Neural Networks <https://arxiv.org/abs/1905.12265>`__
This module is for updating node representations only.
Parameters
----------
num_node_emb_list : list of int
num_node_emb_list[i] gives the number of items to embed for the
i-th categorical node feature variables. E.g. num_node_emb_list[0] can be
the number of atom types and num_node_emb_list[1] can be the number of
atom chirality types.
num_edge_emb_list : list of int
num_edge_emb_list[i] gives the number of items to embed for the
i-th categorical edge feature variables. E.g. num_edge_emb_list[0] can be
the number of bond types and num_edge_emb_list[1] can be the number of
bond direction types.
num_layers : int
Number of GIN layers to use. Default to 5.
emb_dim : int
The size of each embedding vector. Default to 300.
JK : str
JK for jumping knowledge as in `Representation Learning on Graphs with
Jumping Knowledge Networks <https://arxiv.org/abs/1806.03536>`__. It decides
how we are going to combine the all-layer node representations for the final output.
There can be four options for this argument, ``concat``, ``last``, ``max`` and ``sum``.
Default to 'last'.
* ``'concat'``: concatenate the output node representations from all GIN layers
* ``'last'``: use the node representations from the last GIN layer
* ``'max'``: apply max pooling to the node representations across all GIN layers
* ``'sum'``: sum the output node representations from all GIN layers
dropout : float
Dropout to apply to the output of each GIN layer. Default to 0.5
"""
def __init__(self, num_node_emb_list, num_edge_emb_list,
num_layers=5, emb_dim=300, JK='last', dropout=0.5):
super(GIN, self).__init__()

self.num_layers = num_layers
self.JK = JK
self.dropout = nn.Dropout(dropout)

if num_layers < 2:
raise ValueError('Number of GNN layers must be greater '
'than 1, got {:d}'.format(num_layers))

self.node_embeddings = nn.ModuleList()
for num_emb in num_node_emb_list:
emb_module = nn.Embedding(num_emb, emb_dim)
nn.init.xavier_uniform_(emb_module.weight.data)
self.node_embeddings.append(emb_module)

self.gnn_layers = nn.ModuleList()
for layer in range(num_layers):
if layer == num_layers - 1:
self.gnn_layers.append(GINLayer(num_edge_emb_list, emb_dim))
else:
self.gnn_layers.append(GINLayer(num_edge_emb_list, emb_dim, activation=F.relu))

def forward(self, g, categorical_node_feats, categorical_edge_feats):
"""Update node representations
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
categorical_node_feats : list of LongTensor of shape (N)
* Input categorical node features
* len(categorical_node_feats) should be the same as len(self.node_embeddings)
* N is the total number of nodes in the batch of graphs
categorical_edge_feats : list of LongTensor of shape (E)
* Input categorical edge features
* len(categorical_edge_feats) should be the same as
len(num_edge_emb_list) in the arguments
* E is the total number of edges in the batch of graphs
Returns
-------
final_node_feats : float32 tensor of shape (N, M)
Output node representations, N for the number of nodes and
M for output size. In particular, M will be emb_dim * (num_layers + 1)
if self.JK == 'concat' and emb_dim otherwise.
"""
node_embeds = []
for i, feats in enumerate(categorical_node_feats):
node_embeds.append(self.node_embeddings[i](feats))
node_embeds = torch.stack(node_embeds, dim=0).sum(0)

all_layer_node_feats = [node_embeds]
for layer in range(self.num_layers):
node_feats = self.gnn_layers[layer](g, all_layer_node_feats[layer],
categorical_edge_feats)
node_feats = self.dropout(node_feats)
all_layer_node_feats.append(node_feats)

if self.JK == 'concat':
final_node_feats = torch.cat(all_layer_node_feats, dim=1)
elif self.JK == 'last':
final_node_feats = all_layer_node_feats[-1]
elif self.JK == 'max':
all_layer_node_feats = [h.unsqueeze_(0) for h in all_layer_node_feats]
final_node_feats = torch.max(torch.cat(all_layer_node_feats, dim=0), dim=0)[0]
elif self.JK == 'sum':
all_layer_node_feats = [h.unsqueeze_(0) for h in all_layer_node_feats]
final_node_feats = torch.sum(torch.cat(all_layer_node_feats, dim=0), dim=0)
else:
return ValueError("Expect self.JK to be 'concat', 'last', "
"'max' or 'sum', got {}".format(self.JK))

return final_node_feats
1 change: 1 addition & 0 deletions apps/life_sci/python/dgllife/model/model_zoo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
from .wln_reaction_center import *
from .wln_reaction_ranking import *
from .weave_predictor import *
from .gin_predictor import *
122 changes: 122 additions & 0 deletions apps/life_sci/python/dgllife/model/model_zoo/gin_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""GIN-based model for regression and classification on graphs."""
# pylint: disable= no-member, arguments-differ, invalid-name
import dgl
import torch.nn as nn

from dgl.nn.pytorch.glob import GlobalAttentionPooling, SumPooling, AvgPooling, MaxPooling

from ..gnn.gin import GIN

__all__ = ['GINPredictor']

# pylint: disable=W0221
class GINPredictor(nn.Module):
"""GIN-based model for regression and classification on graphs.
GIN was first introduced in `How Powerful Are Graph Neural Networks
<https://arxiv.org/abs/1810.00826>`__ for general graph property
prediction problems. It was further extended in `Strategies for
Pre-training Graph Neural Networks <https://arxiv.org/abs/1905.12265>`__
for pre-training and semi-supervised learning on large-scale datasets.
For classification tasks, the output will be logits, i.e. values before
sigmoid or softmax.
Parameters
----------
num_node_emb_list : list of int
num_node_emb_list[i] gives the number of items to embed for the
i-th categorical node feature variables. E.g. num_node_emb_list[0] can be
the number of atom types and num_node_emb_list[1] can be the number of
atom chirality types.
num_edge_emb_list : list of int
num_edge_emb_list[i] gives the number of items to embed for the
i-th categorical edge feature variables. E.g. num_edge_emb_list[0] can be
the number of bond types and num_edge_emb_list[1] can be the number of
bond direction types.
num_layers : int
Number of GIN layers to use. Default to 5.
emb_dim : int
The size of each embedding vector. Default to 300.
JK : str
JK for jumping knowledge as in `Representation Learning on Graphs with
Jumping Knowledge Networks <https://arxiv.org/abs/1806.03536>`__. It decides
how we are going to combine the all-layer node representations for the final output.
There can be four options for this argument, ``'concat'``, ``'last'``, ``'max'`` and
``'sum'``. Default to 'last'.
* ``'concat'``: concatenate the output node representations from all GIN layers
* ``'last'``: use the node representations from the last GIN layer
* ``'max'``: apply max pooling to the node representations across all GIN layers
* ``'sum'``: sum the output node representations from all GIN layers
dropout : float
Dropout to apply to the output of each GIN layer. Default to 0.5.
readout : str
Readout for computing graph representations out of node representations, which
can be ``'sum'``, ``'mean'``, ``'max'``, or ``'attention'``. Default to 'mean'.
n_tasks : int
Number of tasks, which is also the output size. Default to 1.
"""
def __init__(self, num_node_emb_list, num_edge_emb_list, num_layers=5,
emb_dim=300, JK='last', dropout=0.5, readout='mean', n_tasks=1):
super(GINPredictor, self).__init__()

if num_layers < 2:
raise ValueError('Number of GNN layers must be greater '
'than 1, got {:d}'.format(num_layers))

self.gnn = GIN(num_node_emb_list=num_node_emb_list,
num_edge_emb_list=num_edge_emb_list,
num_layers=num_layers,
emb_dim=emb_dim,
JK=JK,
dropout=dropout)

if readout == 'sum':
self.readout = SumPooling()
elif readout == 'mean':
self.readout = AvgPooling()
elif readout == 'max':
self.readout = MaxPooling()
elif readout == 'attention':
if JK == 'concat':
self.readout = GlobalAttentionPooling(
gate_nn=nn.Linear((num_layers + 1) * emb_dim, 1))
else:
self.readout = GlobalAttentionPooling(
gate_nn=nn.Linear(emb_dim, 1))
else:
raise ValueError("Expect readout to be 'sum', 'mean', "
"'max' or 'attention', got {}".format(readout))

if JK == 'concat':
self.predict = nn.Linear((num_layers + 1) * emb_dim, n_tasks)
else:
self.predict = nn.Linear(emb_dim, n_tasks)

def forward(self, g, categorical_node_feats, categorical_edge_feats):
"""Graph-level regression/soft classification.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
categorical_node_feats : list of LongTensor of shape (N)
* Input categorical node features
* len(categorical_node_feats) should be the same as len(num_node_emb_list)
* N is the total number of nodes in the batch of graphs
categorical_edge_feats : list of LongTensor of shape (E)
* Input categorical edge features
* len(categorical_edge_feats) should be the same as
len(num_edge_emb_list) in the arguments
* E is the total number of edges in the batch of graphs
Returns
-------
FloatTensor of shape (B, n_tasks)
* Predictions on graphs
* B for the number of graphs in the batch
"""
node_feats = self.gnn(g, categorical_node_feats, categorical_edge_feats)
graph_feats = self.readout(g, node_feats)
return self.predict(graph_feats)
Loading

0 comments on commit 165c67c

Please sign in to comment.