diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index 8b571306a4f..bd50ad4589b 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -119,6 +119,8 @@ Bugs - Fix bug with reading EDF and KIT files on big endian architectures such as s390x (:gh:`8618` by `Eric Larson`_) +- Fix bug with :func:`mne.beamformer.make_dics` where the ``rank`` parameter was not properly handled (:gh:`8594` by `Marijn van Vliet`_ and `Eric Larson`_) + - Fix bug with :func:`mne.beamformer.apply_dics` where the whitener was not properly applied (:gh:`8610` by `Eric Larson`_) - Fix bug with `~mne.viz.plot_epochs_image` when ``order`` is supplied and multiple conditions are plotted (:gh:`8377` by `Daniel McCloy`_ ) diff --git a/mne/beamformer/_dics.py b/mne/beamformer/_dics.py index 52c5b512dc9..28f9b0c5e29 100644 --- a/mne/beamformer/_dics.py +++ b/mne/beamformer/_dics.py @@ -15,6 +15,7 @@ _check_option, _validate_type) from ..forward import _subject_from_forward from ..minimum_norm.inverse import combine_xyz, _check_reference, _check_depth +from ..rank import compute_rank from ..source_estimate import _make_stc, _get_src_type from ..time_frequency import csd_fourier, csd_multitaper, csd_morlet from ._compute_beamformer import (_prepare_beamformer_input, @@ -166,7 +167,8 @@ def make_dics(info, forward, csd, reg=0.05, noise_csd=None, label=None, frequencies = [np.mean(freq_bin) for freq_bin in csd.frequencies] n_freqs = len(frequencies) - _check_one_ch_type('dics', info, forward, csd, noise_csd) + _, _, allow_mismatch = _check_one_ch_type('dics', info, forward, csd, + noise_csd) # remove bads so that equalize_channels only keeps all good info = pick_info(info, pick_channels(info['ch_names'], [], info['bads'])) info, forward, csd = equalize_channels([info, forward, csd]) @@ -181,6 +183,23 @@ def make_dics(info, forward, csd, reg=0.05, noise_csd=None, label=None, _prepare_beamformer_input( info, forward, label, pick_ori, noise_cov=noise_csd, rank=rank, pca=False, **depth) + + # Compute ranks + csd_int_rank = [] + if not allow_mismatch: + noise_rank = compute_rank(noise_csd, info=info, rank=rank) + for i in range(len(frequencies)): + csd_rank = compute_rank(csd.get_data(index=i, as_cov=True), + info=info, rank=rank) + if not allow_mismatch: + for key in csd_rank: + if key not in noise_rank or csd_rank[key] != noise_rank[key]: + raise ValueError('%s data rank (%s) did not match the ' + 'noise rank (%s)' + % (key, csd_rank[key], + noise_rank.get(key, None))) + csd_int_rank.append(sum(csd_rank.values())) + del noise_csd ch_names = list(info['ch_names']) @@ -203,8 +222,8 @@ def make_dics(info, forward, csd, reg=0.05, noise_csd=None, label=None, n_orient = 3 if is_free_ori else 1 W, max_power_ori = _compute_beamformer( G, Cm, reg, n_orient, weight_norm, pick_ori, reduce_rank, - rank=rank, inversion=inversion, nn=nn, orient_std=orient_std, - whitener=whitener) + rank=csd_int_rank[i], inversion=inversion, nn=nn, + orient_std=orient_std, whitener=whitener) Ws.append(W) max_oris.append(max_power_ori) diff --git a/mne/beamformer/tests/test_dics.py b/mne/beamformer/tests/test_dics.py index 38127a8be65..19ae2977781 100644 --- a/mne/beamformer/tests/test_dics.py +++ b/mne/beamformer/tests/test_dics.py @@ -18,7 +18,7 @@ from mne.beamformer._compute_beamformer import _prepare_beamformer_input from mne.beamformer._dics import _prepare_noise_csd from mne.time_frequency import csd_morlet -from mne.utils import object_diff, requires_h5py +from mne.utils import object_diff, requires_h5py, catch_logging from mne.proj import compute_proj_evoked, make_projector from mne.surface import _compute_nearest from mne.beamformer.tests.test_lcmv import _assert_weight_norm @@ -93,7 +93,7 @@ def _simulate_data(fwd, idx): # Somewhere on the frontal lobe by default evoked = epochs.average() # Compute the cross-spectral density matrix - csd = csd_morlet(epochs, frequencies=[10, 20], n_cycles=[5, 10], decim=10) + csd = csd_morlet(epochs, frequencies=[10, 20], n_cycles=[5, 10], decim=5) labels = mne.read_labels_from_annot( 'sample', hemi='lh', subjects_dir=subjects_dir) @@ -122,6 +122,19 @@ def _rand_csd(rng, info): return data +def _make_rand_csd(info, csd): + rng = np.random.RandomState(0) + data = _rand_csd(rng, info) + # now we need to have the same null space as the data csd + s, u = np.linalg.eigh(csd.get_data(csd.frequencies[0])) + mask = np.abs(s) >= s[-1] * 1e-7 + rank = mask.sum() + assert rank == len(data) == len(info['ch_names']) + noise_csd = CrossSpectralDensity( + _sym_mat_to_vector(data), info['ch_names'], 0., csd.n_fft) + return noise_csd, rank + + @pytest.mark.slowtest @testing.requires_testing_data @requires_h5py @@ -138,10 +151,8 @@ def test_make_dics(tmpdir, _load_forward, idx, whiten): with pytest.raises(ValueError, match='several sensor types'): make_dics(epochs.info, fwd_surf, csd, label=label, pick_ori=None) if whiten: - rng = np.random.RandomState(0) - data = _rand_csd(rng, epochs.info) - noise_csd = CrossSpectralDensity( - _sym_mat_to_vector(data), epochs.ch_names, 0., csd.n_fft) + noise_csd, rank = _make_rand_csd(epochs.info, csd) + assert rank == len(epochs.info['ch_names']) == 62 else: noise_csd = None epochs.pick_types(meg='grad') @@ -724,3 +735,55 @@ def test_localization_bias_free(bias_params_free, reg, pick_ori, weight_norm, # Compute the percentage of sources for which there is no loc bias: perc = (want == np.argmax(loc, axis=0)).mean() * 100 assert lower <= perc <= upper + + +@testing.requires_testing_data +@idx_param +@pytest.mark.parametrize('whiten', (False, True)) +def test_make_dics_rank(_load_forward, idx, whiten): + """Test making DICS beamformer filters with rank param.""" + _, fwd_surf, fwd_fixed, _ = _load_forward + epochs, _, csd, _, label, _, _ = _simulate_data(fwd_fixed, idx) + if whiten: + noise_csd, want_rank = _make_rand_csd(epochs.info, csd) + kind = 'mag + grad' + else: + noise_csd = None + epochs.pick_types(meg='grad') + want_rank = len(epochs.ch_names) + assert want_rank == 41 + kind = 'grad' + + with catch_logging() as log: + filters = make_dics( + epochs.info, fwd_surf, csd, label=label, noise_csd=noise_csd, + verbose=True) + log = log.getvalue() + assert f'Estimated rank ({kind}): {want_rank}' in log, log + stc, _ = apply_dics_csd(csd, filters) + other_rank = want_rank - 1 # shouldn't make a huge difference + use_rank = dict(meg=other_rank) + if not whiten: + # XXX it's a bug that our rank functions don't treat "meg" + # properly here... + use_rank['grad'] = use_rank.pop('meg') + with catch_logging() as log: + filters_2 = make_dics( + epochs.info, fwd_surf, csd, label=label, noise_csd=noise_csd, + rank=use_rank, verbose=True) + log = log.getvalue() + assert f'Computing rank from covariance with rank={use_rank}' in log, log + stc_2, _ = apply_dics_csd(csd, filters_2) + corr = np.corrcoef(stc_2.data.ravel(), stc.data.ravel())[0, 1] + assert 0.8 < corr < 0.99999 + + # degenerate conditions + if whiten: + # make rank deficient + data = noise_csd.get_data(0.) + data[0] = data[:0] = 0 + noise_csd._data[:, 0] = _sym_mat_to_vector(data) + with pytest.raises(ValueError, match='meg data rank.*the noise rank'): + filters = make_dics( + epochs.info, fwd_surf, csd, label=label, noise_csd=noise_csd, + verbose=True)