Skip to content

Commit

Permalink
MRG: Allow connectivity=False (mne-tools#5096)
Browse files Browse the repository at this point in the history
* ENH: Allow connectivity=False

* FIX: Fix skip

* DOC: Revert removal

* DOC: whats_new
  • Loading branch information
larsoner authored and agramfort committed Apr 19, 2018
1 parent 5ec3ac4 commit a5fa0ce
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 24 deletions.
2 changes: 2 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ Changelog

- Add ability to supply a mask to the plot in :func:`mne.viz.plot_evoked_image` by `Jona Sassenhagen`_

- Add ``connectivity=False`` to cluster-based statistical functions to perform non-clustering stats by `Eric Larson`_

- Add :func:`mne.time_frequency.csd_morlet` and :func:`mne.time_frequency.csd_array_morlet` to estimate cross-spectral density using Morlet wavelets, by `Marijn van Vliet`_

- Add multidictionary time-frequency support to :func:`mne.inverse_sparse.tf_mixed_norm` by `Mathurin Massias`_ and `Daniel Strohmeier`_
Expand Down
49 changes: 28 additions & 21 deletions mne/stats/cluster_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,18 +209,21 @@ def _get_clusters_st(x_in, neighbors, max_step=1):

def _get_components(x_in, connectivity, return_list=True):
"""Get connected components from a mask and a connectivity matrix."""
from scipy.sparse.csgraph import connected_components
mask = np.logical_and(x_in[connectivity.row], x_in[connectivity.col])
data = connectivity.data[mask]
row = connectivity.row[mask]
col = connectivity.col[mask]
shape = connectivity.shape
idx = np.where(x_in)[0]
row = np.concatenate((row, idx))
col = np.concatenate((col, idx))
data = np.concatenate((data, np.ones(len(idx), dtype=data.dtype)))
connectivity = sparse.coo_matrix((data, (row, col)), shape=shape)
_, components = connected_components(connectivity)
if connectivity is False:
components = np.arange(len(x_in))
else:
from scipy.sparse.csgraph import connected_components
mask = np.logical_and(x_in[connectivity.row], x_in[connectivity.col])
data = connectivity.data[mask]
row = connectivity.row[mask]
col = connectivity.col[mask]
shape = connectivity.shape
idx = np.where(x_in)[0]
row = np.concatenate((row, idx))
col = np.concatenate((col, idx))
data = np.concatenate((data, np.ones(len(idx), dtype=data.dtype)))
connectivity = sparse.coo_matrix((data, (row, col)), shape=shape)
_, components = connected_components(connectivity)
if return_list:
start = np.min(components)
stop = np.max(components)
Expand Down Expand Up @@ -258,6 +261,7 @@ def _find_clusters(x, threshold, tail=0, connectivity=None, max_step=1,
If connectivity is a list, it is assumed that each entry stores the
indices of the spatial neighbors in a spatio-temporal dataset x.
Default is None, i.e, a regular lattice connectivity.
False means no connectivity.
max_step : int
If connectivity is a list, this defines the maximal number of steps
between vertices along the second dimension (typically time) to be
Expand Down Expand Up @@ -383,7 +387,7 @@ def _find_clusters(x, threshold, tail=0, connectivity=None, max_step=1,
if tfce is True:
# each point gets treated independently
clusters = np.arange(x.size)
if connectivity is None:
if connectivity is None or connectivity is False:
if x.ndim == 1:
# slices
clusters = [slice(c, c + 1) for c in clusters]
Expand Down Expand Up @@ -451,7 +455,7 @@ def _find_clusters_1dir(x, x_in, connectivity, max_step, t_power, ndimage):
if x.ndim > 1:
raise Exception("Data should be 1D when using a connectivity "
"to define clusters.")
if isinstance(connectivity, sparse.spmatrix):
if isinstance(connectivity, sparse.spmatrix) or connectivity is False:
clusters = _get_components(x_in, connectivity)
elif isinstance(connectivity, list): # use temporal adjacency
clusters = _get_clusters_st(x_in, connectivity, max_step)
Expand Down Expand Up @@ -768,7 +772,7 @@ def _permutation_cluster_test(X, threshold, n_permutations, tail, stat_fun,
X = [np.reshape(x, (x.shape[0], -1)) for x in X]
n_tests = X[0].shape[1]

if connectivity is not None:
if connectivity is not None and connectivity is not False:
connectivity = _setup_connectivity(connectivity, n_tests, n_times)

if (exclude is not None) and not exclude.size == n_tests:
Expand All @@ -792,7 +796,7 @@ def _permutation_cluster_test(X, threshold, n_permutations, tail, stat_fun,
buffer_size = None

# The stat should have the same shape as the samples for no conn.
if connectivity is None:
if connectivity is None or connectivity is False:
t_obs.shape = sample_shape

if exclude is not None:
Expand All @@ -801,7 +805,8 @@ def _permutation_cluster_test(X, threshold, n_permutations, tail, stat_fun,
include = None

# determine if connectivity itself can be separated into disjoint sets
if check_disjoint is True and connectivity is not None:
if check_disjoint is True and (connectivity is not None and
connectivity is not False):
partitions = _get_partitions_from_connectivity(connectivity, n_times)
else:
partitions = None
Expand All @@ -818,7 +823,7 @@ def _permutation_cluster_test(X, threshold, n_permutations, tail, stat_fun,
logger.info('Found %d clusters' % len(clusters))

# convert clusters to old format
if connectivity is not None:
if connectivity is not None and connectivity is not False:
# our algorithms output lists of indices by default
if out_type == 'mask':
clusters = _cluster_indices_to_mask(clusters, n_tests)
Expand Down Expand Up @@ -907,7 +912,7 @@ def get_progress_bar(seeds):
step_down_include = np.ones(n_tests, dtype=bool)
for ti in to_remove:
step_down_include[clusters[ti]] = False
if connectivity is None:
if connectivity is None and connectivity is not False:
step_down_include.shape = sample_shape
n_step_downs += 1
if step_down_p > 0:
Expand Down Expand Up @@ -1001,6 +1006,7 @@ def permutation_cluster_test(
Defines connectivity between features. The matrix is assumed to
be symmetric and only the upper triangular half is used.
Default is None, i.e, a regular lattice connectivity.
Can also be False to assume no connectivity.
n_jobs : int
Number of permutations to run in parallel (requires joblib package).
seed : int | instance of RandomState | None
Expand Down Expand Up @@ -1111,13 +1117,14 @@ def permutation_cluster_1samp_test(
stat_fun : callable | None
Function used to compute the statistical map (default None will use
:func:`mne.stats.ttest_1samp_no_p`).
connectivity : sparse matrix or None
connectivity : sparse matrix | None | False
Defines connectivity between features. The matrix is assumed to
be symmetric and only the upper triangular half is used.
This matrix must be square with dimension (n_vertices * n_times) or
(n_vertices). Default is None, i.e, a regular lattice connectivity.
Use square n_vertices matrix for datasets with a large temporal
extent to save on memory and computation time.
extent to save on memory and computation time. Can also be False
to assume no connectivity. Can also be False to assume no connectivity.
verbose : bool, str, int, or None
If not None, override default verbose level (see :func:`mne.verbose`
and :ref:`Logging documentation <tut_logging>` for more).
Expand Down
17 changes: 14 additions & 3 deletions mne/stats/tests/test_permutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

from numpy.testing import assert_array_equal, assert_allclose
import numpy as np
from scipy import stats
from scipy import stats, sparse

from mne.stats import permutation_cluster_1samp_test
from mne.stats.permutations import permutation_t_test, _ci, _bootstrap_ci
from mne.utils import run_tests_if_main

Expand Down Expand Up @@ -37,15 +38,25 @@ def test_permutation_t_test():
is_significant = p_values < 0.05
assert_array_equal(is_significant, [False, False, False, False, False])

X *= -1
t_obs, p_values, H0 = permutation_t_test(
-X, n_permutations=999, tail=-1, seed=0)
X, n_permutations=999, tail=-1, seed=0)
assert (p_values > 0).all()
assert len(H0) == 999
is_significant = p_values < 0.05
assert_array_equal(is_significant, [True, True, False, False, False])

# check equivalence with spatio_temporal_cluster_test
for connectivity in (sparse.eye(n_tests), False):
t_obs_clust, _, p_values_clust, _ = permutation_cluster_1samp_test(
X, n_permutations=999, seed=0, connectivity=connectivity)
# the cluster tests drop any clusters that don't get thresholded
keep = p_values < 1
assert_allclose(t_obs_clust, t_obs)
assert_allclose(p_values_clust, p_values[keep], atol=1e-2)

X = np.random.randn(18, 1)
t_obs, p_values, H0 = permutation_t_test(X[:, [0]], n_permutations='all')
t_obs, p_values, H0 = permutation_t_test(X, n_permutations='all')
t_obs_scipy, p_values_scipy = stats.ttest_1samp(X[:, 0], 0)
assert_allclose(t_obs[0], t_obs_scipy, 8)
assert_allclose(p_values[0], p_values_scipy, rtol=1e-2)
Expand Down
1 change: 1 addition & 0 deletions mne/tests/test_docstring_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def test_tabs():
minimum_phase
next_fast_len
parallel_func
permutation_t_test
pick_channels_evoked
plot_epochs_psd
plot_epochs_psd_topomap
Expand Down

0 comments on commit a5fa0ce

Please sign in to comment.