Skip to content

Commit

Permalink
Make sure all pipelines can run with batched input (huggingface#1669)
Browse files Browse the repository at this point in the history
* [SD] Make sure batched input works correctly

* uP

* uP

* up

* up

* uP

* up

* fix mask stuff

* up

* uP

* more up

* up

* uP

* up

* finish

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <[email protected]>

Co-authored-by: Pedro Cuenca <[email protected]>
  • Loading branch information
patrickvonplaten and pcuenca authored Dec 13, 2022
1 parent b417042 commit b345c74
Show file tree
Hide file tree
Showing 24 changed files with 336 additions and 152 deletions.
1 change: 1 addition & 0 deletions src/diffusers/models/unet_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def forward(
else:
timestep_embed = timestep_embed[..., None]
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:]))

# 2. down
down_block_res_samples = ()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,9 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids

if not torch.equal(text_input_ids, untruncated_ids):
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,24 @@

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
image = [image]

if isinstance(image[0], PIL.Image.Image):
w, h = image[0].size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32

image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
Expand Down Expand Up @@ -81,7 +92,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
_optional_components = ["safety_checker"]

def __init__(
self,
Expand Down Expand Up @@ -246,9 +257,9 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids

if not torch.equal(text_input_ids, untruncated_ids):
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
Expand Down Expand Up @@ -510,8 +521,7 @@ def __call__(
)

# 4. Preprocess image
if isinstance(image, PIL.Image.Image):
image = preprocess(image)
image = preprocess(image)

# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
Expand Down
1 change: 0 additions & 1 deletion src/diffusers/pipelines/ddim/pipeline_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def __call__(
use_clipped_model_output: Optional[bool] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs,
) -> Union[ImagePipelineOutput, Tuple]:
r"""
Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,18 @@ def prepare_mask_and_masked_image(image, mask):
raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
else:
if isinstance(image, PIL.Image.Image):
image = np.array(image.convert("RGB"))
image = [image]

image = image[None].transpose(0, 3, 1, 2)
image = np.concatenate([np.array(i.convert("RGB"))[None, :] for i in image], axis=0)
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0

# preprocess mask
if isinstance(mask, PIL.Image.Image):
mask = np.array(mask.convert("L"))
mask = mask.astype(np.float32) / 255.0
mask = [mask]

mask = mask[None, None]
mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
mask = mask.astype(np.float32) / 255.0

# paint-by-example inverses the mask
mask = 1 - mask
Expand Down Expand Up @@ -159,7 +161,7 @@ class PaintByExamplePipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
_optional_components = ["safety_checker"]

def __init__(
self,
Expand Down Expand Up @@ -323,8 +325,22 @@ def prepare_mask_latents(
masked_image_latents = 0.18215 * masked_image_latents

# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
mask = mask.repeat(batch_size, 1, 1, 1)
masked_image_latents = masked_image_latents.repeat(batch_size, 1, 1, 1)
if mask.shape[0] < batch_size:
if not batch_size % mask.shape[0] == 0:
raise ValueError(
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
" of masks that you pass is divisible by the total requested batch size."
)
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0:
raise ValueError(
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
)
masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)

mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
masked_image_latents = (
Expand All @@ -351,7 +367,7 @@ def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free

if do_classifier_free_guidance:
uncond_embeddings = self.image_encoder.uncond_vector
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.repeat(1, image_embeddings.shape[0], 1)
uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, 1, -1)

# For classifier free guidance, we need to do two forward passes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,26 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
image = [image]

if isinstance(image[0], PIL.Image.Image):
w, h = image[0].size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32

image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image


def posterior_sample(scheduler, latents, timestep, clean_latents, generator, eta):
Expand Down Expand Up @@ -279,9 +291,9 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids

if not torch.equal(text_input_ids, untruncated_ids):
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
Expand Down Expand Up @@ -551,8 +563,7 @@ def __call__(
)

# 4. Preprocess image
if isinstance(image, PIL.Image.Image):
image = preprocess(image)
image = preprocess(image)

# 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,26 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
return 2.0 * image - 1.0
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
image = [image]

if isinstance(image[0], PIL.Image.Image):
w, h = image[0].size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32

image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image


class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
Expand Down Expand Up @@ -77,7 +90,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
safety_checker: OnnxRuntimeModel
feature_extractor: CLIPFeatureExtractor

_optional_components = ["safety_checker", "feature_extractor"]
_optional_components = ["safety_checker"]

def __init__(
self,
Expand Down Expand Up @@ -325,8 +338,7 @@ def __call__(
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)

if isinstance(image, PIL.Image.Image):
image = preprocess(image)
image = preprocess(image)

# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,9 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids

if not torch.equal(text_input_ids, untruncated_ids):
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,26 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
image = [image]

if isinstance(image[0], PIL.Image.Image):
w, h = image[0].size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32

image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image


class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
Expand Down Expand Up @@ -189,9 +201,9 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids

if not torch.equal(text_input_ids, untruncated_ids):
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
Expand Down Expand Up @@ -366,12 +378,13 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt

def prepare_depth_map(self, image, depth_map, batch_size, do_classifier_free_guidance, dtype, device):
if isinstance(image, PIL.Image.Image):
width, height = image.size
width, height = map(lambda dim: dim - dim % 32, (width, height)) # resize to integer multiple of 32
image = image.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
width, height = image.size
image = [image]
else:
image = [img for img in image]

if isinstance(image[0], PIL.Image.Image):
width, height = image[0].size
else:
width, height = image[0].shape[-2:]

if depth_map is None:
Expand Down Expand Up @@ -493,7 +506,7 @@ def __call__(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)

# 4. Prepare depth mask
# 4. Preprocess image
depth_mask = self.prepare_depth_map(
image,
depth_map,
Expand All @@ -503,11 +516,8 @@ def __call__(
device,
)

# 5. Preprocess image
if isinstance(image, PIL.Image.Image):
image = preprocess(image)
else:
image = 2.0 * (image / 255.0) - 1.0
# 5. Prepare depth mask
image = preprocess(image)

# 6. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
_optional_components = ["safety_checker"]

def __init__(
self,
Expand Down
Loading

0 comments on commit b345c74

Please sign in to comment.