Skip to content

Commit

Permalink
Parallel preprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
edunov authored and myleott committed Sep 25, 2018
1 parent ee46c63 commit 862cad1
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 34 deletions.
32 changes: 29 additions & 3 deletions fairseq/data/indexed_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,15 @@ def data_file_path(prefix_path):
class IndexedDataset(torch.utils.data.Dataset):
"""Loader for TorchNet IndexedDataset"""

def __init__(self, path, fix_lua_indexing=False):
def __init__(self, path, fix_lua_indexing=False, read_data=True):
super().__init__()
self.fix_lua_indexing = fix_lua_indexing
self.read_index(path)
self.data_file = None
if read_data:
self.read_data(path)

def read_index(self, path):
with open(index_file_path(path), 'rb') as f:
magic = f.read(8)
assert magic == b'TNTIDX\x00\x00'
Expand All @@ -66,7 +72,6 @@ def __init__(self, path, fix_lua_indexing=False):
self.dim_offsets = read_longs(f, self.size + 1)
self.data_offsets = read_longs(f, self.size + 1)
self.sizes = read_longs(f, self.s)
self.read_data(path)

def read_data(self, path):
self.data_file = open(data_file_path(path), 'rb', buffering=0)
Expand All @@ -76,7 +81,8 @@ def check_index(self, i):
raise IndexError('index out of range')

def __del__(self):
self.data_file.close()
if self.data_file:
self.data_file.close()

def __getitem__(self, i):
self.check_index(i)
Expand Down Expand Up @@ -193,6 +199,26 @@ def add_item(self, tensor):
self.sizes.append(s)
self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size()))

def merge_file_(self, another_file):
index = IndexedDataset(another_file, read_data=False)
assert index.dtype == self.dtype

begin = self.data_offsets[-1]
for offset in index.data_offsets[1:]:
self.data_offsets.append(begin + offset)
self.sizes.extend(index.sizes)
begin = self.dim_offsets[-1]
for dim_offset in index.dim_offsets[1:]:
self.dim_offsets.append(begin + dim_offset)

with open(data_file_path(another_file), 'rb') as f:
while True:
data = f.read(1024)
if data:
self.out_file.write(data)
else:
break

def finalize(self, index_file):
self.out_file.close()
index = open(index_file, 'wb')
Expand Down
84 changes: 71 additions & 13 deletions fairseq/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
# can be found in the PATENTS file in the same directory.

from collections import Counter
import re
import os, re

import torch

from multiprocessing import Pool

SPACE_NORMALIZER = re.compile("\s+")

Expand All @@ -20,28 +20,74 @@ def tokenize_line(line):
return line.split()


def safe_readline(f):
pos = f.tell()
while True:
try:
return f.readline()
except UnicodeDecodeError:
pos -= 1
f.seek(pos) # search where this character begins

class Tokenizer:

@staticmethod
def add_file_to_dictionary(filename, dict, tokenize):
def add_file_to_dictionary_single_worker(filename, tokenize, eos_word, worker_id=0, num_workers=1):
counter = Counter()
with open(filename, 'r') as f:
for line in f:
size = os.fstat(f.fileno()).st_size
chunk_size = size // num_workers
offset = worker_id * chunk_size
end = offset + chunk_size
f.seek(offset)
if offset > 0:
safe_readline(f) # drop first incomplete line
line = f.readline()
while line:
for word in tokenize(line):
dict.add_symbol(word)
dict.add_symbol(dict.eos_word)
counter.update([word])
counter.update([eos_word])
if f.tell() > end:
break
line = f.readline()
return counter

@staticmethod
def add_file_to_dictionary(filename, dict, tokenize, num_workers):
def merge_result(counter):
for w, c in counter.items():
dict.add_symbol(w, c)
if num_workers > 1:
pool = Pool(processes=num_workers)
results = []
for worker_id in range(num_workers):
results.append(pool.apply_async(
Tokenizer.add_file_to_dictionary_single_worker,
(filename, tokenize, dict.eos_word, worker_id, num_workers)
))
pool.close()
pool.join()
for r in results:
merge_result(r.get())
else:
merge_result(Tokenizer.add_file_to_dictionary_single_worker(filename, tokenize, dict.eos_word))

@staticmethod
def binarize(filename, dict, consumer, tokenize=tokenize_line,
append_eos=True, reverse_order=False):
append_eos=True, reverse_order=False,
offset=0, end=-1):
nseq, ntok = 0, 0
replaced = Counter()

def replaced_consumer(word, idx):
if idx == dict.unk_index and word != dict.unk_word:
replaced.update([word])

with open(filename, 'r') as f:
for line in f:
f.seek(offset)
# next(f) breaks f.tell(), hence readline() must be used
line = safe_readline(f)
while line:
if end > 0 and f.tell() > end:
break
ids = Tokenizer.tokenize(
line=line,
dict=dict,
Expand All @@ -52,10 +98,22 @@ def replaced_consumer(word, idx):
reverse_order=reverse_order,
)
nseq += 1

consumer(ids)
ntok += len(ids)
return {'nseq': nseq, 'nunk': sum(replaced.values()), 'ntok': ntok, 'replaced': len(replaced)}
consumer(ids)
line = f.readline()
return {'nseq': nseq, 'nunk': sum(replaced.values()), 'ntok': ntok, 'replaced': replaced}

@staticmethod
def find_offsets(filename, num_chunks):
with open(filename, 'r') as f:
size = os.fstat(f.fileno()).st_size
chunk_size = size // num_chunks
offsets = [0 for _ in range(num_chunks + 1)]
for i in range(1, num_chunks):
f.seek(chunk_size * i)
safe_readline(f)
offsets[i] = f.tell()
return offsets

@staticmethod
def tokenize(line, dict, tokenize=tokenize_line, add_if_not_exist=True,
Expand Down
87 changes: 69 additions & 18 deletions preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@
"""

import argparse
from collections import Counter
from itertools import zip_longest
import os
import shutil


from fairseq.data import indexed_dataset, dictionary
from fairseq.tokenizer import Tokenizer, tokenize_line
from multiprocessing import Pool, Manager, Process



def get_parser():
Expand All @@ -41,6 +45,7 @@ def get_parser():
parser.add_argument('--only-source', action='store_true', help='Only process the source language')
parser.add_argument('--padding-factor', metavar='N', default=8, type=int,
help='Pad dictionary size to be multiple of N')
parser.add_argument('--workers', metavar='N', default=1, type=int, help='number of parallel workers')
return parser


Expand All @@ -52,7 +57,7 @@ def main(args):
def build_dictionary(filenames):
d = dictionary.Dictionary()
for filename in filenames:
Tokenizer.add_file_to_dictionary(filename, d, tokenize_line)
Tokenizer.add_file_to_dictionary(filename, d, tokenize_line, args.workers)
return d

def train_path(lang):
Expand All @@ -70,11 +75,6 @@ def dest_path(prefix, lang):
def dict_path(lang):
return dest_path('dict', lang) + '.txt'

def dataset_dest_path(output_prefix, lang, extension):
base = f'{args.destdir}/{output_prefix}'
lang_part = f'.{args.source_lang}-{args.target_lang}.{lang}' if lang is not None else ''
return f'{base}{lang_part}.{extension}'

if args.joined_dictionary:
assert not args.srcdict, 'cannot combine --srcdict and --joined-dictionary'
assert not args.tgtdict, 'cannot combine --tgtdict and --joined-dictionary'
Expand Down Expand Up @@ -111,25 +111,54 @@ def dataset_dest_path(output_prefix, lang, extension):
)
tgt_dict.save(dict_path(args.target_lang))

def make_binary_dataset(input_prefix, output_prefix, lang):
def make_binary_dataset(input_prefix, output_prefix, lang, num_workers):
dict = dictionary.Dictionary.load(dict_path(lang))
print('| [{}] Dictionary: {} types'.format(lang, len(dict) - 1))
n_seq_tok = [0, 0]
replaced = Counter()

ds = indexed_dataset.IndexedDatasetBuilder(dataset_dest_path(output_prefix, lang, 'bin'))

def consumer(tensor):
ds.add_item(tensor)
def merge_result(worker_result):
replaced.update(worker_result['replaced'])
n_seq_tok[0] += worker_result['nseq']
n_seq_tok[1] += worker_result['ntok']

input_file = '{}{}'.format(input_prefix, ('.' + lang) if lang is not None else '')
res = Tokenizer.binarize(input_file, dict, consumer)
offsets = Tokenizer.find_offsets(input_file, num_workers)
pool = None
if num_workers > 1:
pool = Pool(processes=num_workers-1)
for worker_id in range(1, num_workers):
prefix = "{}{}".format(output_prefix, worker_id)
pool.apply_async(binarize, (args, input_file, dict, prefix, lang,
offsets[worker_id],
offsets[worker_id + 1]), callback=merge_result)
pool.close()

ds = indexed_dataset.IndexedDatasetBuilder(dataset_dest_file(args, output_prefix, lang, 'bin'))
merge_result(Tokenizer.binarize(input_file, dict, lambda t: ds.add_item(t),
offset=0, end=offsets[1]))
if num_workers > 1:
pool.join()
for worker_id in range(1, num_workers):
prefix = "{}{}".format(output_prefix, worker_id)
temp_file_path = dataset_dest_prefix(args, prefix, lang)
ds.merge_file_(temp_file_path)
os.remove(indexed_dataset.data_file_path(temp_file_path))
os.remove(indexed_dataset.index_file_path(temp_file_path))


ds.finalize(dataset_dest_file(args, output_prefix, lang, 'idx'))


print('| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}'.format(
lang, input_file, res['nseq'], res['ntok'],
100 * res['nunk'] / res['ntok'], dict.unk_word))
ds.finalize(dataset_dest_path(output_prefix, lang, 'idx'))
lang, input_file, n_seq_tok[0], n_seq_tok[1],
100 * sum(replaced.values()) / n_seq_tok[1], dict.unk_word))



def make_dataset(input_prefix, output_prefix, lang):
def make_dataset(input_prefix, output_prefix, lang, num_workers=1):
if args.output_format == 'binary':
make_binary_dataset(input_prefix, output_prefix, lang)
make_binary_dataset(input_prefix, output_prefix, lang, num_workers)
elif args.output_format == 'raw':
# Copy original text file to destination folder
output_text_file = dest_path(
Expand All @@ -140,7 +169,7 @@ def make_dataset(input_prefix, output_prefix, lang):

def make_all(lang):
if args.trainpref:
make_dataset(args.trainpref, 'train', lang)
make_dataset(args.trainpref, 'train', lang, num_workers=args.workers)
if args.validpref:
for k, validpref in enumerate(args.validpref.split(',')):
outprefix = 'valid{}'.format(k) if k > 0 else 'valid'
Expand Down Expand Up @@ -196,6 +225,28 @@ def make_all(lang):
print('{} {}'.format(src_dict[k], tgt_dict[v]), file=f)



def binarize(args, filename, dict, output_prefix, lang, offset, end):

ds = indexed_dataset.IndexedDatasetBuilder(dataset_dest_file(args, output_prefix, lang, 'bin'))
def consumer(tensor):
ds.add_item(tensor)

res = Tokenizer.binarize(filename, dict, consumer, offset=offset, end=end)
ds.finalize(dataset_dest_file(args, output_prefix, lang, 'idx'))
return res

def dataset_dest_prefix(args, output_prefix, lang):
base = f'{args.destdir}/{output_prefix}'
lang_part = f'.{args.source_lang}-{args.target_lang}.{lang}' if lang is not None else ''
return f'{base}{lang_part}'


def dataset_dest_file(args, output_prefix, lang, extension):
base = dataset_dest_prefix(args, output_prefix, lang)
return f'{base}.{extension}'


if __name__ == '__main__':
parser = get_parser()
args = parser.parse_args()
Expand Down

0 comments on commit 862cad1

Please sign in to comment.