Skip to content

Commit

Permalink
WIP: Protect against bad projections
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner committed Mar 28, 2016
1 parent 1590106 commit 533906e
Show file tree
Hide file tree
Showing 9 changed files with 142 additions and 68 deletions.
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ test: in
rm -f .coverage
$(NOSETESTS) -a '!ultra_slow_test' mne

test-verbose: in
rm -f .coverage
$(NOSETESTS) -a '!ultra_slow_test' mne --verbose

test-fast: in
rm -f .coverage
$(NOSETESTS) -a '!slow_test' mne
Expand Down
35 changes: 22 additions & 13 deletions mne/beamformer/_lcmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _setup_picks(picks, info, forward, noise_cov=None):
@verbose
def _apply_lcmv(data, info, tmin, forward, noise_cov, data_cov, reg,
label=None, picks=None, pick_ori=None, rank=None,
verbose=None):
stacklevel=10, verbose=None):
""" LCMV beamformer for evoked data, single epochs, and raw data
Parameters
Expand Down Expand Up @@ -88,6 +88,8 @@ def _apply_lcmv(data, info, tmin, forward, noise_cov, data_cov, reg,
detected automatically. If int, the rank is specified for the MEG
channels. A dictionary with entries 'eeg' and/or 'meg' can be used
to specify the rank for each modality.
stacklevel : int
The stack level for warnings.
verbose : bool, str, int, or None
If not None, override default verbose level (see mne.verbose).
Expand All @@ -98,10 +100,12 @@ def _apply_lcmv(data, info, tmin, forward, noise_cov, data_cov, reg,
"""
is_free_ori, ch_names, proj, vertno, G = (
_prepare_beamformer_input(
info, forward, label, picks, pick_ori))
info, forward, label, picks, pick_ori,
stacklevel=stacklevel - 5))

# Handle whitening + data covariance
whitener, _ = compute_whitener(noise_cov, info, picks, rank=rank)
whitener = compute_whitener(noise_cov, info, picks, rank=rank,
stacklevel=stacklevel)[0]

# whiten the leadfield
G = np.dot(whitener, G)
Expand Down Expand Up @@ -215,7 +219,8 @@ def _apply_lcmv(data, info, tmin, forward, noise_cov, data_cov, reg,
logger.info('[done]')


def _prepare_beamformer_input(info, forward, label, picks, pick_ori):
def _prepare_beamformer_input(info, forward, label, picks, pick_ori,
stacklevel=7):
"""Input preparation common for all beamformer functions.
Check input values, prepare channel list and gain matrix. For documentation
Expand Down Expand Up @@ -256,7 +261,8 @@ def _prepare_beamformer_input(info, forward, label, picks, pick_ori):
G = forward['sol']['data']

# Apply SSPs
proj, ncomp, _ = make_projector(info['projs'], ch_names)
proj, ncomp, _ = make_projector(info['projs'], ch_names,
stacklevel=stacklevel)
if info['projs']:
G = np.dot(proj, G)

Expand Down Expand Up @@ -337,7 +343,7 @@ def lcmv(evoked, forward, noise_cov, data_cov, reg=0.01, label=None,
stc = _apply_lcmv(
data=data, info=info, tmin=tmin, forward=forward, noise_cov=noise_cov,
data_cov=data_cov, reg=reg, label=label, picks=picks, rank=rank,
pick_ori=pick_ori)
pick_ori=pick_ori, stacklevel=12)

return six.advance_iterator(stc)

Expand Down Expand Up @@ -415,11 +421,10 @@ def lcmv_epochs(epochs, forward, noise_cov, data_cov, reg=0.01, label=None,
picks = _setup_picks(picks, info, forward, noise_cov)

data = epochs.get_data()[:, picks, :]

stcs = _apply_lcmv(
data=data, info=info, tmin=tmin, forward=forward, noise_cov=noise_cov,
data_cov=data_cov, reg=reg, label=label, picks=picks, rank=rank,
pick_ori=pick_ori)
pick_ori=pick_ori, stacklevel=9)

if not return_generator:
stcs = [s for s in stcs]
Expand Down Expand Up @@ -469,6 +474,8 @@ def lcmv_raw(raw, forward, noise_cov, data_cov, reg=0.01, label=None,
detected automatically. If int, the rank is specified for the MEG
channels. A dictionary with entries 'eeg' and/or 'meg' can be used
to specify the rank for each modality.
stacklevel : int
The stack level for warnings.
verbose : bool, str, int, or None
If not None, override default verbose level (see mne.verbose).
Expand Down Expand Up @@ -505,15 +512,15 @@ def lcmv_raw(raw, forward, noise_cov, data_cov, reg=0.01, label=None,
stc = _apply_lcmv(
data=data, info=info, tmin=tmin, forward=forward, noise_cov=noise_cov,
data_cov=data_cov, reg=reg, label=label, picks=picks, rank=rank,
pick_ori=pick_ori)
pick_ori=pick_ori, stacklevel=12)

return six.advance_iterator(stc)


@verbose
def _lcmv_source_power(info, forward, noise_cov, data_cov, reg=0.01,
label=None, picks=None, pick_ori=None,
rank=None, verbose=None):
rank=None, stacklevel=11, verbose=None):
"""Linearly Constrained Minimum Variance (LCMV) beamformer.
Calculate source power in a time window based on the provided data
Expand Down Expand Up @@ -571,13 +578,14 @@ def _lcmv_source_power(info, forward, noise_cov, data_cov, reg=0.01,

is_free_ori, ch_names, proj, vertno, G =\
_prepare_beamformer_input(
info, forward, label, picks, pick_ori)
info, forward, label, picks, pick_ori, stacklevel=stacklevel - 5)

# Handle whitening
info = pick_info(
info, [info['ch_names'].index(k) for k in ch_names
if k in info['ch_names']])
whitener, _ = compute_whitener(noise_cov, info, picks, rank=rank)
whitener, _ = compute_whitener(noise_cov, info, picks, rank=rank,
stacklevel=stacklevel)

# whiten the leadfield
G = np.dot(whitener, G)
Expand Down Expand Up @@ -788,7 +796,8 @@ def tf_lcmv(epochs, forward, noise_covs, tmin, tmax, tstep, win_lengths,

stc = _lcmv_source_power(epochs_band.info, forward, noise_cov,
data_cov, reg=reg, label=label,
pick_ori=pick_ori, verbose=verbose)
pick_ori=pick_ori, stacklevel=15,
verbose=verbose)
sol_single.append(stc.data[:, 0])

# Average over all time windows that contain the current time
Expand Down
1 change: 1 addition & 0 deletions mne/beamformer/tests/test_dics.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def _get_data(tmin=-0.11, tmax=0.15, read_all_forward=True, compute_csds=True):
label = mne.read_label(fname_label)
events = mne.read_events(fname_event)[:10]
raw = mne.io.Raw(fname_raw, preload=False)
raw.add_proj([], remove_existing=True) # we'll subselect so remove proj
forward = mne.read_forward_solution(fname_fwd)
if read_all_forward:
forward_surf_ori = read_forward_solution_meg(fname_fwd, surf_ori=True)
Expand Down
90 changes: 54 additions & 36 deletions mne/beamformer/tests/test_lcmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,11 @@ def _get_data(tmin=-0.1, tmax=0.15, all_forward=True, epochs=True,
selection=left_temporal_channels)

# Read epochs
epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=True,
picks=picks, baseline=(None, 0),
preload=epochs_preload,
reject=dict(grad=4000e-13, mag=4e-12, eog=150e-6))
with warnings.catch_warnings(record=True):
epochs = mne.Epochs(
raw, events, event_id, tmin, tmax, proj=True, picks=picks,
baseline=(None, 0), preload=epochs_preload,
reject=dict(grad=4000e-13, mag=4e-12, eog=150e-6))
if epochs_preload:
epochs.resample(200, npad=0, n_jobs=2)
evoked = epochs.average()
Expand All @@ -79,8 +80,9 @@ def _get_data(tmin=-0.1, tmax=0.15, all_forward=True, epochs=True,
info = raw.info

noise_cov = mne.read_cov(fname_cov)
noise_cov = mne.cov.regularize(noise_cov, info, mag=0.05, grad=0.05,
eeg=0.1, proj=True)
with warnings.catch_warnings(record=True): # bad proj here
noise_cov = mne.cov.regularize(noise_cov, info, mag=0.05, grad=0.05,
eeg=0.1, proj=True)
if data_cov:
with warnings.catch_warnings(record=True):
data_cov = mne.compute_covariance(epochs, tmin=0.04, tmax=0.145)
Expand All @@ -100,7 +102,8 @@ def test_lcmv():
forward_surf_ori, forward_fixed, forward_vol = _get_data()

for fwd in [forward, forward_vol]:
stc = lcmv(evoked, fwd, noise_cov, data_cov, reg=0.01)
with warnings.catch_warnings(record=True):
stc = lcmv(evoked, fwd, noise_cov, data_cov, reg=0.01) # bad proj
stc.crop(0.02, None)

stc_pow = np.sum(stc.data, axis=1)
Expand All @@ -113,8 +116,9 @@ def test_lcmv():

if fwd is forward:
# Test picking normal orientation (surface source space only)
stc_normal = lcmv(evoked, forward_surf_ori, noise_cov, data_cov,
reg=0.01, pick_ori="normal")
with warnings.catch_warnings(record=True): # bad proj
stc_normal = lcmv(evoked, forward_surf_ori, noise_cov,
data_cov, reg=0.01, pick_ori="normal")
stc_normal.crop(0.02, None)

stc_pow = np.sum(np.abs(stc_normal.data), axis=1)
Expand All @@ -130,8 +134,9 @@ def test_lcmv():
assert_true((np.abs(stc_normal.data) <= stc.data).all())

# Test picking source orientation maximizing output source power
stc_max_power = lcmv(evoked, fwd, noise_cov, data_cov, reg=0.01,
pick_ori="max-power")
with warnings.catch_warnings(record=True): # bad proj
stc_max_power = lcmv(evoked, fwd, noise_cov, data_cov, reg=0.01,
pick_ori="max-power")
stc_max_power.crop(0.02, None)
stc_pow = np.sum(stc_max_power.data, axis=1)
idx = np.argmax(stc_pow)
Expand Down Expand Up @@ -164,10 +169,12 @@ def test_lcmv():

# Now test single trial using fixed orientation forward solution
# so we can compare it to the evoked solution
stcs = lcmv_epochs(epochs, forward_fixed, noise_cov, data_cov, reg=0.01)
stcs_ = lcmv_epochs(epochs, forward_fixed, noise_cov, data_cov, reg=0.01,
return_generator=True)
assert_array_equal(stcs[0].data, advance_iterator(stcs_).data)
with warnings.catch_warnings(record=True): # bad proj
stcs = lcmv_epochs(epochs, forward_fixed, noise_cov, data_cov,
reg=0.01)
stcs_ = lcmv_epochs(epochs, forward_fixed, noise_cov, data_cov,
reg=0.01, return_generator=True)
assert_array_equal(stcs[0].data, advance_iterator(stcs_).data)

epochs.drop_bad_epochs()
assert_true(len(epochs.events) == len(stcs))
Expand All @@ -179,13 +186,15 @@ def test_lcmv():
stc_avg /= len(stcs)

# compare it to the solution using evoked with fixed orientation
stc_fixed = lcmv(evoked, forward_fixed, noise_cov, data_cov, reg=0.01)
with warnings.catch_warnings(record=True): # bad proj
stc_fixed = lcmv(evoked, forward_fixed, noise_cov, data_cov, reg=0.01)
assert_array_almost_equal(stc_avg, stc_fixed.data)

# use a label so we have few source vertices and delayed computation is
# not used
stcs_label = lcmv_epochs(epochs, forward_fixed, noise_cov, data_cov,
reg=0.01, label=label)
with warnings.catch_warnings(record=True): # bad proj
stcs_label = lcmv_epochs(epochs, forward_fixed, noise_cov, data_cov,
reg=0.01, label=label)

assert_array_almost_equal(stcs_label[0].data, stcs[0].in_label(label).data)

Expand All @@ -207,8 +216,9 @@ def test_lcmv_raw():

data_cov = mne.compute_raw_covariance(raw, tmin=tmin, tmax=tmax)

stc = lcmv_raw(raw, forward, noise_cov, data_cov, reg=0.01, label=label,
start=start, stop=stop, picks=picks)
with warnings.catch_warnings(record=True): # bad proj
stc = lcmv_raw(raw, forward, noise_cov, data_cov, reg=0.01,
label=label, start=start, stop=stop, picks=picks)

assert_array_almost_equal(np.array([tmin, tmax]),
np.array([stc.times[0], stc.times[-1]]),
Expand All @@ -228,8 +238,9 @@ def test_lcmv_source_power():
raw, epochs, evoked, data_cov, noise_cov, label, forward,\
forward_surf_ori, forward_fixed, forward_vol = _get_data()

stc_source_power = _lcmv_source_power(epochs.info, forward, noise_cov,
data_cov, label=label)
with warnings.catch_warnings(record=True): # bad proj
stc_source_power = _lcmv_source_power(epochs.info, forward, noise_cov,
data_cov, label=label)

max_source_idx = np.argmax(stc_source_power.data)
max_source_power = np.max(stc_source_power.data)
Expand All @@ -238,8 +249,10 @@ def test_lcmv_source_power():
assert_true(0.4 < max_source_power < 2.4, max_source_power)

# Test picking normal orientation and using a list of CSD matrices
stc_normal = _lcmv_source_power(epochs.info, forward_surf_ori, noise_cov,
data_cov, pick_ori="normal", label=label)
with warnings.catch_warnings(record=True): # bad proj
stc_normal = _lcmv_source_power(
epochs.info, forward_surf_ori, noise_cov, data_cov,
pick_ori="normal", label=label)

# The normal orientation results should always be smaller than free
# orientation results
Expand Down Expand Up @@ -283,9 +296,10 @@ def test_tf_lcmv():
selection=left_temporal_channels)

# Read epochs
epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=True,
picks=picks, baseline=None, preload=False,
reject=dict(grad=4000e-13, mag=4e-12, eog=150e-6))
with warnings.catch_warnings(record=True): # bad proj
epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=True,
picks=picks, baseline=None, preload=False,
reject=dict(grad=4000e-13, mag=4e-12, eog=150e-6))
epochs.drop_bad_epochs()

freq_bins = [(4, 12), (15, 40)]
Expand All @@ -299,28 +313,32 @@ def test_tf_lcmv():
for (l_freq, h_freq), win_length in zip(freq_bins, win_lengths):
raw_band = raw.copy()
raw_band.filter(l_freq, h_freq, method='iir', n_jobs=1, picks=picks)
epochs_band = mne.Epochs(raw_band, epochs.events, epochs.event_id,
tmin=tmin, tmax=tmax, baseline=None,
proj=True, picks=picks)
with warnings.catch_warnings(record=True): # bad proj
epochs_band = mne.Epochs(
raw_band, epochs.events, epochs.event_id, tmin=tmin, tmax=tmax,
baseline=None, proj=True, picks=picks)
with warnings.catch_warnings(record=True): # not enough samples
noise_cov = compute_covariance(epochs_band, tmin=tmin, tmax=tmin +
win_length)
noise_cov = mne.cov.regularize(noise_cov, epochs_band.info, mag=reg,
grad=reg, eeg=reg, proj=True)
with warnings.catch_warnings(record=True): # bad proj
noise_cov = mne.cov.regularize(
noise_cov, epochs_band.info, mag=reg, grad=reg, eeg=reg,
proj=True)
noise_covs.append(noise_cov)
del raw_band # to save memory

# Manually calculating source power in on frequency band and several
# time windows to compare to tf_lcmv results and test overlapping
if (l_freq, h_freq) == freq_bins[0]:
for time_window in time_windows:
with warnings.catch_warnings(record=True):
with warnings.catch_warnings(record=True): # bad samples
data_cov = compute_covariance(epochs_band,
tmin=time_window[0],
tmax=time_window[1])
stc_source_power = _lcmv_source_power(epochs.info, forward,
noise_cov, data_cov,
reg=reg, label=label)
with warnings.catch_warnings(record=True): # bad proj
stc_source_power = _lcmv_source_power(
epochs.info, forward, noise_cov, data_cov,
reg=reg, label=label)
source_power.append(stc_source_power.data)

with warnings.catch_warnings(record=True):
Expand Down
Loading

0 comments on commit 533906e

Please sign in to comment.