Skip to content

Commit

Permalink
MRG: Add nirs support to ICA topoplot (mne-tools#7427)
Browse files Browse the repository at this point in the history
* Add fnirs support to ica topomap plotting

* Fix latest docs for ica fnirs plotting

* Add tests for nirs ica topomap plotting

* Fix type in nirs ica topomap test

* Change nirs topomap test data to increase code coverage

* Add merge_nirs_data to plot_ica_components

* Add requires sklearn to nirs ica topomap test

* Add _merge_ch_data as a wrapper for merging nirs or grad data

* Change from _merge_grad_data to _merge_channel_data

* Revert public API change of merge_grads
  • Loading branch information
rob-luke authored Mar 17, 2020
1 parent aa7ea77 commit 586c09f
Show file tree
Hide file tree
Showing 12 changed files with 131 additions and 74 deletions.
2 changes: 1 addition & 1 deletion doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ Changelog
- Allow returning vector source estimates from sparse inverse solvers through ``pick_ori='vector'`` by `Christian Brodbeck`_
- Add NIRS support to :func:`mne.viz.plot_topomap` by `Robert Luke`_
- Add NIRS support to :func:`mne.viz.plot_topomap` and :func:`mne.viz.plot_ica_components` by `Robert Luke`_
- Add the ability to :func:`mne.channels.equalize_channels` to also re-order the channels and also operate on instances of :class:`mne.Info`, :class:`mne.Forward`, :class:`mne.Covariance` and :class:`mne.time_frequency.CrossSpectralDensity` by `Marijn van Vliet`_
Expand Down
29 changes: 29 additions & 0 deletions mne/channels/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,35 @@ def _pair_grad_sensors_ch_names_neuromag122(ch_names):
return grad_chs


def _merge_ch_data(data, ch_type, names, method='rms'):
"""Merge data from channel pairs.
Parameters
----------
data : array, shape = (n_channels, ..., n_times)
Data for channels, ordered in pairs.
ch_type : str
Channel type.
names : list
List of channel names.
method : str
Can be 'rms' or 'mean'.
Returns
-------
data : array, shape = (n_channels / 2, ..., n_times)
The root mean square or mean for each pair.
names : list
List of channel names.
"""
if ch_type == 'grad':
data = _merge_grad_data(data, method)
else:
assert ch_type in ('hbo', 'hbr', 'fnirs_raw', 'fnirs_od')
data, names = _merge_nirs_data(data, names)
return data, names


def _merge_grad_data(data, method='rms'):
"""Merge data from channel pairs using the RMS or mean.
Expand Down
4 changes: 2 additions & 2 deletions mne/evoked.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .baseline import rescale
from .channels.channels import (ContainsMixin, UpdateChannelsMixin,
SetChannelsMixin, InterpolationMixin)
from .channels.layout import _merge_grad_data, _pair_grad_sensors
from .channels.layout import _merge_ch_data, _pair_grad_sensors
from .filter import detrend, FilterMixin
from .utils import (check_fname, logger, verbose, _time_mask, warn, sizeof_fmt,
SizeMixin, copy_function_doc_to_method_doc, _validate_type,
Expand Down Expand Up @@ -602,7 +602,7 @@ def get_peak(self, ch_type=None, tmin=None, tmax=None,
ch_names = [ch_names[k] for k in picks]

if merge_grads:
data = _merge_grad_data(data)
data, _ = _merge_ch_data(data, ch_type, [])
ch_names = [ch_name[:-1] + 'X' for ch_name in ch_names[::2]]

ch_idx, time_idx, max_amp = _get_peak(data, self.times, tmin,
Expand Down
3 changes: 2 additions & 1 deletion mne/tests/test_evoked.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,8 @@ def test_get_peak():
assert_equal(ch_name, 'MEG 1421')
assert_allclose(max_amp, 7.17057e-13, rtol=1e-5)

pytest.raises(ValueError, evoked.get_peak, ch_type='mag', merge_grads=True)
pytest.raises(ValueError, evoked.get_peak, ch_type='mag',
merge_grads=True)
ch_name, time_idx = evoked.get_peak(ch_type='grad', merge_grads=True)
assert_equal(ch_name, 'MEG 244X')

Expand Down
5 changes: 3 additions & 2 deletions mne/time_frequency/tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1405,7 +1405,7 @@ def plot_joint(self, timefreqs=None, picks=None, baseline=None,
.. versionadded:: 0.16.0
""" # noqa: E501
from ..viz.topomap import _set_contour_locator, plot_topomap
from ..channels.layout import (find_layout, _merge_grad_data,
from ..channels.layout import (find_layout, _merge_ch_data,
_pair_grad_sensors)
import matplotlib.pyplot as plt

Expand Down Expand Up @@ -1549,7 +1549,8 @@ def plot_joint(self, timefreqs=None, picks=None, baseline=None,
if layout is None:
pos = new_pos
method = combine or 'rms'
data = _merge_grad_data(data[pair_picks], method=method)
data, _ = _merge_ch_data(data[pair_picks], ch_type, [],
method=method)

all_pos.append(pos)

Expand Down
12 changes: 7 additions & 5 deletions mne/viz/evoked.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,17 +114,18 @@ def _line_plot_onselect(xmin, xmax, ch_types, info, data, times, text=None,
for idx, ch_type in enumerate(ch_types):
if ch_type not in ('eeg', 'grad', 'mag'):
continue
picks, pos, merge_grads, _, ch_type, this_sphere, clip_origin = \
picks, pos, merge_channels, _, ch_type, this_sphere, clip_origin = \
_prepare_topomap_plot(info, ch_type, sphere=sphere)
outlines = _make_head_outlines(this_sphere, pos, 'head', clip_origin)
if len(pos) < 2:
fig.delaxes(axarr[0][idx])
continue
this_data = data[picks, minidx:maxidx]
if merge_grads:
from ..channels.layout import _merge_grad_data
if merge_channels:
from ..channels.layout import _merge_ch_data
method = 'mean' if psd else 'rms'
this_data = _merge_grad_data(this_data, method=method)
this_data, _ = _merge_ch_data(this_data, ch_type, [],
method=method)
title = '%s %s' % (ch_type, method.upper())
else:
title = ch_type
Expand Down Expand Up @@ -840,7 +841,8 @@ def plot_evoked_topo(evoked, layout=None, layout_scale=0.945, color=None,
fig_facecolor=fig_facecolor,
fig_background=fig_background,
axis_facecolor=axis_facecolor,
font_color=font_color, merge_grads=merge_grads,
font_color=font_color,
merge_channels=merge_grads,
legend=legend, axes=axes, show=show,
noise_cov=noise_cov)

Expand Down
9 changes: 5 additions & 4 deletions mne/viz/ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,7 +1068,7 @@ def _label_clicked(pos, params):
fig, axes = _prepare_trellis(len(types), max_col=3)
for ch_idx, ch_type in enumerate(types):
try:
data_picks, pos, merge_grads, _, _, this_sphere, clip_origin = \
data_picks, pos, merge_channels, _, _, this_sphere, clip_origin = \
_prepare_topomap_plot(ica, ch_type)
except Exception as exc:
warn(str(exc))
Expand All @@ -1077,11 +1077,12 @@ def _label_clicked(pos, params):
outlines = _make_head_outlines(this_sphere, pos, 'head', clip_origin)
this_data = data[:, data_picks]
ax = axes[ch_idx]
if merge_grads:
from ..channels.layout import _merge_grad_data
if merge_channels:
from ..channels.layout import _merge_ch_data
for ii, data_ in zip(ic_idx, this_data):
ax.set_title('%s %s' % (ica._ica_names[ii], ch_type), fontsize=12)
data_ = _merge_grad_data(data_) if merge_grads else data_
if merge_channels:
data_, _ = _merge_ch_data(data_, 'grad', [])
plot_topomap(data_.flatten(), pos, axes=ax, show=False,
sphere=this_sphere, outlines=outlines)
_hide_frame(ax)
Expand Down
23 changes: 21 additions & 2 deletions mne/viz/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,18 @@

import pytest
import numpy as np
import os.path as op

from mne import create_info, EvokedArray
from mne import create_info, EvokedArray, events_from_annotations, Epochs
from mne.channels import make_standard_montage
from mne.datasets.testing import data_path
from mne.preprocessing.nirs import optical_density, beer_lambert_law
from mne.io import read_raw_nirx


@pytest.fixture()
def fnirs_evoked():
"""Create a fnirs evoked."""
"""Create an fnirs evoked structure."""
montage = make_standard_montage('biosemi16')
ch_names = montage.ch_names
ch_types = ['eeg'] * 16
Expand All @@ -24,3 +28,18 @@ def fnirs_evoked():
evoked.set_channel_types({'Fp1': 'hbo', 'Fp2': 'hbo', 'F4': 'hbo',
'Fz': 'hbo'}, verbose='error')
return evoked


@pytest.fixture()
def fnirs_epochs():
"""Create an fnirs epoch structure."""
fname = op.join(data_path(download=False),
'NIRx', 'nirx_15_2_recording_w_overlap')
raw_intensity = read_raw_nirx(fname, preload=False)
raw_od = optical_density(raw_intensity)
raw_haemo = beer_lambert_law(raw_od)
evts, _ = events_from_annotations(raw_haemo, event_id={'1.0': 1})
evts_dct = {'A': 1}
tn, tx = -1, 2
epochs = Epochs(raw_haemo, evts, event_id=evts_dct, tmin=tn, tmax=tx)
return epochs
3 changes: 2 additions & 1 deletion mne/viz/tests/test_topo.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ def test_plot_topo():
# Show topography
evoked = _get_epochs().average()
# should auto-find layout
plot_evoked_topo([evoked, evoked], merge_grads=True, background_color='w')
plot_evoked_topo([evoked, evoked], merge_grads=True,
background_color='w')

picked_evoked = evoked.copy().pick_channels(evoked.ch_names[:3])
picked_evoked_eeg = evoked.copy().pick_types(meg=False, eeg=True)
Expand Down
33 changes: 17 additions & 16 deletions mne/viz/tests/test_topomap.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,22 @@
from matplotlib.patches import Circle

from mne import (read_evokeds, read_proj, make_fixed_length_events, Epochs,
compute_proj_evoked, find_layout, pick_types, create_info,
events_from_annotations)
compute_proj_evoked, find_layout, pick_types, create_info)
from mne.io.proj import make_eeg_average_ref_proj, Projection
from mne.io import read_raw_fif, read_info, RawArray, read_raw_nirx
from mne.io import read_raw_fif, read_info, RawArray
from mne.io.constants import FIFF
from mne.io.pick import pick_info, channel_indices_by_type
from mne.io.compensator import get_current_comp
from mne.channels import read_layout, make_dig_montage
from mne.datasets import testing
from mne.time_frequency.tfr import AverageTFR
from mne.utils import run_tests_if_main
from mne.datasets.testing import data_path

from mne.viz import plot_evoked_topomap, plot_projs_topomap
from mne.viz.topomap import (_get_pos_outlines, _onselect, plot_topomap,
plot_arrowmap, plot_psds_topomap)
from mne.viz.utils import _find_peaks, _fake_click
from mne.preprocessing.nirs import optical_density, beer_lambert_law
from mne.utils import requires_sklearn


data_dir = testing.data_path(download=False)
Expand Down Expand Up @@ -550,20 +548,23 @@ def test_plot_topomap_bads():
plt.close('all')


def test_plot_topomap_nirs_overlap():
def test_plot_topomap_nirs_overlap(fnirs_epochs):
"""Test plotting nirs topomap with overlapping channels (gh-7414)."""
fname = op.join(data_path(download=False),
'NIRx', 'nirx_15_2_recording_w_overlap')
raw_intensity = read_raw_nirx(fname, preload=False)
raw_od = optical_density(raw_intensity)
raw_haemo = beer_lambert_law(raw_od)
evts, _ = events_from_annotations(raw_haemo, event_id={'1.0': 1})
evts_dct = {'A': 1}
tn, tx = -1, 2
epochs = Epochs(raw_haemo, evts, event_id=evts_dct, tmin=tn, tmax=tx)
fig = epochs['A'].average(picks='hbo').plot_topomap()
fig = fnirs_epochs['A'].average(picks='hbo').plot_topomap()
assert len(fig.axes) == 5
plt.close('all')


@requires_sklearn
def test_plot_topomap_nirs_ica(fnirs_epochs):
"""Test plotting nirs ica topomap."""
from mne.preprocessing import ICA
fnirs_epochs = fnirs_epochs.load_data().pick(picks='hbo')
fnirs_epochs = fnirs_epochs.pick(picks=range(30))
ica = ICA().fit(fnirs_epochs)
fig = ica.plot_components()
assert len(fig[0].axes) == 20
plt.close('all')


run_tests_if_main()
12 changes: 6 additions & 6 deletions mne/viz/topo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from ..io.pick import channel_type, pick_types
from ..utils import _clean_names, warn, _check_option, Bunch
from ..channels.layout import _merge_grad_data, _pair_grad_sensors, find_layout
from ..channels.layout import _merge_ch_data, _pair_grad_sensors, find_layout
from ..defaults import _handle_default
from .utils import (_check_delayed_ssp, _get_color_list, _draw_proj_checkbox,
add_background_image, plt_show, _setup_vmin_vmax,
Expand Down Expand Up @@ -568,7 +568,7 @@ def _plot_evoked_topo(evoked, layout=None, layout_scale=0.945, color=None,
border='none', ylim=None, scalings=None, title=None,
proj=False, vline=(0.,), hline=(0.,), fig_facecolor='k',
fig_background=None, axis_facecolor='k', font_color='w',
merge_grads=False, legend=True, axes=None, show=True,
merge_channels=False, legend=True, axes=None, show=True,
noise_cov=None):
"""Plot 2D topography of evoked responses.
Expand Down Expand Up @@ -620,7 +620,7 @@ def _plot_evoked_topo(evoked, layout=None, layout_scale=0.945, color=None,
The face color to be used for each sensor plot. Defaults to black.
font_color : color
The color of text in the colorbar and title. Defaults to white.
merge_grads : bool
merge_channels : bool
Whether to use RMS value of gradiometer pairs. Only works for Neuromag
data. Defaults to False.
legend : bool | int | string | tuple
Expand Down Expand Up @@ -679,7 +679,7 @@ def _plot_evoked_topo(evoked, layout=None, layout_scale=0.945, color=None,
if not all(e.ch_names == ch_names for e in evoked):
raise ValueError('All evoked.picks must be the same')
ch_names = _clean_names(ch_names)
if merge_grads:
if merge_channels:
picks = _pair_grad_sensors(info, topomap_coords=False)
chs = list()
for pick in picks[::2]:
Expand All @@ -692,7 +692,7 @@ def _plot_evoked_topo(evoked, layout=None, layout_scale=0.945, color=None,
info._check_consistency()
new_picks = list()
for e in evoked:
data = _merge_grad_data(e.data[picks])
data, _ = _merge_ch_data(e.data[picks], 'grad', [])
if noise_cov is None:
data *= scalings['grad']
e.data = data
Expand All @@ -705,7 +705,7 @@ def _plot_evoked_topo(evoked, layout=None, layout_scale=0.945, color=None,
if layout is None:
layout = find_layout(info)

if not merge_grads:
if not merge_channels:
# XXX. at the moment we are committed to 1- / 2-sensor-types layouts
chs_in_layout = set(layout.names) & set(ch_names)
types_used = {channel_type(info, ch_names.index(ch))
Expand Down
Loading

0 comments on commit 586c09f

Please sign in to comment.