From f84fe7ce1739b4df8b821e1898c17d7503c6af09 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Wed, 2 Sep 2015 20:15:14 -0700 Subject: [PATCH] Change cPickle import pattern in datasets --- keras/datasets/cifar.py | 6 +++--- keras/datasets/imdb.py | 4 ++-- keras/datasets/mnist.py | 6 +++--- keras/datasets/reuters.py | 10 +++++----- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/keras/datasets/cifar.py b/keras/datasets/cifar.py index 3065c779dbf..b9ceb39855a 100644 --- a/keras/datasets/cifar.py +++ b/keras/datasets/cifar.py @@ -1,15 +1,15 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import import sys -import six.moves.cPickle +from six.moves import cPickle from six.moves import range def load_batch(fpath, label_key='labels'): f = open(fpath, 'rb') if sys.version_info < (3,): - d = six.moves.cPickle.load(f) + d = cPickle.load(f) else: - d = six.moves.cPickle.load(f, encoding="bytes") + d = cPickle.load(f, encoding="bytes") # decode utf8 for k, v in d.items(): del(d[k]) diff --git a/keras/datasets/imdb.py b/keras/datasets/imdb.py index d589cab8959..70c19349b25 100644 --- a/keras/datasets/imdb.py +++ b/keras/datasets/imdb.py @@ -1,5 +1,5 @@ from __future__ import absolute_import -import six.moves.cPickle +import cPickle import gzip from .data_utils import get_file import random @@ -17,7 +17,7 @@ def load_data(path="imdb.pkl", nb_words=None, skip_top=0, maxlen=None, test_spli else: f = open(path, 'rb') - X, labels = six.moves.cPickle.load(f) + X, labels = cPickle.load(f) f.close() np.random.seed(seed) diff --git a/keras/datasets/mnist.py b/keras/datasets/mnist.py index 3b77ba359cf..0cdc119b8ea 100644 --- a/keras/datasets/mnist.py +++ b/keras/datasets/mnist.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- import gzip from .data_utils import get_file -import six.moves.cPickle +from six.moves import cPickle import sys @@ -14,9 +14,9 @@ def load_data(path="mnist.pkl.gz"): f = open(path, 'rb') if sys.version_info < (3,): - data = six.moves.cPickle.load(f) + data = cPickle.load(f) else: - data = six.moves.cPickle.load(f, encoding="bytes") + data = cPickle.load(f, encoding="bytes") f.close() diff --git a/keras/datasets/reuters.py b/keras/datasets/reuters.py index 40bb07bc6cc..4e0651587ec 100644 --- a/keras/datasets/reuters.py +++ b/keras/datasets/reuters.py @@ -5,7 +5,7 @@ import string import random import os -import six.moves.cPickle +from six.moves import cPickle from six.moves import zip import numpy as np @@ -78,8 +78,8 @@ def make_reuters_dataset(path=os.path.join('datasets', 'temp', 'reuters21578'), dataset = (X, labels) print('-') print('Saving...') - six.moves.cPickle.dump(dataset, open(os.path.join('datasets', 'data', 'reuters.pkl'), 'w')) - six.moves.cPickle.dump(tokenizer.word_index, open(os.path.join('datasets', 'data', 'reuters_word_index.pkl'), 'w')) + cPickle.dump(dataset, open(os.path.join('datasets', 'data', 'reuters.pkl'), 'w')) + cPickle.dump(tokenizer.word_index, open(os.path.join('datasets', 'data', 'reuters_word_index.pkl'), 'w')) def load_data(path="reuters.pkl", nb_words=None, skip_top=0, maxlen=None, test_split=0.2, seed=113, @@ -88,7 +88,7 @@ def load_data(path="reuters.pkl", nb_words=None, skip_top=0, maxlen=None, test_s path = get_file(path, origin="https://s3.amazonaws.com/text-datasets/reuters.pkl") f = open(path, 'rb') - X, labels = six.moves.cPickle.load(f) + X, labels = cPickle.load(f) f.close() np.random.seed(seed) @@ -140,7 +140,7 @@ def load_data(path="reuters.pkl", nb_words=None, skip_top=0, maxlen=None, test_s def get_word_index(path="reuters_word_index.pkl"): path = get_file(path, origin="https://s3.amazonaws.com/text-datasets/reuters_word_index.pkl") f = open(path, 'rb') - return six.moves.cPickle.load(f) + return cPickle.load(f) if __name__ == "__main__":