This section is under construction.
import collections
import d2l
import mxnet as mx
from mxnet import autograd, gluon, init, np, npx
from mxnet.contrib import text
import os
import random
import time
import zipfile
npx.set_np()
# Saved in the d2l package for later use
d2l.DATA_HUB['wikitext-2'] = (
'https://s3.amazonaws.com/research.metamind.io/wikitext/'
'wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe')
We keep paragraphs with at least 2 sentences.
# Saved in the d2l package for later use
def read_wiki(data_dir):
file_name = os.path.join(data_dir, 'wiki.train.tokens')
with open(file_name, 'r') as f:
lines = f.readlines()
# A line represents a paragragh.
paragraghs = [line.strip().lower().split(' . ')
for line in lines if len(line.split(' . ')) >= 2]
random.shuffle(paragraghs)
return paragraghs
# Saved in the d2l package for later use
def get_next_sentence(sentence, next_sentence, paragraphs):
if random.random() < 0.5:
is_next = True
else:
# paragraphs is a list of lists of lists
next_sentence = random.choice(random.choice(paragraphs))
is_next = False
return sentence, next_sentence, is_next
# Saved in the d2l package for later use
def get_tokens_and_segments(tokens_a, tokens_b):
tokens = ['<cls>'] + tokens_a + ['<sep>'] + tokens_b + ['<sep>']
segment_ids = [0] * (len(tokens_a) + 2) + [1] * (len(tokens_b) + 1)
return tokens, segment_ids
...
# Saved in the d2l package for later use
def get_nsp_data_from_paragraph(paragraph, paragraphs, vocab, max_len):
nsp_data_from_paragraph = []
for i in range(len(paragraph) - 1):
tokens_a, tokens_b, is_next = get_next_sentence(
paragraph[i], paragraph[i + 1], paragraphs)
# Consider 1 '<cls>' token and 2 '<sep>' tokens
if len(tokens_a) + len(tokens_b) + 3 > max_len:
continue
tokens, segment_ids = get_tokens_and_segments(tokens_a, tokens_b)
nsp_data_from_paragraph.append((tokens, segment_ids, is_next))
return nsp_data_from_paragraph
# Saved in the d2l package for later use
def replace_mlm_tokens(tokens, candidate_pred_positions, num_mlm_preds,
vocab):
# Make a new copy of tokens for the input of a masked language model,
# where the input may contain replaced '<mask>' or random tokens
mlm_input_tokens = [token for token in tokens]
pred_positions_and_labels = []
# Shuffle for gettting 15% random tokens for prediction in the masked
# language model task
random.shuffle(candidate_pred_positions)
for mlm_pred_position in candidate_pred_positions:
if len(pred_positions_and_labels) >= num_mlm_preds:
break
masked_token = None
# 80% of the time: replace the word with the '<mask>' token
if random.random() < 0.8:
masked_token = '<mask>'
else:
# 10% of the time: keep the word unchanged
if random.random() < 0.5:
masked_token = tokens[mlm_pred_position]
# 10% of the time: replace the word with a random word
else:
masked_token = random.randint(0, len(vocab) - 1)
mlm_input_tokens[mlm_pred_position] = masked_token
pred_positions_and_labels.append(
(mlm_pred_position, tokens[mlm_pred_position]))
return mlm_input_tokens, pred_positions_and_labels
...
# Saved in the d2l package for later use
def get_mlm_data_from_tokens(tokens, vocab):
candidate_pred_positions = []
# tokens is a list of strings
for i, token in enumerate(tokens):
# Special tokens are not predicted in the masked language model task
if token in ['<cls>', '<sep>']:
continue
candidate_pred_positions.append(i)
# 15% of random tokens will be predicted in the masked language model task
num_mlm_preds = max(1, round(len(tokens) * 0.15))
mlm_input_tokens, pred_positions_and_labels = replace_mlm_tokens(
tokens, candidate_pred_positions, num_mlm_preds, vocab)
pred_positions_and_labels = sorted(pred_positions_and_labels,
key=lambda x: x[0])
zipped_positions_and_labels = list(zip(*pred_positions_and_labels))
# e.g., [[1, 'an'], [12, 'car'], [25, '<unk>']] -> [1, 12, 25]
pred_positions = list(zipped_positions_and_labels[0])
# e.g., [[1, 'an'], [12, 'car'], [25, '<unk>']] -> ['an', 'car', '<unk>']
mlm_pred_labels = list(zipped_positions_and_labels[1])
return vocab[mlm_input_tokens], pred_positions, vocab[mlm_pred_labels]
...
# Saved in the d2l package for later use
def pad_bert_inputs(instances, max_len, vocab):
max_num_mlm_preds = round(max_len * 0.15)
X_tokens, X_segments, x_valid_lens, X_pred_positions = [], [], [], []
X_mlm_weights, Y_mlm, y_nsp = [], [], []
for (mlm_input_ids, pred_positions, mlm_pred_label_ids, segment_ids,
is_next) in instances:
X_tokens.append(np.array(mlm_input_ids + [vocab['<pad>']] * (
max_len - len(mlm_input_ids)), dtype='int32'))
X_segments.append(np.array(segment_ids + [0] * (
max_len - len(segment_ids)), dtype='int32'))
x_valid_lens.append(np.array(len(mlm_input_ids)))
X_pred_positions.append(np.array(pred_positions + [0] * (
20 - len(pred_positions)), dtype='int32'))
# Predictions of padded tokens will be filtered out in the loss via
# multiplication of 0 weights
X_mlm_weights.append(np.array([1.0] * len(mlm_pred_label_ids) + [
0.0] * (20 - len(pred_positions)), dtype='float32'))
Y_mlm.append(np.array(mlm_pred_label_ids + [0] * (
20 - len(mlm_pred_label_ids)), dtype='int32'))
y_nsp.append(np.array(is_next))
return (X_tokens, X_segments, x_valid_lens, X_pred_positions,
X_mlm_weights, Y_mlm, y_nsp)
...
# Saved in the d2l package for later use
class WikiTextDataset(gluon.data.Dataset):
def __init__(self, paragraghs, max_len=128):
# Input paragraghs[i] is a list of sentence strings representing a
# paragraph; while output paragraghs[i] is a list of sentences
# representing a paragraph, where each sentence is a list of tokens
paragraghs = [d2l.tokenize(
paragraph, token='word') for paragraph in paragraghs]
sentences = [sentence for paragraph in paragraghs
for sentence in paragraph]
self.vocab = d2l.Vocab(sentences, min_freq=5, reserved_tokens=[
'<pad>', '<mask>', '<cls>', '<sep>'])
# Get data for the next sentence prediction task
instances = []
for paragraph in paragraghs:
instances.extend(get_nsp_data_from_paragraph(
paragraph, paragraghs, self.vocab, max_len))
# Get data for the masked language model task
instances = [(get_mlm_data_from_tokens(tokens, self.vocab)
+ (segment_ids, is_next))
for tokens, segment_ids, is_next in instances]
# Pad inputs
(self.X_tokens, self.X_segments, self.x_valid_lens,
self.X_pred_positions, self.X_mlm_weights, self.Y_mlm,
self.y_nsp) = pad_bert_inputs(instances, max_len, self.vocab)
def __getitem__(self, idx):
return (self.X_tokens[idx], self.X_segments[idx],
self.x_valid_lens[idx], self.X_pred_positions[idx],
self.X_mlm_weights[idx], self.Y_mlm[idx], self.y_nsp[idx])
def __len__(self):
return len(self.X_tokens)
# Saved in the d2l package for later use
def load_data_wiki(batch_size, max_len):
data_dir = d2l.download_extract('wikitext-2', 'wikitext-2')
paragraghs = read_wiki(data_dir)
train_set = WikiTextDataset(paragraghs, max_len)
train_iter = gluon.data.DataLoader(train_set, batch_size, shuffle=True)
return train_iter, train_set.vocab
batch_size, max_len = 512, 128
train_iter, vocab = load_data_wiki(batch_size, max_len)
...
...
...
for (X_tokens, X_segments, x_valid_lens, X_pred_positions, X_mlm_weights,
Y_mlm, y_nsp) in train_iter:
print(X_tokens.shape, X_segments.shape, x_valid_lens.shape,
X_pred_positions.shape, X_mlm_weights.shape, Y_mlm.shape,
y_nsp.shape)
break
net = d2l.BERTModel(len(vocab), num_hiddens=128, ffn_num_hiddens=256,
num_heads=2, num_layers=2, dropout=0.2)
ctx = d2l.try_all_gpus()
net.initialize(init.Xavier(), ctx=ctx)
nsp_loss, mlm_loss = gluon.loss.SoftmaxCELoss(), gluon.loss.SoftmaxCELoss()
...
# Saved in the d2l package for later use
def _get_batch_bert(batch, ctx):
(X_tokens, X_segments, x_valid_lens, X_pred_positions, X_mlm_weights,
Y_mlm, y_nsp) = batch
split_and_load = gluon.utils.split_and_load
return (split_and_load(X_tokens, ctx, even_split=False),
split_and_load(X_segments, ctx, even_split=False),
split_and_load(x_valid_lens.astype('float32'), ctx,
even_split=False),
split_and_load(X_pred_positions, ctx, even_split=False),
split_and_load(X_mlm_weights, ctx, even_split=False),
split_and_load(Y_mlm, ctx, even_split=False),
split_and_load(y_nsp, ctx, even_split=False))
...
# Saved in the d2l package for later use
def batch_loss_bert(net, nsp_loss, mlm_loss, X_tokens_shards,
X_segments_shards, x_valid_lens_shards,
X_pred_positions_shards, X_mlm_weights_shards,
Y_mlm_shards, y_nsp_shards, vocab_size):
ls = []
ls_mlm = []
ls_nsp = []
for (X_tokens_shard, X_segments_shard, x_valid_lens_shard,
X_pred_positions_shard, X_mlm_weights_shard, Y_mlm_shard,
y_nsp_shard) in zip(
X_tokens_shards, X_segments_shards, x_valid_lens_shards,
X_pred_positions_shards, X_mlm_weights_shards, Y_mlm_shards,
y_nsp_shards):
num_masks = X_mlm_weights_shard.sum() + 1e-8
_, decoded, classified = net(
X_tokens_shard, X_segments_shard, x_valid_lens_shard.reshape(-1),
X_pred_positions_shard)
l_mlm = mlm_loss(
decoded.reshape((-1, vocab_size)), Y_mlm_shard.reshape(-1),
X_mlm_weights_shard.reshape((-1, 1)))
l_mlm = l_mlm.sum() / num_masks
l_nsp = nsp_loss(classified, y_nsp_shard)
l_nsp = l_nsp.mean()
l = l_mlm + l_nsp
ls.append(l)
ls_mlm.append(l_mlm)
ls_nsp.append(l_nsp)
npx.waitall()
return ls, ls_mlm, ls_nsp
...
# Saved in the d2l package for later use
def train_bert(data_eval, net, nsp_loss, mlm_loss, vocab_size, ctx,
log_interval, max_step):
trainer = gluon.Trainer(net.collect_params(), 'adam')
step_num = 0
while step_num < max_step:
eval_begin_time = time.time()
begin_time = time.time()
running_mlm_loss = running_nsp_loss = 0
total_mlm_loss = total_nsp_loss = 0
running_num_tks = 0
for _, data_batch in enumerate(data_eval):
(X_tokens_shards, X_segments_shards, x_valid_lens_shards,
X_pred_positions_shards, X_mlm_weights_shards,
Y_mlm_shards, y_nsp_shards) = _get_batch_bert(data_batch, ctx)
step_num += 1
with autograd.record():
ls, ls_mlm, ls_nsp = batch_loss_bert(
net, nsp_loss, mlm_loss, X_tokens_shards,
X_segments_shards, x_valid_lens_shards,
X_pred_positions_shards, X_mlm_weights_shards,
Y_mlm_shards, y_nsp_shards, vocab_size)
for l in ls:
l.backward()
trainer.step(1)
running_mlm_loss += sum([l for l in ls_mlm])
running_nsp_loss += sum([l for l in ls_nsp])
if (step_num + 1) % (log_interval) == 0:
total_mlm_loss += running_mlm_loss
total_nsp_loss += running_nsp_loss
begin_time = time.time()
running_mlm_loss = running_nsp_loss = 0
eval_end_time = time.time()
if running_mlm_loss != 0:
total_mlm_loss += running_mlm_loss
total_nsp_loss += running_nsp_loss
total_mlm_loss /= step_num
total_nsp_loss /= step_num
print('Eval mlm_loss={:.3f}\tnsp_loss={:.3f}\t'
.format(float(total_mlm_loss),
float(total_nsp_loss)))
print('Eval cost={:.1f}s'.format(eval_end_time - eval_begin_time))
...
train_bert(train_iter, net, nsp_loss, mlm_loss, len(vocab), ctx, 20, 1)
- Try other sentence segmentation methods, such as
spaCy
andnltk.tokenize.sent_tokenize
. For instance, after installingnltk
, you need to runimport nltk
andnltk.download('punkt')
first.