Skip to content

Commit

Permalink
[Bugfix][Model] Fixing RGCN evaluation bug (dmlc#778)
Browse files Browse the repository at this point in the history
  • Loading branch information
sharifza authored and jermainewang committed Aug 20, 2019
1 parent 788420d commit 77c5828
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
4 changes: 2 additions & 2 deletions examples/pytorch/rgcn/link_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def main(args):
model.cpu()
model.eval()
print("start eval")
mrr = utils.evaluate(test_graph, model, valid_data, num_nodes,
mrr = utils.evaluate(test_graph, model, valid_data,
hits=[1, 3, 10], eval_bz=args.eval_batch_size)
# save best model
if mrr < best_mrr:
Expand All @@ -207,7 +207,7 @@ def main(args):
model.eval()
model.load_state_dict(checkpoint['state_dict'])
print("Using best epoch: {}".format(checkpoint['epoch']))
utils.evaluate(test_graph, model, test_data, num_nodes, hits=[1, 3, 10],
utils.evaluate(test_graph, model, test_data, hits=[1, 3, 10],
eval_bz=args.eval_batch_size)


Expand Down
13 changes: 7 additions & 6 deletions examples/pytorch/rgcn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,15 @@ def sort_and_rank(score, target):
indices = indices[:, 1].view(-1)
return indices

def perturb_and_get_rank(embedding, w, a, r, b, num_entity, batch_size=100):
def perturb_and_get_rank(embedding, w, a, r, b, test_size, batch_size=100):
""" Perturb one element in the triplets
"""
n_batch = (num_entity + batch_size - 1) // batch_size
n_batch = (test_size + batch_size - 1) // batch_size
ranks = []
for idx in range(n_batch):
print("batch {} / {}".format(idx, n_batch))
batch_start = idx * batch_size
batch_end = min(num_entity, (idx + 1) * batch_size)
batch_end = min(test_size, (idx + 1) * batch_size)
batch_a = a[batch_start: batch_end]
batch_r = r[batch_start: batch_end]
emb_ar = embedding[batch_a] * w[batch_r]
Expand All @@ -187,17 +187,18 @@ def perturb_and_get_rank(embedding, w, a, r, b, num_entity, batch_size=100):

# TODO (lingfan): implement filtered metrics
# return MRR (raw), and Hits @ (1, 3, 10)
def evaluate(test_graph, model, test_triplets, num_entity, hits=[], eval_bz=100):
def evaluate(test_graph, model, test_triplets, hits=[], eval_bz=100):
with torch.no_grad():
embedding, w = model.evaluate(test_graph)
s = test_triplets[:, 0]
r = test_triplets[:, 1]
o = test_triplets[:, 2]
test_size = test_triplets.shape[0]

# perturb subject
ranks_s = perturb_and_get_rank(embedding, w, o, r, s, num_entity, eval_bz)
ranks_s = perturb_and_get_rank(embedding, w, o, r, s, test_size, eval_bz)
# perturb object
ranks_o = perturb_and_get_rank(embedding, w, s, r, o, num_entity, eval_bz)
ranks_o = perturb_and_get_rank(embedding, w, s, r, o, test_size, eval_bz)

ranks = torch.cat([ranks_s, ranks_o])
ranks += 1 # change to 1-indexed
Expand Down

0 comments on commit 77c5828

Please sign in to comment.