Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using n_overlap in raw.compute_psd() fails if good data segments are shorter than n_overlap #13039

Open
moritz-gerster opened this issue Dec 19, 2024 · 0 comments
Labels

Comments

@moritz-gerster
Copy link
Contributor

moritz-gerster commented Dec 19, 2024

Description of the problem

Using n_overlap only works if there are no good data segments shorter than n_overlap:

n_fft = int(raw.info["sfreq"])
n_per_seg = n_fft
n_overlap = n_fft // 2
assert n_overlap < n_per_seg  # Obviously true
spectrum = raw.compute_psd(method="welch", n_fft=n_fft, n_overlap=n_overlap, n_per_seg=n_per_seg)

>>ValueError: noverlap must be less than nperseg.

This is clearly a bug as using 1 second windows for Welch with 500 ms overlap is a very typical analysis. But it will fail if two bad segments are less than 500 ms spaced apart.

Importantly, this was no issue with my data in the past. The bug must have been introduced with some mne update (unfortunately, I don't know which one).

Steps to reproduce

import mne
from mne.time_frequency import psd_array_welch
import os

sample_data_folder = mne.datasets.sample.data_path()
sample_data_raw_file = os.path.join(sample_data_folder, 'MEG', 'sample', 'sample_audvis_raw.fif')
raw = mne.io.read_raw_fif(sample_data_raw_file)

annotations = mne.Annotations(onset=[5, 6.5], duration=[1, 1], description=['bad', 'bad'])
raw.set_annotations(annotations)

n_fft = int(raw.info["sfreq"])
n_overlap = n_fft // 2

# psd_array_welch fails
psds, freqs = psd_array_welch(raw.get_data(reject_by_annotation='NaN'), n_fft, n_fft=n_fft, n_overlap=n_overlap)
>> ValueError: noverlap must be less than nperseg.

# raw.compute_psd fails
spectrum = raw.compute_psd(method="welch", n_fft=n_fft, n_overlap=n_overlap)
>> ValueError: noverlap must be less than nperseg.

# specifying n_per_seg explicitly does not help
n_per_seg = n_fft
assert n_overlap < n_per_seg  # True
spectrum = raw.compute_psd(method="welch", n_fft=n_fft, n_overlap=n_overlap, n_per_seg=n_per_seg)
>> ValueError: noverlap must be less than nperseg.

Link to data

Jupyter Notebook Example

Expected results

The bad segments should be set to np.nan and good data segment windows with less than n_fft sample points should be discarded. This works with n_overlap=None but it does not work with n_overlap > 0.

Actual results

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[11], line 1
----> 1 spectrum = raw.compute_psd(method="welch", **kwargs)

File <decorator-gen-336>:12, in compute_psd(self, method, fmin, fmax, tmin, tmax, picks, exclude, proj, remove_dc, reject_by_annotation, n_jobs, verbose, **method_kw)

File ~/anaconda3/envs/local_sns013/lib/python3.12/site-packages/mne/io/base.py:2232, in BaseRaw.compute_psd(self, method, fmin, fmax, tmin, tmax, picks, exclude, proj, remove_dc, reject_by_annotation, n_jobs, verbose, **method_kw)
   2229 method = _validate_method(method, type(self).__name__)
   2230 self._set_legacy_nfft_default(tmin, tmax, method, method_kw)
-> 2232 return Spectrum(
   2233     self,
   2234     method=method,
   2235     fmin=fmin,
   2236     fmax=fmax,
   2237     tmin=tmin,
   2238     tmax=tmax,
   2239     picks=picks,
   2240     exclude=exclude,
   2241     proj=proj,
   2242     remove_dc=remove_dc,
   2243     reject_by_annotation=reject_by_annotation,
   2244     n_jobs=n_jobs,
   2245     verbose=verbose,
   2246     **method_kw,
   2247 )

File ~/anaconda3/envs/local_sns013/lib/python3.12/site-packages/mne/time_frequency/spectrum.py:1146, in Spectrum.__init__(self, inst, method, fmin, fmax, tmin, tmax, picks, exclude, proj, remove_dc, reject_by_annotation, n_jobs, verbose, **method_kw)
   1144 self._nave = getattr(inst, "nave", None)
   1145 # compute the spectra
-> 1146 self._compute_spectra(data, fmin, fmax, n_jobs, method_kw, verbose)
   1147 # check for correct shape and bad values
   1148 self._check_values()

File ~/anaconda3/envs/local_sns013/lib/python3.12/site-packages/mne/time_frequency/spectrum.py:445, in BaseSpectrum._compute_spectra(self, data, fmin, fmax, n_jobs, method_kw, verbose)
    443 def _compute_spectra(self, data, fmin, fmax, n_jobs, method_kw, verbose):
    444     # make the spectra
--> 445     result = self._psd_func(
    446         data, self.sfreq, fmin=fmin, fmax=fmax, n_jobs=n_jobs, verbose=verbose
    447     )
    448     # assign ._data ._freqs, ._shape
    449     psds, freqs = result

File <decorator-gen-5>:12, in psd_array_welch(x, sfreq, fmin, fmax, n_fft, n_overlap, n_per_seg, n_jobs, average, window, remove_dc, output, verbose)

File ~/anaconda3/envs/local_sns013/lib/python3.12/site-packages/mne/time_frequency/psd.py:272, in psd_array_welch(x, sfreq, fmin, fmax, n_fft, n_overlap, n_per_seg, n_jobs, average, window, remove_dc, output, verbose)
    270     agg_func = np.concatenate
    271     func = _func
--> 272 f_spect = parallel(
    273     my_spect_func(d, func=func, freq_sl=freq_sl, average=average, output=output)
    274     for d in x_splits
    275 )
    276 psds = agg_func(f_spect, axis=0)
    277 shape = dshape + (len(freqs),)

File ~/anaconda3/envs/local_sns013/lib/python3.12/site-packages/mne/time_frequency/psd.py:273, in <genexpr>(.0)
    270     agg_func = np.concatenate
    271     func = _func
    272 f_spect = parallel(
--> 273     my_spect_func(d, func=func, freq_sl=freq_sl, average=average, output=output)
    274     for d in x_splits
    275 )
    276 psds = agg_func(f_spect, axis=0)
    277 shape = dshape + (len(freqs),)

File ~/anaconda3/envs/local_sns013/lib/python3.12/site-packages/mne/time_frequency/psd.py:75, in _spect_func(epoch, func, freq_sl, average, output)
     73     spect = np.apply_along_axis(_decomp_aggregate_mask, -1, epoch, **kwargs)
     74 else:
---> 75     spect = _decomp_aggregate_mask(epoch, **kwargs)
     76 return spect

File ~/anaconda3/envs/local_sns013/lib/python3.12/site-packages/mne/time_frequency/psd.py:52, in _decomp_aggregate_mask(epoch, func, average, freq_sl)
     51 def _decomp_aggregate_mask(epoch, func, average, freq_sl):
---> 52     _, _, spect = func(epoch)
     53     spect = spect[..., freq_sl, :]
     54     # Do the averaging here (per epoch) to save memory

File ~/anaconda3/envs/local_sns013/lib/python3.12/site-packages/mne/time_frequency/psd.py:266, in psd_array_welch.<locals>.func(*args, **kwargs)
    259 with warnings.catch_warnings():
    260     warnings.filterwarnings(
    261         action="ignore",
    262         module="scipy",
    263         category=UserWarning,
    264         message=r"nperseg = \d+ is greater than input length",
    265     )
--> 266     return _func(*args, **kwargs)

File ~/anaconda3/envs/local_sns013/lib/python3.12/site-packages/scipy/signal/_spectral_py.py:783, in spectrogram(x, fs, window, nperseg, noverlap, nfft, detrend, return_onesided, scaling, axis, mode)
    780     noverlap = nperseg // 8
    782 if mode == 'psd':
--> 783     freqs, time, Sxx = _spectral_helper(x, x, fs, window, nperseg,
    784                                         noverlap, nfft, detrend,
    785                                         return_onesided, scaling, axis,
    786                                         mode='psd')
    788 else:
    789     freqs, time, Sxx = _spectral_helper(x, x, fs, window, nperseg,
    790                                         noverlap, nfft, detrend,
    791                                         return_onesided, scaling, axis,
    792                                         mode='stft')

File ~/anaconda3/envs/local_sns013/lib/python3.12/site-packages/scipy/signal/_spectral_py.py:1851, in _spectral_helper(x, y, fs, window, nperseg, noverlap, nfft, detrend, return_onesided, scaling, axis, mode, boundary, padded)
   1849     noverlap = int(noverlap)
   1850 if noverlap >= nperseg:
-> 1851     raise ValueError('noverlap must be less than nperseg.')
   1852 nstep = nperseg - noverlap
   1854 # Padding occurs after boundary extension, so that the extended signal ends
   1855 # in zeros, instead of introducing an impulse at the end.
   1856 # I.e. if x = [..., 3, 2]
   1857 # extend then pad -> [..., 3, 2, 2, 3, 0, 0, 0]
   1858 # pad then extend -> [..., 3, 2, 0, 0, 0, 2, 3]

ValueError: noverlap must be less than nperseg.

Additional information

Platform macOS-15.1.1-arm64-arm-64bit
Python 3.12.5 | packaged by conda-forge | (main, Aug 8 2024, 18:32:50) [Clang 16.0.6 ]
Executable /Users/moritzgerster/anaconda3/envs/local_sns013/bin/python
CPU arm (10 cores)
Memory 16.0 GB

Core
├☒ mne 1.7.1 (outdated, release 1.9.0 is available!)
├☑ numpy 2.0.1 (OpenBLAS 0.3.27 with 10 threads)
├☑ scipy 1.14.0
└☑ matplotlib 3.9.1 (backend=module://matplotlib_inline.backend_inline)

Numerical (optional)
├☑ sklearn 1.5.1
├☑ numba 0.60.0
├☑ nibabel 5.2.1
├☑ nilearn 0.10.4
├☑ pandas 2.2.2
├☑ h5io 0.2.4
├☑ h5py 3.11.0
└☐ unavailable dipy, openmeeg, cupy

Visualization (optional)
└☐ unavailable pyvista, pyvistaqt, vtk, qtpy, ipympl, pyqtgraph, mne-qt-browser, ipywidgets, trame_client, trame_server, trame_vtk, trame_vuetify

Ecosystem (optional)
├☑ mne-bids 0.15.0
├☑ eeglabio 0.0.2-4
├☑ edfio 0.4.3
├☑ pybv 0.7.5
└☐ unavailable mne-nirs, mne-features, mne-connectivity, mne-icalabel, mne-bids-pipeline, neo, mffpy

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant