Skip to content

Commit

Permalink
Faster dataloader merge (NVIDIA#1)
Browse files Browse the repository at this point in the history
* threaded tf_dl+presplit sentences+shuffled dataset with resume

* elaborate in readme
  • Loading branch information
raulpuric authored and shoeybi committed Apr 23, 2019
1 parent fb4cbdc commit 66719e9
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 21 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ python pretrain_bert.py \
--tokenizer-model-type bert-large-uncased \
--vocab-size 30522 \
--train-data wikipedia \
--presplit-sentences \
--loose-json \
--text-key text \
--split 1000,1,1 \
Expand Down Expand Up @@ -79,7 +80,7 @@ This script runs BERT pretraining with a `sentencepiece` tokenizer. If no senten
# Collecting Wikipedia Training Data
We recommend following the wikipedia data extraction process specified by google research: "the recommended pre-processing is to download [the latest dump](https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles.xml.bz2), extract the text with [WikiExtractor.py](https://github.com/attardi/wikiextractor), and then apply any necessary cleanup to convert it into plain text."

We recommend using the `--json` argument when using WikiExtractor, which will dump the wikipedia data into loose json format (one json per line), making it more manageable and readily consumable by our codebase.
We recommend using the `--json` argument when using WikiExtractor, which will dump the wikipedia data into loose json format (one json per line), making it more manageable and readily consumable by our codebase. We recommend further preprocessing this json dataset by preprocessing the dataset with nltk punctuation standardization, and presplitting each document into newline separated sentences. This can be done with the provided script `./scripts/presplit_sentences_json.py` and will allow for faster data processing during training time. Pretraining with presplit data should be run with the `--presplit-sentences` flag as shown above.

Once the json dataset is ready make sure to set the path in line 27 of `data_utils/corpora.py`.

Expand Down
6 changes: 6 additions & 0 deletions arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ def add_data_args(parser):

group = parser.add_argument_group('data', 'data configurations')

group.add_argument('--shuffle', action='store_true',
help='Shuffle data. Shuffling is deterministic '
'based on seed and current epoch.')
group.add_argument('--train-data', nargs='+', required=True,
help='Filename (or whitespace separated filenames) '
'for training.')
Expand All @@ -208,6 +211,9 @@ def add_data_args(parser):
help='Use loose json (one json-formatted string per '
'newline), instead of tight json (data file is one '
'json string)')
group.add_argument('--presplit-sentences', action='store_true',
help='Dataset content consists of documents where '
'each document consists of newline separated sentences')
group.add_argument('--num-workers', type=int, default=2,
help="""Number of workers to use for dataloading""")
group.add_argument('--tokenizer-model-type', type=str,
Expand Down
12 changes: 7 additions & 5 deletions configure_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def make_data_loader(dataset, batch_size, args):

shuffle = args.shuffle
if shuffle:
sampler = torch.utils.data.RandomSampler(dataset)
sampler = data_utils.samplers.RandomSampler(dataset, replacement=True, num_samples=batch_size*args.train_iters)
else:
sampler = torch.utils.data.SequentialSampler(dataset)
world_size = args.world_size
Expand Down Expand Up @@ -81,8 +81,10 @@ def make_tfrecord_loaders(args):
'max_seq_len': args.seq_length,
'max_preds_per_seq': args.max_preds_per_seq,
'train': True,
'num_workers': args.num_workers,
'seed': args.seed+args.rank+1}
'num_workers': max(args.num_workers, 1),
'seed': args.seed + args.rank + 1,
'threaded_dl': args.num_workers > 0
}
train = data_utils.tf_dl.TFRecordDataLoader(args.train_data,
**data_set_args)
data_set_args['train'] = False
Expand Down Expand Up @@ -140,7 +142,8 @@ def make_loaders(args):
'vocab_size': args.vocab_size,
'model_type': args.tokenizer_model_type,
'cache_dir': args.cache_dir,
'max_preds_per_seq': args.max_preds_per_seq}
'max_preds_per_seq': args.max_preds_per_seq,
'presplit_sentences': args.presplit_sentences}

eval_set_args = copy.copy(data_set_args)
eval_set_args['split'] = [1.]
Expand Down Expand Up @@ -218,7 +221,6 @@ def configure_data():
'rank': -1,
'persist_state': 0,
'lazy': False,
'shuffle': False,
'transpose': False,
'data_set_type': 'supervised',
'seq_length': 256,
Expand Down
8 changes: 5 additions & 3 deletions data_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def get_dataset(path, **kwargs):
if supported_corpus(path):
return corpora.NAMED_CORPORA[path](**kwargs)
ext = get_ext(path)
if ext =='.json':
if '.json' in ext:
text = json_dataset(path, **kwargs)
elif ext in ['.csv', '.tsv']:
text = csv_dataset(path, **kwargs)
Expand Down Expand Up @@ -108,8 +108,10 @@ def get_dataset_from_path(path_):
if should_split(split):
ds = split_ds(ds, split)
if ds_type.lower() == 'bert':
ds = [bert_sentencepair_dataset(d, max_seq_len=seq_length) for d in ds]
presplit_sentences = kwargs['presplit_sentences'] if 'presplit_sentences' in kwargs else False
ds = [bert_sentencepair_dataset(d, max_seq_len=seq_length, presplit_sentences=presplit_sentences) for d in ds]
else:
if ds_type.lower() == 'bert':
ds = bert_sentencepair_dataset(ds, max_seq_len=seq_length)
presplit_sentences = kwargs['presplit_sentences'] if 'presplit_sentences' in kwargs else False
ds = bert_sentencepair_dataset(ds, max_seq_len=seq_length, presplit_sentences=presplit_sentences)
return ds, tokenizer
12 changes: 10 additions & 2 deletions data_utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ class bert_sentencepair_dataset(data.Dataset):
dataset_size (int): number of random sentencepairs in the dataset. Default: len(ds)*(len(ds)-1)
"""
def __init__(self, ds, max_seq_len=512, mask_lm_prob=.15, max_preds_per_seq=None, short_seq_prob=.01, dataset_size=None, **kwargs):
def __init__(self, ds, max_seq_len=512, mask_lm_prob=.15, max_preds_per_seq=None, short_seq_prob=.01, dataset_size=None, presplit_sentences=False, **kwargs):
self.ds = ds
self.ds_len = len(self.ds)
self.tokenizer = self.ds.GetTokenizer()
Expand All @@ -464,6 +464,7 @@ def __init__(self, ds, max_seq_len=512, mask_lm_prob=.15, max_preds_per_seq=None
self.dataset_size = dataset_size
if self.dataset_size is None:
self.dataset_size = self.ds_len * (self.ds_len-1)
self.presplit_sentences = presplit_sentences

def __len__(self):
return self.dataset_size
Expand Down Expand Up @@ -494,7 +495,14 @@ def __getitem__(self, idx):

def sentence_split(self, document):
"""split document into sentences"""
return tokenize.sent_tokenize(document)
lines = document.split('\n')
if self.presplit_sentences:
return [line for line in lines if line]
rtn = []
for line in lines:
if line != '':
rtn.extend(tokenize.sent_tokenize(line))
return rtn

def sentence_tokenize(self, sent, sentence_num=0, beginning=False, ending=False):
"""tokenize sentence and get token types"""
Expand Down
51 changes: 51 additions & 0 deletions data_utils/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,57 @@
from torch.utils import data
import numpy as np

class RandomSampler(data.sampler.Sampler):
r"""
Based off of pytorch RandomSampler and DistributedSampler. Essentially a RandomSampler,
but this class lets the user set an epoch like DistributedSampler
Samples elements randomly. If without replacement, then sample from a shuffled dataset.
If with replacement, then user can specify ``num_samples`` to draw.
Arguments:
data_source (Dataset): dataset to sample from
num_samples (int): number of samples to draw, default=len(dataset)
replacement (bool): samples are drawn with replacement if ``True``, default=False
"""

def __init__(self, data_source, replacement=False, num_samples=None):
self.data_source = data_source
self.replacement = replacement
self._num_samples = num_samples
self.epoch = -1

if self._num_samples is not None and replacement is False:
raise ValueError("With replacement=False, num_samples should not be specified, "
"since a random permute will be performed.")

if not isinstance(self.num_samples, int) or self.num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(self.num_samples))
if not isinstance(self.replacement, bool):
raise ValueError("replacement should be a boolean value, but got "
"replacement={}".format(self.replacement))

@property
def num_samples(self):
# dataset size might change at runtime
if self._num_samples is None:
return len(self.data_source)
return self._num_samples

def __iter__(self):
n = len(self.data_source)
g = torch.Generator()
if self.epoch >= 0:
g.manual_seed(self.epoch)
if self.replacement:
return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64, generator=g).tolist())
return iter(torch.randperm(n, generator=g).tolist())

def __len__(self):
return self.num_samples

def set_epoch(self, epoch):
self.epoch = epoch

class DistributedBatchSampler(data.sampler.BatchSampler):
"""
similar to normal implementation of distributed sampler, except implementation is at the
Expand Down
54 changes: 44 additions & 10 deletions data_utils/tf_dl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,16 @@
# limitations under the License.
"""PyTorch DataLoader for TFRecords"""

import queue
import threading

import tensorflow as tf
tf.enable_eager_execution()
import torch
import numpy as np

class TFRecordDataLoader(object):
def __init__(self, records, batch_size, max_seq_len, max_preds_per_seq, train, num_workers=2, seed=1):
def __init__(self, records, batch_size, max_seq_len, max_preds_per_seq, train, num_workers=2, seed=1, threaded_dl=False):
assert max_preds_per_seq is not None, "--max-preds-per-seq MUST BE SPECIFIED when using tfrecords"
tf.set_random_seed(seed)
if isinstance(records, str):
Expand Down Expand Up @@ -55,11 +59,18 @@ def __init__(self, records, batch_size, max_seq_len, max_preds_per_seq, train, n
'num_parallel_batches': num_workers,
'drop_remainder': train}
self.dataloader = self.dataset.apply(tf.contrib.data.map_and_batch(self.record_converter, **loader_args))
self.threaded_dl = threaded_dl
self.num_workers = num_workers

def __iter__(self):
data_iter = iter(self.dataloader)
for item in data_iter:
yield convert_tf_example_to_torch_tensors(item)
if self.threaded_dl:
data_iter = iter(MultiprocessLoader(self.dataloader, self.num_workers))
for item in data_iter:
yield item
else:
data_iter = iter(self.dataloader)
for item in data_iter:
yield convert_tf_example_to_torch_tensors(item)

class Record2Example(object):
def __init__(self, feature_map):
Expand All @@ -74,14 +85,37 @@ def __call__(self, record):
return example

def convert_tf_example_to_torch_tensors(example):
item = {k: torch.from_numpy(v.numpy()) for k,v in example.items()}
mask = torch.zeros_like(item['input_ids'])
mask_labels = torch.ones_like(item['input_ids'])*-1
for b, row in enumerate(item['masked_lm_positions'].long()):
item = {k: (v.numpy()) for k,v in example.items()}
mask = np.zeros_like(item['input_ids'])
mask_labels = np.ones_like(item['input_ids'])*-1
for b, row in enumerate(item['masked_lm_positions'].astype(int)):
for i, idx in enumerate(row):
if item['masked_lm_weights'][b, i] != 0:
mask[b, idx] = 1
mask_labels[b, idx] = item['masked_lm_ids'][b, i]
return {'text': item['input_ids'], 'types': item['segment_ids'],'is_random': item['next_sentence_labels'],
'pad_mask': 1-item['input_mask'], 'mask': mask, 'mask_labels': mask_labels}
output = {'text': item['input_ids'], 'types': item['segment_ids'],'is_random': item['next_sentence_labels'],
'pad_mask': 1-item['input_mask'], 'mask': mask, 'mask_labels': mask_labels}
return {k: torch.from_numpy(v) for k,v in output.items()}

class MultiprocessLoader(object):
def __init__(self, dataloader, num_workers=2):
self.dl = dataloader
self.queue_size = 2*num_workers

def __iter__(self):
output_queue = queue.Queue(self.queue_size)
output_thread = threading.Thread(target=_multiproc_iter,
args=(self.dl, output_queue))
output_thread.daemon = True
output_thread.start()

while output_thread.is_alive():
yield output_queue.get(block=True)
else:
print(RuntimeError('TF record data loader thread exited unexpectedly'))

def _multiproc_iter(dl, output_queue):
data_iter = iter(dl)
for item in data_iter:
tensors = convert_tf_example_to_torch_tensors(item)
output_queue.put(tensors, block=True)
2 changes: 2 additions & 0 deletions pretrain_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,8 @@ def main():
train_data.batch_sampler.start_iter = total_iters % len(train_data)
# For all epochs.
for epoch in range(start_epoch, args.epochs+1):
if args.shuffle:
train_data.batch_sampler.sampler.set_epoch(epoch+args.seed)
timers('epoch time').start()
iteration, skipped = train_epoch(epoch, model, optimizer,
train_data, lr_scheduler,
Expand Down

0 comments on commit 66719e9

Please sign in to comment.