Skip to content

Commit

Permalink
ENH Added SURE support for MxNE and irMxNE (mne-tools#9430)
Browse files Browse the repository at this point in the history
* ENH Added SURE support for mixed MxNE and irMxNE. Needs unitary tests (WIP).

* CLN Fixed flake styling issues

* FIX Passed the comments

* ENH Updated references

* CLN Revert previous commit + ENH Written tests for SURE (1/6 failing)

* ENH Rewritten tests for SURE on syntehtic data + CLN Passed remarks

* UPD Updated example mixed_norm_inverse

* ENH Implemented warm start + fixed pydocstyle errors (all tests pass)

* pass alex

* ENH Written tests for SURE on MEG

* FIX Solved merge conflicts

* FIX Solved merge conflicts? (second attempt)

* FIX Solved merge conflicts? (third attempt)

* CLN Completed docstrings based on Alex's comments

* ENH Fixed minor bugs

* ENH Minor fix in docstrings

* ENH Fix in docstrings

* TEST: Revert changes on example mixed_norm_inverse

* fix doc?

* fix doc?

* fix doc?

* fix doc?

* DOC: Links

* Revert "TEST: Revert changes on example mixed_norm_inverse"

This reverts commit 9473e72.

* Pass Eric's comments

* Set random state for mixed norm examples for reproducibility purposes

* update what's new following @larsoner's suggestion

* address comments by @larsoner

Co-authored-by: Pierre-Antoine Bannier <[email protected]>
Co-authored-by: Pierre-Antoine Bannier <[email protected]>
Co-authored-by: Alexandre Gramfort <[email protected]>
Co-authored-by: Eric Larson <[email protected]>
  • Loading branch information
5 people authored Jun 18, 2021
1 parent b3768fe commit ea1212d
Show file tree
Hide file tree
Showing 7 changed files with 341 additions and 36 deletions.
8 changes: 6 additions & 2 deletions doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,20 @@ Current (0.24.dev0)

.. |Marian Dovgialo| replace:: **Marian Dovgialo**

.. |Pierre-Antoine Bannier| replace:: **Pierre-Antoine Bannier**

Enhancements
~~~~~~~~~~~~
.. - Add something cool (:gh:`9192` **by new contributor** |New Contributor|_)

- New function :func:`mne.chpi.get_chpi_info` to retrieve basic information about the cHPI system used when recording MEG data (:gh:`9369` by `Richard Höchenberger`_)

- Add support for NIRSport and NIRSport2 devices to `mne.io.read_raw_nirx` (:gh:`9348` and :gh:`9401` **by new contributor** |David Julien|_, **new contributor** |Romain Derollepot|_, `Robert Luke`_, and `Eric Larson`_)

- New function :func:`mne.label.find_pos_in_annot` to get atlas label for MRI coordinates. (:gh:`9376` by **by new contributor** |Marian Dovgialo|_)

- Add support for SURE parameter selection in :func:`mne.inverse_sparse.mixed_norm` and make ``alpha`` parameter now default to ``'sure'`` (:gh:`9430` by **new contributor** |Pierre-Antoine Bannier|_ and `Alex Gramfort`_)

- New function :func:`mne.chpi.get_chpi_info` to retrieve basic information about the cHPI system used when recording MEG data (:gh:`9369` by `Richard Höchenberger`_)

- New namespace `mne.export` created to contain functions (such as `mne.export.export_raw` and `mne.export.export_epochs`) for exporting data to non-FIF formats (:gh:`9427` by `Eric Larson`_)

- Add support for Hitachi fNIRS devices in `mne.io.read_raw_hitachi` (:gh:`9391` by `Eric Larson`_)
Expand Down
2 changes: 2 additions & 0 deletions doc/changes/names.inc
Original file line number Diff line number Diff line change
Expand Up @@ -399,3 +399,5 @@
.. _Xiaokai Xia: https://github.com/dddd1007
.. _Marian Dovgialo: https://github.com/mdovgialo
.. _Pierre-Antoine Bannier: https://github.com/PABannier
12 changes: 12 additions & 0 deletions doc/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -2087,3 +2087,15 @@ @article{PolonenkoMaddox2019
year = {2019},
pages = {2331216519871395}
}

@article{DeledalleEtAl2014,
author = {Deledalle, Charles-Alban and Vaiter, Samuel and Fadili, Jalal and Peyré, Gabriel},
title = {Stein Unbiased GrAdient estimator of the Risk (SUGAR) for Multiple Parameter Selection},
journal = {SIAM Journal on Imaging Sciences},
volume = {7},
number = {4},
pages = {2448-2487},
year = {2014},
doi = {10.1137/140968045},
URL = {https://doi.org/10.1137/140968045}
}
10 changes: 5 additions & 5 deletions examples/inverse/mixed_norm_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@
forward = mne.read_forward_solution(fwd_fname)

###############################################################################
# Run solver
alpha = 55 # regularization parameter between 0 and 100 (100 is high)
loose, depth = 0.2, 0.9 # loose orientation & depth weighting
# Run solver with SURE criterion :footcite:`DeledalleEtAl2014`
alpha = "sure" # regularization parameter between 0 and 100 or SURE criterion
loose, depth = 0.9, 0.9 # loose orientation & depth weighting
n_mxne_iter = 10 # if > 1 use L0.5/L2 reweighted mixed norm solver
# if n_mxne_iter > 1 dSPM weighting can be avoided.

Expand All @@ -58,9 +58,9 @@
# Compute (ir)MxNE inverse solution with dipole output
dipoles, residual = mixed_norm(
evoked, forward, cov, alpha, loose=loose, depth=depth, maxit=3000,
tol=1e-4, active_set_size=10, debias=True, weights=stc_dspm,
tol=1e-4, active_set_size=10, debias=False, weights=stc_dspm,
weights_min=8., n_mxne_iter=n_mxne_iter, return_residual=True,
return_as_dipoles=True, verbose=True)
return_as_dipoles=True, verbose=True, random_state=0)

t = 0.083
tidx = evoked.time_as_index(t)
Expand Down
228 changes: 209 additions & 19 deletions mne/inverse_sparse/mxne_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
from ..forward import is_fixed_orient
from ..io.pick import pick_channels_evoked
from ..io.proj import deactivate_proj
from ..utils import logger, verbose, _check_depth, _check_option, sum_squared
from ..utils import (logger, verbose, _check_depth, _check_option, sum_squared,
_validate_type, check_random_state)
from ..dipole import Dipole

from .mxne_optim import (mixed_norm_solver, iterative_mixed_norm_solver, _Phi,
tf_mixed_norm_solver, iterative_tf_mixed_norm_solver,
norm_l2inf, norm_epsilon_inf)
norm_l2inf, norm_epsilon_inf, groups_norm2)


def _check_ori(pick_ori, forward):
Expand Down Expand Up @@ -284,12 +285,12 @@ def make_stc_from_dipoles(dipoles, src, verbose=None):


@verbose
def mixed_norm(evoked, forward, noise_cov, alpha, loose='auto', depth=0.8,
maxit=3000, tol=1e-4, active_set_size=10,
def mixed_norm(evoked, forward, noise_cov, alpha='sure', loose='auto',
depth=0.8, maxit=3000, tol=1e-4, active_set_size=10,
debias=True, time_pca=True, weights=None, weights_min=0.,
solver='auto', n_mxne_iter=1, return_residual=False,
return_as_dipoles=False, dgap_freq=10, rank=None, pick_ori=None,
verbose=None):
sure_alpha_grid="auto", random_state=None, verbose=None):
"""Mixed-norm estimate (MxNE) and iterative reweighted MxNE (irMxNE).
Compute L1/L2 mixed-norm solution :footcite:`GramfortEtAl2012` or L0.5/L2
Expand All @@ -303,9 +304,14 @@ def mixed_norm(evoked, forward, noise_cov, alpha, loose='auto', depth=0.8,
Forward operator.
noise_cov : instance of Covariance
Noise covariance to compute whitener.
alpha : float in range [0, 100)
Regularization parameter. 0 means no regularization, 100 would give 0
active dipole.
alpha : float | str
Regularization parameter. If float it should be in the range [0, 100):
0 means no regularization, 100 would give 0 active dipole.
If ``'sure'`` (default), the SURE method from
:footcite:`DeledalleEtAl2014` will be used.
.. versionchanged:: 0.24
The default was changed to ``'sure'``.
%(loose)s
%(depth)s
maxit : int
Expand Down Expand Up @@ -345,6 +351,17 @@ def mixed_norm(evoked, forward, noise_cov, alpha, loose='auto', depth=0.8,
.. versionadded:: 0.18
%(pick_ori)s
sure_alpha_grid : array | str
If ``'auto'`` (default), the SURE is evaluated along 15 uniformly
distributed alphas between alpha_max and 0.1 * alpha_max. If array, the
grid is directly specified. Ignored if alpha is not "sure".
.. versionadded:: 0.24
random_state : int | None
The random state used in a random number generator for delta and
epsilon used for the SURE computation. Defaults to None.
.. versionadded:: 0.24
%(verbose)s
Returns
Expand All @@ -364,16 +381,25 @@ def mixed_norm(evoked, forward, noise_cov, alpha, loose='auto', depth=0.8,
.. footbibliography::
"""
from scipy import linalg
if not (0. <= alpha < 100.):
raise ValueError('alpha must be in [0, 100). '
_validate_type(alpha, ('numeric', str), 'alpha')
if isinstance(alpha, str):
_check_option('alpha', alpha, ('sure',))
elif not 0. <= alpha < 100:
raise ValueError('If not equal to "sure" alpha must be in [0, 100). '
'Got alpha = %s' % alpha)
if n_mxne_iter < 1:
raise ValueError('MxNE has to be computed at least 1 time. '
'Requires n_mxne_iter >= 1, got %d' % n_mxne_iter)
if dgap_freq <= 0.:
raise ValueError('dgap_freq must be a positive integer.'
' Got dgap_freq = %s' % dgap_freq)

if not(isinstance(sure_alpha_grid, (np.ndarray, list)) or
sure_alpha_grid == "auto"):
raise ValueError('If not equal to "auto" sure_alpha_grid must be an '
'array. Got %s' % type(sure_alpha_grid))
if sure_alpha_grid != "auto" and alpha != "sure":
raise Exception('If sure_alpha_grid is manually specified, alpha must '
'be "sure". Got %s' % alpha)
pca = True
if not isinstance(evoked, list):
evoked = [evoked]
Expand Down Expand Up @@ -413,16 +439,29 @@ def mixed_norm(evoked, forward, noise_cov, alpha, loose='auto', depth=0.8,
gain /= alpha_max
source_weighting /= alpha_max

if n_mxne_iter == 1:
X, active_set, E = mixed_norm_solver(
M, gain, alpha, maxit=maxit, tol=tol,
active_set_size=active_set_size, n_orient=n_dip_per_pos,
debias=debias, solver=solver, dgap_freq=dgap_freq, verbose=verbose)
else:
X, active_set, E = iterative_mixed_norm_solver(
M, gain, alpha, n_mxne_iter, maxit=maxit, tol=tol,
# Alpha selected automatically by SURE minimization
if alpha == "sure":
alpha_grid = (np.geomspace(100, 10, num=15)
if sure_alpha_grid == "auto" else sure_alpha_grid)
X, active_set, best_alpha_ = _compute_mxne_sure(
M, gain, alpha_grid, sigma=1, random_state=random_state,
n_mxne_iter=n_mxne_iter, maxit=maxit, tol=tol,
n_orient=n_dip_per_pos, active_set_size=active_set_size,
debias=debias, solver=solver, dgap_freq=dgap_freq, verbose=verbose)
logger.info('Selected alpha: %s' % best_alpha_)
else:
if n_mxne_iter == 1:
X, active_set, E = mixed_norm_solver(
M, gain, alpha, maxit=maxit, tol=tol,
active_set_size=active_set_size, n_orient=n_dip_per_pos,
debias=debias, solver=solver, dgap_freq=dgap_freq,
verbose=verbose)
else:
X, active_set, E = iterative_mixed_norm_solver(
M, gain, alpha, n_mxne_iter, maxit=maxit, tol=tol,
n_orient=n_dip_per_pos, active_set_size=active_set_size,
debias=debias, solver=solver, dgap_freq=dgap_freq,
verbose=verbose)

if time_pca:
X = np.dot(X, Vh)
Expand Down Expand Up @@ -700,3 +739,154 @@ def tf_mixed_norm(evoked, forward, noise_cov,
out = out, residual

return out


@verbose
def _compute_mxne_sure(M, gain, alpha_grid, sigma, n_mxne_iter, maxit, tol,
n_orient, active_set_size, debias, solver, dgap_freq,
random_state, verbose):
"""Stein Unbiased Risk Estimator (SURE).
Implements the finite-difference Monte-Carlo approximation
of the SURE for Multi-Task LASSO.
See reference :footcite:`DeledalleEtAl2014`.
Parameters
----------
M : array, shape (n_sensors, n_times)
The data.
gain : array, shape (n_sensors, n_dipoles)
The gain matrix a.k.a. lead field.
alpha_grid : array, shape (n_alphas,)
The grid of alphas used to evaluate the SURE.
sigma : float
The true or estimated noise level in the data. Usually 1 if the data
has been previously whitened using MNE whitener.
n_mxne_iter : int
The number of MxNE iterations. If > 1, iterative reweighting is
applied.
maxit : int
Maximum number of iterations.
tol : float
Tolerance parameter.
n_orient : int
The number of orientation (1 : fixed or 3 : free or loose).
active_set_size : int
Size of active set increase at each iteration.
debias : bool
Debias source estimates.
solver : 'prox' | 'cd' | 'bcd' | 'auto'
The algorithm to use for the optimization.
dgap_freq : int or np.inf
The duality gap is evaluated every dgap_freq iterations.
random_state : int | None
The random state used in a random number generator for delta and
epsilon used for the SURE computation.
Returns
-------
X : array, shape (n_active, n_times)
Coefficient matrix.
active_set : array, shape (n_dipoles,)
Array of indices of non-zero coefficients.
best_alpha_ : float
Alpha that minimizes the SURE.
References
----------
.. footbibliography::
"""
def g(w):
return np.sqrt(np.sqrt(groups_norm2(w.copy(), n_orient)))

def gprime(w):
return 2. * np.repeat(g(w), n_orient).ravel()

def _run_solver(alpha, M, n_mxne_iter, as_init=None, X_init=None,
w_init=None):
if n_mxne_iter == 1:
X, active_set, _ = mixed_norm_solver(
M, gain, alpha, maxit=maxit, tol=tol,
active_set_size=active_set_size, n_orient=n_orient,
debias=debias, solver=solver, dgap_freq=dgap_freq,
active_set_init=as_init, X_init=X_init, verbose=False)
else:
X, active_set, _ = iterative_mixed_norm_solver(
M, gain, alpha, n_mxne_iter, maxit=maxit, tol=tol,
n_orient=n_orient, active_set_size=active_set_size,
debias=debias, solver=solver, dgap_freq=dgap_freq,
weight_init=w_init, verbose=False)
return X, active_set

def _fit_on_grid(gain, M, eps, delta):
coefs_grid_1_0 = np.zeros((len(alpha_grid), gain.shape[1], M.shape[1]))
coefs_grid_2_0 = np.zeros((len(alpha_grid), gain.shape[1], M.shape[1]))
active_sets, active_sets_eps = [], []
M_eps = M + eps * delta
# warm start - first iteration (leverages convexity)
logger.info('Warm starting...')
for j, alpha in enumerate(alpha_grid):
logger.info('alpha: %s' % alpha)
X, a_set = _run_solver(alpha, M, 1)
X_eps, a_set_eps = _run_solver(alpha, M_eps, 1)
coefs_grid_1_0[j][a_set, :] = X
coefs_grid_2_0[j][a_set_eps, :] = X_eps
active_sets.append(a_set)
active_sets_eps.append(a_set_eps)
# next iterations
if n_mxne_iter == 1:
return coefs_grid_1_0, coefs_grid_2_0, active_sets
else:
coefs_grid_1 = coefs_grid_1_0.copy()
coefs_grid_2 = coefs_grid_2_0.copy()
logger.info('Fitting SURE on grid.')
for j, alpha in enumerate(alpha_grid):
logger.info('alpha: %s' % alpha)
if active_sets[j].sum() > 0:
w = gprime(coefs_grid_1[j])
X, a_set = _run_solver(alpha, M, n_mxne_iter - 1,
w_init=w)
coefs_grid_1[j][a_set, :] = X
active_sets[j] = a_set
if active_sets_eps[j].sum() > 0:
w_eps = gprime(coefs_grid_2[j])
X_eps, a_set_eps = _run_solver(alpha, M_eps,
n_mxne_iter - 1,
w_init=w_eps)
coefs_grid_2[j][a_set_eps, :] = X_eps
active_sets_eps[j] = a_set_eps

return coefs_grid_1, coefs_grid_2, active_sets

def _compute_sure_val(coef1, coef2, gain, M, sigma, delta, eps):
n_sensors, n_times = gain.shape[0], M.shape[1]
dof = (gain @ (coef2 - coef1) * delta).sum() / eps
df_term = np.linalg.norm(M - gain @ coef1) ** 2
sure = df_term - n_sensors * n_times * sigma ** 2
sure += 2 * dof * sigma ** 2
return sure

sure_path = np.empty(len(alpha_grid))

rng = check_random_state(random_state)
# See Deledalle et al. 20214 Sec. 5.1
eps = 2 * sigma / (M.shape[0] ** 0.3)
delta = rng.randn(*M.shape)

coefs_grid_1, coefs_grid_2, active_sets = _fit_on_grid(gain, M, eps, delta)

logger.info("Computing SURE values on grid.")
for i, (coef1, coef2) in enumerate(zip(coefs_grid_1, coefs_grid_2)):
sure_path[i] = _compute_sure_val(
coef1, coef2, gain, M, sigma, delta, eps)
if verbose:
logger.info("alpha %s :: sure %s" % (alpha_grid[i], sure_path[i]))
best_alpha_ = alpha_grid[np.argmin(sure_path)]

X = coefs_grid_1[np.argmin(sure_path)]
active_set = active_sets[np.argmin(sure_path)]

X = X[active_set, :]

return X, active_set, best_alpha_
Loading

0 comments on commit ea1212d

Please sign in to comment.