Skip to content

Commit

Permalink
Fix MPS scheduler indexing when using mps (huggingface#450)
Browse files Browse the repository at this point in the history
* Fix LMS scheduler indexing in `add_noise` huggingface#358.

* Fix DDIM and DDPM indexing with mps device.

* Verify format is PyTorch before using `.to()`
  • Loading branch information
pcuenca authored Sep 14, 2022
1 parent 7c4b38b commit 1a69c6f
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 3 deletions.
2 changes: 2 additions & 0 deletions src/diffusers/schedulers/scheduling_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,8 @@ def add_noise(
noise: Union[torch.FloatTensor, np.ndarray],
timesteps: Union[torch.IntTensor, np.ndarray],
) -> Union[torch.FloatTensor, np.ndarray]:
if self.tensor_format == "pt":
timesteps = timesteps.to(self.alphas_cumprod.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/schedulers/scheduling_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,8 @@ def add_noise(
noise: Union[torch.FloatTensor, np.ndarray],
timesteps: Union[torch.IntTensor, np.ndarray],
) -> Union[torch.FloatTensor, np.ndarray]:
if self.tensor_format == "pt":
timesteps = timesteps.to(self.alphas_cumprod.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
Expand Down
4 changes: 3 additions & 1 deletion src/diffusers/schedulers/scheduling_lms_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def set_timesteps(self, num_inference_steps: int):
frac = np.mod(self.timesteps, 1.0)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx]
self.sigmas = np.concatenate([sigmas, [0.0]])
self.sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)

self.derivatives = []

Expand Down Expand Up @@ -183,6 +183,8 @@ def add_noise(
noise: Union[torch.FloatTensor, np.ndarray],
timesteps: Union[torch.IntTensor, np.ndarray],
) -> Union[torch.FloatTensor, np.ndarray]:
if self.tensor_format == "pt":
timesteps = timesteps.to(self.sigmas.device)
sigmas = self.match_shape(self.sigmas[timesteps], noise)
noisy_samples = original_samples + noise * sigmas

Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/schedulers/scheduling_pndm.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,8 +367,8 @@ def add_noise(
noise: Union[torch.FloatTensor, np.ndarray],
timesteps: Union[torch.IntTensor, np.ndarray],
) -> torch.Tensor:
# mps requires indices to be in the same device, so we use cpu as is the default with cuda
timesteps = timesteps.to(self.alphas_cumprod.device)
if self.tensor_format == "pt":
timesteps = timesteps.to(self.alphas_cumprod.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
Expand Down

0 comments on commit 1a69c6f

Please sign in to comment.