Skip to content

Commit

Permalink
add spectrum class (mne-tools#10184)
Browse files Browse the repository at this point in the history
* STY: alphabetize imports

* wip: first sketch of spectrum class [ci skip]

implement __repr__; add placeholder _repr_html_

add draft of _repr_html_

default to multitaper for Evokeds

make raw.plot_psd() use the new code path

unify viz.plot_raw_psd code path too

support unaggregated multitaper

add picks param to spectrum.plot()

fix(ish) the units() method

allow average=False as synonym for None

handle unaggregated estimates in combo with epochs

fix CI plotting

implement get_data method [ci skip]

* refactor aggregation

* fix instance type checking

* improve TODO notes

* WIP use new class in plot_psd_topo [ci skip]

* docdict additions

* implement Spectrum.to_data_frame

* test Spectrum.to_data_frame

* adapt to_data_frame for unaggregated spectra [ci skip]

* test unaggregated welch to df

* test unagg multitaper to df

* make tests more similar

* make DRY

* fix epoch test

* simplify test

* fix flake

* use requires_pandas

* fix bad rebase

* fix API for epochs

* use new API in example

* tiny docstring improvement

* fix unused import

* fix docdict key order

* do it the smarter/safer way

* update tests to avoid deprecated calls

* better deprectation message

* convert more deprecated func calls to new method

* unused imports

* fix tests

* more unused imports

* fix circular imports

also:
- apply isort to a couple files
- revert distracting isort on otherwise barely touched file

* make CIs pass

* get I/O working

* don't store verbose attr

* test IO

* fix D202

* fix flake

* better docstring for save method and read func

* fix compute_psd docstring

* return value descr

* decorate test (h5py)

* fixup after rebase

* reorder methods

* add __getitem__ functionality

* __getitem__ tests

* add __eq__, test for .copy(), refactor IO test

* refactor to separate epochs class

* EpochsSpectrum IO

* fix type checking, better variable naming

* test evoked IO too [ci skip]

* fix type checking some more

* adjust .units() for complex multitaper data

* fix flake

* docstring refactor

* working plot_topomap implementation

* fix docstring tests

* fix pydocstyle

* add EpochsSpectrum to the public API

* make test more DRY

* tweak deprecation message

* docdict/docstrings fixes

* add plot_psd_topomap to mixin

* pytest limitation workaround

* work toward unifying plot_topomap API

* don't silently overwrite units

* TODO comments [ci skip]

* WIP tutorial changes

* TODO: plot_psd_topo

* WIP plot_topo & docdict stuff

* plot_topomap docstring dedup

* dedup legacy n_fft default

* fix varname

* more WIP plot_psd_topo

* finish migrating plot_psd_topo to mixin

* use new code path for plot_topomap

* flake

* docstring tests

* unused imports

* whitespace

* flake

* flake again

* add plot_topo as spectrum method

* don't do too much

* fix test

* flake

* fix test

* fix tutorial

* WIP spectrum class tutorial

* codespell

* update tutorial and repr_html template

* better repr, better shape checking

* tweaks from self-review

* update changelog

* flake

* fix

* undo isort / other unrelated changes

* use new API in tutorial

* remove redundant plt_show

* standardize docstring order

* explain setup.cfg entry

* fix html repr of units

* simplify __eq__ by improving object_diff

* remove deepcopy override

* fix flake8 config

* remove superfluous BibTeX fields [ci skip]

Co-authored-by: Marijn van Vliet <[email protected]>

* misc fixes [ci skip]

Co-authored-by: Marijn van Vliet <[email protected]>

* Update mne/viz/utils.py [ci skip]

Co-authored-by: Eric Larson <[email protected]>

* update old tutorials more thoroughly

* file encoding / test comments

Co-authored-by: Marijn van Vliet <[email protected]>

* use __setstate__ and __getstate__

* flake

* fix reject_by_annot appearing where it shouldn't

* fix docstrings

Co-authored-by: Marijn van Vliet <[email protected]>
Co-authored-by: Eric Larson <[email protected]>
  • Loading branch information
3 people authored Aug 27, 2022
1 parent 6cebb13 commit 93485e0
Show file tree
Hide file tree
Showing 40 changed files with 2,397 additions and 713 deletions.
3 changes: 3 additions & 0 deletions doc/_static/style.css
Original file line number Diff line number Diff line change
Expand Up @@ -327,3 +327,6 @@ ul.icon-bullets {
img.hidden {
visibility: hidden;
}
td.justify {
text-align-last: justify;
}
2 changes: 2 additions & 0 deletions doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,5 @@ API changes
~~~~~~~~~~~
- The ``bands`` parameter of :meth:`mne.Epochs.plot_psd_topomap` now accepts :class:`dict` input; legacy :class:`tuple` input is supported, but discouraged for new code (:gh:`11050` by `Daniel McCloy`_)
- The ``show_toolbar`` argument to :class:`mne.viz.Brain` is being removed by deprecation (:gh:`11049` by `Eric Larson`_)
- New classes :class:`~mne.time_frequency.Spectrum` and :class:`~mne.time_frequency.EpochsSpectrum`, created via new methods :meth:`Raw.compute_psd()<mne.io.Raw.compute_psd>`, :meth:`Epochs.compute_psd()<mne.Epochs.compute_psd>`, and :meth:`Evoked.compute_psd()<mne.Evoked.compute_psd>` (:gh:`10184` by `Daniel McCloy`_)
- The PSD functions that operate on Raw/Epochs/Evoked instances (``mne.time_frequency.psd_welch`` and ``mne.time_frequency.psd_multitaper``) are deprecated; for equivalent functionality create :class:`~mne.time_frequency.Spectrum` or :class:`~mne.time_frequency.EpochsSpectrum` objects instead and then run ``spectrum.get_data(return_freqs=True)`` (:gh:`10184` by `Daniel McCloy`_)
2 changes: 2 additions & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,8 @@
'Transform': 'mne.transforms.Transform',
'Coregistration': 'mne.coreg.Coregistration',
'Figure3D': 'mne.viz.Figure3D',
'Spectrum': 'mne.time_frequency.Spectrum',
'EpochsSpectrum': 'mne.time_frequency.EpochsSpectrum',
# dipy
'dipy.align.AffineMap': 'dipy.align.imaffine.AffineMap',
'dipy.align.DiffeomorphicMap': 'dipy.align.imwarp.DiffeomorphicMap',
Expand Down
11 changes: 11 additions & 0 deletions doc/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -2285,6 +2285,17 @@ @article{LuckGaspelin2017
doi = {10.1111/psyp.12639},
}

@article{Welch1967,
title = {The Use of Fast {{Fourier}} Transform for the Estimation of Power Spectra: {{A}} Method Based on Time Averaging over Short, Modified Periodograms},
author = {Welch, Peter D.},
year = {1967},
journal = {IEEE Transactions on Audio and Electroacoustics},
volume = {15},
number = {2},
pages = {70--73},
doi = {10.1109/TAU.1967.1161901},
}

@article{MaksymenkoEtAl2017,
title = {Strategies for statistical thresholding of source localization maps in magnetoencephalography and estimating source extent},
volume = {290},
Expand Down
3 changes: 3 additions & 0 deletions doc/time_frequency.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ Time-Frequency
AverageTFR
EpochsTFR
CrossSpectralDensity
Spectrum
EpochsSpectrum

Functions that operate on mne-python objects:

Expand All @@ -36,6 +38,7 @@ Functions that operate on mne-python objects:
tfr_stockwell
read_tfrs
write_tfrs
read_spectrum

Functions that operate on ``np.ndarray`` objects:

Expand Down
11 changes: 8 additions & 3 deletions mne/channels/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ def set_meas_date(self, meas_date):


class UpdateChannelsMixin(object):
"""Mixin class for Raw, Evoked, Epochs, AverageTFR."""
"""Mixin class for Raw, Evoked, Epochs, Spectrum, AverageTFR."""

@verbose
def pick_types(self, meg=False, eeg=False, stim=False, eog=False,
Expand Down Expand Up @@ -791,6 +791,7 @@ def _pick_drop_channels(self, idx, *, verbose=None):
# avoid circular imports
from ..io import BaseRaw
from ..time_frequency import AverageTFR, EpochsTFR
from ..time_frequency.spectrum import BaseSpectrum

msg = 'adding, dropping, or reordering channels'
if isinstance(self, BaseRaw):
Expand All @@ -815,8 +816,12 @@ def _pick_drop_channels(self, idx, *, verbose=None):
if mat is not None:
setattr(self, key, mat[idx][:, idx])

# All others (Evoked, Epochs, Raw) have chs axis=-2
axis = -3 if isinstance(self, (AverageTFR, EpochsTFR)) else -2
if isinstance(self, BaseSpectrum):
axis = self._dims.index('channel')
elif isinstance(self, (AverageTFR, EpochsTFR)):
axis = -3
else: # All others (Evoked, Epochs, Raw) have chs axis=-2
axis = -2
if hasattr(self, '_data'): # skip non-preloaded Raw
self._data = self._data.take(idx, axis=axis)
else:
Expand Down
2 changes: 2 additions & 0 deletions mne/cov.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,8 @@ def plot_topomap(self, info, ch_type=None, vmin=None,
----------
%(info_not_none)s
%(ch_type_topomap)s
.. versionadded:: 0.21
%(vmin_vmax_topomap)s
%(cmap_topomap)s
%(sensors_topomap)s
Expand Down
2 changes: 1 addition & 1 deletion mne/decoding/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from .. import pick_types
from ..filter import filter_data
from ..time_frequency.psd import psd_array_multitaper
from ..time_frequency import psd_array_multitaper
from ..utils import fill_doc, _check_option, _validate_type, verbose
from ..io.pick import (pick_info, _pick_data_channels, _picks_by_type,
_picks_to_idx)
Expand Down
136 changes: 96 additions & 40 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
pick_channels, pick_info, _pick_data_channels,
_DATA_CH_TYPES_SPLIT, _picks_to_idx)
from .io.proj import setup_proj, ProjMixin
from .io.base import BaseRaw, _get_ch_factors
from .io.base import BaseRaw, TimeMixin, _get_ch_factors
from .bem import _check_origin
from .evoked import EvokedArray
from .baseline import rescale, _log_rescale, _check_baseline
Expand All @@ -49,13 +49,14 @@
from .event import (_read_events_fif, make_fixed_length_events,
match_event_names)
from .fixes import rng_uniform
from .viz import (plot_epochs, plot_epochs_psd, plot_epochs_psd_topomap,
plot_epochs_image, plot_topo_image_epochs, plot_drop_log)
from .time_frequency.spectrum import EpochsSpectrum, SpectrumMixin
from .viz import (plot_epochs, plot_epochs_image,
plot_topo_image_epochs, plot_drop_log)
from .utils import (_check_fname, check_fname, logger, verbose,
check_random_state, warn, _pl,
sizeof_fmt, SizeMixin, copy_function_doc_to_method_doc,
_check_pandas_installed,
_check_preload, GetEpochsMixin, TimeMixin,
_check_preload, GetEpochsMixin,
_prepare_read_metadata, _prepare_write_metadata,
_check_event_id, _gen_events, _check_option,
_check_combine, _build_data_frame,
Expand Down Expand Up @@ -340,7 +341,8 @@ def _handle_event_repeated(events, event_id, event_repeated, selection,
@fill_doc
class BaseEpochs(ProjMixin, ContainsMixin, UpdateChannelsMixin,
SetChannelsMixin, InterpolationMixin, FilterMixin,
TimeMixin, SizeMixin, GetEpochsMixin, EpochAnnotationsMixin):
TimeMixin, SizeMixin, GetEpochsMixin, EpochAnnotationsMixin,
SpectrumMixin):
"""Abstract base class for `~mne.Epochs`-type classes.
.. warning:: This class provides basic functionality and should never be
Expand Down Expand Up @@ -1122,41 +1124,6 @@ def plot(self, picks=None, scalings=None, n_epochs=20, n_channels=20,
use_opengl=use_opengl, theme=theme,
overview_mode=overview_mode)

@copy_function_doc_to_method_doc(plot_epochs_psd)
def plot_psd(self, fmin=0, fmax=np.inf, tmin=None, tmax=None,
proj=False, bandwidth=None, adaptive=False, low_bias=True,
normalization='length', picks=None, ax=None, color='black',
xscale='linear', area_mode='std', area_alpha=0.33,
dB=True, estimate='auto', show=True, n_jobs=None,
average=False, line_alpha=None, spatial_colors=True,
sphere=None, exclude='bads', verbose=None):
return plot_epochs_psd(self, fmin=fmin, fmax=fmax, tmin=tmin,
tmax=tmax, proj=proj, bandwidth=bandwidth,
adaptive=adaptive, low_bias=low_bias,
normalization=normalization, picks=picks, ax=ax,
color=color, xscale=xscale, area_mode=area_mode,
area_alpha=area_alpha, dB=dB, estimate=estimate,
show=show, n_jobs=n_jobs, average=average,
line_alpha=line_alpha,
spatial_colors=spatial_colors, sphere=sphere,
exclude=exclude, verbose=verbose)

@copy_function_doc_to_method_doc(plot_epochs_psd_topomap)
def plot_psd_topomap(self, bands=None, tmin=None,
tmax=None, proj=False, bandwidth=None, adaptive=False,
low_bias=True, normalization='length', ch_type=None,
cmap=None, agg_fun=None, dB=True,
n_jobs=None, normalize=False, cbar_fmt='auto',
outlines='head', axes=None, show=True,
sphere=None, vlim=(None, None), verbose=None):
return plot_epochs_psd_topomap(
self, bands=bands, tmin=tmin, tmax=tmax,
proj=proj, bandwidth=bandwidth, adaptive=adaptive,
low_bias=low_bias, normalization=normalization, ch_type=ch_type,
cmap=cmap, agg_fun=agg_fun, dB=dB, n_jobs=n_jobs,
normalize=normalize, cbar_fmt=cbar_fmt, outlines=outlines,
axes=axes, show=show, sphere=sphere, vlim=vlim, verbose=verbose)

@copy_function_doc_to_method_doc(plot_topo_image_epochs)
def plot_topo_image(self, layout=None, sigma=0., vmin=None, vmax=None,
colorbar=None, order=None, cmap='RdBu_r',
Expand Down Expand Up @@ -2021,6 +1988,95 @@ def equalize_event_counts(self, event_ids=None, method='mintime'):
# actually remove the indices
return self, indices

@verbose
def compute_psd(self, method='multitaper', fmin=0, fmax=np.inf, tmin=None,
tmax=None, picks=None, proj=False, *, n_jobs=1,
verbose=None, **method_kw):
"""Perform spectral analysis on sensor data.
Parameters
----------
%(method_psd)s
Default is ``'multitaper'``.
%(fmin_fmax_psd)s
%(tmin_tmax_psd)s
%(picks_good_data_noref)s
%(proj_psd)s
%(n_jobs)s
%(verbose)s
%(method_kw_psd)s
Returns
-------
spectrum : instance of EpochsSpectrum
The spectral representation of each epoch.
References
----------
.. footbibliography::
"""
return EpochsSpectrum(
self, method=method, fmin=fmin, fmax=fmax, tmin=tmin, tmax=tmax,
picks=picks, proj=proj, n_jobs=n_jobs, verbose=verbose,
**method_kw)

@verbose
def plot_psd(self, fmin=0, fmax=np.inf, tmin=None, tmax=None, picks=None,
proj=False, *, method='auto', average=False, dB=True,
estimate='auto', xscale='linear', area_mode='std',
area_alpha=0.33, color='black', line_alpha=None,
spatial_colors=True, sphere=None, exclude='bads', ax=None,
show=True, n_jobs=1, verbose=None, **method_kw):
"""%(plot_psd_doc)s.
Parameters
----------
%(fmin_fmax_psd)s
%(tmin_tmax_psd)s
%(picks_good_data_noref)s
%(proj_psd)s
%(method_plot_psd_auto)s
%(average_plot_psd)s
%(dB_plot_psd)s
%(estimate_plot_psd)s
%(xscale_plot_psd)s
%(area_mode_plot_psd)s
%(area_alpha_plot_psd)s
%(color_plot_psd)s
%(line_alpha_plot_psd)s
%(spatial_colors_psd)s
%(sphere_topomap_auto)s
.. versionadded:: 0.22.0
exclude : list of str | 'bads'
Channels names to exclude from being shown. If 'bads', the bad
channels are excluded. Pass an empty list to plot all channels
(including channels marked "bad", if any).
.. versionadded:: 0.24.0
%(ax_plot_psd)s
%(show)s
%(n_jobs)s
%(verbose)s
%(method_kw_psd)s
Returns
-------
fig : instance of Figure
Figure with frequency spectra of the data channels.
Notes
-----
%(notes_plot_psd_meth)s
"""
return super().plot_psd(
fmin=fmin, fmax=fmax, tmin=tmin, tmax=tmax, picks=picks, proj=proj,
reject_by_annotation=False, method=method, average=average, dB=dB,
estimate=estimate, xscale=xscale, area_mode=area_mode,
area_alpha=area_alpha, color=color, line_alpha=line_alpha,
spatial_colors=spatial_colors, sphere=sphere, exclude=exclude,
ax=ax, show=show, n_jobs=n_jobs, verbose=verbose, **method_kw)

@verbose
def to_data_frame(self, picks=None, index=None,
scalings=None, copy=True, long_format=False,
Expand Down
Loading

0 comments on commit 93485e0

Please sign in to comment.