This tutorial will walk you through setting up the OpenWebText dataset and launching the pretraining script.
OpenWebText is an open-source reproduction of OpenAI's unreleased WebText training dataset, which was originally used to train GPT-2. The version that is used here consists of 8M documents and is loaded via the load_dataset("openwebtext", ...)
function from the datasets Python package. Please refer to the website hosting the dataset for the licensing information.
In order to start pretraining lit-gpt on it, you need to read, tokenize, and write the data in binary format.
To prepare the dataset with the Llama 2 tokenizer, run
pip install datasets
python scripts/prepare_openwebtext.py \
--checkpoint_dir checkpoints/meta-llama/Llama-2-7b-hf/ \
--destination_path data/openwebtext
The script will take about 15 min to run.
Running the pretraining script with its default settings requires at least 4 GPUs with 40GB+ each. (However, alternatively, you can train a smaller Pythia-70m on 1 GPU, more information about that further below).
python pretrain/openwebtext.py --devices 4
The script will save checkpoints periodically to the folder out/
.
By default, the pretrain/openwebtext.py
script will pretrain the Llama 2 7B model with FSDP in
bfloat16
precision and gradient accumulation.
You can easily change the size of the model by passing a different string to the model name variable
--model_name "Llama-2-7b-hf"
at the top of this script.
The currently supported model names are contained in the config.py file. You can
- either search this file for lines containing "name =",
- or run
python scripts/download.py
without additional command line arguments,
Keep in mind that the original LLaMA training for the 7B model required 83k A100 80GB hours (on a bigger dataset). However, for full pretraining on OpenWebText, you'll likely still need access to a cluster.
Once you're in a cluster, you can follow these instructions to launch the script across machines:
The exposes several hyperparameters you can tweak through the command line.
For instance, --train.micro_batch_size
should be adjusted so the process will use the available
GPU memory. For more tips to avoid out-of-memory issues, please also see the more detailed
Dealing with out-of-memory (OOM) errors guide.
Last, logging is kept minimal in the script. In order to use a particular logger
please refer to https://lightning.ai/docs/fabric/stable/api/loggers.html or
call a logging client library like wandb
directly.
To train a smaller Pythia 70M model on a single GPU, you can pass the --model_name "pythia-70m"
argument.
(Please see the download_*
scripts in the tutorials for more information on downloading model checkpoints for different models.)
Also, before you start training, note that you will need to prepare the dataset specifically for this model since it may use a different tokenizer:
python scripts/prepare_openwebtext.py \
--checkpoint_dir checkpoints/EleutherAI/pythia-70m/ \
--destination_path data/lit-openwebtext
python pretrain/openwebtext.py --devices 4