Skip to content

Commit

Permalink
(1/n) Data Refactor - TinyLlama Pretraining (Lightning-AI#958)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Feb 27, 2024
1 parent 7c15749 commit 1a5e7c0
Show file tree
Hide file tree
Showing 9 changed files with 190 additions and 58 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ build
# data
data
datasets
!lit_gpt/data
!tests/data
checkpoints
out
wandb
Expand Down
4 changes: 4 additions & 0 deletions lit_gpt/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

from lit_gpt.data.data_module import LitDataModule
from lit_gpt.data.tinyllama import TinyLlama
25 changes: 25 additions & 0 deletions lit_gpt/data/data_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
from abc import abstractmethod
from typing import Optional

from lightning import LightningDataModule
from lit_gpt import Tokenizer


class LitDataModule(LightningDataModule):
"""Base class for all data modules in Lit-GPT."""

@abstractmethod
def connect(
self,
tokenizer: Optional[Tokenizer] = None,
batch_size: int = 1,
max_seq_length: Optional[int] = None
) -> None:
"""All settings that can't be determined at the time of instantiation need to be passed through here
before any dataloaders can be accessed.
"""

def setup(self, stage: str = "") -> None:
# Stub is to redefine the default signature, because the concept of 'stage' does not exist in Lit-GPT
pass
99 changes: 99 additions & 0 deletions lit_gpt/data/tinyllama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

from pathlib import Path
from typing import Union, Optional

from torch.utils.data import DataLoader

from lit_gpt import Tokenizer
from lit_gpt.data import LitDataModule


class TinyLlama(LitDataModule):
"""The TinyLlama data module is composed of a mix of SlimPajama and Starcoder data.
Provides training and validation streaming dataloaders that return batches of tokens.
Args:
data_path: The path to the data directory, containing two folders 'slimpajama' and 'starcoder'
which are the output of the preprocessing step done in advance. See the `tutorial/pretrain_tinyllama.md`
for instructions. The path can also be a remote path (e.g., s3://).
seed: The seed to use for shuffling the training data.
num_workers: The number of workers to use for the dataloaders.
"""

def __init__(
self,
data_path: Union[str, Path] = Path("data/"),
seed: int = 42,
num_workers: int = 8,
) -> None:
super().__init__()
self.seed = seed
self.num_workers = num_workers

self.batch_size = 1
self.seq_length = 2048

# Could be a remote path (s3://) or a local path
self.slimpajama_train = str(data_path).rstrip("/") + "/slimpajama/train"
self.slimpajama_val = str(data_path).rstrip("/") + "/slimpajama/val"
self.starcoder_train = str(data_path).rstrip("/") + "/starcoder"

def connect(
self,
tokenizer: Optional[Tokenizer] = None,
batch_size: int = 1,
max_seq_length: Optional[int] = None
) -> None:
self.batch_size = batch_size
self.seq_length = max_seq_length + 1 # Increase by one because we need the next token as well

def prepare_data(self) -> None:
for path in (self.slimpajama_train, self.slimpajama_val, self.starcoder_train):
if not path.startswith("s3://") and not Path(path).is_dir():
raise FileNotFoundError(
"The data path for TinyLlama is expected to be the directory containing these subdirectories:"
f" `slimpajama/train`, `slimpajama/val`, `starcoder`. The directory {path} does not exist."
)

def train_dataloader(self) -> DataLoader:
from lightning.data.streaming import CombinedStreamingDataset, StreamingDataLoader, StreamingDataset, TokensLoader

train_datasets = [
StreamingDataset(
input_dir=self.slimpajama_train,
item_loader=TokensLoader(block_size=self.seq_length),
shuffle=True,
drop_last=True,
),
StreamingDataset(
input_dir=self.starcoder_train,
item_loader=TokensLoader(block_size=self.seq_length),
shuffle=True,
drop_last=True,
),
]

# Mix SlimPajama data and Starcoder data with these proportions:
weights = (0.693584, 0.306416)
combined_dataset = CombinedStreamingDataset(datasets=train_datasets, seed=self.seed, weights=weights)
train_dataloader = StreamingDataLoader(
combined_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True
)
return train_dataloader

def val_dataloader(self) -> DataLoader:
from lightning.data.streaming import StreamingDataset, TokensLoader

val_dataset = StreamingDataset(
input_dir=self.slimpajama_val,
item_loader=TokensLoader(block_size=self.seq_length),
shuffle=True,
# Consider setting to False, but we would lose some samples due to truncation when world size > 1
drop_last=True,
)
val_dataloader = DataLoader(
val_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True
)
return val_dataloader
61 changes: 15 additions & 46 deletions pretrain/tinyllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,23 @@
from torchmetrics.aggregation import RunningMean
from typing_extensions import Literal


# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from lit_gpt.args import EvalArgs, IOArgs, TrainArgs
from lit_gpt.model import GPT, Block, CausalSelfAttention, Config, LLaMAMLP
from lit_gpt.utils import CLI, CycleIterator, chunked_cross_entropy, num_parameters
from lit_gpt.data import TinyLlama, LitDataModule


def setup(
model_name: str = "tiny-llama-1.1b",
name: str = "lit-tiny-llama-1.1b",
model: Config = Config(name="tiny-llama-1.1b"),
logger_name: Literal["wandb", "tensorboard", "csv"] = "tensorboard",
resume: Union[bool, Path] = False,
devices: int = torch.cuda.device_count() or 1,
data: LitDataModule = TinyLlama(),
io: IOArgs = IOArgs(
out_dir=Path(os.getenv("LIGHTNING_ARTIFACTS_DIR", "out")) / "lit-tiny-llama-1.1b", train_data_dir=None
),
Expand All @@ -59,7 +61,7 @@ def setup(
eval: EvalArgs = EvalArgs(interval=1000, max_iters=100),
):
hparams = locals()
logger = choose_logger(io.out_dir, logger_name, name=name, resume=resume)
logger = choose_logger(io.out_dir, logger_name, name=f"pretrain-{model.name}", resume=resume)

strategy = FSDPStrategy(auto_wrap_policy={Block}, state_dict_type="full", sharding_strategy="HYBRID_SHARD")
fabric = L.Fabric(devices=devices, strategy=strategy, precision="bf16-mixed", loggers=[logger])
Expand All @@ -69,14 +71,15 @@ def setup(
if logger_name in ("tensorboard", "wandb"):
fabric.logger.log_hyperparams(hparams)

fabric.launch(main, devices, resume, Config.from_name(name=model_name), io, train, eval)
fabric.launch(main, devices, resume, model, data, io, train, eval)


def main(
fabric: L.Fabric,
devices: int,
resume: Union[bool, Path],
config: Config,
data: LitDataModule,
io: IOArgs,
train: TrainArgs,
eval: EvalArgs,
Expand All @@ -86,9 +89,7 @@ def main(
if fabric.global_rank == 0:
io.out_dir.mkdir(parents=True, exist_ok=True)

train_dataloader, val_dataloader = create_dataloaders(
batch_size=train.micro_batch_size, block_size=config.block_size
)
train_dataloader, val_dataloader = get_dataloaders(fabric, data, train, config.block_size)
train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader)

fabric.seed_everything(3407) # same seed for every process to init model (FSDP)
Expand Down Expand Up @@ -274,45 +275,13 @@ def validate(fabric: L.Fabric, model: nn.Module, val_dataloader: DataLoader, max
return losses.mean()


def create_dataloaders(batch_size: int, block_size: int, num_workers: int = 8) -> Tuple[DataLoader, DataLoader]:
from lightning.data import CombinedStreamingDataset, StreamingDataLoader, StreamingDataset
from lightning.data.streaming.item_loader import TokensLoader

# Increase by one because we need the next word as well
effective_block_size = block_size + 1

train_datasets = [
StreamingDataset(
input_dir="data/slimpajama/train",
item_loader=TokensLoader(block_size=effective_block_size),
shuffle=True,
drop_last=True,
),
StreamingDataset(
input_dir="data/starcoder",
item_loader=TokensLoader(block_size=effective_block_size),
shuffle=True,
drop_last=True,
),
]

# Mix SlimPajama data and Starcoder data with these proportions:
weights = (0.693584, 0.306416)
combined_dataset = CombinedStreamingDataset(datasets=train_datasets, seed=42, weights=weights)
train_dataloader = StreamingDataLoader(
combined_dataset, batch_size=batch_size, pin_memory=True, num_workers=num_workers, drop_last=True
)

val_dataset = StreamingDataset(
input_dir="data/slimpajama/val",
item_loader=TokensLoader(block_size=effective_block_size),
shuffle=True,
# Consider setting to False, but we would lose some samples due to truncation when world size > 1
drop_last=True,
)
val_dataloader = DataLoader(
val_dataset, batch_size=batch_size, pin_memory=True, num_workers=num_workers, drop_last=True
)
def get_dataloaders(fabric: L.Fabric, data: LitDataModule, train: TrainArgs, block_size: int) -> Tuple[DataLoader, DataLoader]:
data.connect(batch_size=train.micro_batch_size, max_seq_length=block_size)
with fabric.rank_zero_first():
data.prepare_data()
data.setup()
train_dataloader = data.train_dataloader()
val_dataloader = data.val_dataloader()
return train_dataloader, val_dataloader


Expand Down
Empty file added tests/data/__init__.py
Empty file.
35 changes: 35 additions & 0 deletions tests/data/test_tinyllama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

import pytest
from torch.utils.data import DataLoader


def test_tinyllama(tmp_path, monkeypatch):
from lit_gpt.data import TinyLlama
from lightning.data.streaming import StreamingDataLoader, StreamingDataset, CombinedStreamingDataset

data = TinyLlama(data_path=(tmp_path / "data"))
assert data.seq_length == 2048
assert data.batch_size == 1

data.connect(batch_size=2, max_seq_length=1024)
assert data.seq_length == 1025
assert data.batch_size == 2

with pytest.raises(FileNotFoundError, match="The directory .*data/slimpajama/train does not exist"):
data.prepare_data()

(tmp_path / "data" / "slimpajama" / "train").mkdir(parents=True)
(tmp_path / "data" / "slimpajama" / "val").mkdir(parents=True)
(tmp_path / "data" / "starcoder").mkdir(parents=True)

data.prepare_data()
data.setup()

train_dataloader = data.train_dataloader()
assert isinstance(train_dataloader, StreamingDataLoader)
assert isinstance(train_dataloader.dataset, CombinedStreamingDataset)

val_dataloader = data.val_dataloader()
assert isinstance(val_dataloader, DataLoader)
assert isinstance(val_dataloader.dataset, StreamingDataset)
9 changes: 4 additions & 5 deletions tests/test_pretrain_tinyllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,19 @@
def test_pretrain_tiny_llama(tmp_path, monkeypatch):
import pretrain.tinyllama as module
from lit_gpt.args import EvalArgs, IOArgs, TrainArgs
from lit_gpt.config import name_to_config
from lit_gpt.config import Config

model_config = dict(block_size=2, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8)
monkeypatch.setitem(name_to_config, "tmp", model_config)
model_config = Config(block_size=2, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8)

dataset = torch.tensor([[0, 1, 2], [3, 4, 5], [0, 1, 2]])
dataloader = DataLoader(dataset)
module.create_dataloaders = Mock(return_value=(dataloader, dataloader))
module.get_dataloaders = Mock(return_value=(dataloader, dataloader))

stdout = StringIO()
with redirect_stdout(stdout):
module.setup(
devices=2,
model_name="tmp",
model=model_config,
io=IOArgs(out_dir=tmp_path, train_data_dir=None),
train=TrainArgs(global_batch_size=2, max_tokens=16, save_interval=1, micro_batch_size=1, max_norm=1.0),
eval=EvalArgs(interval=1, max_iters=1),
Expand Down
13 changes: 6 additions & 7 deletions tutorials/pretrain_tinyllama.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,18 +104,16 @@ python pretrain/tinyllama.py
```

The script will save checkpoints periodically to the folder `out/`.
By default, the `pretrain/tinyllama.py` script will pretrain the Llama 2 7B model with FSDP in
By default, the `pretrain/tinyllama.py` script will pretrain the model with FSDP in
`bfloat16` mixed precision and gradient accumulation.

Note that the `pretrain/tinyllama.py` is not actually a model-specific training script, so feel free to change
the configuration and size by passing a different string to the model name variable
the configuration and size by passing a different string to the model name argument, for example:

```shell
--model_name "tiny-llama-1.1b"
python pretrained/tinyllama.py --model.name Gemma-2b
```

at the top of this script.

The currently supported model names are contained in the [config.py](https://github.com/Lightning-AI/lit-gpt/lit_gpt/config.py) file.
You can

Expand All @@ -126,17 +124,18 @@ Keep in mind that training with a single machine will take weeks. To speed up th
Once you're in a cluster, you can follow [these instructions](https://lightning.ai/docs/fabric/stable/fundamentals/launch.html#launch-on-a-cluster)
to launch the script across machines:

- [Lightning AI](https://lightning.ai/docs/fabric/stable/guide/multi_node/cloud.html)
- [SLURM cluster](https://lightning.ai/docs/fabric/stable/guide/multi_node/slurm.html)
- [Barebones cluster](https://lightning.ai/docs/fabric/stable/guide/multi_node/barebones.html)
- [MPI](https://lightning.ai/docs/fabric/stable/guide/multi_node/other.html)

The exposes several hyperparameters you can tweak through the command line.
The script exposes several hyperparameters you can tweak through the command line.

For instance, `--train.micro_batch_size` should be adjusted so the process will use the available
GPU memory. For more tips to avoid out-of-memory issues, please also see the more detailed
[Dealing with out-of-memory (OOM) errors](oom.md) guide.

Last, logging is kept minimal in the script, but for long running experiments we recommend switching to a proper experiment tracker.
Last, logging is kept minimal in the script, but for long-running experiments we recommend switching to a proper experiment tracker.
As an example, we included WandB (set `use_wandb=True`) to show how you can integrate any experiment tracking framework.
For reference, [here are the loss curves for our reproduction](https://api.wandb.ai/links/awaelchli/y7pzdpwy).

Expand Down

0 comments on commit 1a5e7c0

Please sign in to comment.