Skip to content

Commit

Permalink
[schedulers] hanlde dtype in add_noise (huggingface#767)
Browse files Browse the repository at this point in the history
* handle dtype in vae and image2image pipeline

* handle dtype in add noise

* don't modify vae and pipeline

* remove the if
  • Loading branch information
patil-suraj authored Oct 7, 2022
1 parent cb0bf0b commit ec831b6
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 18 deletions.
8 changes: 3 additions & 5 deletions src/diffusers/schedulers/scheduling_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,11 +301,9 @@ def add_noise(
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
if self.alphas_cumprod.device != original_samples.device:
self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device)

if timesteps.device != original_samples.device:
timesteps = timesteps.to(original_samples.device)
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device)

sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
Expand Down
8 changes: 3 additions & 5 deletions src/diffusers/schedulers/scheduling_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,11 +294,9 @@ def add_noise(
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
if self.alphas_cumprod.device != original_samples.device:
self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device)

if timesteps.device != original_samples.device:
timesteps = timesteps.to(original_samples.device)
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device)

sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
Expand Down
10 changes: 7 additions & 3 deletions src/diffusers/schedulers/scheduling_lms_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,13 @@ def add_noise(
noise: torch.FloatTensor,
timesteps: torch.FloatTensor,
) -> torch.FloatTensor:
sigmas = self.sigmas.to(original_samples.device)
schedule_timesteps = self.timesteps.to(original_samples.device)
# Make sure sigmas and timesteps have the same device and dtype as original_samples
self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
self.timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)

schedule_timesteps = self.timesteps

if isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor):
deprecate(
"timesteps as indices",
Expand All @@ -273,7 +277,7 @@ def add_noise(
else:
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]

sigma = sigmas[step_indices].flatten()
sigma = self.sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)

Expand Down
8 changes: 3 additions & 5 deletions src/diffusers/schedulers/scheduling_pndm.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,11 +400,9 @@ def add_noise(
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
if self.alphas_cumprod.device != original_samples.device:
self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device)

if timesteps.device != original_samples.device:
timesteps = timesteps.to(original_samples.device)
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device)

sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
Expand Down

0 comments on commit ec831b6

Please sign in to comment.