Skip to content

Commit

Permalink
BUG: Fix add_figure for MNEQtBrowser (mne-tools#10485)
Browse files Browse the repository at this point in the history
* BUG: Fix add_figure for MNEQtBrowser

* FIX: Dont keep a ref

* FIX: Fix scraping
  • Loading branch information
larsoner authored Apr 2, 2022
1 parent 4359456 commit c992168
Show file tree
Hide file tree
Showing 14 changed files with 125 additions and 43 deletions.
2 changes: 1 addition & 1 deletion doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Enhancements

Bugs
~~~~
- None yet
- Fix bug where plots produced using the ``'qt'`` / ``mne_qt_browser`` backend could not be added using :meth:`mne.Report.add_figure` (:gh:`10485` by `Eric Larson`_)

API changes
~~~~~~~~~~~
Expand Down
9 changes: 7 additions & 2 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,7 @@
# unlinkable
'CoregistrationUI',
'IntracranialElectrodeLocator',
# TODO: fix the Renderer return type of create_3d_figure(scene=False)
'Renderer',
'mne_qt_browser.figure.MNEQtBrowser',
}
numpydoc_validate = True
numpydoc_validation_checks = {'all'} | set(error_ignores)
Expand Down Expand Up @@ -350,6 +349,12 @@ def __call__(self, gallery_conf, fname, when):
_assert_no_instances(_Renderer, when)
if PyQtGraphBrowser is not None and \
'pyqtgraphbrowser' not in skips:
# Ensure any manual fig.close() events get properly handled
from mne_qt_browser._pg_figure import QApplication
inst = QApplication.instance()
if inst is not None:
for _ in range(2):
inst.processEvents()
_assert_no_instances(PyQtGraphBrowser, when)
# This will overwrite some Sphinx printing but it's useful
# for memory timestamps
Expand Down
35 changes: 24 additions & 11 deletions mne/report/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
Figure3D, use_browser_backend)
from ..viz.misc import _plot_mri_contours, _get_bem_plotting_surfaces
from ..viz.utils import _ndarray_to_fig, tight_layout
from ..viz._scraper import _mne_qt_browser_screenshot
from ..forward import read_forward_solution, Forward
from ..epochs import read_epochs, BaseEpochs
from ..preprocessing.ica import read_ica
Expand Down Expand Up @@ -333,7 +334,6 @@ def _fig_to_img(fig, *, image_format='png', own_figure=True):
# fig can be ndarray, mpl Figure, PyVista Figure
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
_validate_type(fig, (np.ndarray, Figure, Figure3D), 'fig')
if isinstance(fig, np.ndarray):
# In this case, we are creating the fig, so we might as well
# auto-close in all cases
Expand All @@ -343,16 +343,29 @@ def _fig_to_img(fig, *, image_format='png', own_figure=True):
fig, max_width=MAX_IMG_WIDTH, max_res=MAX_IMG_RES
)
own_figure = True # close the figure we just created
elif not isinstance(fig, Figure):
from ..viz.backends.renderer import backend, MNE_3D_BACKEND_TESTING
backend._check_3d_figure(figure=fig)
if not MNE_3D_BACKEND_TESTING:
img = backend._take_3d_screenshot(figure=fig)
else: # Testing mode
img = np.zeros((2, 2, 3))

if own_figure:
backend._close_3d_figure(figure=fig)
elif isinstance(fig, Figure):
pass # nothing to do
else:
# Don't attempt a mne_qt_browser import here (it might pull in Qt
# libraries we don't want), so use a probably good enough class name
# check instead
if fig.__class__.__name__ in ('MNEQtBrowser', 'PyQtGraphBrowser'):
img = _mne_qt_browser_screenshot(fig, return_type='ndarray')
print(img.shape, img.max(), img.min(), img.mean())
elif isinstance(fig, Figure3D):
from ..viz.backends.renderer import backend, MNE_3D_BACKEND_TESTING
backend._check_3d_figure(figure=fig)
if not MNE_3D_BACKEND_TESTING:
img = backend._take_3d_screenshot(figure=fig)
else: # Testing mode
img = np.zeros((2, 2, 3))
if own_figure:
backend._close_3d_figure(figure=fig)
else:
raise TypeError(
'figure must be an instance of np.ndarray, matplotlib Figure, '
'mne_qt_browser.figure.MNEQtBrowser, or mne.viz.Figure3D, got '
f'{type(fig)}')
fig = _ndarray_to_fig(img)
if own_figure:
_constrain_fig_resolution(
Expand Down
22 changes: 20 additions & 2 deletions mne/report/tests/test_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
import pytest
from matplotlib import pyplot as plt

from mne import Epochs, read_events, read_evokeds, read_cov, pick_channels_cov
from mne import (Epochs, read_events, read_evokeds, read_cov,
pick_channels_cov, create_info)
from mne.report import report as report_mod
from mne.report.report import CONTENT_ORDER
from mne.io import read_raw_fif, read_info
from mne.io import read_raw_fif, read_info, RawArray
from mne.datasets import testing
from mne.report import Report, open_report, _ReportScraper, report
from mne.utils import (requires_nibabel, Bunch, requires_version,
Expand Down Expand Up @@ -207,6 +208,23 @@ def test_render_report(renderer_pyvistaqt, tmp_path, invisible_fig):

with pytest.raises(TypeError, match='It seems you passed a path'):
report.add_figure(fig='foo', title='title')
with pytest.raises(TypeError, match='.*MNEQtBrowser.*Figure3D.*got.*'):
report.add_figure(fig=1., title='title')


def test_render_mne_qt_browser(tmp_path, browser_backend):
"""Test adding a mne_qt_browser (and matplotlib) raw plot."""
report = Report()
info = create_info(1, 1000., 'eeg')
data = np.zeros((1, 1000))
raw = RawArray(data, info)
fig = raw.plot()
name = fig.__class__.__name__
if browser_backend.name == 'matplotlib':
assert 'MNEBrowseFigure' in name
else:
assert 'MNEQtBrowser' in name or 'PyQtGraphBrowser' in name
report.add_figure(fig, title='raw')


@testing.requires_testing_data
Expand Down
8 changes: 4 additions & 4 deletions mne/utils/_bunch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ def __init__(self, **kwargs): # noqa: D102
class BunchConst(Bunch):
"""Class to prevent us from re-defining constants (DRY)."""

def __setattr__(self, attr, val): # noqa: D105
if attr != '__dict__' and hasattr(self, attr):
raise AttributeError('Attribute "%s" already set' % attr)
super().__setattr__(attr, val)
def __setitem__(self, key, val): # noqa: D105
if key != '__dict__' and key in self:
raise AttributeError(f'Attribute {repr(key)} already set')
super().__setitem__(key, val)


###############################################################################
Expand Down
8 changes: 7 additions & 1 deletion mne/utils/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from decorator import FunctionMaker

from ._bunch import BunchConst
from ..defaults import HEAD_SIZE_DEFAULT


Expand All @@ -38,7 +39,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
# are alphabetized (you can look up by the name of the argument). This way
# the same ``docdict`` entries are easier to reuse.

docdict = dict()
docdict = BunchConst()

# %%
# A
Expand Down Expand Up @@ -361,6 +362,11 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
``dict(silhouette=True)``).
"""

docdict['browser'] = """
fig : matplotlib.figure.Figure | mne_qt_browser.figure.MNEQtBrowser
Browser instance.
"""

docdict['buffer_size_clust'] = """
buffer_size : int | None
Block size to use when computing test statistics. This can significantly
Expand Down
51 changes: 41 additions & 10 deletions mne/viz/_scraper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
#
# License: Simplified BSD

from contextlib import contextmanager

import numpy as np

from ..utils import _pl


Expand All @@ -13,7 +17,6 @@ def __repr__(self):
def __call__(self, block, block_vars, gallery_conf):
import mne_qt_browser
from sphinx_gallery.scrapers import figure_rst
from PyQt5.QtWidgets import QApplication
if gallery_conf['builder_name'] != 'html':
return ''
img_fnames = list()
Expand All @@ -29,15 +32,7 @@ def __call__(self, block, block_vars, gallery_conf):
gui._scraped = True # monkey-patch but it's easy enough
n_plot += 1
img_fnames.append(next(block_vars['image_path_iterator']))
if getattr(gui, 'load_thread', None) is not None:
if gui.load_thread.isRunning():
gui.load_thread.wait(30000)
if inst is None:
inst = QApplication.instance()
# processEvents to make sure our progressBar is updated
for _ in range(2):
inst.processEvents()
pixmap = gui.grab()
pixmap, inst = _mne_qt_browser_screenshot(gui, inst)
pixmap.save(img_fnames[-1])
# child figures
for fig in gui.mne.child_figs:
Expand All @@ -56,3 +51,39 @@ def __call__(self, block, block_vars, gallery_conf):
return figure_rst(
img_fnames, gallery_conf['src_dir'],
f'Raw plot{_pl(n_plot)}')


@contextmanager
def _screenshot_mode(browser):
browser.mne.toolbar.setVisible(False)
browser.statusBar().setVisible(False)
try:
yield
finally:
browser.mne.toolbar.setVisible(True)
browser.statusBar().setVisible(True)


def _mne_qt_browser_screenshot(browser, inst=None, return_type='pixmap'):
from mne_qt_browser._pg_figure import QApplication
if getattr(browser, 'load_thread', None) is not None:
if browser.load_thread.isRunning():
browser.load_thread.wait(30000)
if inst is None:
inst = QApplication.instance()
# processEvents to make sure our progressBar is updated
with _screenshot_mode(browser):
for _ in range(2):
inst.processEvents()
pixmap = browser.grab()
assert return_type in ('pixmap', 'ndarray')
if return_type == 'ndarray':
img = pixmap.toImage()
img = img.convertToFormat(img.Format_RGBA8888)
ptr = img.bits()
ptr.setsize(img.height() * img.width() * 4)
data = np.frombuffer(ptr, dtype=np.uint8).copy()
data.shape = (img.height(), img.width(), 4)
return data / 255.
else:
return pixmap, inst
5 changes: 2 additions & 3 deletions mne/viz/backends/_pyvista.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ...fixes import _get_args, _point_data, _cell_data, _compare_version
from ...transforms import apply_trans
from ...utils import (copy_base_doc_to_subclass_doc, _check_option,
_require_version)
_require_version, _validate_type)


with warnings.catch_warnings():
Expand Down Expand Up @@ -1050,8 +1050,7 @@ def _set_3d_title(figure, title, size=16):


def _check_3d_figure(figure):
if not isinstance(figure, PyVistaFigure):
raise TypeError('figure must be an instance of PyVistaFigure.')
_validate_type(figure, PyVistaFigure, 'figure')


def _close_3d_figure(figure):
Expand Down
2 changes: 1 addition & 1 deletion mne/viz/backends/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def create_3d_figure(size, bgcolor=(0, 0, 0), smooth_shading=True,
Returns
-------
figure : instance of Figure3D or Renderer
figure : instance of Figure3D or ``Renderer``
The requested empty figure or renderer, depending on ``scene``.
"""
renderer = _get_renderer(
Expand Down
3 changes: 1 addition & 2 deletions mne/viz/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,8 +740,7 @@ def plot_epochs(epochs, picks=None, scalings=None, n_epochs=20, n_channels=20,
Returns
-------
fig : instance of matplotlib.figure.Figure
The figure.
%(browser)s
Notes
-----
Expand Down
3 changes: 1 addition & 2 deletions mne/viz/ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ def plot_ica_sources(ica, inst, picks=None, start=None,
Returns
-------
fig : instance of Figure
The figure.
%(browser)s
Notes
-----
Expand Down
3 changes: 1 addition & 2 deletions mne/viz/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,7 @@ def plot_raw(raw, events=None, duration=10.0, start=0.0, n_channels=20,
Returns
-------
fig : matplotlib.figure.Figure | ``PyQt5.QtWidgets.QMainWindow``
Browser instance.
%(browser)s
Notes
-----
Expand Down
8 changes: 6 additions & 2 deletions mne/viz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,15 +184,19 @@ def tight_layout(pad=1.2, h_pad=None, w_pad=None, fig=None):

fig.canvas.draw()
constrained = fig.get_constrained_layout()
kwargs = dict(pad=pad, h_pad=h_pad, w_pad=w_pad)
if constrained:
return # no-op
try: # see https://github.com/matplotlib/matplotlib/issues/2654
with warnings.catch_warnings(record=True) as ws:
fig.tight_layout(pad=pad, h_pad=h_pad, w_pad=w_pad)
fig.tight_layout(**kwargs)
except Exception:
try:
with warnings.catch_warnings(record=True) as ws:
fig.set_tight_layout(dict(pad=pad, h_pad=h_pad, w_pad=w_pad))
if hasattr(fig, 'set_layout_engine'):
fig.set_layout_engine('tight', **kwargs)
else:
fig.set_tight_layout(kwargs)
except Exception:
warn('Matplotlib function "tight_layout" is not supported.'
' Skipping subplot adjustment.')
Expand Down
9 changes: 9 additions & 0 deletions tutorials/intro/70_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@
image_format='PNG'
)
report.save('report_custom_figure.html', overwrite=True)
plt.close(fig)

# %%
# The :meth:`mne.Report.add_figure` method can add multiple figures at once. In
Expand Down Expand Up @@ -458,9 +459,17 @@
figs.append(fig)
captions.append(f'Rotation angle: {round(angle, 1)}°')

# can also be a MNEQtBrowser instance
figs.append(raw.plot())
captions.append('... plus a raw data plot')

report = mne.Report(title='Multiple figures example')
report.add_figure(fig=figs, title='Fun with figures! 🥳', caption=captions)
report.save('report_custom_figures.html', overwrite=True)
for fig in figs[:-1]:
plt.close(fig)
figs[-1].close()
del figs

# %%
# Adding image files
Expand Down

0 comments on commit c992168

Please sign in to comment.