Skip to content

Commit

Permalink
huge update. Implemented discretized diffusion, binary embeddings, le…
Browse files Browse the repository at this point in the history
…arnable embeddings, self conditioning, and other stuff.
  • Loading branch information
GreenWizard2015 committed May 28, 2024
1 parent 05140f0 commit 3874092
Show file tree
Hide file tree
Showing 44 changed files with 1,296 additions and 100 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ __pycache__
/dataset/*.npy
/final_config.json
/visualized
/debug.log
5 changes: 3 additions & 2 deletions NN/RestorationModel/CRepeatedRestorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ def _withID(self, latents, idx, training):
return res

# only for training and building
def call(self, latents, pos, T, V, residual, training=None):
def call(self, latents, pos, T, V, residual, training=None, **kwargs):
for i in range(self._N):
V = self._restorator(
latents=self._withID(latents, i, training),
pos=pos, T=T, V=V, residual=residual, training=training
pos=pos, T=T, V=V, residual=residual, training=training,
**kwargs
)
continue
return V
Expand Down
18 changes: 11 additions & 7 deletions NN/RestorationModel/CRestorationModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,23 +73,27 @@ def _addRadius(self, latents, R=None, fakeR=0.0, training=False):

return latents

def call(self, latents, pos, T, V, residual, R=None, training=False):
def call(self, latents, pos, T, V, residual, R=None, training=False, extras=None, **kwargs):
EPos = self._encodePos(pos, training=training, args={})
t = self._encodeTime(T, training=training)
latents = self._addResiduals(latents, residual)
latents = self._addRadius(latents, R=R, training=training)
if extras is not None:
latents = tf.concat([latents, extras], axis=-1)
res = self._decoder(condition=latents, coords=EPos, timestep=t, V=V, training=training)
return self._withResidual(res, residual)

def reverse(self, latents, pos, reverseArgs, training, value, residual, index):
EPos = self._encodePos(pos, training=training, args=reverseArgs.get('decoder', {}))
latents = self._addResiduals(latents, residual)

def denoiser(x, t, mask=None, **kwargs):
def denoiser(V, T, mask=None, extras=None, **kwargs):
fakeR = kwargs.get('blurRadius', reverseArgs.get('fakeR', 0.0))
latentsPlus = self._addRadius(latents, R=None, fakeR=fakeR, training=training)
condition = self._addRadius(latents, R=None, fakeR=fakeR, training=training)
if extras is not None:
condition = tf.concat([condition, extras], axis=-1)

args = dict(condition=latentsPlus, coords=EPos, timestep=t, V=x)
args = dict(condition=condition, coords=EPos, timestep=T, V=V)
residuals = residual
if mask is not None:
args = {k: masked(v, mask) for k, v in args.items()}
Expand All @@ -116,10 +120,10 @@ def train_step(self, x0, latents, positions, params, xT=None):
return self._restorator.train_step(
x0=x0,
xT=xT,
model=lambda T, V: self(
model=lambda **kwargs: self(
latents=latents, pos=positions,
T=T, V=V, residual=residual,
R=R,
residual=residual, R=R,
**kwargs,
training=True
),
**params
Expand Down
6 changes: 2 additions & 4 deletions NN/RestorationModel/CSequentialRestorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@ def __init__(self, restorators, **kwargs):
return

# only for training and building
def call(self, latents, pos, T, V, residual, training=None):
def call(self, V, **kwargs):
for restorator in self._restorators:
V = restorator(
latents=latents, pos=pos, T=T, V=V, residual=residual, training=training
)
V = restorator(V=V, **kwargs)
continue
return V

Expand Down
115 changes: 115 additions & 0 deletions NN/layers/BinaryEmbeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import tensorflow as tf
from Utils.utils import CFakeObject
from NN.utils import normVec

class CBinaryEmbeddings(tf.keras.layers.Layer):
def __init__(self, input_dim, output_dim, name, **kwargs):
super().__init__(name=name, **kwargs)
assert input_dim == 256, 'Only 256 input_dim is supported'
assert output_dim == 8, 'Only 8 output_dim is supported'
self._N = input_dim
self.output_dim = output_dim
self._embeddings = tf.Variable(
initial_value=self._initEmbeddings(),
trainable=False,
name='%s/embeddings' % name
)
# self._scaleProbabilities = tf.Variable(
# initial_value=tf.zeros((1,)),
# trainable=True,
# name='%s/scaleProbabilities' % name
# )
return

def _initEmbeddings(self):
x = tf.range(256, dtype=tf.int32)
x = tf.reshape(x, (256, 1))

# To get the binary representation in an array format, you can use binary expansion
x_expanded = tf.reshape(
tf.stack([tf.bitwise.right_shift(x, i) & 1 for i in range(8)], axis=-1),
(256, 8)
)
x = tf.cast(x_expanded, tf.float32)
# Scale from 0 to 1 to -1 to 1
x = x * 2.0 - 1.0
tf.assert_equal(tf.shape(x), (256, 8))
return x

def normalize(self, x):
V, L = normVec(x)
L = tf.clip_by_value(L, clip_value_min=1e-6, clip_value_max=1.0)
return V * L

@property
def embeddings(self):
res = self._embeddings
return res

def call(self, inputs):
B = tf.shape(inputs)[0]
tf.assert_equal(tf.shape(inputs), (B, ))
res = tf.gather(self.embeddings, inputs)
tf.assert_equal(tf.shape(res), (B, self.output_dim))
return res

def _score(self, x):
x = self.normalize(x)
# Ensure `x` is 2D: [batch_size, num_features]
B = tf.shape(x)[0]
tf.assert_equal(tf.shape(x), (B, self.output_dim))

embeddings = self.embeddings
dot_product = tf.matmul(x, embeddings, transpose_b=True) # [B, N]
tf.assert_equal(tf.shape(dot_product), (B, self._N))

embLen = tf.reduce_sum(embeddings ** 2, axis=-1, keepdims=True)
embLen = tf.transpose(embLen)
tf.assert_equal(tf.shape(embLen), (1, self._N))

xLen = tf.reduce_sum(x ** 2, axis=-1, keepdims=True)
tf.assert_equal(tf.shape(xLen), (B, 1))

distance = embLen + xLen - 2 * dot_product
distance = tf.maximum(distance, 0.0)
tf.assert_equal(tf.shape(distance), (B, self._N))

scale = -1. #tf.nn.softplus(self._scaleProbabilities)
res = tf.nn.softmax(distance * scale, axis=-1)
tf.assert_equal(tf.shape(res), (B, self._N))
return res

def separability(self):
return 0.0

@tf.function
def loss(self, x, target):
B = tf.shape(x)[0]
tf.assert_equal(tf.shape(x), (B, self.output_dim))
tf.assert_equal(tf.shape(target), (B, 1))

scores = self._score(x)
tf.assert_equal(tf.shape(scores), (B, self._N))
res = tf.losses.sparse_categorical_crossentropy(target, scores)
return res

def encode(self, color):
# color is in range [-1, 1], single value
N = tf.size(color)
x = tf.reshape(color, (N, 1))
x = tf.clip_by_value(x, clip_value_min=-1.0, clip_value_max=1.0)
x = (x + 1.0) / 2.0 # [0, 1]
idx = tf.cast(x * self._N, tf.int32)
idx = tf.clip_by_value(idx, clip_value_min=0, clip_value_max=self._N - 1)
return CFakeObject(indices=idx, embeddings=self(idx[:, 0]))

def decode(self, x):
B = tf.shape(x)[0]
tf.assert_equal(tf.shape(x), (B, self.output_dim))
scores = self._score(x)
tf.assert_equal(tf.shape(scores), (B, self._N))
idx = tf.argmax(scores, axis=-1)[..., None]
idx = tf.cast(idx, tf.int32)
x = tf.cast(idx, tf.float32) / self._N
x = x * 2.0 - 1.0
return CFakeObject(values=x, indices=idx)
127 changes: 127 additions & 0 deletions NN/layers/ReversibleHyperEmbeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import tensorflow as tf
from Utils.utils import CFakeObject

class CReversibleHyperEmbeddings(tf.keras.layers.Layer):
def __init__(self, input_dim, output_dim, name, **kwargs):
super().__init__(name=name, **kwargs)
self._N = input_dim
self.output_dim = output_dim
self._ittr = tf.Variable(0., trainable=False, name='%s/ittr' % self.name)
self._embeddings = tf.Variable(
initial_value=tf.random.normal((input_dim, output_dim), dtype=tf.float32, stddev=0.1),
trainable=True,
name='%s/embeddings' % name
)
return

@property
def embeddings(self):
res = self._embeddings
return res

def call(self, inputs):
B = tf.shape(inputs)[0]
tf.assert_equal(tf.shape(inputs), (B, ))
res = tf.gather(self.embeddings, inputs)
tf.assert_equal(tf.shape(res), (B, self.output_dim))
return res

def _score(self, x):
# calculate the softmax of the distance between the embeddings and the input
# Ensure `x` is 2D: [batch_size, num_features]
B = tf.shape(x)[0]
tf.assert_equal(tf.shape(x), (B, self.output_dim))

embeddings = self.embeddings
dot_product = tf.matmul(x, embeddings, transpose_b=True) # [B, N]
tf.assert_equal(tf.shape(dot_product), (B, self._N))

embLen = tf.reduce_sum(embeddings ** 2, axis=-1, keepdims=True)
embLen = tf.transpose(embLen)
tf.assert_equal(tf.shape(embLen), (1, self._N))

xLen = tf.reduce_sum(x ** 2, axis=-1, keepdims=True)
tf.assert_equal(tf.shape(xLen), (B, 1))

distance = embLen + xLen - 2 * dot_product
distance = tf.maximum(distance, 0.0)
tf.assert_equal(tf.shape(distance), (B, self._N))

res = tf.nn.softmax(-distance, axis=-1)
tf.assert_equal(tf.shape(res), (B, self._N))
return res

def separability(self):
scores = self._score(self.embeddings)
tf.assert_equal(tf.shape(scores), (self._N, self._N))
# maximize the separability of the embeddings
idx = tf.range(self._N)[..., None]
separability = tf.reduce_mean(
tf.losses.sparse_categorical_crossentropy(idx, scores)
)
# distance from the origin of the embeddings
# minimize the distance from the origin
distance = tf.reduce_sum(self.embeddings ** 2, axis=-1)
distance = tf.sqrt(distance)
distance = tf.reduce_mean(distance)
return separability + distance

@tf.function
def loss(self, x, target):
B = tf.shape(x)[0]
tf.assert_equal(tf.shape(x), (B, self.output_dim))
tf.assert_equal(tf.shape(target), (B, 1))
res = tf.losses.sparse_categorical_crossentropy(target, self._score(x))

self._ittr.assign_add(1.0)
# Each N iterations, print debug info
N = 1000
if tf.cast(self._ittr, tf.int32) % N == 0:
self.debug(x, target)
return res

def encode(self, color):
# color is in range [-1, 1], single value
N = tf.size(color)
x = tf.reshape(color, (N, 1))
x = tf.clip_by_value(x, clip_value_min=-1.0, clip_value_max=1.0)
x = (x + 1.0) / 2.0 # [0, 1]
idx = tf.cast(x * self._N, tf.int32)
idx = tf.clip_by_value(idx, clip_value_min=0, clip_value_max=self._N - 1)
return CFakeObject(indices=idx, embeddings=self(idx[:, 0]))

def decode(self, x):
B = tf.shape(x)[0]
tf.assert_equal(tf.shape(x), (B, self.output_dim))
scores = self._score(x)
tf.assert_equal(tf.shape(scores), (B, self._N))
idx = tf.argmax(scores, axis=-1)[..., None]
idx = tf.cast(idx, tf.int32)
x = tf.cast(idx, tf.float32) / self._N
x = x * 2.0 - 1.0
return CFakeObject(values=x, indices=idx)

def debug(self, x, target):
tf.print('-' * 80)
scores = self._score(x)
# top-1 accuracy
predIdx = tf.cast(tf.argmax(scores, axis=-1), tf.int32)[..., None]
acc = tf.reduce_mean(tf.cast(predIdx == target, tf.float32))
tf.print('[%s] Accuracy TOP-1:' % self.name, acc)
# find avg position of the correct answer
sortedInd = tf.argsort(-scores, axis=-1)
ranks = tf.argsort(sortedInd, axis=-1) # get rank positions
K = tf.gather_nd(ranks, target, batch_dims=1)[:, None] # get ranks of targets
K = tf.cast(K, tf.float32)
tf.assert_equal(tf.shape(K)[-1], 1)
tf.print('[%s] Avg. rank (K):' % self.name, tf.reduce_mean(K))
# find avg euclidean distance between the correct answer and the predicted one
correct = tf.stop_gradient(self(target[..., 0]))
dist = tf.reduce_sum((x - correct) ** 2, axis=-1)
dist = tf.sqrt(dist)
tf.print('[%s] Avg. distance:' % self.name, tf.reduce_mean(dist))
# find avg "radius" of the embeddings
radius = tf.reduce_sum(self.embeddings ** 2, axis=-1)
radius = tf.sqrt(radius)
tf.print('[%s] Avg. radius:' % self.name, tf.reduce_mean(radius))
return
12 changes: 6 additions & 6 deletions NN/restorators/CARProcess.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ def __init__(self, predictions, sourceDistribution, sampler):
self._sampler = sampler
return

def forward(self, x0, xT=None):
def forward(self, x0, xT=None, model=None):
B = tf.shape(x0)[0]
# source distribution need to know the shape of the input, so we need to ensure it explicitly
# x0 = tf.ensure_shape(x0, (None, self.predictions))
sampled = self._sourceDistribution.sampleFor(x0)
x1 = sampled['xT'] if xT is None else xT

# tf.assert_equal(tf.shape(x0), (B, self.predictions))
return self._sampler.train(x0=x0, x1=x1, T=sampled['T'], xT=xT)
return self._sampler.train(x0=x0, x1=x1, T=sampled['T'], xT=xT, model=model)

def calculate_loss(self, x_hat, predicted, **kwargs):
if hasattr(self._sampler, 'calculate_loss'):
Expand All @@ -34,11 +34,11 @@ def calculate_loss(self, x_hat, predicted, **kwargs):
def _makeDenoiser(self, model, modelT):
timeEncoder = make_time_encoder(modelT)

def denoiser(x, t=None, **kwargs):
B = tf.shape(x)[0]
T = timeEncoder(t=t, B=B)
def denoiser(V, T=None, mask=None, **kwargs):
B = tf.shape(V)[0]
T = timeEncoder(t=T, B=B)
tf.assert_equal(B, tf.shape(T)[0])
return model(x=x, t=T, mask=kwargs.get('mask', None))[:, :self.predictions]
return model(V=V, T=T, mask=mask, **kwargs)[:, :self.predictions]
return denoiser

def reverse(self, value, denoiser, modelT=None, index=0, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion NN/restorators/CSingleStepRestoration.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ def __init__(self, **kwargs):
super().__init__(**kwargs)
return

def forward(self, x0, xT=None):
def forward(self, x0, xT=None, model=None):
s = tf.shape(x0)
T = tf.zeros(s[:-1], tf.int32)[..., None]
if xT is None:
Expand Down
9 changes: 6 additions & 3 deletions NN/restorators/IRestorationProcess.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, predictions, name=None, **kwargs):
def call(self, *args, **kwargs):
raise RuntimeError('IRestorationProcess object cannot be called directly')

def forward(self, x0, xT=None):
def forward(self, x0, xT=None, model=None):
raise NotImplementedError()

def reverse(self, value, denoiser, modelT=None, **kwargs):
Expand Down Expand Up @@ -59,8 +59,11 @@ def calculate_loss(self, x_hat, predicted, **kwargs):
raise NotImplementedError()

def train_step(self, x0, model, xT=None, **kwargs):
x_hat = self.forward(x0=x0, xT=xT)
values = model(T=x_hat['T'], V=x_hat['xT'])
x_hat = self.forward(
x0=x0, xT=xT,
model=lambda **kwargs: tf.stop_gradient(model(**kwargs)) # stop gradient for the model
)
values = model(T=x_hat['T'], V=x_hat['xT'], extras=x_hat.get('extras', None))

# we want calculate the loss WITH the residual
totalLoss = self.calculate_loss(x_hat, values, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion NN/restorators/diffusion/CDDIMSampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, stochasticity, noise_provider, steps, clipping, projectNoise)

def _reverseStep(self, model, schedule, eta):
def f(x, t, tPrev):
predictedNoise = model(x, t)
predictedNoise = model(V=x, T=schedule.to_continuous(t))
# based on https://github.com/filipbasara0/simple-diffusion/blob/main/scheduler/ddim.py
# obtain parameters for the current step and previous step
t = schedule.parametersForT(t)
Expand Down
Loading

0 comments on commit 3874092

Please sign in to comment.