Skip to content

Commit

Permalink
ENH: Allow gradient compensated data in maxwell_filter (mne-tools#10554)
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner authored Apr 14, 2023
1 parent 4a60599 commit fc981bd
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 51 deletions.
1 change: 1 addition & 0 deletions doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ Enhancements
- Allow an image with intracranial electrode contacts (e.g. computed tomography) to be used without the freesurfer recon-all surfaces to locate contacts so that it doesn't have to be downsampled to freesurfer dimensions (for microelectrodes) and show an example :ref:`ex-ieeg-micro` with :func:`mne.transforms.apply_volume_registration_points` added to aid this transform (:gh:`11567` by `Alex Rockhill`_)
- Use new :meth:`dipy.workflows.align.DiffeomorphicMap.transform_points` to transform a montage of intracranial contacts more efficiently (:gh:`11572` by `Alex Rockhill`_)
- Improve performance of raw data browsing with many annotations (:gh:`11614` by `Eric Larson`_)
- Add support for :func:`mne.preprocessing.maxwell_filter` with gradient-compensated CTF data, e.g., for tSSS-only mode (:gh:`10554` by `Eric Larson`_)
- Add support for eyetracking data using :func:`mne.io.read_raw_eyelink` (:gh:`11152` by `Dominik Welke`_ and `Scott Huberty`_)

Bugs
Expand Down
11 changes: 8 additions & 3 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3671,7 +3671,8 @@ def average_movements(epochs, head_pos=None, orig_sfreq=None, picks=None,
_check_usable, _col_norm_pinv,
_get_n_moments, _get_mf_picks_fix_mags,
_prep_mf_coils, _check_destination,
_remove_meg_projs, _get_coil_scale)
_remove_meg_projs_comps,
_get_coil_scale, _get_sensor_operator)
if head_pos is None:
raise TypeError('head_pos must be provided and cannot be None')
from .chpi import head_pos_to_trans_rot_t
Expand All @@ -3684,7 +3685,7 @@ def average_movements(epochs, head_pos=None, orig_sfreq=None, picks=None,
head_pos = head_pos_to_trans_rot_t(head_pos)
trn, rot, t = head_pos
del head_pos
_check_usable(epochs)
_check_usable(epochs, ignore_ref)
origin = _check_origin(origin, epochs.info, 'head')
recon_trans = _check_destination(destination, epochs.info, True)

Expand All @@ -3697,6 +3698,7 @@ def average_movements(epochs, head_pos=None, orig_sfreq=None, picks=None,
_get_mf_picks_fix_mags(info_to, int_order, ext_order, ignore_ref)
coil_scale, mag_scale = _get_coil_scale(
meg_picks, mag_picks, grad_picks, mag_scale, info_to)
mult = _get_sensor_operator(epochs, meg_picks)
n_channels, n_times = len(epochs.ch_names), len(epochs.times)
other_picks = np.setdiff1d(np.arange(n_channels), meg_picks)
data = np.zeros((n_channels, n_times))
Expand Down Expand Up @@ -3761,6 +3763,9 @@ def average_movements(epochs, head_pos=None, orig_sfreq=None, picks=None,
# (We would need to include external here for regularization to work)
exp['ext_order'] = 0
S_recon = _trans_sss_basis(exp, all_coils_recon, recon_trans)
if mult is not None:
S_decomp = mult @ S_decomp
S_recon = mult @ S_recon
exp['ext_order'] = ext_order
# We could determine regularization on basis of destination basis
# matrix, restricted to good channels, as regularizing individual
Expand All @@ -3779,7 +3784,7 @@ def average_movements(epochs, head_pos=None, orig_sfreq=None, picks=None,
evoked = epochs._evoked_from_epoch_data(data, info_to, picks,
n_events=count, kind='average',
comment=epochs._name)
_remove_meg_projs(evoked) # remove MEG projectors, they won't apply now
_remove_meg_projs_comps(evoked, ignore_ref)
logger.info('Created Evoked dataset from %s epochs' % (count,))
return (evoked, mapping) if return_mapping else evoked

Expand Down
37 changes: 15 additions & 22 deletions mne/io/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,8 @@ def _dtype(self):
return self._dtype_

@verbose
def _read_segment(self, start=0, stop=None, sel=None, data_buffer=None,
projector=None, verbose=None):
def _read_segment(self, start=0, stop=None, sel=None, data_buffer=None, *,
verbose=None):
"""Read a chunk of raw data.
Parameters
Expand Down Expand Up @@ -344,26 +344,22 @@ def _read_segment(self, start=0, stop=None, sel=None, data_buffer=None,

# set up cals and mult (cals, compensation, and projector)
n_out = len(np.arange(len(self.ch_names))[idx])
cals = self._cals.ravel()[np.newaxis, :]
if projector is not None:
assert projector.shape[0] == projector.shape[1] == cals.shape[1]
if self._comp is not None:
cals = self._cals.ravel()
projector, comp = self._projector, self._comp
if comp is not None:
mult = comp
if projector is not None:
mult = self._comp * cals
mult = np.dot(projector[idx], mult)
else:
mult = self._comp[idx] * cals
elif projector is not None:
mult = projector[idx] * cals
mult = projector @ mult
else:
mult = None
del projector
mult = projector
del projector, comp

if mult is None:
cals = cals.T[idx]
cals = cals[idx, np.newaxis]
assert cals.shape == (n_out, 1)
need_idx = idx # sufficient just to read the given channels
else:
mult = mult[idx] * cals
cals = None # shouldn't be used
assert mult.shape == (n_out, len(self.ch_names))
# read all necessary for proj
Expand Down Expand Up @@ -504,8 +500,7 @@ def _preload_data(self, preload):
data_buffer = None
logger.info('Reading %d ... %d = %9.3f ... %9.3f secs...' %
(0, len(self.times) - 1, 0., self.times[-1]))
self._data = self._read_segment(
data_buffer=data_buffer, projector=self._projector)
self._data = self._read_segment(data_buffer=data_buffer)
assert len(self._data) == self.info['nchan']
self.preload = True
self._comp = None # no longer needed
Expand Down Expand Up @@ -752,8 +747,7 @@ def _getitem(self, item, return_times=True):
if self.preload:
data = self._data[sel, start:stop]
else:
data = self._read_segment(start=start, stop=stop, sel=sel,
projector=self._projector)
data = self._read_segment(start=start, stop=stop, sel=sel)

if return_times:
# Rather than compute the entire thing just compute the subset
Expand Down Expand Up @@ -1669,7 +1663,7 @@ def append(self, raws, preload=None):
nsamp = c_ns[-1]

if not self.preload:
this_data = self._read_segment(projector=self._projector)
this_data = self._read_segment()
else:
this_data = self._data

Expand All @@ -1681,8 +1675,7 @@ def append(self, raws, preload=None):
if not raws[ri].preload:
# read the data directly into the buffer
data_buffer = _data[:, c_ns[ri]:c_ns[ri + 1]]
raws[ri]._read_segment(data_buffer=data_buffer,
projector=self._projector)
raws[ri]._read_segment(data_buffer=data_buffer)
else:
_data[:, c_ns[ri]:c_ns[ri + 1]] = raws[ri]._data
self._data = _data
Expand Down
51 changes: 38 additions & 13 deletions mne/preprocessing/maxwell.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
quat_to_rot, rot_to_quat)
from ..forward import _concatenate_coils, _prep_meg_channels, _create_meg_coils
from ..surface import _normalize_vectors
from ..io.compensator import make_compensator
from ..io.constants import FIFF, FWD
from ..io.meas_info import _simplify_info, Info
from ..io.proc_history import _read_ctc
Expand Down Expand Up @@ -376,7 +377,7 @@ def _prep_maxwell_filter(

# triage inputs ASAP to avoid late-thrown errors
_validate_type(raw, BaseRaw, 'raw')
_check_usable(raw)
_check_usable(raw, ignore_ref)
_check_regularize(regularize)
st_correlation = float(st_correlation)
if st_correlation <= 0. or st_correlation > 1.:
Expand Down Expand Up @@ -478,7 +479,6 @@ def _prep_maxwell_filter(
exp['extended_proj'] = extended_proj
del extended_proj
# Reconstruct data from internal space only (Eq. 38), and rescale S_recon
S_recon /= coil_scale
if recon_trans is not None:
# warn if we have translated too far
diff = 1000 * (info['dev_head_t']['trans'][:3, 3] -
Expand Down Expand Up @@ -520,13 +520,20 @@ def _prep_maxwell_filter(
np.zeros(3)])
else:
this_pos_quat = None

# Figure out our linear operator
mult = _get_sensor_operator(raw, meg_picks)
if mult is not None:
S_recon = mult @ S_recon
S_recon /= coil_scale

_get_this_decomp_trans = partial(
_get_decomp, all_coils=all_coils,
cal=calibration, regularize=regularize,
exp=exp, ignore_ref=ignore_ref, coil_scale=coil_scale,
grad_picks=grad_picks, mag_picks=mag_picks, good_mask=good_mask,
mag_or_fine=mag_or_fine, bad_condition=bad_condition,
mag_scale=mag_scale)
mag_scale=mag_scale, mult=mult)
update_kwargs.update(
nchan=good_mask.sum(), st_only=st_only, recon_trans=recon_trans)
params = dict(
Expand All @@ -536,15 +543,15 @@ def _prep_maxwell_filter(
this_pos_quat=this_pos_quat, meg_picks=meg_picks,
good_mask=good_mask, grad_picks=grad_picks, head_pos=head_pos,
info=info, _get_this_decomp_trans=_get_this_decomp_trans,
S_recon=S_recon, update_kwargs=update_kwargs)
S_recon=S_recon, update_kwargs=update_kwargs, ignore_ref=ignore_ref)
return params


def _run_maxwell_filter(
raw, skip_by_annotation, st_duration, st_correlation, st_only,
st_when, ctc, coil_scale, this_pos_quat, meg_picks, good_mask,
grad_picks, head_pos, info, _get_this_decomp_trans, S_recon,
update_kwargs,
update_kwargs, *, ignore_ref=False,
reconstruct='in', copy=True):
# Eventually find_bad_channels_maxwell could be sped up by moving this
# outside the loop (e.g., in the prep function) but regularization depends
Expand All @@ -564,7 +571,7 @@ def _run_maxwell_filter(
del raw
if not st_only:
# remove MEG projectors, they won't apply now
_remove_meg_projs(raw_sss)
_remove_meg_projs_comps(raw_sss, ignore_ref)
# Figure out which segments of data we can use
onsets, ends = _annotations_starts_stops(
raw_sss, skip_by_annotation, invert=True)
Expand Down Expand Up @@ -745,7 +752,19 @@ def _get_coil_scale(meg_picks, mag_picks, grad_picks, mag_scale, info):
return coil_scale, mag_scale


def _remove_meg_projs(inst):
def _get_sensor_operator(raw, meg_picks):
comp = raw.compensation_grade
if comp not in (0, None):
mult = make_compensator(raw.info, 0, comp)
logger.info(f' Accounting for compensation grade {comp}')
assert mult.shape[0] == mult.shape[1] == len(raw.ch_names)
mult = mult[np.ix_(meg_picks, meg_picks)]
else:
mult = None
return mult


def _remove_meg_projs_comps(inst, ignore_ref):
"""Remove inplace existing MEG projectors (assumes inactive)."""
meg_picks = pick_types(inst.info, meg=True, exclude=[])
meg_channels = [inst.ch_names[pi] for pi in meg_picks]
Expand All @@ -754,6 +773,10 @@ def _remove_meg_projs(inst):
if not any(c in meg_channels for c in proj['data']['col_names']):
non_meg_proj.append(proj)
inst.add_proj(non_meg_proj, remove_existing=True, verbose=False)
if ignore_ref and inst.info['comps']:
assert inst.compensation_grade in (None, 0)
with inst.info._unlock():
inst.info['comps'] = []


def _check_destination(destination, info, head_frame):
Expand Down Expand Up @@ -959,9 +982,9 @@ def _check_pos(pos, head_frame, raw, st_fixed, sfreq):
return pos


def _get_decomp(trans, all_coils, cal, regularize, exp, ignore_ref,
def _get_decomp(trans, *, all_coils, cal, regularize, exp, ignore_ref,
coil_scale, grad_picks, mag_picks, good_mask, mag_or_fine,
bad_condition, t, mag_scale):
bad_condition, t, mag_scale, mult):
"""Get a decomposition matrix and pseudoinverse matrices."""
from scipy import linalg
#
Expand All @@ -970,6 +993,8 @@ def _get_decomp(trans, all_coils, cal, regularize, exp, ignore_ref,
S_decomp_full = _get_s_decomp(
exp, all_coils, trans, coil_scale, cal, ignore_ref, grad_picks,
mag_picks, mag_scale)
if mult is not None:
S_decomp_full = mult @ S_decomp_full
S_decomp = S_decomp_full[good_mask]
#
# Extended SSS basis (eSSS)
Expand Down Expand Up @@ -1143,16 +1168,16 @@ def _check_regularize(regularize):
raise ValueError('regularize must be None or "in"')


def _check_usable(inst):
def _check_usable(inst, ignore_ref):
"""Ensure our data are clean."""
if inst.proj:
raise RuntimeError('Projectors cannot be applied to data during '
'Maxwell filtering.')
current_comp = inst.compensation_grade
if current_comp not in (0, None):
if current_comp not in (0, None) and ignore_ref:
raise RuntimeError('Maxwell filter cannot be done on compensated '
'channels, but data have been compensated with '
'grade %s.' % current_comp)
'channels (data have been compensated with '
'grade {current_comp}) when ignore_ref=True')


def _col_norm_pinv(x):
Expand Down
48 changes: 35 additions & 13 deletions mne/preprocessing/tests/test_maxwell.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,27 +303,49 @@ def test_other_systems():
_assert_shielding(raw_sss_auto, power, 0.7)

# CTF
raw_ctf = read_crop(fname_ctf_raw)
assert raw_ctf.compensation_grade == 3
with pytest.raises(RuntimeError, match='compensated'):
maxwell_filter(raw_ctf)
raw_ctf.apply_gradient_compensation(0)
raw_ctf_3 = read_crop(fname_ctf_raw)
assert raw_ctf_3.compensation_grade == 3
raw_ctf_0 = raw_ctf_3.copy().apply_gradient_compensation(0)
assert raw_ctf_0.compensation_grade == 0
# 3rd-order gradient compensation works really well (better than MF here)
_assert_shielding(raw_ctf_3, raw_ctf_0, 20, 21)
origin = (0., 0., 0.04)
raw_sss_3 = maxwell_filter(raw_ctf_3, origin=origin, verbose=True)
_assert_n_free(raw_sss_3, 70)
_assert_shielding(raw_sss_3, raw_ctf_3, 0.12, 0.14)
_assert_shielding(raw_sss_3, raw_ctf_0, 2.63, 2.66)
assert raw_sss_3.compensation_grade == 3
raw_sss_3.apply_gradient_compensation(0)
assert raw_sss_3.compensation_grade == 0
_assert_shielding(raw_sss_3, raw_ctf_3, 0.15, 0.17)
_assert_shielding(raw_sss_3, raw_ctf_0, 3.18, 3.20)
with pytest.raises(ValueError, match='digitization points'):
maxwell_filter(raw_ctf)
raw_sss = maxwell_filter(raw_ctf, origin=(0., 0., 0.04))
_assert_n_free(raw_sss, 68)
_assert_shielding(raw_sss, raw_ctf, 1.8)
maxwell_filter(raw_ctf_0)
raw_sss_0 = maxwell_filter(raw_ctf_0, origin=origin, verbose=True)
_assert_n_free(raw_sss_0, 68)
_assert_shielding(raw_sss_0, raw_ctf_3, 0.07, 0.09)
_assert_shielding(raw_sss_0, raw_ctf_0, 1.8, 1.9)
raw_sss_0.apply_gradient_compensation(3)
_assert_shielding(raw_sss_0, raw_ctf_3, 0.07, 0.09)
_assert_shielding(raw_sss_0, raw_ctf_0, 1.63, 1.67)
with pytest.raises(RuntimeError, match='ignore_ref'):
maxwell_filter(raw_ctf_3, ignore_ref=True)
# ignoring ref outperforms including it in maxwell filtering
with catch_logging() as log:
raw_sss = maxwell_filter(raw_ctf, origin=(0., 0., 0.04),
raw_sss = maxwell_filter(raw_ctf_0, origin=origin,
ignore_ref=True, verbose=True)
assert ', 12/15 out' in log.getvalue() # homogeneous fields removed
_assert_n_free(raw_sss, 70)
_assert_shielding(raw_sss, raw_ctf, 12)
raw_sss_auto = maxwell_filter(raw_ctf, origin=(0., 0., 0.04),
_assert_shielding(raw_sss, raw_ctf_0, 12, 13)
# if ignore_ref=True, we remove compensators because they will not
# work the way people expect (it puts noise back in the data!)
with pytest.raises(ValueError, match='Desired compensation.*not found'):
raw_sss.copy().apply_gradient_compensation(3)
raw_sss_auto = maxwell_filter(raw_ctf_0, origin=origin,
ignore_ref=True, mag_scale='auto')
assert_allclose(raw_sss._data, raw_sss_auto._data)
with catch_logging() as log:
maxwell_filter(raw_ctf, origin=(0., 0., 0.04), regularize=None,
maxwell_filter(raw_ctf_0, origin=origin, regularize=None,
ignore_ref=True, verbose=True)
assert '80/80 in, 12/15 out' in log.getvalue() # homogeneous fields

Expand Down

0 comments on commit fc981bd

Please sign in to comment.