Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#1511 from helinwang/dataset
Browse files Browse the repository at this point in the history
expose build_dict in imikolov dataset, fix bug that len(word_dict) is…
  • Loading branch information
helinwang authored Mar 3, 2017
2 parents 247a2a4 + 4cbbb23 commit 8530d3c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 19 deletions.
31 changes: 14 additions & 17 deletions python/paddle/v2/dataset/imikolov.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import paddle.v2.dataset.common
import tarfile

__all__ = ['train', 'test']
__all__ = ['train', 'test', 'build_dict']

URL = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz'
MD5 = '30177ea32e27c525793142b6bf2c8e2d'
Expand All @@ -37,35 +37,32 @@ def word_count(f, word_freq=None):
return word_freq


def build_dict(train_filename, test_filename):
def build_dict():
train_filename = './simple-examples/data/ptb.train.txt'
test_filename = './simple-examples/data/ptb.valid.txt'
with tarfile.open(
paddle.v2.dataset.common.download(
paddle.v2.dataset.imikolov.URL, 'imikolov',
paddle.v2.dataset.imikolov.MD5)) as tf:
trainf = tf.extractfile(train_filename)
testf = tf.extractfile(test_filename)
word_freq = word_count(testf, word_count(trainf))
if '<unk>' in word_freq:
# remove <unk> for now, since we will set it as last index
del word_freq['<unk>']

TYPO_FREQ = 50
word_freq = filter(lambda x: x[1] > TYPO_FREQ, word_freq.items())

dictionary = sorted(word_freq, key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*dictionary))
word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*word_freq_sorted))
word_idx = dict(zip(words, xrange(len(words))))
word_idx['<unk>'] = len(words)

return word_idx


word_idx = {}


def reader_creator(filename, n):
global word_idx
if len(word_idx) == 0:
word_idx = build_dict('./simple-examples/data/ptb.train.txt',
'./simple-examples/data/ptb.valid.txt')

def reader_creator(filename, word_idx, n):
def reader():
with tarfile.open(
paddle.v2.dataset.common.download(
Expand All @@ -84,9 +81,9 @@ def reader():
return reader


def train(n):
return reader_creator('./simple-examples/data/ptb.train.txt', n)
def train(word_idx, n):
return reader_creator('./simple-examples/data/ptb.train.txt', word_idx, n)


def test(n):
return reader_creator('./simple-examples/data/ptb.valid.txt', n)
def test(word_idx, n):
return reader_creator('./simple-examples/data/ptb.valid.txt', word_idx, n)
10 changes: 8 additions & 2 deletions python/paddle/v2/dataset/tests/imikolov_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import paddle.v2.dataset.imikolov
import unittest

WORD_DICT = paddle.v2.dataset.imikolov.build_dict()


class TestMikolov(unittest.TestCase):
def check_reader(self, reader, n):
Expand All @@ -9,11 +11,15 @@ def check_reader(self, reader, n):

def test_train(self):
n = 5
self.check_reader(paddle.v2.dataset.imikolov.train(n), n)
self.check_reader(paddle.v2.dataset.imikolov.train(WORD_DICT, n), n)

def test_test(self):
n = 5
self.check_reader(paddle.v2.dataset.imikolov.test(n), n)
self.check_reader(paddle.v2.dataset.imikolov.test(WORD_DICT, n), n)

def test_total(self):
_, idx = zip(*WORD_DICT.items())
self.assertEqual(sorted(idx)[-1], len(WORD_DICT) - 1)


if __name__ == '__main__':
Expand Down

0 comments on commit 8530d3c

Please sign in to comment.