Skip to content

Commit

Permalink
MRG: Clean up and unify _prepare_forward (mne-tools#5999)
Browse files Browse the repository at this point in the history
* MAINT: Working refactor

* MAINT: Refactor _prepare_forward

* STY: Flake

* FIX: Make LGTM happy?
  • Loading branch information
larsoner authored Mar 1, 2019
1 parent 2797c6a commit 693bd2d
Show file tree
Hide file tree
Showing 9 changed files with 215 additions and 217 deletions.
2 changes: 2 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ API

- Reading BDF and GDF files with :func:`mne.io.read_raw_edf` is deprecated and replaced by :func:`mne.io.read_raw_bdf` and :func:`mne.io.read_raw_gdf`, by `Clemens Brunner`_

- :func:`mne.forward.compute_depth_prior` has been reworked to operate directly on :class:`Forward` instance as ``forward`` rather than a representation scattered across the parameters ``G, is_fixed_ori, patch_info``, by `Eric Larson`_

- Deprecate ``method='extended-infomax'`` in :class:`mne.preprocessing.ICA`; Extended Infomax can now be computed with ``method='infomax'`` and ``fit_params=dict(extended=True)`` by `Clemens Brunner`_

.. _changes_0_17:
Expand Down
8 changes: 3 additions & 5 deletions examples/inverse/plot_custom_inverse_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,15 @@ def apply_solver(solver, evoked, forward, noise_cov, loose=0.2, depth=0.8):
"""
# Import the necessary private functions
from mne.inverse_sparse.mxne_inverse import \
(_prepare_gain, _check_loose_forward, is_fixed_orient,
(_prepare_gain, is_fixed_orient,
_reapply_source_weighting, _make_sparse_stc)

all_ch_names = evoked.ch_names

loose, forward = _check_loose_forward(loose, forward)

# Handle depth weighting and whitening (here is no weights)
gain, gain_info, whitener, source_weighting, mask = _prepare_gain(
forward, gain, gain_info, whitener, source_weighting, mask = _prepare_gain(
forward, evoked.info, noise_cov, pca=False, depth=depth,
loose=loose, weights=None, weights_min=None)
loose=loose, weights=None, weights_min=None, rank=None)

# Select channels of interest
sel = [all_ch_names.index(name) for name in gain_info['ch_names']]
Expand Down
83 changes: 41 additions & 42 deletions mne/forward/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,21 +977,12 @@ def _select_orient_forward(forward, info, noise_cov=None, verbose=None):

n_chan = len(ch_names)
logger.info("Computing inverse operator with %d channels." % n_chan)

gain = forward['sol']['data']

# This actually reorders the gain matrix to conform to the info ch order
fwd_idx = [fwd_sol_ch_names.index(name) for name in ch_names]
gain = gain[fwd_idx]
# Any function calling this helper will be using the returned fwd_info
# dict, so fwd['sol']['row_names'] becomes obsolete and is NOT re-ordered

forward = pick_channels_forward(forward, ch_names, ordered=True)
info_idx = [info['ch_names'].index(name) for name in ch_names]
fwd_info = pick_info(info, info_idx)
info_picked = pick_info(info, info_idx)
forward['info']._check_consistency()
fwd_info._check_consistency()

return fwd_info, gain
info_picked._check_consistency()
return forward, info_picked


@verbose
Expand Down Expand Up @@ -1019,21 +1010,22 @@ def compute_orient_prior(forward, loose=0.2, verbose=None):
n_sources = forward['sol']['data'].shape[1]
loose = float(loose)
if not (0 <= loose <= 1):
raise ValueError('loose value should be smaller than 1 and bigger '
'than 0, got %s.' % (loose,))
if loose < 1 and not forward['surf_ori']:
raise ValueError('Forward operator is not oriented in surface '
'coordinates. loose parameter should be 1 '
'not %s.' % loose)
if is_fixed_ori and loose != 0:
raise ValueError('loose must be 0. with forward operator '
'with fixed orientation.')

raise ValueError('loose value should be between 0 and 1, '
'got %s.' % (loose,))
orient_prior = np.ones(n_sources, dtype=np.float)
if not is_fixed_ori and loose < 1:
logger.info('Applying loose dipole orientations. Loose value '
'of %s.' % loose)
orient_prior[np.mod(np.arange(n_sources), 3) != 2] *= loose
if loose > 0.:
if is_fixed_ori:
raise ValueError('loose must be 0. with forward operator '
'with fixed orientation, got %s' % (loose,))
if loose < 1:
if not forward['surf_ori']:
raise ValueError('Forward operator is not oriented in surface '
'coordinates. loose parameter should be 1 '
'not %s.' % (loose,))
logger.info('Applying loose dipole orientations. Loose value '
'of %s.' % loose)
orient_prior[0::3] *= loose
orient_prior[1::3] *= loose

return orient_prior

Expand Down Expand Up @@ -1065,27 +1057,28 @@ def _restrict_gain_matrix(G, info):


@verbose
def compute_depth_prior(G, gain_info, is_fixed_ori, exp=0.8, limit=10.0,
def compute_depth_prior(forward, info, is_fixed_ori=None,
exp=0.8, limit=10.0,
patch_areas=None, limit_depth_chs=False,
combine_xyz='spectral', noise_cov=None, rank=None,
verbose=None):
"""Compute depth prior for depth weighting.
Parameters
----------
G : ndarray, shape (n_channels, n_vertices)
The gain matrix.
gain_info : instance of Info
The info associated with the gain matrix.
is_fixed_ori : bool
Whether or not ``G`` is fixed orientation.
forward : instance of Forward
The forward solution.
info : instance of Info
The measurement info.
is_fixed_ori : bool | None
Deprecated, will be removed in 0.19.
exp : float
Exponent for the depth weighting, must be between 0 and 1.
limit : float | None
The upper bound on depth weighting.
Can be None to be bounded by the largest finite prior.
patch_areas : ndarray | None
Patch areas of the vertices from the forward solution.
Deprecated, will be removed in 0.19.
limit_depth_chs : bool | 'whiten'
How to deal with multiple channel types in depth weighting. Options:
Expand Down Expand Up @@ -1148,11 +1141,17 @@ def compute_depth_prior(G, gain_info, is_fixed_ori, exp=0.8, limit=10.0,
combine_xyz='fro')
"""
# XXX this perhaps should just take ``forward`` instead of ``G`` and
# ``gain_info``. However, it's not easy to do this given that the
# mixed norm code requires that ``G`` is whitened before this chunk
# of code executes.
from ..cov import Covariance, compute_whitener
if isinstance(forward, Forward):
patch_areas = forward.get('patch_areas', None)
is_fixed_ori = is_fixed_orient(forward)
G = forward['sol']['data']
else:
warn('Parameters G, is_fixed_ori, and patch_areas are '
'deprecated and will be removed in 0.19, pass in the forward '
'solution directly.', DeprecationWarning)
G = forward
_validate_type(is_fixed_ori, bool, 'is_fixed_ori')
logger.info('Creating the depth weighting matrix...')
_validate_type(noise_cov, (Covariance, None), 'noise_cov',
'Covariance or None')
Expand All @@ -1168,10 +1167,10 @@ def compute_depth_prior(G, gain_info, is_fixed_ori, exp=0.8, limit=10.0,

# If possible, pick best depth-weighting channels
if limit_depth_chs is True:
G = _restrict_gain_matrix(G, gain_info)
G = _restrict_gain_matrix(G, info)
elif limit_depth_chs == 'whiten':
whitener, _ = compute_whitener(noise_cov, gain_info, pca=True,
rank=rank)
whitener, _ = compute_whitener(noise_cov, info, pca=True, rank=rank,
verbose=False)
G = np.dot(whitener, G)

# Compute the gain matrix
Expand Down
49 changes: 40 additions & 9 deletions mne/forward/tests/test_forward.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import os.path as op
import gc

Expand All @@ -12,27 +11,22 @@
average_forward_solutions, write_forward_solution,
convert_forward_solution, SourceEstimate, pick_types_forward,
read_evokeds)
from mne.io import read_info
from mne.label import read_label
from mne.utils import (requires_mne, run_subprocess, _TempDir,
run_tests_if_main)
from mne.forward import (restrict_forward_to_stc, restrict_forward_to_label,
Forward, is_fixed_orient)
Forward, is_fixed_orient, compute_orient_prior,
compute_depth_prior)

data_path = testing.data_path(download=False)
fname_meeg = op.join(data_path, 'MEG', 'sample',
'sample_audvis_trunc-meg-eeg-oct-4-fwd.fif')
fname_meeg_grad = op.join(data_path, 'MEG', 'sample',
'sample_audvis_trunc-meg-eeg-oct-2-grad-fwd.fif')

fname_raw = op.join(op.dirname(__file__), '..', '..', 'io', 'tests', 'data',
'test_raw.fif')

fname_evoked = op.join(op.dirname(__file__), '..', '..', 'io', 'tests',
'data', 'test-ave.fif')
fname_mri = op.join(data_path, 'MEG', 'sample',
'sample_audvis_trunc-trans.fif')
subjects_dir = os.path.join(data_path, 'subjects')
fname_src = op.join(subjects_dir, 'sample', 'bem', 'sample-oct-4-src.fif')


def compare_forwards(f1, f2):
Expand Down Expand Up @@ -401,4 +395,41 @@ def test_average_forward_solution():
compare_forwards(fwd, fwd_ave)


@testing.requires_testing_data
def test_priors():
"""Test prior computations."""
# Depth prior
fwd = read_forward_solution(fname_meeg)
assert not is_fixed_orient(fwd)
n_sources = fwd['nsource']
info = read_info(fname_evoked)
depth_prior = compute_depth_prior(fwd, info, exp=0.8)
assert depth_prior.shape == (3 * n_sources,)
depth_prior = compute_depth_prior(fwd, info, exp=0.)
assert_array_equal(depth_prior, 1.)
with pytest.raises(ValueError, match='must be "whiten"'):
compute_depth_prior(fwd, info, limit_depth_chs='foo')
with pytest.raises(ValueError, match='noise_cov must be a Covariance'):
compute_depth_prior(fwd, info, limit_depth_chs='whiten')
fwd_fixed = convert_forward_solution(fwd, force_fixed=True)
with pytest.deprecated_call():
depth_prior = compute_depth_prior(
fwd_fixed['sol']['data'], info, is_fixed_ori=True)
assert depth_prior.shape == (n_sources,)
# Orientation prior
orient_prior = compute_orient_prior(fwd, 1.)
assert_array_equal(orient_prior, 1.)
orient_prior = compute_orient_prior(fwd_fixed, 0.)
assert_array_equal(orient_prior, 1.)
with pytest.raises(ValueError, match='oriented in surface coordinates'):
compute_orient_prior(fwd, 0.5)
fwd_surf_ori = convert_forward_solution(fwd, surf_ori=True)
orient_prior = compute_orient_prior(fwd_surf_ori, 0.5)
assert all(np.in1d(orient_prior, (0.5, 1.)))
with pytest.raises(ValueError, match='between 0 and 1'):
compute_orient_prior(fwd_surf_ori, -0.5)
with pytest.raises(ValueError, match='with fixed orientation'):
compute_orient_prior(fwd_fixed, 0.5)


run_tests_if_main()
20 changes: 5 additions & 15 deletions mne/inverse_sparse/_gamma_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import numpy as np
from scipy import linalg

from ..forward import is_fixed_orient, convert_forward_solution
from ..forward import is_fixed_orient

from ..minimum_norm.inverse import _check_reference
from ..utils import logger, verbose, warn
from .mxne_inverse import (_make_sparse_stc, _prepare_gain,
_reapply_source_weighting, _compute_residual,
_make_dipoles_sparse, _check_loose_forward)
_make_dipoles_sparse)


@verbose
Expand Down Expand Up @@ -240,21 +240,11 @@ def gamma_map(evoked, forward, noise_cov, alpha, loose="auto", depth=0.8,
"""
_check_reference(evoked)

loose, forward = _check_loose_forward(loose, forward)

# make forward solution in fixed orientation if necessary
if loose == 0. and not is_fixed_orient(forward):
forward = convert_forward_solution(
forward, surf_ori=True, force_fixed=True, copy=True, use_cps=True)

if is_fixed_orient(forward) or not xyz_same_gamma:
group_size = 1
else:
group_size = 3

gain, gain_info, whitener, source_weighting, mask = _prepare_gain(
forward, gain, gain_info, whitener, source_weighting, mask = _prepare_gain(
forward, evoked.info, noise_cov, pca, depth, loose, rank)

group_size = 1 if (is_fixed_orient(forward) or not xyz_same_gamma) else 3

# get the data
sel = [evoked.ch_names.index(name) for name in gain_info['ch_names']]
M = evoked.data[sel]
Expand Down
27 changes: 6 additions & 21 deletions mne/inverse_sparse/mxne_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from ..source_estimate import (SourceEstimate, VolSourceEstimate,
_BaseSourceEstimate)
from ..minimum_norm.inverse import (combine_xyz, _prepare_forward,
_check_reference, _check_loose_forward)
from ..forward import is_fixed_orient, convert_forward_solution
_check_reference)
from ..forward import is_fixed_orient
from ..io.pick import pick_channels_evoked
from ..io.proj import deactivate_proj
from ..utils import logger, verbose, warn, _check_depth
Expand Down Expand Up @@ -64,7 +64,7 @@ def _prepare_gain(forward, info, noise_cov, pca, depth, loose, rank,
gain, source_weighting, mask = _prepare_weights(
forward, gain, source_weighting, weights, weights_min)

return gain, gain_info, whitener, source_weighting, mask
return forward, gain, gain_info, whitener, source_weighting, mask


def _reapply_source_weighting(X, source_weighting, active_set):
Expand Down Expand Up @@ -353,14 +353,7 @@ def mixed_norm(evoked, forward, noise_cov, alpha, loose='auto', depth=0.8,
for i in range(1, len(evoked))):
raise Exception('All the datasets must have the same good channels.')

loose, forward = _check_loose_forward(loose, forward)

# put the forward solution in fixed orientation if it's not already
if loose == 0. and not is_fixed_orient(forward):
forward = convert_forward_solution(
forward, surf_ori=True, force_fixed=True, copy=True, use_cps=True)

gain, gain_info, whitener, source_weighting, mask = _prepare_gain(
forward, gain, gain_info, whitener, source_weighting, mask = _prepare_gain(
forward, evoked[0].info, noise_cov, pca, depth, loose, rank,
weights, weights_min)

Expand Down Expand Up @@ -611,18 +604,10 @@ def tf_mixed_norm(evoked, forward, noise_cov,
'passed. Got tstep = %s and wsize = %s' %
(tstep, wsize))

loose, forward = _check_loose_forward(loose, forward)

# put the forward solution in fixed orientation if it's not already
if loose == 0. and not is_fixed_orient(forward):
forward = convert_forward_solution(
forward, surf_ori=True, force_fixed=True, copy=True, use_cps=True)

n_dip_per_pos = 1 if is_fixed_orient(forward) else 3

gain, gain_info, whitener, source_weighting, mask = _prepare_gain(
forward, gain, gain_info, whitener, source_weighting, mask = _prepare_gain(
forward, evoked.info, noise_cov, pca, depth, loose, rank,
weights, weights_min)
n_dip_per_pos = 1 if is_fixed_orient(forward) else 3

if window is not None:
evoked = _window_evoked(evoked, window)
Expand Down
Loading

0 comments on commit 693bd2d

Please sign in to comment.