Skip to content

Commit

Permalink
ENH: ICA support for qt-backend (mne-tools#10330)
Browse files Browse the repository at this point in the history
* adapt ica for qt-backend

* further test fixes

* move external plot initialization into backends

* fix flake

* fix docstring

* update latest.inc

* fix latest.inc

* skip tests and modules for mne-qt-browser<=0.2.0
  • Loading branch information
marsipu authored Feb 15, 2022
1 parent 7c0e36d commit 9315ed6
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 72 deletions.
2 changes: 2 additions & 0 deletions doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ Enhancements

- The ``map_surface`` parameter of :meth:`mne.viz.Brain.add_foci` now works and allows you to add foci to a rendering of a brain that are positioned at the vertex of the mesh closest to the given coordinates (:gh:`10299` by `Marijn van Vliet`_)

- :meth:`mne.preprocessing.ICA.plot_sources()` is now also supported by the ``qt`` backend (:gh:`10330` by `Martin Schulz`_)

Bugs
~~~~

Expand Down
25 changes: 17 additions & 8 deletions mne/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,13 @@ def mpl_backend(garbage_collect):
backend._close_all()


# Skip functions or modules for mne-qt-browser < 0.2.0
pre_2_0_skip_modules = ['mne.viz.tests.test_epochs',
'mne.viz.tests.test_ica']
pre_2_0_skip_funcs = ['test_plot_raw_white',
'test_plot_raw_selection']


def _check_pyqtgraph(request):
# Check PyQt5
try:
Expand All @@ -407,14 +414,16 @@ def _check_pyqtgraph(request):
# Check mne-qt-browser
try:
import mne_qt_browser # noqa: F401
# Check if version is high enough for epochs
v_to_low = _compare_version(mne_qt_browser.__version__, '<', '0.2.0')
is_epochs = request.function.__module__ == 'mne.viz.tests.test_epochs'
is_ica = request.function.__module__ == 'mne.viz.tests.test_ica'
if v_to_low and is_epochs:
pytest.skip('No Epochs tests for mne-qt-browser < 0.2.0')
elif v_to_low and is_ica:
pytest.skip('No ICA tests for mne-qt-browser < 0.2.0')
# Check mne-qt-browser version
lower_2_0 = _compare_version(mne_qt_browser.__version__, '<', '0.2.0')
m_name = request.function.__module__
f_name = request.function.__name__
if lower_2_0 and m_name in pre_2_0_skip_modules:
pytest.skip(f'Test-Module "{m_name}" was skipped for'
f' mne-qt-browser < 0.2.0')
elif lower_2_0 and f_name in pre_2_0_skip_funcs:
pytest.skip(f'Test "{f_name}" was skipped for '
f'mne-qt-browser < 0.2.0')
except Exception:
pytest.skip('Requires mne_qt_browser')

Expand Down
6 changes: 4 additions & 2 deletions mne/preprocessing/ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -1951,12 +1951,14 @@ def plot_properties(self, inst, picks=None, axes=None, dB=True,
def plot_sources(self, inst, picks=None, start=None,
stop=None, title=None, show=True, block=False,
show_first_samp=False, show_scrollbars=True,
time_format='float'):
time_format='float', precompute='auto',
use_opengl=None):
return plot_ica_sources(self, inst=inst, picks=picks,
start=start, stop=stop, title=title, show=show,
block=block, show_first_samp=show_first_samp,
show_scrollbars=show_scrollbars,
time_format=time_format)
time_format=time_format,
precompute=precompute, use_opengl=use_opengl)

@copy_function_doc_to_method_doc(plot_ica_scores)
def plot_scores(self, scores, exclude=None, labels=None, axhline=None,
Expand Down
5 changes: 5 additions & 0 deletions mne/viz/_figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ def __init__(self, **kwargs):
self.mne.midpoints = np.convolve(self.mne.boundary_times,
np.ones(2), mode='valid') / 2

# initialize picks and projectors
self._update_picks()
if not self.mne.instance_type == 'ica':
self._update_projector()

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# ANNOTATIONS
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
Expand Down
18 changes: 18 additions & 0 deletions mne/viz/_mpl_figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2222,6 +2222,7 @@ def _patched_canvas(fig):

def _init_browser(**kwargs):
"""Instantiate a new MNE browse-style figure."""
from mne.io import BaseRaw
fig = _figure(toolbar=False, FigureClass=MNEBrowseFigure, **kwargs)

# initialize zen mode
Expand All @@ -2236,4 +2237,21 @@ def _init_browser(**kwargs):
fig.mne.scrollbars_visible = True
fig._toggle_scrollbars()

# Initialize parts of the plot
is_ica = fig.mne.instance_type == 'ica'

if not is_ica:
# make channel selection dialog,
# if requested (doesn't work well in init)
if fig.mne.group_by in ('selection', 'position'):
fig._create_selection_fig()

# update data, and plot
fig._update_trace_offsets()
fig._redraw(update_data=True, annotations=False)

if isinstance(fig.mne.inst, BaseRaw):
fig._setup_annotation_colors()
fig._draw_annotations()

return fig
12 changes: 0 additions & 12 deletions mne/viz/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,18 +906,6 @@ def plot_epochs(epochs, picks=None, scalings=None, n_epochs=20, n_channels=20,
use_opengl=use_opengl)

fig = _get_browser(**params)

fig._update_picks()

# make channel selection dialog, if requested (doesn't work well in init)
if group_by in ('selection', 'position'):
fig._create_selection_fig()

fig._update_projector()
fig._update_trace_offsets()
fig._update_data()
fig._draw_traces()

_show_browser(show, block=block, fig=fig)

return fig
Expand Down
32 changes: 13 additions & 19 deletions mne/viz/ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import numpy as np

from .utils import (tight_layout, _make_event_color_dict,
from .utils import (_show_browser, tight_layout, _make_event_color_dict,
plt_show, _convert_psds, _compute_scalings)
from .topomap import _plot_ica_topomap
from .epochs import plot_epochs_image
Expand All @@ -29,7 +29,8 @@
def plot_ica_sources(ica, inst, picks=None, start=None,
stop=None, title=None, show=True, block=False,
show_first_samp=False, show_scrollbars=True,
time_format='float'):
time_format='float', precompute='auto',
use_opengl=None):
"""Plot estimated latent sources given the unmixing matrix.
Typical usecases:
Expand Down Expand Up @@ -65,6 +66,8 @@ def plot_ica_sources(ica, inst, picks=None, start=None,
If True, show time axis relative to the ``raw.first_samp``.
%(show_scrollbars)s
%(time_format)s
%(precompute)s
%(use_opengl)s
Returns
-------
Expand All @@ -91,7 +94,8 @@ def plot_ica_sources(ica, inst, picks=None, start=None,
show=show, title=title, block=block,
show_first_samp=show_first_samp,
show_scrollbars=show_scrollbars,
time_format=time_format)
time_format=time_format, precompute=precompute,
use_opengl=use_opengl)
elif isinstance(inst, Evoked):
if start is not None or stop is not None:
inst = inst.copy().crop(start, stop)
Expand Down Expand Up @@ -951,7 +955,8 @@ def _plot_ica_overlay_evoked(evoked, evoked_cln, title, show):


def _plot_sources(ica, inst, picks, exclude, start, stop, show, title, block,
show_scrollbars, show_first_samp, time_format):
show_scrollbars, show_first_samp, time_format,
precompute, use_opengl):
"""Plot the ICA components as a RawArray or EpochsArray."""
from ._figure import _get_browser
from .. import EpochsArray, BaseEpochs
Expand Down Expand Up @@ -1086,7 +1091,9 @@ def _plot_sources(ica, inst, picks, exclude, start, stop, show, title, block,
clipping=None,
scrollbars_visible=show_scrollbars,
scalebars_visible=False,
window_title=title)
window_title=title,
precompute=precompute,
use_opengl=use_opengl)
if is_epo:
params.update(n_epochs=n_epochs,
boundary_times=boundary_times,
Expand All @@ -1098,19 +1105,6 @@ def _plot_sources(ica, inst, picks, exclude, start, stop, show, title, block,
xlabel='Epoch number')

fig = _get_browser(**params)
_show_browser(show, block=block, fig=fig)

fig._update_picks()

# update data, and plot
fig._update_trace_offsets()
fig._update_data()
fig._draw_traces()

# plot annotations (if any)
if is_raw:
fig._setup_annotation_colors()
fig._update_annotation_segments()
fig._draw_annotations()

plt_show(show, block=block)
return fig
14 changes: 0 additions & 14 deletions mne/viz/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,20 +352,6 @@ def plot_raw(raw, events=None, duration=10.0, start=0.0, n_channels=20,

fig = _get_browser(**params)

fig._update_picks()

# make channel selection dialog, if requested (doesn't work well in init)
if group_by in ('selection', 'position'):
fig._create_selection_fig()

# update projector and data, and plot
fig._update_projector()
fig._update_trace_offsets()
fig._setup_annotation_colors()

# Draw Plot
fig._redraw(update_data=True, annotations=True)

# start with projectors dialog open, if requested
if show_options:
fig._toggle_proj_fig()
Expand Down
35 changes: 18 additions & 17 deletions mne/viz/tests/test_ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
make_fixed_length_events)
from mne.io import read_raw_fif
from mne.preprocessing import ICA, create_ecg_epochs, create_eog_epochs
from mne.utils import (requires_sklearn, _click_ch_name, catch_logging,
_record_warnings)
from mne.utils import (requires_sklearn, catch_logging, _record_warnings)
from mne.viz.ica import _create_properties_layout, plot_ica_properties
from mne.viz.utils import _fake_click

Expand Down Expand Up @@ -210,7 +209,7 @@ def test_plot_ica_properties():


@requires_sklearn
def test_plot_ica_sources(raw_orig, mpl_backend):
def test_plot_ica_sources(raw_orig, browser_backend):
"""Test plotting of ICA panel."""
raw = raw_orig.copy().crop(0, 1)
picks = _get_picks(raw)
Expand All @@ -222,48 +221,50 @@ def test_plot_ica_sources(raw_orig, mpl_backend):
ica.fit(raw, picks=ica_picks)
ica.exclude = [1]
fig = ica.plot_sources(raw)
assert mpl_backend._get_n_figs() == 1
assert browser_backend._get_n_figs() == 1
# change which component is in ICA.exclude (click data trace to remove
# current one; click name to add other one)
fig._redraw()
# ToDo: This will be different methods in pyqtgraph
x = fig.mne.traces[1].get_xdata()[5]
y = fig.mne.traces[1].get_ydata()[5]
fig._fake_click((x, y), xform='data') # exclude = []
_click_ch_name(fig, ch_index=0, button=1) # exclude = [0]
fig._click_ch_name(ch_index=0, button=1) # exclude = [0]
fig._fake_keypress(fig.mne.close_key)
fig._close_event()
assert mpl_backend._get_n_figs() == 0
assert browser_backend._get_n_figs() == 0
assert_array_equal(ica.exclude, [0])
# test when picks does not include ica.exclude.
fig = ica.plot_sources(raw, picks=[1])
assert len(plt.get_fignums()) == 1
mpl_backend._close_all()
ica.plot_sources(raw, picks=[1])
assert browser_backend._get_n_figs() == 1
browser_backend._close_all()

# dtype can change int->np.int64 after load, test it explicitly
ica.n_components_ = np.int64(ica.n_components_)

# test clicks on y-label (need >2 secs for plot_properties() to work)
long_raw = raw_orig.crop(0, 5)
fig = ica.plot_sources(long_raw)
assert len(plt.get_fignums()) == 1
assert browser_backend._get_n_figs() == 1
fig._redraw()
_click_ch_name(fig, ch_index=0, button=3)
fig._click_ch_name(ch_index=0, button=3)
assert len(fig.mne.child_figs) == 1
assert len(plt.get_fignums()) == 2
assert browser_backend._get_n_figs() == 2
# close child fig directly (workaround for mpl issue #18609)
fig._fake_keypress('escape', fig=fig.mne.child_figs[0])
assert len(plt.get_fignums()) == 1
assert browser_backend._get_n_figs() == 1
fig._fake_keypress(fig.mne.close_key)
assert len(plt.get_fignums()) == 0
assert browser_backend._get_n_figs() == 0
del long_raw

# test with annotations
orig_annot = raw.annotations
raw.set_annotations(Annotations([0.2], [0.1], 'Test'))
fig = ica.plot_sources(raw)
assert len(fig.mne.ax_main.collections) == 1
assert len(fig.mne.ax_hscroll.collections) == 1
if browser_backend.name == 'matplotlib':
assert len(fig.mne.ax_main.collections) == 1
assert len(fig.mne.ax_hscroll.collections) == 1
else:
assert len(fig.mne.regions) == 1
raw.set_annotations(orig_annot)

# test error handling
Expand Down

0 comments on commit 9315ed6

Please sign in to comment.