Skip to content

Commit

Permalink
Update IWSLT configuration for transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
edunov authored and myleott committed Jul 25, 2018
1 parent dbe9637 commit c279407
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 6 deletions.
25 changes: 25 additions & 0 deletions examples/translation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,31 @@ $ python generate.py data-bin/iwslt14.tokenized.de-en \
```

To train transformer model on IWSLT'14 German to English:
```
# Preparation steps are the same as for fconv model.
# Train the model (better for a single GPU setup):
$ mkdir -p checkpoints/transformer
$ CUDA_VISIBLE_DEVICES=0 python train.py data-bin/iwslt14.tokenized.de-en \
-a transformer_iwslt_de_en --optimizer adam --lr 0.0005 -s de -t en \
--label-smoothing 0.1 --dropout 0.3 --max-tokens 4000 \
--min-lr '1e-09' --lr-scheduler inverse_sqrt --weight-decay 0.0001 \
--criterion label_smoothed_cross_entropy --max-update 50000 \
--warmup-updates 4000 --warmup-init-lr '1e-07' \
--adam-betas '(0.9, 0.98)' --save-dir checkpoints/transformer
# Average 10 latest checkpoints:
$ python scripts/average_checkpoints.py --inputs checkpoints/transformer \
--num-epoch-checkpoints 10 --output checkpoints/transformer/model.pt
# Generate:
$ python generate.py data-bin/iwslt14.tokenized.de-en \
--path checkpoints/transformer/model.pt \
--batch-size 128 --beam 5 --remove-bpe
```


### prepare-wmt14en2de.sh

Expand Down
12 changes: 6 additions & 6 deletions fairseq/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,14 +588,14 @@ def base_architecture(args):

@register_model_architecture('transformer', 'transformer_iwslt_de_en')
def transformer_iwslt_de_en(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 512)
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 1024)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 4)
args.encoder_layers = getattr(args, 'encoder_layers', 3)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 512)
args.encoder_layers = getattr(args, 'encoder_layers', 6)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 1024)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 4)
args.decoder_layers = getattr(args, 'decoder_layers', 3)
args.decoder_layers = getattr(args, 'decoder_layers', 6)
base_architecture(args)


Expand Down

0 comments on commit c279407

Please sign in to comment.