diff --git a/python/paddle/v2/dataset/imikolov.py b/python/paddle/v2/dataset/imikolov.py index 285d3eaca8317c..deb556942d9b04 100644 --- a/python/paddle/v2/dataset/imikolov.py +++ b/python/paddle/v2/dataset/imikolov.py @@ -17,7 +17,7 @@ import paddle.v2.dataset.common import tarfile -__all__ = ['train', 'test'] +__all__ = ['train', 'test', 'build_dict'] URL = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz' MD5 = '30177ea32e27c525793142b6bf2c8e2d' @@ -37,7 +37,9 @@ def word_count(f, word_freq=None): return word_freq -def build_dict(train_filename, test_filename): +def build_dict(): + train_filename = './simple-examples/data/ptb.train.txt' + test_filename = './simple-examples/data/ptb.valid.txt' with tarfile.open( paddle.v2.dataset.common.download( paddle.v2.dataset.imikolov.URL, 'imikolov', @@ -45,27 +47,22 @@ def build_dict(train_filename, test_filename): trainf = tf.extractfile(train_filename) testf = tf.extractfile(test_filename) word_freq = word_count(testf, word_count(trainf)) + if '' in word_freq: + # remove for now, since we will set it as last index + del word_freq[''] TYPO_FREQ = 50 word_freq = filter(lambda x: x[1] > TYPO_FREQ, word_freq.items()) - dictionary = sorted(word_freq, key=lambda x: (-x[1], x[0])) - words, _ = list(zip(*dictionary)) + word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0])) + words, _ = list(zip(*word_freq_sorted)) word_idx = dict(zip(words, xrange(len(words)))) word_idx[''] = len(words) return word_idx -word_idx = {} - - -def reader_creator(filename, n): - global word_idx - if len(word_idx) == 0: - word_idx = build_dict('./simple-examples/data/ptb.train.txt', - './simple-examples/data/ptb.valid.txt') - +def reader_creator(filename, word_idx, n): def reader(): with tarfile.open( paddle.v2.dataset.common.download( @@ -84,9 +81,9 @@ def reader(): return reader -def train(n): - return reader_creator('./simple-examples/data/ptb.train.txt', n) +def train(word_idx, n): + return reader_creator('./simple-examples/data/ptb.train.txt', word_idx, n) -def test(n): - return reader_creator('./simple-examples/data/ptb.valid.txt', n) +def test(word_idx, n): + return reader_creator('./simple-examples/data/ptb.valid.txt', word_idx, n) diff --git a/python/paddle/v2/dataset/tests/imikolov_test.py b/python/paddle/v2/dataset/tests/imikolov_test.py index 9b1748eaaa7f91..009e55243a594e 100644 --- a/python/paddle/v2/dataset/tests/imikolov_test.py +++ b/python/paddle/v2/dataset/tests/imikolov_test.py @@ -1,6 +1,8 @@ import paddle.v2.dataset.imikolov import unittest +WORD_DICT = paddle.v2.dataset.imikolov.build_dict() + class TestMikolov(unittest.TestCase): def check_reader(self, reader, n): @@ -9,11 +11,15 @@ def check_reader(self, reader, n): def test_train(self): n = 5 - self.check_reader(paddle.v2.dataset.imikolov.train(n), n) + self.check_reader(paddle.v2.dataset.imikolov.train(WORD_DICT, n), n) def test_test(self): n = 5 - self.check_reader(paddle.v2.dataset.imikolov.test(n), n) + self.check_reader(paddle.v2.dataset.imikolov.test(WORD_DICT, n), n) + + def test_total(self): + _, idx = zip(*WORD_DICT.items()) + self.assertEqual(sorted(idx)[-1], len(WORD_DICT) - 1) if __name__ == '__main__':