Skip to content

Commit

Permalink
Fix unipc karras sigmas exception - fixes huggingface#4580
Browse files Browse the repository at this point in the history
  • Loading branch information
reimager committed Aug 11, 2023
1 parent d5983a6 commit 6d387f1
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions src/diffusers/schedulers/scheduling_unipc_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,44 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:

return sample

# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
log_sigma = np.log(sigma)

# get distribution
dists = log_sigma - log_sigmas[:, np.newaxis]

# get sigmas range
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
high_idx = low_idx + 1

low = log_sigmas[low_idx]
high = log_sigmas[high_idx]

# interpolate sigmas
w = (low - log_sigma) / (low - high)
w = np.clip(w, 0, 1)

# transform interpolation to time range
t = (1 - w) * low_idx + w * high_idx
t = t.reshape(sigma.shape)
return t

# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""

sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()

rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas

def convert_model_output(
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
) -> torch.FloatTensor:
Expand Down

0 comments on commit 6d387f1

Please sign in to comment.