diff --git a/examples/pytorch/gat/train.py b/examples/pytorch/gat/train.py index d76ba396a7d6..ceb6b5bbc634 100644 --- a/examples/pytorch/gat/train.py +++ b/examples/pytorch/gat/train.py @@ -39,11 +39,11 @@ def __init__(self, if feat_drop: self.feat_drop = nn.Dropout(feat_drop) else: - self.feat_drop = None + self.feat_drop = lambda x : x if attn_drop: self.attn_drop = nn.Dropout(attn_drop) else: - self.attn_drop = None + self.attn_drop = lambda x : x self.attn_l = nn.Parameter(torch.Tensor(size=(num_heads, out_dim, 1))) self.attn_r = nn.Parameter(torch.Tensor(size=(num_heads, out_dim, 1))) nn.init.xavier_normal_(self.fc.weight.data, gain=1.414) @@ -60,22 +60,19 @@ def __init__(self, def forward(self, inputs): # prepare - h = inputs # NxD - if self.feat_drop: - h = self.feat_drop(h) + h = self.feat_drop(inputs) # NxD ft = self.fc(h).reshape((h.shape[0], self.num_heads, -1)) # NxHxD' head_ft = ft.transpose(0, 1) # HxNxD' a1 = torch.bmm(head_ft, self.attn_l).transpose(0, 1) # NxHx1 a2 = torch.bmm(head_ft, self.attn_r).transpose(0, 1) # NxHx1 - if self.feat_drop: - ft = self.feat_drop(ft) self.g.ndata.update({'ft' : ft, 'a1' : a1, 'a2' : a2}) # 1. compute edge attention self.g.apply_edges(self.edge_attention) - # 2. compute two results: one is the node features scaled by the dropped, - # unnormalized attention values; another is the normalizer of the attention values. - self.g.update_all([fn.src_mul_edge('ft', 'a_drop', 'ft'), fn.copy_edge('a', 'a')], - [fn.sum('ft', 'ft'), fn.sum('a', 'z')]) + # 2. compute softmax in two parts: exp(x - max(x)) and sum(exp(x - max(x))) + self.edge_softmax() + # 2. compute the aggregated node features scaled by the dropped, + # unnormalized attention values. + self.g.update_all(fn.src_mul_edge('ft', 'a_drop', 'ft'), fn.sum('ft', 'ft')) # 3. apply normalizer ret = self.g.ndata['ft'] / self.g.ndata['z'] # NxHxD' # 4. residual @@ -90,10 +87,17 @@ def forward(self, inputs): def edge_attention(self, edges): # an edge UDF to compute unnormalized attention values from src and dst a = self.leaky_relu(edges.src['a1'] + edges.dst['a2']) - a = torch.exp(a).clamp(-10, 10) # use clamp to avoid overflow - if self.attn_drop: - a_drop = self.attn_drop(a) - return {'a' : a, 'a_drop' : a_drop} + return {'a' : a} + + def edge_softmax(self): + # compute the max + self.g.update_all(fn.copy_edge('a', 'a'), fn.max('a', 'a_max')) + # minus the max and exp + self.g.apply_edges(lambda edges : {'a' : torch.exp(edges.data['a'] - edges.dst['a_max'])}) + # compute dropout + self.g.apply_edges(lambda edges : {'a_drop' : self.attn_drop(edges.data['a'])}) + # compute normalizer + self.g.update_all(fn.copy_edge('a', 'a'), fn.sum('a', 'z')) class GAT(nn.Module): def __init__(self, @@ -247,7 +251,7 @@ def main(args): register_data_args(parser) parser.add_argument("--gpu", type=int, default=-1, help="which GPU to use. Set -1 to use CPU.") - parser.add_argument("--epochs", type=int, default=300, + parser.add_argument("--epochs", type=int, default=200, help="number of training epochs") parser.add_argument("--num-heads", type=int, default=8, help="number of hidden attention heads")