forked from NVIDIA/Megatron-LM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathensemble_classifier.py
149 lines (125 loc) · 5.55 KB
/
ensemble_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os
import argparse
import collections
import numpy as np
import torch
def process_files(args):
all_predictions = collections.OrderedDict()
all_labels = collections.OrderedDict()
all_uid = collections.OrderedDict()
for path in args.paths:
path = os.path.join(path, args.prediction_name)
try:
data = torch.load(path)
for dataset in data:
name, d = dataset
predictions, labels, uid = d
if name not in all_predictions:
all_predictions[name] = np.array(predictions)
if args.labels is None:
args.labels = [i for i in range(all_predictions[name].shape[1])]
if args.eval:
all_labels[name] = np.array(labels)
all_uid[name] = np.array(uid)
else:
all_predictions[name] += np.array(predictions)
assert np.allclose(all_uid[name], np.array(uid))
except Exception as e:
print(e)
continue
return all_predictions, all_labels, all_uid
def get_threshold(all_predictions, all_labels, one_threshold=False):
if one_threshold:
all_predictons = {'combined': np.concatenate(list(all_predictions.values()))}
all_labels = {'combined': np.concatenate(list(all_predictions.labels()))}
out_thresh = []
for dataset in all_predictions:
preds = all_predictions[dataset]
labels = all_labels[dataset]
out_thresh.append(calc_threshold(preds, labels))
return out_thresh
def calc_threshold(p, l):
trials = [(i) * (1. / 100.) for i in range(100)]
best_acc = float('-inf')
best_thresh = 0
for t in trials:
acc = ((apply_threshold(p, t).argmax(-1) == l).astype(float)).mean()
if acc > best_acc:
best_acc = acc
best_thresh = t
return best_thresh
def apply_threshold(preds, t):
assert (np.allclose(preds.sum(-1), np.ones(preds.shape[0])))
prob = preds[:, -1]
thresholded = (prob >= t).astype(int)
preds = np.zeros_like(preds)
preds[np.arange(len(thresholded)), thresholded.reshape(-1)] = 1
return preds
def threshold_predictions(all_predictions, threshold):
if len(threshold) != len(all_predictions):
threshold = [threshold[-1]] * (len(all_predictions) - len(threshold))
for i, dataset in enumerate(all_predictions):
thresh = threshold[i]
preds = all_predictions[dataset]
all_predictions[dataset] = apply_threshold(preds, thresh)
return all_predictions
def postprocess_predictions(all_predictions, all_labels, args):
for d in all_predictions:
all_predictions[d] = all_predictions[d] / len(args.paths)
if args.calc_threshold:
args.threshold = get_threshold(all_predictions, all_labels, args.one_threshold)
print('threshold', args.threshold)
if args.threshold is not None:
all_predictions = threshold_predictions(all_predictions, args.threshold)
return all_predictions, all_labels
def write_predictions(all_predictions, all_labels, all_uid, args):
all_correct = 0
count = 0
for dataset in all_predictions:
preds = all_predictions[dataset]
preds = np.argmax(preds, -1)
if args.eval:
correct = (preds == all_labels[dataset]).sum()
num = len(all_labels[dataset])
accuracy = correct / num
count += num
all_correct += correct
accuracy = (preds == all_labels[dataset]).mean()
print(accuracy)
if not os.path.exists(os.path.join(args.outdir, dataset)):
os.makedirs(os.path.join(args.outdir, dataset))
outpath = os.path.join(
args.outdir, dataset, os.path.splitext(
args.prediction_name)[0] + '.tsv')
with open(outpath, 'w') as f:
f.write('id\tlabel\n')
f.write('\n'.join(str(uid) + '\t' + str(args.labels[p])
for uid, p in zip(all_uid[dataset], preds.tolist())))
if args.eval:
print(all_correct / count)
def ensemble_predictions(args):
all_predictions, all_labels, all_uid = process_files(args)
all_predictions, all_labels = postprocess_predictions(all_predictions, all_labels, args)
write_predictions(all_predictions, all_labels, all_uid, args)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--paths', required=True, nargs='+',
help='paths to checkpoint directories used in ensemble')
parser.add_argument('--eval', action='store_true',
help='compute accuracy metrics against labels (dev set)')
parser.add_argument('--outdir',
help='directory to place ensembled predictions in')
parser.add_argument('--prediction-name', default='test_predictions.pt',
help='name of predictions in checkpoint directories')
parser.add_argument('--calc-threshold', action='store_true',
help='calculate threshold classification')
parser.add_argument('--one-threshold', action='store_true',
help='use on threshold for all subdatasets')
parser.add_argument('--threshold', nargs='+', default=None, type=float,
help='user supplied threshold for classification')
parser.add_argument('--labels', nargs='+', default=None,
help='whitespace separated list of label names')
args = parser.parse_args()
ensemble_predictions(args)
if __name__ == '__main__':
main()