Skip to content

Commit

Permalink
[BugFix] Fix bug in RGCN data processing and use index_select to impr…
Browse files Browse the repository at this point in the history
…ove speed (dmlc#429)

* use index_select instead of __getitem__

* fix bug in dataset processing

* fix edge_type shape bug

* comments
  • Loading branch information
lingfanyu authored Mar 4, 2019
1 parent fb4246e commit d3c24cc
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
6 changes: 3 additions & 3 deletions examples/pytorch/rgcn/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,10 @@ def msg_func(edges):
# an embedding lookup using source node id
embed = weight.view(-1, self.out_feat)
index = edges.data['type'] * self.in_feat + edges.src['id']
return {'msg': embed[index] * edges.data['norm']}
return {'msg': embed.index_select(0, index) * edges.data['norm']}
else:
def msg_func(edges):
w = weight[edges.data['type']]
w = weight.index_select(0, edges.data['type'])
msg = torch.bmm(edges.src['h'].unsqueeze(1), w).squeeze()
msg = msg * edges.data['norm']
return {'msg': msg}
Expand All @@ -119,7 +119,7 @@ def __init__(self, in_feat, out_feat, num_rels, num_bases, bias=None,
nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))

def msg_func(self, edges):
weight = self.weight[edges.data['type']].view(
weight = self.weight.index_select(0, edges.data['type']).view(
-1, self.submat_in, self.submat_out)
node = edges.src['h'].view(-1, 1, self.submat_in)
msg = torch.bmm(node, weight).view(-1, self.out_feat)
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/rgcn/link_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def main(args):

# set node/edge feature
node_id = torch.from_numpy(node_id).view(-1, 1)
edge_type = torch.from_numpy(edge_type).view(-1, 1)
edge_type = torch.from_numpy(edge_type)
node_norm = torch.from_numpy(node_norm).view(-1, 1)
data, labels = torch.from_numpy(data), torch.from_numpy(labels)
deg = g.in_degrees(range(g.number_of_nodes())).float().view(-1, 1)
Expand Down
6 changes: 4 additions & 2 deletions python/dgl/contrib/data/knowledge_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,8 +410,10 @@ def _load_data(dataset_str='aifb', dataset_path=None):
dst = nodes_dict[o]
assert src < num_node and dst < num_node
rel = relations_dict[p]
edge_list.append((src, dst, 2 * rel))
edge_list.append((dst, src, 2 * rel + 1))
# relation id 0 is self-relation, so others should start with 1
edge_list.append((src, dst, 2 * rel + 1))
# reverse relation
edge_list.append((dst, src, 2 * rel + 2))

# sort indices by destination
edge_list = sorted(edge_list, key=lambda x: (x[1], x[0], x[2]))
Expand Down

0 comments on commit d3c24cc

Please sign in to comment.