Skip to content

Commit

Permalink
Merge pull request mne-tools#3381 from kingjr/search_light_transformer
Browse files Browse the repository at this point in the history
MRG: Add search and generalization lights
  • Loading branch information
dengemann authored Jul 30, 2016
2 parents 50bee33 + ff043ad commit 0c57855
Show file tree
Hide file tree
Showing 5 changed files with 735 additions and 23 deletions.
61 changes: 38 additions & 23 deletions examples/decoding/plot_decoding_time_generalization_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,14 @@
# License: BSD (3-clause)

import numpy as np
import matplotlib.pyplot as plt

import mne
from mne.datasets import sample
from mne.decoding import GeneralizationAcrossTime
from mne.decoding.search_light import GeneralizationLight
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression

print(__doc__)

Expand All @@ -38,40 +42,51 @@
raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif'
events_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw-eve.fif'
raw = mne.io.read_raw_fif(raw_fname, preload=True)
picks = mne.pick_types(raw.info, meg=True, exclude='bads') # Pick MEG channels
# Band pass filter signals
raw.filter(1, 30, method='fir', filter_length='auto',
l_trans_bandwidth='auto', h_trans_bandwidth='auto', phase='zero')
picks = mne.pick_types(raw.info, meg='mag') # Pick magnetometers only
events = mne.read_events(events_fname)
event_id = {'AudL': 1, 'AudR': 2, 'VisL': 3, 'VisR': 4}
decim = 2 # decimate to make the example faster to run
epochs = mne.Epochs(raw, events, event_id, -0.050, 0.400, proj=True,
picks=picks, baseline=None, preload=True,
reject=dict(mag=5e-12), decim=decim, verbose=False)
decim=decim, verbose=False)

# We will train the classifier on all left visual vs auditory trials
# and test on all right visual vs auditory trials.

# In this case, because the test data is independent from the train data,
# we test the classifier of each fold and average the respective predictions.
# we do not need a cross validation.

# Define events of interest
triggers = epochs.events[:, 2]
viz_vs_auditory = np.in1d(triggers, (1, 2)).astype(int)

gat = GeneralizationAcrossTime(predict_mode='mean-prediction', n_jobs=1)

# For our left events, which ones are visual?
viz_vs_auditory_l = (triggers[np.in1d(triggers, (1, 3))] == 3).astype(int)
# To make scikit-learn happy, we converted the bool array to integers
# in the same line. This results in an array of zeros and ones:
print("The unique classes' labels are: %s" % np.unique(viz_vs_auditory_l))

gat.fit(epochs[('AudL', 'VisL')], y=viz_vs_auditory_l)

# For our right events, which ones are visual?
viz_vs_auditory_r = (triggers[np.in1d(triggers, (2, 4))] == 4).astype(int)

gat.score(epochs[('AudR', 'VisR')], y=viz_vs_auditory_r)
gat.plot(
title="Generalization Across Time (visual vs auditory): left to right")
# Each estimator fitted at each time point is an independent Scikit-Learn
# pipeline with a ``fit``, and a ``score`` method.
gat = GeneralizationLight(
make_pipeline(StandardScaler(), LogisticRegression()),
n_jobs=1)

# Fit: for our left events, which ones are visual?
X = epochs[('AudL', 'VisL')].get_data()
y = triggers[np.in1d(triggers, (1, 3))] == 3
gat.fit(X, y)

# Generalize: for our right events, which ones are visual?
X = epochs[('AudR', 'VisR')].get_data()
y = triggers[np.in1d(triggers, (2, 4))] == 4
score = gat.score(X, y)

# Plot temporal generalization accuracies.
extent = epochs.times[[0, -1, 0, -1]]
fig, ax = plt.subplots(1)
im = ax.matshow(score, origin='lower', cmap='RdBu_r', vmin=0., vmax=1.,
extent=extent)
ticks = np.arange(0., .401, .100)
ax.set_xticks(ticks)
ax.set_xticklabels(ticks)
ax.set_yticks(ticks)
ax.set_yticklabels(ticks)
ax.axvline(0, color='k')
ax.axhline(0, color='k')
plt.colorbar(im)
plt.show()
1 change: 1 addition & 0 deletions mne/decoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from .csp import CSP
from .ems import compute_ems, EMS
from .time_gen import GeneralizationAcrossTime, TimeDecoding
from .search_light import SearchLight, GeneralizationLight
13 changes: 13 additions & 0 deletions mne/decoding/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,3 +704,16 @@ def _set_cv(cv, estimator=None, X=None, y=None):
raise ValueError('Some folds do not have any train epochs.')

return cv, cv_splits


def _check_estimator(estimator, get_params=True):
"""Check whether an object has the fit, transform, fit_transform and
get_params methods required by scikit-learn"""
for attr in ('fit', 'transform', 'fit_transform'):
if not hasattr(estimator, attr):
raise ValueError('estimator must be a scikit-learn transformer or '
'an estimator with the %s method' % attr)
if get_params and not hasattr(estimator, attr):
raise ValueError('estimator must be a scikit-learn transformer or an '
'estimator with the get_params method that allows '
'cloning.')
Loading

0 comments on commit 0c57855

Please sign in to comment.