forked from mne-tools/mne-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
MRG+1: refactor time decoding & temporal generalization (mne-tools#4103)
refactor time decoding & temporal generalization
- Loading branch information
Showing
14 changed files
with
621 additions
and
300 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,7 +12,7 @@ | |
---------- | ||
.. [1] King & Dehaene (2014) 'Characterizing the dynamics of mental | ||
representations: the temporal generalization method', Trends In | ||
representations: the Temporal Generalization method', Trends In | ||
Cognitive Sciences, 18(4), 203-210. doi: 10.1016/j.tics.2014.01.002. | ||
""" | ||
# Authors: Jean-Remi King <[email protected]> | ||
|
@@ -21,11 +21,15 @@ | |
# | ||
# License: BSD (3-clause) | ||
|
||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
|
||
from sklearn.pipeline import make_pipeline | ||
from sklearn.preprocessing import StandardScaler | ||
from sklearn.linear_model import LogisticRegression | ||
|
||
import mne | ||
from mne.datasets import sample | ||
from mne.decoding import GeneralizationAcrossTime | ||
from mne.decoding import GeneralizingEstimator | ||
|
||
print(__doc__) | ||
|
||
|
@@ -36,36 +40,40 @@ | |
events_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw-eve.fif' | ||
raw = mne.io.read_raw_fif(raw_fname, preload=True) | ||
picks = mne.pick_types(raw.info, meg=True, exclude='bads') # Pick MEG channels | ||
raw.filter(1, 30, method='fft') # Band pass filtering signals | ||
raw.filter(1., 30., method='fft') # Band pass filtering signals | ||
events = mne.read_events(events_fname) | ||
event_id = {'AudL': 1, 'AudR': 2, 'VisL': 3, 'VisR': 4} | ||
event_id = {'Auditory/Left': 1, 'Auditory/Right': 2, | ||
'Visual/Left': 3, 'Visual/Right': 4} | ||
tmin = -0.050 | ||
tmax = 0.400 | ||
decim = 2 # decimate to make the example faster to run | ||
epochs = mne.Epochs(raw, events, event_id, -0.050, 0.400, proj=True, | ||
picks=picks, baseline=None, preload=True, | ||
reject=dict(mag=5e-12), decim=decim, verbose=False) | ||
epochs = mne.Epochs(raw, events, event_id=event_id, tmin=tmin, tmax=tmax, | ||
proj=True, picks=picks, baseline=None, preload=True, | ||
reject=dict(mag=5e-12), decim=decim) | ||
|
||
# We will train the classifier on all left visual vs auditory trials | ||
# and test on all right visual vs auditory trials. | ||
clf = make_pipeline(StandardScaler(), LogisticRegression()) | ||
time_gen = GeneralizingEstimator(clf, scoring='roc_auc', n_jobs=1) | ||
|
||
# In this case, because the test data is independent from the train data, | ||
# we test the classifier of each fold and average the respective predictions. | ||
|
||
# Define events of interest | ||
triggers = epochs.events[:, 2] | ||
viz_vs_auditory = np.in1d(triggers, (1, 2)).astype(int) | ||
|
||
gat = GeneralizationAcrossTime(predict_mode='mean-prediction', n_jobs=1) | ||
|
||
# For our left events, which ones are visual? | ||
viz_vs_auditory_l = (triggers[np.in1d(triggers, (1, 3))] == 3).astype(int) | ||
# To make scikit-learn happy, we converted the bool array to integers | ||
# in the same line. This results in an array of zeros and ones: | ||
print("The unique classes' labels are: %s" % np.unique(viz_vs_auditory_l)) | ||
|
||
gat.fit(epochs[('AudL', 'VisL')], y=viz_vs_auditory_l) | ||
# Fit classifiers on the epochs where the stimulus was presented to the left. | ||
# Note that the experimental condition y indicates auditory or visual | ||
time_gen.fit(X=epochs['Left'].get_data(), | ||
y=epochs['Left'].events[:, 2] > 2) | ||
|
||
# For our right events, which ones are visual? | ||
viz_vs_auditory_r = (triggers[np.in1d(triggers, (2, 4))] == 4).astype(int) | ||
# Score on the epochs where the stimulus was presented to the right. | ||
scores = time_gen.score(X=epochs['Right'].get_data(), | ||
y=epochs['Right'].events[:, 2] > 2) | ||
|
||
gat.score(epochs[('AudR', 'VisR')], y=viz_vs_auditory_r) | ||
gat.plot(title="Temporal Generalization (visual vs auditory): left to right") | ||
# Plot | ||
fig, ax = plt.subplots(1) | ||
im = ax.matshow(scores, vmin=0, vmax=1., cmap='RdBu_r', origin='lower', | ||
extent=epochs.times[[0, -1, 0, -1]]) | ||
ax.axhline(0., color='k') | ||
ax.axvline(0., color='k') | ||
ax.xaxis.set_ticks_position('bottom') | ||
ax.set_xlabel('Testing Time (s)') | ||
ax.set_ylabel('Training Time (s)') | ||
ax.set_title('Generalization across time and condition') | ||
plt.colorbar(im, ax=ax) | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.