From 0f7e9e42f5fdea8e862e5ee3e08d20f954a5471e Mon Sep 17 00:00:00 2001 From: GreenWizard Date: Sat, 29 Jun 2024 22:48:18 +0200 Subject: [PATCH] misc --- scripts/train.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/scripts/train.py b/scripts/train.py index 51aafe1..96cf8d6 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -96,9 +96,15 @@ def evaluate(onlyImproved=False): dists[i] = min(dist, dists[i]) # track the best distance # filter the results by the distance, to ignore the outliers - if dists[i] < 0.1: + maxValue = 0.1 + if dists[i] < maxValue: totalLoss.append(loss) totalDist.append(dist) + else: + # prevent the big "jumps" in the loss and distance when the model is becoming better + # assuming that maxValue is bigger than the corresponding loss + totalLoss.append(maxValue) + totalDist.append(maxValue) continue if not onlyImproved: print('Mean loss: %.5f | Mean distance: %.5f' % (