forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[DGL-LifeSci] Pre-trained GIN (dmlc#1558)
* Update * Fix * Update * Update * Update * Update * Update * Update * Update * Update
- Loading branch information
Showing
12 changed files
with
666 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,3 +7,4 @@ | |
from .schnet import * | ||
from .wln import * | ||
from .weave import * | ||
from .gin import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
"""Graph Isomorphism Networks.""" | ||
# pylint: disable= no-member, arguments-differ, invalid-name | ||
import dgl.function as fn | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
__all__ = ['GIN'] | ||
|
||
# pylint: disable=W0221, C0103 | ||
class GINLayer(nn.Module): | ||
r"""Single Layer GIN from `Strategies for | ||
Pre-training Graph Neural Networks <https://arxiv.org/abs/1905.12265>`__ | ||
Parameters | ||
---------- | ||
num_edge_emb_list : list of int | ||
num_edge_emb_list[i] gives the number of items to embed for the | ||
i-th categorical edge feature variables. E.g. num_edge_emb_list[0] can be | ||
the number of bond types and num_edge_emb_list[1] can be the number of | ||
bond direction types. | ||
emb_dim : int | ||
The size of each embedding vector. | ||
batch_norm : bool | ||
Whether to apply batch normalization to the output of message passing. | ||
Default to True. | ||
activation : None or callable | ||
Activation function to apply to the output node representations. | ||
Default to None. | ||
""" | ||
def __init__(self, num_edge_emb_list, emb_dim, batch_norm=True, activation=None): | ||
super(GINLayer, self).__init__() | ||
|
||
self.mlp = nn.Sequential( | ||
nn.Linear(emb_dim, 2 * emb_dim), | ||
nn.ReLU(), | ||
nn.Linear(2 * emb_dim, emb_dim) | ||
) | ||
self.edge_embeddings = nn.ModuleList() | ||
for num_emb in num_edge_emb_list: | ||
emb_module = nn.Embedding(num_emb, emb_dim) | ||
nn.init.xavier_uniform_(emb_module.weight.data) | ||
self.edge_embeddings.append(emb_module) | ||
|
||
if batch_norm: | ||
self.bn = nn.BatchNorm1d(emb_dim) | ||
else: | ||
self.bn = None | ||
self.activation = activation | ||
|
||
def forward(self, g, node_feats, categorical_edge_feats): | ||
"""Update node representations. | ||
Parameters | ||
---------- | ||
g : DGLGraph | ||
DGLGraph for a batch of graphs | ||
node_feats : FloatTensor of shape (N, emb_dim) | ||
* Input node features | ||
* N is the total number of nodes in the batch of graphs | ||
* emb_dim is the input node feature size, which must match emb_dim in initialization | ||
categorical_edge_feats : list of LongTensor of shape (E) | ||
* Input categorical edge features | ||
* len(categorical_edge_feats) should be the same as len(self.edge_embeddings) | ||
* E is the total number of edges in the batch of graphs | ||
Returns | ||
------- | ||
node_feats : float32 tensor of shape (N, emb_dim) | ||
Output node representations | ||
""" | ||
edge_embeds = [] | ||
for i, feats in enumerate(categorical_edge_feats): | ||
edge_embeds.append(self.edge_embeddings[i](feats)) | ||
edge_embeds = torch.stack(edge_embeds, dim=0).sum(0) | ||
g = g.local_var() | ||
g.ndata['feat'] = node_feats | ||
g.edata['feat'] = edge_embeds | ||
g.update_all(fn.u_add_e('feat', 'feat', 'm'), fn.sum('m', 'feat')) | ||
|
||
node_feats = self.mlp(g.ndata.pop('feat')) | ||
if self.bn is not None: | ||
node_feats = self.bn(node_feats) | ||
if self.activation is not None: | ||
node_feats = self.activation(node_feats) | ||
|
||
return node_feats | ||
|
||
class GIN(nn.Module): | ||
r"""Graph Isomorphism Network from `Strategies for | ||
Pre-training Graph Neural Networks <https://arxiv.org/abs/1905.12265>`__ | ||
This module is for updating node representations only. | ||
Parameters | ||
---------- | ||
num_node_emb_list : list of int | ||
num_node_emb_list[i] gives the number of items to embed for the | ||
i-th categorical node feature variables. E.g. num_node_emb_list[0] can be | ||
the number of atom types and num_node_emb_list[1] can be the number of | ||
atom chirality types. | ||
num_edge_emb_list : list of int | ||
num_edge_emb_list[i] gives the number of items to embed for the | ||
i-th categorical edge feature variables. E.g. num_edge_emb_list[0] can be | ||
the number of bond types and num_edge_emb_list[1] can be the number of | ||
bond direction types. | ||
num_layers : int | ||
Number of GIN layers to use. Default to 5. | ||
emb_dim : int | ||
The size of each embedding vector. Default to 300. | ||
JK : str | ||
JK for jumping knowledge as in `Representation Learning on Graphs with | ||
Jumping Knowledge Networks <https://arxiv.org/abs/1806.03536>`__. It decides | ||
how we are going to combine the all-layer node representations for the final output. | ||
There can be four options for this argument, ``concat``, ``last``, ``max`` and ``sum``. | ||
Default to 'last'. | ||
* ``'concat'``: concatenate the output node representations from all GIN layers | ||
* ``'last'``: use the node representations from the last GIN layer | ||
* ``'max'``: apply max pooling to the node representations across all GIN layers | ||
* ``'sum'``: sum the output node representations from all GIN layers | ||
dropout : float | ||
Dropout to apply to the output of each GIN layer. Default to 0.5 | ||
""" | ||
def __init__(self, num_node_emb_list, num_edge_emb_list, | ||
num_layers=5, emb_dim=300, JK='last', dropout=0.5): | ||
super(GIN, self).__init__() | ||
|
||
self.num_layers = num_layers | ||
self.JK = JK | ||
self.dropout = nn.Dropout(dropout) | ||
|
||
if num_layers < 2: | ||
raise ValueError('Number of GNN layers must be greater ' | ||
'than 1, got {:d}'.format(num_layers)) | ||
|
||
self.node_embeddings = nn.ModuleList() | ||
for num_emb in num_node_emb_list: | ||
emb_module = nn.Embedding(num_emb, emb_dim) | ||
nn.init.xavier_uniform_(emb_module.weight.data) | ||
self.node_embeddings.append(emb_module) | ||
|
||
self.gnn_layers = nn.ModuleList() | ||
for layer in range(num_layers): | ||
if layer == num_layers - 1: | ||
self.gnn_layers.append(GINLayer(num_edge_emb_list, emb_dim)) | ||
else: | ||
self.gnn_layers.append(GINLayer(num_edge_emb_list, emb_dim, activation=F.relu)) | ||
|
||
def forward(self, g, categorical_node_feats, categorical_edge_feats): | ||
"""Update node representations | ||
Parameters | ||
---------- | ||
g : DGLGraph | ||
DGLGraph for a batch of graphs | ||
categorical_node_feats : list of LongTensor of shape (N) | ||
* Input categorical node features | ||
* len(categorical_node_feats) should be the same as len(self.node_embeddings) | ||
* N is the total number of nodes in the batch of graphs | ||
categorical_edge_feats : list of LongTensor of shape (E) | ||
* Input categorical edge features | ||
* len(categorical_edge_feats) should be the same as | ||
len(num_edge_emb_list) in the arguments | ||
* E is the total number of edges in the batch of graphs | ||
Returns | ||
------- | ||
final_node_feats : float32 tensor of shape (N, M) | ||
Output node representations, N for the number of nodes and | ||
M for output size. In particular, M will be emb_dim * (num_layers + 1) | ||
if self.JK == 'concat' and emb_dim otherwise. | ||
""" | ||
node_embeds = [] | ||
for i, feats in enumerate(categorical_node_feats): | ||
node_embeds.append(self.node_embeddings[i](feats)) | ||
node_embeds = torch.stack(node_embeds, dim=0).sum(0) | ||
|
||
all_layer_node_feats = [node_embeds] | ||
for layer in range(self.num_layers): | ||
node_feats = self.gnn_layers[layer](g, all_layer_node_feats[layer], | ||
categorical_edge_feats) | ||
node_feats = self.dropout(node_feats) | ||
all_layer_node_feats.append(node_feats) | ||
|
||
if self.JK == 'concat': | ||
final_node_feats = torch.cat(all_layer_node_feats, dim=1) | ||
elif self.JK == 'last': | ||
final_node_feats = all_layer_node_feats[-1] | ||
elif self.JK == 'max': | ||
all_layer_node_feats = [h.unsqueeze_(0) for h in all_layer_node_feats] | ||
final_node_feats = torch.max(torch.cat(all_layer_node_feats, dim=0), dim=0)[0] | ||
elif self.JK == 'sum': | ||
all_layer_node_feats = [h.unsqueeze_(0) for h in all_layer_node_feats] | ||
final_node_feats = torch.sum(torch.cat(all_layer_node_feats, dim=0), dim=0) | ||
else: | ||
return ValueError("Expect self.JK to be 'concat', 'last', " | ||
"'max' or 'sum', got {}".format(self.JK)) | ||
|
||
return final_node_feats |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
122 changes: 122 additions & 0 deletions
122
apps/life_sci/python/dgllife/model/model_zoo/gin_predictor.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
"""GIN-based model for regression and classification on graphs.""" | ||
# pylint: disable= no-member, arguments-differ, invalid-name | ||
import dgl | ||
import torch.nn as nn | ||
|
||
from dgl.nn.pytorch.glob import GlobalAttentionPooling, SumPooling, AvgPooling, MaxPooling | ||
|
||
from ..gnn.gin import GIN | ||
|
||
__all__ = ['GINPredictor'] | ||
|
||
# pylint: disable=W0221 | ||
class GINPredictor(nn.Module): | ||
"""GIN-based model for regression and classification on graphs. | ||
GIN was first introduced in `How Powerful Are Graph Neural Networks | ||
<https://arxiv.org/abs/1810.00826>`__ for general graph property | ||
prediction problems. It was further extended in `Strategies for | ||
Pre-training Graph Neural Networks <https://arxiv.org/abs/1905.12265>`__ | ||
for pre-training and semi-supervised learning on large-scale datasets. | ||
For classification tasks, the output will be logits, i.e. values before | ||
sigmoid or softmax. | ||
Parameters | ||
---------- | ||
num_node_emb_list : list of int | ||
num_node_emb_list[i] gives the number of items to embed for the | ||
i-th categorical node feature variables. E.g. num_node_emb_list[0] can be | ||
the number of atom types and num_node_emb_list[1] can be the number of | ||
atom chirality types. | ||
num_edge_emb_list : list of int | ||
num_edge_emb_list[i] gives the number of items to embed for the | ||
i-th categorical edge feature variables. E.g. num_edge_emb_list[0] can be | ||
the number of bond types and num_edge_emb_list[1] can be the number of | ||
bond direction types. | ||
num_layers : int | ||
Number of GIN layers to use. Default to 5. | ||
emb_dim : int | ||
The size of each embedding vector. Default to 300. | ||
JK : str | ||
JK for jumping knowledge as in `Representation Learning on Graphs with | ||
Jumping Knowledge Networks <https://arxiv.org/abs/1806.03536>`__. It decides | ||
how we are going to combine the all-layer node representations for the final output. | ||
There can be four options for this argument, ``'concat'``, ``'last'``, ``'max'`` and | ||
``'sum'``. Default to 'last'. | ||
* ``'concat'``: concatenate the output node representations from all GIN layers | ||
* ``'last'``: use the node representations from the last GIN layer | ||
* ``'max'``: apply max pooling to the node representations across all GIN layers | ||
* ``'sum'``: sum the output node representations from all GIN layers | ||
dropout : float | ||
Dropout to apply to the output of each GIN layer. Default to 0.5. | ||
readout : str | ||
Readout for computing graph representations out of node representations, which | ||
can be ``'sum'``, ``'mean'``, ``'max'``, or ``'attention'``. Default to 'mean'. | ||
n_tasks : int | ||
Number of tasks, which is also the output size. Default to 1. | ||
""" | ||
def __init__(self, num_node_emb_list, num_edge_emb_list, num_layers=5, | ||
emb_dim=300, JK='last', dropout=0.5, readout='mean', n_tasks=1): | ||
super(GINPredictor, self).__init__() | ||
|
||
if num_layers < 2: | ||
raise ValueError('Number of GNN layers must be greater ' | ||
'than 1, got {:d}'.format(num_layers)) | ||
|
||
self.gnn = GIN(num_node_emb_list=num_node_emb_list, | ||
num_edge_emb_list=num_edge_emb_list, | ||
num_layers=num_layers, | ||
emb_dim=emb_dim, | ||
JK=JK, | ||
dropout=dropout) | ||
|
||
if readout == 'sum': | ||
self.readout = SumPooling() | ||
elif readout == 'mean': | ||
self.readout = AvgPooling() | ||
elif readout == 'max': | ||
self.readout = MaxPooling() | ||
elif readout == 'attention': | ||
if JK == 'concat': | ||
self.readout = GlobalAttentionPooling( | ||
gate_nn=nn.Linear((num_layers + 1) * emb_dim, 1)) | ||
else: | ||
self.readout = GlobalAttentionPooling( | ||
gate_nn=nn.Linear(emb_dim, 1)) | ||
else: | ||
raise ValueError("Expect readout to be 'sum', 'mean', " | ||
"'max' or 'attention', got {}".format(readout)) | ||
|
||
if JK == 'concat': | ||
self.predict = nn.Linear((num_layers + 1) * emb_dim, n_tasks) | ||
else: | ||
self.predict = nn.Linear(emb_dim, n_tasks) | ||
|
||
def forward(self, g, categorical_node_feats, categorical_edge_feats): | ||
"""Graph-level regression/soft classification. | ||
Parameters | ||
---------- | ||
g : DGLGraph | ||
DGLGraph for a batch of graphs | ||
categorical_node_feats : list of LongTensor of shape (N) | ||
* Input categorical node features | ||
* len(categorical_node_feats) should be the same as len(num_node_emb_list) | ||
* N is the total number of nodes in the batch of graphs | ||
categorical_edge_feats : list of LongTensor of shape (E) | ||
* Input categorical edge features | ||
* len(categorical_edge_feats) should be the same as | ||
len(num_edge_emb_list) in the arguments | ||
* E is the total number of edges in the batch of graphs | ||
Returns | ||
------- | ||
FloatTensor of shape (B, n_tasks) | ||
* Predictions on graphs | ||
* B for the number of graphs in the batch | ||
""" | ||
node_feats = self.gnn(g, categorical_node_feats, categorical_edge_feats) | ||
graph_feats = self.readout(g, node_feats) | ||
return self.predict(graph_feats) |
Oops, something went wrong.