Skip to content

Commit

Permalink
[Model] update gat (dmlc#390)
Browse files Browse the repository at this point in the history
* update gat: add minus max for softmax

* small fix
  • Loading branch information
jermainewang authored Feb 13, 2019
1 parent 6c3dba8 commit 91b7382
Showing 1 changed file with 20 additions and 16 deletions.
36 changes: 20 additions & 16 deletions examples/pytorch/gat/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ def __init__(self,
if feat_drop:
self.feat_drop = nn.Dropout(feat_drop)
else:
self.feat_drop = None
self.feat_drop = lambda x : x
if attn_drop:
self.attn_drop = nn.Dropout(attn_drop)
else:
self.attn_drop = None
self.attn_drop = lambda x : x
self.attn_l = nn.Parameter(torch.Tensor(size=(num_heads, out_dim, 1)))
self.attn_r = nn.Parameter(torch.Tensor(size=(num_heads, out_dim, 1)))
nn.init.xavier_normal_(self.fc.weight.data, gain=1.414)
Expand All @@ -60,22 +60,19 @@ def __init__(self,

def forward(self, inputs):
# prepare
h = inputs # NxD
if self.feat_drop:
h = self.feat_drop(h)
h = self.feat_drop(inputs) # NxD
ft = self.fc(h).reshape((h.shape[0], self.num_heads, -1)) # NxHxD'
head_ft = ft.transpose(0, 1) # HxNxD'
a1 = torch.bmm(head_ft, self.attn_l).transpose(0, 1) # NxHx1
a2 = torch.bmm(head_ft, self.attn_r).transpose(0, 1) # NxHx1
if self.feat_drop:
ft = self.feat_drop(ft)
self.g.ndata.update({'ft' : ft, 'a1' : a1, 'a2' : a2})
# 1. compute edge attention
self.g.apply_edges(self.edge_attention)
# 2. compute two results: one is the node features scaled by the dropped,
# unnormalized attention values; another is the normalizer of the attention values.
self.g.update_all([fn.src_mul_edge('ft', 'a_drop', 'ft'), fn.copy_edge('a', 'a')],
[fn.sum('ft', 'ft'), fn.sum('a', 'z')])
# 2. compute softmax in two parts: exp(x - max(x)) and sum(exp(x - max(x)))
self.edge_softmax()
# 2. compute the aggregated node features scaled by the dropped,
# unnormalized attention values.
self.g.update_all(fn.src_mul_edge('ft', 'a_drop', 'ft'), fn.sum('ft', 'ft'))
# 3. apply normalizer
ret = self.g.ndata['ft'] / self.g.ndata['z'] # NxHxD'
# 4. residual
Expand All @@ -90,10 +87,17 @@ def forward(self, inputs):
def edge_attention(self, edges):
# an edge UDF to compute unnormalized attention values from src and dst
a = self.leaky_relu(edges.src['a1'] + edges.dst['a2'])
a = torch.exp(a).clamp(-10, 10) # use clamp to avoid overflow
if self.attn_drop:
a_drop = self.attn_drop(a)
return {'a' : a, 'a_drop' : a_drop}
return {'a' : a}

def edge_softmax(self):
# compute the max
self.g.update_all(fn.copy_edge('a', 'a'), fn.max('a', 'a_max'))
# minus the max and exp
self.g.apply_edges(lambda edges : {'a' : torch.exp(edges.data['a'] - edges.dst['a_max'])})
# compute dropout
self.g.apply_edges(lambda edges : {'a_drop' : self.attn_drop(edges.data['a'])})
# compute normalizer
self.g.update_all(fn.copy_edge('a', 'a'), fn.sum('a', 'z'))

class GAT(nn.Module):
def __init__(self,
Expand Down Expand Up @@ -247,7 +251,7 @@ def main(args):
register_data_args(parser)
parser.add_argument("--gpu", type=int, default=-1,
help="which GPU to use. Set -1 to use CPU.")
parser.add_argument("--epochs", type=int, default=300,
parser.add_argument("--epochs", type=int, default=200,
help="number of training epochs")
parser.add_argument("--num-heads", type=int, default=8,
help="number of hidden attention heads")
Expand Down

0 comments on commit 91b7382

Please sign in to comment.