-
Notifications
You must be signed in to change notification settings - Fork 18
/
hyperparameters.py
executable file
·137 lines (109 loc) · 5.07 KB
/
hyperparameters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#!/usr/bin/env python3
"""
From the results of kamiak_{train,eval}_tune.srun pick the best hyperparameters
for each dataset-method pair
Outputs/prints a hyperparameter dictionary to put in pick_multi_source.py,
specifying which hyperparameters to pass during training.
"""
from absl import app
from absl import flags
from analysis import get_tuning_files, all_stats
FLAGS = flags.FLAGS
flags.DEFINE_enum("selection", "best_source", ["best_source", "best_target"], "Which model to select")
# Which parameters were changed during tuning, must be saved/available in the
# config.yaml file
parameter_list = [
"batch_division",
"train_batch",
"lr",
]
def params_to_str(params):
args = []
for i, param in enumerate(params):
args.append("--"+parameter_list[i]+"="+str(param))
return " ".join(args)
def main(argv):
dataset = "tune2"
variant = FLAGS.selection
files = get_tuning_files("results", prefix="results_"+dataset+"_"+variant+"-")
tuning_results = all_stats(files)
# Group by [dataset][method][hyperparams] since we want to select the best
# hyperparameters for each dataset-method pair.
indexed = {}
for result in tuning_results:
config = result["parameters"]["config"]
dataset = config["dataset"]
method = config["method"]
hyperparams = tuple([config[p] for p in parameter_list])
if dataset not in indexed:
indexed[dataset] = {}
if method not in indexed[dataset]:
indexed[dataset][method] = {}
if hyperparams not in indexed[dataset][method]:
indexed[dataset][method][hyperparams] = []
indexed[dataset][method][hyperparams].append(result)
# Average over the runs for each set of parameters, then pick the one with
# the highest source or target accuracy. Also, output the target accuracy
# we would get using these parameters (note: on target valid set, not test
# set). Hopefully it doesn't change much from the hyperparameters picked
# with source vs. target accuracy.
print("Dataset;Method;BestAccuracy;TargetAccuracy;BestHyperParameters")
final_params = {}
for dataset in indexed.keys():
for method in indexed[dataset].keys():
best_acc = 0
best_target_acc = []
best_params = []
# Average over runs with each set of hyperparameters
for hyperparams in indexed[dataset][method].keys():
accuracies = []
target_accuracies = []
for result in indexed[dataset][method][hyperparams]:
avgs = result["averages"]
# Note: "Test" is actually the validation data since in
# kamiak_eval_tune.srun we pass --notest to main_eval.py
source_acc = avgs[avgs["Dataset"] == "Test A"]["Avg"].values[0]
target_acc = avgs[avgs["Dataset"] == "Test B"]["Avg"].values[0]
if variant == "best_source":
accuracies.append(source_acc)
else:
accuracies.append(target_acc)
target_accuracies.append(target_acc)
if len(accuracies) > 0:
accuracy = sum(accuracies)/len(accuracies)
target_accuracy = sum(target_accuracies)/len(target_accuracies)
else:
accuracy = 0
target_accuracy = 0
print("Warning: no runs found for", dataset, method, hyperparams)
# Update best parameters if this is better, if it's the same,
# then add to the list of good parameters
if accuracy > best_acc:
best_acc = accuracy
best_target_acc = [target_accuracy]
best_params = [hyperparams]
elif accuracy == best_acc:
best_params.append(hyperparams)
best_target_acc.append(target_accuracy)
# Print out the best we found
best_target_acc = sum(best_target_acc)/len(best_target_acc)
print(dataset, method, best_acc, best_target_acc, *best_params, sep=";")
# Save final parameters to put into pick_multi_source.py
if dataset not in final_params:
final_params[dataset] = {}
assert method not in final_params[dataset]
final_params[dataset][method] = best_params
# Output dictionary for final parameters for pick_multi_source.py
print("hyperparameters = {")
for dataset in final_params.keys():
print(" \""+dataset+"\": {")
for method in final_params[dataset].keys():
best_params = final_params[dataset][method]
print(" \""+method+"\": \""+params_to_str(best_params[0])+"\",")
# Alternative parameters for equivalent accuracy
for i in range(1, len(best_params)):
print(" #\""+method+"\": \""+params_to_str(best_params[i])+"\",")
print(" },")
print("}")
if __name__ == "__main__":
app.run(main)