Skip to content

Commit

Permalink
more data to track training progress
Browse files Browse the repository at this point in the history
  • Loading branch information
GreenWizard2015 committed Jun 28, 2024
1 parent 8e28a38 commit bcb125f
Showing 1 changed file with 17 additions and 6 deletions.
23 changes: 17 additions & 6 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,29 +73,35 @@ def _eval(dataset, model, plotFilename, args):

def evaluator(datasets, model, folder, args):
losses = [np.inf] * len(datasets) # initialize with infinity
dists = [np.inf] * len(datasets) # initialize with infinity
def evaluate(onlyImproved=False):
totalLoss = totalDist = 0.0
losses_dist = []
for i, dataset in enumerate(datasets):
loss, dist, T = _eval(dataset, model, os.path.join(folder, 'pred-%d.png' % i), args)
losses_dist.append((loss, losses[i], dist, dists[i]))
isImproved = loss < losses[i]
if (not onlyImproved) or isImproved:
print('Test %d / %d | %.2f sec | Loss: %.5f (%.5f). Distance: %.5f' % (
i + 1, len(datasets), T, loss, losses[i], dist
print('Test %d / %d | %.2f sec | Loss: %.5f (%.5f). Distance: %.5f (%.5f)' % (
i + 1, len(datasets), T, loss, losses[i], dist, dists[i]
))
if isImproved:
print('Test %d / %d | Improved %.5f => %.5f' % (i + 1, len(datasets), losses[i], loss))
print('Test %d / %d | Improved %.5f => %.5f, Distance: %.5f => %.5f' % (
i + 1, len(datasets), losses[i], loss, dists[i], dist
))
model.save(folder, postfix='best-%d' % i) # save the model separately
losses[i] = loss
pass

dists[i] = min(dist, dists[i]) # track the best distance
totalLoss += loss
totalDist += dist
continue
if not onlyImproved:
print('Mean loss: %.5f | Mean distance: %.5f' % (
totalLoss / len(datasets), totalDist / len(datasets)
))
return totalLoss / len(datasets)
return totalLoss / len(datasets), losses_dist
return evaluate

def _modelTrainingLoop(model, dataset):
Expand Down Expand Up @@ -247,15 +253,20 @@ def main(args):
for nm in glob.glob(os.path.join(folder, 'test-main', 'test-*/'))
]
eval = evaluator(evalDatasets, model, folder, args)
bestLoss = eval() # evaluate loaded model
bestLoss, _ = eval() # evaluate loaded model
bestEpoch = 0
# wrapper for the evaluation function. It saves the model if it is better
def evalWrapper(eval):
def f(epoch, onlyImproved=False):
nonlocal bestLoss, bestEpoch
newLoss = eval(onlyImproved=onlyImproved)
newLoss, losses = eval(onlyImproved=onlyImproved)
if newLoss < bestLoss:
print('Improved %.5f => %.5f' % (bestLoss, newLoss))
if onlyImproved: #details
for i, (loss, bestLoss_, dist, bestDist) in enumerate(losses):
print('Test %d | Loss: %.5f (%.5f). Distance: %.5f (%.5f)' % (i + 1, loss, bestLoss_, dist, bestDist))
continue
print('-' * 80)
bestLoss = newLoss
bestEpoch = epoch
model.save(folder, postfix='best')
Expand Down

0 comments on commit bcb125f

Please sign in to comment.