Skip to content

Commit

Permalink
MRG epochs.average(method=...) for robust averaging (mne-tools#5402)
Browse files Browse the repository at this point in the history
  • Loading branch information
jona-sassenhagen authored and agramfort committed Sep 21, 2018
1 parent 02e17ae commit c248bd4
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 25 deletions.
2 changes: 2 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ Changelog

- Add :func:`mne.open_report` to read back a :class:`mne.Report` object that was saved to an HDF5 file by `Marijn van Vliet`_

- :meth:`mne.Epochs.average` now supports custom, e.g. robust, averaging methods, by `Jona Sassenhagen`_

Bug
~~~

Expand Down
65 changes: 51 additions & 14 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,14 +813,22 @@ def __next__(self, *args, **kwargs):
"""Provide a wrapper for Py3k."""
return self.next(*args, **kwargs)

def average(self, picks=None):
"""Compute average of epochs.
def average(self, picks=None, method="mean"):
"""Compute an average over epochs.
Parameters
----------
picks : array-like of int | None
If None only MEG, EEG, SEEG, ECoG, and fNIRS channels are kept
otherwise the channels indices in picks are kept.
method : str | callable
How to combine the data. If "mean"/"median", the mean/median
are returned.
Otherwise, must be a callable which, when passed an array of shape
(n_epochs, n_channels, n_time) returns an array of shape
(n_channels, n_time).
Note that due to file type limitations, the kind for all
these will be "average".
Returns
-------
Expand All @@ -837,8 +845,18 @@ def average(self, picks=None):
are selected, resulting in an error. This is because ICA channels
are not considered data channels (they are of misc type) and only data
channels are selected when picks is None.
The `method` parameter allows e.g. robust averaging.
For example, one could do:
>>> from scipy.stats import trim_mean # doctest:+SKIP
>>> trim = lambda x: trim_mean(x, 10, axis=0) # doctest:+SKIP
>>> epochs.average(method=trim) # doctest:+SKIP
This would compute the trimmed mean.
"""
return self._compute_mean_or_stderr(picks, 'ave')
return self._compute_aggregate(picks=picks, mode=method)

def standard_error(self, picks=None):
"""Compute standard error over epochs.
Expand All @@ -854,12 +872,10 @@ def standard_error(self, picks=None):
evoked : instance of Evoked
The standard error over epochs.
"""
return self._compute_mean_or_stderr(picks, 'stderr')
return self._compute_aggregate(picks, "std")

def _compute_mean_or_stderr(self, picks, mode='ave'):
def _compute_aggregate(self, picks, mode='mean'):
"""Compute the mean or std over epochs and return Evoked."""
_do_std = True if mode == 'stderr' else False

# if instance contains ICA channels they won't be included unless picks
# is specified
if picks is None:
Expand All @@ -876,10 +892,29 @@ def _compute_mean_or_stderr(self, picks, mode='ave'):

if self.preload:
n_events = len(self.events)
fun = np.std if _do_std else np.mean
data = fun(self._data, axis=0)

if mode == "mean":
def fun(data):
return np.mean(data, axis=0)
elif mode == "std":
def fun(data):
return np.std(data, axis=0)
elif callable(mode):
fun = mode
else:
raise ValueError("mode must be mean, median, std, or callable"
", got %s (type %s)." % (mode, type(mode)))
data = fun(self._data)
assert len(self.events) == len(self._data)
if data.shape != self._data.shape[1:]:
raise RuntimeError("You passed a function that resulted "
"in data of shape {}, but it should be "
"{}.".format(data.shape,
self._data.shape[1:]))
else:
if mode not in {"mean", "std"}:
raise ValueError("If data are not preloaded, can only compute "
"mean or standard deviation.")
data = np.zeros((n_channels, n_times))
n_events = 0
for e in self:
Expand All @@ -893,18 +928,18 @@ def _compute_mean_or_stderr(self, picks, mode='ave'):

# convert to stderr if requested, could do in one pass but do in
# two (slower) in case there are large numbers
if _do_std:
if mode == "std":
data_mean = data.copy()
data.fill(0.)
for e in self:
data += (e - data_mean) ** 2
data = np.sqrt(data / n_events)

if not _do_std:
kind = 'average'
else:
if mode == "std":
kind = 'standard_error'
data /= np.sqrt(n_events)
else:
kind = "average"

return self._evoked_from_epoch_data(data, self.info, picks, n_events,
kind, self._name)
Expand Down Expand Up @@ -1572,7 +1607,7 @@ def _getitem(self, item, reason='IGNORED', copy=True, drop_event_id=True,
return epochs

def crop(self, tmin=None, tmax=None):
"""Crop a time interval from epochs object.
"""Crop a time interval from the epochs.
Parameters
----------
Expand All @@ -1590,6 +1625,8 @@ def crop(self, tmin=None, tmax=None):
-----
Unlike Python slices, MNE time intervals include both their end points;
crop(tmin, tmax) returns the interval tmin <= t <= tmax.
Note that the object is modified in place.
"""
# XXX this could be made to work on non-preloaded data...
_check_preload(self, 'Modifying data of epochs')
Expand Down
22 changes: 11 additions & 11 deletions mne/viz/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import numpy as np

from ..utils import verbose, get_config, set_config, logger, warn
from ..utils import verbose, get_config, set_config, logger, warn, _pl
from ..io.pick import pick_types, channel_type, _get_channel_types
from ..time_frequency import psd_multitaper
from .utils import (tight_layout, figure_nobar, _toggle_proj, _toggle_options,
Expand Down Expand Up @@ -313,14 +313,15 @@ def _get_picks_and_types(picks, ch_types, group_by, combine):
n_picks = len(picks_)
if n_picks < 2:
raise ValueError(" ".join(
(name, "has only ", str(n_picks), "sensors.")))
(name, "has only ", str(n_picks),
"sensor{}.".format(_pl(n_picks)))))
all_ch_types = list()
for picks_, name in zip(all_picks, names):
this_ch_type = list(set((ch_types[pick] for pick in picks_)))
n_types = len(this_ch_type)
if n_types > 1: # we can only scale properly with 1 type
raise ValueError(
"Roi {} contains {} sensor types!".format(
"ROI {} contains more than one sensor type ({})!".format(
name, n_types))
all_ch_types.append(this_ch_type[0])
names.append(name)
Expand Down Expand Up @@ -354,19 +355,18 @@ def _pick_and_combine(epochs, combine, all_picks, all_ch_types, names):
if combine == "gfp":
def combine(data):
return np.sqrt((data * data).mean(axis=1))
elif combine == "mean":
def combine(data):
return np.mean(data, axis=1)
elif combine == "std":
def combine(data):
return np.std(data, axis=1)
elif combine == "median":

elif combine in {"mean", "median", "std"}:
func = getattr(np, combine)

def combine(data):
return np.median(data, axis=1)
return func(data, axis=1)

elif not callable(combine):
raise ValueError(
"``combine`` must be None, a callable or one out of 'mean' "
"or 'gfp'. Got " + str(type(combine)))

for ch_type, picks_, name in zip(all_ch_types, all_picks, names):
if len(np.atleast_1d(picks_)) < 2:
raise ValueError("Cannot combine over only one sensor. "
Expand Down

0 comments on commit c248bd4

Please sign in to comment.