Skip to content

Commit

Permalink
ENH: Allow "cuda" if requested
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner committed Feb 8, 2013
1 parent 5b3d158 commit ba18c3a
Showing 1 changed file with 27 additions and 15 deletions.
42 changes: 27 additions & 15 deletions mne/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,29 +55,41 @@ def parallel_func(func, n_jobs, verbose=None):
return parallel, my_func, n_jobs


def check_n_jobs(n_jobs):
def check_n_jobs(n_jobs, allow_cuda=False):
"""Check n_jobs in particular for negative values
Parameters
----------
n_jobs : int
The number of jobs
The number of jobs.
allow_cuda : bool
Allow n_jobs to be 'cuda'. Default: False.
Returns
-------
n_jobs : int
The checked number of jobs. Always positive.
The checked number of jobs. Always positive (or 'cuda' if
applicable.)
"""
try:
import multiprocessing
n_cores = multiprocessing.cpu_count()
if n_cores + n_jobs <= 0:
raise ValueError('If n_jobs has a negative value it must not be less '
'than the number of CPUs present. You\'ve got '
'%s CPUs' % n_cores)
n_jobs = n_cores + n_jobs
except ImportError:
logger.warn('multiprocessing not installed. Cannot run in '
'parallel.')
n_jobs = 1
if not isinstance(n_jobs, int):
if not allow_cuda:
raise ValueError('n_jobs must be an integer')
elif not isinstance(n_jobs, basestring) or n_jobs != 'cuda':
raise ValueError('n_jobs must be an integer, or "cuda"')
#else, we have n_jobs='cuda' and this is okay, so do nothing
else:
try:
import multiprocessing
n_cores = multiprocessing.cpu_count()
n_jobs = n_cores + n_jobs
if n_jobs <= 0:
raise ValueError('If n_jobs has a negative value it must not '
'be less than the number of CPUs present. '
'You\'ve got %s CPUs' % n_cores)
except ImportError:
# only warn if they tried to use more than 1 job
if n_jobs != 1:
logger.warn('multiprocessing not installed. Cannot run in '
'parallel.')
n_jobs = 1
return n_jobs

0 comments on commit ba18c3a

Please sign in to comment.