diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py index 29d1d707f55a..00083fb392ff 100644 --- a/src/diffusers/models/unet_1d.py +++ b/src/diffusers/models/unet_1d.py @@ -218,6 +218,7 @@ def forward( else: timestep_embed = timestep_embed[..., None] timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype) + timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:])) # 2. down down_block_res_samples = () diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 96a4c1d98888..f149e2623fec 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -249,9 +249,9 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr return_tensors="pt", ) text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if not torch.equal(text_input_ids, untruncated_ids): + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 383ec1ea6070..6d7e57c7cad3 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -44,13 +44,24 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def preprocess(image): - w, h = image.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - return 2.0 * image - 1.0 + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker @@ -81,7 +92,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ - _optional_components = ["safety_checker", "feature_extractor"] + _optional_components = ["safety_checker"] def __init__( self, @@ -246,9 +257,9 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr return_tensors="pt", ) text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if not torch.equal(text_input_ids, untruncated_ids): + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" @@ -510,8 +521,7 @@ def __call__( ) # 4. Preprocess image - if isinstance(image, PIL.Image.Image): - image = preprocess(image) + image = preprocess(image) # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py index f5f2d404c2b4..873c7e97f038 100644 --- a/src/diffusers/pipelines/ddim/pipeline_ddim.py +++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py @@ -46,7 +46,6 @@ def __call__( use_clipped_model_output: Optional[bool] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - **kwargs, ) -> Union[ImagePipelineOutput, Tuple]: r""" Args: diff --git a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py index cae3c4febcc4..e37f771455af 100644 --- a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +++ b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py @@ -109,16 +109,18 @@ def prepare_mask_and_masked_image(image, mask): raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") else: if isinstance(image, PIL.Image.Image): - image = np.array(image.convert("RGB")) + image = [image] - image = image[None].transpose(0, 3, 1, 2) + image = np.concatenate([np.array(i.convert("RGB"))[None, :] for i in image], axis=0) + image = image.transpose(0, 3, 1, 2) image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + # preprocess mask if isinstance(mask, PIL.Image.Image): - mask = np.array(mask.convert("L")) - mask = mask.astype(np.float32) / 255.0 + mask = [mask] - mask = mask[None, None] + mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) + mask = mask.astype(np.float32) / 255.0 # paint-by-example inverses the mask mask = 1 - mask @@ -159,7 +161,7 @@ class PaintByExamplePipeline(DiffusionPipeline): feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ - _optional_components = ["safety_checker", "feature_extractor"] + _optional_components = ["safety_checker"] def __init__( self, @@ -323,8 +325,22 @@ def prepare_mask_latents( masked_image_latents = 0.18215 * masked_image_latents # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method - mask = mask.repeat(batch_size, 1, 1, 1) - masked_image_latents = masked_image_latents.repeat(batch_size, 1, 1, 1) + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask masked_image_latents = ( @@ -351,7 +367,7 @@ def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free if do_classifier_free_guidance: uncond_embeddings = self.image_encoder.uncond_vector - uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.repeat(1, image_embeddings.shape[0], 1) uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, 1, -1) # For classifier free guidance, we need to do two forward passes. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py index fce24feaec31..c51ebf2611f2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py @@ -35,14 +35,26 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def preprocess(image): - w, h = image.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - return 2.0 * image - 1.0 + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image def posterior_sample(scheduler, latents, timestep, clean_latents, generator, eta): @@ -279,9 +291,9 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr return_tensors="pt", ) text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if not torch.equal(text_input_ids, untruncated_ids): + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" @@ -551,8 +563,7 @@ def __call__( ) # 4. Preprocess image - if isinstance(image, PIL.Image.Image): - image = preprocess(image) + image = preprocess(image) # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py index 41128f29949e..86a4636389ee 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py @@ -32,13 +32,26 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def preprocess(image): - w, h = image.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - return 2.0 * image - 1.0 + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): @@ -77,7 +90,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): safety_checker: OnnxRuntimeModel feature_extractor: CLIPFeatureExtractor - _optional_components = ["safety_checker", "feature_extractor"] + _optional_components = ["safety_checker"] def __init__( self, @@ -325,8 +338,7 @@ def __call__( # set timesteps self.scheduler.set_timesteps(num_inference_steps) - if isinstance(image, PIL.Image.Image): - image = preprocess(image) + image = preprocess(image) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 631a2df39028..4681ffdc8252 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -248,9 +248,9 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr return_tensors="pt", ) text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if not torch.equal(text_input_ids, untruncated_ids): + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index 64fe7011ee54..6bff3e94949a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -41,14 +41,26 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def preprocess(image): - w, h = image.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - return 2.0 * image - 1.0 + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): @@ -189,9 +201,9 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr return_tensors="pt", ) text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if not torch.equal(text_input_ids, untruncated_ids): + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" @@ -366,12 +378,13 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt def prepare_depth_map(self, image, depth_map, batch_size, do_classifier_free_guidance, dtype, device): if isinstance(image, PIL.Image.Image): - width, height = image.size - width, height = map(lambda dim: dim - dim % 32, (width, height)) # resize to integer multiple of 32 - image = image.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]) - width, height = image.size + image = [image] else: image = [img for img in image] + + if isinstance(image[0], PIL.Image.Image): + width, height = image[0].size + else: width, height = image[0].shape[-2:] if depth_map is None: @@ -493,7 +506,7 @@ def __call__( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) - # 4. Prepare depth mask + # 4. Preprocess image depth_mask = self.prepare_depth_map( image, depth_map, @@ -503,11 +516,8 @@ def __call__( device, ) - # 5. Preprocess image - if isinstance(image, PIL.Image.Image): - image = preprocess(image) - else: - image = 2.0 * (image / 255.0) - 1.0 + # 5. Prepare depth mask + image = preprocess(image) # 6. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py index 71222f4afbec..1b6c8475d8ef 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py @@ -65,7 +65,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ - _optional_components = ["safety_checker", "feature_extractor"] + _optional_components = ["safety_checker"] def __init__( self, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 02dd463f6803..ffb8f7b60263 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -43,13 +43,24 @@ def preprocess(image): - w, h = image.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - return 2.0 * image - 1.0 + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image class StableDiffusionImg2ImgPipeline(DiffusionPipeline): @@ -79,7 +90,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ - _optional_components = ["safety_checker", "feature_extractor"] + _optional_components = ["safety_checker"] # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__ def __init__( @@ -248,9 +259,9 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr return_tensors="pt", ) text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if not torch.equal(text_input_ids, untruncated_ids): + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" @@ -515,8 +526,7 @@ def __call__( ) # 4. Preprocess image - if isinstance(image, PIL.Image.Image): - image = preprocess(image) + image = preprocess(image) # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 34ee8e6840af..a909c10167f8 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -107,14 +107,29 @@ def prepare_mask_and_masked_image(image, mask): elif isinstance(mask, torch.Tensor): raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") else: - if isinstance(image, PIL.Image.Image): - image = np.array(image.convert("RGB")) - image = image[None].transpose(0, 3, 1, 2) + # preprocess image + if isinstance(image, (PIL.Image.Image, np.ndarray)): + image = [image] + + if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): + image = [np.array(i.convert("RGB"))[None, :] for i in image] + image = np.concatenate(image, axis=0) + elif isinstance(image, list) and isinstance(image[0], np.ndarray): + image = np.concatenate([i[None, :] for i in image], axis=0) + + image = image.transpose(0, 3, 1, 2) image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 - if isinstance(mask, PIL.Image.Image): - mask = np.array(mask.convert("L")) + + # preprocess mask + if isinstance(mask, (PIL.Image.Image, np.ndarray)): + mask = [mask] + + if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): + mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) mask = mask.astype(np.float32) / 255.0 - mask = mask[None, None] + elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): + mask = np.concatenate([m[None, None, :] for m in mask], axis=0) + mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 mask = torch.from_numpy(mask) @@ -151,7 +166,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ - _optional_components = ["safety_checker", "feature_extractor"] + _optional_components = ["safety_checker"] def __init__( self, @@ -313,9 +328,9 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr return_tensors="pt", ) text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if not torch.equal(text_input_ids, untruncated_ids): + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" @@ -481,8 +496,22 @@ def prepare_mask_latents( masked_image_latents = 0.18215 * masked_image_latents # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method - mask = mask.repeat(batch_size, 1, 1, 1) - masked_image_latents = masked_image_latents.repeat(batch_size, 1, 1, 1) + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask masked_image_latents = ( diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index 722177a8111a..3a35163b2462 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -92,7 +92,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ - _optional_components = ["safety_checker", "feature_extractor"] + _optional_components = ["feature_extractor"] # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__ def __init__( @@ -261,9 +261,9 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr return_tensors="pt", ) text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if not torch.equal(text_input_ids, untruncated_ids): + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py index aa8626a57685..273688aeb7b4 100755 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -192,9 +192,9 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr return_tensors="pt", ) text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if not torch.equal(text_input_ids, untruncated_ids): + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index 4eb53a7f607f..ce68af1375a8 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -32,15 +32,23 @@ def preprocess(image): - # resize to multiple of 64 - width, height = image.size - width = width - width % 64 - height = height - height % 64 - image = image.resize((width, height)) - - image = np.array(image.convert("RGB")) - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32 + + image = [np.array(i.resize((w, h)))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) return image @@ -156,9 +164,9 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr return_tensors="pt", ) text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if not torch.equal(text_input_ids, untruncated_ids): + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" @@ -407,10 +415,7 @@ def __call__( ) # 4. Preprocess image - image = [image] if isinstance(image, PIL.Image.Image) else image - if isinstance(image, list): - image = [preprocess(img) for img in image] - image = torch.cat(image, dim=0) + image = preprocess(image) image = image.to(dtype=text_embeddings.dtype, device=device) # 5. set timesteps diff --git a/tests/pipelines/dance_diffusion/test_dance_diffusion.py b/tests/pipelines/dance_diffusion/test_dance_diffusion.py index 5b96e7f17508..b5188e904033 100644 --- a/tests/pipelines/dance_diffusion/test_dance_diffusion.py +++ b/tests/pipelines/dance_diffusion/test_dance_diffusion.py @@ -64,6 +64,7 @@ def get_dummy_inputs(self, device, seed=0): else: generator = torch.Generator(device=device).manual_seed(seed) inputs = { + "batch_size": 1, "generator": generator, "num_inference_steps": 4, } diff --git a/tests/pipelines/ddim/test_ddim.py b/tests/pipelines/ddim/test_ddim.py index df721efcf46b..38ff4fefe60b 100644 --- a/tests/pipelines/ddim/test_ddim.py +++ b/tests/pipelines/ddim/test_ddim.py @@ -52,6 +52,7 @@ def get_dummy_inputs(self, device, seed=0): else: generator = torch.Generator(device=device).manual_seed(seed) inputs = { + "batch_size": 1, "generator": generator, "num_inference_steps": 2, "output_type": "numpy", diff --git a/tests/pipelines/paint_by_example/test_paint_by_example.py b/tests/pipelines/paint_by_example/test_paint_by_example.py index 36da2374ac01..f53124a49351 100644 --- a/tests/pipelines/paint_by_example/test_paint_by_example.py +++ b/tests/pipelines/paint_by_example/test_paint_by_example.py @@ -25,7 +25,7 @@ from diffusers.utils import floats_tensor, load_image, slow, torch_device from diffusers.utils.testing_utils import require_torch_gpu from PIL import Image -from transformers import CLIPVisionConfig +from transformers import CLIPImageProcessor, CLIPVisionConfig from ...test_pipelines_common import PipelineTesterMixin @@ -76,6 +76,7 @@ def get_dummy_components(self): patch_size=4, ) image_encoder = PaintByExampleImageEncoder(config, proj_size=32) + feature_extractor = CLIPImageProcessor(crop_size=32, size=32) components = { "unet": unet, @@ -83,7 +84,7 @@ def get_dummy_components(self): "vae": vae, "image_encoder": image_encoder, "safety_checker": None, - "feature_extractor": None, + "feature_extractor": feature_extractor, } return components @@ -100,7 +101,6 @@ def get_dummy_inputs(self, device="cpu", seed=0): init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64)) mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64)) example_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((32, 32)) - example_image = self.convert_to_pt(example_image) if str(device).startswith("mps"): generator = torch.manual_seed(seed) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py index 4be95322d6b6..8ed1f5968bf0 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py @@ -29,7 +29,8 @@ ) from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device from diffusers.utils.testing_utils import require_torch_gpu -from transformers import CLIPVisionConfig, CLIPVisionModelWithProjection +from PIL import Image +from transformers import CLIPImageProcessor, CLIPVisionConfig, CLIPVisionModelWithProjection from ...test_pipelines_common import PipelineTesterMixin @@ -74,19 +75,22 @@ def get_dummy_components(self): patch_size=4, ) image_encoder = CLIPVisionModelWithProjection(image_encoder_config) + feature_extractor = CLIPImageProcessor(crop_size=32, size=32) components = { "unet": unet, "scheduler": scheduler, "vae": vae, "image_encoder": image_encoder, + "feature_extractor": feature_extractor, "safety_checker": None, - "feature_extractor": None, } return components def get_dummy_inputs(self, device, seed=0): - image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)) + image = image.cpu().permute(0, 2, 3, 1)[0] + image = Image.fromarray(np.uint8(image)).convert("RGB").resize((32, 32)) if str(device).startswith("mps"): generator = torch.manual_seed(seed) else: @@ -112,7 +116,7 @@ def test_stable_diffusion_img_variation_default_case(self): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.5093, 0.5717, 0.4806, 0.4891, 0.5552, 0.4594, 0.5177, 0.4894, 0.4904]) + expected_slice = np.array([0.5167, 0.5746, 0.4835, 0.4914, 0.5605, 0.4691, 0.5201, 0.4898, 0.4958]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 def test_stable_diffusion_img_variation_multiple_images(self): @@ -123,7 +127,7 @@ def test_stable_diffusion_img_variation_multiple_images(self): sd_pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(device) - inputs["image"] = inputs["image"].repeat(2, 1, 1, 1) + inputs["image"] = 2 * [inputs["image"]] output = sd_pipe(**inputs) image = output.images @@ -131,7 +135,7 @@ def test_stable_diffusion_img_variation_multiple_images(self): image_slice = image[-1, -3:, -3:, -1] assert image.shape == (2, 64, 64, 3) - expected_slice = np.array([0.6427, 0.5452, 0.5602, 0.5478, 0.5968, 0.6211, 0.5538, 0.5514, 0.5281]) + expected_slice = np.array([0.6568, 0.5470, 0.5684, 0.5444, 0.5945, 0.6221, 0.5508, 0.5531, 0.5263]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 def test_stable_diffusion_img_variation_num_images_per_prompt(self): @@ -150,7 +154,7 @@ def test_stable_diffusion_img_variation_num_images_per_prompt(self): # test num_images_per_prompt=1 (default) for batch of images batch_size = 2 inputs = self.get_dummy_inputs(device) - inputs["image"] = inputs["image"].repeat(batch_size, 1, 1, 1) + inputs["image"] = batch_size * [inputs["image"]] images = sd_pipe(**inputs).images assert images.shape == (batch_size, 64, 64, 3) @@ -165,7 +169,7 @@ def test_stable_diffusion_img_variation_num_images_per_prompt(self): # test num_images_per_prompt for batch of prompts batch_size = 2 inputs = self.get_dummy_inputs(device) - inputs["image"] = inputs["image"].repeat(batch_size, 1, 1, 1) + inputs["image"] = batch_size * [inputs["image"]] images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images assert images.shape == (batch_size * num_images_per_prompt, 64, 64, 3) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py index 7ce06403fa05..a9d341a1387a 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py @@ -30,7 +30,7 @@ ) from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device from diffusers.utils.testing_utils import require_torch_gpu -from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextConfig, CLIPTextModel, CLIPTokenizer from ...test_pipelines_common import PipelineTesterMixin @@ -77,6 +77,7 @@ def get_dummy_components(self): ) text_encoder = CLIPTextModel(text_encoder_config) tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + feature_extractor = CLIPImageProcessor(crop_size=32, size=32) components = { "unet": unet, @@ -85,7 +86,7 @@ def get_dummy_components(self): "text_encoder": text_encoder, "tokenizer": tokenizer, "safety_checker": None, - "feature_extractor": None, + "feature_extractor": feature_extractor, } return components diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index f331209e64f4..cc5749855701 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -31,7 +31,7 @@ from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device from diffusers.utils.testing_utils import require_torch_gpu from PIL import Image -from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextConfig, CLIPTextModel, CLIPTokenizer from ...test_pipelines_common import PipelineTesterMixin @@ -78,6 +78,7 @@ def get_dummy_components(self): ) text_encoder = CLIPTextModel(text_encoder_config) tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + feature_extractor = CLIPImageProcessor(crop_size=32, size=32) components = { "unet": unet, @@ -86,7 +87,7 @@ def get_dummy_components(self): "text_encoder": text_encoder, "tokenizer": tokenizer, "safety_checker": None, - "feature_extractor": None, + "feature_extractor": feature_extractor, } return components diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py index df074f6c2da9..b4bcd118377c 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py @@ -136,7 +136,9 @@ def get_dummy_components(self): return components def get_dummy_inputs(self, device, seed=0): - image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)) + image = image.cpu().permute(0, 2, 3, 1)[0] + image = Image.fromarray(np.uint8(image)).convert("RGB").resize((32, 32)) if str(device).startswith("mps"): generator = torch.manual_seed(seed) else: @@ -171,7 +173,7 @@ def test_save_load_local(self): output_loaded = pipe_loaded(**inputs)[0] max_diff = np.abs(output - output_loaded).max() - self.assertLess(max_diff, 3e-5) + self.assertLess(max_diff, 1e-4) @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") def test_save_load_float16(self): @@ -243,7 +245,7 @@ def test_cpu_offload_forward_pass(self): output_with_offload = pipe(**inputs)[0] max_diff = np.abs(output_with_offload - output_without_offload).max() - self.assertLess(max_diff, 3e-5, "CPU offloading should not affect the inference results") + self.assertLess(max_diff, 1e-4, "CPU offloading should not affect the inference results") @unittest.skipIf(torch_device == "mps", reason="The depth model does not support MPS yet") def test_dict_tuple_outputs_equivalent(self): @@ -260,7 +262,7 @@ def test_dict_tuple_outputs_equivalent(self): output_tuple = pipe(**self.get_dummy_inputs(torch_device), return_dict=False)[0] max_diff = np.abs(output - output_tuple).max() - self.assertLess(max_diff, 3e-5) + self.assertLess(max_diff, 1e-4) @unittest.skipIf(torch_device == "mps", reason="The depth model does not support MPS yet") def test_num_inference_steps_consistent(self): @@ -285,7 +287,7 @@ def test_stable_diffusion_depth2img_default_case(self): if torch_device == "mps": expected_slice = np.array([0.6071, 0.5035, 0.4378, 0.5776, 0.5753, 0.4316, 0.4513, 0.5263, 0.4546]) else: - expected_slice = np.array([0.6907, 0.5135, 0.4688, 0.5169, 0.5738, 0.4600, 0.4435, 0.5640, 0.4653]) + expected_slice = np.array([0.6854, 0.3740, 0.4857, 0.7130, 0.7403, 0.5536, 0.4829, 0.6182, 0.5053]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 def test_stable_diffusion_depth2img_negative_prompt(self): @@ -305,7 +307,7 @@ def test_stable_diffusion_depth2img_negative_prompt(self): if torch_device == "mps": expected_slice = np.array([0.5825, 0.5135, 0.4095, 0.5452, 0.6059, 0.4211, 0.3994, 0.5177, 0.4335]) else: - expected_slice = np.array([0.755, 0.521, 0.473, 0.554, 0.629, 0.442, 0.440, 0.582, 0.449]) + expected_slice = np.array([0.6074, 0.3096, 0.4802, 0.7463, 0.7388, 0.5393, 0.4531, 0.5928, 0.4972]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 def test_stable_diffusion_depth2img_multiple_init_images(self): @@ -317,7 +319,7 @@ def test_stable_diffusion_depth2img_multiple_init_images(self): inputs = self.get_dummy_inputs(device) inputs["prompt"] = [inputs["prompt"]] * 2 - inputs["image"] = inputs["image"].repeat(2, 1, 1, 1) + inputs["image"] = 2 * [inputs["image"]] image = sd_pipe(**inputs).images image_slice = image[-1, -3:, -3:, -1] @@ -326,7 +328,7 @@ def test_stable_diffusion_depth2img_multiple_init_images(self): if torch_device == "mps": expected_slice = np.array([0.6501, 0.5150, 0.4939, 0.6688, 0.5437, 0.5758, 0.5115, 0.4406, 0.4551]) else: - expected_slice = np.array([0.6475, 0.6302, 0.5627, 0.5222, 0.4318, 0.5489, 0.5079, 0.4419, 0.4494]) + expected_slice = np.array([0.6681, 0.5023, 0.6611, 0.7605, 0.5724, 0.7959, 0.7240, 0.5871, 0.5383]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 def test_stable_diffusion_depth2img_num_images_per_prompt(self): @@ -374,7 +376,6 @@ def test_stable_diffusion_depth2img_pil(self): inputs = self.get_dummy_inputs(device) - inputs["image"] = Image.fromarray(inputs["image"][0].permute(1, 2, 0).numpy().astype(np.uint8)) image = sd_pipe(**inputs).images image_slice = image[0, -3:, -3:, -1] @@ -452,7 +453,7 @@ def test_stable_diffusion_depth2img_pipeline_k_lms(self): image = output.images[0] assert image.shape == (480, 640, 3) - assert np.abs(expected_image - image).max() < 1e-3 + assert np.abs(expected_image - image).max() < 5e-3 def test_stable_diffusion_depth2img_pipeline_ddim(self): init_image = load_image( @@ -540,8 +541,7 @@ def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): torch.cuda.reset_peak_memory_stats() init_image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/depth2img/sketch-mountains-input.jpg" + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/depth2img/two_cats.png" ) init_image = init_image.resize((768, 512)) @@ -565,7 +565,7 @@ def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): guidance_scale=7.5, generator=generator, output_type="np", - num_inference_steps=5, + num_inference_steps=2, ) mem_bytes = torch.cuda.max_memory_allocated() diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py index b2d387cb6890..eb1bf3fe5905 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py @@ -24,7 +24,7 @@ from diffusers.utils import floats_tensor, load_image, load_numpy, torch_device from diffusers.utils.testing_utils import require_torch_gpu, slow from PIL import Image -from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextConfig, CLIPTextModel, CLIPTokenizer from ...test_pipelines_common import PipelineTesterMixin @@ -78,6 +78,7 @@ def get_dummy_components(self): ) text_encoder = CLIPTextModel(text_encoder_config) tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + feature_extractor = CLIPImageProcessor(crop_size=32, size=32) components = { "unet": unet, @@ -86,7 +87,7 @@ def get_dummy_components(self): "text_encoder": text_encoder, "tokenizer": tokenizer, "safety_checker": None, - "feature_extractor": None, + "feature_extractor": feature_extractor, } return components diff --git a/tests/test_pipelines_common.py b/tests/test_pipelines_common.py index 93f8edb8f35d..9676b7f5451a 100644 --- a/tests/test_pipelines_common.py +++ b/tests/test_pipelines_common.py @@ -11,6 +11,7 @@ import numpy as np import torch +import diffusers from diffusers import ( CycleDiffusionPipeline, DanceDiffusionPipeline, @@ -18,6 +19,7 @@ StableDiffusionDepth2ImgPipeline, StableDiffusionImg2ImgPipeline, ) +from diffusers.utils import logging from diffusers.utils.import_utils import is_accelerate_available, is_xformers_available from diffusers.utils.testing_utils import require_torch, torch_device @@ -25,6 +27,9 @@ torch.backends.cuda.matmul.allow_tf32 = False +ALLOWED_REQUIRED_ARGS = ["source_prompt", "prompt", "image", "mask_image", "example_image"] + + @require_torch class PipelineTesterMixin: """ @@ -94,7 +99,80 @@ def test_save_load_local(self): output_loaded = pipe_loaded(**inputs)[0] max_diff = np.abs(output - output_loaded).max() - self.assertLess(max_diff, 1e-5) + self.assertLess(max_diff, 1e-4) + + def test_pipeline_call_implements_required_args(self): + assert hasattr(self.pipeline_class, "__call__"), f"{self.pipeline_class} should have a `__call__` method" + parameters = inspect.signature(self.pipeline_class.__call__).parameters + required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} + required_parameters.pop("self") + required_parameters = set(required_parameters) + optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) + + for param in required_parameters: + if param == "kwargs": + # kwargs can be added if arguments of pipeline call function are deprecated + continue + assert param in ALLOWED_REQUIRED_ARGS + + optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) + + required_optional_params = ["generator", "num_inference_steps", "return_dict"] + for param in required_optional_params: + assert param in optional_parameters + + def test_inference_batch_consistent(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + + logger = logging.get_logger(pipe.__module__) + logger.setLevel(level=diffusers.logging.FATAL) + + # batchify inputs + for batch_size in [2, 4, 13]: + batched_inputs = {} + for name, value in inputs.items(): + if name in ALLOWED_REQUIRED_ARGS: + # prompt is string + if name == "prompt": + len_prompt = len(value) + # make unequal batch sizes + batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)] + + # make last batch super long + batched_inputs[name][-1] = 2000 * "very long" + # or else we have images + else: + batched_inputs[name] = batch_size * [value] + elif name == "batch_size": + batched_inputs[name] = batch_size + else: + batched_inputs[name] = value + + batched_inputs["num_inference_steps"] = inputs["num_inference_steps"] + batched_inputs["output_type"] = None + + if self.pipeline_class.__name__ == "DanceDiffusionPipeline": + batched_inputs.pop("output_type") + + output = pipe(**batched_inputs) + + assert len(output[0]) == batch_size + + batched_inputs["output_type"] = "np" + + if self.pipeline_class.__name__ == "DanceDiffusionPipeline": + batched_inputs.pop("output_type") + + output = pipe(**batched_inputs)[0] + + assert output.shape[0] == batch_size + + logger.setLevel(level=diffusers.logging.WARNING) def test_dict_tuple_outputs_equivalent(self): if torch_device == "mps" and self.pipeline_class in ( @@ -118,13 +196,7 @@ def test_dict_tuple_outputs_equivalent(self): output_tuple = pipe(**self.get_dummy_inputs(torch_device), return_dict=False)[0] max_diff = np.abs(output - output_tuple).max() - self.assertLess(max_diff, 1e-5) - - def test_pipeline_call_implements_required_args(self): - required_args = ["num_inference_steps", "generator", "return_dict"] - - for arg in required_args: - self.assertTrue(arg in inspect.signature(self.pipeline_class.__call__).parameters) + self.assertLess(max_diff, 1e-4) def test_num_inference_steps_consistent(self): components = self.get_dummy_components() @@ -138,7 +210,7 @@ def test_num_inference_steps_consistent(self): outputs = [] times = [] - for num_steps in [3, 6, 9]: + for num_steps in [9, 6, 3]: inputs = self.get_dummy_inputs(torch_device) inputs["num_inference_steps"] = num_steps @@ -152,7 +224,7 @@ def test_num_inference_steps_consistent(self): # check that all outputs have the same shape self.assertTrue(all(outputs[0].shape == output.shape for output in outputs)) # check that the inference time increases with the number of inference steps - self.assertTrue(all(times[i] > times[i - 1] for i in range(1, len(times)))) + self.assertTrue(all(times[i] < times[i - 1] for i in range(1, len(times)))) def test_components_function(self): init_components = self.get_dummy_components() @@ -257,7 +329,7 @@ def test_save_load_optional_components(self): output_loaded = pipe_loaded(**inputs)[0] max_diff = np.abs(output - output_loaded).max() - self.assertLess(max_diff, 1e-5) + self.assertLess(max_diff, 1e-4) @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices") def test_to_device(self): @@ -332,7 +404,7 @@ def test_cpu_offload_forward_pass(self): output_with_offload = pipe(**inputs)[0] max_diff = np.abs(output_with_offload - output_without_offload).max() - self.assertLess(max_diff, 1e-5, "CPU offloading should not affect the inference results") + self.assertLess(max_diff, 1e-4, "CPU offloading should not affect the inference results") @unittest.skipIf( torch_device != "cuda" or not is_xformers_available(), @@ -355,7 +427,7 @@ def test_xformers_attention_forward_pass(self): output_with_offload = pipe(**inputs)[0] max_diff = np.abs(output_with_offload - output_without_offload).max() - self.assertLess(max_diff, 1e-5, "XFormers attention should not affect the inference results") + self.assertLess(max_diff, 1e-4, "XFormers attention should not affect the inference results") def test_progress_bar(self): components = self.get_dummy_components()