Skip to content

Commit

Permalink
MRG: Fix DICS rank handling (mne-tools#8594)
Browse files Browse the repository at this point in the history
* Fix DICS rank handling

* Fix case when no noise_csd is given

* TST: Add test

* DOC: latest

* FIX: Tol

* FIX: Tol

* FIX: Full rank

Co-authored-by: Eric Larson <[email protected]>
  • Loading branch information
wmvanvliet and larsoner authored Dec 17, 2020
1 parent ae9ed23 commit c23df98
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 9 deletions.
2 changes: 2 additions & 0 deletions doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ Bugs

- Fix bug with reading EDF and KIT files on big endian architectures such as s390x (:gh:`8618` by `Eric Larson`_)

- Fix bug with :func:`mne.beamformer.make_dics` where the ``rank`` parameter was not properly handled (:gh:`8594` by `Marijn van Vliet`_ and `Eric Larson`_)

- Fix bug with :func:`mne.beamformer.apply_dics` where the whitener was not properly applied (:gh:`8610` by `Eric Larson`_)

- Fix bug with `~mne.viz.plot_epochs_image` when ``order`` is supplied and multiple conditions are plotted (:gh:`8377` by `Daniel McCloy`_ )
Expand Down
25 changes: 22 additions & 3 deletions mne/beamformer/_dics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
_check_option, _validate_type)
from ..forward import _subject_from_forward
from ..minimum_norm.inverse import combine_xyz, _check_reference, _check_depth
from ..rank import compute_rank
from ..source_estimate import _make_stc, _get_src_type
from ..time_frequency import csd_fourier, csd_multitaper, csd_morlet
from ._compute_beamformer import (_prepare_beamformer_input,
Expand Down Expand Up @@ -166,7 +167,8 @@ def make_dics(info, forward, csd, reg=0.05, noise_csd=None, label=None,
frequencies = [np.mean(freq_bin) for freq_bin in csd.frequencies]
n_freqs = len(frequencies)

_check_one_ch_type('dics', info, forward, csd, noise_csd)
_, _, allow_mismatch = _check_one_ch_type('dics', info, forward, csd,
noise_csd)
# remove bads so that equalize_channels only keeps all good
info = pick_info(info, pick_channels(info['ch_names'], [], info['bads']))
info, forward, csd = equalize_channels([info, forward, csd])
Expand All @@ -181,6 +183,23 @@ def make_dics(info, forward, csd, reg=0.05, noise_csd=None, label=None,
_prepare_beamformer_input(
info, forward, label, pick_ori, noise_cov=noise_csd, rank=rank,
pca=False, **depth)

# Compute ranks
csd_int_rank = []
if not allow_mismatch:
noise_rank = compute_rank(noise_csd, info=info, rank=rank)
for i in range(len(frequencies)):
csd_rank = compute_rank(csd.get_data(index=i, as_cov=True),
info=info, rank=rank)
if not allow_mismatch:
for key in csd_rank:
if key not in noise_rank or csd_rank[key] != noise_rank[key]:
raise ValueError('%s data rank (%s) did not match the '
'noise rank (%s)'
% (key, csd_rank[key],
noise_rank.get(key, None)))
csd_int_rank.append(sum(csd_rank.values()))

del noise_csd
ch_names = list(info['ch_names'])

Expand All @@ -203,8 +222,8 @@ def make_dics(info, forward, csd, reg=0.05, noise_csd=None, label=None,
n_orient = 3 if is_free_ori else 1
W, max_power_ori = _compute_beamformer(
G, Cm, reg, n_orient, weight_norm, pick_ori, reduce_rank,
rank=rank, inversion=inversion, nn=nn, orient_std=orient_std,
whitener=whitener)
rank=csd_int_rank[i], inversion=inversion, nn=nn,
orient_std=orient_std, whitener=whitener)
Ws.append(W)
max_oris.append(max_power_ori)

Expand Down
75 changes: 69 additions & 6 deletions mne/beamformer/tests/test_dics.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from mne.beamformer._compute_beamformer import _prepare_beamformer_input
from mne.beamformer._dics import _prepare_noise_csd
from mne.time_frequency import csd_morlet
from mne.utils import object_diff, requires_h5py
from mne.utils import object_diff, requires_h5py, catch_logging
from mne.proj import compute_proj_evoked, make_projector
from mne.surface import _compute_nearest
from mne.beamformer.tests.test_lcmv import _assert_weight_norm
Expand Down Expand Up @@ -93,7 +93,7 @@ def _simulate_data(fwd, idx): # Somewhere on the frontal lobe by default
evoked = epochs.average()

# Compute the cross-spectral density matrix
csd = csd_morlet(epochs, frequencies=[10, 20], n_cycles=[5, 10], decim=10)
csd = csd_morlet(epochs, frequencies=[10, 20], n_cycles=[5, 10], decim=5)

labels = mne.read_labels_from_annot(
'sample', hemi='lh', subjects_dir=subjects_dir)
Expand Down Expand Up @@ -122,6 +122,19 @@ def _rand_csd(rng, info):
return data


def _make_rand_csd(info, csd):
rng = np.random.RandomState(0)
data = _rand_csd(rng, info)
# now we need to have the same null space as the data csd
s, u = np.linalg.eigh(csd.get_data(csd.frequencies[0]))
mask = np.abs(s) >= s[-1] * 1e-7
rank = mask.sum()
assert rank == len(data) == len(info['ch_names'])
noise_csd = CrossSpectralDensity(
_sym_mat_to_vector(data), info['ch_names'], 0., csd.n_fft)
return noise_csd, rank


@pytest.mark.slowtest
@testing.requires_testing_data
@requires_h5py
Expand All @@ -138,10 +151,8 @@ def test_make_dics(tmpdir, _load_forward, idx, whiten):
with pytest.raises(ValueError, match='several sensor types'):
make_dics(epochs.info, fwd_surf, csd, label=label, pick_ori=None)
if whiten:
rng = np.random.RandomState(0)
data = _rand_csd(rng, epochs.info)
noise_csd = CrossSpectralDensity(
_sym_mat_to_vector(data), epochs.ch_names, 0., csd.n_fft)
noise_csd, rank = _make_rand_csd(epochs.info, csd)
assert rank == len(epochs.info['ch_names']) == 62
else:
noise_csd = None
epochs.pick_types(meg='grad')
Expand Down Expand Up @@ -724,3 +735,55 @@ def test_localization_bias_free(bias_params_free, reg, pick_ori, weight_norm,
# Compute the percentage of sources for which there is no loc bias:
perc = (want == np.argmax(loc, axis=0)).mean() * 100
assert lower <= perc <= upper


@testing.requires_testing_data
@idx_param
@pytest.mark.parametrize('whiten', (False, True))
def test_make_dics_rank(_load_forward, idx, whiten):
"""Test making DICS beamformer filters with rank param."""
_, fwd_surf, fwd_fixed, _ = _load_forward
epochs, _, csd, _, label, _, _ = _simulate_data(fwd_fixed, idx)
if whiten:
noise_csd, want_rank = _make_rand_csd(epochs.info, csd)
kind = 'mag + grad'
else:
noise_csd = None
epochs.pick_types(meg='grad')
want_rank = len(epochs.ch_names)
assert want_rank == 41
kind = 'grad'

with catch_logging() as log:
filters = make_dics(
epochs.info, fwd_surf, csd, label=label, noise_csd=noise_csd,
verbose=True)
log = log.getvalue()
assert f'Estimated rank ({kind}): {want_rank}' in log, log
stc, _ = apply_dics_csd(csd, filters)
other_rank = want_rank - 1 # shouldn't make a huge difference
use_rank = dict(meg=other_rank)
if not whiten:
# XXX it's a bug that our rank functions don't treat "meg"
# properly here...
use_rank['grad'] = use_rank.pop('meg')
with catch_logging() as log:
filters_2 = make_dics(
epochs.info, fwd_surf, csd, label=label, noise_csd=noise_csd,
rank=use_rank, verbose=True)
log = log.getvalue()
assert f'Computing rank from covariance with rank={use_rank}' in log, log
stc_2, _ = apply_dics_csd(csd, filters_2)
corr = np.corrcoef(stc_2.data.ravel(), stc.data.ravel())[0, 1]
assert 0.8 < corr < 0.99999

# degenerate conditions
if whiten:
# make rank deficient
data = noise_csd.get_data(0.)
data[0] = data[:0] = 0
noise_csd._data[:, 0] = _sym_mat_to_vector(data)
with pytest.raises(ValueError, match='meg data rank.*the noise rank'):
filters = make_dics(
epochs.info, fwd_surf, csd, label=label, noise_csd=noise_csd,
verbose=True)

0 comments on commit c23df98

Please sign in to comment.