Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: fairinternal/fairseq-py#1164

Reviewed By: ngoyal2707

Differential Revision: D21373232

Pulled By: myleott

fbshipit-source-id: f31c65c6f2ebd9a603099e0cbe9e32c47585f50d
  • Loading branch information
myleott authored and facebook-github-bot committed May 4, 2020
1 parent 7a6519f commit b2ee110
Show file tree
Hide file tree
Showing 10 changed files with 51 additions and 19 deletions.
4 changes: 3 additions & 1 deletion fairseq/benchmark/dummy_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
'id': 1,
'net_input': {
'src_tokens': torch.stack([self.dummy_src for _ in range(bsz)]),
'src_lengths': torch.full((bsz, ), self.args.tokens_per_sample),
'src_lengths': torch.full(
(bsz, ), self.args.tokens_per_sample, dtype=torch.long
),
},
'target': torch.stack([self.dummy_tgt for _ in range(bsz)]),
'nsentences': bsz,
Expand Down
4 changes: 3 additions & 1 deletion fairseq/benchmark/dummy_masked_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
'id': 1,
'net_input': {
'src_tokens': torch.stack([self.dummy_src for _ in range(bsz)]),
'src_lengths': torch.full((bsz, ), self.args.tokens_per_sample),
'src_lengths': torch.full(
(bsz, ), self.args.tokens_per_sample, dtype=torch.long
),
},
'target': torch.stack([self.dummy_tgt for _ in range(bsz)]),
'nsentences': bsz,
Expand Down
14 changes: 12 additions & 2 deletions fairseq/data/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,14 @@ def index(self, sym):
return self.indices[sym]
return self.unk_index

def string(self, tensor, bpe_symbol=None, escape_unk=False, extra_symbols_to_ignore=None):
def string(
self,
tensor,
bpe_symbol=None,
escape_unk=False,
extra_symbols_to_ignore=None,
unk_string=None,
):
"""Helper for converting a tensor of token indices to a string.
Can optionally remove BPE symbols or escape <unk> words.
Expand All @@ -73,7 +80,10 @@ def string(self, tensor, bpe_symbol=None, escape_unk=False, extra_symbols_to_ign

def token_string(i):
if i == self.unk():
return self.unk_string(escape_unk)
if unk_string is not None:
return unk_string
else:
return self.unk_string(escape_unk)
else:
return self[i]

Expand Down
2 changes: 1 addition & 1 deletion fairseq/models/lstm_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim):
num_layers=args.decoder_layers,
dropout_in=args.decoder_dropout_in,
dropout_out=args.decoder_dropout_out,
attention=options.eval_bool(args.decoder_attention),
attention=False, # decoder-only language model doesn't support attention
encoder_output_units=0,
pretrained_embed=pretrained_decoder_embed,
share_input_output_embed=args.share_decoder_input_output_embed,
Expand Down
12 changes: 12 additions & 0 deletions fairseq/modules/linearized_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@ def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
self._linearized_weight = None
self.register_backward_hook(self._clear_linearized_weight)

def state_dict(self, destination=None, prefix='', keep_vars=False):
state = ConvTBC.state_dict(self, destination, prefix, keep_vars=keep_vars)
# don't store redundant _linearized_weight in checkpoints
if prefix + '_linearized_weight' in state:
del state[prefix + '_linearized_weight']
return state

def upgrade_state_dict_named(self, state_dict, name):
prefix = name + '.' if name != '' else ''
if prefix + '_linearized_weight' in state_dict:
del state_dict[prefix + '_linearized_weight']

def forward(self, input, incremental_state=None):
"""
Args:
Expand Down
8 changes: 4 additions & 4 deletions fairseq/optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ def step(self, closure=None):
state['step'] += 1

# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
Expand All @@ -191,9 +191,9 @@ def step(self, closure=None):
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1

if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr'])

p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size)

# TODO: remove check once pyTorch avoids a copy for this case
if p.data_ptr() != p_data_fp32.data_ptr():
Expand Down
6 changes: 3 additions & 3 deletions fairseq/optim/adamax.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def step(self, closure=None):
state['step'] += 1

# Update biased first moment estimate.
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)

# Update the exponentially weighted infinity norm.
torch.max(
Expand All @@ -146,9 +146,9 @@ def step(self, closure=None):
step_size /= bias_correction

if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr'])

p_data_fp32.addcdiv_(-step_size, exp_avg, exp_inf.add(eps))
p_data_fp32.addcdiv_(exp_avg, exp_inf.add(eps), value=-step_size)

p.data.copy_(p_data_fp32)

Expand Down
6 changes: 3 additions & 3 deletions fairseq/optim/nag.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,10 @@ def step(self, closure=None):

if weight_decay != 0:
p_data_fp32.mul_(1 - lr * weight_decay)
p_data_fp32.add_(momentum * momentum * lr_correct, buf)
p_data_fp32.add_(-(1 + momentum) * lr, d_p)
p_data_fp32.add_(buf, alpha=momentum * momentum * lr_correct)
p_data_fp32.add_(d_p, alpha=-(1 + momentum) * lr)

buf.mul_(momentum * lr_correct).add_(-lr, d_p)
buf.mul_(momentum * lr_correct).add_(d_p, alpha=-lr)

# TODO: remove check once pyTorch avoids a copy for this case
if p.data_ptr() != p_data_fp32.data_ptr():
Expand Down
13 changes: 10 additions & 3 deletions fairseq/tasks/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,14 +181,14 @@ def add_args(parser):
parser.add_argument('--eval-bleu', action='store_true',
help='evaluation with BLEU scores')
parser.add_argument('--eval-bleu-detok', type=str, default="space",
help='detokenizer before computing BLEU (e.g., "moses"); '
help='detokenize before computing BLEU (e.g., "moses"); '
'required if using --eval-bleu; use "space" to '
'disable detokenization; see fairseq.data.encoders '
'for other options')
parser.add_argument('--eval-bleu-detok-args', type=str, metavar='JSON',
help='args for building the tokenizer, if needed')
parser.add_argument('--eval-tokenized-bleu', action='store_true', default=False,
help='if setting, we compute tokenized BLEU instead of sacrebleu')
help='compute tokenized BLEU instead of sacrebleu')
parser.add_argument('--eval-bleu-remove-bpe', nargs='?', const='@@ ', default=None,
help='remove BPE before computing BLEU')
parser.add_argument('--eval-bleu-args', type=str, metavar='JSON',
Expand Down Expand Up @@ -351,7 +351,14 @@ def decode(toks, escape_unk=False):
s = self.tgt_dict.string(
toks.int().cpu(),
self.args.eval_bleu_remove_bpe,
escape_unk=escape_unk,
# The default unknown string in fairseq is `<unk>`, but
# this is tokenized by sacrebleu as `< unk >`, inflating
# BLEU scores. Instead, we use a somewhat more verbose
# alternative that is unlikely to appear in the real
# reference, but doesn't get split into multiple tokens.
unk_string=(
"UNKNOWNTOKENINREF" if escape_unk else "UNKNOWNTOKENINHYP"
),
)
if self.tokenizer:
s = self.tokenizer.decode(s)
Expand Down
1 change: 0 additions & 1 deletion fairseq_cli/eval_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,6 @@ def main(parsed_args, **unused_kwargs):

def cli_main():
parser = options.get_eval_lm_parser()
add_distributed_training_args(parser)
args = options.parse_args_and_arch(parser)
distributed_utils.call_main(args, main)

Expand Down

0 comments on commit b2ee110

Please sign in to comment.