Skip to content

Commit

Permalink
Merge pull request mne-tools#2741 from wmvanvliet/add_reference_chann…
Browse files Browse the repository at this point in the history
…els_bugfix

[MRG] Fix bug in add_reference_channels
  • Loading branch information
larsoner committed Dec 23, 2015
2 parents 16e91fa + cef01cd commit 122b56f
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 3 deletions.
2 changes: 1 addition & 1 deletion mne/io/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def add_reference_channels(inst, ref_channels, copy=True):
raise TypeError("inst should be Raw, Epochs, or Evoked instead of %s."
% type(inst))
nchan = len(inst.info['ch_names'])
if ch in ref_channels:
for ch in ref_channels:
chan_info = {'ch_name': ch,
'coil_type': FIFF.FIFFV_COIL_EEG,
'kind': FIFF.FIFFV_EEG_CH,
Expand Down
30 changes: 28 additions & 2 deletions mne/io/tests/test_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from nose.tools import assert_true, assert_equal, assert_raises
from numpy.testing import assert_array_equal, assert_allclose

from mne import pick_types, Evoked, Epochs, read_events
from mne import pick_channels, pick_types, Evoked, Epochs, read_events
from mne.io.constants import FIFF
from mne.io import (set_eeg_reference, set_bipolar_reference,
add_reference_channels)
Expand Down Expand Up @@ -200,6 +200,18 @@ def test_set_bipolar_reference():
assert_raises(ValueError, set_bipolar_reference, raw,
'EEG 001', 'EEG 002', ch_name='EEG 003')

def _check_channel_names(inst, ref_names):
if isinstance(ref_names, str):
ref_names = [ref_names]

# Test that the names of the reference channels are present in `ch_names`
ref_idx = pick_channels(inst.info['ch_names'], ref_names)
assert_true(len(ref_idx), len(ref_names))

# Test that the names of the reference channels are present in the `chs`
# list
inst.info._check_consistency() # Should raise no exceptions


@testing.requires_testing_data
def test_add_reference():
Expand All @@ -212,26 +224,34 @@ def test_add_reference():
raw_ref = add_reference_channels(raw, 'Ref', copy=True)
assert_equal(raw_ref._data.shape[0], raw._data.shape[0] + 1)
assert_array_equal(raw._data[picks_eeg, :], raw_ref._data[picks_eeg, :])
_check_channel_names(raw_ref, 'Ref')

orig_nchan = raw.info['nchan']
raw = add_reference_channels(raw, 'Ref', copy=False)
assert_array_equal(raw._data, raw_ref._data)
assert_equal(raw.info['nchan'], orig_nchan + 1)
_check_channel_names(raw, 'Ref')

ref_idx = raw.ch_names.index('Ref')
ref_data, _ = raw[ref_idx]
assert_array_equal(ref_data, 0)

# add two reference channels to Raw
raw = Raw(fif_fname, preload=True)
picks_eeg = pick_types(raw.info, meg=False, eeg=True)

# Test adding an existing channel as reference channel
assert_raises(ValueError, add_reference_channels, raw,
raw.info['ch_names'][0])

# add two reference channels to Raw
raw_ref = add_reference_channels(raw, ['M1', 'M2'], copy=True)
_check_channel_names(raw_ref, ['M1', 'M2'])
assert_equal(raw_ref._data.shape[0], raw._data.shape[0] + 2)
assert_array_equal(raw._data[picks_eeg, :], raw_ref._data[picks_eeg, :])
assert_array_equal(raw_ref._data[-2:, :], 0)

raw = add_reference_channels(raw, ['M1', 'M2'], copy=False)
_check_channel_names(raw, ['M1', 'M2'])
ref_idx = raw.ch_names.index('M1')
ref_idy = raw.ch_names.index('M2')
ref_data, _ = raw[[ref_idx, ref_idy]]
Expand All @@ -245,6 +265,7 @@ def test_add_reference():
picks=picks_eeg, preload=True)
epochs_ref = add_reference_channels(epochs, 'Ref', copy=True)
assert_equal(epochs_ref._data.shape[1], epochs._data.shape[1] + 1)
_check_channel_names(epochs_ref, 'Ref')
ref_idx = epochs_ref.ch_names.index('Ref')
ref_data = epochs_ref.get_data()[:, ref_idx, :]
assert_array_equal(ref_data, 0)
Expand All @@ -260,8 +281,11 @@ def test_add_reference():
picks=picks_eeg, preload=True)
epochs_ref = add_reference_channels(epochs, ['M1', 'M2'], copy=True)
assert_equal(epochs_ref._data.shape[1], epochs._data.shape[1] + 2)
_check_channel_names(epochs_ref, ['M1', 'M2'])
ref_idx = epochs_ref.ch_names.index('M1')
ref_idy = epochs_ref.ch_names.index('M2')
assert_equal(epochs_ref.info['chs'][ref_idx]['ch_name'], 'M1')
assert_equal(epochs_ref.info['chs'][ref_idy]['ch_name'], 'M2')
ref_data = epochs_ref.get_data()[:, [ref_idx, ref_idy], :]
assert_array_equal(ref_data, 0)
picks_eeg = pick_types(epochs.info, meg=False, eeg=True)
Expand All @@ -277,6 +301,7 @@ def test_add_reference():
evoked = epochs.average()
evoked_ref = add_reference_channels(evoked, 'Ref', copy=True)
assert_equal(evoked_ref.data.shape[0], evoked.data.shape[0] + 1)
_check_channel_names(evoked_ref, 'Ref')
ref_idx = evoked_ref.ch_names.index('Ref')
ref_data = evoked_ref.data[ref_idx, :]
assert_array_equal(ref_data, 0)
Expand All @@ -293,6 +318,7 @@ def test_add_reference():
evoked = epochs.average()
evoked_ref = add_reference_channels(evoked, ['M1', 'M2'], copy=True)
assert_equal(evoked_ref.data.shape[0], evoked.data.shape[0] + 2)
_check_channel_names(evoked_ref, ['M1', 'M2'])
ref_idx = evoked_ref.ch_names.index('M1')
ref_idy = evoked_ref.ch_names.index('M2')
ref_data = evoked_ref.data[[ref_idx, ref_idy], :]
Expand Down

0 comments on commit 122b56f

Please sign in to comment.