Skip to content

Commit

Permalink
MRG, ENH: Scrape traces when available (mne-tools#7927)
Browse files Browse the repository at this point in the history
* ENH: Scrape traces when available [circle full]

* API: Disallow show_traces=True with time_viewer=False [circle full]
  • Loading branch information
larsoner authored Jul 16, 2020
1 parent 7de4420 commit bd6f426
Show file tree
Hide file tree
Showing 17 changed files with 186 additions and 85 deletions.
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ jobs:
echo "set -e" >> $BASH_ENV
echo "export DISPLAY=:99" >> $BASH_ENV
echo "export OPENBLAS_NUM_THREADS=4" >> $BASH_ENV
echo "export XDG_RUNTIME_DIR=/tmp/runtime-circleci" >> $BASH_ENV
source tools/get_minimal_commands.sh
echo "source ${PWD}/tools/get_minimal_commands.sh" >> $BASH_ENV
echo "export MNE_3D_BACKEND=pyvista" >> $BASH_ENV
echo "export _MNE_BRAIN_TRACES_AUTO=false" >> $BASH_ENV
echo "export PATH=~/.local/bin/:${MNE_ROOT}/bin:$PATH" >> $BASH_ENV
echo "BASH_ENV:"
cat $BASH_ENV
Expand Down
2 changes: 1 addition & 1 deletion doc/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ html-noplot:
@echo "Build finished. The HTML pages are in _build/html_stable."

html_dev-front:
@PATTERN="\(plot_mne_dspm_source_localization.py\|plot_receptive_field.py\|plot_mne_inverse_label_connectivity.py\|plot_sensors_decoding.py\|plot_stats_cluster_spatio_temporal.py\|plot_visualize_evoked.py\)" make html_dev-pattern;
@PATTERN="\(plot_mne_dspm_source_localization.py\|plot_receptive_field.py\|plot_mne_inverse_label_connectivity.py\|plot_sensors_decoding.py\|plot_stats_cluster_spatio_temporal.py\|plot_20_visualize_evoked.py\)" make html_dev-pattern;

dirhtml:
$(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) _build/dirhtml
Expand Down
2 changes: 2 additions & 0 deletions doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ Changelog

- Speed up :meth:`mne.Epochs.copy` and :meth:`mne.Epochs.__getitem__` by avoiding copying immutable attributes by `Eric Larson`_

- Speed up and reduce memory usage of :meth:`mne.SourceEstimate.plot` and related functions/methods when ``show_traces=True`` by `Eric Larson`_

- Reduce memory usage of `~mne.io.Raw.plot_psd`, `~mne.time_frequency.psd_welch`, and `~mne.time_frequency.psd_array_welch` for long segments of data by `Eric Larson`_

- Support for saving movies of source time courses (STCs) with ``brain.save_movie`` method and from graphical user interface by `Guillaume Favelier`_
Expand Down
5 changes: 5 additions & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,11 @@
scrapers += (report_scraper,)
else:
report_scraper = None
if 'pyvista' in scrapers:
brain_scraper = mne.viz._brain._BrainScraper()
scrapers = list(scrapers)
scrapers.insert(scrapers.index('pyvista'), brain_scraper)
scrapers = tuple(scrapers)


def append_attr_meth_examples(app, what, name, obj, options, lines):
Expand Down
3 changes: 1 addition & 2 deletions mne/tests/test_import_nesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,9 @@
out |= {'scipy submodules: %s' % list(bad)}
# check sklearn and others
_sklearn = _pandas = _mayavi = _matplotlib = False
for x in sys.modules.keys():
for key in ('sklearn', 'pandas', 'mayavi', 'pyvista', 'matplotlib',
'dipy', 'nibabel', 'cupy', 'picard'):
'dipy', 'nibabel', 'cupy', 'picard', 'pyvistaqt'):
if x.startswith(key):
out |= {key}
if len(out) > 0:
Expand Down
2 changes: 1 addition & 1 deletion mne/utils/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,7 @@
show_traces : bool | str
If True, enable interactive picking of a point on the surface of the
brain and plot it's time course using the bottom 1/3 of the figure.
This feature is only available with the PyVista 3d backend when
This feature is only available with the PyVista 3d backend, and requires
``time_viewer=True``. Defaults to 'auto', which will use True if and
only if ``time_viewer=True``, the backend is PyVista, and there is more
than one time point.
Expand Down
8 changes: 4 additions & 4 deletions mne/viz/_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

from distutils.version import LooseVersion
from itertools import cycle
import os
import os.path as op
import sys
import warnings
Expand Down Expand Up @@ -1773,9 +1772,7 @@ def _check_time_viewer_compatibility(brain, time_viewer, show_traces):
not using_mayavi and
time_viewer and
brain._times is not None and
len(brain._times) > 1 and
# XXX temporary hidden workaround for memory problems on CircleCI
os.getenv('_MNE_BRAIN_TRACES_AUTO', 'true').lower() != 'false'
len(brain._times) > 1
)

if _get_3d_backend() == "mayavi" and all([time_viewer, show_traces]):
Expand All @@ -1786,6 +1783,9 @@ def _check_time_viewer_compatibility(brain, time_viewer, show_traces):
raise RuntimeError('This function requires pysurfer version '
'>= 0.9')

if show_traces and not time_viewer:
raise ValueError('show_traces cannot be used when time_viewer=False')

if time_viewer:
if using_mayavi:
from surfer import TimeViewer
Expand Down
2 changes: 1 addition & 1 deletion mne/viz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@
from .backends.renderer import (set_3d_backend, get_3d_backend, use_3d_backend,
set_3d_view, set_3d_title, create_3d_figure,
get_brain_class)
from . import backends
from . import backends, _brain
1 change: 1 addition & 0 deletions mne/viz/_brain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# License: Simplified BSD

from ._brain import _Brain
from ._scraper import _BrainScraper
from ._timeviewer import _TimeViewer, _LinkViewer

__all__ = ['_Brain']
7 changes: 5 additions & 2 deletions mne/viz/_brain/_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
from .._3d import _process_clim, _handle_time

from ...surface import mesh_edges
from ...morph import _hemi_morph
from ...label import read_label, _read_annot
from ...utils import (_check_option, logger, verbose, fill_doc, _validate_type,
use_log_level)

Expand Down Expand Up @@ -269,6 +267,7 @@ def __init__(self, subject_id, hemi, surf, title=None,
self._renderer.set_camera(azimuth=views_dict[v].azim,
elevation=views_dict[v].elev)

self._closed = False
if show:
self._renderer.show()

Expand Down Expand Up @@ -592,6 +591,7 @@ def add_label(self, label, color=None, alpha=1, scalar_thresh=None,
To remove previously added labels, run Brain.remove_labels().
"""
from matplotlib.colors import colorConverter
from ...label import read_label
if isinstance(label, str):
if color is None:
color = "crimson"
Expand Down Expand Up @@ -820,6 +820,7 @@ def add_annotation(self, annot, borders=True, alpha=1, hemi=None,
These are passed to the underlying
``mayavi.mlab.pipeline.surface`` call.
"""
from ...label import _read_annot
hemis = self._check_hemis(hemi)

# Figure out where the data is coming from
Expand Down Expand Up @@ -925,6 +926,7 @@ def resolve_coincident_topology(self, actor):

def close(self):
"""Close all figures and cleanup data structure."""
self._closed = True
self._renderer.close()

def show(self):
Expand Down Expand Up @@ -1028,6 +1030,7 @@ def set_data_smoothing(self, n_steps):
n_steps : int
Number of smoothing steps
"""
from ...morph import _hemi_morph
for hemi in ['lh', 'rh']:
hemi_data = self._data.get(hemi)
if hemi_data is not None:
Expand Down
63 changes: 63 additions & 0 deletions mne/viz/_brain/_scraper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import os.path as op

import numpy as np

from ._brain import _Brain


class _BrainScraper(object):
"""Scrape Brain objects."""

def __repr__(self):
return '<BrainScraper>'

def __call__(self, block, block_vars, gallery_conf):
rst = ''
for brain in block_vars['example_globals'].values():
# Only need to process if it's a brain with a time_viewer
# with traces on and shown in the same window, otherwise
# PyVista and matplotlib scrapers can just do the work
if (not isinstance(brain, _Brain)) or brain._closed:
continue
from matplotlib.image import imsave
from sphinx_gallery.scrapers import figure_rst
img_fname = next(block_vars['image_path_iterator'])
img = brain.screenshot()
assert img.size > 0
if getattr(brain, 'time_viewer', None) is not None and \
brain.time_viewer.show_traces and \
not brain.time_viewer.separate_canvas:
canvas = brain.time_viewer.mpl_canvas.fig.canvas
canvas.draw_idle()
# In theory, one of these should work:
#
# trace_img = np.frombuffer(
# canvas.tostring_rgb(), dtype=np.uint8)
# trace_img.shape = canvas.get_width_height()[::-1] + (3,)
#
# or
#
# trace_img = np.frombuffer(
# canvas.tostring_rgb(), dtype=np.uint8)
# size = time_viewer.mpl_canvas.getSize()
# trace_img.shape = (size.height(), size.width(), 3)
#
# But in practice, sometimes the sizes does not match the
# renderer tostring_rgb() size. So let's directly use what
# matplotlib does in lib/matplotlib/backends/backend_agg.py
# before calling tobytes():
trace_img = np.asarray(
canvas.renderer._renderer).take([0, 1, 2], axis=2)
# need to slice into trace_img because generally it's a bit
# smaller
delta = trace_img.shape[1] - img.shape[1]
if delta > 0:
start = delta // 2
trace_img = trace_img[:, start:start + img.shape[1]]
img = np.concatenate([img, trace_img], axis=0)
imsave(img_fname, img)
assert op.isfile(img_fname)
rst += figure_rst(
[img_fname], gallery_conf['src_dir'], brain._title)
brain.close()
return rst
Loading

0 comments on commit bd6f426

Please sign in to comment.