Skip to content

Commit

Permalink
added normalisation
Browse files Browse the repository at this point in the history
removed useless branch
Improved training script
  • Loading branch information
GreenWizard2015 committed Jun 4, 2024
1 parent c4501c0 commit ed74a15
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 53 deletions.
2 changes: 2 additions & 0 deletions NN/Utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def __init__(self, mlp=None, **kwargs):
super().__init__(**kwargs)
if mlp is None: mlp = lambda x: x
self._mlp = mlp
self._norm = L.LayerNormalization()
return

def build(self, input_shapes):
Expand All @@ -243,6 +244,7 @@ def build(self, input_shapes):
def call(self, x):
assert isinstance(x, list), "expected list of inputs"
xhat = tf.concat(x, axis=-1)
xhat = self._norm(xhat)
xhat = self._mlp(xhat)
xhat = self._lastDense(xhat)
x0 = x[0]
Expand Down
29 changes: 0 additions & 29 deletions NN/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,36 +185,7 @@ def Face2LatentModel(
IP = lambda x: IntermediatePredictor()(x) # own IntermediatePredictor for each output
res['intermediate'] = {k: IP(x) for k, x in intermediate.items()}
res['result'] = IP(res['latent'])
###################################
# TODO: figure out is this helpful or not
# branch for global coordinates transformation
# predict shift, rotation, scale
emb = L.Concatenate(-1)([userIdEmb, placeIdEmb, screenIdEmb])
emb = sMLP(sizes=[64, 64, 64, 64, 32], activation='relu')(emb[:, 0])
shift = L.Dense(2, name='GlobalShift')(emb)[:, None]
rotation = L.Dense(1, name='GlobalRotation', activation='sigmoid')(emb)[:, None] * np.pi
scale = L.Dense(2, name='GlobalScale')(emb)[:, None]

shifted = res['result'] + shift - 0.5 # [0.5, 0.5] -> [0, 0]
# Rotation matrix components
cos_rotation = L.Lambda(lambda x: tf.cos(x))(rotation)
sin_rotation = L.Lambda(lambda x: tf.sin(x))(rotation)
rotation_matrix = L.Lambda(lambda x: tf.stack([x[0], x[1]], axis=-1))([cos_rotation, sin_rotation])

# Apply rotation
rotated = L.Lambda(
lambda x: tf.einsum('isj,iomj->isj', x[0], x[1])
)([shifted, rotation_matrix]) + 0.5 # [0, 0] -> [0.5, 0.5] back

# Apply scale
scaled = rotated * scale
def clipWithGradient(x):
res = tf.clip_by_value(x, 0.0, 1.0)
return x + tf.stop_gradient(res - x)

res['result'] = L.Lambda(clipWithGradient)(scaled)
###################################

main = tf.keras.Model(inputs=inputs, outputs=res)
return {
'main': main,
Expand Down
57 changes: 33 additions & 24 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,31 +239,52 @@ def main(args):
model = trainer(**model)
model._model.summary()

if args.average:
averageModels(folder, model)
# find folders with the name "/test-*/"
evalDatasets = [
CTestLoader(nm)
for nm in glob.glob(os.path.join(folder, 'test-main', 'test-*/'))
]
eval = evaluator(evalDatasets, model, folder, args)
bestLoss = eval()
bestLoss = eval() # evaluate loaded model
bestEpoch = 0
# wrapper for the evaluation function. It saves the model if it is better
def evalWrapper(eval):
def f(epoch, onlyImproved=False):
nonlocal bestLoss, bestEpoch
newLoss = eval(onlyImproved=onlyImproved)
if newLoss < bestLoss:
print('Improved %.5f => %.5f' % (bestLoss, newLoss))
bestLoss = newLoss
bestEpoch = epoch
model.save(folder, postfix='best')
return
return f

eval = evalWrapper(eval)

def performRandomSearch(epoch=0):
nonlocal bestLoss, bestEpoch
averageModels(folder, model, noiseStd=0.0)
eval(epoch=epoch, onlyImproved=True) # evaluate the averaged model
for _ in range(args.restarts):
# and add some noise
averageModels(folder, model, noiseStd=args.noise)
# re-evaluate the model with the new weights
eval(epoch=epoch, onlyImproved=True)
continue
return

if args.average:
performRandomSearch()

trainStep = _modelTrainingLoop(model, trainDataset)
for epoch in range(args.epochs):
trainStep(
desc='Epoch %.*d / %d' % (len(str(args.epochs)), epoch, args.epochs),
sampleParams=getSampleParams(epoch)
)
model.save(folder, postfix='latest')

testLoss = eval()
if testLoss < bestLoss:
print('Improved %.5f => %.5f' % (bestLoss, testLoss))
bestLoss = testLoss
bestEpoch = epoch
model.save(folder, postfix='best')
continue
eval(epoch)

print('Passed %d epochs since the last improvement (best: %.5f)' % (epoch - bestEpoch, bestLoss))
if args.patience <= (epoch - bestEpoch):
Expand All @@ -272,19 +293,7 @@ def main(args):
break
if 'reset' == args.on_patience:
print('Resetting the model to the average of the best models')
bestEpoch = epoch # reset the best epoch
for _ in range(args.restarts):
# and add some noise
averageModels(folder, model, noiseStd=args.noise)
# re-evaluate the model with the new weights
testLoss = eval(onlyImproved=True)
if testLoss < bestLoss:
print('Improved %.5f => %.5f' % (bestLoss, testLoss))
bestLoss = testLoss
bestEpoch = epoch
model.save(folder, postfix='best')
continue
continue
performRandomSearch(epoch=epoch)
continue
return

Expand Down

0 comments on commit ed74a15

Please sign in to comment.