Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix generation for hf_gpt2 (facebookresearch#2139)
Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes generation for hf_gpt2. Before fix: ``` env CUDA_VISIBLE_DEVICES="2,3" fairseq-interactive data-bin/wikitext-103 \ --task language_modeling \ --path $PREFIX/checkpoints_gpt2/transformer_wikitext-103/checkpoint_best.pt ``` ![image](https://user-images.githubusercontent.com/22627794/82102310-05021380-96c4-11ea-8073-1ae6919559ba.png) After fix: ![image](https://user-images.githubusercontent.com/22627794/82102316-092e3100-96c4-11ea-825a-c41254ce9efe.png) This test follows the [language model example](https://github.com/pytorch/fairseq/tree/master/examples/language_model), but with hf_gpt2. Trained for one epoch: ``` env CUDA_VISIBLE_DEVICES="2,3" fairseq-train --task language_modeling \ data-bin/wikitext-103 \ --save-dir $PREFIX/checkpoints_gpt2/transformer_wikitext-103 \ --dropout 0.1 \ --optimizer adam --adam-betas '(0.9, 0.98)' --weight-decay 0.01 --clip-norm 0.0 \ --lr 0.0005 --reset-optimizer --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 \ --tokens-per-sample 1024 --sample-break-mode none \ --max-tokens 1024 --update-freq 16 \ --fp16 \ --arch hf_gpt2 --max-target-positions 1024 \ --skip-invalid-size-inputs-valid-test ``` Details of fix: add the unexpected keyword argument, encoder_out, to forward() incremental_state is {}, not None => change to handle this case ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: facebookresearch#2139 Reviewed By: ngoyal2707 Differential Revision: D21663260 Pulled By: myleott fbshipit-source-id: bafbfc7b37d0b49a459e0b64e90da5c13a991d6d
- Loading branch information