Skip to content

Latest commit

 

History

History

xla

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 

TPU support

This project utilizes Fabric, which supports TPUs via PyTorch XLA.

Note

This guide assumes that you have already set-up your Google Cloud environment.

To set up a Google Cloud instance with a TPU v4 VM, run the following commands:

gcloud compute tpus tpu-vm create lit-gpt --version=tpu-vm-v4-base --accelerator-type=v4-8 --zone=us-central2-b
gcloud compute tpus tpu-vm ssh lit-gpt --zone=us-central2-b

You can also choose a different TPU type. To do so, change the version, accelerator-type, and zone arguments. Find all regions and zones here.

Multihost caveats

TPU v4-8 uses a single host. SSH'ing into the machine and running commands manually will only work when using a single host (1 slice in the TPU pod). In multi-host environments, such as larger TPU pod slices, it's necessary to launch all commands on all hosts simultaneously to avoid hangs. For local development, it is advisable to upload a zip file containing all your current changes and execute it inside the VM from your personal computer:

# Zip the local directory, excluding large directories from the zip. You may want to keep them.
zip -r local_changes.zip . -x  ".git/*" "checkpoints/*" "data/*" "out/*"
# Copy the .zip file to the TPU VM
gcloud compute tpus tpu-vm scp --worker=all local_changes.zip "lit-gpt:~"
# Unzip on each host
gcloud compute tpus tpu-vm ssh lit-gpt --worker=all --command="cd ~; unzip -q -o local_changes.zip"

# Example of a typical workflow
gcloud compute tpus tpu-vm ssh tmp --worker=all --command="cd ~; bash install_dependencies.sh"
gcloud compute tpus tpu-vm ssh tmp --worker=all --command="cd ~; bash prepare_checkpoints.sh"
gcloud compute tpus tpu-vm ssh tmp --worker=all --command="cd ~; bash run_desired_script.sh"

# This will allow you to kill all python processes on all workers
gcloud compute tpus tpu-vm ssh tmp --worker=all --command="pkill -e python"

Notice how the commands to install the environment and prepare checkpoints need to be run on all workers, since the filesystem for each worker (host) is not shared.

For the rest of this tutorial, it will be assumed that it is being run on a single host for simplicity.

Once inside the machine, clone the repository and install the dependencies:

git clone https://github.com/Lightning-AI/lit-gpt
cd lit-gpt
pip install -r requirements.txt

Install Optimized BLAS:

sudo apt update
sudo apt install libopenblas-dev

Since Lit-GPT requires a torch version newer than torch 2.0.0, manually install nightly builds of torch and torch_xla:

pip install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch-nightly-cp38-cp38-linux_x86_64.whl
pip install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly-cp38-cp38-linux_x86_64.whl

While computations will run by default using the new PjRT runtime, it is recommended to set the following environment variables:

export ALLOW_MULTIPLE_LIBTPU_LOAD=1
export PJRT_DEVICE=TPU

Note

An extensive guide on setup and available options can be found here.

Since a new machine was created, you may need to download pretrained weights. They can be copied to the machine using gcloud compute tpus tpu-vm scp, or you can follow the steps described in our downloading guide.

It is also recommended to set up a persistent disk from which to load checkpoints. Follow this guide to do so. Read-write disks are not supported in multihost VM setups, so persistent disks cannot be used to save checkpoints in that case. Persistent disks can still be useful in read-only mode to load pretrained weights before finetuning or inference. In multihost settings, FSDP will save checkpoint shards per host and consolidate them into a single checkpoint. For safekeeping, it is recommended to upload the consolidated checkpoints to a Google Cloud bucket. Alternatively, you can use the scp command to transfer these checkpoints from the TPU VM periodically, although this is not implemented in our scripts.

Inference

This project provides custom versions of the regular recipes to run with XLA in the xla directory. To generate text, use the following command:

python3 xla/generate/base.py --prompt "Hello, my name is" --num_samples 3

For the first generation, this command will take around 17 seconds as XLA needs to compile the graph. Subsequent generations will take around 2 seconds.

Fine-tuning

To get started fine-tuning Falcon 7B with adapter, run the following command:

python3 scripts/prepare_alpaca.py --checkpoint_dir checkpoints/tiiuae/falcon-7b

python3 xla/finetune/adapter.py --checkpoint_dir checkpoints/tiiuae/falcon-7b --precision bf16-true
Multihost caveats

This script is configured to save "full" checkpoints, which isn't possible on multihost TPU VMs. Here's how you can consolidate them together into a single one after training with state_dict_type="sharded":

path_to_shards="out/adapter/alpaca/lit_model_adapter_finetuned"
mkdir -p $path_to_shards
workers=4  # 4 hosts
for ((i = 0; i < workers; i++)); do
  # aggregate all shards locally
  gcloud compute tpus tpu-vm scp --worker=$i "lit-gpt:${path_to_shards}/*" "${path_to_shards}/" --zone us-central2-b
done
# copy all shards to all workers
gcloud compute tpus tpu-vm scp --worker=all ${path_to_shards}/* "lit-gpt:${path_to_shards}/" --zone us-central2-b
# consolidate the shards in each worker
gcloud compute tpus tpu-vm ssh tmp --worker=all --command="python -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts --ckpt_prefix ${path_to_shards}/checkpoint --ckpt_suffix '_rank-*-of-*.pth' --save_path ${path_to_shards}.pth" --zone us-central2-b

Since the TPU VM host RAM is limited (200 GB), we implement a technique to sequentially load and shard the checkpoint that can be enabled by setting reduce_cpu_memory_usage_during_load = True. This is necessary to load falcon-40b.

To generate text with the adapter fine-tuned model weights, use the following command:

python3 xla/generate/adapter.py --checkpoint_dir checkpoints/tiiuae/falcon-7b --precision bf16-true --adapter_path out/adapter/alpaca/lit_model_adapter_finetuned.pth

Warning Remember to delete your instance when you are done.

gcloud compute tpus tpu-vm delete lit-gpt --zone=us-central2-b

Computational Performance

Using the adapter finetuning script and XLA's FSDP implementation, a 49.57% MFU was achieved with Falcon 7B on a v4-32 (micro batch size 7), and a 39.67% MFU was achieved with Falcon 40B on a v4-512 (micro batch size 3) at a fixed 1034 maximum sequence length.

Since the TPU VM host has limited system memory (RAM) compared to device memory (HBM), specific techniques were implemented to limit peak RAM usage when loading the model and pretrained weights before sharding, as well as when saving sharded checkpoints. A v4 chip has 32 GiB HBM, so with 4 devices per host (4 * 32 = 128 GiB HBM), each host has 188 GiB RAM, which is shared across the devices. Therefore, any RAM allocation over 188/4 = 47 GiB would exceed the host's RAM capacity. A ~24B parameter model on CPU (with half precision) would be the largest possible model under this setup without the techniques used in our scripts.