-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathintersubjective_training.py
126 lines (105 loc) · 7.4 KB
/
intersubjective_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from src import collect_data_intersubjective, resample_transform, intersubjective_training, save_resutls, intersubjective_shallow, finetune
import os
from os.path import isfile, join
from keras.models import load_model, clone_model
from keras.optimizers import SGD
import sys
#import csv
import pdb
#%%
subjects = [name for name in os.listdir("./data/50/subjects/")]
#test_subjects = ['ab82']
test_subjects = [sys.argv[1]]
batch_size = 64
lr = 0.001
early_stopping = True
epochs = 500
patience = 20
model_config={'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'tanh'}
ft = False
ft_mode = 'all'
ft_trials = [10, 20, 30, 40, 50, 60, 70]
#
datasets = ['50_avg', '250']
for dataset in datasets:
model_names = ['deep_intersubjective_branched_'+str(dataset)+'_thesis1',
'deep_intersubjective_eegnet_'+str(dataset)+'_thesis1',
'deep_intersubjective_cnn_'+str(dataset)+'_thesis1',
'lda_intersubjective_shrinkage_'+str(dataset)+'_thesis1',
'lda_intersubjective_'+str(dataset)+'_thesis1',
'deep_intersubjective_branched_no_bn_'+str(dataset)+'_thesis1',
'deep_intersubjective_branched_no_dropout_'+str(dataset)+'_thesis1',
'deep_intersubjective_branched_no_branched_'+str(dataset)+'_thesis1',
'deep_intersubjective_branched_no_deep_'+str(dataset)+'_thesis1',
'deep_intersubjective_branched_relu_'+str(dataset)+'_thesis1',
'deep_intersubjective_branched_elu_'+str(dataset)+'_thesis1']
model_configs = [{'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'tanh'},
{'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'tanh'},
{'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'tanh'},
{'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'tanh'},
{'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'tanh'},
{'bn':False, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'tanh'},
{'bn':True, 'dropout':False, 'branched':True, 'deep':True, 'nonlinear':'tanh'},
{'bn':True, 'dropout':True, 'branched':False, 'deep':True, 'nonlinear':'tanh'},
{'bn':True, 'dropout':True, 'branched':True, 'deep':False, 'nonlinear':'tanh'},
{'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'relu'},
{'bn':True, 'dropout':True, 'branched':True, 'deep':True, 'nonlinear':'elu'}]
for model_name, model_config in zip(model_names, model_configs):
for test_subject in test_subjects:
print 'working on subject', test_subject
#test_subject = subject
#collecting the data
print 'Collecting data ...'
x_train, y_train, x_test, y_test, o_t_test, o_tr_test = collect_data_intersubjective(subjects,
test_subject,
mode='eeg',
channels=range(29),
frequency=dataset)
#Train
print "Training ..."
if model_name.startswith('deep'):
metrics, history, cnf_matrix = intersubjective_training((x_train, y_train, x_test, y_test, o_t_test, o_tr_test),
model_name,
test_subject,
epochs=epochs,
lr=lr,
batch_size=batch_size,
model_config=model_config,
early_stopping=early_stopping,
patience=patience)
super_final_results = save_resutls([metrics], [history], test_subject,
suffix=model_name,
early_stopping=early_stopping,
patience=patience)
else:
metrics, history, cnf_matrix, clf = intersubjective_shallow((x_train, y_train, x_test, y_test, o_t_test, o_tr_test),
model_name)
super_final_results = save_resutls(metrics, history, test_subject, suffix=model_name, clf=clf)
print 'DA:', model_name,'_', test_subject, ' is:', super_final_results[0]['val_recognition_acc']['mean']
print 'BA:', model_name,'_', test_subject, ' is:', super_final_results[0]['val_balanced_acc']['mean']
print 'Recall:', model_name,'_', test_subject, ' is:', super_final_results[0]['val_recalls']['mean']
print 'precision:', model_name,'_', test_subject, ' is:', super_final_results[0]['val_precisions']['mean']
if model_name.startswith('deep_intersubjective_branched') and ft:
model=clone_model(history.model)
weights = history.model.get_weights()
for i in ft_trials:
model_name_modified = 'deep_intersubjective_branched_ft_'+str(i)+'_trials_'+str(dataset)+'_thesis2'
model_name_modified = model_name+'_ft_'+str(i)+'_trials'
model.set_weights(weights)
metrics, history, cnf_matrix = finetune(model,
(x_test, y_test, o_t_test, o_tr_test),
model_name_modified,
test_subject,
epochs=epochs,
train_trials=i,
mode=ft_mode,
early_stopping=early_stopping,
patience=patience)
super_final_results = save_resutls([metrics], [history], test_subject,
suffix=model_name_modified,
early_stopping=early_stopping,
patience=patience)
print 'DA:', model_name_modified,'_', test_subject, ' is:', super_final_results[0]['val_recognition_acc']['mean']
print 'BA:', model_name_modified,'_', test_subject, ' is:', super_final_results[0]['val_balanced_acc']['mean']
print 'Recall:', model_name_modified,'_', test_subject, ' is:', super_final_results[0]['val_recalls']['mean']
print 'precision:', model_name_modified,'_', test_subject, ' is:', super_final_results[0]['val_precisions']['mean']