Skip to content

Commit

Permalink
add wmt:en_de val set (facebookresearch#359)
Browse files Browse the repository at this point in the history
  • Loading branch information
hadasah authored and alexholdenmiller committed Oct 24, 2017
1 parent 5f4fcfe commit 0646424
Showing 1 changed file with 37 additions and 18 deletions.
55 changes: 37 additions & 18 deletions parlai/tasks/wmt/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,25 @@

import parlai.core.build_data as build_data
import os
import numpy


def readFiles(dpath, rfnames):
en_fname, de_fname = rfnames
url_base = 'https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/'
en_url = url_base + en_fname
de_url = url_base + de_fname
build_data.download(en_url, dpath, en_fname)
build_data.download(de_url, dpath, de_fname)
with open(os.path.join(dpath, en_fname), 'r') as f:
# We replace '##AT##-##AT##' as a workaround in order to use the
# nltk tokenizer specified by DictionaryAgent
en = [l[:-1].replace("##AT##-##AT##", "__AT__") for l in f]

with open(os.path.join(dpath, de_fname), 'r') as f:
de = [l[:-1].replace("##AT##-##AT##", "__AT__") for l in f]

return list(zip(de, en))


def build(opt):
Expand All @@ -22,26 +41,26 @@ def build(opt):
build_data.make_dir(dpath)

# Download the data.
fnames = [('train.en','train.de', 'en_de_train.txt'),
('newstest2014.en','newstest2014.de', 'en_de_test.txt')]
for (en_fname, de_fname, w_fname) in fnames:
url_base = 'https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/'
en_url = url_base + en_fname
de_url = url_base + de_fname
build_data.download(en_url, dpath, en_fname)
build_data.download(de_url, dpath, de_fname)
with open(os.path.join(dpath, en_fname), 'r') as f:
# We replace '##AT##-##AT##' as a workaround in order to use the
# nltk tokenizer specified by DictionaryAgent
en = [l[:-1].replace("##AT##-##AT##", "__AT__") for l in f]

with open(os.path.join(dpath, de_fname), 'r') as f:
de = [l[:-1].replace("##AT##-##AT##", "__AT__") for l in f]

with open(os.path.join(dpath, w_fname), 'w') as f:
for de_sent,en_sent in zip(de,en):

train_r_fnames = ('train.en', 'train.de')
train_w_fname = 'en_de_train.txt'
valid_w_fname = 'en_de_valid.txt'
test_r_fnames = ('newstest2014.en', 'newstest2014.de')
test_w_fname = 'en_de_test.txt'

train_zip = readFiles(dpath, train_r_fnames)
numpy.random.shuffle(train_zip)
with open(os.path.join(dpath, valid_w_fname), 'w') as f:
for de_sent, en_sent in train_zip[:30000]:
f.write("1 "+en_sent+"\t"+de_sent+"\n")
with open(os.path.join(dpath, train_w_fname), 'w') as f:
for de_sent, en_sent in train_zip[30000:]:
f.write("1 "+en_sent+"\t"+de_sent+"\n")

test_zip = readFiles(dpath, test_r_fnames)
with open(os.path.join(dpath, test_w_fname), 'w') as f:
for de_sent, en_sent in test_zip:
f.write("1 "+en_sent+"\t"+de_sent+"\n")

# Mark the data as built.
build_data.mark_done(dpath, version_string=version)

0 comments on commit 0646424

Please sign in to comment.