Skip to content

Commit

Permalink
ENH: Speed up set_bipolar_reference (mne-tools#9270)
Browse files Browse the repository at this point in the history
* Add cathode-location and fix doc

* Implementing matrix-multiplication approach (by @jasmainak)

* Fix returning ref_to from _check_before_reference

* Refine how Info is copied to new channels to avoid mixups and pass tests

* Adjust test comparing info of anode/cathode and bipolar-channel to also pass when bipolar-channels are appended

* Concatenation of Reference-Instances outside the loop

* Using @-operator for matrix-multiplication

* Update test_set_bipolar_reference to show info-keys responsible for the errors

* Improve performance by creating reference-instance from scratch

* channel-information including location is taken from anode

* Update test for info just taken from anode

* Fix import of create_info, improve docs

* Addition to latest.inc

* Fix latest.inc

* Refactor assert-statements in test_reference.py as suggested by @larsoner

* Update latest.inc as suggested by @jasmainak

Co-authored-by: Mainak Jas <[email protected]>

Co-authored-by: Mainak Jas <[email protected]>
  • Loading branch information
marsipu and jasmainak authored Apr 22, 2021
1 parent 723ac8f commit 350f76b
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 51 deletions.
2 changes: 2 additions & 0 deletions doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ Enhancements

- :func:`mne.preprocessing.find_eog_events` and :func:`mne.preprocessing.create_eog_epochs` now accept a list of channel names, allowing you to specify multiple EOG channels at once (:gh:`9269` by `Richard Höchenberger`_)

- Improve performance of :func:`mne.set_bipolar_reference` (:gh:`9270` by `Martin Schulz`_)

Bugs
~~~~
- Fix bug with :func:`mne.viz.plot_evoked_topo` where set ylim parameters gets swapped across channel types. (:gh:`9207` by |Ram Pari|_)
Expand Down
123 changes: 80 additions & 43 deletions mne/io/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#
# License: BSD (3-clause)

from copy import deepcopy
import numpy as np

from .constants import FIFF
Expand Down Expand Up @@ -47,11 +46,9 @@ def _copy_channel(inst, ch_name, new_ch_name):
return inst


def _apply_reference(inst, ref_from, ref_to=None, forward=None,
ch_type='auto'):
"""Apply a custom EEG referencing scheme."""
def _check_before_reference(inst, ref_from, ref_to, ch_type):
"""Prepare instance for referencing."""
# Check to see that data is preloaded
from scipy import linalg
_check_preload(inst, "Applying a reference")

ch_type = _get_ch_type(inst, ch_type)
Expand Down Expand Up @@ -98,6 +95,21 @@ def _apply_reference(inst, ref_from, ref_to=None, forward=None,
inst._projector, _ = \
setup_proj(inst.info, add_eeg_ref=False, activate=False)

# If the reference touches EEG/ECoG/sEEG/DBS electrodes, note in the
# info that a non-CAR has been applied.
ref_to_channels = pick_channels(inst.ch_names, ref_to, ordered=True)
if len(np.intersect1d(ref_to_channels, eeg_idx)) > 0:
inst.info['custom_ref_applied'] = FIFF.FIFFV_MNE_CUSTOM_REF_ON

return ref_to


def _apply_reference(inst, ref_from, ref_to=None, forward=None,
ch_type='auto'):
"""Apply a custom EEG referencing scheme."""
from scipy import linalg
ref_to = _check_before_reference(inst, ref_from, ref_to, ch_type)

# Compute reference
if len(ref_from) > 0:
# this is guaranteed below, but we should avoid the crazy pick_channels
Expand All @@ -113,10 +125,6 @@ def _apply_reference(inst, ref_from, ref_to=None, forward=None,
data[..., ref_to, :] -= ref_data
ref_data = ref_data[..., 0, :]

# If the reference touches EEG/ECoG/sEEG/DBS electrodes, note in the
# info that a non-CAR has been applied.
if len(np.intersect1d(ref_to, eeg_idx)) > 0:
inst.info['custom_ref_applied'] = FIFF.FIFFV_MNE_CUSTOM_REF_ON
# REST
if forward is not None:
# use ch_sel and the given forward
Expand Down Expand Up @@ -377,15 +385,14 @@ def set_bipolar_reference(inst, anode, cathode, ch_name=None, ch_info=None,
A bipolar reference takes the difference between two channels (the anode
minus the cathode) and adds it as a new virtual channel. The original
channels will be dropped.
channels will be dropped by default.
Multiple anodes and cathodes can be specified, in which case multiple
virtual channels will be created. The 1st anode will be subtracted from the
1st cathode, the 2nd anode from the 2nd cathode, etc.
virtual channels will be created. The 1st cathode will be subtracted
from the 1st anode, the 2nd cathode from the 2nd anode, etc.
By default, the virtual channels will be annotated with channel info of
the anodes, their locations set to (0, 0, 0) and coil types set to
EEG_BIPOLAR.
By default, the virtual channels will be annotated with channel-info and
-location of the anodes and coil types will be set to EEG_BIPOLAR.
Parameters
----------
Expand Down Expand Up @@ -432,6 +439,11 @@ def set_bipolar_reference(inst, anode, cathode, ch_name=None, ch_info=None,
.. versionadded:: 0.9.0
"""
from .meas_info import create_info
from ..io import RawArray
from ..epochs import EpochsArray
from ..evoked import EvokedArray

_check_can_reref(inst)
if not isinstance(anode, list):
anode = [anode]
Expand All @@ -444,7 +456,7 @@ def set_bipolar_reference(inst, anode, cathode, ch_name=None, ch_info=None,
'of cathodes (got %d).' % (len(anode), len(cathode)))

if ch_name is None:
ch_name = ['%s-%s' % ac for ac in zip(anode, cathode)]
ch_name = [f'{a}-{c}' for (a, c) in zip(anode, cathode)]
elif not isinstance(ch_name, list):
ch_name = [ch_name]
if len(ch_name) != len(anode):
Expand All @@ -467,39 +479,64 @@ def set_bipolar_reference(inst, anode, cathode, ch_name=None, ch_info=None,
raise ValueError('Number of channel info dictionaries must equal the '
'number of anodes/cathodes.')

# Merge specified and anode channel information dictionaries
new_chs = []
for ci, (an, ch) in enumerate(zip(anode, ch_info)):
_check_ch_keys(ch, ci, name='ch_info', check_min=False)
an_idx = inst.ch_names.index(an)
this_chs = deepcopy(inst.info['chs'][an_idx])
if copy:
inst = inst.copy()

# Set channel location and coil type
this_chs['loc'] = np.zeros(12)
this_chs['coil_type'] = FIFF.FIFFV_COIL_EEG_BIPOLAR
anode = _check_before_reference(inst, ref_from=cathode,
ref_to=anode, ch_type='auto')

this_chs.update(ch)
new_chs.append(this_chs)
# Create bipolar reference channels by multiplying the data
# (channels x time) with a matrix (n_virtual_channels x channels)
# and add them to the instance.
multiplier = np.zeros((len(anode), len(inst.ch_names)))
for idx, (a, c) in enumerate(zip(anode, cathode)):
multiplier[idx, inst.ch_names.index(a)] = 1
multiplier[idx, inst.ch_names.index(c)] = -1

if copy:
inst = inst.copy()
ref_info = create_info(ch_names=ch_name, sfreq=inst.info['sfreq'],
ch_types=inst.get_channel_types(picks=anode))

for i, (an, ca, name, chs) in enumerate(
zip(anode, cathode, ch_name, new_chs)):
if an in anode[i + 1:] or an in cathode[i + 1:] or not drop_refs:
# Make a copy of the channel if it's still needed later
# otherwise it's modified inplace
_copy_channel(inst, an, 'TMP')
an = 'TMP'
_apply_reference(inst, [ca], [an]) # ensures preloaded
# Update "chs" in Reference-Info.
for ch_idx, (an, info) in enumerate(zip(anode, ch_info)):
_check_ch_keys(info, ch_idx, name='ch_info', check_min=False)
an_idx = inst.ch_names.index(an)
inst.info['chs'][an_idx] = chs
inst.info['chs'][an_idx]['ch_name'] = name
logger.info('Bipolar channel added as "%s".' % name)
inst.info._update_redundant()
# Copy everything from anode (except ch_name).
an_chs = {k: v for k, v in inst.info['chs'][an_idx].items()
if k != 'ch_name'}
ref_info['chs'][ch_idx].update(an_chs)
# Set coil-type to bipolar.
ref_info['chs'][ch_idx]['coil_type'] = FIFF.FIFFV_COIL_EEG_BIPOLAR
# Update with info from ch_info-parameter.
ref_info['chs'][ch_idx].update(info)

# Set other info-keys from original instance.
pick_info = {k: v for k, v in inst.info.items() if k not in
['chs', 'ch_names', 'bads', 'nchan', 'sfreq']}
ref_info.update(pick_info)

# Rereferencing of data.
ref_data = multiplier @ inst._data

if isinstance(inst, BaseRaw):
ref_inst = RawArray(ref_data, ref_info, first_samp=inst.first_samp,
copy=None)
elif isinstance(inst, BaseEpochs):
ref_inst = EpochsArray(ref_data, ref_info, events=inst.events,
tmin=inst.tmin, event_id=inst.event_id,
metadata=inst.metadata)
else:
ref_inst = EvokedArray(ref_data, ref_info, tmin=inst.tmin,
comment=inst.comment, nave=inst.nave,
kind='average')

# Add referenced instance to original instance.
inst.add_channels([ref_inst], force_update_info=True)

added_channels = ', '.join([name for name in ch_name])
logger.info(f'Added the following bipolar channels:\n{added_channels}')

if getattr(inst, 'picks', None) is not None:
del inst.picks # picks cannot be tracked anymore
for attr_name in ['picks', '_projector']:
setattr(inst, attr_name, None)

# Drop remaining channels.
if drop_refs:
Expand Down
14 changes: 6 additions & 8 deletions mne/io/tests/test_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,16 +369,14 @@ def test_set_bipolar_reference(inst_type):

# Check channel information
bp_info = reref.info['chs'][reref.ch_names.index('bipolar')]
an_info = reref.info['chs'][inst.ch_names.index('EEG 001')]
an_info = inst.info['chs'][inst.ch_names.index('EEG 001')]
for key in bp_info:
if key == 'loc':
assert_array_equal(bp_info[key], 0)
elif key == 'coil_type':
assert_equal(bp_info[key], FIFF.FIFFV_COIL_EEG_BIPOLAR)
if key == 'coil_type':
assert bp_info[key] == FIFF.FIFFV_COIL_EEG_BIPOLAR, key
elif key == 'kind':
assert_equal(bp_info[key], FIFF.FIFFV_EOG_CH)
else:
assert_equal(bp_info[key], an_info[key])
assert bp_info[key] == FIFF.FIFFV_EOG_CH, key
elif key != 'ch_name':
assert_equal(bp_info[key], an_info[key], err_msg=key)

# Minimalist call
reref = set_bipolar_reference(inst, 'EEG 001', 'EEG 002')
Expand Down

0 comments on commit 350f76b

Please sign in to comment.