forked from NVIDIA/Megatron-LM
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ADLR/megatron-lm!1522 - ModelOpt Distillation API
- Loading branch information
Showing
14 changed files
with
1,066 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
136 changes: 136 additions & 0 deletions
136
examples/export/knowledge_distillation/pretrain_gpt_modelopt.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
|
||
"""Pretrain GPT.""" | ||
import os | ||
import sys | ||
from functools import partial | ||
|
||
# This file isn't located in project root, but to import, it should pretend to be. | ||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))) | ||
|
||
from megatron.core import mpu | ||
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder | ||
from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig, MockGPTDataset | ||
from megatron.core.datasets.utils import get_blend_from_list | ||
from megatron.core.enums import ModelType | ||
from megatron.core.models.gpt import GPTModel | ||
from megatron.core.utils import StragglerDetector | ||
from megatron.inference.arguments import add_modelopt_args | ||
from megatron.inference.gpt import loss_func, model_provider | ||
from megatron.training import get_args, get_timers, get_tokenizer, pretrain | ||
from megatron.training.utils import ( | ||
get_batch_on_this_cp_rank, | ||
get_batch_on_this_tp_rank, | ||
print_rank_0, | ||
) | ||
|
||
stimer = StragglerDetector() | ||
|
||
|
||
def get_batch(data_iterator): | ||
"""Generate a batch.""" | ||
|
||
# TODO: this is pretty hacky, find a better way | ||
if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()): | ||
return None, None, None, None, None | ||
|
||
# get batches based on the TP rank you are on | ||
batch = get_batch_on_this_tp_rank(data_iterator) | ||
|
||
# slice batch along sequence dimension for context parallelism | ||
batch = get_batch_on_this_cp_rank(batch) | ||
|
||
return batch.values() | ||
|
||
|
||
def forward_step(data_iterator, model: GPTModel): | ||
"""Forward training step. | ||
Args: | ||
data_iterator : Input data iterator | ||
model (GPTModel): The GPT Model | ||
""" | ||
timers = get_timers() | ||
|
||
# Get the batch. | ||
timers('batch-generator', log_level=2).start() | ||
global stimer | ||
with stimer(bdata=True): | ||
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator) | ||
timers('batch-generator').stop() | ||
|
||
with stimer: | ||
output_tensor = model(tokens, position_ids, attention_mask, labels=labels) | ||
|
||
# [ModelOpt]: model is needed to access ModelOpt distillation losses | ||
return output_tensor, partial(loss_func, loss_mask, model) | ||
|
||
|
||
def is_dataset_built_on_rank(): | ||
return ( | ||
mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage() | ||
) and mpu.get_tensor_model_parallel_rank() == 0 | ||
|
||
|
||
def core_gpt_dataset_config_from_args(args): | ||
tokenizer = get_tokenizer() | ||
|
||
return GPTDatasetConfig( | ||
random_seed=args.seed, | ||
sequence_length=args.seq_length, | ||
blend=get_blend_from_list(args.data_path), | ||
blend_per_split=[ | ||
get_blend_from_list(args.train_data_path), | ||
get_blend_from_list(args.valid_data_path), | ||
get_blend_from_list(args.test_data_path), | ||
], | ||
split=args.split, | ||
num_dataset_builder_threads=args.num_dataset_builder_threads, | ||
path_to_cache=args.data_cache_path, | ||
mmap_bin_files=args.mmap_bin_files, | ||
tokenizer=tokenizer, | ||
reset_position_ids=args.reset_position_ids, | ||
reset_attention_mask=args.reset_attention_mask, | ||
eod_mask_loss=args.eod_mask_loss, | ||
create_attention_mask=args.create_attention_mask_in_dataloader, | ||
) | ||
|
||
|
||
def train_valid_test_datasets_provider(train_val_test_num_samples): | ||
"""Build the train test and validation datasets. | ||
Args: | ||
train_val_test_num_samples : A list containing the number of samples in train test and validation. | ||
""" | ||
args = get_args() | ||
|
||
config = core_gpt_dataset_config_from_args(args) | ||
|
||
if args.mock_data: | ||
dataset_type = MockGPTDataset | ||
else: | ||
dataset_type = GPTDataset | ||
|
||
print_rank_0("> building train, validation, and test datasets for GPT ...") | ||
|
||
train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder( | ||
dataset_type, train_val_test_num_samples, is_dataset_built_on_rank, config | ||
).build() | ||
|
||
print_rank_0("> finished creating GPT datasets ...") | ||
|
||
return train_ds, valid_ds, test_ds | ||
|
||
|
||
if __name__ == "__main__": | ||
# Temporary for transition to core datasets | ||
train_valid_test_datasets_provider.is_distributed = True | ||
|
||
pretrain( | ||
train_valid_test_datasets_provider, | ||
model_provider, | ||
ModelType.encoder_or_decoder, | ||
forward_step, | ||
args_defaults={"tokenizer_type": "GPT2BPETokenizer"}, | ||
extra_args_provider=add_modelopt_args, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. |
Oops, something went wrong.