Skip to content

Commit

Permalink
remove 64-bit DDIM
Browse files Browse the repository at this point in the history
  • Loading branch information
GreenWizard2015 committed Sep 17, 2023
1 parent 6de8b66 commit c55d78c
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 56 deletions.
40 changes: 2 additions & 38 deletions NN/restorators/diffusion/CDDIMSampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,46 +9,16 @@
# https://github.com/filipbasara0/simple-diffusion/blob/main/scheduler/ddim.py
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddim.py
class CDDIMSampler(IDiffusionSampler):
def __init__(self, stochasticity, noise_provider, steps, clipping, projectNoise, useFloat64=False):
def __init__(self, stochasticity, noise_provider, steps, clipping, projectNoise):
assert (0.0 <= stochasticity <= 1.0), 'Stochasticity must be in [0, 1] range'
self._eta = stochasticity
self._stepsConfig = steps
self._noise_provider = noise_provider
self._clipping = clipping
self._projectNoise = projectNoise
self._useFloat64 = useFloat64
return

def _reverseStep_float64(self, model, schedule, eta):
# use float64 and some tricks to improve numerical stability
def f(x, t, tPrev):
predictedNoise = model(x, t)
# based on https://github.com/filipbasara0/simple-diffusion/blob/main/scheduler/ddim.py
# obtain parameters for the current step and previous step
t = schedule.parametersForT(t, dtype=tf.float64)
tPrev = schedule.parametersForT(tPrev, dtype=tf.float64)

stepVariance = schedule.varianceBetween(t.alphaHat, tPrev.alphaHat)
sigma = tf.sqrt(stepVariance) * tf.cast(eta, dtype=stepVariance.dtype)
#######################################
noise_scale = tf.sqrt(1.0 - t.alphaHat)
coef2 = tf.sqrt(1.0 - tPrev.alphaHat - tf.square(sigma))
coef1 = tf.sqrt(tPrev.alphaHat / t.alphaHat)
# convert all tensors to x.dtype
coef1 = tf.cast(coef1, dtype=x.dtype)
coef2 = tf.cast(coef2, dtype=x.dtype)
noise_scale = tf.cast(noise_scale, dtype=x.dtype)
sigma = tf.cast(sigma, dtype=x.dtype)

x_minus_noise = x - (noise_scale * predictedNoise)
x_prev = ( (coef1 * x_minus_noise) + (coef2 * predictedNoise) )
tf.assert_equal(x.dtype, x_prev.dtype)
x_prev = tf.ensure_shape(x_prev, x.shape)
x0 = x_minus_noise / tf.cast(tf.sqrt(t.alphaHat), dtype=x.dtype)
return CFakeObject(x_prev=x_prev, sigma=sigma, x0=x0, x1=predictedNoise)
return f

def _reverseStep_float32(self, model, schedule, eta):
def _reverseStep(self, model, schedule, eta):
def f(x, t, tPrev):
predictedNoise = model(x, t)
# based on https://github.com/filipbasara0/simple-diffusion/blob/main/scheduler/ddim.py
Expand All @@ -73,12 +43,6 @@ def f(x, t, tPrev):
return CFakeObject(x_prev=x_prev, sigma=sigma, x0=pred_original_sample, x1=predictedNoise)
return f

def _reverseStep(self, model, schedule, eta):
if self._useFloat64:
return self._reverseStep_float64(model, schedule, eta)

return self._reverseStep_float32(model, schedule, eta)

def _valueUpdater(self, noise_provider, projectNoise):
if not projectNoise: # just add noise to the new value
return lambda step: step.x_prev + noise_provider(shape=tf.shape(step.x_prev), sigma=step.sigma)
Expand Down
3 changes: 1 addition & 2 deletions NN/restorators/diffusion/diffusion_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ def sampler_from_config(config):
noise_provider=noise_provider_from_config(config['noise stddev']),
steps=config['steps skip type'],
clipping=config.get('clipping', None),
projectNoise=config.get('project noise', False),
useFloat64=config.get('use float64', False),
projectNoise=config.get('project noise', False)
)

raise ValueError('Unknown sampler: %s' % config)
17 changes: 1 addition & 16 deletions tests/test_diffusion_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@ def fakeModel(x, t):
return fakeNoise + tf.cast(t + 1, tf.float32) * x
return { 'x': x, 'fakeModel': fakeModel, 'fakeNoise': fakeNoise }

def _fake_DDIM(stochasticity, K, useFloat64=False, noiseProjection=False):
def _fake_DDIM(stochasticity, K, noiseProjection=False):
return sampler_from_config({
'name': 'DDIM',
'stochasticity': stochasticity,
'noise stddev': 'zero',
'steps skip type': { 'name': 'uniform', 'K': K },
'use float64': useFloat64,
'project noise': noiseProjection,
})

Expand Down Expand Up @@ -101,20 +100,6 @@ def counter(*args, **kwargs):
tf.assert_equal(fakeModelB.calls, schedule.noise_steps)
return

def test_DDIM_float64():
schedule = CDPDiscrete( beta_schedule=get_beta_schedule('linear'), noise_steps=10 )
model = _fake_model(schedule.noise_steps)
x, fakeModel = model['x'], model['fakeModel']

ddimA = _fake_DDIM(stochasticity=1.0, K=1, useFloat64=False)
ddimB = _fake_DDIM(stochasticity=1.0, K=1, useFloat64=True)

A = ddimA.sample(value=x, model=fakeModel, schedule=schedule)
B = ddimB.sample(value=x, model=fakeModel, schedule=schedule)

tf.debugging.assert_near(A, B, atol=1e-6)
return

# test that noise projection does not change if noise is zero
@pytest.mark.parametrize(
'stochasticity,K',
Expand Down

0 comments on commit c55d78c

Please sign in to comment.