forked from nilearn/nistats
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added the thresholding functionality
- Loading branch information
Showing
2 changed files
with
140 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
""" Test the thresholding utilities | ||
""" | ||
import numpy as np | ||
from scipy.stats import norm | ||
from nose.tools import assert_true | ||
from numpy.testing import assert_array_almost_equal, assert_almost_equal, assert_equal | ||
import nibabel as nib | ||
from ..thresholding import (fdr_threshold, map_threshold) | ||
|
||
|
||
def test_fdr(): | ||
n = 100 | ||
x = np.linspace(.5 / n, 1. - .5 / n, n) | ||
x[:10] = .0005 | ||
x = norm.isf(x) | ||
np.random.shuffle(x) | ||
assert_almost_equal(fdr_threshold(x, .1), norm.isf(.0005)) | ||
assert_true(fdr_threshold(x, .001) == np.infty) | ||
|
||
|
||
def test_map_threshold(): | ||
shape = (9, 10, 11) | ||
data = np.random.randn(*shape) | ||
threshold = norm.sf(data.max() + 1) | ||
data[2:4, 5:7, 6:8] = data.max() + 2 | ||
stat_img = nib.Nifti1Image(data, np.eye(4)) | ||
mask_img = nib.Nifti1Image(np.ones(shape), np.eye(4)) | ||
|
||
# test 1 | ||
th_map = map_threshold( | ||
stat_img, mask_img, threshold, height_control='fpr', | ||
cluster_threshold=0) | ||
vals = th_map.get_data() | ||
assert_equal(np.sum(vals > 0), 8) | ||
|
||
# test 2:excessive size threshold | ||
th_map = map_threshold( | ||
stat_img, mask_img, threshold, height_control='fpr', | ||
cluster_threshold=10) | ||
vals = th_map.get_data() | ||
assert_true(np.sum(vals > 0) == 0) | ||
|
||
# test 3: excessive cluster forming threshold | ||
th_map = map_threshold( | ||
stat_img, mask_img, 100, height_control='fpr', | ||
cluster_threshold=0) | ||
vals = th_map.get_data() | ||
assert_true(np.sum(vals > 0) == 0) | ||
|
||
# test 4: fdr threshold | ||
for control in ['fdr', 'bonferoni']: | ||
th_map = map_threshold( | ||
stat_img, mask_img, .05, height_control='fdr', | ||
cluster_threshold=5) | ||
vals = th_map.get_data() | ||
assert_equal(np.sum(vals > 0), 8) | ||
|
||
# test 5: direct threshold | ||
th_map = map_threshold( | ||
stat_img, mask_img, 5., height_control=None, | ||
cluster_threshold=5) | ||
vals = th_map.get_data() | ||
assert_equal(np.sum(vals > 0), 8) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
""" Utilities to describe the result of cluster-level analysis of statistical | ||
maps. | ||
Author: Bertrand Thirion, 2015 | ||
""" | ||
import numpy as np | ||
from scipy.ndimage import label | ||
from scipy.stats import norm | ||
from nilearn.input_data import NiftiMasker | ||
|
||
|
||
def fdr_threshold(z_vals, alpha): | ||
""" return the BH fdr for the input z_vals""" | ||
z_vals_ = - np.sort(- z_vals) | ||
p_vals = norm.sf(z_vals_) | ||
n_samples = len(p_vals) | ||
pos = p_vals < alpha * np.linspace( | ||
.5 / n_samples, 1 - .5 / n_samples, n_samples) | ||
if pos.any(): | ||
return (z_vals_[pos][-1] - 1.e-8) | ||
else: | ||
return np.infty | ||
|
||
|
||
def map_threshold(stat_img, mask_img, threshold, height_control='fpr', | ||
cluster_threshold=0): | ||
""" Threshold the provvided map | ||
Parameters | ||
---------- | ||
stat_img : Niimg-like object, | ||
statistical image (presumably in z scale) | ||
mask_img : Niimg-like object, | ||
mask image | ||
threshold: float, | ||
cluster forming threshold (either a p-value or z-scale value) | ||
height_control: string | ||
false positive control meaning of cluster forming | ||
threshold: 'fpr'|'fdr'|'bonferroni'|'none' | ||
cluster_threshold : float, optional | ||
cluster size threshold | ||
Returns | ||
------- | ||
thresholded_map : Nifti1Image, | ||
the stat_map theresholded at the prescribed voxel- and cluster-level | ||
""" | ||
# Masking | ||
masker = NiftiMasker(mask_img=mask_img) | ||
stats = np.ravel(masker.fit_transform(stat_img)) | ||
n_voxels = np.size(stats) | ||
|
||
# Thresholding | ||
if height_control == 'fpr': | ||
z_th = norm.isf(threshold) | ||
elif height_control == 'fdr': | ||
z_th = fdr_threshold(stats, threshold) | ||
elif height_control == 'bonferroni': | ||
z_th = norm.isf(threshold / n_voxels) | ||
else: # Brute-force thresholding | ||
z_th = threshold | ||
stats *= (stats > z_th) | ||
|
||
stat_map = masker.inverse_transform(stats).get_data() | ||
|
||
# Extract connected components above threshold | ||
label_map, n_labels = label(stat_map > z_th) | ||
labels = label_map[(masker.mask_img.get_data() > 0)] | ||
for label_ in range(1, n_labels + 1): | ||
if np.sum(labels == label_) < cluster_threshold: | ||
stats[labels == label_] = 0 | ||
|
||
return masker.inverse_transform(stats) |