Skip to content

Commit

Permalink
fix a bug of prompt embeds in sdxl (huggingface#4099)
Browse files Browse the repository at this point in the history
* fix bug in sdxl

* Update pipeline_stable_diffusion_xl_img2img.py

* Update pipeline_stable_diffusion_xl.py

* Update pipeline_stable_diffusion_xl_img2img.py

* Update pipeline_stable_diffusion_xl_inpaint.py

* Update pipeline_stable_diffusion_xl.py

* Update pipeline_stable_diffusion_xl_img2img.py

* Update pipeline_stable_diffusion_xl_inpaint.py

* Update pipeline_stable_diffusion_xl_img2img.py

* Update pipeline_controlnet_sd_xl.py

* Update pipeline_controlnet_sd_xl.py

* Update pipeline_stable_diffusion_xl.py

* Update pipeline_stable_diffusion_xl_img2img.py

* Update pipeline_stable_diffusion_xl_inpaint.py

* Update test_stable_diffusion_xl.py

* Update test_stable_diffusion_xl.py

* Update test_stable_diffusion_xl.py

add test on prompt_embeds

* add test on prompt_embeds

---------

Co-authored-by: Patrick von Platen <[email protected]>
  • Loading branch information
xiaohu2015 and patrickvonplaten authored Jul 24, 2023
1 parent 8e8954b commit 8e5921c
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 84 deletions.
34 changes: 13 additions & 21 deletions src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,11 +304,6 @@ def encode_prompt(
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]

bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)

prompt_embeds_list.append(prompt_embeds)

prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
Expand Down Expand Up @@ -361,26 +356,23 @@ def encode_prompt(
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]

if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]

negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)

negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(
batch_size * num_images_per_prompt, seq_len, -1
)

# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes

negative_prompt_embeds_list.append(negative_prompt_embeds)

negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)

bs_embed = pooled_prompt_embeds.shape[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)

if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,11 +319,6 @@ def encode_prompt(
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]

bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)

prompt_embeds_list.append(prompt_embeds)

prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
Expand Down Expand Up @@ -376,26 +371,23 @@ def encode_prompt(
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]

if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]

negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)

negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(
batch_size * num_images_per_prompt, seq_len, -1
)

# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes

negative_prompt_embeds_list.append(negative_prompt_embeds)

negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)

bs_embed = pooled_prompt_embeds.shape[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)

if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,11 +327,6 @@ def encode_prompt(
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]

bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)

prompt_embeds_list.append(prompt_embeds)

prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
Expand Down Expand Up @@ -384,26 +379,23 @@ def encode_prompt(
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]

if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]

negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)

negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(
batch_size * num_images_per_prompt, seq_len, -1
)

# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes

negative_prompt_embeds_list.append(negative_prompt_embeds)

negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)

bs_embed = pooled_prompt_embeds.shape[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)

if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -433,11 +433,6 @@ def encode_prompt(
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]

bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)

prompt_embeds_list.append(prompt_embeds)

prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
Expand Down Expand Up @@ -490,26 +485,23 @@ def encode_prompt(
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]

if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]

negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)

negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(
batch_size * num_images_per_prompt, seq_len, -1
)

# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes

negative_prompt_embeds_list.append(negative_prompt_embeds)

negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)

bs_embed = pooled_prompt_embeds.shape[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)

if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
Expand Down
38 changes: 38 additions & 0 deletions tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,44 @@ def test_stable_diffusion_xl_euler(self):

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

def test_stable_diffusion_xl_prompt_embeds(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)

# forward without prompt embeds
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt"] = 2 * [inputs["prompt"]]
inputs["num_images_per_prompt"] = 2

output = sd_pipe(**inputs)
image_slice_1 = output.images[0, -3:, -3:, -1]

# forward with prompt embeds
inputs = self.get_dummy_inputs(torch_device)
prompt = 2 * [inputs.pop("prompt")]

(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = sd_pipe.encode_prompt(prompt)

output = sd_pipe(
**inputs,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
)
image_slice_2 = output.images[0, -3:, -3:, -1]

# make sure that it's equal
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4

def test_stable_diffusion_xl_negative_prompt_embeds(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**components)
Expand Down

0 comments on commit 8e5921c

Please sign in to comment.