Skip to content

Commit

Permalink
[Model] Add stable diffusion model based on fastdeploy (PaddlePaddle#297
Browse files Browse the repository at this point in the history
)

* Add stable diffusion model base on fastdeploy

* Add sd infer

* pipelines->multimodal

* add create_ort_runtime

* use fp16 input

* fix pil

* Add optimize unet model

* add hf license

* Add workspace args

* Add profile func

* Add schedulers

* usrelace torch.Tenosr  byp.ndarray

* Add readme

* Add trt shape setting

* add dynamic shape

* Add dynamic shape for stable diffusion

* fix max shape setting

* rename tensorrt file suffix

* update dynamic shape setting

* Add scheduler output

* Add inference_steps and benchmark steps

* add diffuser benchmark

* Add paddle infer script

* Rename 1

* Rename infer.py to torch_onnx_infer.py

* Add export torch to onnx model

* renmove export model

* Add paddle export model for diffusion

* Fix export model

* mv torch onnx infer to infer

* Fix export model

* Fix infer

* modif create_trt_runtime create_ort_runtime

* update export torch

* update requirements

* add paddle inference backend

* Fix unet pp run

* remove print

* Add paddle model export and infer

* Add device id

* remove profile to utils

* Add -1 device id

* Add safety checker args

* remove safety checker temporarily

* Add export model description

* Add predict description

* Fix readme

* Fix device_id description

* add timestep shape

* add use fp16 precision

* move use gpu

* Add EulerAncestralDiscreteScheduler

* Use EulerAncestralDiscreteScheduler with v1-5 model

* Add export model readme

* Add link of exported model

* Update scheduler on README

* Addd stable-diffusion-v1-5
  • Loading branch information
joey12300 authored Nov 10, 2022
1 parent fa80734 commit d4995e5
Show file tree
Hide file tree
Showing 13 changed files with 2,301 additions and 0 deletions.
59 changes: 59 additions & 0 deletions examples/multimodal/stable_diffusion/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# FastDeploy Diffusion模型高性能部署

本部署示例使用⚡️`FastDeploy`在Huggingface团队[Diffusers](https://github.com/huggingface/diffusers)项目设计的`DiffusionPipeline`基础上,完成Diffusion模型的高性能部署。

### 部署模型准备

本示例需要使用训练模型导出后的部署模型。有两种部署模型的获取方式:

- 模型导出方式,可参考[模型导出文档](./export.md)导出部署模型。
- 下载部署模型。为了方便开发者快速测试本示例,我们已经将部分`Diffusion`模型预先导出,开发者只要下载模型就可以快速测试:

| 模型 | Scheduler |
|----------|--------------|
| [CompVis/stable-diffusion-v1-4](https://bj.bcebos.com/fastdeploy/models/stable-diffusion/CompVis/stable-diffusion-v1-4.tgz) | PNDM |
| [runwayml/stable-diffusion-v1-5](https://bj.bcebos.com/fastdeploy/models/stable-diffusion/runwayml/stable-diffusion-v1-5.tgz) | EulerAncestral |

## 环境依赖

在示例中使用了PaddleNLP的CLIP模型的分词器,所以需要执行以下命令安装依赖。

```shell
pip install paddlenlp paddlepaddle-gpu
```

### 快速体验

我们经过部署模型准备,可以开始进行测试。下面将指定模型目录以及推理引擎后端,运行`infer.py`脚本,完成推理。

```
python infer.py --model_dir stable-diffusion-v1-4/ --scheduler "pndm" --backend paddle
```

得到的图像文件为fd_astronaut_rides_horse.png。生成的图片示例如下(每次生成的图片都不相同,示例仅作参考):

![fd_astronaut_rides_horse.png](https://user-images.githubusercontent.com/10826371/200261112-68e53389-e0a0-42d1-8c3a-f35faa6627d7.png)

如果使用stable-diffusion-v1-5模型,则可执行以下命令完成推理:

```
python infer.py --model_dir stable-diffusion-v1-5/ --scheduler "euler_ancestral" --backend paddle
```

#### 参数说明

`infer.py` 除了以上示例的命令行参数,还支持更多命令行参数的设置。以下为各命令行参数的说明。

| 参数 |参数说明 |
|----------|--------------|
| --model_dir | 导出后模型的目录。 |
| --model_format | 模型格式。默认为`'paddle'`,可选列表:`['paddle', 'onnx']`|
| --backend | 推理引擎后端。默认为`paddle`,可选列表:`['onnx_runtime', 'paddle']`,当模型格式为`onnx`时,可选列表为`['onnx_runtime']`|
| --scheduler | StableDiffusion 模型的scheduler。默认为`'pndm'`。可选列表:`['pndm', 'euler_ancestral']`,StableDiffusio模型对应的scheduler可参考[ppdiffuser模型列表](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/ppdiffusers/examples/textual_inversion)|
| --unet_model_prefix | UNet模型前缀。默认为`unet`|
| --vae_model_prefix | VAE模型前缀。默认为`vae_decoder`|
| --text_encoder_model_prefix | TextEncoder模型前缀。默认为`text_encoder`|
| --inference_steps | UNet模型运行的次数,默认为100。 |
| --image_path | 生成图片的路径。默认为`fd_astronaut_rides_horse.png`|
| --device_id | gpu设备的id。若`device_id`为-1,视为使用cpu推理。 |
| --use_fp16 | 是否使用fp16精度。默认为`False`。使用tensorrt或者paddle-tensorrt后端时可以设为`True`开启。 |
156 changes: 156 additions & 0 deletions examples/multimodal/stable_diffusion/config_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Copyright 2022 The HuggingFace Inc. team.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import inspect
from collections import OrderedDict
from typing import Any, Dict, Tuple, Union


class ConfigMixin:
r"""
Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all
methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with
- [`~ConfigMixin.from_config`]
- [`~ConfigMixin.save_config`]
Class attributes:
- **config_name** (`str`) -- A filename under which the config should stored when calling
[`~ConfigMixin.save_config`] (should be overridden by parent class).
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
overridden by parent class).
"""
config_name = None
ignore_for_config = []

def register_to_config(self, **kwargs):
if self.config_name is None:
raise NotImplementedError(
f"Make sure that {self.__class__} has defined a class name `config_name`"
)
kwargs["_class_name"] = self.__class__.__name__

# Special case for `kwargs` used in deprecation warning added to schedulers
# TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
# or solve in a more general way.
kwargs.pop("kwargs", None)
for key, value in kwargs.items():
try:
setattr(self, key, value)
except AttributeError as err:
logger.error(f"Can't set {key} with value {value} for {self}")
raise err

if not hasattr(self, "_internal_dict"):
internal_dict = kwargs
else:
previous_dict = dict(self._internal_dict)
internal_dict = { ** self._internal_dict, ** kwargs}
logger.debug(
f"Updating config from {previous_dict} to {internal_dict}")

self._internal_dict = FrozenDict(internal_dict)

@property
def config(self) -> Dict[str, Any]:
return self._internal_dict


class FrozenDict(OrderedDict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

for key, value in self.items():
setattr(self, key, value)

self.__frozen = True

def __delitem__(self, *args, **kwargs):
raise Exception(
f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance."
)

def setdefault(self, *args, **kwargs):
raise Exception(
f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance."
)

def pop(self, *args, **kwargs):
raise Exception(
f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")

def update(self, *args, **kwargs):
raise Exception(
f"You cannot use ``update`` on a {self.__class__.__name__} instance."
)

def __setattr__(self, name, value):
if hasattr(self, "__frozen") and self.__frozen:
raise Exception(
f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance."
)
super().__setattr__(name, value)

def __setitem__(self, name, value):
if hasattr(self, "__frozen") and self.__frozen:
raise Exception(
f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance."
)
super().__setitem__(name, value)


def register_to_config(init):
r"""
Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
shouldn't be registered in the config, use the `ignore_for_config` class variable
Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
"""

@functools.wraps(init)
def inner_init(self, *args, **kwargs):
# Ignore private kwargs in the init.
init_kwargs = {
k: v
for k, v in kwargs.items() if not k.startswith("_")
}
init(self, *args, **init_kwargs)
if not isinstance(self, ConfigMixin):
raise RuntimeError(
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
"not inherit from `ConfigMixin`.")
ignore = getattr(self, "ignore_for_config", [])
# Get positional arguments aligned with kwargs
new_kwargs = {}
signature = inspect.signature(init)
parameters = {
name: p.default
for i, (name, p) in enumerate(signature.parameters.items())
if i > 0 and name not in ignore
}
for arg, name in zip(args, parameters.keys()):
new_kwargs[name] = arg

# Then add all kwargs
new_kwargs.update({
k: init_kwargs.get(k, default)
for k, default in parameters.items()
if k not in ignore and k not in new_kwargs
})
getattr(self, "register_to_config")(**new_kwargs)

return inner_init
105 changes: 105 additions & 0 deletions examples/multimodal/stable_diffusion/export.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Diffusion模型导出教程

本项目支持两种模型导出方式:[PPDiffusers](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/ppdiffusers)模型导出以及[Diffusers](https://github.com/huggingface/diffusers)模型导出。下面分别介绍这两种模型导出方式。

## PPDiffusers 模型导出

[PPDiffusers](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/ppdiffusers)是一款支持跨模态(如图像与语音)训练和推理的扩散模型(Diffusion Model)工具箱,其借鉴了🤗 Huggingface团队的[Diffusers](https://github.com/huggingface/diffusers)的优秀设计,并且依托[PaddlePaddle](https://github.com/PaddlePaddle/Paddle)框架和[PaddleNLP](https://github.com/PaddlePaddle/PaddleNLP)自然语言处理库。下面介绍如何使用FastDeploy将PPDiffusers提供的Diffusion模型进行高性能部署。

### 依赖安装

模型导出需要依赖`paddlepaddle`, `paddlenlp`以及`ppdiffusers`,可使用`pip`执行下面的命令进行快速安装。

```shell
pip install -r requirements_paddle.txt
```

### 模型导出

___注意:模型导出过程中,需要下载StableDiffusion模型。为了使用该模型与权重,你必须接受该模型所要求的License,请访问HuggingFace的[model card](https://huggingface.co/runwayml/stable-diffusion-v1-5), 仔细阅读里面的License,然后签署该协议。___

___Tips: Stable Diffusion是基于以下的License: The CreativeML OpenRAIL M license is an Open RAIL M license, adapted from the work that BigScience and the RAIL Initiative are jointly carrying in the area of responsible AI licensing. See also the article about the BLOOM Open RAIL license on which this license is based.___

可执行以下命令行完成模型导出。

```shell
python export_model.py --pretrained_model_name_or_path CompVis/stable-diffusion-v1-4 --output_path stable-diffusion-v1-4
```

输出的模型目录结构如下:
```shell
stable-diffusion-v1-4/
├── text_encoder
│   ├── inference.pdiparams
│   ├── inference.pdiparams.info
│   └── inference.pdmodel
├── unet
│   ├── inference.pdiparams
│   ├── inference.pdiparams.info
│   └── inference.pdmodel
└── vae_decoder
├── inference.pdiparams
├── inference.pdiparams.info
└── inference.pdmodel
```

#### 参数说明

`export_model.py` 各命令行参数的说明。

| 参数 |参数说明 |
|----------|--------------|
|<div style="width: 230pt">--pretrained_model_name_or_path </div> | ppdiffuers提供的diffusion预训练模型。默认为:"CompVis/stable-diffusion-v1-4 "。更多diffusion预训练模型可参考[ppdiffuser模型列表](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/ppdiffusers/examples/textual_inversion)|
|--output_path | 导出的模型目录。 |


## Diffusers 模型导出

[Diffusers](https://github.com/huggingface/diffusers)是一款由HuggingFace打造的支持跨模态(如图像与语音)训练和推理的扩散模型(Diffusion Model)工具箱。其底层的模型代码提供PyTorch实现的版本以及Flax实现的版本两种版本。本示例将介绍如何使用FastDeploy将PyTorch实现的Diffusion模型进行高性能部署。

### 依赖安装

模型导出需要依赖`onnx`, `torch`, `diffusers`以及`transformers`,可使用`pip`执行下面的命令进行快速安装。

```shell
pip install -r requirements_torch.txt
```

### 模型导出

___注意:模型导出过程中,需要下载StableDiffusion模型。为了使用该模型与权重,你必须接受该模型所要求的License,并且获取HF Hub授予的Token。请访问HuggingFace的[model card](https://huggingface.co/runwayml/stable-diffusion-v1-5), 仔细阅读里面的License,然后签署该协议。___

___Tips: Stable Diffusion是基于以下的License: The CreativeML OpenRAIL M license is an Open RAIL M license, adapted from the work that BigScience and the RAIL Initiative are jointly carrying in the area of responsible AI licensing. See also the article about the BLOOM Open RAIL license on which this license is based.___

若第一次导出模型,需要先登录HuggingFace客户端。执行以下命令进行登录:

```shell
huggingface-cli login
```

完成登录后,执行以下命令行完成模型导出。

```shell
python export_torch_to_onnx_model.py --pretrained_model_name_or_path CompVis/stable-diffusion-v1-4 --output_path torch_diffusion_model
```

输出的模型目录结构如下:

```shell
torch_diffusion_model/
├── text_encoder
│   └── inference.onnx
├── unet
│   └── inference.onnx
└── vae_decoder
└── inference.onnx
```

#### 参数说明

`export_torch_to_onnx_model.py` 各命令行参数的说明。

| 参数 |参数说明 |
|----------|--------------|
|<div style="width: 230pt">--pretrained_model_name_or_path </div> | ppdiffuers提供的diffusion预训练模型。默认为:"CompVis/stable-diffusion-v1-4 "。更多diffusion预训练模型可参考[HuggingFace模型列表说明](https://huggingface.co/CompVis/stable-diffusion-v1-4)|
|--output_path | 导出的模型目录。 |
Loading

0 comments on commit d4995e5

Please sign in to comment.