diff --git a/examples/pytorch/rgcn/layers.py b/examples/pytorch/rgcn/layers.py index 974b8d4a3e5a..f96a16d9a745 100644 --- a/examples/pytorch/rgcn/layers.py +++ b/examples/pytorch/rgcn/layers.py @@ -89,10 +89,10 @@ def msg_func(edges): # an embedding lookup using source node id embed = weight.view(-1, self.out_feat) index = edges.data['type'] * self.in_feat + edges.src['id'] - return {'msg': embed[index] * edges.data['norm']} + return {'msg': embed.index_select(0, index) * edges.data['norm']} else: def msg_func(edges): - w = weight[edges.data['type']] + w = weight.index_select(0, edges.data['type']) msg = torch.bmm(edges.src['h'].unsqueeze(1), w).squeeze() msg = msg * edges.data['norm'] return {'msg': msg} @@ -119,7 +119,7 @@ def __init__(self, in_feat, out_feat, num_rels, num_bases, bias=None, nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu')) def msg_func(self, edges): - weight = self.weight[edges.data['type']].view( + weight = self.weight.index_select(0, edges.data['type']).view( -1, self.submat_in, self.submat_out) node = edges.src['h'].view(-1, 1, self.submat_in) msg = torch.bmm(node, weight).view(-1, self.out_feat) diff --git a/examples/pytorch/rgcn/link_predict.py b/examples/pytorch/rgcn/link_predict.py index af27723ccffd..bb961bd9835a 100644 --- a/examples/pytorch/rgcn/link_predict.py +++ b/examples/pytorch/rgcn/link_predict.py @@ -149,7 +149,7 @@ def main(args): # set node/edge feature node_id = torch.from_numpy(node_id).view(-1, 1) - edge_type = torch.from_numpy(edge_type).view(-1, 1) + edge_type = torch.from_numpy(edge_type) node_norm = torch.from_numpy(node_norm).view(-1, 1) data, labels = torch.from_numpy(data), torch.from_numpy(labels) deg = g.in_degrees(range(g.number_of_nodes())).float().view(-1, 1) diff --git a/python/dgl/contrib/data/knowledge_graph.py b/python/dgl/contrib/data/knowledge_graph.py index 3ede042c949e..af583e5eed5e 100644 --- a/python/dgl/contrib/data/knowledge_graph.py +++ b/python/dgl/contrib/data/knowledge_graph.py @@ -410,8 +410,10 @@ def _load_data(dataset_str='aifb', dataset_path=None): dst = nodes_dict[o] assert src < num_node and dst < num_node rel = relations_dict[p] - edge_list.append((src, dst, 2 * rel)) - edge_list.append((dst, src, 2 * rel + 1)) + # relation id 0 is self-relation, so others should start with 1 + edge_list.append((src, dst, 2 * rel + 1)) + # reverse relation + edge_list.append((dst, src, 2 * rel + 2)) # sort indices by destination edge_list = sorted(edge_list, key=lambda x: (x[1], x[0], x[2]))