Skip to content

Commit

Permalink
Merge pull request mne-tools#5147 from wmvanvliet/fix_get_peak
Browse files Browse the repository at this point in the history
[MRG] Fix SourceEstimate.get_peak() method
  • Loading branch information
wmvanvliet authored Apr 20, 2018
2 parents 8c07f60 + bb21937 commit a7d3c20
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 10 deletions.
2 changes: 2 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ Bug

- Fix bug in :class:`mne.make_forward_solution` when passing data with compensation channels (e.g. CTF) that contain bad channels by `Alex Gramfort`_

- Fix bug in :meth:`mne.SourceEstimate.get_peak` and :meth:`mne.VolSourceEstimate.get_peak` when there is only a single time point by `Marijn van Vliet`_

API
~~~

Expand Down
6 changes: 3 additions & 3 deletions mne/evoked.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,10 +1263,10 @@ def _get_peak(data, times, tmin=None, tmax=None, mode='abs'):
raise ValueError('The tmin value is out of bounds. It must be '
'within {0} and {1}'.format(times.min(), times.max()))
if tmax > times.max():
raise ValueError('The tmin value is out of bounds. It must be '
raise ValueError('The tmax value is out of bounds. It must be '
'within {0} and {1}'.format(times.min(), times.max()))
if tmin >= tmax:
raise ValueError('The tmin must be smaller than tmax')
if tmin > tmax:
raise ValueError('The tmin must be smaller or equal to tmax')

time_win = (times >= tmin) & (times <= tmax)
mask = np.ones_like(data).astype(np.bool)
Expand Down
19 changes: 12 additions & 7 deletions mne/tests/test_source_estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,24 +822,29 @@ def test_get_peak():
data = rng.randn(n_vert, n_times)
stc_surf = SourceEstimate(data, vertices=vertices, tmin=0, tstep=1,
subject='sample')

stc_vol = VolSourceEstimate(data, vertices=vertices[0], tmin=0, tstep=1,
subject='sample')

for ii, stc in enumerate([stc_surf, stc_vol]):
# Versions with only one time point
stc_surf_1 = SourceEstimate(data[:, :1], vertices=vertices, tmin=0,
tstep=1, subject='sample')
stc_vol_1 = VolSourceEstimate(data[:, :1], vertices=vertices[0], tmin=0,
tstep=1, subject='sample')

for ii, stc in enumerate([stc_surf, stc_vol, stc_surf_1, stc_vol_1]):
assert_raises(ValueError, stc.get_peak, tmin=-100)
assert_raises(ValueError, stc.get_peak, tmax=90)
assert_raises(ValueError, stc.get_peak, tmin=0.002, tmax=0.001)

vert_idx, time_idx = stc.get_peak()
vertno = np.concatenate(stc.vertices) if ii == 0 else stc.vertices
vertno = np.concatenate(stc.vertices) if ii in [0, 2] else stc.vertices
assert_true(vert_idx in vertno)
assert_true(time_idx in stc.times)

ch_idx, time_idx = stc.get_peak(vert_as_index=True,
time_as_index=True)
assert_true(vert_idx < stc.data.shape[0])
assert_true(time_idx < len(stc.times))
data_idx, time_idx = stc.get_peak(vert_as_index=True,
time_as_index=True)
assert_equal(data_idx, np.argmax(np.abs(stc.data[:, time_idx])))
assert_equal(time_idx, np.argmax(np.abs(stc.data[data_idx, :])))


@testing.requires_testing_data
Expand Down

0 comments on commit a7d3c20

Please sign in to comment.