Skip to content

Commit

Permalink
refactor Image processor for x4 upscaler (huggingface#3692)
Browse files Browse the repository at this point in the history
* refactor x4 upscaler

* style

* copies

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
  • Loading branch information
yiyixuxu and yiyixuxu authored Jun 6, 2023
1 parent 8669e83 commit 017ee16
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.preprocess
def preprocess(image):
warnings.warn(
"The preprocess method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor.preprocess instead",
FutureWarning,
)
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer

from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import AttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor
Expand All @@ -34,6 +35,11 @@


def preprocess(image):
warnings.warn(
"The preprocess method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor.preprocess instead",
FutureWarning,
)
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
Expand Down Expand Up @@ -125,6 +131,8 @@ def __init__(
watermarker=watermarker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic")
self.register_to_config(max_noise_level=max_noise_level)

def enable_sequential_cpu_offload(self, gpu_id=0):
Expand Down Expand Up @@ -432,14 +440,15 @@ def check_inputs(
if (
not isinstance(image, torch.Tensor)
and not isinstance(image, PIL.Image.Image)
and not isinstance(image, np.ndarray)
and not isinstance(image, list)
):
raise ValueError(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or `list` but is {type(image)}"
f"`image` has to be of type `torch.Tensor`, `np.ndarray`, `PIL.Image.Image` or `list` but is {type(image)}"
)

# verify batch size of prompt and image are same if image is a list or tensor
if isinstance(image, list) or isinstance(image, torch.Tensor):
# verify batch size of prompt and image are same if image is a list or tensor or numpy array
if isinstance(image, list) or isinstance(image, torch.Tensor) or isinstance(image, np.ndarray):
if isinstance(prompt, str):
batch_size = 1
else:
Expand Down Expand Up @@ -483,7 +492,14 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None,
image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
num_inference_steps: int = 75,
guidance_scale: float = 9.0,
noise_level: int = 20,
Expand All @@ -506,7 +522,7 @@ def __call__(
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`):
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, or tensor representing an image batch which will be upscaled. *
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
Expand Down Expand Up @@ -627,7 +643,7 @@ def __call__(
)

# 4. Preprocess image
image = preprocess(image)
image = self.image_processor.preprocess(image)
image = image.to(dtype=prompt_embeds.dtype, device=device)

# 5. set timesteps
Expand Down Expand Up @@ -723,25 +739,25 @@ def __call__(
else:
latents = latents.float()

# 11. Convert to PIL
if output_type == "pil":
image = self.decode_latents(latents)

# post-processing
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept, _ = self.run_safety_checker(image, device, prompt_embeds.dtype)

image = self.numpy_to_pil(image)

# 11. Apply watermark
if self.watermarker is not None:
image = self.watermarker.apply_watermark(image)
elif output_type == "pt":
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
has_nsfw_concept = None
else:
image = self.decode_latents(latents)
image = latents
has_nsfw_concept = None

if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]

image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)

# 11. Apply watermark
if output_type == "pil" and self.watermarker is not None:
image = self.watermarker.apply_watermark(image)

# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
Expand Down

0 comments on commit 017ee16

Please sign in to comment.