Skip to content

Commit

Permalink
MRG, ENH: Add mne.viz.centers_to_edges for pcolormesh (mne-tools#8023)
Browse files Browse the repository at this point in the history
* ENH: Add mne.viz.centers_to_edges for pcolormesh

* FIX: Usage

* FIX: Old numpy print
  • Loading branch information
larsoner authored Jul 18, 2020
1 parent adc70d4 commit de5e369
Show file tree
Hide file tree
Showing 10 changed files with 57 additions and 13 deletions.
2 changes: 2 additions & 0 deletions doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ Changelog

- Add ``mri`` and ``show_orientation`` arguments to :func:`mne.viz.plot_bem` by `Eric Larson`_

- Add :func:`mne.viz.centers_to_edges` to help when using :meth:`matplotlib.axes.Axes.pcolormesh` with flat shading by `Eric Larson`_

- Add "on_missing='raise'" to :meth:`mne.io.Raw.set_montage` and related functions to allow ignoring of missing electrode coordinates by `Adam Li`_

- Add better sanity checking of ``max_pca_components`` and ``n_components`` to provide more informative error messages for :class:`mne.preprocessing.ICA` by `Eric Larson`_
Expand Down
1 change: 1 addition & 0 deletions doc/python_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ Visualization

ClickableImage
add_background_image
centers_to_edges
compare_fiff
circular_layout
iter_topography
Expand Down
2 changes: 1 addition & 1 deletion examples/forward/plot_forward_sensitivity_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
ax.set_title(ch_type.upper())
ax.set_xlabel('sources')
ax.set_ylabel('sensors')
fig.colorbar(im, ax=ax, cmap='RdBu_r')
fig.colorbar(im, ax=ax)

fig_2, ax = plt.subplots()
ax.hist([grad_map.data.ravel(), mag_map.data.ravel(), eeg_map.data.ravel()],
Expand Down
5 changes: 3 additions & 2 deletions examples/time_frequency/plot_time_frequency_simulated.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from mne.baseline import rescale
from mne.time_frequency import (tfr_multitaper, tfr_stockwell, tfr_morlet,
tfr_array_morlet)
from mne.viz import centers_to_edges

print(__doc__)

Expand Down Expand Up @@ -188,8 +189,8 @@
# Baseline the output
rescale(power, epochs.times, (0., 0.1), mode='mean', copy=False)
fig, ax = plt.subplots()
mesh = ax.pcolormesh(epochs.times * 1000, freqs, power[0],
cmap='RdBu_r', vmin=vmin, vmax=vmax)
x, y = centers_to_edges(epochs.times * 1000, freqs)
mesh = ax.pcolormesh(x, y, power[0], cmap='RdBu_r', vmin=vmin, vmax=vmax)
ax.set_title('TFR calculated on a numpy array')
ax.set(ylim=freqs[[0, -1]], xlabel='Time (ms)')
fig.colorbar(mesh)
Expand Down
4 changes: 2 additions & 2 deletions mne/time_frequency/tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def _compute_tfr(epoch_data, freqs, sfreq=1.0, method='morlet',
n_freqs = len(freqs)
n_epochs, n_chans, n_times = epoch_data[:, :, decim].shape
if output in ('power', 'phase', 'avg_power', 'itc'):
dtype = np.float
dtype = np.float64
elif output in ('complex', 'avg_power_itc'):
# avg_power_itc is stored as power + 1i * itc to keep a
# simple dimensionality
Expand Down Expand Up @@ -495,7 +495,7 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim):
The decimation slice: e.g. power[:, decim]
"""
# Set output type
dtype = np.float
dtype = np.float64
if output in ['complex', 'avg_power_itc']:
dtype = np.complex128

Expand Down
3 changes: 2 additions & 1 deletion mne/viz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
plot_epochs_psd_topomap, plot_layout)
from .topo import plot_topo_image_epochs, iter_topography
from .utils import (tight_layout, mne_analyze_colormap, compare_fiff,
ClickableImage, add_background_image, plot_sensors)
ClickableImage, add_background_image, plot_sensors,
centers_to_edges)
from ._3d import (plot_sparse_source_estimates, plot_source_estimates,
plot_vector_source_estimates, plot_evoked_field,
plot_dipole_locations, snapshot_brain_montage,
Expand Down
1 change: 0 additions & 1 deletion mne/viz/backends/_pysurfer_mayavi.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def __init__(self, fig=None, size=(600, 600), bgcolor='black',
name=None, show=False, shape=(1, 1), smooth_shading=True):
if bgcolor is not None:
bgcolor = _check_color(bgcolor)
print(bgcolor)
self.mlab = _import_mlab()
self.shape = shape
if fig is None:
Expand Down
10 changes: 9 additions & 1 deletion mne/viz/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from mne.viz.utils import (compare_fiff, _fake_click, _compute_scalings,
_validate_if_list_of_axes, _get_color_list,
_setup_vmin_vmax, center_cmap)
_setup_vmin_vmax, center_cmap, centers_to_edges)
from mne.viz import ClickableImage, add_background_image, mne_analyze_colormap
from mne.utils import run_tests_if_main
from mne.io import read_raw_fif
Expand Down Expand Up @@ -171,4 +171,12 @@ def test_center_cmap():
assert not np.allclose(cmap(0.5), reference[1])


def test_centers_to_edges():
"""Test centers_to_edges."""
assert_allclose(centers_to_edges([0, 1, 2])[0], [-0.5, 0.5, 1.5, 2.5])
assert_allclose(centers_to_edges([0])[0], [-0.001, 0.001])
assert_allclose(centers_to_edges([1])[0], [0.999, 1.001])
assert_allclose(centers_to_edges([1000])[0], [999., 1001.])


run_tests_if_main()
1 change: 0 additions & 1 deletion mne/viz/topomap.py
Original file line number Diff line number Diff line change
Expand Up @@ -2305,7 +2305,6 @@ def _animate(frame, ax, ax_line, params):
line.remove()
ylim = ax_line.get_ylim()
params['line'] = ax_line.axvline(all_times[time_idx], color='r')
print(all_times[time_idx])
ax_line.set_ylim(ylim)
items.append(params['line'])
params['frame'] = frame
Expand Down
41 changes: 37 additions & 4 deletions mne/viz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2896,10 +2896,7 @@ def _plot_masked_image(ax, data, times, mask=None, yvals=None,

if yscale == "log": # pcolormesh for log scale
# compute bounds between time samples
time_diff = np.diff(times) / 2. if len(times) > 1 else [0.0005]
time_lims = np.concatenate([[times[0] - time_diff[0]], times[:-1] +
time_diff, [times[-1] + time_diff[-1]]])

time_lims, = centers_to_edges(times)
log_yvals = np.concatenate([[yvals[0] / ratio[0]], yvals,
[yvals[-1] * ratio[0]]])
yval_lims = np.sqrt(log_yvals[:-1] * log_yvals[1:])
Expand Down Expand Up @@ -3282,3 +3279,39 @@ def _trim_ticks(ticks, _min, _max):
def _set_window_title(fig, title):
if fig.canvas.manager is not None:
fig.canvas.manager.set_window_title(title)


def centers_to_edges(*arrays):
"""Convert center points to edges.
Parameters
----------
*arrays : list of ndarray
Each input array should be 1D monotonically increasing,
and will be cast to float.
Returns
-------
arrays : list of ndarray
Given each input of shape (N,), the output will have shape (N+1,).
Examples
--------
>>> x = [0., 0.1, 0.2, 0.3]
>>> y = [20, 30, 40]
>>> centers_to_edges(x, y) # doctest: +SKIP
[array([-0.05, 0.05, 0.15, 0.25, 0.35]), array([15., 25., 35., 45.])]
"""
out = list()
for ai, arr in enumerate(arrays):
arr = np.asarray(arr, dtype=float)
_check_option(f'arrays[{ai}].ndim', arr.ndim, (1,))
if len(arr) > 1:
arr_diff = np.diff(arr) / 2.
else:
arr_diff = [abs(arr[0]) * 0.001] if arr[0] != 0 else [0.001]
out.append(np.concatenate([
[arr[0] - arr_diff[0]],
arr[:-1] + arr_diff,
[arr[-1] + arr_diff[-1]]]))
return out

0 comments on commit de5e369

Please sign in to comment.