Skip to content

Commit

Permalink
modify code with buildin function (dmlc#2394)
Browse files Browse the repository at this point in the history
* modify code with buildin function

* use enisum

Co-authored-by: Zihao Ye <[email protected]>
  • Loading branch information
Maybewuss and yzh119 authored Dec 8, 2020
1 parent 013d145 commit 6c0cc1f
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 48 deletions.
86 changes: 39 additions & 47 deletions examples/pytorch/hgt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
from dgl.ops import edge_softmax

class HGTLayer(nn.Module):
def __init__(self,
Expand All @@ -27,77 +28,67 @@ def __init__(self,
self.d_k = out_dim // n_heads
self.sqrt_dk = math.sqrt(self.d_k)
self.att = None

self.k_linears = nn.ModuleList()
self.q_linears = nn.ModuleList()
self.v_linears = nn.ModuleList()
self.a_linears = nn.ModuleList()
self.norms = nn.ModuleList()
self.use_norm = use_norm

for t in range(self.num_types):
self.k_linears.append(nn.Linear(in_dim, out_dim))
self.q_linears.append(nn.Linear(in_dim, out_dim))
self.v_linears.append(nn.Linear(in_dim, out_dim))
self.a_linears.append(nn.Linear(out_dim, out_dim))
if use_norm:
self.norms.append(nn.LayerNorm(out_dim))

self.relation_pri = nn.Parameter(torch.ones(self.num_relations, self.n_heads))
self.relation_att = nn.Parameter(torch.Tensor(self.num_relations, n_heads, self.d_k, self.d_k))
self.relation_msg = nn.Parameter(torch.Tensor(self.num_relations, n_heads, self.d_k, self.d_k))
self.skip = nn.Parameter(torch.ones(self.num_types))
self.drop = nn.Dropout(dropout)

nn.init.xavier_uniform_(self.relation_att)
nn.init.xavier_uniform_(self.relation_msg)

def edge_attention(self, edges):
etype = edges.data['id'][0]

'''
Step 1: Heterogeneous Mutual Attention
'''
relation_att = self.relation_att[etype]
relation_pri = self.relation_pri[etype]
key = torch.bmm(edges.src['k'].transpose(1,0), relation_att).transpose(1,0)
att = (edges.dst['q'] * key).sum(dim=-1) * relation_pri / self.sqrt_dk

'''
Step 2: Heterogeneous Message Passing
'''
relation_msg = self.relation_msg[etype]
val = torch.bmm(edges.src['v'].transpose(1,0), relation_msg).transpose(1,0)
return {'a': att, 'v': val}

def message_func(self, edges):
return {'v': edges.data['v'], 'a': edges.data['a']}

def reduce_func(self, nodes):
'''
Softmax based on target node's id (edge_index_i).
NOTE: Using DGL's API, there is a minor difference with this softmax with the original one.
This implementation will do softmax only on edges belong to the same relation type, instead of for all of the edges.
'''
att = F.softmax(nodes.mailbox['a'], dim=1)
h = torch.sum(att.unsqueeze(dim = -1) * nodes.mailbox['v'], dim=1)
return {'t': h.view(-1, self.out_dim)}

def forward(self, G, h):
with G.local_scope():
node_dict, edge_dict = self.node_dict, self.edge_dict
for srctype, etype, dsttype in G.canonical_etypes:
sub_graph = G[srctype, etype, dsttype]

k_linear = self.k_linears[node_dict[srctype]]
v_linear = self.v_linears[node_dict[srctype]]
v_linear = self.v_linears[node_dict[srctype]]
q_linear = self.q_linears[node_dict[dsttype]]

G.nodes[srctype].data['k'] = k_linear(h[srctype]).view(-1, self.n_heads, self.d_k)
G.nodes[srctype].data['v'] = v_linear(h[srctype]).view(-1, self.n_heads, self.d_k)
G.nodes[dsttype].data['q'] = q_linear(h[dsttype]).view(-1, self.n_heads, self.d_k)

G.apply_edges(func=self.edge_attention, etype=etype)
G.multi_update_all({etype : (self.message_func, self.reduce_func) \

k = k_linear(h[srctype]).view(-1, self.n_heads, self.d_k)
v = v_linear(h[srctype]).view(-1, self.n_heads, self.d_k)
q = q_linear(h[dsttype]).view(-1, self.n_heads, self.d_k)

e_id = self.edge_dict[etype]

relation_att = self.relation_att[e_id]
relation_pri = self.relation_pri[e_id]
relation_msg = self.relation_msg[e_id]

k = torch.einsum("bij,ijk->bik", k, realtion_att)
v = torch.einsum("bij,ijk->bik", k, relation_msg)

sub_graph.srcdata['k'] = k
sub_graph.dstdata['q'] = q
sub_graph.srcdata['v'] = v

sub_graph.apply_edges(fn.v_dot_u('q', 'k', 't'))
attn_score = sub_graph.edata.pop('t').sum(-1) * relation_pri / self.sqrt_dk
attn_score = edge_softmax(sub_graph, attn_score, norm_by='dst')

sub_graph.edata['t'] = attn_score.unsqueeze(-1)

G.multi_update_all({etype : (fn.u_mul_e('v', 't', 'm'), fn.sum('m', 't')) \
for etype in edge_dict}, cross_reducer = 'mean')

new_h = {}
for ntype in G.ntypes:
'''
Expand All @@ -106,14 +97,15 @@ def forward(self, G, h):
'''
n_id = node_dict[ntype]
alpha = torch.sigmoid(self.skip[n_id])
trans_out = self.drop(self.a_linears[n_id](G.nodes[ntype].data['t']))
t = G.nodes[ntype].data['t'].view(-1, self.out_dim)
trans_out = self.drop(self.a_linears[n_id](t))
trans_out = trans_out * alpha + h[ntype] * (1-alpha)
if self.use_norm:
new_h[ntype] = self.norms[n_id](trans_out)
else:
new_h[ntype] = trans_out
return new_h

class HGT(nn.Module):
def __init__(self, G, node_dict, edge_dict, n_inp, n_hid, n_out, n_layers, n_heads, use_norm = True):
super(HGT, self).__init__()
Expand Down Expand Up @@ -167,8 +159,8 @@ def forward(self, G, feat_dict):
G.multi_update_all(funcs, 'sum')
# return the updated node feature dictionary
return {ntype : G.nodes[ntype].data['h'] for ntype in G.ntypes}


class HeteroRGCN(nn.Module):
def __init__(self, G, in_size, hidden_size, out_size):
super(HeteroRGCN, self).__init__()
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/hgt/train_acm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import argparse

torch.manual_seed(0)
data_url = 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/ACM.mat'
data_url = 'https://data.dgl.ai/dataset/ACM.mat'
data_file_path = '/tmp/ACM.mat'

urllib.request.urlretrieve(data_url, data_file_path)
Expand Down

0 comments on commit 6c0cc1f

Please sign in to comment.