Skip to content

Commit

Permalink
MRG, MAINT: standardize the way we get channel types (mne-tools#7486)
Browse files Browse the repository at this point in the history
* MAINT: standardize the way we get channel types

* fix unused imports

* fix missing comma

* closes mne-tools#7487 (rename get_channel_types -> get_channel_type_constants)

* fix tests

* allow non-integer picks

* fix docdict key

* fix test

* proper deprecation

* add docstring to new test
  • Loading branch information
drammock authored Mar 22, 2020
1 parent 55b7820 commit 9e78506
Show file tree
Hide file tree
Showing 15 changed files with 74 additions and 55 deletions.
18 changes: 14 additions & 4 deletions mne/channels/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from ..io.meas_info import anonymize_info, Info, MontageMixin
from ..io.pick import (channel_type, pick_info, pick_types, _picks_by_type,
_check_excludes_includes, _contains_ch_type,
channel_indices_by_type, pick_channels, _picks_to_idx)
channel_indices_by_type, pick_channels, _picks_to_idx,
_get_channel_types)


def _get_meg_system(info):
Expand Down Expand Up @@ -207,16 +208,25 @@ def compensation_grade(self):
"""The current gradient compensation grade."""
return get_current_comp(self.info)

def get_channel_types(self):
@fill_doc
def get_channel_types(self, picks=None, unique=False, only_data_chs=False):
"""Get a list of channel type for each channel.
Parameters
----------
%(picks_all)s
unique : bool
Whether to return only unique channel types. Default is ``False``.
only_data_chs : bool
Whether to ignore non-data channels. Default is ``False``.
Returns
-------
channel_types : list
The channel types.
"""
return [channel_type(self.info, n)
for n in range(len(self.info['ch_names']))]
return _get_channel_types(self.info, picks=picks, unique=unique,
only_data_chs=only_data_chs)


# XXX Eventually de-duplicate with _kind_dict of mne/io/meas_info.py
Expand Down
4 changes: 2 additions & 2 deletions mne/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from .cov import read_cov, compute_whitener
from .io.constants import FIFF
from .io.pick import pick_types, channel_type
from .io.pick import pick_types
from .io.proj import make_projector, _needs_eeg_average_ref_proj
from .bem import _fit_sphere
from .evoked import _read_evoked, _aspect_rev, _write_evokeds
Expand Down Expand Up @@ -1283,7 +1283,7 @@ def fit_dipole(evoked, cov, bem, trans=None, min_dist=5., n_jobs=1,
logger.info('%d bad channels total' % len(info['bads']))

# Forward model setup (setup_forward_model from setup.c)
ch_types = [channel_type(info, idx) for idx in range(info['nchan'])]
ch_types = evoked.get_channel_types()

megcoils, compcoils, megnames, meg_info = [], [], [], None
eegels, eegnames = [], []
Expand Down
6 changes: 2 additions & 4 deletions mne/evoked.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@
from .io.open import fiff_open
from .io.tag import read_tag
from .io.tree import dir_tree_find
from .io.pick import (channel_type, pick_types, _pick_data_channels,
_picks_to_idx)
from .io.pick import pick_types, _picks_to_idx
from .io.meas_info import read_meas_info, write_meas_info
from .io.proj import ProjMixin
from .io.write import (start_file, start_block, end_file, end_block,
Expand Down Expand Up @@ -552,8 +551,7 @@ def get_peak(self, ch_type=None, tmin=None, tmax=None,
""" # noqa: E501
supported = ('mag', 'grad', 'eeg', 'seeg', 'ecog', 'misc', 'hbo',
'hbr', 'None', 'fnirs_raw', 'fnirs_od')
data_picks = _pick_data_channels(self.info, with_ref_meg=False)
types_used = {channel_type(self.info, idx) for idx in data_picks}
types_used = self.get_channel_types(unique=True, only_data_chs=True)

_check_option('ch_type', str(ch_type), supported)

Expand Down
18 changes: 12 additions & 6 deletions mne/io/pick.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,16 @@

from .constants import FIFF
from ..utils import (logger, verbose, _validate_type, fill_doc, _ensure_int,
_check_option)
_check_option, deprecated)


@deprecated('The function mne.io.pick.get_channel_types() has changed names. '
'Please use mne.io.pick.get_channel_type_constants() instead.')
def get_channel_types():
return get_channel_type_constants()


def get_channel_type_constants():
"""Return all known channel types.
Returns
Expand Down Expand Up @@ -1121,13 +1127,13 @@ def _pick_inst(inst, picks, exclude, copy=True):
return inst


def _get_channel_types(info, picks=None, unique=True,
restrict_data_types=False):
def _get_channel_types(info, picks=None, unique=False, only_data_chs=False):
"""Get the data channel types in an info instance."""
picks = range(info['nchan']) if picks is None else picks
none = 'data' if only_data_chs else 'all'
picks = _picks_to_idx(info, picks, none, (), allow_empty=False)
ch_types = [channel_type(info, idx) for idx in range(info['nchan'])
if idx in picks]
if restrict_data_types is True:
if only_data_chs:
ch_types = [ch_type for ch_type in ch_types
if ch_type in _DATA_CH_TYPES_SPLIT]
return set(ch_types) if unique is True else ch_types
return set(ch_types) if unique else ch_types
13 changes: 10 additions & 3 deletions mne/io/tests/test_pick.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from mne.io.pick import (channel_indices_by_type, channel_type,
pick_types_forward, _picks_by_type, _picks_to_idx,
get_channel_types, _DATA_CH_TYPES_SPLIT,
_contains_ch_type, pick_channels_cov)
_contains_ch_type, pick_channels_cov,
_get_channel_types, get_channel_type_constants)
from mne.io.constants import FIFF
from mne.datasets import testing
from mne.utils import run_tests_if_main, catch_logging, assert_object_equal
Expand Down Expand Up @@ -70,7 +71,7 @@ def _channel_type_old(info, idx):
# iterate through all defined channel types until we find a match with ch
# go in order from most specific (most rules entries) to least specific
channel_types = sorted(
get_channel_types().items(), key=lambda x: len(x[1]))[::-1]
get_channel_type_constants().items(), key=lambda x: len(x[1]))[::-1]
for t, rules in channel_types:
for key, vals in rules.items(): # all keys must match the values
if ch.get(key, None) not in np.array(vals):
Expand Down Expand Up @@ -245,7 +246,7 @@ def test_pick_chpi():
# Make sure we don't mis-classify cHPI channels
info = read_info(op.join(io_dir, 'tests', 'data', 'test_chpi_raw_sss.fif'))
_assert_channel_types(info)
channel_types = {channel_type(info, idx) for idx in range(info['nchan'])}
channel_types = _get_channel_types(info)
assert 'chpi' in channel_types
assert 'seeg' not in channel_types
assert 'ecog' not in channel_types
Expand Down Expand Up @@ -548,4 +549,10 @@ def test_pick_channels_cov():
assert 'loglik' not in cov_copy


def test_deprecation():
"""Test deprecated call."""
with pytest.deprecated_call():
_ = get_channel_types()


run_tests_if_main()
6 changes: 3 additions & 3 deletions mne/preprocessing/ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ..cov import compute_whitener
from .. import Covariance, Evoked
from ..io.pick import (pick_types, pick_channels, pick_info,
_picks_to_idx, _DATA_CH_TYPES_SPLIT)
_picks_to_idx, _get_channel_types, _DATA_CH_TYPES_SPLIT)
from ..io.write import (write_double_matrix, write_string,
write_name_list, write_int, start_block,
end_block)
Expand Down Expand Up @@ -58,7 +58,7 @@
from ..filter import filter_data
from .bads import _find_outliers
from .ctps_ import ctps
from ..io.pick import channel_type, pick_channels_regexp
from ..io.pick import pick_channels_regexp

__all__ = ('ICA', 'ica_find_ecg_events', 'ica_find_eog_events',
'get_score_funcs', 'read_ica', 'run_ica', 'read_ica_eeglab')
Expand Down Expand Up @@ -112,7 +112,7 @@ def _check_for_unsupported_ica_channels(picks, info, allow_ref_meg=False):
"""
types = _DATA_CH_TYPES_SPLIT + ('eog',)
types += ('ref_meg',) if allow_ref_meg else ()
chs = list({channel_type(info, j) for j in picks})
chs = _get_channel_types(info, picks, unique=True, only_data_chs=False)
check = all([ch in types for ch in chs])
if not check:
raise ValueError('Invalid channel type%s passed for ICA: %s.'
Expand Down
5 changes: 2 additions & 3 deletions mne/tests/test_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from mne.cov import prepare_noise_cov
from mne.datasets import testing
from mne.io import read_raw_fif
from mne.io.pick import channel_type, _picks_by_type
from mne.io.pick import _picks_by_type, _get_channel_types
from mne.io.proj import _has_eeg_average_ref_proj
from mne.proj import compute_proj_raw
from mne.rank import (estimate_rank, compute_rank, _get_rank_sss,
Expand Down Expand Up @@ -154,8 +154,7 @@ def test_cov_rank_estimation(rank_method, proj, meg):
for proj in this_info['projs'])

# count channel types
ch_types = [channel_type(this_info, idx)
for idx in range(len(picks))]
ch_types = _get_channel_types(this_info)
n_eeg, n_mag, n_grad = [ch_types.count(k) for k in
['eeg', 'mag', 'grad']]
n_meg = n_mag + n_grad
Expand Down
4 changes: 2 additions & 2 deletions mne/time_frequency/tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,7 +1420,7 @@ def plot_joint(self, timefreqs=None, picks=None, baseline=None,
# Nonetheless, it should be refactored for code reuse.
copy = any(var is not None for var in (exclude, picks, baseline))
tfr = _pick_inst(self, picks, exclude, copy=copy)
ch_types = _get_channel_types(tfr.info)
ch_types = _get_channel_types(tfr.info, unique=True)

# if multiple sensor types: one plot per channel type, recursive call
if len(ch_types) > 1:
Expand All @@ -1431,7 +1431,7 @@ def plot_joint(self, timefreqs=None, picks=None, baseline=None,
type_picks = [idx for idx in range(tfr.info['nchan'])
if channel_type(tfr.info, idx) == this_type]
tf_ = _pick_inst(tfr, type_picks, None, copy=True)
if len(_get_channel_types(tf_.info)) > 1:
if len(_get_channel_types(tf_.info, unique=True)) > 1:
raise RuntimeError(
'Possibly infinite loop due to channel selection '
'problem. This should never happen! Please check '
Expand Down
13 changes: 6 additions & 7 deletions mne/viz/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def plot_epochs_image(epochs, picks=None, sigma=0., vmin=None,

# is picks a channel type (or None)?
picks, picked_types = _picks_to_idx(epochs.info, picks, return_kind=True)
ch_types = _get_channel_types(epochs.info, picks=picks, unique=False)
ch_types = _get_channel_types(epochs.info, picks)

# `combine` defaults to 'gfp' unless picks are specific channels and
# there was no group_by passed
Expand Down Expand Up @@ -980,17 +980,16 @@ def _prepare_mne_browse_epochs(params, projs, n_channels, n_epochs, scalings,
picks = _picks_to_idx(info, picks)
picks = sorted(picks)
# channel type string for every channel
types = [channel_type(info, ch) for ch in picks]
types = _get_channel_types(info, picks)
# list of unique channel types
ch_types = list(_get_channel_types(info))
unique_types = _get_channel_types(info, unique=True)
if order is None:
order = _DATA_CH_TYPES_ORDER_DEFAULT
inds = [pick_idx for order_type in order
for pick_idx, ch_type in zip(picks, types)
if order_type == ch_type]
if len(ch_types) > len(order):
ch_missing = [ch_type for ch_type in ch_types if ch_type not in order]
ch_missing = np.unique(ch_missing)
if len(unique_types) > len(order):
ch_missing = unique_types - set(order)
raise RuntimeError('%s are in picks but not in order.'
' Please specify all channel types picked.' %
(str(ch_missing)))
Expand Down Expand Up @@ -1174,7 +1173,7 @@ def _prepare_mne_browse_epochs(params, projs, n_channels, n_epochs, scalings,
'ev_texts': list(),
'ann': list(), # list for butterfly view annotations
'order': order,
'ch_types': ch_types})
'ch_types': unique_types})

params['plot_fun'] = partial(_plot_traces, params=params)

Expand Down
13 changes: 6 additions & 7 deletions mne/viz/evoked.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,10 @@ def _plot_evoked(evoked, picks, exclude, unit, show, ylim, proj, xlim, hline,
"found in `axes`")
ax = axes[sel]
# the unwieldy dict comp below defaults the title to the sel
titles = ({channel_type(evoked.info, idx): sel
for idx in group_by[sel]} if titles is None else titles)
_plot_evoked(evoked, group_by[sel], exclude, unit, show, ylim,
proj, xlim, hline, units, scalings,
(titles if titles is not None else
{channel_type(evoked.info, idx): sel
for idx in group_by[sel]}),
proj, xlim, hline, units, scalings, titles,
ax, plot_type, cmap=cmap, gfp=gfp,
window_title=window_title,
set_tight_layout=set_tight_layout,
Expand Down Expand Up @@ -290,7 +289,7 @@ def _plot_evoked(evoked, picks, exclude, unit, show, ylim, proj, xlim, hline,

picks = np.array([pick for pick in picks if pick not in exclude])

types = np.array([channel_type(info, idx) for idx in picks], np.unicode)
types = np.array(_get_channel_types(info, picks), np.unicode)
ch_types_used = list()
for this_type in _VALID_CHANNEL_TYPES:
if this_type in types:
Expand Down Expand Up @@ -1346,7 +1345,7 @@ def plot_evoked_joint(evoked, times="peaks", title='', picks=None,
# simply create a new evoked object with the desired channel selection
evoked = _pick_inst(evoked, picks, exclude, copy=True)
info = evoked.info
ch_types = _get_channel_types(info, restrict_data_types=True)
ch_types = _get_channel_types(info, unique=True, only_data_chs=True)

# if multiple sensor types: one plot per channel type, recursive call
if len(ch_types) > 1:
Expand All @@ -1359,7 +1358,7 @@ def plot_evoked_joint(evoked, times="peaks", title='', picks=None,
ev_ = evoked.copy().pick_channels(
[info['ch_names'][idx] for idx in range(info['nchan'])
if channel_type(info, idx) == this_type])
if len(_get_channel_types(ev_.info)) > 1:
if len(_get_channel_types(ev_.info, unique=True)) > 1:
raise RuntimeError('Possibly infinite loop due to channel '
'selection problem. This should never '
'happen! Please check your channel types.')
Expand Down
5 changes: 2 additions & 3 deletions mne/viz/ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
from ..utils import warn, _validate_type, fill_doc
from ..defaults import _handle_default
from ..io.meas_info import create_info
from ..io.pick import (pick_types, _picks_to_idx, _get_channel_types,
_DATA_CH_TYPES_ORDER_DEFAULT)
from ..io.pick import (pick_types, _picks_to_idx, _DATA_CH_TYPES_ORDER_DEFAULT)
from ..time_frequency.psd import psd_multitaper
from ..utils import _reject_data_segments

Expand Down Expand Up @@ -727,7 +726,7 @@ def plot_ica_overlay(ica, inst, exclude=None, picks=None, start=None,
title = 'Signals before (red) and after (black) cleaning'
picks = ica.ch_names if picks is None else picks
picks = _picks_to_idx(inst.info, picks, exclude=())
ch_types_used = _get_channel_types(inst.info, picks=picks, unique=True)
ch_types_used = inst.get_channel_types(picks=picks, unique=True)
if exclude is None:
exclude = ica.exclude
if not isinstance(exclude, (np.ndarray, list)):
Expand Down
5 changes: 2 additions & 3 deletions mne/viz/topo.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,9 +905,8 @@ def plot_topo_image_epochs(epochs, layout=None, sigma=0., vmin=None,
ch_names = set(layout.names) & set(epochs.ch_names)
idxs = [epochs.ch_names.index(ch_name) for ch_name in ch_names]
epochs = epochs.pick(idxs)
# iterate over a sequential index to get lists of chan. type & scale coef.
ch_idxs = range(epochs.info['nchan'])
ch_types = [channel_type(epochs.info, idx) for idx in ch_idxs]
# get lists of channel type & scale coefficient
ch_types = epochs.get_channel_types()
scale_coeffs = [scalings.get(ch_type, 1) for ch_type in ch_types]
# scale the data
epochs._data *= np.array(scale_coeffs)[:, np.newaxis]
Expand Down
9 changes: 4 additions & 5 deletions mne/viz/topomap.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from ..channels.layout import (
_find_topomap_coords, find_layout, _pair_grad_sensors, Layout,
_merge_ch_data)
from ..io.pick import (pick_types, _picks_by_type, channel_type, pick_info,
_pick_data_channels, pick_channels, _picks_to_idx)
from ..io.pick import (pick_types, _picks_by_type, pick_info, pick_channels,
_pick_data_channels, _picks_to_idx, _get_channel_types)
from ..utils import (_clean_names, _time_mask, verbose, logger, warn, fill_doc,
_validate_type, _check_sphere)
from .utils import (tight_layout, _setup_vmin_vmax, _prepare_trellis,
Expand Down Expand Up @@ -326,7 +326,7 @@ def plot_projs_topomap(projs, info, cmap=None, sensors=True,
if vlim == 'joint':
ch_idxs = np.where(np.in1d(info['ch_names'],
proj['data']['col_names']))[0]
these_ch_types = set([channel_type(info, n) for n in ch_idxs])
these_ch_types = _get_channel_types(info, ch_idxs, unique=True)
# each projector should have only one channel type
assert len(these_ch_types) == 1
types.append(list(these_ch_types)[0])
Expand Down Expand Up @@ -809,8 +809,7 @@ def _plot_topomap(data, pos, vmin=None, vmax=None, cmap=None, sensors=True,
pos = pick_info(pos, picks)

# check if there is only 1 channel type, and n_chans matches the data
ch_type = {channel_type(pos, idx)
for idx, _ in enumerate(pos["chs"])}
ch_type = _get_channel_types(pos, unique=True)
info_help = ("Pick Info with e.g. mne.pick_info and "
"mne.io.pick.channel_indices_by_type.")
if len(ch_type) > 1:
Expand Down
2 changes: 1 addition & 1 deletion mne/viz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3214,7 +3214,7 @@ def _plot_psd(inst, fig, freqs, psd_list, picks_list, titles_list,
if not average:
picks = np.concatenate(picks_list)
psd_list = np.concatenate(psd_list)
types = np.array([channel_type(inst.info, idx) for idx in picks])
types = np.array(inst.get_channel_types(picks=picks))
# Needed because the data do not match the info anymore.
info = create_info([inst.ch_names[p] for p in picks],
inst.info['sfreq'], types)
Expand Down
8 changes: 6 additions & 2 deletions tutorials/intro/plot_30_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,13 @@

###############################################################################
# To obtain several channel types at once, you could embed
# :func:`~mne.channel_type` in a :term:`list comprehension`:
# :func:`~mne.channel_type` in a :term:`list comprehension`, or use the
# :meth:`~mne.io.Raw.get_channel_types` method of a :class:`~mne.io.Raw`,
# :class:`~mne.Epochs`, or :class:`~mne.Evoked` instance:

print([mne.channel_type(info, x) for x in (25, 76, 77, 319)])
picks = (25, 76, 77, 319)
print([mne.channel_type(info, x) for x in picks])
print(raw.get_channel_types(picks=picks))

###############################################################################
# Alternatively, you can get the indices of all channels of *all* channel types
Expand Down

0 comments on commit 9e78506

Please sign in to comment.