Skip to content

Commit

Permalink
Add examples with Intel optimizations (huggingface#1579)
Browse files Browse the repository at this point in the history
* Add examples with Intel optimizations (BF16 fine-tuning and inference)

* Remove unused package

* Add README for intel_opts and refine the description for research projects

* Add notes of intel opts for diffusers
  • Loading branch information
hshen14 authored Dec 15, 2022
1 parent c5f04d4 commit c891330
Show file tree
Hide file tree
Showing 6 changed files with 790 additions and 0 deletions.
4 changes: 4 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ For such examples, we are more lenient regarding the philosophy defined above an
Examples that are useful for the community, but are either not yet deemed popular or not yet following our above philosophy should go into the [community examples](https://github.com/huggingface/diffusers/tree/main/examples/community) folder. The community folder therefore includes training examples and inference pipelines.
**Note**: Community examples can be a [great first contribution](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) to show to the community how you like to use `diffusers` 🪄.

## Research Projects

We also provide **research_projects** examples that are maintained by the community as defined in the respective research project folders. These examples are useful and offer the extended capabilities which are complementary to the official examples. You may refer to [research_projects](https://github.com/huggingface/diffusers/tree/main/examples/research_projects) for details.

## Important note

To make sure you can successfully run the latest versions of the example scripts, you have to **install the library from source** and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
Expand Down
17 changes: 17 additions & 0 deletions examples/research_projects/intel_opts/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
## Diffusers examples with Intel optimizations

**This research project is not actively maintained by the diffusers team. For any questions or comments, please make sure to tag @hshen14 .**

This aims to provide diffusers examples with Intel optimizations such as Bfloat16 for training/fine-tuning acceleration and 8-bit integer (INT8) for inference acceleration on Intel platforms.

## Accelerating the fine-tuning for textual inversion

We accelereate the fine-tuning for textual inversion with Intel Extension for PyTorch. The [examples](textual_inversion) enable both single node and multi-node distributed training with Bfloat16 support on Intel Xeon Scalable Processor.

## Accelerating the inference for Stable Diffusion using Bfloat16

We start the inference acceleration with Bfloat16 using Intel Extension for PyTorch. The [script](inference_bf16.py) is generally designed to support standard Stable Diffusion models with Bfloat16 support.

## Accelerating the inference for Stable Diffusion using INT8

Coming soon ...
49 changes: 49 additions & 0 deletions examples/research_projects/intel_opts/inference_bf16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch

import intel_extension_for_pytorch as ipex
from diffusers import StableDiffusionPipeline
from PIL import Image


def image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols

w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
grid_w, grid_h = grid.size

for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid


prompt = ["a lovely <dicoo> in red dress and hat, in the snowly and brightly night, with many brighly buildings"]
batch_size = 8
prompt = prompt * batch_size

device = "cpu"
model_id = "path-to-your-trained-model"
model = StableDiffusionPipeline.from_pretrained(model_id)
model = model.to(device)

# to channels last
model.unet = model.unet.to(memory_format=torch.channels_last)
model.vae = model.vae.to(memory_format=torch.channels_last)
model.text_encoder = model.text_encoder.to(memory_format=torch.channels_last)
model.safety_checker = model.safety_checker.to(memory_format=torch.channels_last)

# optimize with ipex
model.unet = ipex.optimize(model.unet.eval(), dtype=torch.bfloat16, inplace=True)
model.vae = ipex.optimize(model.vae.eval(), dtype=torch.bfloat16, inplace=True)
model.text_encoder = ipex.optimize(model.text_encoder.eval(), dtype=torch.bfloat16, inplace=True)
model.safety_checker = ipex.optimize(model.safety_checker.eval(), dtype=torch.bfloat16, inplace=True)

# compute
seed = 666
generator = torch.Generator(device).manual_seed(seed)
with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
images = model(prompt, guidance_scale=7.5, num_inference_steps=50, generator=generator).images

# save image
grid = image_grid(images, rows=2, cols=4)
grid.save(model_id + ".png")
68 changes: 68 additions & 0 deletions examples/research_projects/intel_opts/textual_inversion/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
## Textual Inversion fine-tuning example

[Textual inversion](https://arxiv.org/abs/2208.01618) is a method to personalize text2image models like stable diffusion on your own images using just 3-5 examples.
The `textual_inversion.py` script shows how to implement the training procedure and adapt it for stable diffusion.

## Training with Intel Extension for PyTorch

Intel Extension for PyTorch provides the optimizations for faster training and inference on CPUs. You can leverage the training example "textual_inversion.py". Follow the [instructions](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion) to get the model and [dataset](https://huggingface.co/sd-concepts-library/dicoo2) before running the script.

The example supports both single node and multi-node distributed training:

### Single node training

```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export DATA_DIR="path-to-dir-containing-dicoo-images"

python textual_inversion.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--train_data_dir=$DATA_DIR \
--learnable_property="object" \
--placeholder_token="<dicoo>" --initializer_token="toy" \
--seed=7 \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=1 \
--max_train_steps=3000 \
--learning_rate=2.5e-03 --scale_lr \
--output_dir="textual_inversion_dicoo"
```

Note: Bfloat16 is available on Intel Xeon Scalable Processors Cooper Lake or Sapphire Rapids. You may not get performance speedup without Bfloat16 support.

### Multi-node distributed training

Before running the scripts, make sure to install the library's training dependencies successfully:

```bash
python -m pip install oneccl_bind_pt==1.13 -f https://developer.intel.com/ipex-whl-stable-cpu
```

```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export DATA_DIR="path-to-dir-containing-dicoo-images"

oneccl_bindings_for_pytorch_path=$(python -c "from oneccl_bindings_for_pytorch import cwd; print(cwd)")
source $oneccl_bindings_for_pytorch_path/env/setvars.sh

python -m intel_extension_for_pytorch.cpu.launch --distributed \
--hostfile hostfile --nnodes 2 --nproc_per_node 2 textual_inversion.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--train_data_dir=$DATA_DIR \
--learnable_property="object" \
--placeholder_token="<dicoo>" --initializer_token="toy" \
--seed=7 \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=1 \
--max_train_steps=750 \
--learning_rate=2.5e-03 --scale_lr \
--output_dir="textual_inversion_dicoo"
```
The above is a simple distributed training usage on 2 nodes with 2 processes on each node. Add the right hostname or ip address in the "hostfile" and make sure these 2 nodes are reachable from each other. For more details, please refer to the [user guide](https://github.com/intel/torch-ccl).


### Reference

We publish a [Medium blog](https://medium.com/intel-analytics-software/personalized-stable-diffusion-with-few-shot-fine-tuning-on-a-single-cpu-f01a3316b13) on how to create your own Stable Diffusion model on CPUs using textual inversion. Try it out now, if you have interests.
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
accelerate
torchvision
transformers>=4.21.0
ftfy
tensorboard
modelcards
intel_extension_for_pytorch>=1.13
Loading

0 comments on commit c891330

Please sign in to comment.