-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
_gamma_map.py
341 lines (289 loc) · 9.93 KB
/
_gamma_map.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import numpy as np
from ..fixes import _safe_svd
from ..forward import is_fixed_orient
from ..minimum_norm.inverse import _check_reference, _log_exp_var
from ..utils import logger, verbose, warn
from .mxne_inverse import (
_check_ori,
_compute_residual,
_make_dipoles_sparse,
_make_sparse_stc,
_prepare_gain,
_reapply_source_weighting,
)
@verbose
def _gamma_map_opt(
M,
G,
alpha,
maxit=10000,
tol=1e-6,
update_mode=1,
group_size=1,
gammas=None,
verbose=None,
):
"""Hierarchical Bayes (Gamma-MAP).
Parameters
----------
M : array, shape=(n_sensors, n_times)
Observation.
G : array, shape=(n_sensors, n_sources)
Forward operator.
alpha : float
Regularization parameter (noise variance).
maxit : int
Maximum number of iterations.
tol : float
Tolerance parameter for convergence.
group_size : int
Number of consecutive sources which use the same gamma.
update_mode : int
Update mode, 1: MacKay update (default), 3: Modified MacKay update.
gammas : array, shape=(n_sources,)
Initial values for posterior variances (gammas). If None, a
variance of 1.0 is used.
%(verbose)s
Returns
-------
X : array, shape=(n_active, n_times)
Estimated source time courses.
active_set : array, shape=(n_active,)
Indices of active sources.
"""
G = G.copy()
M = M.copy()
if gammas is None:
gammas = np.ones(G.shape[1], dtype=np.float64)
eps = np.finfo(float).eps
n_sources = G.shape[1]
n_sensors, n_times = M.shape
# apply normalization so the numerical values are sane
M_normalize_constant = np.linalg.norm(np.dot(M, M.T), ord="fro")
M /= np.sqrt(M_normalize_constant)
alpha /= M_normalize_constant
G_normalize_constant = np.linalg.norm(G, ord=np.inf)
G /= G_normalize_constant
if n_sources % group_size != 0:
raise ValueError(
"Number of sources has to be evenly dividable by the group size"
)
n_active = n_sources
active_set = np.arange(n_sources)
gammas_full_old = gammas.copy()
if update_mode == 2:
denom_fun = np.sqrt
else:
# do nothing
def denom_fun(x):
return x
last_size = -1
for itno in range(maxit):
gammas[np.isnan(gammas)] = 0.0
gidx = np.abs(gammas) > eps
active_set = active_set[gidx]
gammas = gammas[gidx]
# update only active gammas (once set to zero it stays at zero)
if n_active > len(active_set):
n_active = active_set.size
G = G[:, gidx]
CM = np.dot(G * gammas[np.newaxis, :], G.T)
CM.flat[:: n_sensors + 1] += alpha
# Invert CM keeping symmetry
U, S, _ = _safe_svd(CM, full_matrices=False)
S = S[np.newaxis, :]
del CM
CMinv = np.dot(U / (S + eps), U.T)
CMinvG = np.dot(CMinv, G)
A = np.dot(CMinvG.T, M) # mult. w. Diag(gamma) in gamma update
if update_mode == 1:
# MacKay fixed point update (10) in [1]
numer = gammas**2 * np.mean((A * A.conj()).real, axis=1)
denom = gammas * np.sum(G * CMinvG, axis=0)
elif update_mode == 2:
# modified MacKay fixed point update (11) in [1]
numer = gammas * np.sqrt(np.mean((A * A.conj()).real, axis=1))
denom = np.sum(G * CMinvG, axis=0) # sqrt is applied below
else:
raise ValueError("Invalid value for update_mode")
if group_size == 1:
if denom is None:
gammas = numer
else:
gammas = numer / np.maximum(denom_fun(denom), np.finfo("float").eps)
else:
numer_comb = np.sum(numer.reshape(-1, group_size), axis=1)
if denom is None:
gammas_comb = numer_comb
else:
denom_comb = np.sum(denom.reshape(-1, group_size), axis=1)
gammas_comb = numer_comb / denom_fun(denom_comb)
gammas = np.repeat(gammas_comb / group_size, group_size)
# compute convergence criterion
gammas_full = np.zeros(n_sources, dtype=np.float64)
gammas_full[active_set] = gammas
err = np.sum(np.abs(gammas_full - gammas_full_old)) / np.sum(
np.abs(gammas_full_old)
)
gammas_full_old = gammas_full
breaking = err < tol or n_active == 0
if len(gammas) != last_size or breaking:
logger.info(
f"Iteration: {itno}\t active set size: {len(gammas)}\t convergence: "
f"{err:.3e}"
)
last_size = len(gammas)
if breaking:
break
if itno < maxit - 1:
logger.info("\nConvergence reached !\n")
else:
warn("\nConvergence NOT reached !\n")
# undo normalization and compute final posterior mean
n_const = np.sqrt(M_normalize_constant) / G_normalize_constant
x_active = n_const * gammas[:, None] * A
return x_active, active_set
@verbose
def gamma_map(
evoked,
forward,
noise_cov,
alpha,
loose="auto",
depth=0.8,
xyz_same_gamma=True,
maxit=10000,
tol=1e-6,
update_mode=1,
gammas=None,
pca=True,
return_residual=False,
return_as_dipoles=False,
rank=None,
pick_ori=None,
verbose=None,
):
"""Hierarchical Bayes (Gamma-MAP) sparse source localization method.
Models each source time course using a zero-mean Gaussian prior with an
unknown variance (gamma) parameter. During estimation, most gammas are
driven to zero, resulting in a sparse source estimate, as in
:footcite:`WipfEtAl2007` and :footcite:`WipfNagarajan2009`.
For fixed-orientation forward operators, a separate gamma is used for each
source time course, while for free-orientation forward operators, the same
gamma is used for the three source time courses at each source space point
(separate gammas can be used in this case by using xyz_same_gamma=False).
Parameters
----------
evoked : instance of Evoked
Evoked data to invert.
forward : dict
Forward operator.
noise_cov : instance of Covariance
Noise covariance to compute whitener.
alpha : float
Regularization parameter (noise variance).
%(loose)s
%(depth)s
xyz_same_gamma : bool
Use same gamma for xyz current components at each source space point.
Recommended for free-orientation forward solutions.
maxit : int
Maximum number of iterations.
tol : float
Tolerance parameter for convergence.
update_mode : int
Update mode, 1: MacKay update (default), 2: Modified MacKay update.
gammas : array, shape=(n_sources,)
Initial values for posterior variances (gammas). If None, a
variance of 1.0 is used.
pca : bool
If True the rank of the data is reduced to the true dimension.
return_residual : bool
If True, the residual is returned as an Evoked instance.
return_as_dipoles : bool
If True, the sources are returned as a list of Dipole instances.
%(rank_none)s
.. versionadded:: 0.18
%(pick_ori)s
%(verbose)s
Returns
-------
stc : instance of SourceEstimate
Source time courses.
residual : instance of Evoked
The residual a.k.a. data not explained by the sources.
Only returned if return_residual is True.
References
----------
.. footbibliography::
"""
_check_reference(evoked)
forward, gain, gain_info, whitener, source_weighting, mask = _prepare_gain(
forward, evoked.info, noise_cov, pca, depth, loose, rank
)
_check_ori(pick_ori, forward)
group_size = 1 if (is_fixed_orient(forward) or not xyz_same_gamma) else 3
# get the data
sel = [evoked.ch_names.index(name) for name in gain_info["ch_names"]]
M = evoked.data[sel]
# whiten the data
logger.info("Whitening data matrix.")
M = np.dot(whitener, M)
# run the optimization
X, active_set = _gamma_map_opt(
M,
gain,
alpha,
maxit=maxit,
tol=tol,
update_mode=update_mode,
gammas=gammas,
group_size=group_size,
verbose=verbose,
)
if len(active_set) == 0:
raise Exception("No active dipoles found. alpha is too big.")
M_estimate = gain[:, active_set] @ X
# Reapply weights to have correct unit
X = _reapply_source_weighting(X, source_weighting, active_set)
if return_residual:
residual = _compute_residual(forward, evoked, X, active_set, gain_info)
if group_size == 1 and not is_fixed_orient(forward):
# make sure each source has 3 components
idx, offset = divmod(active_set, 3)
active_src = np.unique(idx)
if len(X) < 3 * len(active_src):
X_xyz = np.zeros((len(active_src), 3, X.shape[1]), dtype=X.dtype)
idx = np.searchsorted(active_src, idx)
X_xyz[idx, offset, :] = X
X_xyz.shape = (len(active_src) * 3, X.shape[1])
X = X_xyz
active_set = (active_src[:, np.newaxis] * 3 + np.arange(3)).ravel()
source_weighting[source_weighting == 0] = 1 # zeros
gain_active = gain[:, active_set] / source_weighting[active_set]
del source_weighting
tmin = evoked.times[0]
tstep = 1.0 / evoked.info["sfreq"]
if return_as_dipoles:
out = _make_dipoles_sparse(
X, active_set, forward, tmin, tstep, M, gain_active, active_is_idx=True
)
else:
out = _make_sparse_stc(
X,
active_set,
forward,
tmin,
tstep,
active_is_idx=True,
pick_ori=pick_ori,
verbose=verbose,
)
_log_exp_var(M, M_estimate, prefix="")
logger.info("[done]")
if return_residual:
out = out, residual
return out