Skip to content

Commit

Permalink
FIX: Fix raw sim with BEM and use_cps=True
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner committed Nov 16, 2017
1 parent 0e5894f commit 1e3cd0c
Show file tree
Hide file tree
Showing 8 changed files with 133 additions and 32 deletions.
2 changes: 2 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ BUG

- Fix bug in :meth:`mne.io.set_eeg_reference` to remove an average reference projector when setting the reference to ``[]`` (i.e. do not change the existing reference) by `Clemens Brunner`_

- Fix bug in :func:`mne.simulation.simulate_raw` where 1- and 3-layer BEMs were not properly transformed using ``trans`` by `Eric Larson`_

.. _changes_0_15:

Version 0.15
Expand Down
29 changes: 19 additions & 10 deletions mne/forward/_make_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,13 +219,15 @@ def _create_eeg_els(chs):

@verbose
def _setup_bem(bem, bem_extra, neeg, mri_head_t, verbose=None):
"""Set up a BEM for forward computation."""
"""Set up a BEM for forward computation, making a copy and modifying."""
logger.info('')
if isinstance(bem, string_types):
logger.info('Setting up the BEM model using %s...\n' % bem_extra)
bem = read_bem_solution(bem)
if not isinstance(bem, ConductorModel):
raise TypeError('bem must be a string or ConductorModel')
else:
if not isinstance(bem, ConductorModel):
raise TypeError('bem must be a string or ConductorModel')
bem = bem.copy()
if bem['is_sphere']:
logger.info('Using the sphere model.\n')
if len(bem['layers']) == 0 and neeg > 0:
Expand All @@ -234,6 +236,10 @@ def _setup_bem(bem, bem_extra, neeg, mri_head_t, verbose=None):
if bem['coord_frame'] != FIFF.FIFFV_COORD_HEAD:
raise RuntimeError('Spherical model is not in head coordinates')
else:
if bem['surfs'][0]['coord_frame'] != FIFF.FIFFV_COORD_MRI:
raise RuntimeError(
'BEM is in %s coordinates, should be in MRI'
% (_coord_frame_name(bem['surfs'][0]['coord_frame']),))
if neeg > 0 and len(bem['surfs']) == 1:
raise RuntimeError('Cannot use a homogeneous model in EEG '
'calculations')
Expand Down Expand Up @@ -558,7 +564,10 @@ def make_forward_solution(info, trans, src, bem, meg=True, eeg=True,
# read the transformation from MRI to HEAD coordinates
# (could also be HEAD to MRI)
mri_head_t, trans = _get_trans(trans)
bem_extra = 'dict' if isinstance(bem, dict) else bem
if isinstance(bem, ConductorModel):
bem_extra = 'instance of ConductorModel'
else:
bem_extra = bem
if not isinstance(info, (Info, string_types)):
raise TypeError('info should be an instance of Info or string')
if isinstance(info, string_types):
Expand All @@ -569,15 +578,15 @@ def make_forward_solution(info, trans, src, bem, meg=True, eeg=True,
n_jobs = check_n_jobs(n_jobs)

# Report the setup
logger.info('Source space : %s' % src)
logger.info('MRI -> head transform source : %s' % trans)
logger.info('Measurement data : %s' % info_extra)
if isinstance(bem, dict) and bem['is_sphere']:
logger.info('Sphere model : origin at %s mm'
logger.info('Source space : %s' % src)
logger.info('MRI -> head transform : %s' % trans)
logger.info('Measurement data : %s' % info_extra)
if isinstance(bem, ConductorModel) and bem['is_sphere']:
logger.info('Sphere model : origin at %s mm'
% (bem['r0'],))
logger.info('Standard field computations')
else:
logger.info('BEM model : %s' % bem_extra)
logger.info('Conductor model : %s' % bem_extra)
logger.info('Accurate field computations')
logger.info('Do computations in %s coordinates',
_coord_frame_name(FIFF.FIFFV_COORD_HEAD))
Expand Down
3 changes: 1 addition & 2 deletions mne/forward/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,8 +641,7 @@ def convert_forward_solution(fwd, surf_ori=False, force_fixed=False,

if surf_ori:
if use_cps is True:
if ('patch_inds' in fwd['src'][0] and
fwd['src'][0]['patch_inds'] is not None):
if fwd['src'][0].get('patch_inds') is not None:
use_ave_nn = True
logger.info(' Average patch normals will be employed in '
'the rotation to the local surface coordinates..'
Expand Down
12 changes: 8 additions & 4 deletions mne/simulation/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
_prepare_for_forward, _transform_orig_meg_coils,
_compute_forwards, _to_forward_dict)
from ..transforms import _get_trans, transform_surface_to
from ..source_space import _ensure_src, _points_outside_surface
from ..source_space import (_ensure_src, _points_outside_surface,
_adjust_patch_info)
from ..source_estimate import _BaseSourceEstimate
from ..utils import logger, verbose, check_random_state, warn, _pl
from ..parallel import check_n_jobs
Expand Down Expand Up @@ -365,7 +366,7 @@ def simulate_raw(raw, stc, trans, src, bem, cov='simple',
# XXX eventually we could speed this up by allowing the forward
# solution code to only compute the normal direction
fwd = convert_forward_solution(fwd, surf_ori=True, force_fixed=True,
verbose=False, use_cps=use_cps)
use_cps=use_cps, verbose=False)
if blink:
fwd_blink = fwd_blink['sol']['data']
for ii in range(len(blink_rrs)):
Expand Down Expand Up @@ -479,8 +480,9 @@ def _iter_forward_solutions(info, trans, src, bem, exg_bem, dev_head_ts,
idx = np.where(np.array([s['id'] for s in bem['surfs']]) ==
FIFF.FIFFV_BEM_SURF_ID_BRAIN)[0]
assert len(idx) == 1
# make a copy so it isn't mangled in use
bem_surf = transform_surface_to(bem['surfs'][idx[0]], coord_frame,
mri_head_t)
mri_head_t, copy=True)
for ti, dev_head_t in enumerate(dev_head_ts):
# Could be *slightly* more efficient not to do this N times,
# but the cost here is tiny compared to actual fwd calculation
Expand Down Expand Up @@ -538,7 +540,9 @@ def _restrict_source_space_to(src, vertices):
s['nuse'] = len(v)
s['vertno'] = v
s['inuse'][s['vertno']] = 1
for key in ('pinfo', 'nuse_tri', 'use_tris', 'patch_inds'):
for key in ('nuse_tri', 'use_tris'):
if key in s:
del s[key]
# This will fix 'patch_info' and 'pinfo'
_adjust_patch_info(s, verbose=False)
return src
44 changes: 43 additions & 1 deletion mne/simulation/tests/test_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@
from mne import (read_source_spaces, pick_types, read_trans, read_cov,
make_sphere_model, create_info, setup_volume_source_space,
find_events, Epochs, fit_dipole, transform_surface_to,
make_ad_hoc_cov, SourceEstimate, setup_source_space)
make_ad_hoc_cov, SourceEstimate, setup_source_space,
read_bem_solution, make_forward_solution,
convert_forward_solution)
from mne.chpi import _calculate_chpi_positions, read_head_pos, _get_hpi_info
from mne.tests.test_chpi import _assert_quats
from mne.datasets import testing
from mne.simulation import simulate_sparse_stc, simulate_raw
from mne.source_space import _compare_source_spaces
from mne.io import read_raw_fif, RawArray
from mne.time_frequency import psd_welch
from mne.utils import _TempDir, run_tests_if_main
Expand All @@ -38,6 +41,7 @@
bem_path = op.join(subjects_dir, 'sample', 'bem')
src_fname = op.join(bem_path, 'sample-oct-2-src.fif')
bem_fname = op.join(bem_path, 'sample-320-320-320-bem-sol.fif')
bem_1_fname = op.join(bem_path, 'sample-320-bem-sol.fif')

raw_chpi_fname = op.join(data_path, 'SSS', 'test_move_anon_raw.fif')
pos_fname = op.join(data_path, 'SSS', 'test_move_anon_raw_subsampled.pos')
Expand Down Expand Up @@ -249,6 +253,44 @@ def test_simulate_raw_bem():
assert_true(med_diff < tol, msg='%s: %s' % (bem, med_diff))


@testing.requires_testing_data
def test_simulate_round_trip():
"""Test simulate_raw round trip calculations."""
# Check a diagonal round-trip
raw, src, stc, trans, sphere = _get_data()
raw.pick_types(meg=True, stim=True)
bem = read_bem_solution(bem_1_fname)
old_bem = bem.copy()
old_src = src.copy()
old_trans = trans.copy()
fwd = make_forward_solution(raw.info, trans, src, bem)
# no omissions
assert (sum(len(s['vertno']) for s in src) ==
sum(len(s['vertno']) for s in fwd['src']) ==
36)
# make sure things were not modified
assert (old_bem['surfs'][0]['coord_frame'] ==
bem['surfs'][0]['coord_frame'])
assert trans == old_trans
_compare_source_spaces(src, old_src)
data = np.eye(fwd['nsource'])
raw.crop(0, (len(data) - 1) / raw.info['sfreq'])
stc = SourceEstimate(data, [s['vertno'] for s in fwd['src']],
0, 1. / raw.info['sfreq'])
for use_cps in (False, True):
this_raw = simulate_raw(raw, stc, trans, src, bem, cov=None,
use_cps=use_cps)
this_raw.pick_types(meg=True, eeg=True)
assert (old_bem['surfs'][0]['coord_frame'] ==
bem['surfs'][0]['coord_frame'])
assert trans == old_trans
_compare_source_spaces(src, old_src)
this_fwd = convert_forward_solution(fwd, force_fixed=True,
use_cps=use_cps)
assert_allclose(this_raw[:][0], this_fwd['sol']['data'],
atol=1e-12, rtol=1e-6)


@pytest.mark.slowtest
@testing.requires_testing_data
def test_simulate_raw_chpi():
Expand Down
21 changes: 14 additions & 7 deletions mne/source_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -2171,16 +2171,23 @@ def _filter_source_spaces(surf, limit, mri_head_t, src, n_jobs=1,
logger.info('%d source space point%s omitted because of the '
'%6.1f-mm distance limit.' % tuple(extras))
# Adjust the patch inds as well if necessary
if omit + omit_outside > 0 and s.get('patch_inds') is not None:
if s['nearest'] is None:
# This shouldn't happen, but if it does, we can probably come
# up with a more clever solution
raise RuntimeError('Cannot adjust patch information properly, '
'please contact the mne-python developers')
_add_patch_info(s)
if omit + omit_outside > 0:
_adjust_patch_info(s)
logger.info('Thank you for waiting.')


@verbose
def _adjust_patch_info(s, verbose=None):
"""Adjust patch information in place after vertex omission."""
if s.get('patch_inds') is not None:
if s['nearest'] is None:
# This shouldn't happen, but if it does, we can probably come
# up with a more clever solution
raise RuntimeError('Cannot adjust patch information properly, '
'please contact the mne-python developers')
_add_patch_info(s)


@verbose
def _points_outside_surface(rr, surf, n_jobs=1, verbose=None):
"""Check whether points are outside a surface.
Expand Down
8 changes: 2 additions & 6 deletions mne/tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,7 @@ def test_get_trans():
trans = read_trans(fname)
trans = invert_transform(trans) # starts out as head->MRI, so invert
trans_2 = _get_trans(fname_trans)[0]
assert_equal(trans['from'], trans_2['from'])
assert_equal(trans['to'], trans_2['to'])
assert_allclose(trans['trans'], trans_2['trans'], rtol=1e-5, atol=1e-5)
assert trans.__eq__(trans_2, atol=1e-5)


@testing.requires_testing_data
Expand All @@ -79,9 +77,7 @@ def test_io_trans():
trans1 = read_trans(fname1)

# check all properties
assert_true(trans0['from'] == trans1['from'])
assert_true(trans0['to'] == trans1['to'])
assert_array_equal(trans0['trans'], trans1['trans'])
assert trans0 == trans1

# check reading non -trans.fif files
assert_raises(IOError, read_trans, fname_eve)
Expand Down
46 changes: 44 additions & 2 deletions mne/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,48 @@ def __repr__(self): # noqa: D105
% (_coord_frame_name(self['from']),
_coord_frame_name(self['to']), self['trans']))

def __eq__(self, other, rtol=0., atol=0.):
"""Check for equality.
Parameter
---------
other : instance of Transform
The other transform.
rtol : float
Relative tolerance.
atol : float
Absolute tolerance.
Returns
-------
eq : bool
True if the transforms are equal.
"""
return (isinstance(other, Transform) and
self['from'] == other['from'] and
self['to'] == other['to'] and
np.allclose(self['trans'], other['trans'], rtol=rtol,
atol=atol))

def __ne__(self, other, rtol=0., atol=0.):
"""Check for inequality.
Parameter
---------
other : instance of Transform
The other transform.
rtol : float
Relative tolerance.
atol : float
Absolute tolerance.
Returns
-------
eq : bool
True if the transforms are not equal.
"""
return not self == other

@property
def from_str(self):
"""The "from" frame as a string."""
Expand Down Expand Up @@ -396,9 +438,9 @@ def _get_trans(trans, fro='mri', to='head'):
raise RuntimeError('File "%s" did not have 4x4 entries'
% trans)
fro_to_t = Transform(to, fro, t)
elif isinstance(trans, dict):
elif isinstance(trans, Transform):
fro_to_t = trans
trans = 'dict'
trans = 'instance of Transform'
elif trans is None:
fro_to_t = Transform(fro, to)
trans = 'identity'
Expand Down

0 comments on commit 1e3cd0c

Please sign in to comment.