Skip to content

Commit

Permalink
MRG: a few more _validate_types (mne-tools#5259)
Browse files Browse the repository at this point in the history
* a few more

* fix test

* fix

* pep8

* improve messages
  • Loading branch information
jona-sassenhagen authored and larsoner committed Jun 6, 2018
1 parent a10845b commit 6614a5b
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 53 deletions.
6 changes: 2 additions & 4 deletions mne/beamformer/_lcmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ..minimum_norm.inverse import combine_xyz, _check_reference
from ..cov import compute_whitener, compute_covariance
from ..source_estimate import _make_stc, SourceEstimate
from ..utils import logger, verbose, warn, estimate_rank
from ..utils import logger, verbose, warn, estimate_rank, _validate_type
from .. import Epochs
from ..externals import six
from ._compute_beamformer import (
Expand Down Expand Up @@ -182,9 +182,7 @@ def make_lcmv(info, forward, data_cov, reg=0.05, noise_cov=None, label=None,
'with rank reduction using reduce_rank '
'parameter is only implemented with '
'pick_ori=="max-power".')
if not isinstance(reduce_rank, bool):
raise ValueError('reduce_rank has to be True or False '
' (got %s).' % reduce_rank)
_validate_type(reduce_rank, bool, "reduce_rank", "a boolean")

# Compute spatial filters
W = np.dot(G.T, Cm_inv)
Expand Down
9 changes: 4 additions & 5 deletions mne/bem.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from .surface import (read_surface, write_surface, complete_surface_info,
_compute_nearest, _get_ico_surface, read_tri,
_fast_cross_nd_sum, _get_solids)
from .utils import verbose, logger, run_subprocess, get_subjects_dir, warn, _pl
from .utils import (verbose, logger, run_subprocess, get_subjects_dir, warn,
_pl, _validate_type)
from .fixes import einsum
from .externals.six import string_types

Expand Down Expand Up @@ -905,8 +906,7 @@ def get_fitting_dig(info, dig_kinds='auto', verbose=None):
.. versionadded:: 0.14
"""
if not isinstance(info, Info):
raise TypeError('info must be an instance of Info not %s' % type(info))
_validate_type(info, Info, "info", "Info")
if info['dig'] is None:
raise RuntimeError('Cannot fit headshape without digitization '
', info["dig"] is None')
Expand Down Expand Up @@ -1555,8 +1555,7 @@ def _prepare_env(subject, subjects_dir, requires_freesurfer):
raise RuntimeError('I cannot find freesurfer. The FREESURFER_HOME '
'environment variable is not set.')

if not isinstance(subject, string_types):
raise TypeError('The subject argument must be set')
_validate_type(subject, "str")

subjects_dir = get_subjects_dir(subjects_dir, raise_error=True)
if not op.isdir(subjects_dir):
Expand Down
26 changes: 10 additions & 16 deletions mne/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from os.path import splitext


from .utils import check_fname, logger, verbose, _get_stim_channel, warn
from .utils import (check_fname, logger, verbose, _get_stim_channel, warn,
_validate_type)
from .io.constants import FIFF
from .io.tree import dir_tree_find
from .io.tag import read_tag
Expand Down Expand Up @@ -720,9 +721,7 @@ def _mask_trigs(events, mask, mask_type):
raise ValueError('mask_type must be "not_and" or "and", got %s'
% (mask_type,))
if mask is not None:
if not isinstance(mask, int):
raise TypeError('You provided a(n) %s.' % type(mask) +
'Mask must be an int or None.')
_validate_type(mask, "int", "mask", "int or None")
n_events = len(events)
if n_events == 0:
return events.copy()
Expand Down Expand Up @@ -856,14 +855,10 @@ def make_fixed_length_events(raw, id=1, start=0, stop=None, duration=1.,
The new events.
"""
from .io.base import BaseRaw
if not isinstance(raw, BaseRaw):
raise ValueError('Input data must be an instance of Raw, got'
' %s instead.' % (type(raw)))
if not isinstance(id, int):
raise ValueError('id must be an integer')
if not isinstance(duration, (int, float)):
raise ValueError('duration must be an integer of a float, '
'got %s instead.' % (type(duration)))
_validate_type(raw, BaseRaw, "raw")
_validate_type(id, int, "id")
_validate_type(duration, "numeric", "duration")

start = raw.time_as_index(start, use_rounding=True)[0]
if stop is not None:
stop = raw.time_as_index(stop, use_rounding=True)[0]
Expand Down Expand Up @@ -913,8 +908,7 @@ def concatenate_events(events, first_samps, last_samps):
--------
mne.concatenate_raws
"""
if not isinstance(events, list):
raise ValueError('events must be a list of arrays')
_validate_type(events, list, "events")
if not (len(events) == len(last_samps) and
len(events) == len(first_samps)):
raise ValueError('events, first_samps, and last_samps must all have '
Expand Down Expand Up @@ -1103,8 +1097,8 @@ def __getitem__(self, item):
"""
if isinstance(item, str):
item = [item]
elif not isinstance(item, list):
raise ValueError('Keys must be category names')
else:
_validate_type(item, list, "Keys", "category names")
cats = list()
for it in item:
if it in self._categories:
Expand Down
3 changes: 1 addition & 2 deletions mne/evoked.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,8 +715,7 @@ def __init__(self, data, info, tmin=0., comment='', nave=1, kind='average',
self.verbose = verbose
self.preload = True
self._projector = None
if not isinstance(self.kind, string_types):
raise TypeError('kind must be a string, not "%s"' % (type(kind),))
_validate_type(self.kind, "str", "kind")
if self.kind not in _aspect_dict:
raise ValueError('unknown kind "%s", should be "average" or '
'"standard_error"' % (self.kind,))
Expand Down
32 changes: 10 additions & 22 deletions mne/forward/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@
from ..transforms import (transform_surface_to, invert_transform,
write_trans)
from ..utils import (_check_fname, get_subjects_dir, has_mne_c, warn,
run_subprocess, check_fname, logger, verbose)
run_subprocess, check_fname, logger, verbose,
_validate_type)
from ..label import Label


Expand Down Expand Up @@ -1371,16 +1372,14 @@ def restrict_forward_to_label(fwd, labels):
--------
restrict_forward_to_stc
"""
message = 'labels must be instance of Label or a list of Label.'
vertices = [np.array([], int), np.array([], int)]

if not isinstance(labels, list):
labels = [labels]

# Get vertices separately of each hemisphere from all label
for label in labels:
if not isinstance(label, Label):
raise TypeError(message + ' Instead received %s' % type(label))
_validate_type(label, Label, "label", "Label or list")
i = 0 if label.hemi == 'lh' else 1
vertices[i] = np.append(vertices[i], label.vertices)
# Remove duplicates and sort
Expand Down Expand Up @@ -1544,9 +1543,7 @@ def _do_forward_solution(subject, meas, fname=None, src=None, spacing=None,
if fname is None:
fname = op.join(temp_dir, 'temp-fwd.fif')
_check_fname(fname, overwrite)

if not isinstance(subject, string_types):
raise ValueError('subject must be a string')
_validate_type(subject, "str", "subject")

# check for meas to exist as string, or try to make evoked
if isinstance(meas, string_types):
Expand All @@ -1568,8 +1565,7 @@ def _do_forward_solution(subject, meas, fname=None, src=None, spacing=None,
raise ValueError('Either trans or mri must be specified')

if trans is not None:
if not isinstance(trans, string_types):
raise ValueError('trans must be a string')
_validate_type(trans, "str", "trans")
if not op.isfile(trans):
raise IOError('trans file "%s" not found' % trans)
if mri is not None:
Expand Down Expand Up @@ -1607,15 +1603,9 @@ def _do_forward_solution(subject, meas, fname=None, src=None, spacing=None,
mindist = ['--mindist', '%g' % mindist]

# src, spacing, bem
if src is not None:
if not isinstance(src, string_types):
raise ValueError('src must be a string or None')
if spacing is not None:
if not isinstance(spacing, string_types):
raise ValueError('spacing must be a string or None')
if bem is not None:
if not isinstance(bem, string_types):
raise ValueError('bem must be a string or None')
for element, name in zip((src, spacing, bem), ("src", "spacing", "bem")):
if element is not None:
_validate_type(element, "str", name, "string or None")

# put together the actual call
cmd = ['mne_do_forward_solution',
Expand Down Expand Up @@ -1693,8 +1683,7 @@ def average_forward_solutions(fwds, weights=None):
The averaged forward solution.
"""
# check for fwds being a list
if not isinstance(fwds, list):
raise TypeError('fwds must be a list')
_validate_type(fwds, list, "fwds")
if not len(fwds) > 0:
raise ValueError('fwds must not be empty')

Expand All @@ -1714,8 +1703,7 @@ def average_forward_solutions(fwds, weights=None):
# check our forward solutions
for fwd in fwds:
# check to make sure it's a forward solution
if not isinstance(fwd, dict):
raise TypeError('Each entry in fwds must be a dict')
_validate_type(fwd, dict, "each entry in fwds", "dict")
# check to make sure the dict is actually a fwd
check_keys = ['info', 'sol_grad', 'nchan', 'src', 'source_nn', 'sol',
'source_rr', 'source_ori', 'surf_ori', 'coord_frame',
Expand Down
8 changes: 4 additions & 4 deletions mne/tests/test_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,9 +407,9 @@ def test_make_fixed_length_events():
assert_raises(ValueError, make_fixed_length_events, raw, 1,
tmin, tmax - 1e-3, duration)
# not raw, bad id or duration
assert_raises(ValueError, make_fixed_length_events, raw, 2.3)
assert_raises(ValueError, make_fixed_length_events, 'not raw', 2)
assert_raises(ValueError, make_fixed_length_events, raw, 23, tmin, tmax,
assert_raises(TypeError, make_fixed_length_events, raw, 2.3)
assert_raises(TypeError, make_fixed_length_events, 'not raw', 2)
assert_raises(TypeError, make_fixed_length_events, raw, 23, tmin, tmax,
'abc')

# Let's try some ugly sample rate/sample count combos
Expand Down Expand Up @@ -499,7 +499,7 @@ def test_acqparser():
assert_raises(KeyError, acqp.__getitem__, 'does not exist')
assert_raises(KeyError, acqp.get_condition, raw, 'foo')
# category not a string
assert_raises(ValueError, acqp.__getitem__, 0)
assert_raises(TypeError, acqp.__getitem__, 0)
# number of events / categories
assert_equal(len(acqp), 7)
assert_equal(len(acqp.categories), 7)
Expand Down

0 comments on commit 6614a5b

Please sign in to comment.