From fc1d25497fe2c47bd9aa129bf13ffb38aaaaf335 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Fri, 28 Jun 2019 13:28:28 +0200 Subject: [PATCH] FIX: `Evoked.decimate` is not updating `.first` and `.last` (#6504) * Fix Evoked.decimate not updating evoked.first/last * Fix unit test * Add entry to whats_new.rst --- doc/whats_new.rst | 1 + mne/evoked.py | 2 ++ mne/tests/test_evoked.py | 32 +++++++++++++++++++++----------- 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 8430570f61b..1415f5d556d 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -67,6 +67,7 @@ Bug - Fix bug in :func:`mne.Epochs.plot_psd` when some channels had zero/infinite ``psd`` values causing erroneous error messages by `Luke Bloy`_ +- Fix :func:`mne.Evoked.decimate` not setting ``inst.first`` and ``inst.last`` properly by `Marijn van Vliet`_ API ~~~ diff --git a/mne/evoked.py b/mne/evoked.py index 0399ca17637..6a4082b7de2 100644 --- a/mne/evoked.py +++ b/mne/evoked.py @@ -274,6 +274,8 @@ def decimate(self, decim, offset=0): self.info['sfreq'] = new_sfreq self.data = self.data[:, decim_slice].copy() self.times = self.times[decim_slice].copy() + self.first = int(self.times[0] * self.info['sfreq']) + self.last = len(self.times) + self.first - 1 return self def shift_time(self, tshift, relative=True): diff --git a/mne/tests/test_evoked.py b/mne/tests/test_evoked.py index d33048c2b5c..2c4b26ccde2 100644 --- a/mne/tests/test_evoked.py +++ b/mne/tests/test_evoked.py @@ -34,22 +34,32 @@ def test_decim(): """Test evoked decimation.""" rng = np.random.RandomState(0) - n_epochs, n_channels, n_times = 5, 10, 20 + n_channels, n_times = 10, 20 dec_1, dec_2 = 2, 3 decim = dec_1 * dec_2 - sfreq = 1000. + sfreq = 10. sfreq_new = sfreq / decim - data = rng.randn(n_epochs, n_channels, n_times) - events = np.array([np.arange(n_epochs), [0] * n_epochs, [1] * n_epochs]).T + data = rng.randn(n_channels, n_times) info = create_info(n_channels, sfreq, 'eeg') info['lowpass'] = sfreq_new / float(decim) - epochs = EpochsArray(data, info, events) - data_epochs = epochs.copy().decimate(decim).get_data() - data_epochs_2 = epochs.copy().decimate(decim, offset=1).get_data() - data_epochs_3 = epochs.decimate(dec_1).decimate(dec_2).get_data() - assert_array_equal(data_epochs, data[:, :, ::decim]) - assert_array_equal(data_epochs_2, data[:, :, 1::decim]) - assert_array_equal(data_epochs, data_epochs_3) + evoked = EvokedArray(data, info, tmin=-1) + evoked_dec = evoked.copy().decimate(decim) + evoked_dec_2 = evoked.copy().decimate(decim, offset=1) + evoked_dec_3 = evoked.decimate(dec_1).decimate(dec_2) + assert_array_equal(evoked_dec.data, data[:, ::decim]) + assert_array_equal(evoked_dec_2.data, data[:, 1::decim]) + assert_array_equal(evoked_dec.data, evoked_dec_3.data) + + # Check proper updating of various fields + assert evoked_dec.first == -1 + assert evoked_dec.last == 2 + assert_array_equal(evoked_dec.times, [-1, -0.4, 0.2, 0.8]) + assert evoked_dec_2.first == -1 + assert evoked_dec_2.last == 2 + assert_array_equal(evoked_dec_2.times, [-0.9, -0.3, 0.3, 0.9]) + assert evoked_dec_3.first == -1 + assert evoked_dec_3.last == 2 + assert_array_equal(evoked_dec_3.times, [-1, -0.4, 0.2, 0.8]) # Now let's do it with some real data raw = read_raw_fif(raw_fname)