-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Model]PPI dataloader and inductive learning script. (dmlc#395)
* Create ppi.py * Create train_ppi.py * Update train_ppi.py * Update train_ppi.py * Create gat.py * Update train.py * Update train_ppi.py * Update ppi.py * Update train_ppi.py * Update ppi.py * Update train_ppi.py * Update train_ppi.py * Update ppi.py * Update train_ppi.py * update docs and readme
- Loading branch information
1 parent
1ea0bcf
commit 788d8dd
Showing
7 changed files
with
447 additions
and
122 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
""" | ||
Graph Attention Networks in DGL using SPMV optimization. | ||
References | ||
---------- | ||
Paper: https://arxiv.org/abs/1710.10903 | ||
Author's code: https://github.com/PetarV-/GAT | ||
Pytorch implementation: https://github.com/Diego999/pyGAT | ||
""" | ||
|
||
import torch | ||
import torch.nn as nn | ||
import dgl.function as fn | ||
|
||
class GraphAttention(nn.Module): | ||
def __init__(self, | ||
g, | ||
in_dim, | ||
out_dim, | ||
num_heads, | ||
feat_drop, | ||
attn_drop, | ||
alpha, | ||
residual=False): | ||
super(GraphAttention, self).__init__() | ||
self.g = g | ||
self.num_heads = num_heads | ||
self.fc = nn.Linear(in_dim, num_heads * out_dim, bias=False) | ||
if feat_drop: | ||
self.feat_drop = nn.Dropout(feat_drop) | ||
else: | ||
self.feat_drop = lambda x : x | ||
if attn_drop: | ||
self.attn_drop = nn.Dropout(attn_drop) | ||
else: | ||
self.attn_drop = lambda x : x | ||
self.attn_l = nn.Parameter(torch.Tensor(size=(num_heads, out_dim, 1))) | ||
self.attn_r = nn.Parameter(torch.Tensor(size=(num_heads, out_dim, 1))) | ||
nn.init.xavier_normal_(self.fc.weight.data, gain=1.414) | ||
nn.init.xavier_normal_(self.attn_l.data, gain=1.414) | ||
nn.init.xavier_normal_(self.attn_r.data, gain=1.414) | ||
self.leaky_relu = nn.LeakyReLU(alpha) | ||
self.residual = residual | ||
if residual: | ||
if in_dim != out_dim: | ||
self.res_fc = nn.Linear(in_dim, num_heads * out_dim, bias=False) | ||
nn.init.xavier_normal_(self.res_fc.weight.data, gain=1.414) | ||
else: | ||
self.res_fc = None | ||
|
||
def forward(self, inputs): | ||
# prepare | ||
h = self.feat_drop(inputs) # NxD | ||
ft = self.fc(h).reshape((h.shape[0], self.num_heads, -1)) # NxHxD' | ||
head_ft = ft.transpose(0, 1) # HxNxD' | ||
a1 = torch.bmm(head_ft, self.attn_l).transpose(0, 1) # NxHx1 | ||
a2 = torch.bmm(head_ft, self.attn_r).transpose(0, 1) # NxHx1 | ||
self.g.ndata.update({'ft' : ft, 'a1' : a1, 'a2' : a2}) | ||
# 1. compute edge attention | ||
self.g.apply_edges(self.edge_attention) | ||
# 2. compute softmax in two parts: exp(x - max(x)) and sum(exp(x - max(x))) | ||
self.edge_softmax() | ||
# 2. compute the aggregated node features scaled by the dropped, | ||
# unnormalized attention values. | ||
self.g.update_all(fn.src_mul_edge('ft', 'a_drop', 'ft'), fn.sum('ft', 'ft')) | ||
# 3. apply normalizer | ||
ret = self.g.ndata['ft'] / self.g.ndata['z'] # NxHxD' | ||
# 4. residual | ||
if self.residual: | ||
if self.res_fc is not None: | ||
resval = self.res_fc(h).reshape((h.shape[0], self.num_heads, -1)) # NxHxD' | ||
else: | ||
resval = torch.unsqueeze(h, 1) # Nx1xD' | ||
ret = resval + ret | ||
return ret | ||
|
||
def edge_attention(self, edges): | ||
# an edge UDF to compute unnormalized attention values from src and dst | ||
a = self.leaky_relu(edges.src['a1'] + edges.dst['a2']) | ||
return {'a' : a} | ||
|
||
def edge_softmax(self): | ||
# compute the max | ||
self.g.update_all(fn.copy_edge('a', 'a'), fn.max('a', 'a_max')) | ||
# minus the max and exp | ||
self.g.apply_edges(lambda edges : {'a' : torch.exp(edges.data['a'] - edges.dst['a_max'])}) | ||
# compute dropout | ||
self.g.apply_edges(lambda edges : {'a_drop' : self.attn_drop(edges.data['a'])}) | ||
# compute normalizer | ||
self.g.update_all(fn.copy_edge('a', 'a'), fn.sum('a', 'z')) | ||
|
||
class GAT(nn.Module): | ||
def __init__(self, | ||
g, | ||
num_layers, | ||
in_dim, | ||
num_hidden, | ||
num_classes, | ||
heads, | ||
activation, | ||
feat_drop, | ||
attn_drop, | ||
alpha, | ||
residual): | ||
super(GAT, self).__init__() | ||
self.g = g | ||
self.num_layers = num_layers | ||
self.gat_layers = nn.ModuleList() | ||
self.activation = activation | ||
# input projection (no residual) | ||
self.gat_layers.append(GraphAttention( | ||
g, in_dim, num_hidden, heads[0], feat_drop, attn_drop, alpha, False)) | ||
# hidden layers | ||
for l in range(1, num_layers): | ||
# due to multi-head, the in_dim = num_hidden * num_heads | ||
self.gat_layers.append(GraphAttention( | ||
g, num_hidden * heads[l-1], num_hidden, heads[l], | ||
feat_drop, attn_drop, alpha, residual)) | ||
# output projection | ||
self.gat_layers.append(GraphAttention( | ||
g, num_hidden * heads[-2], num_classes, heads[-1], | ||
feat_drop, attn_drop, alpha, residual)) | ||
|
||
def forward(self, inputs): | ||
h = inputs | ||
for l in range(self.num_layers): | ||
h = self.gat_layers[l](h).flatten(1) | ||
h = self.activation(h) | ||
# output projection | ||
logits = self.gat_layers[-1](h).mean(1) | ||
return logits |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.