Skip to content

Commit

Permalink
more advanced loss function for the model
Browse files Browse the repository at this point in the history
  • Loading branch information
GreenWizard2015 committed Jul 11, 2024
1 parent 130d1e5 commit 039902b
Showing 1 changed file with 38 additions and 12 deletions.
50 changes: 38 additions & 12 deletions Core/CModelTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,26 +40,52 @@ def _pointLoss(self, ytrue, ypred):
tf.assert_equal(tf.shape(loss), tf.shape(ytrue))
return tf.reduce_mean(loss, axis=-1)

def _trainStep(self, Data):
print('Instantiate _trainStep')
###############
x, (y, ) = Data
y = y[..., 0, :]
losses = {}
with tf.GradientTape() as tape:
data = x['augmented']
def _trainOn(self, data, y_list):
def calculate_loss(predictions):
# select the smallest loss from the list of suggested points
losses = []
for y in y_list:
loss = self._pointLoss(y, predictions)[..., None]
losses.append(loss)
continue
losses = tf.concat(losses, axis=-1)
shp = tf.shape(y_list[0])
tf.assert_equal(tf.shape(losses), tf.concat([shp[:-1], [len(y_list)]], axis=0))
losses = tf.reduce_min(losses, axis=-1)
tf.assert_equal(tf.shape(losses), shp[:-1])
return tf.reduce_mean(losses)

data = self._replaceByEmbeddings(data)
predictions = self._model(data, training=True)
intermediate = predictions['intermediate']
losses['final'] = tf.reduce_mean(self._pointLoss(y, predictions['result']))
finalPredictions = predictions['result']
losses = {}
losses['final'] = calculate_loss(finalPredictions)
for name, encoder in self._intermediateEncoders.items():
latent = intermediate[name]
pts = encoder(latent, training=True)
loss = self._pointLoss(y, pts)
loss = calculate_loss(pts)
losses['loss-%s' % name] = tf.reduce_mean(loss)
continue
loss = sum(losses.values())
losses['loss'] = loss
return losses, tf.stop_gradient(finalPredictions)

def _trainStep(self, Data):
print('Instantiate _trainStep')
###############
x, (y, ) = Data
y = y[..., 0, :]
losses = {}
with tf.GradientTape() as tape:
lossesClean, y_clean = self._trainOn(x['clean'], [y])
# ensure that the augmentations are not affect predictions
lossesAugmented, _ = self._trainOn(x['augmented'], [y, y_clean])
assert lossesClean.keys() == lossesAugmented.keys(), 'Losses keys mismatch'
# combine losses
losses = {k: lossesClean[k] + lossesAugmented[k] for k in lossesClean.keys()}
# calculate total loss and final loss
losses['total-clean'] = sum(lossesClean.values())
losses['total-augmented'] = sum(lossesAugmented.values())
losses['loss'] = loss = sum([losses['total-clean'], losses['total-augmented']])

self._optimizer.minimize(loss, tape.watched_variables(), tape=tape)
###############
Expand Down

0 comments on commit 039902b

Please sign in to comment.