Skip to content

Commit

Permalink
MRG+1: refactor time decoding & temporal generalization (mne-tools#4103)
Browse files Browse the repository at this point in the history
refactor time decoding & temporal generalization
  • Loading branch information
kingjr authored and agramfort committed Mar 31, 2017
1 parent 20cfc04 commit aa4e189
Show file tree
Hide file tree
Showing 14 changed files with 621 additions and 300 deletions.
63 changes: 38 additions & 25 deletions doc/manual/decoding.rst
Original file line number Diff line number Diff line change
Expand Up @@ -122,42 +122,55 @@ To plot the corresponding filter, you can do::
Sensor-space decoding
=====================

Generalization Across Time
^^^^^^^^^^^^^^^^^^^^^^^^^^
Generalization Across Time (GAT) is a modern strategy to infer neuroscientific conclusions from decoding analysis of sensor-space data. An accuracy matrix is constructed where each point represents the performance of the model trained on one time window and tested on another.
Decoding over time
^^^^^^^^^^^^^^^^^^

.. image:: ../../_images/sphx_glr_plot_decoding_time_generalization_001.png
:align: center
:width: 400px

To use this functionality, simply do::

>>> gat = GeneralizationAcrossTime(predict_mode='cross-validation', n_jobs=1)
>>> gat.fit(epochs)
>>> gat.score(epochs)
>>> gat.plot(vmin=0.1, vmax=0.9, title="Generalization Across Time (faces vs. scrambled)")

.. topic:: Examples:
This strategy consists in fitting a multivariate predictive model on each
time instant and evaluating its performance at the same instant on new
epochs. The :class:`decoding.SlidingEstimator` will take as input a
pair of features :math:`X` and targets :math:`y`, where :math:`X` has
more than 2 dimensions. For decoding over time the data :math:`X`
is the epochs data of shape n_epochs x n_channels x n_times. As the
last dimension of :math:`X` is the time an estimator will be fit
on every time instant.

* :ref:`sphx_glr_auto_examples_decoding_plot_ems_filtering.py`
* :ref:`sphx_glr_auto_examples_decoding_plot_decoding_time_generalization_conditions.py`
This approach is analogous to SlidingEstimator-based approaches in fMRI,
where here we are interested in when one can discriminate experimental
conditions and therefore figure out when the effect of interest happens.

Time Decoding
^^^^^^^^^^^^^
In this strategy, a model trained on one time window is tested on the same time window. A moving time window will thus yield an accuracy curve similar to an ERP, but is considered more sensitive to effects in some situations. It is related to searchlight-based approaches in fMRI. This is also the diagonal of the GAT matrix.
When working with linear models as estimators, this approach boils
down to estimating a discriminative spatial filter for each time instant.

.. image:: ../../_images/sphx_glr_plot_decoding_sensors_001.png
:align: center
:width: 400px

To generate this plot, you need to initialize a GAT object and then use the method ``plot_diagonal``::
To generate this plot see our tutorial :ref:`sphx_glr_auto_tutorials_plot_sensors_decoding.py`.

>>> gat.plot_diagonal()
Temporal Generalization
^^^^^^^^^^^^^^^^^^^^^^^

.. topic:: Examples:
Temporal Generalization is an extension of the decoding over time approach.
It consists in evaluating whether the model estimated at a particular
time instant accurately predicts any other time instant. It is analogous to
transferring a trained model to a distinct learning problem, where the problems
correspond to decoding the patterns of brain activity recorded at distinct time
instants.

The object to for Temporal Generalization is
:class:`decoding.GeneralizingEstimator`. It expects as input :math:`X` and
:math:`y` (similarly to :class:`decoding.SlidingEstimator`) but, when generate
predictions from each model for all time instants. The class
:class:`decoding.GeneralizingEstimator` is generic and will treat the last
dimension as the one to be used for generalization testing. For convenience,
here, we refer to it different tasks. If :math:`X` corresponds to epochs data
then the last dimension is time.

.. image:: ../../_images/sphx_glr_plot_decoding_time_generalization_001.png
:align: center
:width: 400px

* :ref:`sphx_glr_auto_tutorials_plot_sensors_decoding.py`
* :ref:`sphx_glr_auto_examples_decoding_plot_decoding_time_generalization_conditions.py`
To generate this plot see our tutorial :ref:`sphx_glr_auto_tutorials_plot_sensors_decoding.py`.

Source-space decoding
=====================
Expand Down
10 changes: 5 additions & 5 deletions doc/python_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@ Classes
preprocessing.Xdawn
decoding.CSP
decoding.FilterEstimator
decoding.GeneralizationAcrossTime
decoding.PSDEstimator
decoding.Scaler
decoding.TimeDecoding
decoding.ReceptiveField
decoding.Scaler
decoding.SlidingEstimator
decoding.GeneralizingEstimator
realtime.RtEpochs
realtime.RtClient
realtime.MockRtClient
Expand Down Expand Up @@ -1108,16 +1108,16 @@ Classes:
CSP
EMS
FilterEstimator
GeneralizationAcrossTime
LinearModel
PSDEstimator
Scaler
TemporalFilter
TimeDecoding
TimeFrequency
UnsupervisedSpatialFilter
Vectorizer
ReceptiveField
SlidingEstimator
GeneralizingEstimator

Functions:

Expand Down
12 changes: 10 additions & 2 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,14 @@ Changelog

- Add .bvef extension (BrainVision Electrodes File) to :func:`mne.channels.read_montage` by `Jean-Baptiste Schiratti`_

- Add :func:`mne.decoding.cross_val_multiscore` to allow scoring of multiple tasks, typically used with :class:`mne.decoding.SlidingEstimator`, by `Jean-Remi King`_

- Add :class:`mne.decoding.ReceptiveField` module for modeling electrode response to input features by `Chris Holdgraf`_

- Add new :mod:`mne.datasets.mtrf` dataset by `Chris Holdgraf`_

- Add example of time-frequency decoding with CSP by `Laura Gwilliams`_

BUG
~~~

Expand All @@ -38,6 +42,12 @@ API

- Make the goodness of fit (GOF) of the dipoles returned by :func:`mne.beamformer.rap_music` consistent with the GOF of dipoles returned by :func:`mne.fit_dipole` by `Alex Gramfort`_.

- :class:`mne.decoding.SlidingEstimator` will now replace ``mne.decoding.TimeDecoding`` to make it generic and fully compatible with scikit-learn, by `Jean-Remi King`_ and `Alex Gramfort`_

- :class:`mne.decoding.GeneralizingEstimator` will now replace ``mne.decoding.GeneralizationAcrossTime`` to make it generic and fully compatible with scikit-learn, by `Jean-Remi King`_ and `Alex Gramfort`_

- ``mne.viz.decoding.plot_gat_times``, ``mne.viz.decoding.plot_gat_matrix`` are now deprecated. Use matplotlib instead as shown in the examples, by `Jean-Remi King`_ and `Alex Gramfort`_


.. _changes_0_14:

Expand All @@ -47,8 +57,6 @@ Version 0.14
Changelog
~~~~~~~~~

- Add example of time-frequency decoding with CSP by `Laura Gwilliams`_

- Automatically create a legend in :func:`mne.viz.evoked.plot_evoked_topo` by `Jussi Nurminen`_

- Add I/O support for Artemis123 infant/toddler MEG data by `Luke Bloy`_
Expand Down
64 changes: 36 additions & 28 deletions examples/decoding/plot_decoding_time_generalization_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
----------
.. [1] King & Dehaene (2014) 'Characterizing the dynamics of mental
representations: the temporal generalization method', Trends In
representations: the Temporal Generalization method', Trends In
Cognitive Sciences, 18(4), 203-210. doi: 10.1016/j.tics.2014.01.002.
"""
# Authors: Jean-Remi King <[email protected]>
Expand All @@ -21,11 +21,15 @@
#
# License: BSD (3-clause)

import numpy as np
import matplotlib.pyplot as plt

from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression

import mne
from mne.datasets import sample
from mne.decoding import GeneralizationAcrossTime
from mne.decoding import GeneralizingEstimator

print(__doc__)

Expand All @@ -36,36 +40,40 @@
events_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw-eve.fif'
raw = mne.io.read_raw_fif(raw_fname, preload=True)
picks = mne.pick_types(raw.info, meg=True, exclude='bads') # Pick MEG channels
raw.filter(1, 30, method='fft') # Band pass filtering signals
raw.filter(1., 30., method='fft') # Band pass filtering signals
events = mne.read_events(events_fname)
event_id = {'AudL': 1, 'AudR': 2, 'VisL': 3, 'VisR': 4}
event_id = {'Auditory/Left': 1, 'Auditory/Right': 2,
'Visual/Left': 3, 'Visual/Right': 4}
tmin = -0.050
tmax = 0.400
decim = 2 # decimate to make the example faster to run
epochs = mne.Epochs(raw, events, event_id, -0.050, 0.400, proj=True,
picks=picks, baseline=None, preload=True,
reject=dict(mag=5e-12), decim=decim, verbose=False)
epochs = mne.Epochs(raw, events, event_id=event_id, tmin=tmin, tmax=tmax,
proj=True, picks=picks, baseline=None, preload=True,
reject=dict(mag=5e-12), decim=decim)

# We will train the classifier on all left visual vs auditory trials
# and test on all right visual vs auditory trials.
clf = make_pipeline(StandardScaler(), LogisticRegression())
time_gen = GeneralizingEstimator(clf, scoring='roc_auc', n_jobs=1)

# In this case, because the test data is independent from the train data,
# we test the classifier of each fold and average the respective predictions.

# Define events of interest
triggers = epochs.events[:, 2]
viz_vs_auditory = np.in1d(triggers, (1, 2)).astype(int)

gat = GeneralizationAcrossTime(predict_mode='mean-prediction', n_jobs=1)

# For our left events, which ones are visual?
viz_vs_auditory_l = (triggers[np.in1d(triggers, (1, 3))] == 3).astype(int)
# To make scikit-learn happy, we converted the bool array to integers
# in the same line. This results in an array of zeros and ones:
print("The unique classes' labels are: %s" % np.unique(viz_vs_auditory_l))

gat.fit(epochs[('AudL', 'VisL')], y=viz_vs_auditory_l)
# Fit classifiers on the epochs where the stimulus was presented to the left.
# Note that the experimental condition y indicates auditory or visual
time_gen.fit(X=epochs['Left'].get_data(),
y=epochs['Left'].events[:, 2] > 2)

# For our right events, which ones are visual?
viz_vs_auditory_r = (triggers[np.in1d(triggers, (2, 4))] == 4).astype(int)
# Score on the epochs where the stimulus was presented to the right.
scores = time_gen.score(X=epochs['Right'].get_data(),
y=epochs['Right'].events[:, 2] > 2)

gat.score(epochs[('AudR', 'VisR')], y=viz_vs_auditory_r)
gat.plot(title="Temporal Generalization (visual vs auditory): left to right")
# Plot
fig, ax = plt.subplots(1)
im = ax.matshow(scores, vmin=0, vmax=1., cmap='RdBu_r', origin='lower',
extent=epochs.times[[0, -1, 0, -1]])
ax.axhline(0., color='k')
ax.axvline(0., color='k')
ax.xaxis.set_ticks_position('bottom')
ax.set_xlabel('Testing Time (s)')
ax.set_ylabel('Training Time (s)')
ax.set_title('Generalization across time and condition')
plt.colorbar(im, ax=ax)
plt.show()
3 changes: 2 additions & 1 deletion mne/decoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from .transformer import (PSDEstimator, Vectorizer,
UnsupervisedSpatialFilter, TemporalFilter)
from .mixin import TransformerMixin
from .base import BaseEstimator, LinearModel, get_coef
from .base import BaseEstimator, LinearModel, get_coef, cross_val_multiscore
from .csp import CSP
from .ems import compute_ems, EMS
from .time_gen import GeneralizationAcrossTime, TimeDecoding
from .time_frequency import TimeFrequency
from .receptive_field import ReceptiveField
from .search_light import SlidingEstimator, GeneralizingEstimator
Loading

0 comments on commit aa4e189

Please sign in to comment.