Skip to content

Commit

Permalink
ENH: Better test coverage for utils and fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner committed Nov 21, 2013
1 parent a50cb91 commit 49e2abc
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 88 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ $(CURDIR)/examples/MNE-sample-data/MEG/sample/sample_audvis_raw.fif:
ln -s ${PWD}/examples/MNE-sample-data ${PWD}/MNE-sample-data -f

test: in sample_data
rm .coverage
$(NOSETESTS) mne

test-no-sample: in
Expand Down
105 changes: 42 additions & 63 deletions mne/fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,56 +16,40 @@
from operator import itemgetter
import inspect

import warnings
import numpy as np
import scipy
from scipy import linalg
from math import ceil, log
from numpy.fft import irfft
from scipy.signal import filtfilt as sp_filtfilt
from distutils.version import LooseVersion
from functools import partial
import copy_reg

try:
Counter = collections.Counter
except AttributeError:
class Counter(collections.defaultdict):
"""Partial replacement for Python 2.7 collections.Counter."""
def __init__(self, iterable=(), **kwargs):
super(Counter, self).__init__(int, **kwargs)
self.update(iterable)

def most_common(self):
return sorted(self.iteritems(), key=itemgetter(1), reverse=True)

def update(self, other):
"""Adds counts for elements in other"""
if isinstance(other, self.__class__):
for x, n in other.iteritems():
self[x] += n
else:
for x in other:
self[x] += 1

class _Counter(collections.defaultdict):
"""Partial replacement for Python 2.7 collections.Counter."""
def __init__(self, iterable=(), **kwargs):
super(_Counter, self).__init__(int, **kwargs)
self.update(iterable)

def lsqr(X, y, tol=1e-3):
import scipy.sparse.linalg as sp_linalg
from ..utils.extmath import safe_sparse_dot
def most_common(self):
return sorted(self.iteritems(), key=itemgetter(1), reverse=True)

if hasattr(sp_linalg, 'lsqr'):
# scipy 0.8 or greater
return sp_linalg.lsqr(X, y)
else:
n_samples, n_features = X.shape
if n_samples > n_features:
coef, _ = sp_linalg.cg(safe_sparse_dot(X.T, X),
safe_sparse_dot(X.T, y),
tol=tol)
def update(self, other):
"""Adds counts for elements in other"""
if isinstance(other, self.__class__):
for x, n in other.iteritems():
self[x] += n
else:
coef, _ = sp_linalg.cg(safe_sparse_dot(X, X.T), y, tol=tol)
coef = safe_sparse_dot(X.T, coef)
for x in other:
self[x] += 1

residues = y - safe_sparse_dot(X, coef)
return coef, None, None, residues
try:
Counter = collections.Counter
except AttributeError:
Counter = _Counter


def _unique(ar, return_index=False, return_inverse=False):
Expand Down Expand Up @@ -110,15 +94,7 @@ def _unique(ar, return_index=False, return_inverse=False):
flag = np.concatenate(([True], ar[1:] != ar[:-1]))
return ar[flag]

np_version = []
for x in np.__version__.split('.'):
try:
np_version.append(int(x))
except ValueError:
# x may be of the form dev-1ea1592
np_version.append(x)

if np_version[:2] < (1, 5):
if LooseVersion(np.__version__) < LooseVersion('1.5'):
unique = _unique
else:
unique = np.unique
Expand All @@ -133,7 +109,7 @@ def _bincount(X, weights=None, minlength=None):
out[:len(result)] = result
return out

if np_version[:2] < (1, 6):
if LooseVersion(np.__version__) < LooseVersion('1.6'):
bincount = _bincount
else:
bincount = np.bincount
Expand Down Expand Up @@ -205,26 +181,29 @@ def _unravel_index(indices, dims):
return tuple(unraveled_coords.T)


if np_version[:2] < (1, 4):
if LooseVersion(np.__version__) < LooseVersion('1.4'):
unravel_index = _unravel_index
else:
unravel_index = np.unravel_index


def qr_economic(A, **kwargs):
"""Compat function for the QR-decomposition in economic mode
def _qr_economic_old(A, **kwargs):
"""
Compat function for the QR-decomposition in economic mode
Scipy 0.9 changed the keyword econ=True to mode='economic'
"""
import scipy.linalg
# trick: triangular solve has introduced in 0.9
if hasattr(scipy.linalg, 'solve_triangular'):
return scipy.linalg.qr(A, mode='economic', **kwargs)
else:
import warnings
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
return scipy.linalg.qr(A, econ=True, **kwargs)
with warnings.catch_warnings(True):
return linalg.qr(A, econ=True, **kwargs)


def _qr_economic_new(A, **kwargs):
return linalg.qr(A, mode='economic', **kwargs)


if LooseVersion(scipy.__version__) < LooseVersion('0.9'):
qr_economic = _qr_economic_old
else:
qr_economic = _qr_economic_new


def savemat(file_name, mdict, oned_as="column", **kwargs):
Expand Down Expand Up @@ -372,7 +351,8 @@ def _firwin2(numtaps, freq, gain, nfreqs=None, window='hamming', nyq=1.0):

if nfreqs is not None and numtaps >= nfreqs:
raise ValueError('ntaps must be less than nfreqs, but firwin2 was '
'called with ntaps=%d and nfreqs=%s' % (numtaps, nfreqs))
'called with ntaps=%d and nfreqs=%s'
% (numtaps, nfreqs))

if freq[0] != 0 or freq[-1] != nyq:
raise ValueError('freq must start with 0 and end with `nyq`.')
Expand All @@ -385,7 +365,7 @@ def _firwin2(numtaps, freq, gain, nfreqs=None, window='hamming', nyq=1.0):

if numtaps % 2 == 0 and gain[-1] != 0.0:
raise ValueError("A filter with an even number of coefficients must "
"have zero gain at the Nyquist rate.")
"have zero gain at the Nyquist rate.")

if nfreqs is None:
nfreqs = 1 + 2 ** int(ceil(log(numtaps, 2)))
Expand Down Expand Up @@ -539,9 +519,8 @@ def _reduce_partial(p):

def normalize_colors(vmin, vmax, clip=False):
"""Helper to handle matplotlib API"""
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
if 'Normalize' in vars(plt):
return plt.Normalize(vmin, vmax, clip=clip)
else:
return plt.normalize(vmin, vmax, clip=clip)

70 changes: 63 additions & 7 deletions mne/tests/test_fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,68 @@

from nose.tools import assert_equal
from numpy.testing import assert_array_equal
from distutils.version import LooseVersion
from scipy import signal

from ..fixes import _in1d, _tril_indices, _copysign, _unravel_index
from ..fixes import (_in1d, _tril_indices, _copysign, _unravel_index,
_Counter, _unique, _bincount)
from ..fixes import _firwin2 as mne_firwin2
from ..fixes import _filtfilt as mne_filtfilt


def test_counter():
"""Test Counter replacement"""
import collections
try:
Counter = collections.Counter
except:
pass
else:
a = Counter([1, 2, 1, 3])
b = _Counter([1, 2, 1, 3])
for key, count in zip([1, 2, 3], [2, 1, 1]):
assert_equal(a[key], b[key])


def test_unique():
"""Test unique() replacement
"""
# skip test for np version < 1.5
if LooseVersion(np.__version__) < LooseVersion('1.5'):
return
for arr in [np.array([]), np.random.rand(10), np.ones(10)]:
# basic
assert_array_equal(np.unique(arr), _unique(arr))
# with return_index=True
x1, x2 = np.unique(arr, return_index=True, return_inverse=False)
y1, y2 = _unique(arr, return_index=True, return_inverse=False)
assert_array_equal(x1, y1)
assert_array_equal(x2, y2)
# with return_inverse=True
x1, x2 = np.unique(arr, return_index=False, return_inverse=True)
y1, y2 = _unique(arr, return_index=False, return_inverse=True)
assert_array_equal(x1, y1)
assert_array_equal(x2, y2)
# with both:
x1, x2, x3 = np.unique(arr, return_index=True, return_inverse=True)
y1, y2, y3 = _unique(arr, return_index=True, return_inverse=True)
assert_array_equal(x1, y1)
assert_array_equal(x2, y2)
assert_array_equal(x3, y3)


def test_bincount():
"""Test bincount() replacement
"""
# skip test for np version < 1.6
if LooseVersion(np.__version__) < LooseVersion('1.6'):
return
for minlength in [None, 100]:
x = _bincount(np.ones(10, int), None, minlength)
y = np.bincount(np.ones(10, int), None, minlength)
assert_array_equal(x, y)


def test_in1d():
"""Test numpy.in1d() replacement"""
a = np.arange(10)
Expand All @@ -40,12 +95,12 @@ def test_tril_indices():
def test_unravel_index():
"""Test numpy.unravel_index() replacement"""
assert_equal(_unravel_index(2, (2, 3)), (0, 2))
assert_equal(_unravel_index(2,(2,2)), (1,0))
assert_equal(_unravel_index(254,(17,94)), (2,66))
assert_equal(_unravel_index((2*3 + 1)*6 + 4, (4,3,6)), (2,1,4))
assert_array_equal(_unravel_index(np.array([22, 41, 37]), (7,6)),
[[3, 6, 6],[4, 5, 1]])
assert_array_equal(_unravel_index(1621, (6,7,8,9)), (3,1,4,1))
assert_equal(_unravel_index(2, (2, 2)), (1, 0))
assert_equal(_unravel_index(254, (17, 94)), (2, 66))
assert_equal(_unravel_index((2 * 3 + 1) * 6 + 4, (4, 3, 6)), (2, 1, 4))
assert_array_equal(_unravel_index(np.array([22, 41, 37]), (7, 6)),
[[3, 6, 6], [4, 5, 1]])
assert_array_equal(_unravel_index(1621, (6, 7, 8, 9)), (3, 1, 4, 1))


def test_copysign():
Expand All @@ -64,6 +119,7 @@ def test_firwin2():
taps2 = signal.firwin2(150, [0.0, 0.5, 1.0], [1.0, 1.0, 0.0])
assert_array_equal(taps1, taps2)


def test_filtfilt():
"""Test IIR filtfilt replacement
"""
Expand Down
49 changes: 45 additions & 4 deletions mne/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from ..utils import (set_log_level, set_log_file, _TempDir,
get_config, set_config, deprecated, _fetch_file,
sum_squared, requires_mem_gb)
sum_squared, requires_mem_gb, estimate_rank,
_url_to_local_path, sizeof_fmt, _check_fname)
from ..fiff import Evoked, show_fiff

warnings.simplefilter('always') # enable b/c these tests throw warnings
Expand All @@ -27,6 +28,21 @@ def clean_lines(lines):
return [l if 'Reading ' not in l else 'Reading test file' for l in lines]


def test_tempdir():
"""Test TempDir
"""
tempdir2 = _TempDir()
assert_true(op.isdir(tempdir2))
tempdir2.cleanup()
assert_true(not op.isdir(tempdir2))


def test_estimate_rank():
data = np.eye(10)
data[0, 0] = 0
assert_equal(estimate_rank(data), 9)


def test_logging():
"""Test logging (to file)
"""
Expand Down Expand Up @@ -183,13 +199,38 @@ def test_fetch_file():
except urllib2.URLError:
from nose.plugins.skip import SkipTest
raise SkipTest('No internet connection, skipping download test.')
url = "http://github.com/mne-tools/mne-python/blob/master/README.rst"
archive_name = op.join(tempdir, "download_test")
_fetch_file(url, archive_name, print_destination=False)

urls = ['http://github.com/mne-tools/mne-python/blob/master/README.rst',
'ftp://surfer.nmr.mgh.harvard.edu/pub/data/bert.recon.md5sum.txt']
for url in urls:
archive_name = op.join(tempdir, "download_test")
_fetch_file(url, archive_name, print_destination=False)
assert_raises(Exception, _fetch_file, 'http://0.0',
op.join(tempdir, 'test'))
resume_name = op.join(tempdir, "download_resume")
# touch file
with file(resume_name + '.part', 'w'):
os.utime(resume_name + '.part', None)
_fetch_file(url, resume_name, print_destination=False, resume=True)


def test_sum_squared():
"""Optimized sum of squares
"""
X = np.random.randint(0, 50, (3, 3))
assert_equal(np.sum(X ** 2), sum_squared(X))


def test_sizeof_fmt():
"""Test sizeof_fmt
"""
assert_equal(sizeof_fmt(0), '0 bytes')
assert_equal(sizeof_fmt(1), '1 byte')
assert_equal(sizeof_fmt(1000), '1000 bytes')


def test_url_to_local_path():
"""Test URL to local path
"""
assert_equal(_url_to_local_path('http://google.com/home/why.html', '.'),
op.join('.', 'home', 'why.html'))
Loading

0 comments on commit 49e2abc

Please sign in to comment.