Skip to content

Commit

Permalink
updating the decoder code to use private functions with arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
choldgraf committed Jan 13, 2016
1 parent ceac67d commit b6c7e05
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 196 deletions.
58 changes: 58 additions & 0 deletions examples/time_frequency/plot_psd_sensors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""
==============================================================
PSD estimation for MEG sensors
==============================================================
PSD calculation with both multitaper and welch's method are displayed
"""
# Authors: Chris Holdgraf <[email protected]>
# Alexandre Gramfort <[email protected]>
# Denis Engemann <[email protected]>
#
# License: BSD (3-clause)

import matplotlib.pyplot as plt

import mne
from mne import io
from mne.time_frequency import psd_welch, psd_multitaper
from mne.datasets import somato

print(__doc__)

###############################################################################
# Set parameters
data_path = somato.data_path()
raw_fname = data_path + '/MEG/somato/sef_raw_sss.fif'
event_id, tmin, tmax = 1, -1., 3.
fmin, fmax = 2, 40
n_fft = 256

# Setup for reading the raw data
raw = io.Raw(raw_fname)
baseline = (None, 0)
events = mne.find_events(raw, stim_channel='STI 014')

# picks MEG gradiometers
picks = mne.pick_types(raw.info, meg='grad', eeg=False, eog=True, stim=False)

epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks,
baseline=baseline, reject=dict(grad=4000e-13, eog=350e-6))
picks_psd = picks[:5]

###############################################################################
# Calculate power spectral density
psds_we, freqs_we = psd_welch(epochs, tmin=tmin, tmax=tmax, fmin=fmin,
fmax=fmax, n_fft=n_fft, proj=False,
picks=picks_psd)
psds_mt, freqs_mt = psd_multitaper(epochs, tmin=tmin, tmax=tmax, fmin=fmin,
fmax=fmax, low_bias=True, proj=False,
picks=picks_psd)

f, axs = plt.subplots(1, 2)
for psd, freqs, ax in zip([psds_we, psds_mt], [freqs_we, freqs_mt], axs):
ax.plot(freqs, psd.mean(0).T)
axs[0].set(title='Welch PSD')
axs[1].set(title='Multitaper PSD')
plt.setp(axs, xlabel='Frequency', ylabel='Power Spectral Density (PSD)')
mne.viz.tight_layout()
13 changes: 5 additions & 8 deletions mne/decoding/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .. import pick_types
from ..filter import (low_pass_filter, high_pass_filter, band_pass_filter,
band_stop_filter)
from ..time_frequency.psd import psd_multitaper
from ..time_frequency.psd import _psd_multitaper
from ..externals import six
from ..utils import _check_type_picks

Expand Down Expand Up @@ -296,7 +296,7 @@ def __init__(self, sfreq=2 * np.pi, fmin=0, fmax=np.inf, bandwidth=None,
self.verbose = verbose
self.normalization = normalization

def fit(self, epochs_data, y=None):
def fit(self, epochs_data, y):
"""Compute power spectrum density (PSD) using a multi-taper method
Parameters
Expand All @@ -310,7 +310,6 @@ def fit(self, epochs_data, y=None):
-------
self : instance of PSDEstimator
returns the modified instance
"""
if not isinstance(epochs_data, np.ndarray):
raise ValueError("epochs_data should be of type ndarray (got %s)."
Expand All @@ -332,19 +331,17 @@ def transform(self, epochs_data, y=None):
Returns
-------
psd : array, shape (n_signals, len(freqs)) or (len(freqs),)
The computed PSD. This also creates the attribute `freqs_`, which
contains the frequencies in the PSD estimate.
The computed PSD.
"""

if not isinstance(epochs_data, np.ndarray):
raise ValueError("epochs_data should be of type ndarray (got %s)."
% type(epochs_data))
psd, freqs = psd_multitaper(
epochs_data, fmin=self.fmin, fmax=self.fmax, sfreq=self.sfreq,
psd, freqs = _psd_multitaper(
epochs_data, sfreq=self.sfreq, fmin=self.fmin, fmax=self.fmax,
bandwidth=self.bandwidth, adaptive=self.adaptive,
low_bias=self.low_bias, normalization=self.normalization,
n_jobs=self.n_jobs, verbose=self.verbose)
self.freqs_ = freqs
return psd


Expand Down
111 changes: 51 additions & 60 deletions mne/time_frequency/psd.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,39 @@ def _check_psd_data(inst, tmin, tmax, picks, proj):
return data, sfreq


@verbose
def _psd_welch(x, sfreq, fmin=0, fmax=np.inf, n_fft=256, n_overlap=0,
n_jobs=1, verbose=None):
"""Helper function for calculating Welch PSD."""
from scipy.signal import welch
dshape = x.shape[:-1]
n_times = x.shape[-1]
x = x.reshape(np.product(dshape), -1)

# Prep the PSD
n_fft, n_overlap = _check_nfft(n_times, n_fft, n_overlap)
win_size = n_fft / float(sfreq)
logger.info("Effective window size : %0.3f (s)" % win_size)
freqs = np.arange(n_fft // 2 + 1, dtype=float) * (sfreq / n_fft)
freq_mask = (freqs >= fmin) & (freqs <= fmax)
freqs = freqs[freq_mask]

# Parallelize across first N-1 dimensions
psds = np.empty(x.shape[:-1] + (freqs.size,))
parallel, my_pwelch, n_jobs = parallel_func(_pwelch, n_jobs=n_jobs,
verbose=verbose)
x_splits = np.array_split(x, n_jobs)
f_psd = parallel(my_pwelch(d, noverlap=n_overlap, nfft=n_fft,
fs=sfreq, freq_mask=freq_mask,
welch_fun=welch)
for d in x_splits)

# Combining/reshaping to original data shape
psds = np.concatenate(f_psd, axis=0)
psds = psds.reshape(np.hstack([dshape, -1]))
return psds, freqs


@verbose
def psd_welch(inst, fmin=0, fmax=np.inf, tmin=None, tmax=None, n_fft=256,
n_overlap=0, picks=None, proj=False, n_jobs=1, verbose=None):
Expand All @@ -159,8 +192,8 @@ def psd_welch(inst, fmin=0, fmax=np.inf, tmin=None, tmax=None, n_fft=256,
n_fft : int
The length of the tapers ie. the windows. The smaller
it is the smoother are the PSDs. The default value is 256.
If ``n_fft > len(epochs.times)``, it will be adjusted down to
``len(epochs.times)``.
If ``n_fft > len(inst.times)``, it will be adjusted down to
``len(inst.times)``.
n_overlap : int
The number of points of overlap between blocks. Will be adjusted
to be <= n_fft.
Expand Down Expand Up @@ -192,47 +225,31 @@ def psd_welch(inst, fmin=0, fmax=np.inf, tmin=None, tmax=None, n_fft=256,
return psds, freqs


@verbose
def _psd_welch(x, sfreq, fmin=0, fmax=np.inf, n_fft=256, n_overlap=0,
n_jobs=1, verbose=None):
"""Helper function for calculating Welch PSD."""
from scipy.signal import welch
def _psd_multitaper(x, sfreq, fmin=0, fmax=np.inf, bandwidth=None,
adaptive=False, low_bias=True, normalization='length',
n_jobs=1, verbose=None):
"""Helper function for calculating Multitaper PSD."""
from .multitaper import multitaper_psd
dshape = x.shape[:-1]
n_times = x.shape[-1]

# This will return the same object if len(dshape) == 1
x = x.reshape(np.product(dshape), n_times)

# Prep the PSD
n_fft, n_overlap = _check_nfft(n_times, n_fft, n_overlap)
win_size = n_fft / float(sfreq)
logger.info("Effective window size : %0.3f (s)" % win_size)
freqs = np.arange(n_fft // 2 + 1, dtype=float) * (sfreq / n_fft)
freq_mask = (freqs >= fmin) & (freqs <= fmax)
freqs = freqs[freq_mask]
n_freqs = len(freqs)
x = x.reshape(np.product(dshape), -1)

# Parallelize across first N-1 dimensions
psds = np.empty(x.shape[:-1] + (freqs.size,))
parallel, my_pwelch, n_jobs = parallel_func(_pwelch, n_jobs=n_jobs,
verbose=verbose)
x_splits = np.array_split(x, n_jobs)
f_psd = parallel(my_pwelch(d, noverlap=n_overlap, nfft=n_fft,
fs=sfreq, freq_mask=freq_mask,
welch_fun=welch)
for d in x_splits)
# Stack data so it's treated separately
psds, freqs = multitaper_psd(x=x, sfreq=sfreq, fmin=fmin, fmax=fmax,
bandwidth=bandwidth, adaptive=adaptive,
low_bias=low_bias,
normalization=normalization, n_jobs=n_jobs,
verbose=verbose)

# Combining/reshaping to original data shape
psds = np.concatenate(f_psd, axis=0)
psds = psds.reshape(np.hstack([dshape, n_freqs]))
psds = psds.reshape(np.hstack([dshape, -1]))
return psds, freqs


@verbose
def psd_multitaper(inst, fmin=0, fmax=np.inf, tmin=None, tmax=None,
bandwidth=None, adaptive=False, low_bias=True,
normalization='length', picks=None, proj=False,
sfreq=None, n_jobs=1, verbose=None):
n_jobs=1, verbose=None):
"""Compute the PSD using multitapers.
Expand Down Expand Up @@ -267,9 +284,6 @@ def psd_multitaper(inst, fmin=0, fmax=np.inf, tmin=None, tmax=None,
If None, take all channels.
proj : bool
Apply SSP projection vectors. If inst is ndarray this is not used.
sfreq : float
The sampling frequency of the data. Required if inst is an array,
otherwise this is not used and sfreq is pulled from inst.
n_jobs : int
Number of CPUs to use in the computation.
verbose : bool, str, int, or None
Expand All @@ -280,7 +294,7 @@ def psd_multitaper(inst, fmin=0, fmax=np.inf, tmin=None, tmax=None,
psds : ndarray, shape ([n_epochs], n_channels, n_freqs)
The power spectral densities. If Raw is provided,
then psds will be 2-D.
freqs : ndarray (n_freqs)
freqs : ndarray, shape (n_freqs)
The frequencies.
"""
# Prep data
Expand All @@ -293,29 +307,6 @@ def psd_multitaper(inst, fmin=0, fmax=np.inf, tmin=None, tmax=None,
return psds, freqs


def _psd_multitaper(x, sfreq, fmin=0, fmax=np.inf, bandwidth=None,
adaptive=False, low_bias=True, normalization='length',
n_jobs=1, verbose=None):
"""Helper function for calculating Welch PSD."""
from .multitaper import multitaper_psd
dshape = x.shape[:-1]
n_times = x.shape[-1]

# This will return the same object if len(dshape) == 1
x = x.reshape(np.product(dshape), n_times)

# Stack data so it's treated separately
psds, freqs = multitaper_psd(x=x, sfreq=sfreq, fmin=fmin, fmax=fmax,
bandwidth=bandwidth, adaptive=adaptive,
low_bias=low_bias,
normalization=normalization, n_jobs=n_jobs,
verbose=verbose)

# Combining/reshaping to original data shape
psds = psds.reshape(np.hstack([dshape, len(freqs)]))
return psds, freqs


@verbose
@deprecated('This will be deprecated in release v0.13, see psd_ functions.')
def compute_epochs_psd(epochs, picks=None, fmin=0, fmax=np.inf, tmin=None,
Expand Down Expand Up @@ -355,7 +346,7 @@ def compute_epochs_psd(epochs, picks=None, fmin=0, fmax=np.inf, tmin=None,
-------
psds : ndarray (n_epochs, n_channels, n_freqs)
The power spectral densities.
freqs : ndarray (n_freqs)
freqs : ndarray, shape (n_freqs)
The frequencies.
"""
from scipy.signal import welch
Expand Down
Loading

0 comments on commit b6c7e05

Please sign in to comment.