From 99f38a6d3fd8b10474e939071a6345be63f16e03 Mon Sep 17 00:00:00 2001 From: Emily Dinan Date: Fri, 22 Dec 2017 13:41:10 -0500 Subject: [PATCH] undid changes to train_model --- examples/train_model.py | 3 +-- parlai/agents/seq2seq/seq2seq.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/train_model.py b/examples/train_model.py index e115c364281..b926bce4bb2 100644 --- a/examples/train_model.py +++ b/examples/train_model.py @@ -145,8 +145,7 @@ def validate(self): valid_report, self.valid_world = run_eval( self.agent, opt, 'valid', opt['validation_max_exs'], valid_world=self.valid_world) - if valid_report[opt['validation_metric']] > self.best_valid or \ - valid_report[opt['validation_metric']] == 1: + if valid_report[opt['validation_metric']] > self.best_valid: self.best_valid = valid_report[opt['validation_metric']] self.impatience = 0 print('[ new best {}: {} ]'.format( diff --git a/parlai/agents/seq2seq/seq2seq.py b/parlai/agents/seq2seq/seq2seq.py index 868e4b709bb..533688578d4 100644 --- a/parlai/agents/seq2seq/seq2seq.py +++ b/parlai/agents/seq2seq/seq2seq.py @@ -341,7 +341,7 @@ def predict(self, xs, ys=None, cands=None, valid_cands=None, lm=False): self.zero_grad() loss = 0 predictions, scores, _ = self.model(xs, ys) - for i in range(ys.size(1)): + for i in range(scores.size(1)): # sum loss per-token score = scores.select(1, i) y = ys.select(1, i)