diff --git a/examples/pytorch/rgcn/link_predict.py b/examples/pytorch/rgcn/link_predict.py index 62cacce926a7..c2fa19ab7838 100644 --- a/examples/pytorch/rgcn/link_predict.py +++ b/examples/pytorch/rgcn/link_predict.py @@ -182,7 +182,7 @@ def main(args): model.cpu() model.eval() print("start eval") - mrr = utils.evaluate(test_graph, model, valid_data, num_nodes, + mrr = utils.evaluate(test_graph, model, valid_data, hits=[1, 3, 10], eval_bz=args.eval_batch_size) # save best model if mrr < best_mrr: @@ -207,7 +207,7 @@ def main(args): model.eval() model.load_state_dict(checkpoint['state_dict']) print("Using best epoch: {}".format(checkpoint['epoch'])) - utils.evaluate(test_graph, model, test_data, num_nodes, hits=[1, 3, 10], + utils.evaluate(test_graph, model, test_data, hits=[1, 3, 10], eval_bz=args.eval_batch_size) diff --git a/examples/pytorch/rgcn/utils.py b/examples/pytorch/rgcn/utils.py index 8c33ce13aef8..15f21aaf635c 100644 --- a/examples/pytorch/rgcn/utils.py +++ b/examples/pytorch/rgcn/utils.py @@ -163,15 +163,15 @@ def sort_and_rank(score, target): indices = indices[:, 1].view(-1) return indices -def perturb_and_get_rank(embedding, w, a, r, b, num_entity, batch_size=100): +def perturb_and_get_rank(embedding, w, a, r, b, test_size, batch_size=100): """ Perturb one element in the triplets """ - n_batch = (num_entity + batch_size - 1) // batch_size + n_batch = (test_size + batch_size - 1) // batch_size ranks = [] for idx in range(n_batch): print("batch {} / {}".format(idx, n_batch)) batch_start = idx * batch_size - batch_end = min(num_entity, (idx + 1) * batch_size) + batch_end = min(test_size, (idx + 1) * batch_size) batch_a = a[batch_start: batch_end] batch_r = r[batch_start: batch_end] emb_ar = embedding[batch_a] * w[batch_r] @@ -187,17 +187,18 @@ def perturb_and_get_rank(embedding, w, a, r, b, num_entity, batch_size=100): # TODO (lingfan): implement filtered metrics # return MRR (raw), and Hits @ (1, 3, 10) -def evaluate(test_graph, model, test_triplets, num_entity, hits=[], eval_bz=100): +def evaluate(test_graph, model, test_triplets, hits=[], eval_bz=100): with torch.no_grad(): embedding, w = model.evaluate(test_graph) s = test_triplets[:, 0] r = test_triplets[:, 1] o = test_triplets[:, 2] + test_size = test_triplets.shape[0] # perturb subject - ranks_s = perturb_and_get_rank(embedding, w, o, r, s, num_entity, eval_bz) + ranks_s = perturb_and_get_rank(embedding, w, o, r, s, test_size, eval_bz) # perturb object - ranks_o = perturb_and_get_rank(embedding, w, s, r, o, num_entity, eval_bz) + ranks_o = perturb_and_get_rank(embedding, w, s, r, o, test_size, eval_bz) ranks = torch.cat([ranks_s, ranks_o]) ranks += 1 # change to 1-indexed