From 1a5e7c023eb6240fc639d084ad36d425521e20f8 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 27 Feb 2024 17:25:28 +0100 Subject: [PATCH] (1/n) Data Refactor - TinyLlama Pretraining (#958) --- .gitignore | 2 + lit_gpt/data/__init__.py | 4 ++ lit_gpt/data/data_module.py | 25 ++++++++ lit_gpt/data/tinyllama.py | 99 ++++++++++++++++++++++++++++++++ pretrain/tinyllama.py | 61 +++++--------------- tests/data/__init__.py | 0 tests/data/test_tinyllama.py | 35 +++++++++++ tests/test_pretrain_tinyllama.py | 9 ++- tutorials/pretrain_tinyllama.md | 13 ++--- 9 files changed, 190 insertions(+), 58 deletions(-) create mode 100644 lit_gpt/data/__init__.py create mode 100644 lit_gpt/data/data_module.py create mode 100644 lit_gpt/data/tinyllama.py create mode 100644 tests/data/__init__.py create mode 100644 tests/data/test_tinyllama.py diff --git a/.gitignore b/.gitignore index 6eab51ae0e..6e6500ef5c 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,8 @@ build # data data datasets +!lit_gpt/data +!tests/data checkpoints out wandb diff --git a/lit_gpt/data/__init__.py b/lit_gpt/data/__init__.py new file mode 100644 index 0000000000..aa88d20ffb --- /dev/null +++ b/lit_gpt/data/__init__.py @@ -0,0 +1,4 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +from lit_gpt.data.data_module import LitDataModule +from lit_gpt.data.tinyllama import TinyLlama diff --git a/lit_gpt/data/data_module.py b/lit_gpt/data/data_module.py new file mode 100644 index 0000000000..441294b205 --- /dev/null +++ b/lit_gpt/data/data_module.py @@ -0,0 +1,25 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. +from abc import abstractmethod +from typing import Optional + +from lightning import LightningDataModule +from lit_gpt import Tokenizer + + +class LitDataModule(LightningDataModule): + """Base class for all data modules in Lit-GPT.""" + + @abstractmethod + def connect( + self, + tokenizer: Optional[Tokenizer] = None, + batch_size: int = 1, + max_seq_length: Optional[int] = None + ) -> None: + """All settings that can't be determined at the time of instantiation need to be passed through here + before any dataloaders can be accessed. + """ + + def setup(self, stage: str = "") -> None: + # Stub is to redefine the default signature, because the concept of 'stage' does not exist in Lit-GPT + pass diff --git a/lit_gpt/data/tinyllama.py b/lit_gpt/data/tinyllama.py new file mode 100644 index 0000000000..21ab7137de --- /dev/null +++ b/lit_gpt/data/tinyllama.py @@ -0,0 +1,99 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +from pathlib import Path +from typing import Union, Optional + +from torch.utils.data import DataLoader + +from lit_gpt import Tokenizer +from lit_gpt.data import LitDataModule + + +class TinyLlama(LitDataModule): + """The TinyLlama data module is composed of a mix of SlimPajama and Starcoder data. + + Provides training and validation streaming dataloaders that return batches of tokens. + + Args: + data_path: The path to the data directory, containing two folders 'slimpajama' and 'starcoder' + which are the output of the preprocessing step done in advance. See the `tutorial/pretrain_tinyllama.md` + for instructions. The path can also be a remote path (e.g., s3://). + seed: The seed to use for shuffling the training data. + num_workers: The number of workers to use for the dataloaders. + """ + + def __init__( + self, + data_path: Union[str, Path] = Path("data/"), + seed: int = 42, + num_workers: int = 8, + ) -> None: + super().__init__() + self.seed = seed + self.num_workers = num_workers + + self.batch_size = 1 + self.seq_length = 2048 + + # Could be a remote path (s3://) or a local path + self.slimpajama_train = str(data_path).rstrip("/") + "/slimpajama/train" + self.slimpajama_val = str(data_path).rstrip("/") + "/slimpajama/val" + self.starcoder_train = str(data_path).rstrip("/") + "/starcoder" + + def connect( + self, + tokenizer: Optional[Tokenizer] = None, + batch_size: int = 1, + max_seq_length: Optional[int] = None + ) -> None: + self.batch_size = batch_size + self.seq_length = max_seq_length + 1 # Increase by one because we need the next token as well + + def prepare_data(self) -> None: + for path in (self.slimpajama_train, self.slimpajama_val, self.starcoder_train): + if not path.startswith("s3://") and not Path(path).is_dir(): + raise FileNotFoundError( + "The data path for TinyLlama is expected to be the directory containing these subdirectories:" + f" `slimpajama/train`, `slimpajama/val`, `starcoder`. The directory {path} does not exist." + ) + + def train_dataloader(self) -> DataLoader: + from lightning.data.streaming import CombinedStreamingDataset, StreamingDataLoader, StreamingDataset, TokensLoader + + train_datasets = [ + StreamingDataset( + input_dir=self.slimpajama_train, + item_loader=TokensLoader(block_size=self.seq_length), + shuffle=True, + drop_last=True, + ), + StreamingDataset( + input_dir=self.starcoder_train, + item_loader=TokensLoader(block_size=self.seq_length), + shuffle=True, + drop_last=True, + ), + ] + + # Mix SlimPajama data and Starcoder data with these proportions: + weights = (0.693584, 0.306416) + combined_dataset = CombinedStreamingDataset(datasets=train_datasets, seed=self.seed, weights=weights) + train_dataloader = StreamingDataLoader( + combined_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True + ) + return train_dataloader + + def val_dataloader(self) -> DataLoader: + from lightning.data.streaming import StreamingDataset, TokensLoader + + val_dataset = StreamingDataset( + input_dir=self.slimpajama_val, + item_loader=TokensLoader(block_size=self.seq_length), + shuffle=True, + # Consider setting to False, but we would lose some samples due to truncation when world size > 1 + drop_last=True, + ) + val_dataloader = DataLoader( + val_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True + ) + return val_dataloader diff --git a/pretrain/tinyllama.py b/pretrain/tinyllama.py index 06c1a5cbc6..6f025557e3 100644 --- a/pretrain/tinyllama.py +++ b/pretrain/tinyllama.py @@ -24,6 +24,7 @@ from torchmetrics.aggregation import RunningMean from typing_extensions import Literal + # support running without installing as a package wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) @@ -31,14 +32,15 @@ from lit_gpt.args import EvalArgs, IOArgs, TrainArgs from lit_gpt.model import GPT, Block, CausalSelfAttention, Config, LLaMAMLP from lit_gpt.utils import CLI, CycleIterator, chunked_cross_entropy, num_parameters +from lit_gpt.data import TinyLlama, LitDataModule def setup( - model_name: str = "tiny-llama-1.1b", - name: str = "lit-tiny-llama-1.1b", + model: Config = Config(name="tiny-llama-1.1b"), logger_name: Literal["wandb", "tensorboard", "csv"] = "tensorboard", resume: Union[bool, Path] = False, devices: int = torch.cuda.device_count() or 1, + data: LitDataModule = TinyLlama(), io: IOArgs = IOArgs( out_dir=Path(os.getenv("LIGHTNING_ARTIFACTS_DIR", "out")) / "lit-tiny-llama-1.1b", train_data_dir=None ), @@ -59,7 +61,7 @@ def setup( eval: EvalArgs = EvalArgs(interval=1000, max_iters=100), ): hparams = locals() - logger = choose_logger(io.out_dir, logger_name, name=name, resume=resume) + logger = choose_logger(io.out_dir, logger_name, name=f"pretrain-{model.name}", resume=resume) strategy = FSDPStrategy(auto_wrap_policy={Block}, state_dict_type="full", sharding_strategy="HYBRID_SHARD") fabric = L.Fabric(devices=devices, strategy=strategy, precision="bf16-mixed", loggers=[logger]) @@ -69,7 +71,7 @@ def setup( if logger_name in ("tensorboard", "wandb"): fabric.logger.log_hyperparams(hparams) - fabric.launch(main, devices, resume, Config.from_name(name=model_name), io, train, eval) + fabric.launch(main, devices, resume, model, data, io, train, eval) def main( @@ -77,6 +79,7 @@ def main( devices: int, resume: Union[bool, Path], config: Config, + data: LitDataModule, io: IOArgs, train: TrainArgs, eval: EvalArgs, @@ -86,9 +89,7 @@ def main( if fabric.global_rank == 0: io.out_dir.mkdir(parents=True, exist_ok=True) - train_dataloader, val_dataloader = create_dataloaders( - batch_size=train.micro_batch_size, block_size=config.block_size - ) + train_dataloader, val_dataloader = get_dataloaders(fabric, data, train, config.block_size) train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) fabric.seed_everything(3407) # same seed for every process to init model (FSDP) @@ -274,45 +275,13 @@ def validate(fabric: L.Fabric, model: nn.Module, val_dataloader: DataLoader, max return losses.mean() -def create_dataloaders(batch_size: int, block_size: int, num_workers: int = 8) -> Tuple[DataLoader, DataLoader]: - from lightning.data import CombinedStreamingDataset, StreamingDataLoader, StreamingDataset - from lightning.data.streaming.item_loader import TokensLoader - - # Increase by one because we need the next word as well - effective_block_size = block_size + 1 - - train_datasets = [ - StreamingDataset( - input_dir="data/slimpajama/train", - item_loader=TokensLoader(block_size=effective_block_size), - shuffle=True, - drop_last=True, - ), - StreamingDataset( - input_dir="data/starcoder", - item_loader=TokensLoader(block_size=effective_block_size), - shuffle=True, - drop_last=True, - ), - ] - - # Mix SlimPajama data and Starcoder data with these proportions: - weights = (0.693584, 0.306416) - combined_dataset = CombinedStreamingDataset(datasets=train_datasets, seed=42, weights=weights) - train_dataloader = StreamingDataLoader( - combined_dataset, batch_size=batch_size, pin_memory=True, num_workers=num_workers, drop_last=True - ) - - val_dataset = StreamingDataset( - input_dir="data/slimpajama/val", - item_loader=TokensLoader(block_size=effective_block_size), - shuffle=True, - # Consider setting to False, but we would lose some samples due to truncation when world size > 1 - drop_last=True, - ) - val_dataloader = DataLoader( - val_dataset, batch_size=batch_size, pin_memory=True, num_workers=num_workers, drop_last=True - ) +def get_dataloaders(fabric: L.Fabric, data: LitDataModule, train: TrainArgs, block_size: int) -> Tuple[DataLoader, DataLoader]: + data.connect(batch_size=train.micro_batch_size, max_seq_length=block_size) + with fabric.rank_zero_first(): + data.prepare_data() + data.setup() + train_dataloader = data.train_dataloader() + val_dataloader = data.val_dataloader() return train_dataloader, val_dataloader diff --git a/tests/data/__init__.py b/tests/data/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/data/test_tinyllama.py b/tests/data/test_tinyllama.py new file mode 100644 index 0000000000..7360ff68d9 --- /dev/null +++ b/tests/data/test_tinyllama.py @@ -0,0 +1,35 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import pytest +from torch.utils.data import DataLoader + + +def test_tinyllama(tmp_path, monkeypatch): + from lit_gpt.data import TinyLlama + from lightning.data.streaming import StreamingDataLoader, StreamingDataset, CombinedStreamingDataset + + data = TinyLlama(data_path=(tmp_path / "data")) + assert data.seq_length == 2048 + assert data.batch_size == 1 + + data.connect(batch_size=2, max_seq_length=1024) + assert data.seq_length == 1025 + assert data.batch_size == 2 + + with pytest.raises(FileNotFoundError, match="The directory .*data/slimpajama/train does not exist"): + data.prepare_data() + + (tmp_path / "data" / "slimpajama" / "train").mkdir(parents=True) + (tmp_path / "data" / "slimpajama" / "val").mkdir(parents=True) + (tmp_path / "data" / "starcoder").mkdir(parents=True) + + data.prepare_data() + data.setup() + + train_dataloader = data.train_dataloader() + assert isinstance(train_dataloader, StreamingDataLoader) + assert isinstance(train_dataloader.dataset, CombinedStreamingDataset) + + val_dataloader = data.val_dataloader() + assert isinstance(val_dataloader, DataLoader) + assert isinstance(val_dataloader.dataset, StreamingDataset) diff --git a/tests/test_pretrain_tinyllama.py b/tests/test_pretrain_tinyllama.py index 0649310b38..f781f1db76 100644 --- a/tests/test_pretrain_tinyllama.py +++ b/tests/test_pretrain_tinyllama.py @@ -17,20 +17,19 @@ def test_pretrain_tiny_llama(tmp_path, monkeypatch): import pretrain.tinyllama as module from lit_gpt.args import EvalArgs, IOArgs, TrainArgs - from lit_gpt.config import name_to_config + from lit_gpt.config import Config - model_config = dict(block_size=2, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8) - monkeypatch.setitem(name_to_config, "tmp", model_config) + model_config = Config(block_size=2, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8) dataset = torch.tensor([[0, 1, 2], [3, 4, 5], [0, 1, 2]]) dataloader = DataLoader(dataset) - module.create_dataloaders = Mock(return_value=(dataloader, dataloader)) + module.get_dataloaders = Mock(return_value=(dataloader, dataloader)) stdout = StringIO() with redirect_stdout(stdout): module.setup( devices=2, - model_name="tmp", + model=model_config, io=IOArgs(out_dir=tmp_path, train_data_dir=None), train=TrainArgs(global_batch_size=2, max_tokens=16, save_interval=1, micro_batch_size=1, max_norm=1.0), eval=EvalArgs(interval=1, max_iters=1), diff --git a/tutorials/pretrain_tinyllama.md b/tutorials/pretrain_tinyllama.md index 936b4facfa..3f5d847910 100644 --- a/tutorials/pretrain_tinyllama.md +++ b/tutorials/pretrain_tinyllama.md @@ -104,18 +104,16 @@ python pretrain/tinyllama.py ``` The script will save checkpoints periodically to the folder `out/`. -By default, the `pretrain/tinyllama.py` script will pretrain the Llama 2 7B model with FSDP in +By default, the `pretrain/tinyllama.py` script will pretrain the model with FSDP in `bfloat16` mixed precision and gradient accumulation. Note that the `pretrain/tinyllama.py` is not actually a model-specific training script, so feel free to change -the configuration and size by passing a different string to the model name variable +the configuration and size by passing a different string to the model name argument, for example: ```shell ---model_name "tiny-llama-1.1b" +python pretrained/tinyllama.py --model.name Gemma-2b ``` -at the top of this script. - The currently supported model names are contained in the [config.py](https://github.com/Lightning-AI/lit-gpt/lit_gpt/config.py) file. You can @@ -126,17 +124,18 @@ Keep in mind that training with a single machine will take weeks. To speed up th Once you're in a cluster, you can follow [these instructions](https://lightning.ai/docs/fabric/stable/fundamentals/launch.html#launch-on-a-cluster) to launch the script across machines: +- [Lightning AI](https://lightning.ai/docs/fabric/stable/guide/multi_node/cloud.html) - [SLURM cluster](https://lightning.ai/docs/fabric/stable/guide/multi_node/slurm.html) - [Barebones cluster](https://lightning.ai/docs/fabric/stable/guide/multi_node/barebones.html) - [MPI](https://lightning.ai/docs/fabric/stable/guide/multi_node/other.html) -The exposes several hyperparameters you can tweak through the command line. +The script exposes several hyperparameters you can tweak through the command line. For instance, `--train.micro_batch_size` should be adjusted so the process will use the available GPU memory. For more tips to avoid out-of-memory issues, please also see the more detailed [Dealing with out-of-memory (OOM) errors](oom.md) guide. -Last, logging is kept minimal in the script, but for long running experiments we recommend switching to a proper experiment tracker. +Last, logging is kept minimal in the script, but for long-running experiments we recommend switching to a proper experiment tracker. As an example, we included WandB (set `use_wandb=True`) to show how you can integrate any experiment tracking framework. For reference, [here are the loss curves for our reproduction](https://api.wandb.ai/links/awaelchli/y7pzdpwy).