Skip to content

Commit

Permalink
Support different strength for Stable Diffusion TensorRT Inpainting p…
Browse files Browse the repository at this point in the history
…ipeline (huggingface#4216)

* Support different strength

* run make style
  • Loading branch information
jinwonkim93 authored Aug 3, 2023
1 parent d0b8de1 commit e391b78
Showing 1 changed file with 37 additions and 18 deletions.
55 changes: 37 additions & 18 deletions examples/community/stable_diffusion_tensorrt_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,14 +823,14 @@ def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dt

return self

def __initialize_timesteps(self, timesteps, strength):
self.scheduler.set_timesteps(timesteps)
offset = self.scheduler.steps_offset if hasattr(self.scheduler, "steps_offset") else 0
init_timestep = int(timesteps * strength) + offset
init_timestep = min(init_timestep, timesteps)
t_start = max(timesteps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:].to(self.torch_device)
return timesteps, t_start
def __initialize_timesteps(self, num_inference_steps, strength):
self.scheduler.set_timesteps(num_inference_steps)
offset = self.scheduler.config.steps_offset if hasattr(self.scheduler, "steps_offset") else 0
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
t_start = max(num_inference_steps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :].to(self.torch_device)
return timesteps, num_inference_steps - t_start

def __preprocess_images(self, batch_size, images=()):
init_images = []
Expand Down Expand Up @@ -953,7 +953,7 @@ def __call__(
prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
strength: float = 0.75,
strength: float = 1.0,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
Expand Down Expand Up @@ -1043,26 +1043,45 @@ def __call__(
latent_height = self.image_height // 8
latent_width = self.image_width // 8

# Pre-process input images
mask, masked_image, init_image = self.__preprocess_images(
batch_size,
prepare_mask_and_masked_image(
image,
mask_image,
self.image_height,
self.image_width,
return_image=True,
),
)
# print(mask)
mask = torch.nn.functional.interpolate(mask, size=(latent_height, latent_width))
mask = torch.cat([mask] * 2)

# Initialize timesteps
timesteps, t_start = self.__initialize_timesteps(self.denoising_steps, strength)

# at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
latent_timestep = timesteps[:1].repeat(batch_size)
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
is_strength_max = strength == 1.0

# Pre-initialize latents
num_channels_latents = self.vae.config.latent_channels
latents = self.prepare_latents(
latents_outputs = self.prepare_latents(
batch_size,
num_channels_latents,
self.image_height,
self.image_width,
torch.float32,
self.torch_device,
generator,
image=init_image,
timestep=latent_timestep,
is_strength_max=is_strength_max,
)

# Pre-process input images
mask, masked_image = self.__preprocess_images(batch_size, prepare_mask_and_masked_image(image, mask_image))
# print(mask)
mask = torch.nn.functional.interpolate(mask, size=(latent_height, latent_width))
mask = torch.cat([mask] * 2)

# Initialize timesteps
timesteps, t_start = self.__initialize_timesteps(self.denoising_steps, strength)
latents = latents_outputs[0]

# VAE encode masked image
masked_latents = self.__encode_image(masked_image)
Expand Down

0 comments on commit e391b78

Please sign in to comment.