Skip to content

Commit

Permalink
MRG, ENH: Use overlap-add in spectrum_fit mode (mne-tools#7609)
Browse files Browse the repository at this point in the history
* ENH: Use overlap-add in spectrum_fit mode

* DOC: Add to tutorial [skip travis]

* FIX: Fix inplace div
  • Loading branch information
larsoner authored Aug 11, 2020
1 parent 64e3dd0 commit 5898459
Show file tree
Hide file tree
Showing 7 changed files with 217 additions and 101 deletions.
4 changes: 4 additions & 0 deletions doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ Changelog

- Add :func:`mne.read_freesurfer_lut` to make it easier to work with volume atlases by `Eric Larson`_

- Add support for overlap-add processing when ``method='spectrum_fit'`` in :func:`mne.io.Raw.notch_filter` by `Eric Larson`_

- Add functionality to interpolate bad NIRS channels by `Robert Luke`_

- Add ability to interpolate EEG channels using minimum-norm projection in :meth:`mne.io.Raw.interpolate_bads` and related functions with ``method=dict(eeg='MNE')`` by `Eric Larson`_
Expand Down Expand Up @@ -315,6 +317,8 @@ API
- In :func:`mne.stats.permutation_cluster_test` and :func:`mne.stats.permutation_cluster_1samp_test` the default parameter value ``out_type='mask'`` has changed to ``None``, which in 0.21 means ``'mask'`` but will change to mean ``'indices'`` in the next version, by `Daniel McCloy`_
- The default window size set by ``filter_length`` when ``method='spectrum_fit'`` in :meth:`mne.io.Raw.notch_filter` will change from ``None`` (use whole file) to ``'10s'`` in 0.22, by `Eric Larson`_
- ``vmin`` and ``vmax`` parameters are deprecated in :meth:`mne.Epochs.plot_psd_topomap` and :func:`mne.viz.plot_epochs_psd_topomap`; use new ``vlim`` parameter instead, by `Daniel McCloy`_.
- The method ``stc_mixed.plot_surface`` for a :class:`mne.MixedSourceEstimate` has been deprecated in favor of :meth:`stc.surface().plot(...) <mne.MixedSourceEstimate.surface>` by `Eric Larson`_
Expand Down
206 changes: 134 additions & 72 deletions mne/filter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""IIR and FIR filtering and resampling functions."""

from collections import Counter
from copy import deepcopy
from functools import partial

Expand All @@ -13,7 +14,8 @@
from .parallel import parallel_func, check_n_jobs
from .time_frequency.multitaper import _mt_spectra, _compute_mt_params
from .utils import (logger, verbose, sum_squared, check_version, warn, _pl,
_check_preload, _validate_type, _check_option)
_check_preload, _validate_type, _check_option, _ensure_int)
from ._ola import _COLA

# These values from Ifeachor and Jervis.
_length_factors = dict(hann=3.1, hamming=3.3, blackman=5.0)
Expand Down Expand Up @@ -288,6 +290,7 @@ def _firwin_design(N, freq, gain, window, sfreq):
assert freq[0] == 0
assert len(freq) > 1
assert len(freq) == len(gain)
assert N % 2 == 1
h = np.zeros(N)
prev_freq = freq[-1]
prev_gain = gain[-1]
Expand All @@ -309,6 +312,7 @@ def _firwin_design(N, freq, gain, window, sfreq):
# Construct a lowpass
this_h = firwin(this_N, (prev_freq + this_freq) / 2.,
window=window, pass_zero=True, nyq=freq[-1])
assert this_h.shape == (this_N,)
offset = (N - this_N) // 2
if this_gain == 0:
h[offset:N - offset] -= this_h
Expand Down Expand Up @@ -1070,7 +1074,8 @@ def notch_filter(x, Fs, freqs, filter_length='auto', notch_widths=None,
trans_bandwidth=1, method='fir', iir_params=None,
mt_bandwidth=None, p_value=0.05, picks=None, n_jobs=1,
copy=True, phase='zero', fir_window='hamming',
fir_design='firwin', pad='reflect_limited', verbose=None):
fir_design='firwin', pad='reflect_limited',
verbose=None):
r"""Notch filter for the signal x.
Applies a zero-phase notch filter to the signal x, operating on the last
Expand All @@ -1086,7 +1091,7 @@ def notch_filter(x, Fs, freqs, filter_length='auto', notch_widths=None,
Frequencies to notch filter in Hz, e.g. np.arange(60, 241, 60).
None can only be used with the mode 'spectrum_fit', where an F
test is used to find sinusoidal components.
%(filter_length)s
%(filter_length_notch)s
notch_widths : float | array of float | None
Width of the stop band (centred at each freq in freqs) in Hz.
If None, freqs / 200 is used.
Expand Down Expand Up @@ -1187,76 +1192,120 @@ def notch_filter(x, Fs, freqs, filter_length='auto', notch_widths=None,
fir_design, pad=pad)
elif method == 'spectrum_fit':
xf = _mt_spectrum_proc(x, Fs, freqs, notch_widths, mt_bandwidth,
p_value, picks, n_jobs, copy)
p_value, picks, n_jobs, copy, filter_length)

return xf


def _mt_spectrum_proc(x, sfreq, line_freqs, notch_widths, mt_bandwidth,
p_value, picks, n_jobs, copy):
"""Call _mt_spectrum_remove."""
from scipy import stats
# set up array for filtering, reshape to 2D, operate on last axis
n_jobs = check_n_jobs(n_jobs)
x, orig_shape, picks = _prep_for_filtering(x, copy, picks)

# XXX need to implement the moving window version for raw files
n_times = x.shape[1]

def _get_window_thresh(n_times, sfreq, mt_bandwidth, p_value):
# max taper size chosen because it has an max error < 1e-3:
# >>> np.max(np.diff(dpss_windows(953, 4, 100)[0]))
# 0.00099972447657578449
# so we use 1000 because it's the first "nice" number bigger than 953.
# but if we have a new enough scipy,
# it's only ~0.175 sec for 8 tapers even with 100000 samples
from scipy import stats
dpss_n_times_max = 100000 if check_version('scipy', '1.1') else 1000

# figure out what tapers to use
window_fun, eigvals, _ = _compute_mt_params(
window_fun, _, _ = _compute_mt_params(
n_times, sfreq, mt_bandwidth, False, False,
interp_from=min(n_times, dpss_n_times_max), verbose=False)

# F-stat of 1-p point
threshold = stats.f.ppf(1 - p_value / n_times, 2, 2 * len(window_fun) - 2)
return window_fun, threshold


def _mt_spectrum_proc(x, sfreq, line_freqs, notch_widths, mt_bandwidth,
p_value, picks, n_jobs, copy, filter_length):
"""Call _mt_spectrum_remove."""
# set up array for filtering, reshape to 2D, operate on last axis
n_jobs = check_n_jobs(n_jobs)
x, orig_shape, picks = _prep_for_filtering(x, copy, picks)
if isinstance(filter_length, str) and filter_length == 'auto':
filter_length = None
warn('The default for "filter_length" when using method="spectrum_fit"'
' is None in 0.21 but will change to 10. in 0.22, set it '
'explicitly to avoid this warning', DeprecationWarning)
if filter_length is None:
filter_length = x.shape[-1]
filter_length = min(_to_samples(filter_length, sfreq, '', ''), x.shape[-1])
get_wt = partial(
_get_window_thresh, sfreq=sfreq, mt_bandwidth=mt_bandwidth,
p_value=p_value)
window_fun, threshold = get_wt(filter_length)
if n_jobs == 1:
freq_list = list()
for ii, x_ in enumerate(x):
if ii in picks:
x[ii], f = _mt_spectrum_remove(x_, sfreq, line_freqs,
notch_widths, window_fun,
threshold)
x[ii], f = _mt_spectrum_remove_win(
x_, sfreq, line_freqs, notch_widths, window_fun, threshold,
get_wt)
freq_list.append(f)
else:
parallel, p_fun, _ = parallel_func(_mt_spectrum_remove, n_jobs)
parallel, p_fun, _ = parallel_func(_mt_spectrum_remove_win, n_jobs)
data_new = parallel(p_fun(x_, sfreq, line_freqs, notch_widths,
window_fun, threshold)
window_fun, threshold, get_wt)
for xi, x_ in enumerate(x)
if xi in picks)
freq_list = [d[1] for d in data_new]
data_new = np.array([d[0] for d in data_new])
x[picks, :] = data_new

# report found frequencies
for rm_freqs in freq_list:
if line_freqs is None:
if len(rm_freqs) > 0:
found_freqs = ', '.join(str(rm_f) for rm_f in rm_freqs)
else:
found_freqs = 'None'
logger.info(f'Detected notch frequencies:\n{found_freqs}')
# report found frequencies, but do some sanitizing first by binning into
# 1 Hz bins
counts = Counter(sum((np.unique(np.round(ff)).tolist()
for f in freq_list for ff in f), list()))
kind = 'Detected' if line_freqs is None else 'Removed'
found_freqs = '\n'.join(f' {freq:6.2f} : '
f'{counts[freq]:4d} window{_pl(counts[freq])}'
for freq in sorted(counts)) or ' None'
logger.info(f'{kind} notch frequencies (Hz):\n{found_freqs}')

x.shape = orig_shape
return x


def _mt_spectrum_remove_win(x, sfreq, line_freqs, notch_widths,
window_fun, threshold, get_thresh):
n_times = x.shape[-1]
n_samples = window_fun.shape[1]
n_overlap = (n_samples + 1) // 2
x_out = np.zeros_like(x)
rm_freqs = list()
idx = [0]

# Define how to process a chunk of data
def process(x_):
out = _mt_spectrum_remove(
x_, sfreq, line_freqs, notch_widths, window_fun, threshold,
get_thresh)
rm_freqs.append(out[1])
return (out[0],) # must return a tuple

# Define how to store a chunk of fully processed data (it's trivial)
def store(x_):
stop = idx[0] + x_.shape[-1]
x_out[..., idx[0]:stop] += x_
idx[0] = stop

_COLA(process, store, n_times, n_samples, n_overlap, sfreq,
verbose=False).feed(x)
assert idx[0] == n_times
return x_out, rm_freqs


def _mt_spectrum_remove(x, sfreq, line_freqs, notch_widths,
window_fun, threshold):
window_fun, threshold, get_thresh):
"""Use MT-spectrum to remove line frequencies.
Based on Chronux. If line_freqs is specified, all freqs within notch_width
of each line_freq is set to zero.
"""
assert x.ndim == 1
if x.shape[-1] != window_fun.shape[-1]:
window_fun, threshold = get_thresh(x.shape[-1])
# drop the even tapers
n_tapers = len(window_fun)
tapers_odd = np.arange(0, n_tapers, 2)
Expand Down Expand Up @@ -1303,8 +1352,7 @@ def _mt_spectrum_remove(x, sfreq, line_freqs, notch_widths,
# specify frequencies
indices_1 = np.unique([np.argmin(np.abs(freqs - lf))
for lf in line_freqs])
notch_widths /= 2.0
indices_2 = [np.logical_and(freqs > lf - nw, freqs < lf + nw)
indices_2 = [np.logical_and(freqs > lf - nw / 2., freqs < lf + nw / 2.)
for lf, nw in zip(line_freqs, notch_widths)]
indices_2 = np.where(np.any(np.array(indices_2), axis=0))[0]
indices = np.unique(np.r_[indices_1, indices_2])
Expand All @@ -1320,7 +1368,7 @@ def _mt_spectrum_remove(x, sfreq, line_freqs, notch_widths,
datafit = 0.0
else:
# fitted sinusoids are summed, and subtracted from data
datafit = np.sum(np.atleast_2d(fits), axis=0)
datafit = np.sum(fits, axis=0)

return x - datafit, rm_freqs

Expand Down Expand Up @@ -1586,6 +1634,34 @@ def detrend(x, order=1, axis=-1):
}


def _to_samples(filter_length, sfreq, phase, fir_design):
_validate_type(filter_length, (str, 'int-like'), 'filter_length')
if isinstance(filter_length, str):
filter_length = filter_length.lower()
err_msg = ('filter_length, if a string, must be a '
'human-readable time, e.g. "10s", or "auto", not '
'"%s"' % filter_length)
if filter_length.lower().endswith('ms'):
mult_fact = 1e-3
filter_length = filter_length[:-2]
elif filter_length[-1].lower() == 's':
mult_fact = 1
filter_length = filter_length[:-1]
else:
raise ValueError(err_msg)
# now get the number
try:
filter_length = float(filter_length)
except ValueError:
raise ValueError(err_msg)
filter_length = max(int(np.ceil(filter_length * mult_fact *
sfreq)), 1)
if fir_design == 'firwin':
filter_length += (filter_length - 1) % 2
filter_length = _ensure_int(filter_length, 'filter_length')
return filter_length


def _triage_filter_params(x, sfreq, l_freq, h_freq,
l_trans_bandwidth, h_trans_bandwidth,
filter_length, method, phase, fir_window,
Expand Down Expand Up @@ -1643,6 +1719,7 @@ def float_array(c):
elif phase == 'zero-double':
dB_cutoff = '-12 dB'

# we go to the next power of two when in FIR and zero-double mode
if method == 'iir':
# Ignore these parameters, effectively
l_stop, h_stop = l_freq, h_freq
Expand Down Expand Up @@ -1716,54 +1793,39 @@ def float_array(c):
raise ValueError('Effective band-stop frequency (%s) is too '
'high (maximum based on Nyquist is %s)'
% (h_stop, sfreq / 2.))
if isinstance(filter_length, str):

if isinstance(filter_length, str) and filter_length.lower() == 'auto':
filter_length = filter_length.lower()
if filter_length == 'auto':
h_check = h_trans_bandwidth if h_freq is not None else np.inf
l_check = l_trans_bandwidth if l_freq is not None else np.inf
mult_fact = 2. if fir_design == 'firwin2' else 1.
filter_length = max(int(round(
_length_factors[fir_window] * sfreq * mult_fact /
float(min(h_check, l_check)))), 1)
else:
err_msg = ('filter_length, if a string, must be a '
'human-readable time, e.g. "10s", or "auto", not '
'"%s"' % filter_length)
if filter_length.lower().endswith('ms'):
mult_fact = 1e-3
filter_length = filter_length[:-2]
elif filter_length[-1].lower() == 's':
mult_fact = 1
filter_length = filter_length[:-1]
else:
raise ValueError(err_msg)
# now get the number
try:
filter_length = float(filter_length)
except ValueError:
raise ValueError(err_msg)
if phase == 'zero-double': # old mode
filter_length = 2 ** int(np.ceil(np.log2(
filter_length * mult_fact * sfreq)))
else:
filter_length = max(int(np.ceil(filter_length * mult_fact *
sfreq)), 1)
if fir_design == 'firwin':
filter_length += (filter_length - 1) % 2
elif not isinstance(filter_length, int):
raise ValueError('filter_length must be a str, int, or None, got '
'%s' % (type(filter_length),))
h_check = h_trans_bandwidth if h_freq is not None else np.inf
l_check = l_trans_bandwidth if l_freq is not None else np.inf
mult_fact = 2. if fir_design == 'firwin2' else 1.
filter_length = '%ss' % (_length_factors[fir_window] * mult_fact /
float(min(h_check, l_check)),)
next_pow_2 = False # disable old behavior
else:
next_pow_2 = (
isinstance(filter_length, str) and phase == 'zero-double')

filter_length = _to_samples(filter_length, sfreq, phase, fir_design)

# use correct type of filter (must be odd length for firwin and for
# zero phase)
if fir_design == 'firwin' or phase == 'zero':
filter_length += (filter_length - 1) % 2

logger.info('- Filter length: %s samples (%0.3f sec)'
% (filter_length, filter_length / sfreq))
logger.info('')

if filter_length != 'auto':
if phase == 'zero' and method == 'fir':
filter_length += (filter_length % 2 == 0)
if filter_length <= 0:
raise ValueError('filter_length must be positive, got %s'
% (filter_length,))

if next_pow_2:
filter_length = 2 ** int(np.ceil(np.log2(filter_length)))
if fir_design == 'firwin':
filter_length += (filter_length - 1) % 2

# If we have data supplied, do a sanity check
if x is not None:
x = _check_filterable(x)
Expand Down
2 changes: 1 addition & 1 deletion mne/io/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,7 +968,7 @@ def notch_filter(self, freqs, picks=None, filter_length='auto',
Europe. None can only be used with the mode 'spectrum_fit',
where an F test is used to find sinusoidal components.
%(picks_all_data)s
%(filter_length)s
%(filter_length_notch)s
notch_widths : float | array of float | None
Width of each stop band (centred at each freq in freqs) in Hz.
If None, freqs / 200 is used.
Expand Down
4 changes: 3 additions & 1 deletion mne/io/fiff/tests/test_raw_fiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,8 +856,10 @@ def test_filter():
assert_array_almost_equal(data_bs, data_notch, sig_dec_notch)

# now use the sinusoidal fitting
assert raw.times[-1] < 10 # catch error with filter_length > n_times
raw_notch = raw.copy().notch_filter(
None, picks=picks, n_jobs=2, method='spectrum_fit')
None, picks=picks, n_jobs=2, method='spectrum_fit',
filter_length='10s')
data_notch, _ = raw_notch[picks, :]
data, _ = raw[picks, :]
assert_array_almost_equal(data, data_notch, sig_dec_notch_fit)
Expand Down
Loading

0 comments on commit 5898459

Please sign in to comment.