Skip to content

Commit 45aa8bb

Browse files
entrpnjfacevedo-googlesayakpaul
authored
Ptxla sd training (huggingface#9381)
* enable pxla training of stable diffusion 2.x models. * run linter/style and run pipeline test for stable diffusion and fix issues. * update xla libraries * fix read me newline. * move files to research folder. * update per comments. * rename readme. --------- Co-authored-by: Juan Acevedo <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent 5e1427a commit 45aa8bb

File tree

4 files changed

+855
-0
lines changed

4 files changed

+855
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
# Stable Diffusion text-to-image fine-tuning using PyTorch/XLA
2+
3+
The `train_text_to_image_xla.py` script shows how to fine-tune stable diffusion model on TPU devices using PyTorch/XLA.
4+
5+
It has been tested on v4 and v5p TPU versions. Training code has been tested on multi-host.
6+
7+
This script implements Distributed Data Parallel using GSPMD feature in XLA compiler
8+
where we shard the input batches over the TPU devices.
9+
10+
As of 9-11-2024, these are some expected step times.
11+
12+
| accelerator | global batch size | step time (seconds) |
13+
| ----------- | ----------------- | --------- |
14+
| v5p-128 | 1024 | 0.245 |
15+
| v5p-256 | 2048 | 0.234 |
16+
| v5p-512 | 4096 | 0.2498 |
17+
18+
## Create TPU
19+
20+
To create a TPU on Google Cloud first set these environment variables:
21+
22+
```bash
23+
export TPU_NAME=<tpu-name>
24+
export PROJECT_ID=<project-id>
25+
export ZONE=<google-cloud-zone>
26+
export ACCELERATOR_TYPE=<accelerator type like v5p-8>
27+
export RUNTIME_VERSION=<runtime version like v2-alpha-tpuv5 for v5p>
28+
```
29+
30+
Then run the create TPU command:
31+
```bash
32+
gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --project ${PROJECT_ID}
33+
--zone ${ZONE} --accelerator-type ${ACCELERATOR_TYPE} --version ${RUNTIME_VERSION}
34+
--reserved
35+
```
36+
37+
You can also use other ways to reserve TPUs like GKE or queued resources.
38+
39+
## Setup TPU environment
40+
41+
Install PyTorch and PyTorch/XLA nightly versions:
42+
```bash
43+
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
44+
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
45+
--command='
46+
pip3 install --pre torch==2.5.0.dev20240905+cpu torchvision==0.20.0.dev20240905+cpu --index-url https://download.pytorch.org/whl/nightly/cpu
47+
pip3 install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.5.0.dev20240905-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
48+
'
49+
```
50+
51+
Verify that PyTorch and PyTorch/XLA were installed correctly:
52+
53+
```bash
54+
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
55+
--project ${PROJECT_ID} --zone ${ZONE} --worker=all \
56+
--command='python3 -c "import torch; import torch_xla;"'
57+
```
58+
59+
Install dependencies:
60+
```bash
61+
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
62+
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
63+
--command='
64+
git clone https://github.com/huggingface/diffusers.git
65+
cd diffusers
66+
git checkout main
67+
cd examples/research_projects/pytorch_xla
68+
pip3 install -r requirements.txt
69+
pip3 install pillow --upgrade
70+
cd ../../..
71+
pip3 install .'
72+
```
73+
74+
## Run the training job
75+
76+
### Authenticate
77+
78+
Run the following command to authenticate your token.
79+
80+
```bash
81+
huggingface-cli login
82+
```
83+
84+
This script only trains the unet part of the network. The VAE and text encoder
85+
are fixed.
86+
87+
```bash
88+
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
89+
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
90+
--command='
91+
export XLA_DISABLE_FUNCTIONALIZATION=1
92+
export PROFILE_DIR=/tmp/
93+
export CACHE_DIR=/tmp/
94+
export DATASET_NAME=lambdalabs/naruto-blip-captions
95+
export PER_HOST_BATCH_SIZE=32 # This is known to work on TPU v4. Can set this to 64 for TPU v5p
96+
export TRAIN_STEPS=50
97+
export OUTPUT_DIR=/tmp/trained-model/
98+
python diffusers/examples/research_projects/pytorch_xla/train_text_to_image_xla.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-2-base --dataset_name=$DATASET_NAME --resolution=512 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=80000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=4 --loader_prefetch_size=4 --device_prefetch_size=4'
99+
100+
```
101+
102+
### Environment Envs Explained
103+
104+
* `XLA_DISABLE_FUNCTIONALIZATION`: To optimize the performance for AdamW optimizer.
105+
* `PROFILE_DIR`: Specify where to put the profiling results.
106+
* `CACHE_DIR`: Directory to store XLA compiled graphs for persistent caching.
107+
* `DATASET_NAME`: Dataset to train the model.
108+
* `PER_HOST_BATCH_SIZE`: Size of the batch to load per CPU host. For e.g. for a v5p-16 with 2 CPU hosts, the global batch size will be 2xPER_HOST_BATCH_SIZE. The input batch is sharded along the batch axis.
109+
* `TRAIN_STEPS`: Total number of training steps to run the training for.
110+
* `OUTPUT_DIR`: Directory to store the fine-tuned model.
111+
112+
## Run inference using the output model
113+
114+
To run inference using the output, you can simply load the model and pass it
115+
input prompts. The first pass will compile the graph and takes longer with the following passes running much faster.
116+
117+
```bash
118+
export CACHE_DIR=/tmp/
119+
```
120+
121+
```python
122+
import torch
123+
import os
124+
import sys
125+
import numpy as np
126+
127+
import torch_xla.core.xla_model as xm
128+
from time import time
129+
from diffusers import StableDiffusionPipeline
130+
import torch_xla.runtime as xr
131+
132+
CACHE_DIR = os.environ.get("CACHE_DIR", None)
133+
if CACHE_DIR:
134+
xr.initialize_cache(CACHE_DIR, readonly=False)
135+
136+
def main():
137+
device = xm.xla_device()
138+
model_path = "jffacevedo/pxla_trained_model"
139+
pipe = StableDiffusionPipeline.from_pretrained(
140+
model_path,
141+
torch_dtype=torch.bfloat16
142+
)
143+
pipe.to(device)
144+
prompt = ["A naruto with green eyes and red legs."]
145+
start = time()
146+
print("compiling...")
147+
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
148+
print(f"compile time: {time() - start}")
149+
print("generate...")
150+
start = time()
151+
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
152+
print(f"generation time (after compile) : {time() - start}")
153+
image.save("naruto.png")
154+
155+
if __name__ == '__main__':
156+
main()
157+
```
158+
159+
Expected Results:
160+
161+
```bash
162+
compiling...
163+
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [10:03<00:00, 20.10s/it]
164+
compile time: 720.656970500946
165+
generate...
166+
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:01<00:00, 17.65it/s]
167+
generation time (after compile) : 1.8461642265319824
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
accelerate>=0.16.0
2+
torchvision
3+
transformers>=4.25.1
4+
datasets>=2.19.1
5+
ftfy
6+
tensorboard
7+
Jinja2
8+
peft==0.7.0

0 commit comments

Comments
 (0)