-
Notifications
You must be signed in to change notification settings - Fork 2
/
evalute_score.py
97 lines (77 loc) · 3 KB
/
evalute_score.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from prettytable import PrettyTable
from scipy.stats import spearmanr, pearsonr, kendalltau
import json
import re
import argparse
def calculate_correlation(pred_score, human_score, result):
assert len(pred_score) == len(human_score)
if (len(result) == 0):
result = {'pearson': 0, 'spearman': 0, 'kendalltau': 0}
result['pearson'] += pearsonr(pred_score, human_score)[0]
result['spearman'] += spearmanr(pred_score, human_score)[0]
result['kendalltau'] += kendalltau(pred_score, human_score)[0]
return result
def extract_numbers(text):
numbers = re.findall(r'\d+', text)
return [int(num) for num in numbers]
def print_correlations(result, n):
table = PrettyTable(['Pearson', 'Spearman', 'Kendall'])
if (n == 0):
n = 1
table.add_row(
[round(result['pearson'] / n, 4), round(result['spearman'] / n, 4), round(result['kendalltau'] / n, 4)])
print(table)
def parse_output(output):
matched = re.search("^ ?([\d\.]+)", output)
if (matched):
try:
score = float(matched.group(1))
except:
score = 0
else:
score = 0
return score
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--input_fp', type=str, default='results/gpt4_flu_detailed_openai.json')
parser.add_argument('--dimension', type=str, default='fluency')
args = parser.parse_args()
jobj = json.load(open(args.input_fp))
pred_scores, human_scores = {}, {}
strange_idx=[]
for idx, item in enumerate(jobj):
doc_id = item["doc_id"]
if (doc_id not in pred_scores):
pred_scores[doc_id] = []
human_scores[doc_id] = []
response = item[args.dimension]
if args.dimension == 'fluency':
try:
score = int(list(str(int(extract_numbers(response)[-1])))[-1])
except:
continue
pred_scores[doc_id].append(score)
human_score = item['human_score']
human_scores[doc_id].append(human_score)
else:
try:
score = int(list(str(int(extract_numbers(response)[0])))[0])
pred_scores[doc_id].append(score)
human_score = item['human_score']
human_scores[doc_id].append(human_score)
except:
strange_idx.append(idx)
continue
print('len(pred_scores): {}'.format(len(pred_scores)))
print('len(human_scores): {}'.format(len(human_scores)))
print('len(strange_idx): ', len(strange_idx))
results = {'pearson': 0, 'spearman': 0, 'kendalltau': 0}
d_ctr = 0
for doc_id in pred_scores:
pred_scores_doc = pred_scores[doc_id]
human_scores_doc = human_scores[doc_id]
if (len(set(human_scores_doc)) <= 1) or (len(set(pred_scores_doc)) <= 1):
continue
results = calculate_correlation(pred_scores_doc, human_scores_doc, results)
d_ctr += 1
print_correlations(results, n=d_ctr)