From 5b11c5dc779b7e42022d8353b1b1aa6fb9b758f3 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Sat, 23 Sep 2023 09:07:50 -1000 Subject: [PATCH] fix the add_noise function for dpm-multi et al (#5158) * remove to _device() for sigmas * update add_noise to use simgas --------- Co-authored-by: yiyixuxu --- src/diffusers/schedulers/scheduling_deis_multistep.py | 11 ++++++----- .../schedulers/scheduling_dpmsolver_multistep.py | 10 +++++----- .../scheduling_dpmsolver_multistep_inverse.py | 9 +++++---- .../schedulers/scheduling_dpmsolver_singlestep.py | 9 +++++---- .../schedulers/scheduling_unipc_multistep.py | 11 ++++++----- 5 files changed, 27 insertions(+), 23 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index c7a94bce88eb..af9b0381dcc4 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -243,8 +243,8 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) - self.sigmas = torch.from_numpy(sigmas).to(device=device) - self.timesteps = torch.from_numpy(timesteps).to(device) + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) self.num_inference_steps = len(timesteps) @@ -707,12 +707,12 @@ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch """ return sample - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, - timesteps: torch.FloatTensor, + timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) @@ -730,7 +730,8 @@ def add_noise( while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) - noisy_samples = original_samples + noise * sigma + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise return noisy_samples def __len__(self): diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 264ee268ae17..726ad138ad84 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -263,8 +263,8 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) - self.sigmas = torch.from_numpy(sigmas).to(device=device) - self.timesteps = torch.from_numpy(timesteps).to(device) + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) self.num_inference_steps = len(timesteps) @@ -840,12 +840,11 @@ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch """ return sample - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, - timesteps: torch.FloatTensor, + timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) @@ -863,7 +862,8 @@ def add_noise( while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) - noisy_samples = original_samples + noise * sigma + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise return noisy_samples def __len__(self): diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index 7c740234fa40..c0b286a37060 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -274,7 +274,7 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc _, unique_indices = np.unique(timesteps, return_index=True) timesteps = timesteps[np.sort(unique_indices)] - self.timesteps = torch.from_numpy(timesteps).to(device) + self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) self.num_inference_steps = len(timesteps) @@ -858,12 +858,12 @@ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch """ return sample - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, - timesteps: torch.FloatTensor, + timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) @@ -881,7 +881,8 @@ def add_noise( while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) - noisy_samples = original_samples + noise * sigma + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise return noisy_samples def __len__(self): diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 10f7ab34e0a4..6744a68b4c4b 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -275,7 +275,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic self.sigmas = torch.from_numpy(sigmas).to(device=device) - self.timesteps = torch.from_numpy(timesteps).to(device) + self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) self.model_outputs = [None] * self.config.solver_order self.sample = None @@ -870,12 +870,12 @@ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch """ return sample - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, - timesteps: torch.FloatTensor, + timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) @@ -893,7 +893,8 @@ def add_noise( while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) - noisy_samples = original_samples + noise * sigma + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise return noisy_samples def __len__(self): diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 2dcca2ecaece..2b5bd4fd60db 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -254,8 +254,8 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) - self.sigmas = torch.from_numpy(sigmas).to(device=device) - self.timesteps = torch.from_numpy(timesteps).to(device) + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) self.num_inference_steps = len(timesteps) @@ -801,12 +801,12 @@ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch """ return sample - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, - timesteps: torch.FloatTensor, + timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) @@ -824,7 +824,8 @@ def add_noise( while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) - noisy_samples = original_samples + noise * sigma + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise return noisy_samples def __len__(self):