Skip to content

Commit

Permalink
[Feature] Filtered MRR metrics for R-GCN example (dmlc#1298)
Browse files Browse the repository at this point in the history
* Add filtered metrics for R-GCN example

* Add new line to end of file

* Add evaluation protocol argument option for R-GCN example

* Update README

Co-authored-by: xiang song(charlie.song) <[email protected]>
Co-authored-by: Quan (Andy) Gan <[email protected]>
Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
4 people authored Mar 13, 2020
1 parent 87842db commit 6c7c403
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 11 deletions.
6 changes: 5 additions & 1 deletion examples/pytorch/rgcn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,9 @@ python3 entity_classify.py -d am --n-bases=40 --n-hidden=10 --l2norm=5e-4 --test
### Link Prediction
FB15k-237: MRR 0.151 (DGL), 0.158 (paper)
```
python3 link_predict.py -d FB15k-237 --gpu 0
python3 link_predict.py -d FB15k-237 --gpu 0 --raw
```
FB15k-237: Filtered-MRR 0.2044
```
python3 link_predict.py -d FB15k-237 --gpu 0 --filtered
```
11 changes: 7 additions & 4 deletions examples/pytorch/rgcn/link_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,9 @@ def main(args):
model.eval()
print("start eval")
embed = model(test_graph, test_node_id, test_rel, test_norm)
mrr = utils.calc_mrr(embed, model.w_relation, valid_data,
hits=[1, 3, 10], eval_bz=args.eval_batch_size)
mrr = utils.calc_mrr(embed, model.w_relation, torch.LongTensor(train_data),
valid_data, test_data, hits=[1, 3, 10], eval_bz=args.eval_batch_size,
eval_p=args.eval_protocol)
# save best model
if mrr < best_mrr:
if epoch >= args.n_epochs:
Expand All @@ -212,8 +213,8 @@ def main(args):
model.load_state_dict(checkpoint['state_dict'])
print("Using best epoch: {}".format(checkpoint['epoch']))
embed = model(test_graph, test_node_id, test_rel, test_norm)
utils.calc_mrr(embed, model.w_relation, test_data,
hits=[1, 3, 10], eval_bz=args.eval_batch_size)
utils.calc_mrr(embed, model.w_relation, torch.LongTensor(train_data), valid_data,
test_data, hits=[1, 3, 10], eval_bz=args.eval_batch_size, eval_p=args.eval_protocol)

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='RGCN')
Expand All @@ -235,6 +236,8 @@ def main(args):
help="dataset to use")
parser.add_argument("--eval-batch-size", type=int, default=500,
help="batch size when evaluating")
parser.add_argument("--eval-protocol", type=str, default="filtered",
help="type of evaluation protocol: 'raw' or 'filtered' mrr")
parser.add_argument("--regularization", type=float, default=0.01,
help="regularization weight")
parser.add_argument("--grad-norm", type=float, default=1.0,
Expand Down
125 changes: 119 additions & 6 deletions examples/pytorch/rgcn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def negative_sampling(pos_samples, num_entity, negative_rate):

#######################################################################
#
# Utility function for evaluations
# Utility functions for evaluations (raw)
#
#######################################################################

Expand All @@ -175,7 +175,7 @@ def sort_and_rank(score, target):
indices = indices[:, 1].view(-1)
return indices

def perturb_and_get_rank(embedding, w, a, r, b, test_size, batch_size=100):
def perturb_and_get_raw_rank(embedding, w, a, r, b, test_size, batch_size=100):
""" Perturb one element in the triplets
"""
n_batch = (test_size + batch_size - 1) // batch_size
Expand All @@ -197,19 +197,18 @@ def perturb_and_get_rank(embedding, w, a, r, b, test_size, batch_size=100):
ranks.append(sort_and_rank(score, target))
return torch.cat(ranks)

# TODO (lingfan): implement filtered metrics
# return MRR (raw), and Hits @ (1, 3, 10)
def calc_mrr(embedding, w, test_triplets, hits=[], eval_bz=100):
def calc_raw_mrr(embedding, w, test_triplets, hits=[], eval_bz=100):
with torch.no_grad():
s = test_triplets[:, 0]
r = test_triplets[:, 1]
o = test_triplets[:, 2]
test_size = test_triplets.shape[0]

# perturb subject
ranks_s = perturb_and_get_rank(embedding, w, o, r, s, test_size, eval_bz)
ranks_s = perturb_and_get_raw_rank(embedding, w, o, r, s, test_size, eval_bz)
# perturb object
ranks_o = perturb_and_get_rank(embedding, w, s, r, o, test_size, eval_bz)
ranks_o = perturb_and_get_raw_rank(embedding, w, s, r, o, test_size, eval_bz)

ranks = torch.cat([ranks_s, ranks_o])
ranks += 1 # change to 1-indexed
Expand All @@ -221,3 +220,117 @@ def calc_mrr(embedding, w, test_triplets, hits=[], eval_bz=100):
avg_count = torch.mean((ranks <= hit).float())
print("Hits (raw) @ {}: {:.6f}".format(hit, avg_count.item()))
return mrr.item()

#######################################################################
#
# Utility functions for evaluations (filtered)
#
#######################################################################

def filter_o(triplets_to_filter, target_s, target_r, target_o, num_entities):
target_s, target_r, target_o = int(target_s), int(target_r), int(target_o)
filtered_o = []
# Do not filter out the test triplet, since we want to predict on it
if (target_s, target_r, target_o) in triplets_to_filter:
triplets_to_filter.remove((target_s, target_r, target_o))
# Do not consider an object if it is part of a triplet to filter
for o in range(num_entities):
if (target_s, target_r, o) not in triplets_to_filter:
filtered_o.append(o)
return torch.LongTensor(filtered_o)

def filter_s(triplets_to_filter, target_s, target_r, target_o, num_entities):
target_s, target_r, target_o = int(target_s), int(target_r), int(target_o)
filtered_s = []
# Do not filter out the test triplet, since we want to predict on it
if (target_s, target_r, target_o) in triplets_to_filter:
triplets_to_filter.remove((target_s, target_r, target_o))
# Do not consider a subject if it is part of a triplet to filter
for s in range(num_entities):
if (s, target_r, target_o) not in triplets_to_filter:
filtered_s.append(s)
return torch.LongTensor(filtered_s)

def perturb_o_and_get_filtered_rank(embedding, w, s, r, o, test_size, triplets_to_filter):
""" Perturb object in the triplets
"""
num_entities = embedding.shape[0]
ranks = []
for idx in range(test_size):
if idx % 100 == 0:
print("test triplet {} / {}".format(idx, test_size))
target_s = s[idx]
target_r = r[idx]
target_o = o[idx]
filtered_o = filter_o(triplets_to_filter, target_s, target_r, target_o, num_entities)
target_o_idx = int((filtered_o == target_o).nonzero())
emb_s = embedding[target_s]
emb_r = w[target_r]
emb_o = embedding[filtered_o]
emb_triplet = emb_s * emb_r * emb_o
scores = torch.sigmoid(torch.sum(emb_triplet, dim=1))
_, indices = torch.sort(scores, descending=True)
rank = int((indices == target_o_idx).nonzero())
ranks.append(rank)
return torch.LongTensor(ranks)

def perturb_s_and_get_filtered_rank(embedding, w, s, r, o, test_size, triplets_to_filter):
""" Perturb subject in the triplets
"""
num_entities = embedding.shape[0]
ranks = []
for idx in range(test_size):
if idx % 100 == 0:
print("test triplet {} / {}".format(idx, test_size))
target_s = s[idx]
target_r = r[idx]
target_o = o[idx]
filtered_s = filter_s(triplets_to_filter, target_s, target_r, target_o, num_entities)
target_s_idx = int((filtered_s == target_s).nonzero())
emb_s = embedding[filtered_s]
emb_r = w[target_r]
emb_o = embedding[target_o]
emb_triplet = emb_s * emb_r * emb_o
scores = torch.sigmoid(torch.sum(emb_triplet, dim=1))
_, indices = torch.sort(scores, descending=True)
rank = int((indices == target_s_idx).nonzero())
ranks.append(rank)
return torch.LongTensor(ranks)

def calc_filtered_mrr(embedding, w, train_triplets, valid_triplets, test_triplets, hits=[]):
with torch.no_grad():
s = test_triplets[:, 0]
r = test_triplets[:, 1]
o = test_triplets[:, 2]
test_size = test_triplets.shape[0]

triplets_to_filter = torch.cat([train_triplets, valid_triplets, test_triplets]).tolist()
triplets_to_filter = {tuple(triplet) for triplet in triplets_to_filter}
print('Perturbing subject...')
ranks_s = perturb_s_and_get_filtered_rank(embedding, w, s, r, o, test_size, triplets_to_filter)
print('Perturbing object...')
ranks_o = perturb_o_and_get_filtered_rank(embedding, w, s, r, o, test_size, triplets_to_filter)

ranks = torch.cat([ranks_s, ranks_o])
ranks += 1 # change to 1-indexed

mrr = torch.mean(1.0 / ranks.float())
print("MRR (filtered): {:.6f}".format(mrr.item()))

for hit in hits:
avg_count = torch.mean((ranks <= hit).float())
print("Hits (filtered) @ {}: {:.6f}".format(hit, avg_count.item()))
return mrr.item()

#######################################################################
#
# Main evaluation function
#
#######################################################################

def calc_mrr(embedding, w, train_triplets, valid_triplets, test_triplets, hits=[], eval_bz=100, eval_p="filtered"):
if eval_p == "filtered":
mrr = calc_filtered_mrr(embedding, w, train_triplets, valid_triplets, test_triplets, hits)
else:
mrr = calc_raw_mrr(embedding, w, test_triplets, hits, eval_bz)
return mrr

0 comments on commit 6c7c403

Please sign in to comment.