Skip to content

Commit

Permalink
Make stc.times read-only and stc.data writable.
Browse files Browse the repository at this point in the history
Users sometimes attempt to modify an existing STC, which involved
hacking the `._data` property. This commit makes `.data` writable
(while performing some sanity checks).

.tmin and .tstep are now also properties. Writing to them
calls ._update_times().

Also fixes the time-range of stc.transform().
  • Loading branch information
wmvanvliet committed Jul 31, 2017
1 parent 6e61eb5 commit e3c3346
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 68 deletions.
2 changes: 1 addition & 1 deletion mne/simulation/tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_metrics():
stc_bad = stc2.copy().crop(0, 0.5)
assert_raises(ValueError, source_estimate_quantification, stc1, stc_bad)
stc_bad = stc2.copy()
stc_bad.times -= 0.1
stc_bad.tmin -= 0.1
assert_raises(ValueError, source_estimate_quantification, stc1, stc_bad)
assert_raises(ValueError, source_estimate_quantification, stc1, stc2,
metric='foo')
Expand Down
114 changes: 83 additions & 31 deletions mne/source_estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,14 +457,14 @@ def __init__(self, data, vertices=None, tmin=None, tstep=None,
'must match' % (n_src, data.shape[0]))

self._data = data
self.tmin = tmin
self.tstep = tstep
self._tmin = tmin
self._tstep = tstep
self.vertices = vertices
self.verbose = verbose
self._kernel = kernel
self._sens_data = sens_data
self._kernel_removed = False
self.times = None
self._times = None
self._update_times()
self.subject = _check_subject(None, subject, False)

Expand Down Expand Up @@ -496,9 +496,8 @@ def crop(self, tmin=None, tmax=None):
if self._kernel is not None and self._sens_data is not None:
self._sens_data = self._sens_data[:, mask]
else:
self._data = self._data[:, mask]
self.data = self.data[:, mask]

self._update_times()
return self # return self for chaining methods

@verbose
Expand Down Expand Up @@ -535,11 +534,10 @@ def resample(self, sfreq, npad='auto', window='boxcar', n_jobs=1,
self._remove_kernel_sens_data_()

o_sfreq = 1.0 / self.tstep
self._data = resample(self._data, sfreq, o_sfreq, npad, n_jobs=n_jobs)
self.data = resample(self.data, sfreq, o_sfreq, npad, n_jobs=n_jobs)

# adjust indirectly affected variables
self.tstep = 1.0 / sfreq
self._update_times()
return self

@property
Expand All @@ -551,16 +549,73 @@ def data(self):
self._remove_kernel_sens_data_()
return self._data

@data.setter
def data(self, value):
value = np.asarray(value)
if self._data is not None and value.ndim != self._data.ndim:
raise ValueError('Data array should have %d dimensions.' %
self._data.ndim)

# vertices can be a single number, so cast to ndarray
if isinstance(self.vertices, list):
n_verts = sum([len(v) for v in self.vertices])
elif isinstance(self.vertices, np.ndarray):
n_verts = len(self.vertices)
else:
raise ValueError('Vertices must be a list or numpy array')

if value.shape[0] != n_verts:
raise ValueError('The first dimension of the data array must '
'match the number of vertices (%d != %d)' %
(value.shape[0], n_verts))

self._data = value
self._update_times()

@property
def shape(self):
"""Shape of the data."""
if self._data is not None:
return self._data.shape
return (self._kernel.shape[0], self._sens_data.shape[1])

@property
def tmin(self):
"""The first timestamp."""
return self._tmin

@tmin.setter
def tmin(self, value):
self._tmin = float(value)
self._update_times()

@property
def tstep(self):
"""The change in time between two consecutive samples (1 / sfreq)."""
return self._tstep

@tstep.setter
def tstep(self, value):
if value <= 0:
raise ValueError('.tstep must be greater than 0.')
self._tstep = float(value)
self._update_times()

@property
def times(self):
"""A timestamp for each sample."""
return self._times

@times.setter
def times(self, value):
raise ValueError('You cannot write to the .times attribute directly. '
'This property automatically updates whenever '
'.tmin, .tstep or .data changes.')

def _update_times(self):
"""Update the times attribute after changing tmin, tmax, or tstep."""
self.times = self.tmin + (self.tstep * np.arange(self.shape[1]))
self._times = self.tmin + (self.tstep * np.arange(self.shape[1]))
self._times.flags.writeable = False

def __add__(self, a):
"""Add source estimates."""
Expand All @@ -572,9 +627,9 @@ def __iadd__(self, a): # noqa: D105
self._remove_kernel_sens_data_()
if isinstance(a, _BaseSourceEstimate):
_verify_source_estimate_compat(self, a)
self._data += a.data
self.data += a.data
else:
self._data += a
self.data += a
return self

def mean(self):
Expand Down Expand Up @@ -604,9 +659,9 @@ def __isub__(self, a): # noqa: D105
self._remove_kernel_sens_data_()
if isinstance(a, _BaseSourceEstimate):
_verify_source_estimate_compat(self, a)
self._data -= a.data
self.data -= a.data
else:
self._data -= a
self.data -= a
return self

def __truediv__(self, a): # noqa: D105
Expand All @@ -625,9 +680,9 @@ def __idiv__(self, a): # noqa: D105
self._remove_kernel_sens_data_()
if isinstance(a, _BaseSourceEstimate):
_verify_source_estimate_compat(self, a)
self._data /= a.data
self.data /= a.data
else:
self._data /= a
self.data /= a
return self

def __mul__(self, a):
Expand All @@ -640,9 +695,9 @@ def __imul__(self, a): # noqa: D105
self._remove_kernel_sens_data_()
if isinstance(a, _BaseSourceEstimate):
_verify_source_estimate_compat(self, a)
self._data *= a.data
self.data *= a.data
else:
self._data *= a
self.data *= a
return self

def __pow__(self, a): # noqa: D105
Expand All @@ -652,7 +707,7 @@ def __pow__(self, a): # noqa: D105

def __ipow__(self, a): # noqa: D105
self._remove_kernel_sens_data_()
self._data **= a
self.data **= a
return self

def __radd__(self, a): # noqa: D105
Expand All @@ -671,7 +726,7 @@ def __neg__(self): # noqa: D105
"""Negate the source estimate."""
stc = copy.deepcopy(self)
stc._remove_kernel_sens_data_()
stc._data *= -1
stc.data *= -1
return stc

def __pos__(self): # noqa: D105
Expand Down Expand Up @@ -870,7 +925,8 @@ def transform(self, func, idx=None, tmin=None, tmax=None, copy=False):
if tmax is None:
tmax_idx = None
else:
tmax_idx = t_idx[-1]
# +1, because upper boundary needs to include the last sample
tmax_idx = t_idx[-1] + 1

data_t = self.transform_data(func, idx=idx, tmin_idx=tmin_idx,
tmax_idx=tmax_idx)
Expand All @@ -887,13 +943,8 @@ def transform(self, func, idx=None, tmin=None, tmax=None, copy=False):
verts = [verts_lh, verts_rh]

tmin_idx = 0 if tmin_idx is None else tmin_idx
tmax_idx = -1 if tmax_idx is None else tmax_idx

tmin = self.times[tmin_idx]

times = np.arange(self.times[tmin_idx],
self.times[tmax_idx] + self.tstep / 2, self.tstep)

if data_t.ndim > 2:
# return list of stcs if transformed data has dimensionality > 2
if copy:
Expand All @@ -906,8 +957,9 @@ def transform(self, func, idx=None, tmin=None, tmax=None, copy=False):
else:
# return new or overwritten stc
stcs = self if not copy else self.copy()
stcs._data, stcs.vertices = data_t, verts
stcs.tmin, stcs.times = tmin, times
stcs.vertices = verts
stcs.data = data_t
stcs.tmin = tmin

return stcs

Expand Down Expand Up @@ -1168,8 +1220,8 @@ def expand(self, vertices):
self.vertices[vi] = np.insert(v_old, inds, v_new)
inds = [ii + offset for ii, offset in zip(inserters, offsets[:-1])]
inds = np.concatenate(inds)
new_data = np.zeros((len(inds), self._data.shape[1]))
self._data = np.insert(self._data, inds, new_data, axis=0)
new_data = np.zeros((len(inds), self.data.shape[1]))
self.data = np.insert(self.data, inds, new_data, axis=0)
return self

@verbose
Expand Down Expand Up @@ -1379,7 +1431,7 @@ def to_original_src(self, src_orig, subject_orig=None,
subject_orig = _ensure_src_subject(src_orig, subject_orig)
data_idx, vertices = _get_morph_src_reordering(
self.vertices, src_orig, subject_orig, self.subject, subjects_dir)
return SourceEstimate(self._data[data_idx], vertices,
return SourceEstimate(self.data[data_idx], vertices,
self.tmin, self.tstep, subject_orig)

@verbose
Expand Down Expand Up @@ -2017,8 +2069,8 @@ def _morph_sparse(stc, subject_from, subject_to, subjects_dir=None):
vertno_k = _sparse_argmax_nnz_row(map_hemi[stc.vertices[k]])
order = np.argsort(vertno_k)
n_active_hemi = len(vertno_k)
data_hemi = stc_morph._data[cnt:cnt + n_active_hemi]
stc_morph._data[cnt:cnt + n_active_hemi] = data_hemi[order]
data_hemi = stc_morph.data[cnt:cnt + n_active_hemi]
stc_morph.data[cnt:cnt + n_active_hemi] = data_hemi[order]
stc_morph.vertices[k] = vertno_k[order]
cnt += n_active_hemi
else:
Expand Down
Loading

0 comments on commit e3c3346

Please sign in to comment.