Skip to content

Commit

Permalink
Stable Diffusion image-to-image and inpaint using onnx. (huggingface#552
Browse files Browse the repository at this point in the history
)

* * Stabe Diffusion img2img using onnx.

* * Stabe Diffusion inpaint using onnx.

* Export vae_encoder, upgrade img2img, add test

* updated inpainting pipeline + test

* style

Co-authored-by: anton-l <[email protected]>
  • Loading branch information
zledas and anton-l authored Oct 18, 2022
1 parent fbe807b commit a9908ec
Show file tree
Hide file tree
Showing 8 changed files with 837 additions and 2 deletions.
7 changes: 7 additions & 0 deletions scripts/convert_stable_diffusion_checkpoint_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def convert_models(model_path: str, output_path: str, opset: int):
},
opset=opset,
)
del pipeline.text_encoder

# UNET
unet_path = output_path / "unet" / "model.onnx"
Expand Down Expand Up @@ -125,6 +126,7 @@ def convert_models(model_path: str, output_path: str, opset: int):
location="weights.pb",
convert_attribute=False,
)
del pipeline.unet

# VAE ENCODER
vae_encoder = pipeline.vae
Expand Down Expand Up @@ -157,6 +159,7 @@ def convert_models(model_path: str, output_path: str, opset: int):
},
opset=opset,
)
del pipeline.vae

# SAFETY CHECKER
safety_checker = pipeline.safety_checker
Expand All @@ -173,8 +176,10 @@ def convert_models(model_path: str, output_path: str, opset: int):
},
opset=opset,
)
del pipeline.safety_checker

onnx_pipeline = StableDiffusionOnnxPipeline(
vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"),
vae_decoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_decoder"),
text_encoder=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder"),
tokenizer=pipeline.tokenizer,
Expand All @@ -187,6 +192,8 @@ def convert_models(model_path: str, output_path: str, opset: int):
onnx_pipeline.save_pretrained(output_path)
print("ONNX pipeline saved to", output_path)

del pipeline
del onnx_pipeline
_ = StableDiffusionOnnxPipeline.from_pretrained(output_path, provider="CPUExecutionProvider")
print("ONNX pipeline is loadable")

Expand Down
7 changes: 6 additions & 1 deletion src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,12 @@
from .utils.dummy_torch_and_transformers_objects import * # noqa F403

if is_torch_available() and is_transformers_available() and is_onnx_available():
from .pipelines import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline
from .pipelines import (
OnnxStableDiffusionImg2ImgPipeline,
OnnxStableDiffusionInpaintPipeline,
OnnxStableDiffusionPipeline,
StableDiffusionOnnxPipeline,
)
else:
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403

Expand Down
7 changes: 6 additions & 1 deletion src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
)

if is_transformers_available() and is_onnx_available():
from .stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline
from .stable_diffusion import (
OnnxStableDiffusionImg2ImgPipeline,
OnnxStableDiffusionInpaintPipeline,
OnnxStableDiffusionPipeline,
StableDiffusionOnnxPipeline,
)

if is_transformers_available() and is_flax_available():
from .stable_diffusion import FlaxStableDiffusionPipeline
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/stable_diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class StableDiffusionPipelineOutput(BaseOutput):

if is_transformers_available() and is_onnx_available():
from .pipeline_onnx_stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline
from .pipeline_onnx_stable_diffusion_img2img import OnnxStableDiffusionImg2ImgPipeline
from .pipeline_onnx_stable_diffusion_inpaint import OnnxStableDiffusionInpaintPipeline

if is_transformers_available() and is_flax_available():
import flax
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):

def __init__(
self,
vae_encoder: OnnxRuntimeModel,
vae_decoder: OnnxRuntimeModel,
text_encoder: OnnxRuntimeModel,
tokenizer: CLIPTokenizer,
Expand All @@ -36,6 +37,7 @@ def __init__(
):
super().__init__()
self.register_modules(
vae_encoder=vae_encoder,
vae_decoder=vae_decoder,
text_encoder=text_encoder,
tokenizer=tokenizer,
Expand Down

Large diffs are not rendered by default.

Large diffs are not rendered by default.

68 changes: 68 additions & 0 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
LDMPipeline,
LDMTextToImagePipeline,
LMSDiscreteScheduler,
OnnxStableDiffusionImg2ImgPipeline,
OnnxStableDiffusionInpaintPipeline,
OnnxStableDiffusionPipeline,
PNDMPipeline,
PNDMScheduler,
Expand Down Expand Up @@ -2025,6 +2027,72 @@ def test_stable_diffusion_onnx(self):
expected_slice = np.array([0.3602, 0.3688, 0.3652, 0.3895, 0.3782, 0.3747, 0.3927, 0.4241, 0.4327])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3

@slow
def test_stable_diffusion_img2img_onnx(self):
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/img2img/sketch-mountains-input.jpg"
)
init_image = init_image.resize((768, 512))

pipe = OnnxStableDiffusionImg2ImgPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="onnx", provider="CPUExecutionProvider"
)
pipe.set_progress_bar_config(disable=None)

prompt = "A fantasy landscape, trending on artstation"

np.random.seed(0)
output = pipe(
prompt=prompt,
init_image=init_image,
strength=0.75,
guidance_scale=7.5,
num_inference_steps=8,
output_type="np",
)
images = output.images
image_slice = images[0, 255:258, 383:386, -1]

assert images.shape == (1, 512, 768, 3)
expected_slice = np.array([[0.4806, 0.5125, 0.5453, 0.4846, 0.4984, 0.4955, 0.4830, 0.4962, 0.4969]])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3

@slow
def test_stable_diffusion_inpaint_onnx(self):
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo.png"
)
mask_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
)

pipe = OnnxStableDiffusionInpaintPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="onnx", provider="CPUExecutionProvider"
)
pipe.set_progress_bar_config(disable=None)

prompt = "A red cat sitting on a park bench"

np.random.seed(0)
output = pipe(
prompt=prompt,
init_image=init_image,
mask_image=mask_image,
strength=0.75,
guidance_scale=7.5,
num_inference_steps=8,
output_type="np",
)
images = output.images
image_slice = images[0, 255:258, 255:258, -1]

assert images.shape == (1, 512, 512, 3)
expected_slice = np.array([0.3524, 0.3289, 0.3464, 0.3872, 0.4129, 0.3566, 0.3709, 0.4128, 0.3734])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3

@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_text2img_intermediate_state(self):
Expand Down

0 comments on commit a9908ec

Please sign in to comment.