forked from huggingface/diffusers
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor onnxruntime integration (huggingface#2042)
* refactor onnxruntime integration * fix requirements.txt bug * make style * add support for textual_inversion * make style * add readme * cleanup README files * 1/27/2023 update to training scripts * make style * 1/30 update to train_unconditional * style with black-22.8.0 --------- Co-authored-by: Prathik Rao <[email protected]> Co-authored-by: anton- <[email protected]>
- Loading branch information
1 parent
ecadcde
commit a87e87f
Showing
11 changed files
with
1,889 additions
and
63 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
## Diffusers examples with ONNXRuntime optimizations | ||
|
||
**This research project is not actively maintained by the diffusers team. For any questions or comments, please contact Prathik Rao (prathikr), Sunghoon Choi (hanbitmyths), Ashwini Khade (askhade), or Peng Wang (pengwa) on github with any questions.** | ||
|
||
This aims to provide diffusers examples with ONNXRuntime optimizations for training/fine-tuning unconditional image generation, text to image, and textual inversion. Please see individual directories for more details on how to run each task using ONNXRuntime. |
74 changes: 74 additions & 0 deletions
74
examples/research_projects/onnxruntime/text_to_image/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# Stable Diffusion text-to-image fine-tuning | ||
|
||
The `train_text_to_image.py` script shows how to fine-tune stable diffusion model on your own dataset. | ||
|
||
___Note___: | ||
|
||
___This script is experimental. The script fine-tunes the whole model and often times the model overfits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparamters to get the best result on your dataset.___ | ||
|
||
|
||
## Running locally with PyTorch | ||
### Installing the dependencies | ||
|
||
Before running the scripts, make sure to install the library's training dependencies: | ||
|
||
**Important** | ||
|
||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: | ||
```bash | ||
git clone https://github.com/huggingface/diffusers | ||
cd diffusers | ||
pip install . | ||
``` | ||
|
||
Then cd in the example folder and run | ||
```bash | ||
pip install -r requirements.txt | ||
``` | ||
|
||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: | ||
|
||
```bash | ||
accelerate config | ||
``` | ||
|
||
### Pokemon example | ||
|
||
You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree. | ||
|
||
You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens). | ||
|
||
Run the following command to authenticate your token | ||
|
||
```bash | ||
huggingface-cli login | ||
``` | ||
|
||
If you have already cloned the repo, then you won't need to go through these steps. | ||
|
||
<br> | ||
|
||
## Use ONNXRuntime to accelerate training | ||
In order to leverage onnxruntime to accelerate training, please use train_text_to_image.py | ||
|
||
The command to train a DDPM UNetCondition model on the Pokemon dataset with onnxruntime: | ||
|
||
```bash | ||
export MODEL_NAME="CompVis/stable-diffusion-v1-4" | ||
export dataset_name="lambdalabs/pokemon-blip-captions" | ||
accelerate launch --mixed_precision="fp16" train_text_to_image.py \ | ||
--pretrained_model_name_or_path=$MODEL_NAME \ | ||
--dataset_name=$dataset_name \ | ||
--use_ema \ | ||
--resolution=512 --center_crop --random_flip \ | ||
--train_batch_size=1 \ | ||
--gradient_accumulation_steps=4 \ | ||
--gradient_checkpointing \ | ||
--max_train_steps=15000 \ | ||
--learning_rate=1e-05 \ | ||
--max_grad_norm=1 \ | ||
--lr_scheduler="constant" --lr_warmup_steps=0 \ | ||
--output_dir="sd-pokemon-model" | ||
``` | ||
|
||
Please contact Prathik Rao (prathikr), Sunghoon Choi (hanbitmyths), Ashwini Khade (askhade), or Peng Wang (pengwa) on github with any questions. |
7 changes: 7 additions & 0 deletions
7
examples/research_projects/onnxruntime/text_to_image/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
accelerate | ||
torchvision | ||
transformers>=4.25.1 | ||
datasets | ||
ftfy | ||
tensorboard | ||
modelcards |
Oops, something went wrong.