Skip to content
/ FA-DAT Public
forked from karin0018/FA-DAT

Official Implementation for the ICLR2023 paper "Fuzzy Alignments in Directed Acyclic Graph for Non-autoregressive Machine Translation"

License

Notifications You must be signed in to change notification settings

PPPNut/FA-DAT

Repository files navigation

Fuzzy Alignments in Directed Acyclic Graph for Non-autoregressive Machine Translation

Implementation for the ICLR 2023 paper "Fuzzy Alignments in Directed Acyclic Graph for Non-autoregressive Machine Translation".

Abstract: We introduce a fuzzy alignment objective between the directed acyclic graph and reference sentence based on n-gram matching, aiming to handle the training data multi-modality.

Highlights:

  • FA-DAT outperforms DA-Transformer baseline by 1.1 BLEU on raw WMT17 ZH<->EN dataset.
  • FA-DAT achieves the performance of the autoregressive Transformer on raw WMT14 EN<->DE & WMT17 ZH-EN dataset with fully parallel decoding (13Ă— speedup).

Features:

  • We provide numba implementations (optional) for dynamic programming in the calculation of fuzzy alignment to speedup the training.
  • We also provide the functions implemented in PyTorch (by default).

Files:

  • Most codes of the framework are from DA-Transformer. We mainly add the following files as plugins.

    FA-DAT
    └── fs_plugins
        └── criterions
            ├── nat_dag_loss_ngram.py                   # fuzzy alignment loss
            └── pass_prob.py                            # numba implementation of dynamic programming in loss
    
    
  • This repo is forked from DA-Transformer, which is modified from fairseq:5175fd. You can refer to above repos for more information.

Below is a guide to replicate the results reported in the paper. We give an example of experiments on WMT14 En-De dataset.

Requirements & Installation

Requirements

  • Python >= 3.7
  • Pytorch == 1.10.1 (tested with cuda == 11.3)
  • gcc >= 7.0.0 (for compiling cuda operations in NLL pretraining, as recommended in DA-Transformer)
  • (Optional) numba == 0.56.2

Installation

  • git clone --recurse-submodules https://github.com/ictnlp/FA-DAT.git
  • cd FA-DAT && pip install -e .

Data Preprocess

The training require binarized data generated by the following script:

input_dir=path/to/raw_data        # directory of raw text data
data_dir=path/to/binarized_data   # directory of the generated binarized data
src=en                            # source language id
tgt=de                            # target language id
fairseq-preprocess --source-lang ${src} --target-lang ${tgt} \
    --trainpref ${input_dir}/train --validpref ${input_dir}/valid --testpref ${input_dir}/test \
    --src-dict ${input_dir}/dict.${src}.txt --tgt-dict {input_dir}/dict.${tgt}.txt \
    --destdir ${data_dir} --workers 32

Negative Log-likelihood Pre-training

We use a batch size of approximating 64k tokens, so the GPU number * max_tokens * update_freq should be 64k.

For example, if you have 8 V100-32 GPUs, run the following script for training:

data_dir=/path/to/binarized/data/dir
checkpoint_dir=/path/to/checkpoint/dir

fairseq-train ${data_dir}  \
    --user-dir fs_plugins \
    \
    --task translation_lev_modified  --noise full_mask \
    \
    --arch glat_decomposed_link_base \
    --decoder-learned-pos --encoder-learned-pos \
    --share-all-embeddings --activation-fn gelu \
    --apply-bert-init \
    --links-feature feature:position --decode-strategy lookahead \
    --max-source-positions 128 --max-target-positions 1024 --src-upsample-scale 8.0 \
    \
    --criterion nat_dag_loss \
    --length-loss-factor 0 --max-transition-length 99999 \
    --glat-p 0.5:0.1@200k --glance-strategy number-random \
    \
    --optimizer adam --adam-betas '(0.9,0.999)' --fp16 \
    --label-smoothing 0.0 --weight-decay 0.01 --dropout 0.1 \
    --lr-scheduler inverse_sqrt  --warmup-updates 10000   \
    --clip-norm 0.1 --lr 0.0005 --warmup-init-lr '1e-07' --stop-min-lr '1e-09' \
    --ddp-backend c10d \
    \
    --max-tokens 4096  --update-freq 2 --grouped-shuffling \
    --max-update 300000 --max-tokens-valid 4096 \
    --save-interval 1  --save-interval-updates 10000  \
    --seed 0 \
    \
    --valid-subset valid \
    --validate-interval 1       --validate-interval-updates 10000 \
    --eval-bleu --eval-bleu-args '{"iter_decode_max_iter": 0, "iter_decode_with_beam": 1}' \
    --eval-bleu-detok space --eval-bleu-remove-bpe --eval-bleu-print-samples \
    --fixed-validation-seed 7 \
    --best-checkpoint-metric bleu --maximize-best-checkpoint-metric   \
    --keep-best-checkpoints 5 --save-dir ${checkpoint_dir} \
    --log-format 'simple' --log-interval 100

Fuzzy Alignment Training

ngram_order=2
exp=ende_raw_lambda_8_ngram_${ngram_order}_256kbsz
data_dir=~/dataset/wmt14_ende_ori/data-bin-en-de
checkpoint_dir=./checkpoints/$exp


nohup fairseq-train ${data_dir}  \
    --user-dir fs_plugins \
    --task translation_lev_modified  --noise full_mask \
    --arch glat_decomposed_link_base \
    --decoder-learned-pos --encoder-learned-pos \
    --share-all-embeddings --activation-fn gelu \
    --finetune-from-model ./checkpoints/ende_raw_lambda_8/average.pt \
    --links-feature feature:position --decode-strategy lookahead \
    --max-source-positions 128 --max-target-positions 1024 --src-upsample-scale 8.0 \
    --criterion nat_dag_loss_ngram --max-ngram-order ${ngram_order} \
    --length-loss-factor 0 --max-transition-length 99999 \
    --glat-p 0.1:0.1@4k --glance-strategy number-random \
    --optimizer adam --adam-betas '(0.9,0.999)' --fp16 \
    --label-smoothing 0.0 --weight-decay 0.01 --dropout 0.1 \
    --lr-scheduler inverse_sqrt  --warmup-updates 500   \
    --clip-norm 0.1 --lr 0.0002 --warmup-init-lr '1e-07' --stop-min-lr '1e-09' \
    --ddp-backend=legacy_ddp \
    --max-tokens 200  --update-freq 328 --grouped-shuffling \
    --max-update 5000 --max-tokens-valid 256 \
    --save-interval 1  --save-interval-updates 500 \
    --seed 0 \
    --valid-subset valid \
    --skip-invalid-size-inputs-valid-test \
    --validate-interval 1       --validate-interval-updates 500 \
    --eval-bleu --eval-bleu-args '{"iter_decode_max_iter": 0, "iter_decode_with_beam": 1}' \
    --eval-bleu-detok moses --eval-bleu-remove-bpe \
    --fixed-validation-seed 7 \
    --best-checkpoint-metric bleu --maximize-best-checkpoint-metric  \
    --keep-best-checkpoints 5 --save-dir ${checkpoint_dir} \
    --keep-interval-updates 5 --keep-last-epochs 5 \
    --log-format 'simple' --log-interval 2 > logs/$exp.txt &

Explanations of some arguments

--task translation_lev_modified   # Task for NATs
--optimizer adam                  # The optimizer. You can use "ls_adam" instead to enable LightSeq's Optimizer

--arch glat_decomposed_link_base  # The Model Architecture. You can use "ls_glat_decomposed_link_base" to enable LightSeq's Transformer
--links-feature feature:position  # Features used to predict links(transitions). We use transformer outputs with learnable positional embeddings
--max-transition-length 99999       # Max transition distance. -1 means no limitation, which does not support cuda operations. To use cuda operations with no limitation, please set a very large number such as 99999.
--src-upsample-scale 8.0          # The upsampling rate (\lambda), Graph Size = \lambda * src_length
--max-source-positions 128        # Max length of encoder
--max-target-positions 1024       # Max length of decoder. Should be at least \lambda * the max target length in samples. The samples not satisfying the limitation will be dropped.
--decode-strategy  lookahead      # Decoding Strategy. Possible values: greedy, lookahead, beamsearch.

--criterion nat_dag_loss          # The Criterion for DA-Transformer
--label-smoothing 0.0             # DA-Transformer does not use label smoothing for now
--length-loss-factor 0            # Disable the length predictor
--glat-p 0.5:0.1@200k             # Glancing probability. It indicates annealing p from 0.5 to 0.1 in 200k steps.
--glance-strategy number-random   # Glancing strategy. Possible values: "number-random" or "None" or "CMLM"
--torch-dag-loss                  # Use torch implementation for dag loss instead cuda implementation. It may become slower and consume more memory.
--torch-dag-best-alignment        # Use torch implementation for best-alignment instead cuda implementation. It may become slower and consume more memory.
--torch-dag-logsoftmax-gather     # Use torch implementation for logsoftmax-gather instead cuda implementation. It may become slower and consume more memory.

Speed up with Lightseq

Only need to change two options:

  • --arch glat_decomposed_link_base to --arch ls_glat_decomposed_link_base
  • --optimizer adam to --optimizer ls_adam

Lightseq usually brings about 1.5x speed up in training.

Note: Lightseq does not provide all Transformer variants in Fairseq. If you want to change the model architecture, it may work inappropriately without any warnings. You must read the codes carefully before you make any changes.

Inference

DA-Transformer provides four decoding strategy:

  • Greedy: Fastest, use argmax in token prediction and transition prediction.
  • Lookahead: Similar speed as Greedy, but higher quality. Consider transition prediction and token prediction together
  • Viterbi: Slightly slower than Lookahead but higher quality. Support length penalty to control the output length.
  • BeamSearch: Slowest but highest quality. Can combine with n-gram language model.

Average Checkpoints

We average the parameters of the best 5 checkpoints, empirically leading to a better performance.

checkpoint_dir=/path/to/checkpoint/dir
average_checkpoint_path=/path/to/checkpoint/average.pt

python3 ./fs_plugins/scripts/average_checkpoints.py --inputs ${checkpoint_dir} \
                --max-metric --best-checkpoints-metric bleu --num-best-checkpoints-metric 5 --output ${average_checkpoint_path}

Greedy & Lookahead

# Greedy
data_dir=/path/to/binarized/data/dir
average_checkpoint_path=/path/to/checkpoint/average.pt

fairseq-generate ${data_dir} \
    --gen-subset test --user-dir fs_plugins --task translation_lev_modified \
    --iter-decode-max-iter 0 --iter-decode-eos-penalty 0 --beam 1 \
    --remove-bpe --max-tokens 4096 --seed 0 \
    --model-overrides "{\"decode_strategy\":\"greedy\"}" \
    --path ${average_checkpoint_path}

# Lookahead
# ``decode_beta`` scales the score of logits. Specifically: y_i, a_i = argmax [ log P(y_i|a_i) + beta * log P(a_i|a_{i-1}) ]

fairseq-generate ${data_dir} \
    --gen-subset test --user-dir fs_plugins --task translation_lev_modified \
    --iter-decode-max-iter 0 --iter-decode-eos-penalty 0 --beam 1 \
    --remove-bpe --max-tokens 4096 --seed 0 \
    --model-overrides "{\"decode_strategy\":\"lookahead\",\"decode_beta\":1}" \
    --path ${average_checkpoint_path}

For WMT17 En-Zh, we add --source-lang en --target-lang zh --tokenizer moses --scoring sacrebleu --sacrebleu-tokenizer zh.

Note: the latency in our paper is evaluated with batch size of 1. That is, replacing --max-tokens 4096 by --batch-size 1

Viterbi Decoding

Viterbi decoding algorithms proposed in "Viterbi Decoding of Directed Acyclic Transformer for Non-Autoregressive Machine Translation". decode_viterbibeta is the length penalty that controls the output length. Viterbi decoding finds the path than maximizes P(A|X) / |Y|^{beta}. Joint-Viterbi finds the output that maximizes P(A,Y|X) / |Y|^{beta}.

You can specify decode_strategy to viterbi or jointviterbi to enable the Viterbi decoding. jointviterbi is usually recommended because it jointly considers the transition and token probabilities, similar to lookahead decoding.

# Viterbi
data_dir=/path/to/binarized/data/dir
average_checkpoint_path=/path/to/checkpoint/average.pt

fairseq-generate ${data_dir} \
    --gen-subset test --user-dir fs_plugins --task translation_lev_modified \
    --iter-decode-max-iter 0 --iter-decode-eos-penalty 0 --beam 1 \
    --remove-bpe --max-tokens 4096 --seed 0 \
    --model-overrides "{\"decode_strategy\":\"viterbi\", \"decode_viterbibeta\":1.0}" \
    --path ${average_checkpoint_path}

# Joint-Viterbi

fairseq-generate ${data_dir} \
    --gen-subset test --user-dir fs_plugins --task translation_lev_modified \
    --iter-decode-max-iter 0 --iter-decode-eos-penalty 0 --beam 1 \
    --remove-bpe --max-tokens 4096 --seed 0 \
    --model-overrides "{\"decode_strategy\":\"jointviterbi\",\"decode_viterbibeta\":1.0}" \
    --path ${average_checkpoint_path}

BeamSearch

Please install dag_search first, see ./dag_search/install.sh for requirements.

If you want to use n-gram LM in BeamSearch, see this to build one before generation.

# OMP_NUM_THREADS sets the number of threads in inference.
# For batch_size = 1, the optimal number is 2. In most cases, do not use more than 8 (It will be slow due to resource conflicts.)

# The algorithm finds the sentence maximize: 1 / |Y|^{alpha} [ log P(Y) + gamma log P_{n-gram}(Y)]
# ``decode_beta`` scales the score of logits. Specifically: log P(Y, A) := sum P(y_i|a_i) + beta * sum log(a_i|a_{i-1})
# ``decode_alpha`` is used for length penalty. ``decode_gamma`` is used for the n-gram language model score.
# ``decode_lm_path`` is the path to the language model
# ``decode_beamsize`` is the beam size; ``decode_top_cand_n`` set the numbers of top candidates when considering transition
# ``decode_top_p`` set the max probability of top candidates when considering transition
# ``decode_max_beam_per_length`` limits the number of beam that has a same length in each step
# ``decode_max_batchsize`` should not be smaller than the real batch size (the value is used for memory allocation)
# ``dedup`` indicates token deduplication.

# BeamSearch without LM
data_dir=/path/to/binarized/data/dir
average_checkpoint_path=/path/to/checkpoint/average.pt
OMP_NUM_THREADS=8  # The cpu number used for beamsearch

fairseq-generate  ${data_dir} \
    --gen-subset test --user-dir fs_plugins --task translation_lev_modified \
    --iter-decode-max-iter 0 --iter-decode-eos-penalty 0 --beam 1 \
    --remove-bpe --batch-size 32 --seed 0 \
    --model-overrides "{\"decode_strategy\": \"beamsearch\", \"decode_beta\": 1, \
        \"decode_alpha\": 1.1, \"decode_gamma\": 0, \
        \"decode_lm_path\": None, \
        \"decode_beamsize\": 200, \"decode_top_cand_n\": 5, \"decode_top_p\": 0.9, \
        \"decode_max_beam_per_length\": 10, \"decode_max_batchsize\": 32, \"decode_dedup\": True}" \
    --path ${average_checkpoint_path}

# BeamSearch with LM
# You should first build the n-gram language model and save it to /path/to/ngram_lm.arpa
fairseq-generate ${data_dir} \
    --gen-subset test --user-dir fs_plugins --task translation_lev_modified \
    --iter-decode-max-iter 0 --iter-decode-eos-penalty 0 --beam 1 \
    --remove-bpe --batch-size 32 --seed 0 \
    --model-overrides "{\"decode_strategy\": \"beamsearch\", \"decode_beta\": 1, \
        \"decode_alpha\": 1.1, \"decode_gamma\": 0.1, \
        \"decode_lm_path\": \"/path/to/ngram_lm.arpa\", \
        \"decode_beamsize\": 200, \"decode_top_cand_n\": 5, \"decode_top_p\": 0.9, \
        \"decode_max_beam_per_length\": 10, \"decode_max_batchsize\": 32, \"decode_dedup\": True}" \
    --path ${average_checkpoint_path}

For WMT17 En-Zh, we add --source-lang en --target-lang zh --tokenizer moses --scoring sacrebleu --sacrebleu-tokenizer zh.

Note: We tune decode_alpha from 1 to 1.4 on the valid set to achieve maximum BLEU, and then apply it to the test set. You can use ./fs_plugins/scripts/test_tradeoff.py to search the optimal alpha.

Note: the latency in our paper is evaluated with batch size of 1. That is, replacing --batch-size 32 by --batch-size 1

Other Scripts

We provide a script to converting a LightSeq checkpoint to a Fairseq checkpoint:

python3 ./fs_plugins/scripts/convert_ls_to_fairseq.py --input path/to/ls_checkpoint.pt --output path/to/fairseq_checkpoint.pt --user_dir ./fs_plugins

FAQs

  1. Cuda Compiled Failed: error: invalid static_cast from type ...

    Check your gcc version. We recommend to use at least gcc 7 since PyTorch seems no longer supporting a lower gcc version.

    If you do not want to upgrade your gcc, here is a workaround (https://zhuanlan.zhihu.com/p/468605263, in Chinese):

    • find the header file /PATH/TO/PYTHONLIB/torch/include/torch/csrc/api/include/torch/nn/cloneable.h

    • modify line 46, 58, 70. Original codes:

      copy->parameters_.size() == parameters_.size()
      copy->buffers_.size() == buffers_.size()
      copy->children_.size() == children_.size()
      

      Modified codes:

      copy->parameters_.size() == this -> parameters_.size()
      copy->buffers_.size() == this -> buffers_.size()
      copy->children_.size() == this -> children_.size()
      
    • Rerun your script

Citing

Please kindly cite us if you find our papers or codes useful.

DA-Transformer:

@inproceedings{huang2022DATransformer,
  author = {Fei Huang and Hao Zhou and Yang Liu and Hang Li and Minlie Huang},
  title = {Directed Acyclic Transformer for Non-Autoregressive Machine Translation},
  booktitle = {Proceedings of the 39th International Conference on Machine Learning, {ICML} 2022},
  year = {2022}
}

Viterbi Decoding:

@inproceedings{shao2022viterbi,
  author = {Chenze Shao and Zhengrui Ma and Yang Feng},
  title = {Viterbi Decoding of Directed Acyclic Transformer for Non-Autoregressive Machine Translation},
  booktitle = {Findings of EMNLP 2022},
  year = {2022}
}

About

Official Implementation for the ICLR2023 paper "Fuzzy Alignments in Directed Acyclic Graph for Non-autoregressive Machine Translation"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 96.1%
  • Cuda 2.6%
  • C++ 0.7%
  • Cython 0.4%
  • Lua 0.1%
  • Shell 0.1%