Skip to content

Commit

Permalink
MRG+1: Merge grads before rescaling in tfr topomaps (mne-tools#5312)
Browse files Browse the repository at this point in the history
* Merge grads before rescaling in tfr topomaps

* Fix error with None method

* Touch the example

* Remove vmax from example

* Update _merge_grad_data to retain shape

* Touch example.

* FIX: Ellipses

* Update whats_new.rst

* Style.
  • Loading branch information
teekuningas authored and jona-sassenhagen committed Aug 7, 2018
1 parent 1d3cb1e commit de86b1d
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 18 deletions.
2 changes: 2 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ Changelog
Bug
~~~

- Fix bug of not showing ERD's in baseline rescaled tfr topomaps if grads are combined by `Erkka Heinila`_

- Fix bug with reading measurement dates from BrainVision files by `Stefan Appelhoff`_

- Fix bug with ``mne flash_bem`` when ``flash30`` is not used by `Eric Larson`_
Expand Down
8 changes: 4 additions & 4 deletions mne/channels/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,24 +845,24 @@ def _merge_grad_data(data, method='rms'):
Parameters
----------
data : array, shape = (n_channels, n_times)
data : array, shape = (n_channels, ..., n_times)
Data for channels, ordered in pairs.
method : str
Can be 'rms' or 'mean'.
Returns
-------
data : array, shape = (n_channels / 2, n_times)
data : array, shape = (n_channels / 2, ..., n_times)
The root mean square or mean for each pair.
"""
data = data.reshape((len(data) // 2, 2, -1))
data, orig_shape = data.reshape((len(data) // 2, 2, -1)), data.shape
if method == 'mean':
data = np.mean(data, axis=1)
elif method == 'rms':
data = np.sqrt(np.sum(data ** 2, axis=1) / 2)
else:
raise ValueError('method must be "rms" or "mean, got %s.' % method)
return data
return data.reshape(data.shape[:1] + orig_shape[1:])


def generate_2d_layout(xy, w=.07, h=.05, pad=.02, ch_names=None,
Expand Down
29 changes: 23 additions & 6 deletions mne/time_frequency/tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1386,6 +1386,8 @@ def plot_joint(self, timefreqs=None, picks=None, baseline=None,
""" # noqa: E501
from ..viz.topomap import _set_contour_locator
from ..channels.layout import (find_layout, _merge_grad_data,
_pair_grad_sensors)
import matplotlib.pyplot as plt

#####################################
Expand Down Expand Up @@ -1480,7 +1482,7 @@ def plot_joint(self, timefreqs=None, picks=None, baseline=None,
############

from ..viz import plot_topomap
titles, all_data, vlims = [], [], []
titles, all_data, all_pos, vlims = [], [], [], []

# the structure here is a bit complicated to allow aggregating vlims
# over all topomaps. First, one loop over all timefreqs to collect
Expand Down Expand Up @@ -1511,9 +1513,25 @@ def plot_joint(self, timefreqs=None, picks=None, baseline=None,
fmin = freq - freq_half_range
fmax = freq + freq_half_range

data = tfr.data

pos = find_layout(tfr.info).pos if layout is None else layout.pos

# merging grads here before rescaling makes ERDs visible
if ch_type == 'grad':
picks, new_pos = _pair_grad_sensors(tfr.info,
find_layout(tfr.info))
if layout is None:
pos = new_pos
method = combine or 'rms'
data = _merge_grad_data(data[picks], method=method)

all_pos.append(pos)

data, times, freqs, _, _ = _preproc_tfr(
tfr.data, tfr.times, tfr.freqs, tmin, tmax, fmin, fmax,
None, None, vmin, vmax, None, tfr.info['sfreq'])
data, tfr.times, tfr.freqs, tmin, tmax, fmin, fmax,
mode, baseline, vmin, vmax, None, tfr.info['sfreq'])

vlims.append(np.abs(data).max())
titles.append(sub_map_title)
all_data.append(data)
Expand All @@ -1529,10 +1547,9 @@ def plot_joint(self, timefreqs=None, picks=None, baseline=None,
vmin, vmax, topomap_args_pass["contours"])
topomap_args_pass['contours'] = contours

for ax, title, data in zip(map_ax, titles, all_data):
for ax, title, data, pos in zip(map_ax, titles, all_data, all_pos):
ax.set_title(title)
plot_topomap(data.mean(-1).mean(-1),
tfr.info if layout is None else layout.pos,
plot_topomap(data.mean(axis=(-1, -2)), pos,
cmap=cmap[0], axes=ax, show=False,
**topomap_args_pass)

Expand Down
14 changes: 8 additions & 6 deletions mne/viz/topomap.py
Original file line number Diff line number Diff line change
Expand Up @@ -1187,7 +1187,13 @@ def plot_tfr_topomap(tfr, tmin=None, tmax=None, fmin=None, fmax=None,
if not show_names:
names = None

data = tfr.data
data = tfr.data[picks, :, :]

# merging grads before rescaling makes ERDs visible
if merge_grads:
from ..channels.layout import _merge_grad_data
data = _merge_grad_data(data)

data = rescale(data, tfr.times, baseline, mode, copy=True)

# crop time
Expand All @@ -1206,13 +1212,9 @@ def plot_tfr_topomap(tfr, tmin=None, tmax=None, fmin=None, fmax=None,
if fmax is not None:
ifmax = idx[-1] + 1

data = data[picks, ifmin:ifmax, itmin:itmax]
data = data[:, ifmin:ifmax, itmin:itmax]
data = np.mean(np.mean(data, axis=2), axis=1)[:, np.newaxis]

if merge_grads:
from ..channels.layout import _merge_grad_data
data = _merge_grad_data(data)

norm = False if np.min(data) < 0 else True
vmin, vmax = _setup_vmin_vmax(data, vmin, vmax, norm)
cmap = _setup_cmap(cmap, norm=norm)
Expand Down
4 changes: 2 additions & 2 deletions tutorials/plot_sensors_time_frequency.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,10 @@
fig, axis = plt.subplots(1, 2, figsize=(7, 4))
power.plot_topomap(ch_type='grad', tmin=0.5, tmax=1.5, fmin=8, fmax=12,
baseline=(-0.5, 0), mode='logratio', axes=axis[0],
title='Alpha', vmax=0.45, show=False)
title='Alpha', show=False)
power.plot_topomap(ch_type='grad', tmin=0.5, tmax=1.5, fmin=13, fmax=25,
baseline=(-0.5, 0), mode='logratio', axes=axis[1],
title='Beta', vmax=0.45, show=False)
title='Beta', show=False)
mne.viz.tight_layout()
plt.show()

Expand Down

0 comments on commit de86b1d

Please sign in to comment.