Skip to content

Commit

Permalink
[ONNX] Improve ONNXPipeline scheduler compatibility, fix safety_check…
Browse files Browse the repository at this point in the history
…er (huggingface#1173)

* [ONNX] Improve ONNX scheduler compatibility, fix safety_checker

* typo
  • Loading branch information
anton-l authored Nov 8, 2022
1 parent 555203e commit 11f7d6f
Show file tree
Hide file tree
Showing 10 changed files with 346 additions and 89 deletions.
73 changes: 49 additions & 24 deletions scripts/convert_stable_diffusion_checkpoint_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
output_path = Path(output_path)

# TEXT ENCODER
num_tokens = pipeline.text_encoder.config.max_position_embeddings
text_hidden_size = pipeline.text_encoder.config.hidden_size
text_input = pipeline.tokenizer(
"A sample prompt",
padding="max_length",
Expand All @@ -103,13 +105,15 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
del pipeline.text_encoder

# UNET
unet_in_channels = pipeline.unet.config.in_channels
unet_sample_size = pipeline.unet.config.sample_size
unet_path = output_path / "unet" / "model.onnx"
onnx_export(
pipeline.unet,
model_args=(
torch.randn(2, pipeline.unet.in_channels, 64, 64).to(device=device, dtype=dtype),
torch.LongTensor([0, 1]).to(device=device),
torch.randn(2, 77, 768).to(device=device, dtype=dtype),
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(device=device, dtype=dtype),
torch.randn(2).to(device=device, dtype=dtype),
torch.randn(2, num_tokens, text_hidden_size).to(device=device, dtype=dtype),
False,
),
output_path=unet_path,
Expand Down Expand Up @@ -142,11 +146,16 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F

# VAE ENCODER
vae_encoder = pipeline.vae
vae_in_channels = vae_encoder.config.in_channels
vae_sample_size = vae_encoder.config.sample_size
# need to get the raw tensor output (sample) from the encoder
vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].sample()
onnx_export(
vae_encoder,
model_args=(torch.randn(1, 3, 512, 512).to(device=device, dtype=dtype), False),
model_args=(
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(device=device, dtype=dtype),
False,
),
output_path=output_path / "vae_encoder" / "model.onnx",
ordered_input_names=["sample", "return_dict"],
output_names=["latent_sample"],
Expand All @@ -158,11 +167,16 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F

# VAE DECODER
vae_decoder = pipeline.vae
vae_latent_channels = vae_decoder.config.latent_channels
vae_out_channels = vae_decoder.config.out_channels
# forward only through the decoder part
vae_decoder.forward = vae_encoder.decode
onnx_export(
vae_decoder,
model_args=(torch.randn(1, 4, 64, 64).to(device=device, dtype=dtype), False),
model_args=(
torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(device=device, dtype=dtype),
False,
),
output_path=output_path / "vae_decoder" / "model.onnx",
ordered_input_names=["latent_sample", "return_dict"],
output_names=["sample"],
Expand All @@ -174,24 +188,35 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
del pipeline.vae

# SAFETY CHECKER
safety_checker = pipeline.safety_checker
safety_checker.forward = safety_checker.forward_onnx
onnx_export(
pipeline.safety_checker,
model_args=(
torch.randn(1, 3, 224, 224).to(device=device, dtype=dtype),
torch.randn(1, 512, 512, 3).to(device=device, dtype=dtype),
),
output_path=output_path / "safety_checker" / "model.onnx",
ordered_input_names=["clip_input", "images"],
output_names=["out_images", "has_nsfw_concepts"],
dynamic_axes={
"clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"},
"images": {0: "batch", 1: "height", 2: "width", 3: "channels"},
},
opset=opset,
)
del pipeline.safety_checker
if pipeline.safety_checker is not None:
safety_checker = pipeline.safety_checker
clip_num_channels = safety_checker.config.vision_config.num_channels
clip_image_size = safety_checker.config.vision_config.image_size
safety_checker.forward = safety_checker.forward_onnx
onnx_export(
pipeline.safety_checker,
model_args=(
torch.randn(
1,
clip_num_channels,
clip_image_size,
clip_image_size,
).to(device=device, dtype=dtype),
torch.randn(1, vae_sample_size, vae_sample_size, vae_out_channels).to(device=device, dtype=dtype),
),
output_path=output_path / "safety_checker" / "model.onnx",
ordered_input_names=["clip_input", "images"],
output_names=["out_images", "has_nsfw_concepts"],
dynamic_axes={
"clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"},
"images": {0: "batch", 1: "height", 2: "width", 3: "channels"},
},
opset=opset,
)
del pipeline.safety_checker
safety_checker = OnnxRuntimeModel.from_pretrained(output_path / "safety_checker")
else:
safety_checker = None

onnx_pipeline = OnnxStableDiffusionPipeline(
vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"),
Expand All @@ -200,7 +225,7 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
tokenizer=pipeline.tokenizer,
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
scheduler=pipeline.scheduler,
safety_checker=OnnxRuntimeModel.from_pretrained(output_path / "safety_checker"),
safety_checker=safety_checker,
feature_extractor=pipeline.feature_extractor,
)

Expand Down
28 changes: 26 additions & 2 deletions src/diffusers/onnx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from huggingface_hub import hf_hub_download

from .utils import ONNX_WEIGHTS_NAME, is_onnx_available, logging
from .utils import ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, is_onnx_available, logging


if is_onnx_available():
Expand All @@ -33,13 +33,28 @@

logger = logging.get_logger(__name__)

ORT_TO_NP_TYPE = {
"tensor(bool)": np.bool_,
"tensor(int8)": np.int8,
"tensor(uint8)": np.uint8,
"tensor(int16)": np.int16,
"tensor(uint16)": np.uint16,
"tensor(int32)": np.int32,
"tensor(uint32)": np.uint32,
"tensor(int64)": np.int64,
"tensor(uint64)": np.uint64,
"tensor(float16)": np.float16,
"tensor(float)": np.float32,
"tensor(double)": np.float64,
}


class OnnxRuntimeModel:
def __init__(self, model=None, **kwargs):
logger.info("`diffusers.OnnxRuntimeModel` is experimental and might change in the future.")
self.model = model
self.model_save_dir = kwargs.get("model_save_dir", None)
self.latest_model_name = kwargs.get("latest_model_name", "model.onnx")
self.latest_model_name = kwargs.get("latest_model_name", ONNX_WEIGHTS_NAME)

def __call__(self, **kwargs):
inputs = {k: np.array(v) for k, v in kwargs.items()}
Expand Down Expand Up @@ -84,6 +99,15 @@ def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional
except shutil.SameFileError:
pass

# copy external weights (for models >2GB)
src_path = self.model_save_dir.joinpath(ONNX_EXTERNAL_WEIGHTS_NAME)
if src_path.exists():
dst_path = Path(save_directory).joinpath(ONNX_EXTERNAL_WEIGHTS_NAME)
try:
shutil.copyfile(src_path, dst_path)
except shutil.SameFileError:
pass

def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
# if the model is in a pipeline module, then we load it from the pipeline
if name in passed_class_obj:
# 1. check that passed_class_obj has correct parent class
if not is_pipeline_module:
if not is_pipeline_module and passed_class_obj[name] is not None:
library = importlib.import_module(library_name)
class_obj = getattr(library, class_name)
importable_classes = LOADABLE_CLASSES[library_name]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from typing import Callable, List, Optional, Union

import numpy as np
import torch

from transformers import CLIPFeatureExtractor, CLIPTokenizer

from ...configuration_utils import FrozenDict
from ...onnx_utils import OnnxRuntimeModel
from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import deprecate, logging
Expand Down Expand Up @@ -186,7 +187,7 @@ def __call__(
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)

latents = latents * self.scheduler.init_noise_sigma
latents = latents * np.float(self.scheduler.init_noise_sigma)

# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
Expand All @@ -197,15 +198,20 @@ def __call__(
if accepts_eta:
extra_step_kwargs["eta"] = eta

timestep_dtype = next(
(input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
)
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]

for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
latent_model_input = latent_model_input.cpu().numpy()

# predict the noise residual
noise_pred = self.unet(
sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings
)
timestep = np.array([t], dtype=timestep_dtype)
noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=text_embeddings)
noise_pred = noise_pred[0]

# perform guidance
Expand All @@ -214,7 +220,7 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t, torch.from_numpy(latents), **extra_step_kwargs).prev_sample
latents = np.array(latents)

# call the callback, if provided
Expand All @@ -235,6 +241,9 @@ def __call__(
safety_checker_input = self.feature_extractor(
self.numpy_to_pil(image), return_tensors="np"
).pixel_values.astype(image.dtype)

image, has_nsfw_concepts = self.safety_checker(clip_input=safety_checker_input, images=image)

# There will throw an error if use safety_checker batchsize>1
images, has_nsfw_concept = [], []
for i in range(image.shape[0]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from transformers import CLIPFeatureExtractor, CLIPTokenizer

from ...configuration_utils import FrozenDict
from ...onnx_utils import OnnxRuntimeModel
from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import deprecate, logging
Expand Down Expand Up @@ -338,14 +338,21 @@ def __call__(
t_start = max(num_inference_steps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:].numpy()

timestep_dtype = next(
(input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
)
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]

for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
latent_model_input = latent_model_input.cpu().numpy()

# predict the noise residual
timestep = np.array([t], dtype=timestep_dtype)
noise_pred = self.unet(
sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings
sample=latent_model_input, timestep=timestep, encoder_hidden_states=text_embeddings
)[0]

# perform guidance
Expand All @@ -354,7 +361,7 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t, torch.from_numpy(latents), **extra_step_kwargs).prev_sample
latents = latents.numpy()

# call the callback, if provided
Expand All @@ -375,7 +382,7 @@ def __call__(
safety_checker_input = self.feature_extractor(
self.numpy_to_pil(image), return_tensors="np"
).pixel_values.astype(image.dtype)
# There will throw an error if use safety_checker batchsize>1
# safety_checker does not support batched inputs yet
images, has_nsfw_concept = [], []
for i in range(image.shape[0]):
image_i, has_nsfw_concept_i = self.safety_checker(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from transformers import CLIPFeatureExtractor, CLIPTokenizer

from ...configuration_utils import FrozenDict
from ...onnx_utils import OnnxRuntimeModel
from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import deprecate, logging
Expand Down Expand Up @@ -352,7 +352,7 @@ def __call__(
self.scheduler.set_timesteps(num_inference_steps)

# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
latents = latents * np.float(self.scheduler.init_noise_sigma)

# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
Expand All @@ -363,17 +363,23 @@ def __call__(
if accepts_eta:
extra_step_kwargs["eta"] = eta

timestep_dtype = next(
(input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
)
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]

for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
# concat latents, mask, masked_image_latnets in the channel dimension
latent_model_input = np.concatenate([latent_model_input, mask, masked_image_latents], axis=1)
latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
latent_model_input = latent_model_input.numpy()
latent_model_input = latent_model_input.cpu().numpy()

# predict the noise residual
timestep = np.array([t], dtype=timestep_dtype)
noise_pred = self.unet(
sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings
sample=latent_model_input, timestep=timestep, encoder_hidden_states=text_embeddings
)[0]

# perform guidance
Expand All @@ -382,7 +388,7 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t, torch.from_numpy(latents), **extra_step_kwargs).prev_sample
latents = latents.numpy()

# call the callback, if provided
Expand All @@ -403,7 +409,7 @@ def __call__(
safety_checker_input = self.feature_extractor(
self.numpy_to_pil(image), return_tensors="np"
).pixel_values.astype(image.dtype)
# There will throw an error if use safety_checker batchsize>1
# safety_checker does not support batched inputs yet
images, has_nsfw_concept = [], []
for i in range(image.shape[0]):
image_i, has_nsfw_concept_i = self.safety_checker(
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
WEIGHTS_NAME = "diffusion_pytorch_model.bin"
FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
ONNX_WEIGHTS_NAME = "model.onnx"
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
DIFFUSERS_CACHE = default_cache_path
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
Expand Down
Loading

0 comments on commit 11f7d6f

Please sign in to comment.