Skip to content

Commit

Permalink
MRG Make combine_evoked transparently reorder channels if required (m…
Browse files Browse the repository at this point in the history
…ne-tools#5431)

* init

* test

* Pep8

* more tests

* remove redundant check

* ...

* whatsnew

* whatsnew
  • Loading branch information
jona-sassenhagen authored and larsoner committed Aug 16, 2018
1 parent 0d99a0d commit 35f54e7
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 5 deletions.
2 changes: 2 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ Changelog

- Add :func:`mne.head_to_mri` to convert positions from head coordinates to MRI RAS coordinates, by `Joan Massich`_ and `Alex Gramfort`_

- :func:`mne.combine_evoked` and :func:`mne.grand_average` can now handle input with the same channels in different orders, if required, by `Jona Sassenhagen`_

Bug
~~~

Expand Down
18 changes: 13 additions & 5 deletions mne/evoked.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,13 +828,21 @@ def grand_average(all_evoked, interpolate_bads=True):
def _check_evokeds_ch_names_times(all_evoked):
evoked = all_evoked[0]
ch_names = evoked.ch_names
for ev in all_evoked[1:]:
for ii, ev in enumerate(all_evoked[1:]):
if ev.ch_names != ch_names:
raise ValueError(
"%s and %s do not contain the same channels" % (evoked, ev))
if set(ev.ch_names) != set(ch_names):
raise ValueError(
"%s and %s do not contain the same channels." % (evoked,
ev))
else:
warn("Order of channels differs, reordering channels ...")
ev = ev.copy()
ev.reorder_channels(ch_names)
all_evoked[ii + 1] = ev
if not np.max(np.abs(ev.times - evoked.times)) < 1e-7:
raise ValueError("%s and %s do not contain the same time instants"
% (evoked, ev))
return all_evoked


def combine_evoked(all_evoked, weights):
Expand All @@ -861,7 +869,6 @@ def combine_evoked(all_evoked, weights):
-----
.. versionadded:: 0.9.0
"""
evoked = all_evoked[0].copy()
if isinstance(weights, string_types):
if weights not in ('nave', 'equal'):
raise ValueError('weights must be a list of float, or "nave" or '
Expand All @@ -875,7 +882,8 @@ def combine_evoked(all_evoked, weights):
if weights.ndim != 1 or weights.size != len(all_evoked):
raise ValueError('weights must be the same size as all_evoked')

_check_evokeds_ch_names_times(all_evoked)
all_evoked = _check_evokeds_ch_names_times(all_evoked)
evoked = all_evoked[0].copy()

# use union of bad channels
bads = list(set(evoked.info['bads']).union(*(ev.info['bads']
Expand Down
12 changes: 12 additions & 0 deletions mne/tests/test_evoked.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,18 @@ def test_arithmetic():
assert_equal(gave.nave, 2)
pytest.raises(TypeError, grand_average, [1, evoked1])

# test channel (re)ordering
evoked1, evoked2 = read_evokeds(fname, condition=[0, 1], proj=True)
data2 = evoked2.data # assumes everything is ordered to the first evoked
data = (evoked1.data + evoked2.data) / 2
evoked2.reorder_channels(evoked2.ch_names[::-1])
assert not np.allclose(data2, evoked2.data)
with pytest.warns(RuntimeWarning, match='reordering'):
ev3 = grand_average((evoked1, evoked2))
assert np.allclose(ev3.data, data)
assert evoked1.ch_names != evoked2.ch_names
assert evoked1.ch_names == ev3.ch_names


def test_array_epochs():
"""Test creating evoked from array."""
Expand Down

0 comments on commit 35f54e7

Please sign in to comment.