Skip to content

Commit

Permalink
fix a bug in from_pretrained when load optional components (hugging…
Browse files Browse the repository at this point in the history
…face#4745)

* fix
---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Patrick von Platen <[email protected]>
  • Loading branch information
3 people authored Aug 25, 2023
1 parent 3bba44d commit b3b2d30
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,8 +1012,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)

# define init kwargs
init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
# define init kwargs and make sure that optional component modules are filtered out
init_kwargs = {
k: init_dict.pop(k)
for k in optional_kwargs
if k in init_dict and k not in pipeline_class._optional_components
}
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}

# remove `null` components
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import gc
import random
import tempfile
import unittest

import numpy as np
Expand Down Expand Up @@ -315,6 +316,52 @@ def test_stable_diffusion_upscale_fp16(self):
expected_height_width = low_res_image.size[0] * 4
assert image.shape == (1, expected_height_width, expected_height_width, 3)

def test_stable_diffusion_upscale_from_save_pretrained(self):
pipes = []

device = "cpu" # ensure determinism for the device-dependent torch.Generator
low_res_scheduler = DDPMScheduler()
scheduler = DDIMScheduler(prediction_type="v_prediction")
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionUpscalePipeline(
unet=self.dummy_cond_unet_upscale,
low_res_scheduler=low_res_scheduler,
scheduler=scheduler,
vae=self.dummy_vae,
text_encoder=self.dummy_text_encoder,
tokenizer=tokenizer,
max_noise_level=350,
)
sd_pipe = sd_pipe.to(device)
pipes.append(sd_pipe)

with tempfile.TemporaryDirectory() as tmpdirname:
sd_pipe.save_pretrained(tmpdirname)
sd_pipe = StableDiffusionUpscalePipeline.from_pretrained(tmpdirname).to(device)
pipes.append(sd_pipe)

prompt = "A painting of a squirrel eating a burger"
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
low_res_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))

image_slices = []
for pipe in pipes:
generator = torch.Generator(device=device).manual_seed(0)
image = pipe(
[prompt],
image=low_res_image,
generator=generator,
guidance_scale=6.0,
noise_level=20,
num_inference_steps=2,
output_type="np",
).images
image_slices.append(image[0, -3:, -3:, -1].flatten())

assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3


@slow
@require_torch_gpu
Expand Down
23 changes: 23 additions & 0 deletions tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import copy
import tempfile
import unittest

import numpy as np
Expand Down Expand Up @@ -689,3 +690,25 @@ def test_stable_diffusion_xl_multi_prompts(self):

# ensure the results are not equal
assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4

def test_stable_diffusion_xl_save_from_pretrained(self):
pipes = []
components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**components).to(torch_device)
pipes.append(sd_pipe)

with tempfile.TemporaryDirectory() as tmpdirname:
sd_pipe.save_pretrained(tmpdirname)
sd_pipe = StableDiffusionXLPipeline.from_pretrained(tmpdirname).to(torch_device)
pipes.append(sd_pipe)

image_slices = []
for pipe in pipes:
pipe.unet.set_default_attn_processor()

inputs = self.get_dummy_inputs(torch_device)
image = pipe(**inputs).images

image_slices.append(image[0, -3:, -3:, -1].flatten())

assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3

0 comments on commit b3b2d30

Please sign in to comment.