Skip to content

Commit

Permalink
simplified diffusion model
Browse files Browse the repository at this point in the history
  • Loading branch information
GreenWizard2015 committed Jul 6, 2024
1 parent 84b5493 commit 6d66496
Show file tree
Hide file tree
Showing 3 changed files with 304 additions and 6 deletions.
272 changes: 272 additions & 0 deletions Core/CModelDiffusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
import os
import numpy as np
import NN.networks as networks
import tensorflow as tf
import tensorflow_probability as tfp
import NN.Utils as NNU
import time
from tensorflow.keras import layers as L

# TODO: Implement the standard diffusion process (with the prediction of the noise, proper sampling, etc)
class CModelDiffusion:
'''
Wrapper for the diffusion model to predict the gaze point
Diffusion T is equal to the stddev of the gaussian noise
'''
def __init__(self, timesteps, model='simple', user=None, stats=None, use_encoders=False, **kwargs):
if user is None:
user = {
'userId': 0,
'placeId': 0,
'screenId': 0,
}
else:
user = {
'userId': stats['userId'].index(user['userId']),
'placeId': stats['placeId'].index(user['placeId']),
'screenId': stats['screenId'].index(user['screenId']),
}
self._user = user

self._modelID = model
self._timesteps = timesteps
embeddings = {
'userId': len(stats['userId']),
'placeId': len(stats['placeId']),
'screenId': len(stats['screenId']),
'size': 64,
}
self._modelRaw = networks.Face2LatentModel(
steps=timesteps, latentSize=64, embeddings=embeddings,
diffusion=True
)
self._model = self._modelRaw['main']
self._embeddings = {
'userId': L.Embedding(len(stats['userId']), embeddings['size']),
'placeId': L.Embedding(len(stats['placeId']), embeddings['size']),
'screenId': L.Embedding(len(stats['screenId']), embeddings['size']),
}
self._intermediateEncoders = {}
if use_encoders:
shapes = self._modelRaw['intermediate shapes']
for name, shape in shapes.items():
enc = networks.IntermediatePredictor(name='%s-encoder' % name)
enc.build(shape)
self._intermediateEncoders[name] = enc
continue

self._maxDiffusionT = 100.0
if 'weights' in kwargs:
self.load(**kwargs['weights'])
self.compile()
# add signatures to help tensorflow optimize the graph
specification = self._modelRaw['inputs specification']
self._trainStep = tf.function(
self._trainStep,
input_signature=[
(
{ 'clean': specification, 'augmented': specification, },
( tf.TensorSpec(shape=(None, None, None, 2), dtype=tf.float32), )
)
]
)
self._eval = tf.function(
self._eval,
input_signature=[(
specification,
( tf.TensorSpec(shape=(None, None, None, 2), dtype=tf.float32), )
)]
)

return

def _step2mean(self, step):
step = tf.cast(step, tf.float32) / self._maxDiffusionT
step = tf.cast(step, tf.float32) + 1e-6
# step = tf.pow(step, 2.0) # make it decrease faster
return tf.clip_by_value(step, 1e-3, 1.0)

def _replaceByEmbeddings(self, data):
data = dict(**data) # copy
for name, emb in self._embeddings.items():
data[name] = emb(data[name][..., 0])
continue
return data

def _makeGaussian(self, mean, stddev):
stddev = tf.concat([stddev, stddev], axis=-1)
return tfp.distributions.MultivariateNormalDiag(mean, stddev)

@tf.function
def _infer(self, data, training=False):
print('Instantiate _infer')
data = self._replaceByEmbeddings(data)
shp = tf.shape(data['userId'])
B, N = shp[0], self.timesteps
result = tf.zeros((B, N, 2), dtype=tf.float32)
for step in tf.range(self._maxDiffusionT, -1, -5):
mean = self._step2mean(
tf.fill((B, N, 1), step)
)
stepData = dict(**data)
stepData['diffusionT'] = mean
stepData['diffusionPoints'] = tf.random.normal((B, N, 2), mean=result, stddev=mean)
result = self._model(stepData, training=training)['result']
return result

def predict(self, data, **kwargs):
B = self._timesteps
userId = kwargs.get('userId', self._user['userId'])
placeId = kwargs.get('placeId', self._user['placeId'])
screenId = kwargs.get('screenId', self._user['screenId'])
# put them as (1, B, ?)
data['userId'] = np.full((1, B, 1), userId, dtype=np.int32)
data['placeId'] = np.full((1, B, 1), placeId, dtype=np.int32)
data['screenId'] = np.full((1, B, 1), screenId, dtype=np.int32)

data = self._replaceByEmbeddings(data) # replace embeddings

result = self._infer(data)
return result.numpy()

def __call__(self, data, startPos=None):
predictions = self.predict(data)
return {
'coords': predictions[0, -1, :],
}

def compile(self):
self._optimizer = NNU.createOptimizer()
return

def _modelFilename(self, folder, postfix=''):
postfix = '-' + postfix if postfix else ''
return os.path.join(folder, '%s-%s%s.h5' % (self._modelID, 'model', postfix))

def save(self, folder=None, postfix=''):
path = self._modelFilename(folder, postfix)
self._model.save_weights(path)
embeddings = {}
for nm in self._embeddings.keys():
weights = self._embeddings[nm].get_weights()[0]
embeddings[nm] = weights
continue
np.savez_compressed(path.replace('.h5', '-embeddings.npz'), **embeddings)
# save intermediate encoders
if self._intermediateEncoders:
encoders = {}
for nm, encoder in self._intermediateEncoders.items():
# save each variable separately
for ww in encoder.trainable_variables:
encoders['%s-%s' % (nm, ww.name)] = ww.numpy()
continue
np.savez_compressed(path.replace('.h5', '-intermediate-encoders.npz'), **encoders)
return

def load(self, folder=None, postfix='', embeddings=False):
path = self._modelFilename(folder, postfix) if not os.path.isfile(folder) else folder
self._model.load_weights(path)
if embeddings:
embeddings = np.load(path.replace('.h5', '-embeddings.npz'))
for nm, emb in self._embeddings.items():
w = embeddings[nm]
if not emb.built: emb.build((None, w.shape[0]))
emb.set_weights([w]) # replace embeddings
continue

if self._intermediateEncoders:
encodersName = path.replace('.h5', '-intermediate-encoders.npz')
if os.path.isfile(encodersName):
encoders = np.load(encodersName)
for nm, encoder in self._intermediateEncoders.items():
for ww in encoder.trainable_variables:
w = encoders['%s-%s' % (nm, ww.name)]
ww.assign(w)
continue
return

def lock(self, isLocked):
self._model.trainable = not isLocked
return

@property
def timesteps(self):
return self._timesteps

def trainable_variables(self):
parts = list(self._embeddings.values()) + [self._model] + list(self._intermediateEncoders.values())
return sum([p.trainable_variables for p in parts], [])

def _pointLoss(self, ytrue, ypred):
# pseudo huber loss
delta = 0.01
tf.assert_equal(tf.shape(ytrue), tf.shape(ypred))
diff = tf.square(ytrue - ypred)
loss = tf.sqrt(diff + delta ** 2) - delta
tf.assert_equal(tf.shape(loss), tf.shape(ytrue))
return tf.reduce_mean(loss, axis=-1)

def _trainStep(self, Data):
print('Instantiate _trainStep')
###############
x, (y, ) = Data
y = y[..., 0, :]
losses = {}
with tf.GradientTape() as tape:
data = x['augmented']
data = self._replaceByEmbeddings(data)
# add sampled T
B = tf.shape(y)[0]
N = self.timesteps
maxT = 100
diffusionT = tf.random.uniform((B, 1), minval=0, maxval=maxT, dtype=tf.int32)
# (B, 1) -> (B, N, 1)
diffusionT = tf.tile(diffusionT, (1, N))[..., None]
diffusionT = self._step2mean(diffusionT)
tf.assert_equal(tf.shape(diffusionT), (B, N, 1))

# store the diffusion parameters
data['diffusionT'] = diffusionT
# sample the points
data['diffusionPoints'] = tf.random.normal((B, N, 2), mean=y, stddev=diffusionT)
predictions = self._model(data, training=True)
# intermediate = predictions['intermediate']
# assert len(intermediate) == 0, 'Intermediate predictions are not supported'

predictedMean = predictions['result']
gaussian = self._makeGaussian(predictedMean, diffusionT)
losses['log_prob'] = tf.reduce_mean(
-gaussian.log_prob(y)
)
losses['points'] = self._pointLoss(y, predictedMean)
loss = sum(losses.values())
losses['loss'] = loss

self._optimizer.minimize(loss, tape.watched_variables(), tape=tape)
###############
return losses

def fit(self, data):
t = time.time()
losses = self._trainStep(data)
losses = {k: v.numpy() for k, v in losses.items()}
return {'time': int((time.time() - t) * 1000), 'losses': losses}

def _eval(self, xy):
print('Instantiate _eval')
x, (y,) = xy
y = y[:, :, 0]
B, N = tf.shape(y)[0], tf.shape(y)[1]

predictions = self._infer(x)

mean = self._step2mean(tf.fill((B, N, 1), 0))
gaussian = self._makeGaussian(predictions, mean)
loss = tf.nn.sigmoid( -gaussian.log_prob(y) )
points = predictions
_, dist = NNU.normVec(y - predictions)
return loss, points, dist

def eval(self, data):
loss, sampled, dist = self._eval(data)
return loss.numpy(), sampled.numpy(), dist.numpy()
34 changes: 29 additions & 5 deletions NN/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ def call(self, T):
return T[..., 0, :]

class IntermediatePredictor(tf.keras.layers.Layer):
def __init__(self, shift=0.5, **kwargs):
super().__init__(**kwargs)
self._shift = shift
return

def build(self, input_shape):
self._mlp = sMLP(
sizes=[128, 64, 32], activation='relu',
Expand All @@ -33,7 +38,7 @@ def call(self, x):
B = tf.shape(x)[0]
N = tf.shape(x)[1]
x = self._mlp(x)
x = 0.5 + self._decodePoints(x) # [0, 0] -> [0.5, 0.5]
x = self._shift + self._decodePoints(x) # [0, 0] -> [0.5, 0.5]
tf.assert_equal(tf.shape(x), (B, N, 2))
return x
# End of IntermediatePredictor
Expand Down Expand Up @@ -138,7 +143,8 @@ def _InputSpec():

def Face2LatentModel(
pointsN=478, eyeSize=32, steps=None, latentSize=64,
embeddings=None
embeddings=None,
diffusion=False # whether to use diffusion model
):
points = L.Input((steps, pointsN, 2))
eyeL = L.Input((steps, eyeSize, eyeSize, 1))
Expand All @@ -149,6 +155,14 @@ def Face2LatentModel(
screenIdEmb = L.Input((steps, embeddings['size']))

emb = L.Concatenate(-1)([userIdEmb, placeIdEmb, screenIdEmb])
if diffusion:
diffusionT = L.Input((steps, 1))
diffusionPoints = L.Input((steps, 2))
encodedDT = CTimeEncoderLayer()(diffusionT)
# shared transformation for all points
encodedDP = CCoordsEncodingLayer(32, sharedTransformation=True)(diffusionPoints)
# add diffusion features to the embeddings
emb = L.Concatenate(-1)([emb, encodedDT, encodedDP])

Face2Step = Face2StepModel(pointsN, eyeSize, latentSize, embeddingsSize=emb.shape[-1])
Step2Latent = Step2LatentModel(latentSize, embeddingsSize=emb.shape[-1])
Expand Down Expand Up @@ -179,15 +193,25 @@ def Face2LatentModel(
'placeId': placeIdEmb,
'screenId': screenIdEmb,
}

res['result'] = IntermediatePredictor()(res['latent'])
res['result'] = IntermediatePredictor(
shift=0.0 if diffusion else 0.5 # shift points to the center, if not using diffusion
)(
L.Concatenate(-1)([res['latent'], emb])
)

if diffusion:
inputs['diffusionT'] = diffusionT
inputs['diffusionPoints'] = diffusionPoints
# make residuals
res['result'] = diffusionPoints + res['result']

main = tf.keras.Model(inputs=inputs, outputs=res)
return {
'intermediate shapes': {k: v.shape for k, v in res['intermediate'].items()},
'main': main,
'Face2Step': Face2Step,
'Step2Latent': Step2Latent,
'inputs specification': _InputSpec(),
'inputs specification': _InputSpec()
}

if __name__ == '__main__':
Expand Down
4 changes: 3 additions & 1 deletion scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from collections import defaultdict
import time
from Core.CModelTrainer import CModelTrainer
from Core.CModelDiffusion import CModelDiffusion
import tqdm
import json
import glob
Expand Down Expand Up @@ -186,6 +187,7 @@ def F(epoch):

def _trainer_from(args):
if args.trainer == 'default': return CModelTrainer
if args.trainer == 'diffusion': return CModelDiffusion
raise Exception('Unknown trainer: %s' % (args.trainer, ))

def averageModels(folder, model, noiseStd=0.0):
Expand Down Expand Up @@ -330,7 +332,7 @@ def performRandomSearch(epoch=0):
parser.add_argument('--modelId', type=str)
parser.add_argument(
'--trainer', type=str, default='default',
choices=['default']
choices=['default', 'diffusion']
)
parser.add_argument(
'--schedule', type=str, default=None,
Expand Down

0 comments on commit 6d66496

Please sign in to comment.