Skip to content

Commit

Permalink
Fix for InstructPix2PixPipeline to allow for prompt embeds to be pass…
Browse files Browse the repository at this point in the history
…ed in without prompts. (huggingface#2456)

* fix check inputs to allow prompt embeds in instruct pix2pix

* linting

* add reference comment to check inputs

* remove comment

* style changes

---------

Co-authored-by: Will Berman <[email protected]>
  • Loading branch information
DN6 and williamberman authored Mar 5, 2023
1 parent 2ea1da8 commit e4c356d
Showing 1 changed file with 37 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -246,13 +246,19 @@ def __call__(
(nsfw) content, according to the `safety_checker`.
"""
# 0. Check inputs
self.check_inputs(prompt, callback_steps)
self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)

if image is None:
raise ValueError("`image` input cannot be undefined.")

# 1. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]

device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
Expand Down Expand Up @@ -640,10 +646,9 @@ def decode_latents(self, latents):
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image

def check_inputs(self, prompt, callback_steps):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")

def check_inputs(
self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
):
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
Expand All @@ -652,6 +657,32 @@ def check_inputs(self, prompt, callback_steps):
f" {type(callback_steps)}."
)

if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")

if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)

if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
Expand Down

0 comments on commit e4c356d

Please sign in to comment.