Skip to content

Commit

Permalink
Update model.py (dmlc#2796)
Browse files Browse the repository at this point in the history
  • Loading branch information
BarclayII authored Mar 29, 2021
1 parent 2952f3c commit 66f7fe8
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions examples/pytorch/hgt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,16 @@ def forward(self, G, h):

sub_graph.srcdata['k'] = k
sub_graph.dstdata['q'] = q
sub_graph.srcdata['v'] = v
sub_graph.srcdata['v_%d' % e_id] = v

sub_graph.apply_edges(fn.v_dot_u('q', 'k', 't'))
attn_score = sub_graph.edata.pop('t').sum(-1) * relation_pri / self.sqrt_dk
attn_score = edge_softmax(sub_graph, attn_score, norm_by='dst')

sub_graph.edata['t'] = attn_score.unsqueeze(-1)

G.multi_update_all({etype : (fn.u_mul_e('v', 't', 'm'), fn.sum('m', 't')) \
for etype in edge_dict}, cross_reducer = 'mean')
G.multi_update_all({etype : (fn.u_mul_e('v_%d' % e_id, 't', 'm'), fn.sum('m', 't')) \
for etype, e_id in edge_dict.items()}, cross_reducer = 'mean')

new_h = {}
for ntype in G.ntypes:
Expand Down

0 comments on commit 66f7fe8

Please sign in to comment.