Skip to content

Commit

Permalink
[ONNX] Stable Diffusion exporter and pipeline (huggingface#399)
Browse files Browse the repository at this point in the history
* initial export and design

* update imports

* custom prover, import fixes

* Update src/diffusers/onnx_utils.py

Co-authored-by: Patrick von Platen <[email protected]>

* Update src/diffusers/onnx_utils.py

Co-authored-by: Patrick von Platen <[email protected]>

* remove push_to_hub

* Update src/diffusers/onnx_utils.py

Co-authored-by: Suraj Patil <[email protected]>

* remove torch_device

* numpify the rest of the pipeline

* torchify the safety checker

* revert tensor

* Code review suggestions + quality

* fix tests

* fix provider, add an end-to-end test

* style

Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Suraj Patil <[email protected]>
  • Loading branch information
3 people authored Sep 8, 2022
1 parent 7bcc873 commit 8d9c4a5
Show file tree
Hide file tree
Showing 13 changed files with 657 additions and 6 deletions.
196 changes: 196 additions & 0 deletions scripts/convert_stable_diffusion_checkpoint_to_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
from pathlib import Path

import torch
from torch.onnx import export

from diffusers import StableDiffusionOnnxPipeline, StableDiffusionPipeline
from diffusers.onnx_utils import OnnxRuntimeModel
from packaging import version


is_torch_less_than_1_11 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.11")


def onnx_export(
model,
model_args: tuple,
output_path: Path,
ordered_input_names,
output_names,
dynamic_axes,
opset,
use_external_data_format=False,
):
output_path.parent.mkdir(parents=True, exist_ok=True)
# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
# so we check the torch version for backwards compatibility
if is_torch_less_than_1_11:
export(
model,
model_args,
f=output_path.as_posix(),
input_names=ordered_input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
do_constant_folding=True,
use_external_data_format=use_external_data_format,
enable_onnx_checker=True,
opset_version=opset,
)
else:
export(
model,
model_args,
f=output_path.as_posix(),
input_names=ordered_input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
do_constant_folding=True,
opset_version=opset,
)


@torch.no_grad()
def convert_models(model_path: str, output_path: str, opset: int):
pipeline = StableDiffusionPipeline.from_pretrained(model_path, use_auth_token=True)
output_path = Path(output_path)

# TEXT ENCODER
text_input = pipeline.tokenizer(
"A sample prompt",
padding="max_length",
max_length=pipeline.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
onnx_export(
pipeline.text_encoder,
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
model_args=(text_input.input_ids.to(torch.int32)),
output_path=output_path / "text_encoder" / "model.onnx",
ordered_input_names=["input_ids"],
output_names=["last_hidden_state", "pooler_output"],
dynamic_axes={
"input_ids": {0: "batch", 1: "sequence"},
},
opset=opset,
)

# UNET
onnx_export(
pipeline.unet,
model_args=(torch.randn(2, 4, 64, 64), torch.LongTensor([0, 1]), torch.randn(2, 77, 768), False),
output_path=output_path / "unet" / "model.onnx",
ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"],
output_names=["out_sample"], # has to be different from "sample" for correct tracing
dynamic_axes={
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
"timestep": {0: "batch"},
"encoder_hidden_states": {0: "batch", 1: "sequence"},
},
opset=opset,
use_external_data_format=True, # UNet is > 2GB, so the weights need to be split
)

# VAE ENCODER
vae_encoder = pipeline.vae
# need to get the raw tensor output (sample) from the encoder
vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].sample()
onnx_export(
vae_encoder,
model_args=(torch.randn(1, 3, 512, 512), False),
output_path=output_path / "vae_encoder" / "model.onnx",
ordered_input_names=["sample", "return_dict"],
output_names=["latent_sample"],
dynamic_axes={
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
},
opset=opset,
)

# VAE DECODER
vae_decoder = pipeline.vae
# forward only through the decoder part
vae_decoder.forward = vae_encoder.decode
onnx_export(
vae_decoder,
model_args=(torch.randn(1, 4, 64, 64), False),
output_path=output_path / "vae_decoder" / "model.onnx",
ordered_input_names=["latent_sample", "return_dict"],
output_names=["sample"],
dynamic_axes={
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
},
opset=opset,
)

# SAFETY CHECKER
safety_checker = pipeline.safety_checker
safety_checker.forward = safety_checker.forward_onnx
onnx_export(
pipeline.safety_checker,
model_args=(torch.randn(1, 3, 224, 224), torch.randn(1, 512, 512, 3)),
output_path=output_path / "safety_checker" / "model.onnx",
ordered_input_names=["clip_input", "images"],
output_names=["out_images", "has_nsfw_concepts"],
dynamic_axes={
"clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"},
"images": {0: "batch", 1: "channels", 2: "height", 3: "width"},
},
opset=opset,
)

onnx_pipeline = StableDiffusionOnnxPipeline(
vae_decoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_decoder"),
text_encoder=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder"),
tokenizer=pipeline.tokenizer,
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
scheduler=pipeline.scheduler,
safety_checker=OnnxRuntimeModel.from_pretrained(output_path / "safety_checker"),
feature_extractor=pipeline.feature_extractor,
)

onnx_pipeline.save_pretrained(output_path)
print("ONNX pipeline saved to", output_path)

_ = StableDiffusionOnnxPipeline.from_pretrained(output_path, provider="CPUExecutionProvider")
print("ONNX pipeline is loadable")


if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument(
"--model_path",
type=str,
required=True,
help="Path to the `diffusers` checkpoint to convert (either a local directory or on the Hub).",
)

parser.add_argument("--output_path", type=str, required=True, help="Path to the output model.")

parser.add_argument(
"--opset",
default=14,
type=str,
help="The version of the ONNX operator set to use.",
)

args = parser.parse_args()

convert_models(args.model_path, args.output_path, args.opset)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def run(self):
extras["quality"] = ["black==22.3", "isort>=5.5.4", "flake8>=3.8.3", "hf-doc-builder"]
extras["docs"] = ["hf-doc-builder"]
extras["training"] = ["accelerate", "datasets", "tensorboard", "modelcards"]
extras["test"] = ["datasets", "pytest", "pytest-timeout", "pytest-xdist", "scipy", "transformers"]
extras["test"] = ["datasets", "onnxruntime", "pytest", "pytest-timeout", "pytest-xdist", "scipy", "transformers"]
extras["dev"] = extras["quality"] + extras["test"] + extras["training"] + extras["docs"]

install_requires = [
Expand Down
15 changes: 14 additions & 1 deletion src/diffusers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
from .utils import is_inflect_available, is_scipy_available, is_transformers_available, is_unidecode_available
from .utils import (
is_inflect_available,
is_onnx_available,
is_scipy_available,
is_transformers_available,
is_unidecode_available,
)


__version__ = "0.3.0.dev0"

from .modeling_utils import ModelMixin
from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
from .onnx_utils import OnnxRuntimeModel
from .optimization import (
get_constant_schedule,
get_constant_schedule_with_warmup,
Expand Down Expand Up @@ -44,3 +51,9 @@
)
else:
from .utils.dummy_transformers_objects import * # noqa F403


if is_transformers_available() and is_onnx_available():
from .pipelines import StableDiffusionOnnxPipeline
else:
from .utils.dummy_transformers_and_onnx_objects import * # noqa F403
Loading

0 comments on commit 8d9c4a5

Please sign in to comment.