Skip to content

Commit

Permalink
Merge branch 'mblaz/fpsl-model-weights' into mblaz/fpsl-model-weights…
Browse files Browse the repository at this point in the history
…-load
  • Loading branch information
mikolajblaz committed Apr 18, 2024
2 parents 48efe4b + 68eac5c commit 2369270
Show file tree
Hide file tree
Showing 27 changed files with 936 additions and 268 deletions.
27 changes: 17 additions & 10 deletions docs/llama2.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,31 @@ Users must first apply for access to download the Llama-2 checkpoints either dir

# Convert checkpoint format

Depending on which checkpoint format is downloaded (Meta or HF), one or two steps must be taken to convert to Megatron format.
We recommend passing `--dtype bf16` for training or finetuning. Inference can be done in bfloat16 or float16.

### Meta format

The Meta format checkpoints must first be converted to HF format before converting to Megatron format. The `transformers` package is required for the first step, and must have version >=4.31.0 (e.g., `pip install transformers>=4.31.0`). (**Note**: we have specifically tested with versions `4.31.0` and `4.32.0`; your experience may vary with newer versions.) Assuming the downloaded checkpoints are in `$CHECKPOINT_DIR` (with separate sub-directories for 7B, 13B, 70B, etc.), the following example command can be used to convert from Llama-2 format to HF format:
The Meta format checkpoints are converted to HF format as an intermediate step before converting to Megatron format. The `transformers` package is required, and must have version >=4.31.0 (e.g., `pip install transformers>=4.31.0`). (**Note**: we have specifically tested with versions `4.31.0` and `4.32.0`; your experience may vary with newer versions.) Assuming the downloaded checkpoints are in `$CHECKPOINT_DIR` (with separate sub-directories for 7B, 13B, 70B, etc.), the following example command can be used to convert from Llama-2 format to HF format in bfloat16:

```
$>: python $LIB_DIR/transformers/models/llama/convert_llama_weights_to_hf.py \
> --input_dir $LLAMA_FORMAT_DIR \
> --output_dir $HF_FORMAT_DIR \
> --model_size 7B`
python tools/checkpoint/util.py --model-type GPT \
> --loader llama2 \
> --saver megatron \
> --checkpoint-type meta
> --model_size 7B \
> --load-dir $LLAMA_META_FORMAT_DIR \
> --save-dir ${MEGATRON_FORMAT_DIR} \
> --tokenizer-model ${TOKENIZER_MODEL} \
> --target-tensor-parallel-size ${TP} \
> --target-pipeline-parallel-size ${PP} \
> --bf16
```

Valid values for `--model_size` include `7B`, `13B`, and `70B` (for pretrained-only models), and `7Bf`, `13Bf`, and `70Bf` (for chat-finetuned models). Use `python convert_llama_weights_to_hf.py --help` for additional argument details. Once the checkpoints have been converted to HF format, proceed to the Huggingface format section below.
Valid values for `--model_size` include `7B`, `13B`, and `70B` (for pretrained-only models), and `7Bf`, `13Bf`, and `70Bf` (for chat-finetuned models).

### Huggingface format

The HF checkpoints can be converted to Megatron format by using Megatron's own Llama-2 checkpoint converter for HF format (see script `tools/checkpoint/loader_llama2_hf.py`). One important argument that must be set correctly is the tensor parallel size (`TP`) for each model. The following table shows these values:
The HF checkpoints can be converted to Megatron format by using Megatron's own Llama-2 checkpoint converter for HF format (see script `tools/checkpoint/loader_llama2.py`). One important argument that must be set correctly is the tensor parallel size (`TP`) for each model. The following table shows these values:

| Model size | Tensor parallel size (`TP`) |
| ---------- | --------------------------- |
Expand All @@ -57,9 +64,10 @@ Using these values for `TP`, along with the path to the Llama-2 tokenizer model
```
$>: python tools/checkpoint/util.py \
> --model-type GPT \
> --loader llama2_hf \
> --loader llama2 \
> --saver megatron \
> --target-tensor-parallel-size ${TP} \
> --checkpoint-type hf
> --load-dir ${HF_FORMAT_DIR} \
> --save-dir ${MEGATRON_FORMAT_DIR} \
> --tokenizer-model ${TOKENIZER_MODEL}
Expand All @@ -85,7 +93,6 @@ If loading for either inference or finetuning, use the following arguments:
--use-checkpoint-args \
--no-load-optim \
--no-load-rng \
--fp16 \
--untie-embeddings-and-output-weights \
--use-rotary-position-embeddings \
--normalization RMSNorm \
Expand Down
5 changes: 3 additions & 2 deletions examples/detxoify_lm/finetune_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
from megatron.core.datasets.blended_megatron_dataset_config import GPTDatasetConfig
from megatron.core.datasets.gpt_dataset import GPTDataset
from megatron.core.datasets.utils import get_blend_from_list
from megatron.legacy.model import GPTModel
from megatron.core.enums import ModelType
from megatron.training import pretrain
Expand Down Expand Up @@ -107,7 +108,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
train_val_test_num_samples,
lambda: True,
GPTDatasetConfig(
blend=args.data_path,
blend=get_blend_from_list(args.data_path),
split=args.split,
random_seed=args.seed,
sequence_length=args.seq_length,
Expand All @@ -122,7 +123,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
train_val_test_num_samples,
lambda: True,
GPTDatasetConfig(
blend=args.data_path2,
blend=get_blend_from_list(args.data_path2),
split="98,2,0",
random_seed=1234,
sequence_length=2048,
Expand Down
Empty file removed megatron/__init__.py
Empty file.
67 changes: 37 additions & 30 deletions megatron/core/datasets/blended_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
import time
from collections import OrderedDict
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union

import numpy
import torch
Expand All @@ -26,9 +26,9 @@ class BlendedDataset(torch.utils.data.Dataset):
Args:
datasets (List[MegatronDataset]): The MegatronDataset instances to blend
weights (List[float]): The weights which determines the dataset blend ratios
weights (List[Union[int, float]]): The weights that determine the dataset blend ratios
size (int): The number of samples to draw from the blend
size (Optional[int]): The number of samples to draw from the blend. If None, for each dataset index idx draw exactly weights[idx] samples from datasets[idx].
config (BlendedMegatronDatasetConfig): The config
Expand All @@ -39,32 +39,38 @@ class BlendedDataset(torch.utils.data.Dataset):
def __init__(
self,
datasets: List[MegatronDataset],
weights: List[float],
size: int,
weights: List[Union[int, float]],
size: Optional[int],
config: BlendedMegatronDatasetConfig,
) -> None:
assert len(datasets) < 32767
assert len(datasets) == len(weights)
assert numpy.isclose(sum(weights), 1.0)
assert len(datasets) < 32767
assert all(map(lambda _: type(_) == type(datasets[0]), datasets))
assert all(map(lambda _: _.index_split == datasets[0].index_split, datasets))
assert all(map(lambda _: _ > 0, weights))
assert all(map(lambda _: type(_) == type(weights[0]), weights))
if size is None and isinstance(weights[0], float):
assert all(map(lambda _: _ == int(_), weights))

# Alert user to unnecessary blending
if len(datasets) == 1:
log_single_rank(
logger, logging.WARNING, f"Building a BlendedDataset for a single MegatronDataset"
)

# Redundant normalization for bitwise identical comparison with Megatron-LM
weights = normalize(weights)
if size is not None:
weights = normalize(weights)

self.datasets = datasets
self.split = self.datasets[0].index_split
self.weights = weights
self.size = size
self.config = config

unique_identifiers = OrderedDict()
unique_identifiers["class"] = type(self).__name__
unique_identifiers["datasets"] = [dataset.unique_identifiers for dataset in self.datasets]
unique_identifiers["split"] = self.split.name
unique_identifiers["weights"] = self.weights
unique_identifiers["size"] = self.size

Expand All @@ -77,16 +83,8 @@ def __init__(

self.dataset_index, self.dataset_sample_index = self._build_indices()

# Check size
_ = self[self.size - 1]
try:
_ = self[self.size]
raise RuntimeError(f"{type(self).__name__} size is improperly bounded")
except IndexError:
log_single_rank(logger, logging.INFO, f"> {type(self).__name__} length: {len(self)}")

def __len__(self) -> int:
return self.size
return self.dataset_index.shape[0]

def __getitem__(self, idx: int) -> Dict[str, Union[int, numpy.ndarray]]:
dataset_id = self.dataset_index[idx]
Expand All @@ -110,7 +108,8 @@ def _build_indices(self) -> Tuple[numpy.ndarray, numpy.ndarray]:

if path_to_cache:
get_path_to = lambda suffix: os.path.join(
path_to_cache, f"{self.unique_description_hash}-{type(self).__name__}-{suffix}"
path_to_cache,
f"{self.unique_description_hash}-{type(self).__name__}-{self.split.name}-{suffix}",
)
path_to_description = get_path_to("description.txt")
path_to_dataset_index = get_path_to("dataset_index.npy")
Expand All @@ -136,16 +135,24 @@ def _build_indices(self) -> Tuple[numpy.ndarray, numpy.ndarray]:
t_beg = time.time()
from megatron.core.datasets import helpers

dataset_index = numpy.zeros(self.size, dtype=numpy.int16)
dataset_sample_index = numpy.zeros(self.size, dtype=numpy.int64)
helpers.build_blending_indices(
dataset_index,
dataset_sample_index,
self.weights,
len(self.datasets),
self.size,
_VERBOSE,
)
if self.size is not None:
dataset_index = numpy.zeros(self.size, dtype=numpy.int16)
dataset_sample_index = numpy.zeros(self.size, dtype=numpy.int64)
helpers.build_blending_indices(
dataset_index,
dataset_sample_index,
self.weights,
len(self.datasets),
self.size,
_VERBOSE,
)
else:
size = sum(self.weights)
dataset_index = numpy.zeros(size, dtype=numpy.int16)
dataset_sample_index = numpy.zeros(size, dtype=numpy.int64)
helpers.build_exhaustive_blending_indices(
dataset_index, dataset_sample_index, self.weights, len(self.datasets)
)

if path_to_cache:
os.makedirs(path_to_cache, exist_ok=True)
Expand All @@ -159,7 +166,7 @@ def _build_indices(self) -> Tuple[numpy.ndarray, numpy.ndarray]:
log_single_rank(
logger,
logging.WARNING,
"Unable to save the indexes because path_to_cache is None",
"Unable to save the blending indexes because path_to_cache is None",
)

t_end = time.time()
Expand Down
Loading

0 comments on commit 2369270

Please sign in to comment.