Skip to content

Commit

Permalink
[MRG] Float colors for plot_compare_evoked (mne-tools#4775)
Browse files Browse the repository at this point in the history
  • Loading branch information
jona-sassenhagen authored and agramfort committed Nov 29, 2017
1 parent 63e4e3b commit 857b1fe
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 22 deletions.
34 changes: 25 additions & 9 deletions examples/stats/plot_sensor_regression.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
=====================================
Sensor space least squares regression
=====================================
============================================================================
Analysing continuous features with binning and regression in sensor space
============================================================================
Predict single trial activity from a continuous variable.
A single-trial regression is performed in each sensor and timepoint
Expand Down Expand Up @@ -35,20 +35,36 @@
import pandas as pd
import mne
from mne.stats import linear_regression
from mne.viz import plot_compare_evokeds
from mne.datasets import kiloword

# Load the data
path = kiloword.data_path() + '/kword_metadata-epo.fif'
epochs = mne.read_epochs(path)
print(epochs.metadata.head())

# Add intercept column
df = pd.DataFrame(epochs.metadata)
epochs.metadata = df.assign(Intercept=[1 for _ in epochs.events])
##############################################################################
# Psycholinguistically relevant word characteristics are continuous. I.e.,
# concreteness or imaginability is a graded property. In the metadata,
# we have concreteness ratings on a 5-point scale. We can show the dependence
# of the EEG on concreteness by dividing the data into bins and plotting the
# mean activity per bin, color coded.
name = "Concreteness"
df = epochs.metadata
df[name] = pd.cut(df[name], 11, labels=False) / 10
colors = {str(val): val for val in df[name].unique()}
epochs.metadata = df.assign(Intercept=1) # Add an intercept for later
evokeds = {val: epochs[name + " == " + val].average() for val in colors}
plot_compare_evokeds(evokeds, colors=colors, split_legend=True,
cmap=(name + " Percentile", "viridis"))

# Run and visualize the regression
names = ["Intercept", "Concreteness", "BigramFrequency"]
##############################################################################
# We observe that there appears to be a monotonic dependence of EEG on
# concreteness. We can also conduct a continuous analysis: single-trial level
# regression with concreteness as a continuous (although here, binned)
# feature. We can plot the resulting regression coefficient just like an
# Event-related Potential.
names = ["Intercept", name]
res = linear_regression(epochs, epochs.metadata[names], names=names)

for cond in names:
res[cond].beta.plot_joint(title=cond)
56 changes: 44 additions & 12 deletions mne/viz/evoked.py
Original file line number Diff line number Diff line change
Expand Up @@ -1515,6 +1515,7 @@ def _setup_styles(conditions, styles, cmap, colors, linestyles):
import matplotlib.pyplot as plt
# continuous colors
the_colors, color_conds, color_order = None, None, None
colors_are_float = False
if cmap is not None:
for color_value in colors.values():
try:
Expand All @@ -1526,18 +1527,38 @@ def _setup_styles(conditions, styles, cmap, colors, linestyles):
cmapper = getattr(plt.cm, cmap, cmap)
color_conds = list(colors.keys())
all_colors = [colors[cond] for cond in color_conds]
n_colors = len(all_colors)
color_order = np.array(all_colors).argsort()
color_indices = color_order.argsort()

if all([isinstance(color, Integral) for color in all_colors]):
msg = "Integer colors detected, mapping to rank positions ..."
n_colors = len(all_colors)
colors_ = {cond: ind for cond, ind in
zip(color_conds, color_indices)}

def convert_colors(color):
return colors_[color]
else:
for color in all_colors:
if not 0 <= color <= 1:
raise ValueError("Values of colors must be all-integer or "
"floats between 0 and 1, got %s." % color)
msg = "Float colors detected, mapping to percentiles ..."
n_colors = 101 # percentiles plus 1 if we have 1.0s
colors_old = colors.copy()

def convert_colors(color):
return int(colors_old[color] * 100)
colors_are_float = True
logger.info(msg)
the_colors = cmapper(np.linspace(0, 1, n_colors))

colors_ = {cond: ind for cond, ind in zip(color_conds, color_indices)}
colors = dict()
for cond in conditions:
for cond_number, color in colors_.items():
if cond_number in cond:
colors[cond] = the_colors[color]
cond_ = cond.split("/")
for color in color_conds:
if color in cond_:
colors[cond] = the_colors[convert_colors(color)]
continue

# categorical colors
Expand Down Expand Up @@ -1572,7 +1593,7 @@ def _setup_styles(conditions, styles, cmap, colors, linestyles):
styles[condition]['linestyle'] = styles[condition].get(
'linestyle', linestyles[condition])

return styles, the_colors, color_conds, color_order
return styles, the_colors, color_conds, color_order, colors_are_float


def plot_compare_evokeds(evokeds, picks=None, gfp=False, colors=None,
Expand Down Expand Up @@ -1616,7 +1637,10 @@ def plot_compare_evokeds(evokeds, picks=None, gfp=False, colors=None,
If None (default), a sequence of desaturated colors is used.
If `cmap` is None, `colors` will indicate how each condition is
colored with reference to its position on the colormap - see `cmap`
below.
below. In that case, the values of colors must be either integers,
in which case they will be mapped to colors in rank order; or floats
between 0 and 1, in which case they will be mapped to percentiles of
the colormap.
linestyles : list | dict
If a list, will be sequentially and repeatedly used for evoked plot
linestyles.
Expand Down Expand Up @@ -1645,7 +1669,8 @@ def plot_compare_evokeds(evokeds, picks=None, gfp=False, colors=None,
value corresponds to the position on the colorbar.
If ``evokeds`` is a dict, ``colors`` should be a dict mapping from
(potentially HED-style) condition tags to numbers corresponding to
rank order positions on the colorbar. E.g., ::
positions on the colorbar - rank order for integers, or floats for
percentiles. E.g., ::
evokeds={"cond1/A": ev1, "cond2/A": ev2, "cond3/A": ev3, "B": ev4},
cmap='viridis', colors=dict(cond1=1 cond2=2, cond3=3),
Expand Down Expand Up @@ -1939,7 +1964,7 @@ def plot_compare_evokeds(evokeds, picks=None, gfp=False, colors=None,
legend_lines.append(line)
legend_labels.append(style)

styles, the_colors, color_conds, color_order =\
styles, the_colors, color_conds, color_order, colors_are_float =\
_setup_styles(data_dict.keys(), styles, cmap, colors, linestyles)

# We now have a 'styles' dict with one entry per condition, specifying at
Expand Down Expand Up @@ -2042,9 +2067,16 @@ def plot_compare_evokeds(evokeds, picks=None, gfp=False, colors=None,
from mpl_toolkits.axes_grid1 import make_axes_locatable
divider = make_axes_locatable(ax)
ax_cb = divider.append_axes("right", size="5%", pad=0.05)
ax_cb.imshow(the_colors[:, np.newaxis, :], interpolation='none')
ax_cb.set_yticks(np.arange(len(the_colors)))
ax_cb.set_yticklabels(np.array(color_conds)[color_order])
if colors_are_float:
ax_cb.imshow(the_colors[:, np.newaxis, :], interpolation='none',
aspect=.05)
color_ticks = np.array(list(set(colors.values()))) * 100
ax_cb.set_yticks(color_ticks)
ax_cb.set_yticklabels(color_ticks)
else:
ax_cb.imshow(the_colors[:, np.newaxis, :], interpolation='none')
ax_cb.set_yticks(np.arange(len(the_colors)))
ax_cb.set_yticklabels(np.array(color_conds)[color_order])
ax_cb.yaxis.tick_right()
ax_cb.set_xticks(())
ax_cb.set_ylabel(cmap_label)
Expand Down
9 changes: 8 additions & 1 deletion mne/viz/tests/test_evoked.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ def test_plot_evoked():
# various bad styles
params = [dict(picks=3, colors=dict(fake=1)),
dict(picks=3, styles=dict(fake=1)), dict(picks=3, gfp=True),
dict(picks=3, show_sensors="a")]
dict(picks=3, show_sensors="a"),
dict(colors=dict(red=10., blue=-2))]
for param in params:
assert_raises(ValueError, plot_compare_evokeds, evoked, **param)
assert_raises(TypeError, plot_compare_evokeds, evoked, picks='str')
Expand Down Expand Up @@ -217,6 +218,12 @@ def test_plot_evoked():
contrasts, colors=colors, picks=[0], cmap='Reds',
split_legend=split, linestyles=linestyles,
ci=False, show_sensors=False)
colors = {"a" + str(ii): ii / len(evokeds)
for ii, _ in enumerate(evokeds)}
plot_compare_evokeds(
contrasts, colors=colors, picks=[0], cmap='Reds',
split_legend=split, linestyles=linestyles, ci=False,
show_sensors=False)
red.info["chs"][0]["loc"][:2] = 0 # test plotting channel at zero
plot_compare_evokeds(red, picks=[0],
ci=lambda x: [x.std(axis=0), -x.std(axis=0)])
Expand Down

0 comments on commit 857b1fe

Please sign in to comment.