diff --git a/examples/pytorch/rgcn/README.md b/examples/pytorch/rgcn/README.md index 4a547427e705..283984e2b0d5 100644 --- a/examples/pytorch/rgcn/README.md +++ b/examples/pytorch/rgcn/README.md @@ -40,5 +40,9 @@ python3 entity_classify.py -d am --n-bases=40 --n-hidden=10 --l2norm=5e-4 --test ### Link Prediction FB15k-237: MRR 0.151 (DGL), 0.158 (paper) ``` -python3 link_predict.py -d FB15k-237 --gpu 0 +python3 link_predict.py -d FB15k-237 --gpu 0 --raw +``` +FB15k-237: Filtered-MRR 0.2044 +``` +python3 link_predict.py -d FB15k-237 --gpu 0 --filtered ``` diff --git a/examples/pytorch/rgcn/link_predict.py b/examples/pytorch/rgcn/link_predict.py index 7925233fdae6..328e338e2971 100644 --- a/examples/pytorch/rgcn/link_predict.py +++ b/examples/pytorch/rgcn/link_predict.py @@ -186,8 +186,9 @@ def main(args): model.eval() print("start eval") embed = model(test_graph, test_node_id, test_rel, test_norm) - mrr = utils.calc_mrr(embed, model.w_relation, valid_data, - hits=[1, 3, 10], eval_bz=args.eval_batch_size) + mrr = utils.calc_mrr(embed, model.w_relation, torch.LongTensor(train_data), + valid_data, test_data, hits=[1, 3, 10], eval_bz=args.eval_batch_size, + eval_p=args.eval_protocol) # save best model if mrr < best_mrr: if epoch >= args.n_epochs: @@ -212,8 +213,8 @@ def main(args): model.load_state_dict(checkpoint['state_dict']) print("Using best epoch: {}".format(checkpoint['epoch'])) embed = model(test_graph, test_node_id, test_rel, test_norm) - utils.calc_mrr(embed, model.w_relation, test_data, - hits=[1, 3, 10], eval_bz=args.eval_batch_size) + utils.calc_mrr(embed, model.w_relation, torch.LongTensor(train_data), valid_data, + test_data, hits=[1, 3, 10], eval_bz=args.eval_batch_size, eval_p=args.eval_protocol) if __name__ == '__main__': parser = argparse.ArgumentParser(description='RGCN') @@ -235,6 +236,8 @@ def main(args): help="dataset to use") parser.add_argument("--eval-batch-size", type=int, default=500, help="batch size when evaluating") + parser.add_argument("--eval-protocol", type=str, default="filtered", + help="type of evaluation protocol: 'raw' or 'filtered' mrr") parser.add_argument("--regularization", type=float, default=0.01, help="regularization weight") parser.add_argument("--grad-norm", type=float, default=1.0, diff --git a/examples/pytorch/rgcn/utils.py b/examples/pytorch/rgcn/utils.py index 6c67c1441efe..62ff682e47ed 100644 --- a/examples/pytorch/rgcn/utils.py +++ b/examples/pytorch/rgcn/utils.py @@ -165,7 +165,7 @@ def negative_sampling(pos_samples, num_entity, negative_rate): ####################################################################### # -# Utility function for evaluations +# Utility functions for evaluations (raw) # ####################################################################### @@ -175,7 +175,7 @@ def sort_and_rank(score, target): indices = indices[:, 1].view(-1) return indices -def perturb_and_get_rank(embedding, w, a, r, b, test_size, batch_size=100): +def perturb_and_get_raw_rank(embedding, w, a, r, b, test_size, batch_size=100): """ Perturb one element in the triplets """ n_batch = (test_size + batch_size - 1) // batch_size @@ -197,9 +197,8 @@ def perturb_and_get_rank(embedding, w, a, r, b, test_size, batch_size=100): ranks.append(sort_and_rank(score, target)) return torch.cat(ranks) -# TODO (lingfan): implement filtered metrics # return MRR (raw), and Hits @ (1, 3, 10) -def calc_mrr(embedding, w, test_triplets, hits=[], eval_bz=100): +def calc_raw_mrr(embedding, w, test_triplets, hits=[], eval_bz=100): with torch.no_grad(): s = test_triplets[:, 0] r = test_triplets[:, 1] @@ -207,9 +206,9 @@ def calc_mrr(embedding, w, test_triplets, hits=[], eval_bz=100): test_size = test_triplets.shape[0] # perturb subject - ranks_s = perturb_and_get_rank(embedding, w, o, r, s, test_size, eval_bz) + ranks_s = perturb_and_get_raw_rank(embedding, w, o, r, s, test_size, eval_bz) # perturb object - ranks_o = perturb_and_get_rank(embedding, w, s, r, o, test_size, eval_bz) + ranks_o = perturb_and_get_raw_rank(embedding, w, s, r, o, test_size, eval_bz) ranks = torch.cat([ranks_s, ranks_o]) ranks += 1 # change to 1-indexed @@ -221,3 +220,117 @@ def calc_mrr(embedding, w, test_triplets, hits=[], eval_bz=100): avg_count = torch.mean((ranks <= hit).float()) print("Hits (raw) @ {}: {:.6f}".format(hit, avg_count.item())) return mrr.item() + +####################################################################### +# +# Utility functions for evaluations (filtered) +# +####################################################################### + +def filter_o(triplets_to_filter, target_s, target_r, target_o, num_entities): + target_s, target_r, target_o = int(target_s), int(target_r), int(target_o) + filtered_o = [] + # Do not filter out the test triplet, since we want to predict on it + if (target_s, target_r, target_o) in triplets_to_filter: + triplets_to_filter.remove((target_s, target_r, target_o)) + # Do not consider an object if it is part of a triplet to filter + for o in range(num_entities): + if (target_s, target_r, o) not in triplets_to_filter: + filtered_o.append(o) + return torch.LongTensor(filtered_o) + +def filter_s(triplets_to_filter, target_s, target_r, target_o, num_entities): + target_s, target_r, target_o = int(target_s), int(target_r), int(target_o) + filtered_s = [] + # Do not filter out the test triplet, since we want to predict on it + if (target_s, target_r, target_o) in triplets_to_filter: + triplets_to_filter.remove((target_s, target_r, target_o)) + # Do not consider a subject if it is part of a triplet to filter + for s in range(num_entities): + if (s, target_r, target_o) not in triplets_to_filter: + filtered_s.append(s) + return torch.LongTensor(filtered_s) + +def perturb_o_and_get_filtered_rank(embedding, w, s, r, o, test_size, triplets_to_filter): + """ Perturb object in the triplets + """ + num_entities = embedding.shape[0] + ranks = [] + for idx in range(test_size): + if idx % 100 == 0: + print("test triplet {} / {}".format(idx, test_size)) + target_s = s[idx] + target_r = r[idx] + target_o = o[idx] + filtered_o = filter_o(triplets_to_filter, target_s, target_r, target_o, num_entities) + target_o_idx = int((filtered_o == target_o).nonzero()) + emb_s = embedding[target_s] + emb_r = w[target_r] + emb_o = embedding[filtered_o] + emb_triplet = emb_s * emb_r * emb_o + scores = torch.sigmoid(torch.sum(emb_triplet, dim=1)) + _, indices = torch.sort(scores, descending=True) + rank = int((indices == target_o_idx).nonzero()) + ranks.append(rank) + return torch.LongTensor(ranks) + +def perturb_s_and_get_filtered_rank(embedding, w, s, r, o, test_size, triplets_to_filter): + """ Perturb subject in the triplets + """ + num_entities = embedding.shape[0] + ranks = [] + for idx in range(test_size): + if idx % 100 == 0: + print("test triplet {} / {}".format(idx, test_size)) + target_s = s[idx] + target_r = r[idx] + target_o = o[idx] + filtered_s = filter_s(triplets_to_filter, target_s, target_r, target_o, num_entities) + target_s_idx = int((filtered_s == target_s).nonzero()) + emb_s = embedding[filtered_s] + emb_r = w[target_r] + emb_o = embedding[target_o] + emb_triplet = emb_s * emb_r * emb_o + scores = torch.sigmoid(torch.sum(emb_triplet, dim=1)) + _, indices = torch.sort(scores, descending=True) + rank = int((indices == target_s_idx).nonzero()) + ranks.append(rank) + return torch.LongTensor(ranks) + +def calc_filtered_mrr(embedding, w, train_triplets, valid_triplets, test_triplets, hits=[]): + with torch.no_grad(): + s = test_triplets[:, 0] + r = test_triplets[:, 1] + o = test_triplets[:, 2] + test_size = test_triplets.shape[0] + + triplets_to_filter = torch.cat([train_triplets, valid_triplets, test_triplets]).tolist() + triplets_to_filter = {tuple(triplet) for triplet in triplets_to_filter} + print('Perturbing subject...') + ranks_s = perturb_s_and_get_filtered_rank(embedding, w, s, r, o, test_size, triplets_to_filter) + print('Perturbing object...') + ranks_o = perturb_o_and_get_filtered_rank(embedding, w, s, r, o, test_size, triplets_to_filter) + + ranks = torch.cat([ranks_s, ranks_o]) + ranks += 1 # change to 1-indexed + + mrr = torch.mean(1.0 / ranks.float()) + print("MRR (filtered): {:.6f}".format(mrr.item())) + + for hit in hits: + avg_count = torch.mean((ranks <= hit).float()) + print("Hits (filtered) @ {}: {:.6f}".format(hit, avg_count.item())) + return mrr.item() + +####################################################################### +# +# Main evaluation function +# +####################################################################### + +def calc_mrr(embedding, w, train_triplets, valid_triplets, test_triplets, hits=[], eval_bz=100, eval_p="filtered"): + if eval_p == "filtered": + mrr = calc_filtered_mrr(embedding, w, train_triplets, valid_triplets, test_triplets, hits) + else: + mrr = calc_raw_mrr(embedding, w, test_triplets, hits, eval_bz) + return mrr