Skip to content

Commit

Permalink
Merge branch 'xren/dataset_fix' into 'main'
Browse files Browse the repository at this point in the history
skip unnecessary attention mask generation

See merge request ADLR/megatron-lm!1259
  • Loading branch information
jaredcasper committed Mar 26, 2024
2 parents 6835eb7 + e7f376c commit 9de386d
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 41 deletions.
166 changes: 128 additions & 38 deletions megatron/core/datasets/gpt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
import sys
import time
from dataclasses import dataclass
from typing import Dict, Tuple
from typing import Dict, Optional, Tuple

import numpy
import torch

from megatron.core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig
from megatron.core.datasets.indexed_dataset import IndexedDataset
from megatron.core.datasets.megatron_dataset import MegatronDataset, MockDataset
from megatron.core.datasets.megatron_dataset import LowLevelDataset, MegatronDataset, MockDataset
from megatron.core.datasets.utils import Split, log_single_rank

logger = logging.getLogger(__name__)
Expand All @@ -29,6 +29,8 @@ class GPTDatasetConfig(BlendedMegatronDatasetConfig):
eod_mask_loss (bool): Option to enable the EOD mask loss
create_attention_mask (bool): Option to enable the attention masks generation. Can be disabled if attention kernel generates masks by itself.
vocab_size (int): Size of vocabulary
"""
Expand All @@ -39,6 +41,8 @@ class GPTDatasetConfig(BlendedMegatronDatasetConfig):

eod_mask_loss: bool = None

create_attention_mask: bool = True

vocab_size: int = sys.maxsize

def __post_init__(self) -> None:
Expand All @@ -57,6 +61,29 @@ class MockGPTDataset(MockDataset):
"""The mock GPT dataset
"""

def __init__(
self,
dataset: Optional[LowLevelDataset],
dataset_path: Optional[str],
indices: Optional[numpy.ndarray],
num_samples: int,
index_split: Split,
config: BlendedMegatronDatasetConfig,
) -> None:
super().__init__(dataset, dataset_path, indices, num_samples, index_split, config)

self.masks_and_position_ids_are_cacheable = not any(
[
self.config.reset_position_ids,
self.config.reset_attention_mask,
self.config.eod_mask_loss,
]
)
self.masks_and_position_ids_are_cached = False
self.cached_attention_mask = None
self.cached_loss_mask = None
self.cached_position_ids = None

def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""Return a sequence_length + 1 token sequence consisting of the following:
- (1) S, the RNG length-sentinel in the range [0, sequence_length)
Expand Down Expand Up @@ -89,21 +116,43 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
labels = text[1:].contiguous()
tokens = text[:-1].contiguous()

attention_mask, loss_mask, position_ids = _get_ltor_masks_and_position_ids(
tokens,
eod,
self.config.reset_position_ids,
self.config.reset_attention_mask,
self.config.eod_mask_loss,
)

return {
"tokens": tokens,
"labels": labels,
"attention_mask": attention_mask,
"loss_mask": loss_mask,
"position_ids": position_ids,
}
if (
not self.masks_and_position_ids_are_cacheable
or not self.masks_and_position_ids_are_cached
):
attention_mask, loss_mask, position_ids = _get_ltor_masks_and_position_ids(
tokens,
eod,
self.config.reset_position_ids,
self.config.reset_attention_mask,
self.config.eod_mask_loss,
self.config.create_attention_mask,
)
if self.masks_and_position_ids_are_cacheable:
self.cached_attention_mask = attention_mask
self.cached_loss_mask = loss_mask
self.cached_position_ids = position_ids
self.masks_and_position_ids_are_cached = True
else:
attention_mask = self.cached_attention_mask
loss_mask = self.cached_loss_mask
position_ids = self.cached_position_ids

if self.config.create_attention_mask:
return {
"tokens": tokens,
"labels": labels,
"attention_mask": attention_mask,
"loss_mask": loss_mask,
"position_ids": position_ids,
}
else:
return {
"tokens": tokens,
"labels": labels,
"loss_mask": loss_mask,
"position_ids": position_ids,
}


class GPTDataset(MegatronDataset):
Expand Down Expand Up @@ -138,6 +187,18 @@ def __init__(

self.vocab_size = config.vocab_size

self.masks_and_position_ids_are_cacheable = not any(
[
self.config.reset_position_ids,
self.config.reset_attention_mask,
self.config.eod_mask_loss,
]
)
self.masks_and_position_ids_are_cached = False
self.cached_attention_mask = None
self.cached_loss_mask = None
self.cached_position_ids = None

def _finalize(self) -> None:
"""Abstract method implementation
Expand Down Expand Up @@ -205,21 +266,43 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
tokens >= self.vocab_size
), "An input token is out of bounds of the tokenizer vocabulary"

attention_mask, loss_mask, position_ids = _get_ltor_masks_and_position_ids(
tokens,
self.config.tokenizer.eod,
self.config.reset_position_ids,
self.config.reset_attention_mask,
self.config.eod_mask_loss,
)

return {
"tokens": tokens,
"labels": labels,
"attention_mask": attention_mask,
"loss_mask": loss_mask,
"position_ids": position_ids,
}
if (
not self.masks_and_position_ids_are_cacheable
or not self.masks_and_position_ids_are_cached
):
attention_mask, loss_mask, position_ids = _get_ltor_masks_and_position_ids(
tokens,
self.config.tokenizer.eod,
self.config.reset_position_ids,
self.config.reset_attention_mask,
self.config.eod_mask_loss,
self.config.create_attention_mask,
)
if self.masks_and_position_ids_are_cacheable:
self.cached_attention_mask = attention_mask
self.cached_loss_mask = loss_mask
self.cached_position_ids = position_ids
self.masks_and_position_ids_are_cached = True
else:
attention_mask = self.cached_attention_mask
loss_mask = self.cached_loss_mask
position_ids = self.cached_position_ids

if self.config.create_attention_mask:
return {
"tokens": tokens,
"labels": labels,
"attention_mask": attention_mask,
"loss_mask": loss_mask,
"position_ids": position_ids,
}
else:
return {
"tokens": tokens,
"labels": labels,
"loss_mask": loss_mask,
"position_ids": position_ids,
}

def _query_document_sample_shuffle_indices(
self, idx: int
Expand Down Expand Up @@ -575,6 +658,7 @@ def _get_ltor_masks_and_position_ids(
reset_position_ids: bool,
reset_attention_mask: bool,
eod_mask_loss: bool,
create_attention_mask: bool,
):
"""Build masks and position id for left to right model.
Expand All @@ -589,6 +673,8 @@ def _get_ltor_masks_and_position_ids(
eod_mask_loss (bool): Switch to enable the EOD mask loss
create_attention_mask (bool): Switch to enable the attention masks generation. Can be disabled if attention kernel generates masks by itself.
Returns:
torch.Tensor: Attention mask needed to be used for Attention
Expand All @@ -598,9 +684,12 @@ def _get_ltor_masks_and_position_ids(
"""
seq_length = data.numel()

attention_mask = torch.tril(torch.ones((seq_length, seq_length), device=data.device)).unsqueeze(
0
)
if create_attention_mask:
attention_mask = torch.tril(
torch.ones((seq_length, seq_length), device=data.device)
).unsqueeze(0)
else:
attention_mask = None

# Loss mask.
loss_mask = torch.ones(seq_length, dtype=torch.float, device=data.device)
Expand All @@ -625,14 +714,15 @@ def _get_ltor_masks_and_position_ids(
for j in range(eod_index.numel()):
i = eod_index[j]
# Mask attention loss.
if reset_attention_mask:
if reset_attention_mask and attention_mask is not None:
attention_mask[0, (i + 1) :, : (i + 1)] = 0
# Reset positions.
if reset_position_ids:
position_ids[(i + 1) :] -= i + 1 - prev_index
prev_index = i + 1

# Convert attention mask to binary:
attention_mask = attention_mask < 0.5
if attention_mask is not None:
# Convert attention mask to binary:
attention_mask = attention_mask < 0.5

return attention_mask, loss_mask, position_ids
3 changes: 3 additions & 0 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1402,6 +1402,9 @@ def _add_data_args(parser):
'end-of-document token.')
group.add_argument('--eod-mask-loss', action='store_true',
help='Mask loss for the end of document tokens.')
group.add_argument('--no-create-attention-mask-in-dataloader', action='store_false',
help='If set, do not create attention_masks in dataloader.',
dest='create_attention_mask_in_dataloader')

return parser

Expand Down
12 changes: 9 additions & 3 deletions megatron/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,8 @@ def get_batch_on_this_tp_rank(data_iterator):
args = get_args()

def _broadcast(item):
torch.distributed.broadcast(item, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group())
if item is not None:
torch.distributed.broadcast(item, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group())

if mpu.get_tensor_model_parallel_rank() == 0:

Expand All @@ -291,7 +292,7 @@ def _broadcast(item):
'tokens': data["tokens"].cuda(non_blocking = True),
'labels': data["labels"].cuda(non_blocking = True),
'loss_mask': data["loss_mask"].cuda(non_blocking = True),
'attention_mask': data["attention_mask"].cuda(non_blocking = True),
'attention_mask': None if "attention_mask" not in data else data["attention_mask"].cuda(non_blocking = True),
'position_ids': data["position_ids"].cuda(non_blocking = True)
}

Expand All @@ -317,7 +318,12 @@ def _broadcast(item):
tokens=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device())
labels=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device())
loss_mask=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.float32 , device = torch.cuda.current_device())
attention_mask=torch.empty((args.micro_batch_size,1,args.seq_length,args.seq_length), dtype = torch.bool , device = torch.cuda.current_device())
if args.create_attention_mask_in_dataloader:
attention_mask=torch.empty(
(args.micro_batch_size,1,args.seq_length,args.seq_length), dtype = torch.bool , device = torch.cuda.current_device()
)
else:
attention_mask=None
position_ids=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device())

if args.pipeline_model_parallel_size == 1:
Expand Down
1 change: 1 addition & 0 deletions pretrain_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def core_gpt_dataset_config_from_args(args):
reset_position_ids=args.reset_position_ids,
reset_attention_mask=args.reset_attention_mask,
eod_mask_loss=args.eod_mask_loss,
create_attention_mask=args.create_attention_mask_in_dataloader,
vocab_size=get_tokenizer().vocab_size,
)

Expand Down
1 change: 1 addition & 0 deletions tests/functional_tests/jet_recipes/MR-gpt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ spec:
products:
# MCore
- {tp_size: [2], pp_size: [2]}
- {tp_size: [2], pp_size: [2], extra_args: ["--no-create-attention-mask-in-dataloader"], args_meta: ["no_create_attention_mask_in_dataloader"]}
- {tp_size: [2], pp_size: [2], extra_args: ["--no-mmap-bin-files"], args_meta: ["no_mmap_bin_files"]}
- {tp_size: [1], pp_size: [4], vp_size: [1]}
- {tp_size: [4], pp_size: [1], extra_args: ["--qk-layernorm --test-mode"]}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [10.92392, 10.93645, 10.89657, 10.86919, 10.74782, 10.658, 10.15864, 10.24906, 10.15088, 9.83933]}, "num-zeros": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [1735.0, 1861.0, 2111.0, 1844.0, 1762.0, 1858.0, 1554.0, 2031.0, 2309.0, 2225.0]}, "iteration_timing_avg": 0.15396205882352942}

0 comments on commit 9de386d

Please sign in to comment.