Skip to content

Commit

Permalink
update rounding, train logging (facebookresearch#431)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexholdenmiller authored Dec 6, 2017
1 parent 578833d commit 8796266
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
2 changes: 1 addition & 1 deletion examples/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,9 @@ def train(self):
if opt['num_epochs'] > 0 and self.max_parleys is not None and (
(self.max_parleys > 0 and self.parleys >= self.max_parleys)
or self.total_epochs >= opt['num_epochs']):
self.log()
print('[ num_epochs completed:{} time elapsed:{}s ]'.format(
opt['num_epochs'], self.train_time.time()))
self.log()
break
if opt['max_train_time'] > 0 and self.train_time.time() > opt['max_train_time']:
print('[ max_train_time elapsed:{}s ]'.format(self.train_time.time()))
Expand Down
2 changes: 1 addition & 1 deletion parlai/agents/seq2seq/seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def predict(self, xs, ys=None, cands=None, valid_cands=None, lm=False):
loss.backward()
self.update_params()
losskey = 'loss' if not lm else 'lmloss'
loss_dict = {losskey: loss.mul_(len(xs)).data[0]}
loss_dict = {losskey: loss.mul_(len(xs)).data}
else:
self.model.eval()
predictions, scores, text_cand_inds = self.model(xs, ys, cands,
Expand Down
11 changes: 6 additions & 5 deletions parlai/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,13 @@ def round_sigfigs(x, sigfigs=4):
try:
if x == 0:
return 0
except RuntimeError:
if x in [float('inf'), float('-inf'), float('NaN')]:
return x
return round(x, -math.floor(math.log10(abs(x)) - sigfigs + 1))
except TypeError:
# handle 1D torch tensors
x = x[0]
if x in [float('inf'), float('-inf'), float('NaN')]:
return x
return round(x, -math.floor(math.log10(abs(x)) - sigfigs + 1))
# if anything else breaks here please file an issue on Github
return round_sigfigs(x[0], sigfigs)


def flatten(teacher, context_length=-1, include_labels=True):
Expand Down

0 comments on commit 8796266

Please sign in to comment.