forked from lllyasviel/stable-diffusion-webui-forge
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsd_samplers_timesteps_impl.py
178 lines (134 loc) · 7.31 KB
/
sd_samplers_timesteps_impl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import torch
import tqdm
import k_diffusion.sampling
import numpy as np
from modules import shared
from modules.models.diffusion.uni_pc import uni_pc
from modules.torch_utils import float64
@torch.no_grad()
def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones((x.shape[0]))
s_x = x.new_ones((x.shape[0], 1, 1, 1))
for i in tqdm.trange(len(timesteps) - 1, disable=disable):
index = len(timesteps) - 1 - i
e_t = model(x, timesteps[index].item() * s_in, **extra_args)
a_t = alphas[index].item() * s_x
a_prev = alphas_prev[index].item() * s_x
sigma_t = sigmas[index].item() * s_x
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
noise = sigma_t * k_diffusion.sampling.torch.randn_like(x)
x = a_prev.sqrt() * pred_x0 + dir_xt + noise
if callback is not None:
callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
return x
@torch.no_grad()
def ddim_cfgpp(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
""" Implements CFG++: Manifold-constrained Classifier Free Guidance For Diffusion Models (2024).
Uses the unconditional noise prediction instead of the conditional noise to guide the denoising direction.
The CFG scale is divided by 12.5 to map CFG from [0.0, 12.5] to [0, 1.0].
"""
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
model.cond_scale_miltiplier = 1 / 12.5
model.need_last_noise_uncond = True
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones((x.shape[0]))
s_x = x.new_ones((x.shape[0], 1, 1, 1))
for i in tqdm.trange(len(timesteps) - 1, disable=disable):
index = len(timesteps) - 1 - i
e_t = model(x, timesteps[index].item() * s_in, **extra_args)
last_noise_uncond = model.last_noise_uncond
a_t = alphas[index].item() * s_x
a_prev = alphas_prev[index].item() * s_x
sigma_t = sigmas[index].item() * s_x
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * last_noise_uncond
noise = sigma_t * k_diffusion.sampling.torch.randn_like(x)
x = a_prev.sqrt() * pred_x0 + dir_xt + noise
if callback is not None:
callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
return x
@torch.no_grad()
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
s_x = x.new_ones((x.shape[0], 1, 1, 1))
old_eps = []
def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep
a_t = alphas[index].item() * s_x
a_prev = alphas_prev[index].item() * s_x
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
# direction pointing to x_t
dir_xt = (1. - a_prev).sqrt() * e_t
x_prev = a_prev.sqrt() * pred_x0 + dir_xt
return x_prev, pred_x0
for i in tqdm.trange(len(timesteps) - 1, disable=disable):
index = len(timesteps) - 1 - i
ts = timesteps[index].item() * s_in
t_next = timesteps[max(index - 1, 0)].item() * s_in
e_t = model(x, ts, **extra_args)
if len(old_eps) == 0:
# Pseudo Improved Euler (2nd order)
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
e_t_next = model(x_prev, t_next, **extra_args)
e_t_prime = (e_t + e_t_next) / 2
elif len(old_eps) == 1:
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (3 * e_t - old_eps[-1]) / 2
elif len(old_eps) == 2:
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
else:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
x = x_prev
if callback is not None:
callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
return x
class UniPCCFG(uni_pc.UniPC):
def __init__(self, cfg_model, extra_args, callback, *args, **kwargs):
super().__init__(None, *args, **kwargs)
def after_update(x, model_x):
callback({'x': x, 'i': self.index, 'sigma': 0, 'sigma_hat': 0, 'denoised': model_x})
self.index += 1
self.cfg_model = cfg_model
self.extra_args = extra_args
self.callback = callback
self.index = 0
self.after_update = after_update
def get_model_input_time(self, t_continuous):
return (t_continuous - 1. / self.noise_schedule.total_N) * 1000.
def model(self, x, t):
t_input = self.get_model_input_time(t)
res = self.cfg_model(x, t_input, **self.extra_args)
return res
def unipc(model, x, timesteps, extra_args=None, callback=None, disable=None, is_img2img=False):
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
ns = uni_pc.NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
t_start = timesteps[-1] / 1000 + 1 / 1000 if is_img2img else None # this is likely off by a bit - if someone wants to fix it please by all means
unipc_sampler = UniPCCFG(model, extra_args, callback, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant)
x = unipc_sampler.sample(x, steps=len(timesteps), t_start=t_start, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final)
return x