Skip to content

Commit

Permalink
Added the thresholding functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
bthirion committed Dec 18, 2015
1 parent 61c93f0 commit fcac819
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 0 deletions.
63 changes: 63 additions & 0 deletions nistats/tests/test_thresholding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
""" Test the thresholding utilities
"""
import numpy as np
from scipy.stats import norm
from nose.tools import assert_true
from numpy.testing import assert_array_almost_equal, assert_almost_equal, assert_equal
import nibabel as nib
from ..thresholding import (fdr_threshold, map_threshold)


def test_fdr():
n = 100
x = np.linspace(.5 / n, 1. - .5 / n, n)
x[:10] = .0005
x = norm.isf(x)
np.random.shuffle(x)
assert_almost_equal(fdr_threshold(x, .1), norm.isf(.0005))
assert_true(fdr_threshold(x, .001) == np.infty)


def test_map_threshold():
shape = (9, 10, 11)
data = np.random.randn(*shape)
threshold = norm.sf(data.max() + 1)
data[2:4, 5:7, 6:8] = data.max() + 2
stat_img = nib.Nifti1Image(data, np.eye(4))
mask_img = nib.Nifti1Image(np.ones(shape), np.eye(4))

# test 1
th_map = map_threshold(
stat_img, mask_img, threshold, height_control='fpr',
cluster_threshold=0)
vals = th_map.get_data()
assert_equal(np.sum(vals > 0), 8)

# test 2:excessive size threshold
th_map = map_threshold(
stat_img, mask_img, threshold, height_control='fpr',
cluster_threshold=10)
vals = th_map.get_data()
assert_true(np.sum(vals > 0) == 0)

# test 3: excessive cluster forming threshold
th_map = map_threshold(
stat_img, mask_img, 100, height_control='fpr',
cluster_threshold=0)
vals = th_map.get_data()
assert_true(np.sum(vals > 0) == 0)

# test 4: fdr threshold
for control in ['fdr', 'bonferoni']:
th_map = map_threshold(
stat_img, mask_img, .05, height_control='fdr',
cluster_threshold=5)
vals = th_map.get_data()
assert_equal(np.sum(vals > 0), 8)

# test 5: direct threshold
th_map = map_threshold(
stat_img, mask_img, 5., height_control=None,
cluster_threshold=5)
vals = th_map.get_data()
assert_equal(np.sum(vals > 0), 8)
77 changes: 77 additions & 0 deletions nistats/thresholding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
""" Utilities to describe the result of cluster-level analysis of statistical
maps.
Author: Bertrand Thirion, 2015
"""
import numpy as np
from scipy.ndimage import label
from scipy.stats import norm
from nilearn.input_data import NiftiMasker


def fdr_threshold(z_vals, alpha):
""" return the BH fdr for the input z_vals"""
z_vals_ = - np.sort(- z_vals)
p_vals = norm.sf(z_vals_)
n_samples = len(p_vals)
pos = p_vals < alpha * np.linspace(
.5 / n_samples, 1 - .5 / n_samples, n_samples)
if pos.any():
return (z_vals_[pos][-1] - 1.e-8)
else:
return np.infty


def map_threshold(stat_img, mask_img, threshold, height_control='fpr',
cluster_threshold=0):
""" Threshold the provvided map
Parameters
----------
stat_img : Niimg-like object,
statistical image (presumably in z scale)
mask_img : Niimg-like object,
mask image
threshold: float,
cluster forming threshold (either a p-value or z-scale value)
height_control: string
false positive control meaning of cluster forming
threshold: 'fpr'|'fdr'|'bonferroni'|'none'
cluster_threshold : float, optional
cluster size threshold
Returns
-------
thresholded_map : Nifti1Image,
the stat_map theresholded at the prescribed voxel- and cluster-level
"""
# Masking
masker = NiftiMasker(mask_img=mask_img)
stats = np.ravel(masker.fit_transform(stat_img))
n_voxels = np.size(stats)

# Thresholding
if height_control == 'fpr':
z_th = norm.isf(threshold)
elif height_control == 'fdr':
z_th = fdr_threshold(stats, threshold)
elif height_control == 'bonferroni':
z_th = norm.isf(threshold / n_voxels)
else: # Brute-force thresholding
z_th = threshold
stats *= (stats > z_th)

stat_map = masker.inverse_transform(stats).get_data()

# Extract connected components above threshold
label_map, n_labels = label(stat_map > z_th)
labels = label_map[(masker.mask_img.get_data() > 0)]
for label_ in range(1, n_labels + 1):
if np.sum(labels == label_) < cluster_threshold:
stats[labels == label_] = 0

return masker.inverse_transform(stats)

0 comments on commit fcac819

Please sign in to comment.