Skip to content

Latest commit

 

History

History
363 lines (315 loc) · 13.2 KB

File metadata and controls

363 lines (315 loc) · 13.2 KB

Pretraining BERT

This section is under construction.

import collections
import d2l
import mxnet as mx
from mxnet import autograd, gluon, init, np, npx
from mxnet.contrib import text
import os
import random
import time
import zipfile

npx.set_np()
# Saved in the d2l package for later use
d2l.DATA_HUB['wikitext-2'] = (
    'https://s3.amazonaws.com/research.metamind.io/wikitext/'
    'wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe')

We keep paragraphs with at least 2 sentences.

# Saved in the d2l package for later use
def read_wiki(data_dir):
    file_name = os.path.join(data_dir, 'wiki.train.tokens')
    with open(file_name, 'r') as f:
        lines = f.readlines()
    # A line represents a paragragh.
    paragraghs = [line.strip().lower().split(' . ')
                  for line in lines if len(line.split(' . ')) >= 2]
    random.shuffle(paragraghs)
    return paragraghs

Prepare NSP data

# Saved in the d2l package for later use
def get_next_sentence(sentence, next_sentence, paragraphs):
    if random.random() < 0.5:
        is_next = True
    else:
        # paragraphs is a list of lists of lists
        next_sentence = random.choice(random.choice(paragraphs))
        is_next = False
    return sentence, next_sentence, is_next
# Saved in the d2l package for later use
def get_tokens_and_segments(tokens_a, tokens_b):
    tokens = ['<cls>'] + tokens_a + ['<sep>'] + tokens_b + ['<sep>']
    segment_ids = [0] * (len(tokens_a) + 2) + [1] * (len(tokens_b) + 1)
    return tokens, segment_ids

...

# Saved in the d2l package for later use
def get_nsp_data_from_paragraph(paragraph, paragraphs, vocab, max_len):
    nsp_data_from_paragraph = []
    for i in range(len(paragraph) - 1):
        tokens_a, tokens_b, is_next = get_next_sentence(
            paragraph[i], paragraph[i + 1], paragraphs)
        # Consider 1 '<cls>' token and 2 '<sep>' tokens
        if len(tokens_a) + len(tokens_b) + 3 > max_len:
             continue
        tokens, segment_ids = get_tokens_and_segments(tokens_a, tokens_b)
        nsp_data_from_paragraph.append((tokens, segment_ids, is_next))
    return nsp_data_from_paragraph

Prepare MLM data

# Saved in the d2l package for later use
def replace_mlm_tokens(tokens, candidate_pred_positions, num_mlm_preds,
                       vocab):
    # Make a new copy of tokens for the input of a masked language model,
    # where the input may contain replaced '<mask>' or random tokens
    mlm_input_tokens = [token for token in tokens]
    pred_positions_and_labels = []
    # Shuffle for gettting 15% random tokens for prediction in the masked
    # language model task
    random.shuffle(candidate_pred_positions)
    for mlm_pred_position in candidate_pred_positions:
        if len(pred_positions_and_labels) >= num_mlm_preds:
            break
        masked_token = None
        # 80% of the time: replace the word with the '<mask>' token
        if random.random() < 0.8:
            masked_token = '<mask>'
        else:
            # 10% of the time: keep the word unchanged
            if random.random() < 0.5:
                masked_token = tokens[mlm_pred_position]
            # 10% of the time: replace the word with a random word
            else:
                masked_token = random.randint(0, len(vocab) - 1)
        mlm_input_tokens[mlm_pred_position] = masked_token
        pred_positions_and_labels.append(
            (mlm_pred_position, tokens[mlm_pred_position]))
    return mlm_input_tokens, pred_positions_and_labels

...

# Saved in the d2l package for later use
def get_mlm_data_from_tokens(tokens, vocab):
    candidate_pred_positions = []
    # tokens is a list of strings
    for i, token in enumerate(tokens):
        # Special tokens are not predicted in the masked language model task
        if token in ['<cls>', '<sep>']:
            continue
        candidate_pred_positions.append(i)
    # 15% of random tokens will be predicted in the masked language model task
    num_mlm_preds = max(1, round(len(tokens) * 0.15))
    mlm_input_tokens, pred_positions_and_labels = replace_mlm_tokens(
        tokens, candidate_pred_positions, num_mlm_preds, vocab)
    pred_positions_and_labels = sorted(pred_positions_and_labels,
                                           key=lambda x: x[0])

    zipped_positions_and_labels = list(zip(*pred_positions_and_labels))
    # e.g., [[1, 'an'], [12, 'car'], [25, '<unk>']] -> [1, 12, 25]
    pred_positions = list(zipped_positions_and_labels[0])
    # e.g., [[1, 'an'], [12, 'car'], [25, '<unk>']] -> ['an', 'car', '<unk>']
    mlm_pred_labels = list(zipped_positions_and_labels[1])
    return vocab[mlm_input_tokens], pred_positions, vocab[mlm_pred_labels]

Prepare Training Data

...

# Saved in the d2l package for later use
def pad_bert_inputs(instances, max_len, vocab):
    max_num_mlm_preds = round(max_len * 0.15)
    X_tokens, X_segments, x_valid_lens, X_pred_positions = [], [], [], []
    X_mlm_weights, Y_mlm, y_nsp = [], [], []
    for (mlm_input_ids, pred_positions, mlm_pred_label_ids, segment_ids,
         is_next) in instances:
        X_tokens.append(np.array(mlm_input_ids + [vocab['<pad>']] * (
            max_len - len(mlm_input_ids)), dtype='int32'))
        X_segments.append(np.array(segment_ids + [0] * (
            max_len - len(segment_ids)), dtype='int32'))
        x_valid_lens.append(np.array(len(mlm_input_ids)))
        X_pred_positions.append(np.array(pred_positions + [0] * (
            20 - len(pred_positions)), dtype='int32'))
        # Predictions of padded tokens will be filtered out in the loss via
        # multiplication of 0 weights
        X_mlm_weights.append(np.array([1.0] * len(mlm_pred_label_ids) + [
            0.0] * (20 - len(pred_positions)), dtype='float32'))
        Y_mlm.append(np.array(mlm_pred_label_ids + [0] * (
            20 - len(mlm_pred_label_ids)), dtype='int32'))
        y_nsp.append(np.array(is_next))
    return (X_tokens, X_segments, x_valid_lens, X_pred_positions,
            X_mlm_weights, Y_mlm, y_nsp)

...

# Saved in the d2l package for later use
class WikiTextDataset(gluon.data.Dataset):
    def __init__(self, paragraghs, max_len=128):
        # Input paragraghs[i] is a list of sentence strings representing a
        # paragraph; while output paragraghs[i] is a list of sentences
        # representing a paragraph, where each sentence is a list of tokens
        paragraghs = [d2l.tokenize(
            paragraph, token='word') for paragraph in paragraghs]
        sentences = [sentence for paragraph in paragraghs
                     for sentence in paragraph]
        self.vocab = d2l.Vocab(sentences, min_freq=5, reserved_tokens=[
            '<pad>', '<mask>', '<cls>', '<sep>'])
        # Get data for the next sentence prediction task
        instances = []
        for paragraph in paragraghs:
            instances.extend(get_nsp_data_from_paragraph(
                paragraph, paragraghs, self.vocab, max_len))
        # Get data for the masked language model task
        instances = [(get_mlm_data_from_tokens(tokens, self.vocab)
                      + (segment_ids, is_next))
                     for tokens, segment_ids, is_next in instances]
        # Pad inputs
        (self.X_tokens, self.X_segments, self.x_valid_lens,
         self.X_pred_positions, self.X_mlm_weights, self.Y_mlm,
         self.y_nsp) = pad_bert_inputs(instances, max_len, self.vocab)

    def __getitem__(self, idx):
        return (self.X_tokens[idx], self.X_segments[idx],
                self.x_valid_lens[idx], self.X_pred_positions[idx],
                self.X_mlm_weights[idx], self.Y_mlm[idx], self.y_nsp[idx])

    def __len__(self):
        return len(self.X_tokens)
# Saved in the d2l package for later use
def load_data_wiki(batch_size, max_len):
    data_dir = d2l.download_extract('wikitext-2', 'wikitext-2')
    paragraghs = read_wiki(data_dir)
    train_set = WikiTextDataset(paragraghs, max_len)
    train_iter = gluon.data.DataLoader(train_set, batch_size, shuffle=True)
    return train_iter, train_set.vocab
batch_size, max_len = 512, 128
train_iter, vocab = load_data_wiki(batch_size, max_len)

...

...

...

for (X_tokens, X_segments, x_valid_lens, X_pred_positions, X_mlm_weights,
     Y_mlm, y_nsp) in train_iter:
    print(X_tokens.shape, X_segments.shape, x_valid_lens.shape,
          X_pred_positions.shape, X_mlm_weights.shape, Y_mlm.shape,
          y_nsp.shape)
    break

Training BERT

net = d2l.BERTModel(len(vocab), num_hiddens=128, ffn_num_hiddens=256,
                    num_heads=2, num_layers=2, dropout=0.2)
ctx = d2l.try_all_gpus()
net.initialize(init.Xavier(), ctx=ctx)
nsp_loss, mlm_loss = gluon.loss.SoftmaxCELoss(), gluon.loss.SoftmaxCELoss()

...

# Saved in the d2l package for later use
def _get_batch_bert(batch, ctx):
    (X_tokens, X_segments, x_valid_lens, X_pred_positions, X_mlm_weights,
     Y_mlm, y_nsp) = batch
    split_and_load = gluon.utils.split_and_load
    return (split_and_load(X_tokens, ctx, even_split=False),
            split_and_load(X_segments, ctx, even_split=False),
            split_and_load(x_valid_lens.astype('float32'), ctx,
                           even_split=False),
            split_and_load(X_pred_positions, ctx, even_split=False),
            split_and_load(X_mlm_weights, ctx, even_split=False),
            split_and_load(Y_mlm, ctx, even_split=False),
            split_and_load(y_nsp, ctx, even_split=False))

...

# Saved in the d2l package for later use
def batch_loss_bert(net, nsp_loss, mlm_loss, X_tokens_shards,
                    X_segments_shards, x_valid_lens_shards,
                    X_pred_positions_shards, X_mlm_weights_shards,
                    Y_mlm_shards, y_nsp_shards, vocab_size):
    ls = []
    ls_mlm = []
    ls_nsp = []
    for (X_tokens_shard, X_segments_shard, x_valid_lens_shard,
         X_pred_positions_shard, X_mlm_weights_shard, Y_mlm_shard,
         y_nsp_shard) in zip(
        X_tokens_shards, X_segments_shards, x_valid_lens_shards,
        X_pred_positions_shards, X_mlm_weights_shards, Y_mlm_shards,
        y_nsp_shards):

        num_masks = X_mlm_weights_shard.sum() + 1e-8
        _, decoded, classified = net(
            X_tokens_shard, X_segments_shard, x_valid_lens_shard.reshape(-1),
            X_pred_positions_shard)
        l_mlm = mlm_loss(
            decoded.reshape((-1, vocab_size)), Y_mlm_shard.reshape(-1),
            X_mlm_weights_shard.reshape((-1, 1)))
        l_mlm = l_mlm.sum() / num_masks
        l_nsp = nsp_loss(classified, y_nsp_shard)
        l_nsp = l_nsp.mean()
        l = l_mlm + l_nsp
        ls.append(l)
        ls_mlm.append(l_mlm)
        ls_nsp.append(l_nsp)
        npx.waitall()
    return ls, ls_mlm, ls_nsp

...

# Saved in the d2l package for later use
def train_bert(data_eval, net, nsp_loss, mlm_loss, vocab_size, ctx,
               log_interval, max_step):
    trainer = gluon.Trainer(net.collect_params(), 'adam')
    step_num = 0
    while step_num < max_step:
        eval_begin_time = time.time()
        begin_time = time.time()

        running_mlm_loss = running_nsp_loss = 0
        total_mlm_loss = total_nsp_loss = 0
        running_num_tks = 0
        for _, data_batch in enumerate(data_eval):
            (X_tokens_shards, X_segments_shards, x_valid_lens_shards,
             X_pred_positions_shards, X_mlm_weights_shards,
             Y_mlm_shards, y_nsp_shards) = _get_batch_bert(data_batch, ctx)

            step_num += 1
            with autograd.record():
                ls, ls_mlm, ls_nsp = batch_loss_bert(
                    net, nsp_loss, mlm_loss, X_tokens_shards,
                    X_segments_shards, x_valid_lens_shards,
                    X_pred_positions_shards, X_mlm_weights_shards,
                    Y_mlm_shards, y_nsp_shards, vocab_size)
            for l in ls:
                l.backward()

            trainer.step(1)

            running_mlm_loss += sum([l for l in ls_mlm])
            running_nsp_loss += sum([l for l in ls_nsp])

            if (step_num + 1) % (log_interval) == 0:
                total_mlm_loss += running_mlm_loss
                total_nsp_loss += running_nsp_loss
                begin_time = time.time()
                running_mlm_loss = running_nsp_loss = 0

        eval_end_time = time.time()
        if running_mlm_loss != 0:
            total_mlm_loss += running_mlm_loss
            total_nsp_loss += running_nsp_loss
        total_mlm_loss /= step_num
        total_nsp_loss /= step_num
        print('Eval mlm_loss={:.3f}\tnsp_loss={:.3f}\t'
                     .format(float(total_mlm_loss),
                             float(total_nsp_loss)))
        print('Eval cost={:.1f}s'.format(eval_end_time - eval_begin_time))

...

train_bert(train_iter, net, nsp_loss, mlm_loss, len(vocab), ctx, 20, 1)

Exercises

  1. Try other sentence segmentation methods, such as spaCy and nltk.tokenize.sent_tokenize. For instance, after installing nltk, you need to run import nltk and nltk.download('punkt') first.