forked from yewsiang/ConceptBottleneck
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexperiments.py
101 lines (77 loc) · 2.81 KB
/
experiments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import pdb
import sys
def run_experiments(dataset, args):
if dataset == 'OAI':
from OAI.train import (
train_X_to_C,
train_oracle_C_to_y_and_test_on_Chat,
train_Chat_to_y_and_test_on_Chat,
train_X_to_C_to_y,
train_X_to_y,
train_X_to_y_with_aux_C,
train_X_to_Cy,
train_probe,
test_time_intervention,
hyperparameter_optimization
)
elif dataset == 'CUB':
from CUB.train import (
train_X_to_C,
train_oracle_C_to_y_and_test_on_Chat,
train_Chat_to_y_and_test_on_Chat,
train_X_to_C_to_y,
train_X_to_y,
train_X_to_Cy,
train_probe,
test_time_intervention,
robustness,
hyperparameter_optimization
)
experiment = args[0].exp
if experiment == 'Concept_XtoC':
train_X_to_C(*args)
elif experiment == 'Independent_CtoY':
train_oracle_C_to_y_and_test_on_Chat(*args)
elif experiment == 'Sequential_CtoY':
train_Chat_to_y_and_test_on_Chat(*args)
elif experiment == 'Joint':
train_X_to_C_to_y(*args)
elif experiment == 'Standard':
train_X_to_y(*args)
elif experiment == 'StandardWithAuxC':
train_X_to_y_with_aux_C(*args)
elif experiment == 'Multitask':
train_X_to_Cy(*args)
elif experiment == 'Probe':
train_probe(*args)
elif experiment == 'TTI':
test_time_intervention(*args)
elif experiment == 'Robustness':
robustness(*args)
elif experiment == 'HyperparameterSearch':
hyperparameter_optimization(*args)
def parse_arguments():
# First arg must be dataset, and based on which dataset it is, we will parse arguments accordingly
assert len(sys.argv) > 2, 'You need to specify dataset and experiment'
assert sys.argv[1].upper() in ['OAI', 'CUB'], 'Please specify OAI or CUB dataset'
assert sys.argv[2] in ['Concept_XtoC', 'Independent_CtoY', 'Sequential_CtoY',
'Standard', 'StandardWithAuxC', 'Multitask', 'Joint', 'Probe',
'TTI', 'Robustness', 'HyperparameterSearch'], \
'Please specify valid experiment. Current: %s' % sys.argv[2]
dataset = sys.argv[1].upper()
experiment = sys.argv[2].upper()
# Handle accordingly to dataset
if dataset == 'OAI':
from OAI.train import parse_arguments
elif dataset == 'CUB':
from CUB.train import parse_arguments
args = parse_arguments(experiment=experiment)
return dataset, args
if __name__ == '__main__':
import torch
import numpy as np
dataset, args = parse_arguments()
# Seeds
np.random.seed(args[0].seed)
torch.manual_seed(args[0].seed)
run_experiments(dataset, args)