Skip to content

Commit

Permalink
Integrate code from t5_main into existing code.
Browse files Browse the repository at this point in the history
  • Loading branch information
jaredcasper committed Apr 16, 2021
1 parent f32a638 commit 48a5e0d
Show file tree
Hide file tree
Showing 10 changed files with 883 additions and 154 deletions.
9 changes: 6 additions & 3 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,11 @@ def parse_args(extra_args_provider=None, defaults={},
flush=True)

# If we do accumulation and all-reduces in fp32, we need to have
# local DDP and we should set the use-contiguous-buffers-in-ddp.
# local DDP and we should set the use-contiguous-buffers-in-ddp.
if args.accumulate_allreduce_grads_in_fp32:
assert args.DDP_impl == 'local'
args.use_contiguous_buffers_in_ddp = True

if args.dataloader_type is None:
args.dataloader_type = 'single'

Expand Down Expand Up @@ -212,7 +212,7 @@ def parse_args(extra_args_provider=None, defaults={},
else:
assert args.encoder_seq_length is not None
args.seq_length = args.encoder_seq_length

assert args.hidden_size % args.num_attention_heads == 0
if args.seq_length is not None:
assert args.max_position_embeddings >= args.seq_length
Expand Down Expand Up @@ -625,6 +625,9 @@ def _add_data_args(parser):
help='Path to the vocab file.')
group.add_argument('--merge-file', type=str, default=None,
help='Path to the BPE merge file.')
group.add_argument('--vocab-extra-ids', type=int, default=0,
help='Number of additional vocabulary tokens. '
'They are used for span masking in the T5 model')
group.add_argument('--seq-length', type=int, default=None,
help='Maximum sequence length to process.')
group.add_argument('--encoder-seq-length', type=int, default=None,
Expand Down
164 changes: 56 additions & 108 deletions megatron/data/bert_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,25 @@

"""BERT Style dataset."""

import os
import time

import numpy as np
import torch
from torch.utils.data import Dataset

from megatron import get_tokenizer, get_args
from megatron import print_rank_0
from megatron import mpu
from megatron.data.dataset_utils import get_a_and_b_segments
from megatron.data.dataset_utils import truncate_segments
from megatron.data.dataset_utils import create_tokens_and_tokentypes
from megatron.data.dataset_utils import pad_and_convert_to_numpy
from megatron.data.dataset_utils import create_masked_lm_predictions
from megatron import (
get_args,
get_tokenizer,
mpu,
print_rank_0
)
from megatron.data.dataset_utils import (
get_samples_mapping,
get_a_and_b_segments,
truncate_segments,
create_tokens_and_tokentypes,
create_masked_lm_predictions
)


class BertDataset(Dataset):
class BertDataset(torch.utils.data.Dataset):

def __init__(self, name, indexed_dataset, data_prefix,
num_epochs, max_num_samples, masked_lm_prob,
Expand All @@ -49,15 +50,15 @@ def __init__(self, name, indexed_dataset, data_prefix,
self.indexed_dataset = indexed_dataset

# Build the samples mapping.
self.samples_mapping = get_samples_mapping_(self.indexed_dataset,
data_prefix,
num_epochs,
max_num_samples,
self.max_seq_length,
short_seq_prob,
self.seed,
self.name,
self.binary_head)
self.samples_mapping = get_samples_mapping(self.indexed_dataset,
data_prefix,
num_epochs,
max_num_samples,
self.max_seq_length - 3, # account for added tokens
short_seq_prob,
self.seed,
self.name,
self.binary_head)

# Vocab stuff.
tokenizer = get_tokenizer()
Expand Down Expand Up @@ -87,91 +88,6 @@ def __getitem__(self, idx):
self.binary_head)


def get_samples_mapping_(indexed_dataset,
data_prefix,
num_epochs,
max_num_samples,
max_seq_length,
short_seq_prob,
seed,
name,
binary_head):
if not num_epochs:
if not max_num_samples:
raise ValueError("Need to specify either max_num_samples "
"or num_epochs")
num_epochs = np.iinfo(np.int32).max - 1
if not max_num_samples:
max_num_samples = np.iinfo(np.int64).max - 1

# Filename of the index mapping
indexmap_filename = data_prefix
indexmap_filename += '_{}_indexmap'.format(name)
if num_epochs != (np.iinfo(np.int32).max - 1):
indexmap_filename += '_{}ep'.format(num_epochs)
if max_num_samples != (np.iinfo(np.int64).max - 1):
indexmap_filename += '_{}mns'.format(max_num_samples)
indexmap_filename += '_{}msl'.format(max_seq_length)
indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob)
indexmap_filename += '_{}s'.format(seed)
indexmap_filename += '.npy'

# Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0 and \
not os.path.isfile(indexmap_filename):
print(' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'.format(indexmap_filename))

# Make sure the types match the helpers input types.
assert indexed_dataset.doc_idx.dtype == np.int64
assert indexed_dataset.sizes.dtype == np.int32

# Build samples mapping
verbose = torch.distributed.get_rank() == 0
start_time = time.time()
print_rank_0(' > building sapmles index mapping for {} ...'.format(
name))
# First compile and then import.
from megatron.data import helpers
samples_mapping = helpers.build_mapping(
indexed_dataset.doc_idx,
indexed_dataset.sizes,
num_epochs,
max_num_samples,
max_seq_length - 3, # account for added tokens
short_seq_prob,
seed,
verbose,
2 if binary_head else 1)
print_rank_0(' > done building sapmles index maping')
np.save(indexmap_filename, samples_mapping, allow_pickle=True)
print_rank_0(' > saved the index mapping in {}'.format(
indexmap_filename))
# Make sure all the ranks have built the mapping
print_rank_0(' > elasped time to build and save samples mapping '
'(seconds): {:4f}'.format(
time.time() - start_time))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
assert counts[0].item() == (
torch.distributed.get_world_size() //
torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()))

# Load indexed dataset.
print_rank_0(' > loading indexed mapping from {}'.format(
indexmap_filename))
start_time = time.time()
samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time))
print_rank_0(' total number of samples: {}'.format(
samples_mapping.shape[0]))

return samples_mapping


def build_training_sample(sample,
Expand Down Expand Up @@ -225,7 +141,7 @@ def build_training_sample(sample,

# Masking.
max_predictions_per_seq = masked_lm_prob * max_num_tokens
(tokens, masked_positions, masked_labels, _) = create_masked_lm_predictions(
(tokens, masked_positions, masked_labels, _, _) = create_masked_lm_predictions(
tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng)

Expand All @@ -244,3 +160,35 @@ def build_training_sample(sample,
'truncated': int(truncated)}
return train_sample


def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
masked_labels, pad_id, max_seq_length):
"""Pad sequences and convert them to numpy."""

# Some checks.
num_tokens = len(tokens)
padding_length = max_seq_length - num_tokens
assert padding_length >= 0
assert len(tokentypes) == num_tokens
assert len(masked_positions) == len(masked_labels)

# Tokens and token types.
filler = [pad_id] * padding_length
tokens_np = np.array(tokens + filler, dtype=np.int64)
tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)

# Padding mask.
padding_mask_np = np.array([1] * num_tokens + [0] * padding_length,
dtype=np.int64)

# Lables and loss mask.
labels = [-1] * max_seq_length
loss_mask = [0] * max_seq_length
for i in range(len(masked_positions)):
assert masked_positions[i] < num_tokens
labels[masked_positions[i]] = masked_labels[i]
loss_mask[masked_positions[i]] = 1
labels_np = np.array(labels, dtype=np.int64)
loss_mask_np = np.array(loss_mask, dtype=np.int64)

return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np
Loading

0 comments on commit 48a5e0d

Please sign in to comment.