Skip to content

Commit

Permalink
ignore outliers
Browse files Browse the repository at this point in the history
  • Loading branch information
GreenWizard2015 committed Jun 29, 2024
1 parent fa32d5a commit 2f22c74
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def evaluator(datasets, model, folder, args):
losses = [np.inf] * len(datasets) # initialize with infinity
dists = [np.inf] * len(datasets) # initialize with infinity
def evaluate(onlyImproved=False):
totalLoss = totalDist = 0.0
totalLoss = []
totalDist = []
losses_dist = []
for i, dataset in enumerate(datasets):
loss, dist, T = _eval(dataset, model, os.path.join(folder, 'pred-%d.png' % i), args)
Expand All @@ -94,14 +95,16 @@ def evaluate(onlyImproved=False):
pass

dists[i] = min(dist, dists[i]) # track the best distance
totalLoss += loss
totalDist += dist
# filter the results by the distance, to ignore the outliers
if dists[i] < 0.1:
totalLoss.append(loss)
totalDist.append(dist)
continue
if not onlyImproved:
print('Mean loss: %.5f | Mean distance: %.5f' % (
totalLoss / len(datasets), totalDist / len(datasets)
np.mean(totalLoss), np.mean(totalDist)
))
return totalLoss / len(datasets), losses_dist
return np.mean(totalLoss), losses_dist
return evaluate

def _modelTrainingLoop(model, dataset):
Expand Down

0 comments on commit 2f22c74

Please sign in to comment.