Skip to content

Commit

Permalink
eval added
Browse files Browse the repository at this point in the history
  • Loading branch information
Kearlay committed Sep 26, 2018
1 parent b15c3db commit 8dba00d
Show file tree
Hide file tree
Showing 7 changed files with 724 additions and 419 deletions.
164 changes: 164 additions & 0 deletions eeg_tensorflow.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,170 @@
" average_loss = 0.\n",
" average_acc = 0."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 5. Evaluation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# load in libraries\n",
"import pickle\n",
"import itertools\n",
"from sklearn.metrics import confusion_matrix, classification_report, accuracy_score"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# make directories\n",
"if not os.path.exists('./metrics/'):\n",
" os.makedirs('./metrics/')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def plot_history(history):\n",
" loss_list = [s for s in history.keys() if 'loss' in s and 'val' not in s]\n",
" val_loss_list = [s for s in history.keys() if 'loss' in s and 'val' in s]\n",
" acc_list = [s for s in history.keys() if 'acc' in s and 'val' not in s]\n",
" val_acc_list = [s for s in history.keys() if 'acc' in s and 'val' in s]\n",
" \n",
" if len(loss_list) == 0:\n",
" print('Loss is missing in history')\n",
" return \n",
" \n",
" ## As loss always exists\n",
" epochs = range(1,len(history[loss_list[0]]) + 1)\n",
" \n",
" ## Loss\n",
" plt.figure(1)\n",
" for l in loss_list:\n",
" plt.plot(epochs, history[l], 'b', label='Training loss (' + str(str(format(history[l][-1],'.5f'))+')'))\n",
" for l in val_loss_list:\n",
" plt.plot(epochs, history[l], 'g', label='Validation loss (' + str(str(format(history[l][-1],'.5f'))+')'))\n",
" \n",
" plt.title('Loss')\n",
" plt.xlabel('Epochs')\n",
" plt.ylabel('Loss')\n",
" plt.legend()\n",
" plt.savefig(\"./metrics/loss.png\")\n",
" \n",
" ## Accuracy\n",
" plt.figure(2)\n",
" for l in acc_list:\n",
" plt.plot(epochs, history[l], 'b', label='Training accuracy (' + str(format(history[l][-1],'.5f'))+')')\n",
" for l in val_acc_list: \n",
" plt.plot(epochs, history[l], 'g', label='Validation accuracy (' + str(format(history[l][-1],'.5f'))+')')\n",
"\n",
" plt.title('Accuracy')\n",
" plt.xlabel('Epochs')\n",
" plt.ylabel('Accuracy')\n",
" plt.legend()\n",
" plt.show()\n",
" plt.savefig(\"./metrics/acc.png\")\n",
" \n",
"def plot_confusion_matrix(cm, classes,\n",
" normalize=False,\n",
" cmap=plt.cm.Blues):\n",
" \"\"\"\n",
" This function prints and plots the confusion matrix.\n",
" Normalization can be applied by setting `normalize=True`.\n",
" \"\"\"\n",
" if normalize:\n",
" cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n",
" title='Normalized confusion matrix'\n",
" else:\n",
" title='Confusion matrix'\n",
"\n",
" plt.figure(3)\n",
" plt.imshow(cm, interpolation='nearest', cmap=cmap)\n",
" plt.title(title)\n",
" plt.colorbar()\n",
" tick_marks = np.arange(len(classes))\n",
" plt.xticks(tick_marks, classes, rotation=45)\n",
" plt.yticks(tick_marks, classes)\n",
"\n",
" fmt = '.2f' if normalize else 'd'\n",
" thresh = cm.max() / 2.\n",
" for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):\n",
" plt.text(j, i, format(cm[i, j], fmt),\n",
" horizontalalignment=\"center\",\n",
" color=\"white\" if cm[i, j] > thresh else \"black\")\n",
"\n",
" plt.tight_layout()\n",
" plt.ylabel('True label')\n",
" plt.xlabel('Predicted label')\n",
" plt.savefig(\"./metrics/confuMat.png\")\n",
" plt.show()\n",
" \n",
"def full_multiclass_report(model,\n",
" x,\n",
" y_true,\n",
" classes):\n",
" \n",
" # 2. Predict classes and stores in y_pred\n",
" y_pred = model.predict(x).argmax(axis=1)\n",
" \n",
" # 3. Print accuracy score\n",
" print(\"Accuracy : \"+ str(accuracy_score(y_true,y_pred)))\n",
" \n",
" print(\"\")\n",
" \n",
" # 4. Print classification report\n",
" print(\"Classification Report\")\n",
" print(classification_report(y_true,y_pred,digits=4)) \n",
" \n",
" # 5. Plot confusion matrix\n",
" cnf_matrix = confusion_matrix(y_true,y_pred)\n",
" print(cnf_matrix)\n",
" plot_confusion_matrix(cnf_matrix,classes=classes) "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Load in the data\n",
"howManyTest = 0.2\n",
"\n",
"thisInd = np.random.randint(0, len(X_test), size=(len(X_test)//howManyTest))\n",
"X_conf, y_conf = X_test[[i for i in thisInd], :], y_test[[i for i in thisInd],:] \n",
"\n",
"'''\n",
"## Only if you have a previous model + history\n",
"# Get the model\n",
"model = models.load_model('./model/model0.h5')\n",
"\n",
"# Get the history\n",
"with open('./history/history0.pkl', 'rb') as hist:\n",
" history = pickle.load(hist)\n",
"'''\n",
"\n",
"# Get the graphics\n",
"plot_history(history)\n",
"X_test = X_test.reshape(X_test.shape[0], X_train.shape[1], X_train.shape[2], X_train.shape[3], 1)\n",
"full_multiclass_report(model,\n",
" X_test,\n",
" y_test.argmax(axis=1),\n",
" [1,2,3,4,5])"
]
}
],
"metadata": {
Expand Down
70 changes: 29 additions & 41 deletions py/eeg_data_downloads.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,30 @@
'''
Name: eeg_data_downloads.py
Author: Jim Chung
Dependencies: re, request, pathlib, urllib, os
Description: Executing this script will generate folders and start downloading
PhysioNet motor imagery data. Please refer to the official descriptions for
the dataset. Labels and recording protocols were elaborated on the website.
https://www.physionet.org/pn4/eegmmidb/
'''


import re
import requests
import pathlib
import urllib
import os

#%%
CONTEXT = 'pn4/'
MATERIAL = 'eegmmidb/'
URL = 'https://www.physionet.org/' + CONTEXT + MATERIAL

USERDIR = './data/' # Change this directory according to your setting

page = requests.get(URL).text
FOLDERS = sorted(list(set(re.findall(r'S[0-9]+', page))))

URLS = [URL+x+'/' for x in FOLDERS]

# Warning: Executing this block will create folders
for folder in FOLDERS:
pathlib.Path(USERDIR +'/'+ folder).mkdir(parents=True, exist_ok=True)

# Warning: Executing this block will start downloading data
for i, folder in enumerate(FOLDERS):
page = requests.get(URLS[i]).text
subs = list(set(re.findall(r'S[0-9]+R[0-9]+', page)))

print('Working on {}, {:.1%} completed'.format(folder, (i+1)/len(FOLDERS)))
for sub in subs:
import re
import requests
import pathlib
import urllib
import os

#%%
CONTEXT = 'pn4/'
MATERIAL = 'eegmmidb/'
URL = 'https://www.physionet.org/' + CONTEXT + MATERIAL

USERDIR = './data/' # Change this directory according to your setting

page = requests.get(URL).text
FOLDERS = sorted(list(set(re.findall(r'S[0-9]+', page))))

URLS = [URL+x+'/' for x in FOLDERS]

# Warning: Executing this block will create folders
for folder in FOLDERS:
pathlib.Path(USERDIR +'/'+ folder).mkdir(parents=True, exist_ok=True)

# Warning: Executing this block will start downloading data
for i, folder in enumerate(FOLDERS):
page = requests.get(URLS[i]).text
subs = list(set(re.findall(r'S[0-9]+R[0-9]+', page)))

print('Working on {}, {:.1%} completed'.format(folder, (i+1)/len(FOLDERS)))
for sub in subs:
urllib.request.urlretrieve(URLS[i]+sub+'.edf', os.path.join(USERDIR, folder, sub+'.edf'))
149 changes: 149 additions & 0 deletions py/eeg_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
'''
Name: eeg_eval.py
Author: Jim Chung
Description:
This script is written to monitor the performance of neural network trained
on PhysioNet EEG data. Please check out eeg_main.py or eeg_import_py for
further information.
'hitory.pkl' file requied in './history/' folder.
'''

# load in libraries
import pickle
import matplotlib.pyplot as plt
import numpy as np
import itertools
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
from keras import models
from eeg_import import get_data, FNAMES
from eeg_preprocessing import prepare_data
import os

# make directories
if not os.path.exists('./metrics/'):
os.makedirs('./metrics/')

# functions defined


def plot_history(history):
loss_list = [s for s in history.keys() if 'loss' in s and 'val' not in s]
val_loss_list = [s for s in history.keys() if 'loss' in s and 'val' in s]
acc_list = [s for s in history.keys() if 'acc' in s and 'val' not in s]
val_acc_list = [s for s in history.keys() if 'acc' in s and 'val' in s]

if len(loss_list) == 0:
print('Loss is missing in history')
return

## As loss always exists
epochs = range(1,len(history[loss_list[0]]) + 1)

## Loss
plt.figure(1)
for l in loss_list:
plt.plot(epochs, history[l], 'b', label='Training loss (' + str(str(format(history[l][-1],'.5f'))+')'))
for l in val_loss_list:
plt.plot(epochs, history[l], 'g', label='Validation loss (' + str(str(format(history[l][-1],'.5f'))+')'))

plt.title('Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.savefig("./metrics/loss.png")

## Accuracy
plt.figure(2)
for l in acc_list:
plt.plot(epochs, history[l], 'b', label='Training accuracy (' + str(format(history[l][-1],'.5f'))+')')
for l in val_acc_list:
plt.plot(epochs, history[l], 'g', label='Validation accuracy (' + str(format(history[l][-1],'.5f'))+')')

plt.title('Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
plt.savefig("./metrics/acc.png")

def plot_confusion_matrix(cm, classes,
normalize=False,
cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
title='Normalized confusion matrix'
else:
title='Confusion matrix'

plt.figure(3)
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)

fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")

plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.savefig("./metrics/confuMat.png")
plt.show()

def full_multiclass_report(model,
x,
y_true,
classes):

# 2. Predict classes and stores in y_pred
y_pred = model.predict(x).argmax(axis=1)

# 3. Print accuracy score
print("Accuracy : "+ str(accuracy_score(y_true,y_pred)))

print("")

# 4. Print classification report
print("Classification Report")
print(classification_report(y_true,y_pred,digits=4))

# 5. Plot confusion matrix
cnf_matrix = confusion_matrix(y_true,y_pred)
print(cnf_matrix)
plot_confusion_matrix(cnf_matrix,classes=classes)


# Load in the data
howManyTest = 10
this = np.random.randint(1, 100, size=howManyTest)
X,y = get_data([FNAMES[i] for i in this], epoch_sec=0.0625)
X_train, y_train, X_test, y_test = prepare_data(X, y)

print(X.shape)
print(y.shape)

# Get the model
model = models.load_model('./model/model0.h5')

# Get the history
with open('./history/history0.pkl', 'rb') as hist:
history = pickle.load(hist)

# Get the graphics
plot_history(history)
X_test = X_test.reshape(X_test.shape[0], X_train.shape[1], X_train.shape[2], X_train.shape[3], 1)
full_multiclass_report(model,
X_test,
y_test.argmax(axis=1),
[1,2,3,4,5])
Loading

0 comments on commit 8dba00d

Please sign in to comment.