diff --git a/doc/changes/latest.inc b/doc/changes/latest.inc index 173f1b11679..43a89c00345 100644 --- a/doc/changes/latest.inc +++ b/doc/changes/latest.inc @@ -239,6 +239,8 @@ Bugs - Fix bug with :func:`mne.extract_label_time_course` where labels, STCs, and the source space were not checked for compatible ``subject`` attributes (:gh:`9284` by `Eric Larson`_) +- Fix bug with :func:`mne.grow_labels` where ``overlap=False`` could run forever or raise an error (:gh:`9317` by `Eric Larson`_) + - Fix compatibility bugs with :mod:`mne_realtime` (:gh:`8845` by `Eric Larson`_) - Fix bug with `mne.viz.Brain` where non-inflated surfaces had an X-offset imposed by default (:gh:`8794` by `Eric Larson`_) diff --git a/mne/label.py b/mne/label.py index 4842db6dafe..100122f27e1 100644 --- a/mne/label.py +++ b/mne/label.py @@ -1573,8 +1573,7 @@ def grow_labels(subject, seeds, extents, hemis, subjects_dir=None, n_jobs=1, # make sure the inputs are arrays if np.isscalar(seeds): seeds = [seeds] - # these can have different sizes so need to use object array - seeds = np.asarray([np.atleast_1d(seed) for seed in seeds], dtype='O') + seeds = [np.atleast_1d(seed) for seed in seeds] extents = np.atleast_1d(extents) hemis = np.atleast_1d(hemis) n_seeds = len(seeds) @@ -1636,7 +1635,7 @@ def grow_labels(subject, seeds, extents, hemis, subjects_dir=None, n_jobs=1, if overlap: # create the patches parallel, my_grow_labels, _ = parallel_func(_grow_labels, n_jobs) - seeds = np.array_split(seeds, n_jobs) + seeds = np.array_split(np.array(seeds, dtype='O'), n_jobs) extents = np.array_split(extents, n_jobs) hemis = np.array_split(hemis, n_jobs) names = np.array_split(names, n_jobs) @@ -1668,7 +1667,7 @@ def _grow_nonoverlapping_labels(subject, seeds_, extents_, hemis, vertices_, labels = [] for hemi in set(hemis): hemi_index = (hemis == hemi) - seeds = seeds_[hemi_index] + seeds = [seed for seed, h in zip(seeds_, hemis) if h == hemi] extents = extents_[hemi_index] names = names_[hemi_index] graph = graphs[hemi] # distance graph @@ -1698,6 +1697,10 @@ def _grow_nonoverlapping_labels(subject, seeds_, extents_, hemis, vertices_, # add neighbors within allowable distance row = graph[vert_from, :] for vert_to, dist in zip(row.indices, row.data): + # Prevent adding a point that has already been used + # (prevents infinite loop) + if (vert_to == seeds[label]).any(): + continue new_dist = old_dist + dist # abort if outside of extent diff --git a/mne/tests/test_label.py b/mne/tests/test_label.py index cd50c57cac1..9bc66d2afb7 100644 --- a/mne/tests/test_label.py +++ b/mne/tests/test_label.py @@ -845,6 +845,10 @@ def test_grow_labels(): l1 = l11 + l12 assert_array_equal(l1.vertices, l0.vertices) + # non-overlapping (gh-8848) + for overlap in (False, True): + grow_labels('fsaverage', [0], 1, 1, subjects_dir, overlap=overlap) + @testing.requires_testing_data def test_random_parcellation():