Skip to content

Commit

Permalink
metrics added
Browse files Browse the repository at this point in the history
  • Loading branch information
Kearlay committed Sep 18, 2018
1 parent 411a016 commit f1c237b
Showing 1 changed file with 75 additions and 8 deletions.
83 changes: 75 additions & 8 deletions eeg_tensorflow.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# Tensorflow Style Guide\n",
Expand All @@ -41,7 +43,10 @@
"from keras.layers import Conv2D, BatchNormalization, Activation, Flatten, Dense, Dropout, LSTM, Input, TimeDistributed\n",
"from keras import initializers, Model, optimizers, callbacks\n",
"from keras.utils.training_utils import multi_gpu_model\n",
"from keras import backend as K\n",
"from keras.callbacks import Callback\n",
"from sklearn.preprocessing import OneHotEncoder, StandardScaler\n",
"from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score\n",
"\n",
"# Essential Data Handling\n",
"import numpy as np\n",
Expand All @@ -52,8 +57,15 @@
"\n",
"# EEG package\n",
"from mne import pick_types\n",
"from mne.io import read_raw_edf\n",
"\n",
"from mne.io import read_raw_edf"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Eager Execution\n",
"import tensorflow as tf\n",
"\n",
Expand Down Expand Up @@ -305,7 +317,7 @@
},
"outputs": [],
"source": [
"X,y = get_data(FNAMES[:2], epoch_sec=0.0625)"
"X,y = get_data(FNAMES, epoch_sec=0.0625)"
]
},
{
Expand Down Expand Up @@ -523,7 +535,6 @@
"def timeDist(layer, prev_layer, name):\n",
" return TimeDistributed(layer, name=name)(prev_layer)\n",
" \n",
"\n",
"# Input layer\n",
"inputs = Input(shape=input_shape)\n",
"\n",
Expand Down Expand Up @@ -595,6 +606,62 @@
"model = multi_gpu_model(model, gpus=4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"'''\n",
"This metrics calculate sensitivity and specificity batch-wise.\n",
"Keras development team removed this feature because\n",
"these metrics should be understood as global metrics.\n",
"\n",
"I am not using it this time.\n",
"\n",
"# Metrics - sensitivity, specificity, accuracy\n",
"def sens(y_true, y_pred): # Sensitivity\n",
" true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))\n",
" possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))\n",
" return true_positives / (possible_positives + K.epsilon())\n",
"\n",
"def prec(y_true, y_pred): # Precision\n",
" true_negatives = K.sum(K.round(K.clip((1-y_true) * (1-y_pred), 0, 1)))\n",
" possible_negatives = K.sum(K.round(K.clip(1-y_true, 0, 1)))\n",
" return true_negatives / (possible_negatives + K.epsilon())\n",
"'''"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Metrics(Callback):\n",
" '''\n",
" This metrics works as a callback recording f1, recall, and precision following each epoch.\n",
" '''\n",
" def on_train_begin(self, logs={}):\n",
" self.val_f1s = []\n",
" self.val_recalls = []\n",
" self.val_precisions = []\n",
" \n",
" def on_epoch_end(self, epoch, logs={}):\n",
" val_predict = (np.asarray(self.model.predict(self.validation_data[0]))).round()\n",
" val_targ = self.validation_data[1]\n",
" _val_f1 = f1_score(val_targ, val_predict, average='weighted', )\n",
" _val_recall = recall_score(val_targ, val_predict, average='weighted')\n",
" _val_precision = precision_score(val_targ, val_predict, average='weighted')\n",
" self.val_f1s.append(_val_f1)\n",
" self.val_recalls.append(_val_recall)\n",
" self.val_precisions.append(_val_precision)\n",
" print(\"— val_f1: %.2f — val_precision: %.2f — val_recall %.2f\"%(_val_f1, _val_precision, _val_recall))\n",
" return\n",
" \n",
"metrics = Metrics()"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -604,11 +671,12 @@
"callbacks_list = [callbacks.ModelCheckpoint('model.h5', save_best_only=True, monitor='val_loss'),\n",
" callbacks.EarlyStopping(monitor='acc', patience=3),\n",
" callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5),\n",
" callbacks.TensorBoard(log_dir='./my_log_dir/', histogram=1)]\n",
" callbacks.TensorBoard(log_dir='./my_log_dir/', histogram_freq=1)]\n",
"\n",
"# Start training\n",
"model.compile(loss='categorical_crossentropy', optimizer=optimizers.adam(lr=0.001), metrics=['acc'])\n",
"model.fit(X_train, y_train, batch_size=64, epochs=5000, validation_data=(X_test, y_test))"
"history = model.fit(X_train[:100], y_train[:100], batch_size=64, epochs=5000, \n",
" validation_data=(X_test[:100], y_test[:100]), callbacks=[metrics])"
]
},
{
Expand Down Expand Up @@ -768,7 +836,6 @@
" return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(\n",
" logits=inference_fn(inputs), labels=labels))\n",
"\n",
"\n",
"# Calculate accuracy\n",
"def accuracy_fn(inference_fn, inputs, labels):\n",
" prediction = inference_fn(inputs)\n",
Expand Down

0 comments on commit f1c237b

Please sign in to comment.