Skip to content

Commit

Permalink
Retro
Browse files Browse the repository at this point in the history
  • Loading branch information
lmcafee-nvidia authored and jaredcasper committed Feb 22, 2023
1 parent ef59b68 commit 17a6044
Show file tree
Hide file tree
Showing 56 changed files with 5,990 additions and 113 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ __pycache__
build
.coverage_*
*.egg-info
*~
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ The following table shows both model (MFU) and hardware (HFU) FLOPs utilization
* [Distributed Pretraining](#distributed-pretraining)
* [Activation Checkpointing and Recomputation](#activation-checkpointing-and-recomputation)
* [Distributed Optimizer](#distributed-optimizer)
* [FlashAttention](#flashattention)
* [GPT-3 Example](#gpt-3-example)
* [Retro](#retro)
* [Evaluation and Tasks](#evaluation-and-tasks)
* [GPT Text Generation](#gpt-text-generation)
* [GPT Evaluation](#gpt-evaluation)
Expand Down Expand Up @@ -323,6 +325,19 @@ In `examples/pretrain_gpt3_175B.sh` we have provided an example of how to config
With full global batch size of 1536 on 1024 A100 GPUs, each iteration takes around 32 seconds resulting in 138 teraFLOPs per GPU which is 44% of the theoretical peak FLOPs.


## Retro

See:

- `tools/retro/README.md` for an overview.
- `tools/retro/examples/get_preprocess_cmd.sh` for an example of common preprocessing arguments.
- `tools/retro/examples/preprocess_data.sh` for an example of how to preprocess data.
- `tools/retro/examples/pretrain_model.sh` for an example of how to pretrain a model.

Retro is a retrieval-enhanced model that is based on GPT. As described in [Improving language models by retrieving from trillions of tokens](https://arxiv.org/abs/2112.04426), Retro retrieves from a database of document chunks by performing locality search using a sample's tokens. The retrieval database can be large -- often billions or even trillions of tokens -- and provides a more efficient storage mechanism of factual knowledge, when compared to storing factual knowledge implicitly within the network's parameters.

Using Retro requires two steps: 1) preprocessing the retrieval database and pretraining neighbors, and 2) pretraining a model using this data. Please see `tools/retro/README.md` for a detailed overview.

<!--
## REALM Pipeline
We are working on implementing the [REALM](https://arxiv.org/pdf/2002.08909.pdf) system. The following sections (will) reflect the three stages of training it. For now it's just the ICT code.
Expand Down
4 changes: 2 additions & 2 deletions examples/detxoify_lm/finetune_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir, os.path.pardir)))
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron import get_tokenizer
from megatron import mpu
from megatron import print_rank_0
from megatron.core import mpu
from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.gpt_dataset import build_train_valid_test_datasets
from megatron.model import GPTModel, ModelType
Expand Down
4 changes: 2 additions & 2 deletions examples/detxoify_lm/generate_samples_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
os.path.pardir, os.path.pardir)))
import torch
from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron import mpu
from megatron import print_rank_0
from megatron.checkpointing import load_checkpoint
from megatron.core import mpu
from megatron.initialize import initialize_megatron
from megatron.model import GPTModel
from megatron.training import get_model
Expand Down
3 changes: 2 additions & 1 deletion megatron/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.

import torch

from .global_vars import get_args
from .global_vars import get_args, get_retro_args
from .global_vars import get_current_global_batch_size
from .global_vars import get_num_microbatches
from .global_vars import get_signal_handler
Expand Down
95 changes: 87 additions & 8 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@
"""Megatron arguments."""

import argparse
import json
import os

import torch
import types

from megatron.global_vars import set_retro_args, get_retro_args
from tools.retro.utils import get_args_path as get_retro_args_path


def parse_args(extra_args_provider=None, ignore_unknown_args=False):
"""Parse all arguments."""
Expand All @@ -29,6 +34,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
parser = _add_logging_args(parser)
parser = _add_inference_args(parser)
parser = _add_transformer_engine_args(parser)
parser = _add_retro_args(parser)

# Custom arguments.
if extra_args_provider is not None:
Expand Down Expand Up @@ -333,7 +339,6 @@ def validate_args(args, defaults={}):
if args.sequence_parallel:
args.async_tensor_model_parallel_allreduce = False


if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
if args.sequence_parallel:
raise RuntimeError(
Expand All @@ -344,23 +349,39 @@ def validate_args(args, defaults={}):
"Using async gradient all reduce requires setting the environment "
"variable CUDA_DEVICE_MAX_CONNECTIONS to 1")

# Load retro args.
if args.retro_workdir:
retro_args_path = get_retro_args_path(args.retro_workdir)
if os.path.exists(retro_args_path):
with open(retro_args_path) as f:
retro_args = types.SimpleNamespace(**json.load(f))
retro_args.retro_return_doc_ids = args.retro_return_doc_ids
retro_args.retro_gpt_retrieved_length = \
args.retro_num_retrieved_chunks * \
retro_args.retro_gpt_chunk_length
set_retro_args(retro_args)

# Print arguments.
_print_args("arguments", args)
retro_args = get_retro_args()
if retro_args and args != retro_args:
_print_args("retro arguments", types.SimpleNamespace(**{k:v for k,v in vars(retro_args).items() if k.startswith("retro")}, rank=args.rank))

_print_args(args)
return args


def _print_args(args):
def _print_args(title, args):
"""Print arguments."""
if args.rank == 0:
print('------------------------ arguments ------------------------',
print(f'------------------------ {title} ------------------------',
flush=True)
str_list = []
for arg in vars(args):
dots = '.' * (48 - len(arg))
str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg)))
for arg in sorted(str_list, key=lambda x: x.lower()):
print(arg, flush=True)
print('-------------------- end of arguments ---------------------',
print(f'-------------------- end of {title} ---------------------',
flush=True)


Expand Down Expand Up @@ -403,15 +424,67 @@ def _add_inference_args(parser):
help='During inference, if batch-size times '
'sequence-length is smaller than this threshold '
'then we will not use pipelining, otherwise we will.')

group.add_argument('--max-tokens-to-oom',
type=int, default=12000,
help='Maximum number of tokens during inference'
'tokens here is # in prompt + # to generate'
'Allows us to throw an error before OOM crashes server')
group.add_argument('--output-bert-embeddings', action='store_true',
help='Output Bert embeddings (via mean pooling) from '
'model, rather than its binary head output or entire '
'hidden batch.')
group.add_argument('--bert-embedder-type', default="megatron",
choices=["megatron", "huggingface"],
help='Select either Megatron or Huggingface as the '
'Bert embedder.')

return parser



def _add_retro_args(parser):
group = parser.add_argument_group(title='retro')

group.add_argument('--retro-workdir', default=None,
help='Retro working directory, which contains the '
'preprocessed data for for pretraining. This directory '
'is built during preprocessing (see '
'tools/retro/README.md), and contains subdirectories '
'for the chunk database and pretraining neighbors.')
group.add_argument('--retro-add-retriever',
action='store_true', default=False,
help='Add a retriever to the transformer, for use in '
'pretraining a Retro model.')
group.add_argument('--retro-cyclic-train-iters', type=int, default=None,
help='Set number of training iterations for cyclic '
'Retro training.')
group.add_argument('--retro-encoder-layers', type=int, default=2,
help='Number of layers to use for the retrieval '
'encoder.')
group.add_argument('--retro-encoder-hidden-dropout',
type=float, default=0.1, help='Hidden dropout for '
'retrieval encoder.')
group.add_argument('--retro-encoder-attention-dropout',
type=float, default=0.1, help='Attention dropout for '
'retrieval encoder.')
group.add_argument("--retro-num-neighbors", type=int, default=2,
help='Number of neighbors to retrieve during '
'pretraining.')
group.add_argument("--retro-num-retrieved-chunks", type=int, default=2,
help='Number of chunks to retrieve from the retrieval '
'database.')
group.add_argument("--retro-return-doc-ids", action="store_true",
help="Turn this on when preprocessing retro data.")

# Enforce argument naming convention.
for action in group._group_actions:
prefix = action.dest.split("_")[0]
assert prefix == "retro", \
"Retro args must be prefixed with '--retro-*', for consistent " \
"styling. Please fix '%s'." % ", ".join(action.option_strings)

return parser


def _add_network_size_args(parser):
group = parser.add_argument_group(title='network size')

Expand Down Expand Up @@ -775,6 +848,10 @@ def _add_checkpointing_args(parser):
group.add_argument('--use-checkpoint-args', action='store_true',
help='Override any command line arguments with arguments '
'from the checkpoint')
group.add_argument('--exit-on-missing-checkpoint', action='store_true',
help="If '--load' is set, but checkpoint is not found "
"(e.g., path typo), then exit instead of random "
"initialization.")

return parser

Expand Down Expand Up @@ -835,6 +912,8 @@ def _add_distributed_args(parser):
group.add_argument('--distributed-backend', default='nccl',
choices=['nccl', 'gloo'],
help='Which backend to use for distributed training.')
group.add_argument('--distributed-timeout-minutes', type=int, default=10,
help='Timeout minutes for torch.distributed.')
group.add_argument('--DDP-impl', default='local',
choices=['local', 'torch'],
help='which DistributedDataParallel implementation '
Expand Down
9 changes: 9 additions & 0 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,16 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
use_distributed_optimizer=args.use_distributed_optimizer,
rank0=False)

# Checkpoint not loaded.
if model_state_dict is None:

# Conditionally exit at this point.
if args.exit_on_missing_checkpoint:
print_rank_0(">> '--exit-on-missing-checkpoint' set ... exiting. <<")
torch.distributed.barrier()
sys.exit()

# Iteration defaults to 0.
return 0

# set checkpoint version
Expand Down
5 changes: 3 additions & 2 deletions megatron/data/bert_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
create_masked_lm_predictions
)


class BertDataset(torch.utils.data.Dataset):

def __init__(self, name, indexed_dataset, data_prefix,
Expand Down Expand Up @@ -156,7 +155,9 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
# Some checks.
num_tokens = len(tokens)
padding_length = max_seq_length - num_tokens
assert padding_length >= 0
assert padding_length >= 0, \
f"num_tokens ({num_tokens}) is greater than " \
"max_seq_length ({max_seq_length})."
assert len(tokentypes) == num_tokens
assert len(masked_positions) == len(masked_labels)

Expand Down
5 changes: 4 additions & 1 deletion megatron/data/blendable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,7 @@ def __len__(self):
def __getitem__(self, idx):
dataset_idx = self.dataset_index[idx]
sample_idx = self.dataset_sample_index[idx]
return self.datasets[dataset_idx][sample_idx]
return {
"dataset_idx" : dataset_idx,
**self.datasets[dataset_idx][sample_idx],
}
5 changes: 3 additions & 2 deletions megatron/data/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,15 +452,16 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
prefixes[i], data_impl, splits_string,
datasets_train_valid_test_num_samples[i],
max_seq_length, masked_lm_prob, short_seq_prob,
seed, skip_warmup, binary_head, dataset_type=dataset_type)
seed, skip_warmup, binary_head, max_seq_length_dec,
dataset_type=dataset_type)
if train_ds:
train_datasets.append(train_ds)
if valid_ds:
valid_datasets.append(valid_ds)
if test_ds:
test_datasets.append(test_ds)

# Blend.
# Blend.
blending_train_dataset = None
if train_datasets:
blending_train_dataset = BlendableDataset(train_datasets, weights)
Expand Down
Loading

0 comments on commit 17a6044

Please sign in to comment.