Skip to content

Commit

Permalink
misc
Browse files Browse the repository at this point in the history
  • Loading branch information
GreenWizard2015 committed Jun 1, 2024
1 parent be83836 commit 3787ab3
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
4 changes: 2 additions & 2 deletions Core/CModelTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
class CModelTrainer(CModelWrapper):
def __init__(self, timesteps, model='simple', **kwargs):
super().__init__(timesteps, model=model, **kwargs)
self._compile()
self.compile()
# add signatures to help tensorflow optimize the graph
specification = self._modelRaw['inputs specification']
self._trainStep = tf.function(
Expand All @@ -27,7 +27,7 @@ def __init__(self, timesteps, model='simple', **kwargs):
)
return

def _compile(self):
def compile(self):
self._model.compile(optimizer=NNU.createOptimizer())
return

Expand Down
6 changes: 4 additions & 2 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,12 @@ def evaluate(onlyImproved=False):
totalLoss = totalDist = 0.0
for i, dataset in enumerate(datasets):
loss, dist, T = _eval(dataset, model, os.path.join(folder, 'pred-%d.png' % i), args)
if not onlyImproved:
isImproved = loss < losses[i]
if (not onlyImproved) or isImproved:
print('Test %d / %d | %.2f sec | Loss: %.5f (%.5f). Distance: %.5f' % (
i + 1, len(datasets), T, loss, losses[i], dist
))
if loss < losses[i]:
if isImproved:
print('Test %d / %d | Improved %.5f => %.5f' % (i + 1, len(datasets), losses[i], loss))
model.save(folder, postfix='best-%d' % i) # save the model separately
losses[i] = loss
Expand Down Expand Up @@ -197,6 +198,7 @@ def averageModels(folder, model, noiseStd=0.0):
# average the weights
TV = [(x / N) + np.random.normal(0.0, noiseStd, x.shape) for x in TV]
model._model.set_weights(TV)
model.compile() # recompile the model with the new weights
return

def main(args):
Expand Down

0 comments on commit 3787ab3

Please sign in to comment.