Skip to content

Commit

Permalink
fix link prediction (dmlc#3485)
Browse files Browse the repository at this point in the history
  • Loading branch information
BarclayII authored Nov 9, 2021
1 parent db78fac commit f7360c3
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion tutorials/blitz/4_link_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@

neg_eids = np.random.choice(len(neg_u), g.number_of_edges())
test_neg_u, test_neg_v = neg_u[neg_eids[:test_size]], neg_v[neg_eids[:test_size]]
train_neg_u, train_neg_v = neg_u[neg_eids[train_size:]], neg_v[neg_eids[train_size:]]
train_neg_u, train_neg_v = neg_u[neg_eids[test_size:]], neg_v[neg_eids[test_size:]]


######################################################################
Expand Down
2 changes: 1 addition & 1 deletion tutorials/large/L2_large_link_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def closure():
# You also need to label the edges, 1 if positive and 0 if negative.
#

test_src = torch.cat([test_pos_src, test_neg_src])
test_src = torch.cat([test_pos_src, test_pos_dst])
test_dst = torch.cat([test_neg_src, test_neg_dst])
test_graph = dgl.graph((test_src, test_dst), num_nodes=graph.num_nodes())
test_graph.edata['label'] = torch.cat(
Expand Down

0 comments on commit f7360c3

Please sign in to comment.