Skip to content

Commit

Permalink
Fix generation for hf_gpt2 (facebookresearch#2139)
Browse files Browse the repository at this point in the history
Summary:
# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
- [x] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?

## What does this PR do?

Fixes generation for hf_gpt2.

Before fix:
```
env CUDA_VISIBLE_DEVICES="2,3" fairseq-interactive data-bin/wikitext-103 \
  --task language_modeling \
  --path $PREFIX/checkpoints_gpt2/transformer_wikitext-103/checkpoint_best.pt
```
![image](https://user-images.githubusercontent.com/22627794/82102310-05021380-96c4-11ea-8073-1ae6919559ba.png)

After fix:
![image](https://user-images.githubusercontent.com/22627794/82102316-092e3100-96c4-11ea-825a-c41254ce9efe.png)

This test follows the [language model example](https://github.com/pytorch/fairseq/tree/master/examples/language_model), but with hf_gpt2.

Trained for one epoch:
```
env CUDA_VISIBLE_DEVICES="2,3" fairseq-train --task language_modeling \
  data-bin/wikitext-103 \
  --save-dir $PREFIX/checkpoints_gpt2/transformer_wikitext-103 \
  --dropout 0.1 \
  --optimizer adam --adam-betas '(0.9, 0.98)' --weight-decay 0.01 --clip-norm 0.0 \
  --lr 0.0005 --reset-optimizer --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 \
  --tokens-per-sample 1024 --sample-break-mode none \
  --max-tokens 1024 --update-freq 16 \
  --fp16 \
  --arch hf_gpt2 --max-target-positions 1024 \
  --skip-invalid-size-inputs-valid-test
```

Details of fix:
add the unexpected keyword argument, encoder_out, to forward()
incremental_state is {}, not None => change to handle this case

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �
Pull Request resolved: facebookresearch#2139

Reviewed By: ngoyal2707

Differential Revision: D21663260

Pulled By: myleott

fbshipit-source-id: bafbfc7b37d0b49a459e0b64e90da5c13a991d6d
  • Loading branch information
timcheck authored and facebook-github-bot committed May 20, 2020
1 parent fa6a3cd commit eedf27a
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions fairseq/models/huggingface/hf_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def forward(
prev_output_tokens,
src_lengths=None,
incremental_state: Optional[Dict[str, List[torch.Tensor]]] = None,
encoder_out=None,
):
features = self.extract_features(prev_output_tokens, incremental_state)
lm_logits = self.model.lm_head(features)
Expand All @@ -109,7 +110,7 @@ def extract_features(
prev_output_tokens,
incremental_state: Optional[Dict[str, List[torch.Tensor]]] = None,
):
if incremental_state is not None:
if incremental_state:
past = self.get_incremental_state("past")
else:
past = None
Expand All @@ -132,7 +133,7 @@ def extract_features(
)
last_hidden_states = outputs[0]

if incremental_state is not None:
if incremental_state:
self.set_incremental_state(incremental_state, "past", outputs[1])

return last_hidden_states
Expand Down

0 comments on commit eedf27a

Please sign in to comment.