diff --git a/mne/datasets/utils.py b/mne/datasets/utils.py index 4d530f2c9d9..33f4076c159 100644 --- a/mne/datasets/utils.py +++ b/mne/datasets/utils.py @@ -19,7 +19,7 @@ from .. import __version__ as mne_version from ..label import read_labels_from_annot, Label, write_labels_to_annot from ..utils import (get_config, set_config, _fetch_file, logger, warn, - verbose, get_subjects_dir) + verbose, get_subjects_dir, md5sum) from ..externals.six import string_types from ..externals.six.moves import input @@ -369,9 +369,25 @@ def _data_path(path=None, force_update=False, update_path=True, download=True, assert len(url) == len(archive_name) assert len(url) == len(folder_orig) assert len(url) == len(folder_path) - for u, fp, an, h, fo in zip(url, folder_path, archive_name, hash_, + assert len(url) > 0 + # 1. Get all the archives + full_name = list() + for u, an, h, fo in zip(url, archive_name, hash_, folder_orig): + remove_archive, full = _download(path, u, an, h) + full_name.append(full) + del archive_name + # 2. Extract all of the files + remove_dir = True + for u, fp, an, h, fo in zip(url, folder_path, full_name, hash_, folder_orig): - _download_and_extract(path, name, u, fp, an, h, fo) + _extract(path, name, fp, an, fo, remove_dir) + remove_dir = False # only do on first iteration + # 3. Remove all of the archives + if remove_archive: + for an in full_name: + os.remove(op.join(path, an)) + + logger.info('Successfully extracted to: %s' % folder_path) _do_path_update(path, update_path, key, name) path = folder_path[0] @@ -389,35 +405,41 @@ def _data_path(path=None, force_update=False, update_path=True, download=True, return (path, data_version) if return_version else path -def _download_and_extract(path, name, url, folder_path, archive_name, hash_, - folder_orig): - """Download and extract an archive.""" - logger.info('Downloading or reinstalling ' - 'data archive %s at location %s' % (archive_name, path)) - rm_archive = False +def _download(path, url, archive_name, hash_): + """Download and extract an archive, completing the filename.""" martinos_path = '/cluster/fusion/sample_data/' + archive_name neurospin_path = '/neurospin/tmp/gramfort/' + archive_name + remove_archive = False if op.exists(martinos_path): - archive_name = martinos_path + full_name = martinos_path elif op.exists(neurospin_path): - archive_name = neurospin_path + full_name = neurospin_path else: - archive_name = op.join(path, archive_name) - rm_archive = True + full_name = op.join(path, archive_name) + remove_archive = True fetch_archive = True - if op.exists(archive_name): - msg = ('Archive already exists. Overwrite it (y/[n])? ') - answer = input(msg) - if answer.lower() == 'y': - os.remove(archive_name) - else: - fetch_archive = False - + if op.exists(full_name): + logger.info('Archive exists (%s), checking hash %s.' + % (archive_name, hash_,)) + md5 = md5sum(full_name) + fetch_archive = False + if md5 != hash_: + if input('Archive already exists but the hash does not match: ' + '%s\nOverwrite (y/[n])?' + % (archive_name,)).lower() == 'y': + os.remove(full_name) + fetch_archive = True if fetch_archive: - _fetch_file(url, archive_name, print_destination=False, + logger.info('Downloading archive %s to %s' % (archive_name, path)) + _fetch_file(url, full_name, print_destination=False, hash_=hash_) + return remove_archive, full_name + + +def _extract(path, name, folder_path, archive_name, folder_orig, remove_dir): + if op.exists(folder_path) and remove_dir: + logger.info('Removing old directory: %s' % (folder_path,)) - if op.exists(folder_path): def onerror(func, path, exc_info): """Deal with access errors (e.g. testing dataset read-only).""" # Is the error an access error ? @@ -459,10 +481,6 @@ def onerror(func, path, exc_info): if folder_orig is not None: shutil.move(op.join(path, folder_orig), folder_path) - if rm_archive: - os.remove(archive_name) - logger.info('Successfully extracted to: %s' % folder_path) - def _get_version(name): """Get a dataset version.""" diff --git a/mne/utils.py b/mne/utils.py index 9a5ab515142..f06426fbaf2 100644 --- a/mne/utils.py +++ b/mne/utils.py @@ -1929,7 +1929,7 @@ def _fetch_file(url, file_name, print_destination=True, resume=True, finally: u.close() del u - logger.info('Downloading data from %s (%s)\n' + logger.info('Downloading %s (%s)' % (url, sizeof_fmt(file_size))) # Triage resume @@ -1961,7 +1961,7 @@ def _fetch_file(url, file_name, print_destination=True, resume=True, # check md5sum if hash_ is not None: - logger.info('Verifying download hash.') + logger.info('Verifying hash %s.' % (hash_,)) md5 = md5sum(temp_file_name) if hash_ != md5: raise RuntimeError('Hash mismatch for downloaded file %s, '