diff --git a/docs/code/data.rst b/docs/code/data.rst index 32afacc68..50c1c91bd 100644 --- a/docs/code/data.rst +++ b/docs/code/data.rst @@ -102,6 +102,19 @@ Data Iterators .. autoclass:: texar.data.TrainTestDataIterator :members: +:hidden:`BatchingStrategy` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: texar.data.BatchingStrategy + :members: + +:hidden:`TokenCountBatchingStrategy` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: texar.data.TokenCountBatchingStrategy + :members: + + Data Utilities =============== diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 1fe36da74..2238971b6 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -55,4 +55,5 @@ PyTorch pytorch torch fastly +CUDA precompute diff --git a/examples/transformer/config_iwslt15.py b/examples/transformer/config_iwslt15.py index 40342a3ed..1d6d2269d 100644 --- a/examples/transformer/config_iwslt15.py +++ b/examples/transformer/config_iwslt15.py @@ -1,4 +1,4 @@ -batch_size = 2048 +max_batch_tokens = 2048 test_batch_size = 32 max_train_epoch = 20 diff --git a/examples/transformer/config_wmt14.py b/examples/transformer/config_wmt14.py index 42fd2a133..11f5e2719 100644 --- a/examples/transformer/config_wmt14.py +++ b/examples/transformer/config_wmt14.py @@ -1,4 +1,4 @@ -batch_size = 3072 +max_batch_tokens = 3072 test_batch_size = 32 max_train_epoch = 10 diff --git a/examples/transformer/model.py b/examples/transformer/model.py new file mode 100644 index 000000000..1b1a46a9f --- /dev/null +++ b/examples/transformer/model.py @@ -0,0 +1,212 @@ +from typing import Optional + +import torch +from torch import nn + +import texar as tx + + +class Transformer(nn.Module): + r"""A standalone sequence-to-sequence Transformer model, from "Attention + Is All You Need". The Transformer model consists of the word embedding + layer, position embedding layer, an encoder and a decoder. Both encoder + and decoder are stacks of self-attention layers followed by feed-forward + layers. See "Attention Is All You Need" (https://arxiv.org/abs/1706.03762) + for the full description of the model. + """ + + def __init__(self, model_config, data_config, vocab: tx.data.Vocab): + super().__init__() + + self.config_model = model_config + self.config_data = data_config + self.vocab = vocab + self.vocab_size = vocab.size + + self.word_embedder = tx.modules.WordEmbedder( + vocab_size=self.vocab_size, + hparams=self.config_model.emb, + ) + self.pos_embedder = tx.modules.SinusoidsPositionEmbedder( + position_size=self.config_data.max_decoding_length, + hparams=self.config_model.position_embedder_hparams, + ) + + self.encoder = tx.modules.TransformerEncoder( + hparams=self.config_model.encoder + ) + self.decoder = tx.modules.TransformerDecoder( + vocab_size=self.vocab_size, + output_layer=self.word_embedder.embedding, + hparams=self.config_model.decoder, + ) + + self.smoothed_loss_func = LabelSmoothingLoss( + label_confidence=self.config_model.loss_label_confidence, + tgt_vocab_size=self.vocab_size, + ignore_index=0, + ) + + def forward( # type: ignore + self, + encoder_input: torch.Tensor, + decoder_input: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + beam_width: Optional[int] = None, + ): + r"""Compute the maximum likelihood loss or perform decoding, depending + on arguments. + + Args: + encoder_input: the source sentence embedding, with the shape of + `[batch_size, source_seq_length, input_dim]`. + decoder_input: the target sentence embedding, with the shape of + `[batch_size, target_seq_length, input_dim]`. + labels: the target sentence labels, with the shape of + `[batch_size, target_seq_length]`. + beam_width: Used in beam search. + + :returns: + - If both :attr:`decoder_input` and :attr:`labels` are both + provided, the function enters training logic and returns the + maximum likelihood loss. + - Otherwise the function enters inference logic and returns the + decoded sequence. + - If `beam_width` > 1, beam search decoding is performed. Please + refer to :meth:`texar.modules.TransformerDecoder.forward` for + details on return types. + """ + + batch_size = encoder_input.size(0) + # (text sequence length excluding padding) + encoder_input_length = (encoder_input != 0).int().sum(dim=1) + + # Source word embedding + src_word_embeds = self.word_embedder(encoder_input) + src_word_embeds = src_word_embeds * self.config_model.hidden_dim ** 0.5 + + # Position embedding (shared b/w source and target) + src_seq_len = torch.full( + (batch_size,), encoder_input.size(1), dtype=torch.int32 + ) + src_seq_len = src_seq_len.to(device=encoder_input.device) + + src_pos_embeds = self.pos_embedder(sequence_length=src_seq_len) + src_input_embedding = src_word_embeds + src_pos_embeds + + encoder_output = self.encoder( + inputs=src_input_embedding, sequence_length=encoder_input_length + ) + + if decoder_input is not None and labels is not None: + # enter the training logic + + tgt_word_embeds = self.word_embedder(decoder_input) + tgt_word_embeds = ( + tgt_word_embeds * self.config_model.hidden_dim ** 0.5 + ) + tgt_seq_len = decoder_input.new_full( + (batch_size,), decoder_input.size(1), + ) + + tgt_pos_embeds = self.pos_embedder(sequence_length=tgt_seq_len) + + tgt_input_embedding = tgt_word_embeds + tgt_pos_embeds + + # For training + outputs = self.decoder( + memory=encoder_output, + memory_sequence_length=encoder_input_length, + inputs=tgt_input_embedding, + decoding_strategy="train_greedy", + ) + label_lengths = (labels != 0).long().sum(dim=1) + is_target = (labels != 0).float() + mle_loss = self.smoothed_loss_func( + outputs.logits, labels, label_lengths + ) + mle_loss = (mle_loss * is_target).sum() / is_target.sum() + return mle_loss + + else: + start_tokens = encoder_input.new_full( + (batch_size,), self.vocab.bos_token_id, + ) + + def _embedding_fn(x, y): + word_embed = self.word_embedder(x) + scale = self.config_model.hidden_dim ** 0.5 + pos_embed = self.pos_embedder(y) + return word_embed * scale + pos_embed + + predictions = self.decoder( + memory=encoder_output, + memory_sequence_length=encoder_input_length, + beam_width=beam_width, + length_penalty=self.config_model.length_penalty, + start_tokens=start_tokens, + end_token=self.vocab.eos_token_id, + embedding=_embedding_fn, + max_decoding_length=self.config_data.max_decoding_length, + decoding_strategy="infer_greedy", + ) + # Uses the best sample by beam search + return predictions + + +class LabelSmoothingLoss(nn.Module): + r"""With label smoothing, + KL-divergence between q_{smoothed ground truth prob.}(w) + and p_{prob. computed by model}(w) is minimized. + + Args: + label_confidence: the confidence weight on the ground truth label. + tgt_vocab_size: the size of the final classification. + ignore_index: The index in the vocabulary to ignore weight. + """ + + def __init__(self, label_confidence, tgt_vocab_size, ignore_index=0): + super().__init__() + self.ignore_index = ignore_index + self.tgt_vocab_size = tgt_vocab_size + + label_smoothing = 1 - label_confidence + assert 0.0 < label_smoothing <= 1.0 + smoothing_value = label_smoothing / (tgt_vocab_size - 2) + one_hot = torch.full((tgt_vocab_size,), smoothing_value) + one_hot[self.ignore_index] = 0 + self.register_buffer("one_hot", one_hot.unsqueeze(0)) + self.confidence = label_confidence + + def forward( # type: ignore + self, + output: torch.Tensor, + target: torch.Tensor, + label_lengths: torch.LongTensor, + ) -> torch.Tensor: + r"""Compute the label smoothing loss. + + Args: + output (FloatTensor): batch_size x seq_length * n_classes + target (LongTensor): batch_size * seq_length, specify the label + target + label_lengths(torch.LongTensor): specify the length of the labels + """ + orig_shapes = (output.size(), target.size()) + output = output.view(-1, self.tgt_vocab_size) + target = target.view(-1) + model_prob = self.one_hot.repeat(target.size(0), 1) + model_prob = model_prob.to(device=target.device) + model_prob.scatter_(1, target.unsqueeze(1), self.confidence) + model_prob.masked_fill_((target == self.ignore_index).unsqueeze(1), 0) + + output = output.view(orig_shapes[0]) + model_prob = model_prob.view(orig_shapes[0]) + + return tx.losses.sequence_softmax_cross_entropy( + labels=model_prob, + logits=output, + sequence_length=label_lengths, + average_across_batch=False, + sum_over_timesteps=False, + ) diff --git a/examples/transformer/requirements.txt b/examples/transformer/requirements.txt index 5391f327c..2abbda538 100644 --- a/examples/transformer/requirements.txt +++ b/examples/transformer/requirements.txt @@ -1,2 +1,2 @@ sentencepiece -torchtext +tqdm diff --git a/examples/transformer/transformer_main.py b/examples/transformer/transformer_main.py index 704e1c05d..35a4a72e7 100644 --- a/examples/transformer/transformer_main.py +++ b/examples/transformer/transformer_main.py @@ -13,59 +13,39 @@ # limitations under the License. """Transformer model. """ -from typing import Optional - import argparse import functools import importlib import os -import random + import torch -from torch import nn -from torchtext import data -from tqdm import tqdm +import tqdm import texar as tx from texar.data import Vocab -from texar.module_base import ModuleBase -from texar.modules import TransformerDecoder -from texar.modules import WordEmbedder -from texar.modules import SinusoidsPositionEmbedder -from texar.modules import TransformerEncoder -from texar.losses import sequence_softmax_cross_entropy from bleu_tool import bleu_wrapper -from utils import data_utils, utils +from model import Transformer +import utils.data_utils as data_utils +import utils.utils as utils parser = argparse.ArgumentParser() parser.add_argument( - "--config_model", type=str, default="config_model", help="The model config." -) + "--config_model", type=str, default="config_model", + help="The model config.") parser.add_argument( - "--config_data", - type=str, - default="config_iwslt15", - help="The dataset config.", -) + "--config_data", type=str, default="config_iwslt15", + help="The dataset config.") parser.add_argument( - "--run_mode", - type=str, - default="train_and_evaluate", - help="Either train_and_evaluate or evaluate or test.", -) + "--run_mode", type=str, default="train_and_evaluate", + help="Either train_and_evaluate or evaluate or test.") parser.add_argument( - "--model_dir", - type=str, - default="./outputs/", - help="Path to save the trained model and logs.", -) + "--model_dir", type=str, default="./outputs/", + help="Path to save the trained model and logs.") parser.add_argument( - "--model_fn", - type=str, - default="best-model.ckpt", - help="Model filename to save the trained weights", -) + "--model_fn", type=str, default="best-model.ckpt", + help="Model filename to save the trained weights") args = parser.parse_args() @@ -75,215 +55,34 @@ utils.set_random_seed(config_model.random_seed) -class Transformer(ModuleBase): - r"""A standalone sequence-to-sequence Transformer model. - TODO: Add detailed docstrings. - """ - - def __init__(self, model_config, data_config, vocab: Vocab): - ModuleBase.__init__(self) - - self.config_model = model_config - self.config_data = data_config - self.vocab = vocab - self.vocab_size = vocab.size - - self.word_embedder = WordEmbedder( - vocab_size=self.vocab_size, hparams=config_model.emb - ) - self.pos_embedder = SinusoidsPositionEmbedder( - position_size=config_data.max_decoding_length, - hparams=config_model.position_embedder_hparams, - ) - - self.encoder = TransformerEncoder(hparams=config_model.encoder) - self.decoder = TransformerDecoder( - vocab_size=self.vocab_size, - output_layer=self.word_embedder.embedding, - hparams=config_model.decoder, - ) - - self.smoothed_loss_func = LabelSmoothingLoss( - label_confidence=config_model.loss_label_confidence, - tgt_vocab_size=self.vocab_size, - ignore_index=0, - ) - - def forward( # type: ignore - self, - encoder_input: torch.Tensor, - is_train_mode: Optional[bool], - decoder_input: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - beam_width: Optional[int] = None, - ): - r"""TODO: Add detailed docstrings. - - Args: - encoder_input: - is_train_mode: - decoder_input: - labels: - beam_width: - - Returns: - - """ - - batch_size = encoder_input.size()[0] - # (text sequence length excluding padding) - encoder_input_length = (encoder_input != 0).int().sum(dim=1) - - if is_train_mode: - self.train() - - else: - self.eval() - - # Source word embedding - src_word_embeds = self.word_embedder(encoder_input) - src_word_embeds = src_word_embeds * config_model.hidden_dim ** 0.5 - - # Position embedding (shared b/w source and target) - src_seq_len = torch.full( - (batch_size,), encoder_input.size()[1], dtype=torch.int32 - ) - src_seq_len = src_seq_len.to(device=encoder_input.device) - - src_pos_embeds = self.pos_embedder(sequence_length=src_seq_len) - src_input_embedding = src_word_embeds + src_pos_embeds - - encoder_output = self.encoder( - inputs=src_input_embedding, sequence_length=encoder_input_length - ) - - if is_train_mode: - assert decoder_input is not None - assert labels is not None - - tgt_word_embeds = self.word_embedder(decoder_input) - tgt_word_embeds = ( - tgt_word_embeds * config_model.hidden_dim ** 0.5 - ) - tgt_seq_len = torch.full( - (batch_size,), decoder_input.size()[1], dtype=torch.int32 - ) - tgt_seq_len = tgt_seq_len.to(device=decoder_input.device) - - tgt_pos_embeds = self.pos_embedder(sequence_length=tgt_seq_len) - - tgt_input_embedding = tgt_word_embeds + tgt_pos_embeds - - # For training - outputs = self.decoder( - memory=encoder_output, - memory_sequence_length=encoder_input_length, - inputs=tgt_input_embedding, - decoding_strategy="train_greedy", - ) - labels = labels.to(device=outputs.logits.device) - label_lengths = (labels != 0).long().sum(dim=1) - label_lengths = label_lengths.to(device=outputs.logits.device) - is_target = (labels != 0).float() - mle_loss = self.smoothed_loss_func( - outputs.logits, labels, label_lengths - ) - mle_loss = (mle_loss * is_target).sum() / is_target.sum() - return mle_loss - else: - start_tokens = encoder_input.new_full( - (batch_size,), self.vocab.bos_token_id, dtype=torch.long - ) - - def _embedding_fn(x, y): - word_embed = self.word_embedder(x) - scale = config_model.hidden_dim ** 0.5 - pos_embed = self.pos_embedder(y) - return word_embed * scale + pos_embed - - predictions = self.decoder( - memory=encoder_output, - memory_sequence_length=encoder_input_length, - beam_width=beam_width, - length_penalty=config_model.length_penalty, - start_tokens=start_tokens, - end_token=self.vocab.eos_token_id, - embedding=_embedding_fn, - max_decoding_length=config_data.max_decoding_length, - decoding_strategy="infer_greedy", - ) - # Uses the best sample by beam search - return predictions - - -class LabelSmoothingLoss(nn.Module): - r"""With label smoothing, - KL-divergence between q_{smoothed ground truth prob.}(w) - and p_{prob. computed by model}(w) is minimized. - - Args: - label_confidence: the confidence weight on the ground truth label. - tgt_vocab_size: the size of the final classification. - ignore_index: The index in the vocabulary to ignore weight. - """ - - def __init__(self, label_confidence, tgt_vocab_size, ignore_index=0): - self.ignore_index = ignore_index - self.tgt_vocab_size = tgt_vocab_size - super().__init__() - - label_smoothing = 1 - label_confidence - assert 0.0 < label_smoothing <= 1.0 - smoothing_value = label_smoothing / (tgt_vocab_size - 2) - one_hot = torch.full((tgt_vocab_size,), smoothing_value) - one_hot[self.ignore_index] = 0 - self.register_buffer("one_hot", one_hot.unsqueeze(0)) - self.confidence = label_confidence - - def forward( # type: ignore - self, - output: torch.Tensor, - target: torch.Tensor, - label_lengths: torch.LongTensor, - ) -> torch.Tensor: - r""" - - Args: - output (FloatTensor): batch_size x seq_length * n_classes - target (LongTensor): batch_size * seq_length, specify the label - target - label_lengths(torch.LongTensor): specify the length of the labels - """ - ori_shapes = (output.size(), target.size()) - output = output.view(-1, self.tgt_vocab_size) - target = target.view(-1) - model_prob = self.one_hot.repeat(target.size(0), 1) - model_prob = model_prob.to(device=target.device) - model_prob.scatter_(1, target.unsqueeze(1), self.confidence) - model_prob.masked_fill_((target == self.ignore_index).unsqueeze(1), 0) - - output = output.view(ori_shapes[0]) - model_prob = model_prob.view(ori_shapes[0]) - - return sequence_softmax_cross_entropy( - labels=model_prob, - logits=output, - sequence_length=label_lengths, - average_across_batch=False, - sum_over_timesteps=False, - ) - - def main(): """Entry point. """ - # Load data - train_data, dev_data, test_data = data_utils.load_data_numpy( - config_data.input_dir, config_data.filename_prefix - ) + if torch.cuda.is_available(): + device = torch.device(torch.cuda.current_device()) + print(f"Using CUDA device {device}") + else: + device = None + # Load data vocab = Vocab(config_data.vocab_file) - + data_hparams = { + # "batch_size" is ignored for train since we use dynamic batching + "batch_size": config_data.test_batch_size, + "bos_id": vocab.bos_token_id, + "eos_id": vocab.eos_token_id, + } + datasets = { + split: data_utils.Seq2SeqData( + os.path.join( + config_data.input_dir, + f"{config_data.filename_prefix}{split}.npy" + ), + hparams=data_hparams, + device=device + ) for split in ["train", "valid", "test"] + } + print(f"Training data size: {len(datasets['train'])}") beam_width = config_model.beam_width # Create logging @@ -292,12 +91,8 @@ def main(): logger = utils.get_logger(logging_file) print(f"logging file is saved in: {logging_file}") - model = Transformer(config_model, config_data, vocab) - if torch.cuda.is_available(): - model = model.cuda() - device = torch.cuda.current_device() - else: - device = None + # Create model and optimizer + model = Transformer(config_model, config_data, vocab).to(device) best_results = {"score": 0, "epoch": -1} lr_config = config_model.lr_config @@ -314,43 +109,34 @@ def main(): ) scheduler = torch.optim.lr_scheduler.LambdaLR(optim, scheduler_lambda) - def _eval_epoch(epoch, mode): - - if mode == "eval": - eval_data = dev_data - elif mode == "test": - eval_data = test_data + @torch.no_grad() + def _eval_epoch(epoch, mode, print_fn=None): + if print_fn is None: + print_fn = print + tqdm_leave = True else: - raise ValueError('`mode` should be either "eval" or "test".') + tqdm_leave = False model.eval() + eval_data = datasets[mode] + eval_iter = tx.data.DataIterator(eval_data) references, hypotheses = [], [] - bsize = config_data.test_batch_size - for i in tqdm(range(0, len(eval_data), bsize)): - sources, targets = zip(*eval_data[i: i + bsize]) - with torch.no_grad(): - x_block = data_utils.source_pad_concat_convert( - sources, device=device - ) - predictions = model( - encoder_input=x_block, - is_train_mode=False, - beam_width=beam_width, - ) - if beam_width == 1: - decoded_ids = predictions[0].sample_id - else: - decoded_ids = predictions["sample_id"][:, :, 0] + for batch in tqdm.tqdm(eval_iter, ncols=120, leave=tqdm_leave, + desc=f"Eval on {mode} set"): + predictions = model( + encoder_input=batch.source, + beam_width=beam_width, + ) + if beam_width == 1: + decoded_ids = predictions[0].sample_id + else: + decoded_ids = predictions["sample_id"][:, :, 0] - hypotheses.extend(h.tolist() for h in decoded_ids) - references.extend(r.tolist() for r in targets) - hypotheses = utils.list_strip_eos( - hypotheses, vocab.eos_token_id - ) - references = utils.list_strip_eos( - references, vocab.eos_token_id - ) + hypotheses.extend(h.tolist() for h in decoded_ids) + references.extend(r.tolist() for r in batch.target_output) + hypotheses = utils.list_strip_eos(hypotheses, vocab.eos_token_id) + references = utils.list_strip_eos(references, vocab.eos_token_id) - if mode == "eval": + if mode == "valid": # Writes results to files to evaluate BLEU # For 'eval' mode, the BLEU is based on token ids (rather than # text tokens) and serves only as a surrogate metric to monitor @@ -364,17 +150,13 @@ def _eval_epoch(epoch, mode): hwords = tx.utils.str_join(hwords) rwords = tx.utils.str_join(rwords) hyp_fn, ref_fn = tx.utils.write_paired_text( - hwords, - rwords, - fname, - mode="s", - src_fname_suffix="hyp", - tgt_fname_suffix="ref", + hwords, rwords, fname, mode="s", + src_fname_suffix="hyp", tgt_fname_suffix="ref", ) eval_bleu = bleu_wrapper(ref_fn, hyp_fn, case_sensitive=True) eval_bleu = 100.0 * eval_bleu logger.info("epoch: %d, eval_bleu %.4f", epoch, eval_bleu) - print(f"epoch: {epoch:d}, eval_bleu {eval_bleu:.4f}") + print_fn(f"epoch: {epoch:d}, eval_bleu {eval_bleu:.4f}") if eval_bleu > best_results["score"]: logger.info("epoch: %d, best bleu: %.4f", epoch, eval_bleu) @@ -382,7 +164,7 @@ def _eval_epoch(epoch, mode): best_results["epoch"] = epoch model_path = os.path.join(args.model_dir, args.model_fn) logger.info("Saving model to %s", model_path) - print(f"Saving model to {model_path}") + print_fn(f"Saving model to {model_path}") states = { "model": model.state_dict(), @@ -397,43 +179,34 @@ def _eval_epoch(epoch, mode): fname = os.path.join(args.model_dir, "test.output") hwords, rwords = [], [] for hyp, ref in zip(hypotheses, references): - hwords.append([vocab.id_to_token_map_py[y] for y in hyp]) - rwords.append([vocab.id_to_token_map_py[y] for y in ref]) + hwords.append(vocab.map_ids_to_tokens_py(hyp)) + rwords.append(vocab.map_ids_to_tokens_py(ref)) hwords = tx.utils.str_join(hwords) rwords = tx.utils.str_join(rwords) hyp_fn, ref_fn = tx.utils.write_paired_text( - hwords, - rwords, - fname, - mode="s", - src_fname_suffix="hyp", - tgt_fname_suffix="ref", + hwords, rwords, fname, mode="s", + src_fname_suffix="hyp", tgt_fname_suffix="ref", ) logger.info("Test output written to file: %s", hyp_fn) - print(f"Test output written to file: {hyp_fn}") + print_fn(f"Test output written to file: {hyp_fn}") def _train_epoch(epoch: int): - random.shuffle(train_data) model.train() - train_iter = data.iterator.pool( - train_data, - config_data.batch_size, - key=lambda x: (len(x[0]), len(x[1])), - # key is not used if sort_within_batch is False by default - batch_size_fn=utils.batch_size_fn, - random_shuffler=data.iterator.RandomShuffler(), + train_iter = tx.data.DataIterator( + datasets["train"], + data_utils.CustomBatchingStrategy(config_data.max_batch_tokens) ) - for _, train_batch in tqdm(enumerate(train_iter)): + progress = tqdm.tqdm( + train_iter, ncols=120, + desc=f"Training epoch {epoch}", + ) + for train_batch in progress: optim.zero_grad() - in_arrays = data_utils.seq2seq_pad_concat_convert( - train_batch, device=device - ) loss = model( - encoder_input=in_arrays[0], - is_train_mode=True, - decoder_input=in_arrays[1], - labels=in_arrays[2], + encoder_input=train_batch.source, + decoder_input=train_batch.target_input, + labels=train_batch.target_output, ) loss.backward() @@ -444,40 +217,33 @@ def _train_epoch(epoch: int): if step % config_data.display_steps == 0: logger.info("step: %d, loss: %.4f", step, loss) lr = optim.param_groups[0]["lr"] - print(f"lr: {lr} step: {step}, loss: {loss:.4}") + progress.write(f"lr: {lr} step: {step}, loss: {loss:.4}") if step and step % config_data.eval_steps == 0: - _eval_epoch(epoch, mode="eval") + _eval_epoch(epoch, mode="valid", print_fn=progress.write) + progress.close() + + model_path = os.path.join(args.model_dir, args.model_fn) if args.run_mode == "train_and_evaluate": logger.info("Begin running with train_and_evaluate mode") - model_path = os.path.join(args.model_dir, args.model_fn) if os.path.exists(model_path): - logger.info("Restore latest checkpoint in", model_path) + logger.info("Restore latest checkpoint in %s", model_path) ckpt = torch.load(model_path) model.load_state_dict(ckpt["model"]) optim.load_state_dict(ckpt["optimizer"]) scheduler.load_state_dict(ckpt["scheduler"]) - _eval_epoch(0, mode="test") + _eval_epoch(0, mode="valid") for epoch in range(config_data.max_train_epoch): _train_epoch(epoch) - _eval_epoch(epoch, mode="eval") + _eval_epoch(epoch, mode="valid") - elif args.run_mode == "evaluate": - logger.info("Begin running with evaluate mode") - model_path = os.path.join(args.model_dir, args.model_fn) + elif args.run_mode in ["evaluate", "test"]: + logger.info("Begin running with %s mode", args.run_mode) logger.info("Restore latest checkpoint in %s", model_path) ckpt = torch.load(model_path) model.load_state_dict(ckpt["model"]) - _eval_epoch(0, mode="eval") - - elif args.run_mode == "test": - logger.info("Begin running with test mode") - model_path = os.path.join(args.model_dir, args.model_fn) - logger.info("Restore latest checkpoint in", model_path) - ckpt = torch.load(model_path) - model.load_state_dict(ckpt["model"]) - _eval_epoch(0, mode="test") + _eval_epoch(0, mode=("test" if args.run_mode == "test" else "valid")) else: raise ValueError(f"Unknown mode: {args.run_mode}") diff --git a/examples/transformer/utils/data_utils.py b/examples/transformer/utils/data_utils.py index bba6c8961..a576916e9 100644 --- a/examples/transformer/utils/data_utils.py +++ b/examples/transformer/utils/data_utils.py @@ -14,115 +14,107 @@ """Data read/write utilities for Transformer. """ -import codecs -import os +from typing import List, Optional, Tuple import numpy as np import torch +import texar as tx -def load_data_numpy(input_dir, prefix): - train_data = np.load( - os.path.join(input_dir, prefix + "train.npy"), - encoding="latin1", - allow_pickle=True, - ).tolist() - dev_data = np.load( - os.path.join(input_dir, prefix + "valid.npy"), - encoding="latin1", - allow_pickle=True, - ).tolist() - test_data = np.load( - os.path.join(input_dir, prefix + "test.npy"), - encoding="latin1", - allow_pickle=True, - ).tolist() - print("train data size:{}".format(len(train_data))) - return train_data, dev_data, test_data - - -def seq2seq_pad_concat_convert(xy_batch, eos_id=2, bos_id=1, device=None): - """ - Args: - xy_batch (list of tuple of two numpy.ndarray-s or cupy.ndarray-s): - xy_batch[i][0] is an array - of token ids of i-th input sentence in a minibatch. - xy_batch[i][1] is an array - of token ids of i-th target sentence in a minibatch. - The shape of each array is `(sentence length, )`. - eos_id: The index of end-of-sentence special token in the - dictionary. - bos_id: The index of begin-of-sentence special token in the - dictionary. - device: The device of the generated tensors. - - Returns: - Tuple of Converted array. - (input_sent_batch_array, target_sent_batch_input_array, - target_sent_batch_output_array). - The shape of each array is `(batchsize, max_sentence_length)`. - All sentences are padded with 0 to reach max_sentence_length. - """ - - x_seqs, y_seqs = zip(*xy_batch) - x_block = _concat_examples(x_seqs, padding=0) - y_block = _concat_examples(y_seqs, padding=0) - - # Add EOS - x_block = np.pad(x_block, ((0, 0), (0, 1)), "constant", constant_values=0) - for i_batch, seq in enumerate(x_seqs): - x_block[i_batch, len(seq)] = eos_id - - y_out_block = np.pad( - y_block, ((0, 0), (0, 1)), "constant", constant_values=0) - for i_batch, seq in enumerate(y_seqs): - y_out_block[i_batch, len(seq)] = eos_id +Example = Tuple[np.ndarray, np.ndarray] - # Add BOS in target language - y_in_block = np.pad( - y_block, ((0, 0), (1, 0)), "constant", constant_values=bos_id) - x_block = torch.tensor(x_block, dtype=torch.long, device=device) - y_in_block = torch.tensor(y_in_block, dtype=torch.long, device=device) - y_out_block = torch.tensor(y_out_block, dtype=torch.long, device=device) - return x_block, y_in_block, y_out_block +class CustomBatchingStrategy(tx.data.BatchingStrategy[Example]): + r"""Create dynamically-sized batches for paired text data so that the total + number of source and target tokens (including padding) inside each batch is + constrained. - -def source_pad_concat_convert(x_seqs, eos_id=2, device=None): - """ - This function is used when testing the model without target input. + Args: + max_tokens (int): The maximum number of source or target tokens inside + each batch. """ - x_block = _concat_examples(x_seqs, padding=0) - # add EOS - x_block = np.pad(x_block, ((0, 0), (0, 1)), "constant", constant_values=0) - for i_batch, seq in enumerate(x_seqs): - x_block[i_batch, len(seq)] = eos_id - x_block = torch.tensor(x_block, dtype=torch.long, device=device) - return x_block + def __init__(self, max_tokens: int): + self.max_tokens = max_tokens + self.max_src_len = 0 + self.max_tgt_len = 0 + self.cur_batch_size = 0 + def reset_batch(self) -> None: + self.max_src_len = 0 + self.max_tgt_len = 0 + self.cur_batch_size = 0 -def _concat_examples(arrays, padding=0): - if len(arrays) == 0: - raise ValueError("batch is empty") + def add_example(self, ex: Example) -> bool: + max_src_len = max(self.max_src_len, len(ex[0])) + max_tgt_len = max(self.max_tgt_len, len(ex[1])) + if ((self.cur_batch_size + 1) * + max(max_src_len, max_tgt_len) > self.max_tokens): + return False + self.max_src_len = max_src_len + self.max_tgt_len = max_tgt_len + self.cur_batch_size += 1 + return True - first_elem = arrays[0] - assert isinstance(first_elem, np.ndarray) - shape = np.array(arrays[0].shape, dtype=int) - for array in arrays[1:]: - if np.any(shape != array.shape): - np.maximum(shape, array.shape, shape) - shape = tuple(np.insert(shape, 0, len(arrays))) - - result = np.full(shape, padding, dtype=arrays[0].dtype) - for i, src in enumerate(arrays): - slices = tuple(slice(dim) for dim in src.shape) - result[(i,) + slices] = src - return result +class Seq2SeqData(tx.data.DataBase[Example, Example]): + r"""A dataset that reads processed paired text from dumped NumPy files. + Args: + filename (str): The path to the dumped NumPy file. + hparams: A `dict` or instance of :class:`~texar.HParams` containing + hyperparameters. See :meth:`default_hparams` for the defaults. + device: The device of the produces batches. For GPU training, set to + current CUDA device. + """ -def write_words(words_list, filename): - with codecs.open(filename, "w+", "utf-8") as myfile: - for words in words_list: - myfile.write(" ".join(words) + "\n") + def __init__(self, filename: str, hparams=None, + device: Optional[torch.device] = None): + data: List[Example] = np.load( + filename, + encoding="latin1", + allow_pickle=True).tolist() + source = tx.data.SequenceDataSource(data) + super().__init__(source, hparams, device) + + @staticmethod + def default_hparams(): + return { + **tx.data.DataBase.default_hparams(), + "bos_id": 1, + "eos_id": 2, + } + + def process(self, raw_example: Example) -> Example: # pylint: disable=no-self-use + # No-op. The data should already be processed. + return raw_example + + def collate(self, examples: List[Example]) -> tx.data.Batch: + src_seqs = [ex[0] for ex in examples] + tgt_seqs = [ex[1] for ex in examples] + max_src_len = max(map(len, src_seqs)) + max_tgt_len = max(map(len, tgt_seqs)) + # Add EOS token by setting pad_length to max length + 1. + source, _ = tx.data.padded_batch( + src_seqs, pad_length=(max_src_len + 1), + pad_value=self._hparams.eos_id, + ) + target_output, _ = tx.data.padded_batch( + tgt_seqs, pad_length=(max_tgt_len + 1), + pad_value=self._hparams.eos_id, + ) + # Add BOS token to the target inputs. + target_input = np.pad( + target_output[:, :max_tgt_len], ((0, 0), (1, 0)), + "constant", constant_values=self._hparams.bos_id, + ) + source, target_input, target_output = [ + torch.from_numpy(x).to(device=self.device) + for x in [source, target_input, target_output] + ] + return tx.data.Batch( + len(examples), + source=source, + target_input=target_input, + target_output=target_output + ) diff --git a/examples/transformer/utils/utils.py b/examples/transformer/utils/utils.py index d641e215e..7cdb23465 100644 --- a/examples/transformer/utils/utils.py +++ b/examples/transformer/utils/utils.py @@ -15,9 +15,10 @@ Helper functions for model training. """ -import random -import math import logging +import math +import random + import numpy as np import torch @@ -30,18 +31,6 @@ def set_random_seed(seed): torch.cuda.manual_seed(seed) -def batch_size_fn(new, count, size_so_far): # pylint: disable=unused-argument - if count == 1 or not hasattr(batch_size_fn, 'max_src_in_batch'): - batch_size_fn.max_src_in_batch = 0 - batch_size_fn.max_tgt_in_batch = 0 - batch_size_fn.max_src_in_batch = max( - batch_size_fn.max_src_in_batch, len(new[0]) + 1) - batch_size_fn.max_tgt_in_batch = max( - batch_size_fn.max_tgt_in_batch, len(new[1]) + 1) - return count * max(batch_size_fn.max_src_in_batch, - batch_size_fn.max_tgt_in_batch) - - def get_lr_multiplier(step: int, warmup_steps: int) -> float: r"""Calculate the learning rate multiplier given current step and the number of warm-up steps. The learning rate schedule follows a linear warm-up and @@ -62,8 +51,7 @@ def get_logger(log_path): logger.setLevel(logging.DEBUG) fh = logging.FileHandler(log_path) fh.setLevel(logging.DEBUG) - fh.setFormatter( - logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')) + fh.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')) logger.addHandler(fh) return logger diff --git a/texar/data/data/__init__.py b/texar/data/data/__init__.py index a0531ab69..c35beaf36 100644 --- a/texar/data/data/__init__.py +++ b/texar/data/data/__init__.py @@ -17,9 +17,10 @@ from texar.data.data.data_base import * from texar.data.data.data_iterators import * +from texar.data.data.dataset_utils import * from texar.data.data.mono_text_data import * -from texar.data.data.paired_text_data import * -from texar.data.data.scalar_data import * from texar.data.data.multi_aligned_data import * +from texar.data.data.paired_text_data import * from texar.data.data.record_data import * +from texar.data.data.scalar_data import * from texar.data.data.text_data_base import * diff --git a/texar/data/data/data_base.py b/texar/data/data/data_base.py index dba446d67..383912962 100644 --- a/texar/data/data/data_base.py +++ b/texar/data/data/data_base.py @@ -259,7 +259,7 @@ class DataBase(Dataset, Generic[RawExample, Example], ABC): _source: DataSource[RawExample] _dataset_size: Optional[int] - def __init__(self, source: DataSource[RawExample], hparams, + def __init__(self, source: DataSource[RawExample], hparams=None, device: Optional[torch.device] = None): self._source = source self._hparams = HParams(hparams, self.default_hparams()) @@ -608,13 +608,8 @@ def _prefetch_source(self, index: int) -> Optional[int]: def __len__(self) -> int: if self._dataset_size is None: - warnings.warn( - "The provided data source does not support random access. To " - "obtain dataset size, a full traversal must be performed. " - "This is often unnecessary and slow, consider redesigning your " - "use case.") - self._prefetch_all_source() - assert self._dataset_size is not None + raise TypeError( + "__len__ not supported for datasets with undetermined size") return self._dataset_size def process(self, raw_example: RawExample) -> Example: diff --git a/texar/data/data/data_iterators.py b/texar/data/data/data_iterators.py index d479615d3..ada4f548b 100644 --- a/texar/data/data/data_iterators.py +++ b/texar/data/data/data_iterators.py @@ -18,28 +18,39 @@ # pylint: disable=protected-access from typing import ( - Any, Dict, Iterable, Iterator, List, Mapping, Optional, Sequence, - Tuple, Union) + Any, Callable, Dict, Generic, Iterable, Iterator, List, Mapping, Optional, + Sequence, Tuple, TypeVar, Union) -import torch from torch.utils.data import DataLoader from torch.utils.data import sampler as torch_sampler from torch.utils.data.dataloader import _DataLoaderIter as torch_DataLoaderIter +import torch from texar.data.data.data_base import DataBase from texar.data.data.dataset_utils import Batch from texar.utils.types import MaybeSeq +from texar.utils.utils import ceildiv __all__ = [ "DataIterator", "TrainTestDataIterator", + "BatchingStrategy", + "TokenCountBatchingStrategy", ] DatasetsType = Union[Dict[str, DataBase], MaybeSeq[DataBase]] +Example = TypeVar('Example') + +# pylint: disable=attribute-defined-outside-init +# TODO: Remove this when Pylint fixes the bug. If the `disable` directive is not +# added, Pylint incorrectly reports this error for `self.size` in subclasses of +# `SamplerBase` in Python 3.6 due to use of the Generic class. +# See Pylint issue: https://github.com/PyCQA/pylint/issues/2981 -class SamplerBase(torch_sampler.Sampler): - r"""A subclass of :class:`~torch.utils.data.Sampler` that supports: +class SamplerBase(torch_sampler.Sampler, Generic[Example]): + r"""A subclass of :torch_docs:`~torch.utils.data.Sampler + ` that supports: - Returning raw examples when required. - Creating iterators with unknown dataset size. @@ -51,12 +62,13 @@ class SamplerBase(torch_sampler.Sampler): Args: data: The :class:`~texar.data.data.DataBase` instance. """ + size: Optional[int] - def __init__(self, data: DataBase): + def __init__(self, data: DataBase[Any, Example]): super().__init__(data) self._data = data - self.size: Optional[int] = None + self.size = None def _iterator_given_size(self, size: int) -> Iterator[int]: r"""Return an iterator that generates samples when the dataset size @@ -76,7 +88,7 @@ def _iterator_unknown_size(self) -> Iterator[int]: """ raise NotImplementedError - def __iter__(self) -> Union[Iterator[int], Iterator[Tuple[int, Any]]]: + def __iter__(self) -> Union[Iterator[int], Iterator[Tuple[int, Example]]]: r"""Return an iterator based on the dataset settings. """ self.size = self._data._dataset_size @@ -111,12 +123,13 @@ def __len__(self): raise AttributeError("Dataset size cannot be determined at this point") -class SequentialSampler(SamplerBase): +class SequentialSampler(SamplerBase[Example]): r"""Samples elements sequentially, always in the same order. Same as - :class:`torch.utils.data.SequentialSampler`. + :torch_docs:`~torch.utils.data.SequentialSampler + ` """ - def _iterator_given_size(self, size: int) -> Iterator[int]: + def _iterator_given_size(self, size: int) -> Iterator[int]: # pylint: disable=no-self-use return iter(range(size)) def _iterator_unknown_size(self) -> Iterator[int]: @@ -130,12 +143,13 @@ def _iterator_unknown_size(self) -> Iterator[int]: index += 1 -class RandomSampler(SamplerBase): +class RandomSampler(SamplerBase[Example]): r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. If with replacement, then user can specify ``num_samples`` to draw. - This class uses :class:`torch.utils.data.RandomSampler` directly. Given the + This class uses :torch_docs:`torch.utils.data.RandomSampler + ` directly. Given the nature of such shuffling, it cannot be used for iterators with unknown size. Args: @@ -145,23 +159,25 @@ class RandomSampler(SamplerBase): default=False """ - def __init__(self, data: DataBase, replacement: bool = False, + def __init__(self, data: DataBase[Any, Example], replacement: bool = False, num_samples: Optional[int] = None): super().__init__(data) self._sampler = torch_sampler.RandomSampler( data, replacement, num_samples) def _iterator_given_size(self, size: int) -> Iterator[int]: + del size # not used return iter(self._sampler) - def _iterator_unknown_size(self) -> Iterator[int]: + def _iterator_unknown_size(self) -> Iterator[int]: # pylint: disable=no-self-use raise TypeError( "RandomSampler does not support lazy data loading. To perform " "shuffling with lazy loading, use BufferShuffleSampler.") -class BufferShuffleSampler(SamplerBase): - r"""A :class:`~torch.utils.data.Sampler` that uses a shuffle buffer, as +class BufferShuffleSampler(SamplerBase[Example]): + r"""A :torch_docs:`~torch.utils.data.Sampler + ` that uses a shuffle buffer, as in TensorFlow. The buffer is first filled with data examples. Each time a sample is drawn from the buffer, and the drawn sample is replaced with the next data example. @@ -176,7 +192,7 @@ class BufferShuffleSampler(SamplerBase): more uniformly-random shuffling. """ - def __init__(self, data: DataBase, buffer_size: int): + def __init__(self, data: DataBase[Any, Example], buffer_size: int): super().__init__(data) self.buffer_size = buffer_size @@ -210,6 +226,121 @@ def _iterator_unknown_size(self) -> Iterator[int]: if buffer[x] < self.size) +# pylint: enable=attribute-defined-outside-init + + +class BatchingStrategy(Generic[Example]): + r"""Decides batch boundaries in dynamic batching. Please refer to + :class:`TokenCountBatchingStrategy` for a concrete example. + """ + + def reset_batch(self) -> None: + r"""Reset the internal state of the batching strategy. This method is + called at the start of iteration, and after each batch is yielded. + """ + raise NotImplementedError + + def add_example(self, example: Example) -> bool: + r"""Add an example into the current batch, and modify internal states + accordingly. If the example should not be added to the batch, this + method does not modify the internal state, and returns `False`. + + Args: + example: The example to add to the batch. + + Returns: + A boolean value indicating whether :attr:`example` should be added + to the batch. + """ + raise NotImplementedError + + +class TokenCountBatchingStrategy(BatchingStrategy[Example]): + r"""Create dynamically-sized batches so that the total number of tokens + inside each batch is constrained. + + Args: + max_tokens (int): The maximum number of tokens inside each batch. + max_batch_size (int, optional): The maximum number of examples for each + batch. If `None`, batches can contain arbitrary number of examples + as long as the total number of tokens does not exceed + :attr:`max_tokens`. + length_fn (callable, optional): A function taking a data example as + argument, and returning the number of tokens in the example. By + default, :python:`len` is used, which is the desired behavior if the + dataset in question is a :class:`~texar.data.MonoTextData`. + """ + + def __init__(self, max_tokens: int, max_batch_size: Optional[int] = None, + length_fn: Optional[Callable[[Example], int]] = None): + self.max_batch_size = max_batch_size + self.max_tokens = max_tokens + self.length_fn: Callable[[Example], int] + self.length_fn = length_fn or len # type: ignore + self.sum_tokens = 0 + self.cur_batch_size = 0 + + def reset_batch(self) -> None: + self.sum_tokens = 0 + self.cur_batch_size = 0 + + def add_example(self, example: Example) -> bool: + if self.cur_batch_size == self.max_batch_size: + return False + cur_tokens = self.length_fn(example) + if cur_tokens + self.sum_tokens > self.max_tokens: + return False + + self.cur_batch_size += 1 + self.sum_tokens += cur_tokens + return True + + +class DynamicBatchSampler(torch_sampler.BatchSampler, Generic[Example]): + r"""A subclass of :torch_docs:`~torch.utils.data.BatchSampler + ` that supports dynamic batching + through a user-provided :class:`BatchingStrategy`. This class is used + internally. + + Args: + dataset: The dataset to create batches from. + sampler: An instance of :class:`SamplerBase` that returns indices of + each sampled example. + strategy: An instance of :class:`BatchingStrategy` that decides whether + a batch should be yielded. + """ + + def __init__(self, dataset: DataBase[Any, Example], # pylint: disable=super-init-not-called + sampler: SamplerBase, strategy: BatchingStrategy[Example]): + self.dataset = dataset + self.sampler = sampler + self.strategy = strategy + + def __iter__(self) -> Union[Iterator[List[int]], # type: ignore + Iterator[List[Tuple[int, Example]]]]: + batch = [] # type: ignore + self.strategy.reset_batch() + for idx in self.sampler: + if isinstance(idx, tuple): + example = self.dataset[idx[0]] + else: + example = self.dataset[idx] + while not self.strategy.add_example(example): + if len(batch) == 0: + raise ValueError(f"Batching strategy refused to add " + f"example {idx} to empty batch.") + yield batch + batch = [] + self.strategy.reset_batch() + batch.append(idx) + if len(batch) > 0: + yield batch + self.strategy.reset_batch() + + def __len__(self): + raise TypeError("DynamicBatchSampler does not support __len__") + + class _DataLoaderIter(torch_DataLoaderIter): # pylint: disable=abstract-method r"""Iterates once over the DataLoader's dataset. This is almost identical to PyTorch :class:`torch.utils.data.dataloader._DataLoaderIter`, except @@ -225,7 +356,10 @@ def __init__(self, loader: DataLoader): def __next__(self): batch = super().__next__() - if (batch.batch_size < self._batch_size and + # Drop smaller final batch according to settings. Note that + # `_batch_size` could be None if dynamic batching is used. + if (self._batch_size is not None and + batch.batch_size < self._batch_size and not self.dataset.hparams.allow_smaller_final_batch): raise StopIteration return batch @@ -294,7 +428,8 @@ def __next__(self): self.dataset._add_cached_examples(indices, examples) else: batch = super().__next__() - if (batch.batch_size < self.dataset.batch_size and + if (self._batch_size is not None and + batch.batch_size < self.dataset.batch_size and not self.dataset.hparams.allow_smaller_final_batch): raise StopIteration return batch @@ -309,10 +444,13 @@ class SingleDatasetIterator(DataLoader): dataset: The dataset to iterator through. The dataset must be an instance of :class:`texar.data.DataBase`, because configurations are read from the dataset `HParams`. + batching_strategy: The batching strategy to use when performing dynamic + batching. If `None`, fixed-sized batching is used. """ dataset: DataBase - def __init__(self, dataset: DataBase): + def __init__(self, dataset: DataBase, + batching_strategy: Optional[BatchingStrategy] = None): shuffle = dataset.hparams.shuffle shuffle_buffer_size = dataset.hparams.shuffle_buffer_size sampler: SamplerBase @@ -325,10 +463,18 @@ def __init__(self, dataset: DataBase): num_parallel_calls = dataset.hparams.num_parallel_calls collate_fn = dataset._collate_and_maybe_return - super().__init__( - dataset, dataset.batch_size, sampler=sampler, collate_fn=collate_fn, - num_workers=(0 if num_parallel_calls == 1 else num_parallel_calls), - drop_last=False) + num_workers = (0 if num_parallel_calls == 1 else num_parallel_calls) + + if batching_strategy is not None: + batch_sampler = DynamicBatchSampler( + dataset, sampler, batching_strategy) + super().__init__( + dataset, batch_sampler=batch_sampler, + collate_fn=collate_fn, num_workers=num_workers) + else: + super().__init__( + dataset, batch_size=dataset.batch_size, drop_last=False, + sampler=sampler, collate_fn=collate_fn, num_workers=num_workers) def __iter__(self): if self.dataset._should_return_processed_examples: @@ -337,6 +483,14 @@ def __iter__(self): else: return _DataLoaderIter(self) + def __len__(self): + if self.batch_size is None: + raise TypeError("__len__ not supported for dynamic batching") + data_length = len(self.dataset) # may throw TypeError + if self.dataset.hparams.allow_smaller_final_batch: + return ceildiv(data_length, self.batch_size) + return data_length // self.batch_size + class DataIterator: r"""Data iterator that switches and iterates through multiple datasets. @@ -352,8 +506,13 @@ class DataIterator: - A `list` of instances of :class:`texar.data.DataBase`. The name of instances (:attr:`texar.data.DataBase.name`) must be unique. + batching_strategy: The batching strategy to use when performing dynamic + batching. If `None`, fixed-sized batching is used. + Example: + Create an iterator over two datasets and generating fixed-sized batches: + .. code-block:: python train_data = MonoTextData(hparams_train) @@ -369,11 +528,50 @@ class DataIterator: # Starts iterating through test data from the beginning for batch in iterator.get_iterator('test'): ... # Do testing with the batch. + + Dynamic batching based on total number of tokens: + + .. code-block:: python + + iterator = DataIterator( + {'train': train_data, 'test': test_data}, + batching_strategy=TokenCountBatchingStrategy(max_tokens=1000)) + + Dynamic batching with custom strategy (e.g. total number of tokens in + examples from :class:`~texar.data.PairedTextData`, including padding): + + .. code-block:: python + + class CustomBatchingStrategy(BatchingStrategy): + def __init__(self, max_tokens: int): + self.max_tokens = max_tokens + self.reset_batch() + + def reset_batch(self) -> None: + self.max_src_len = 0 + self.max_tgt_len = 0 + self.cur_batch_size = 0 + + def add_example(self, ex: Tuple[List[str], List[str]]) -> bool: + max_src_len = max(self.max_src_len, len(ex[0])) + max_tgt_len = max(self.max_tgt_len, len(ex[0])) + if (max(max_src_len + max_tgt_len) * + (self.cur_batch_size + 1) > self.max_tokens): + return False + self.max_src_len = max_src_len + self.max_tgt_len = max_tgt_len + self.cur_batch_size += 1 + return True + + iterator = DataIterator( + {'train': train_data, 'test': test_data}, + batching_strategy=CustomBatchingStrategy(max_tokens=1000)) """ # TODO: Think about whether we should support save/load. - def __init__(self, datasets: DatasetsType): + def __init__(self, datasets: DatasetsType, + batching_strategy: Optional[BatchingStrategy] = None): self._default_dataset_name = 'data' if isinstance(datasets, DataBase): datasets = {self._default_dataset_name: datasets} @@ -386,7 +584,7 @@ def __init__(self, datasets: DatasetsType): if len(datasets) < num_datasets: raise ValueError("Names of datasets must be unique.") - _datasets = {name: SingleDatasetIterator(dataset) + _datasets = {name: SingleDatasetIterator(dataset, batching_strategy) for name, dataset in datasets.items()} self._datasets = _datasets @@ -452,6 +650,9 @@ def __iter__(self) -> Iterator[Batch]: """ return self.get_iterator() + def __len__(self): + return len(self._datasets[self._validate_dataset_name(None)]) + class TrainTestDataIterator(DataIterator): r"""Data iterator that alternatives between train, val, and test datasets. diff --git a/texar/data/data/data_iterators_test.py b/texar/data/data/data_iterators_test.py index acd3ef2cd..aa4a7eb1c 100644 --- a/texar/data/data/data_iterators_test.py +++ b/texar/data/data/data_iterators_test.py @@ -7,8 +7,9 @@ from typing import List, Tuple, no_type_check import numpy as np - import torch + +import texar as tx from texar.data.data.data_base import ( DataBase, DataSource, IterDataSource, SequenceDataSource, ZipDataSource) from texar.data.data.data_iterators import ( @@ -246,6 +247,35 @@ def test_train_test_data_iterator(self): data_iterator.switch_to_val_data() self.assertTrue('Val data not provided' in str(context.exception)) + def test_dynamic_batching(self): + r"""Tests dynamic batching using :class:`texar.data.BatchingStrategy`. + """ + sent_lengths = np.random.randint(10, 20, size=(100,)) + sentences = [['a'] * length for length in sent_lengths] + data_source = tx.data.SequenceDataSource(sentences) + + class CustomData(tx.data.DataBase): + def __init__(self, source): + super().__init__(source) + + def process(self, raw_example): + return raw_example + + def collate(self, examples): + return Batch(len(examples), text=examples) + + train_data = CustomData(data_source) + + batch_size = 5 + max_tokens = 75 + strategy = tx.data.TokenCountBatchingStrategy( + max_tokens, batch_size, len) + iterator = tx.data.DataIterator(train_data, strategy) + + for batch in iterator: + self.assertLessEqual(len(batch), batch_size) + self.assertLessEqual(sum(len(s) for s in batch.text), max_tokens) + RawExample = Tuple[List[int], str] Example = Tuple[List[int], List[str]]