-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Transformer with integrated pointer-generator network (#2529)
Summary: This pull request implements a variant of the Transformer model that uses an attention distribution for pointing to input words. The attention distribution over the input words is interpolated with the normal output distribution over the vocabulary words, as in [See et al. (2017)](https://arxiv.org/abs/1704.04368). This allows the model to generate words that appear in the input, even if they don't appear in the vocabulary, helping especially with small vocabularies. The mechanism for copying out-of-vocabulary words from the input has been implemented differently to See et al. In their [implementation](https://github.com/abisee/pointer-generator) they convey the word identities through the model in order to be able to produce out-of-vocabulary words. We wanted to minimize changes to the Fairseq code base and took a different approach, which I'll describe below. The entire implementation is contained in one file (plus there's one new test). Copying out-of-vocabulary words is possible by pre-processing the input and post-processing the output. The user may add special words to the end of the vocabulary that can be used in place of `<unk>` tokens to identify different input positions (e.g. `<unk-0>`, `<unk-1>`, `<unk-2>`, ...). The number of these special words is given to the model with the `--source-position-markers` argument—the model simply maps all of these to the same word embedding as `<unk>`. With a simple post-processing the user may retrieve word at position N in the original text and use it in place of `<unk-N>`. I didn't find a good place to document this usage of this model, so let me know if you think I should improve documentation somewhere. This feature has not yet been discussed via a GitHub issue, but I'll open a new issue for discussion. Pull Request resolved: facebookresearch/fairseq#2529 Reviewed By: ngoyal2707 Differential Revision: D23398430 Pulled By: myleott fbshipit-source-id: f2f26c8ce8802ae6cf95515637660348ff3fc457
- Loading branch information
1 parent
53f1357
commit 3b7d85c
Showing
8 changed files
with
972 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# Transformer with Pointer-Generator Network | ||
|
||
This page describes the `transformer_pointer_generator` model that incorporates | ||
a pointing mechanism in the Transformer model that facilitates copying of input | ||
words to the output. This architecture is described in [Enarvi et al. (2020)](https://www.aclweb.org/anthology/2020.nlpmc-1.4/). | ||
|
||
## Background | ||
|
||
The pointer-generator network was introduced in [See et al. (2017)](https://arxiv.org/abs/1704.04368) | ||
for RNN encoder-decoder attention models. A similar mechanism can be | ||
incorporated in a Transformer model by reusing one of the many attention | ||
distributions for pointing. The attention distribution over the input words is | ||
interpolated with the normal output distribution over the vocabulary words. This | ||
allows the model to generate words that appear in the input, even if they don't | ||
appear in the vocabulary, helping especially with small vocabularies. | ||
|
||
## Implementation | ||
|
||
The mechanism for copying out-of-vocabulary words from the input has been | ||
implemented differently to See et al. In their [implementation](https://github.com/abisee/pointer-generator) | ||
they convey the word identities through the model in order to be able to produce | ||
words that appear in the input sequence but not in the vocabulary. A different | ||
approach was taken in the Fairseq implementation to keep it self-contained in | ||
the model file, avoiding any changes to the rest of the code base. Copying | ||
out-of-vocabulary words is possible by pre-processing the input and | ||
post-processing the output. This is described in detail in the next section. | ||
|
||
## Usage | ||
|
||
The training and evaluation procedure is outlined below. You can also find a | ||
more detailed example for the XSum dataset on [this page](README.xsum.md). | ||
|
||
##### 1. Create a vocabulary and extend it with source position markers | ||
|
||
The pointing mechanism is especially helpful with small vocabularies, if we are | ||
able to recover the identities of any out-of-vocabulary words that are copied | ||
from the input. For this purpose, the model allows extending the vocabulary with | ||
special tokens that can be used in place of `<unk>` tokens to identify different | ||
input positions. For example, the user may add `<unk-0>`, `<unk-1>`, `<unk-2>`, | ||
etc. to the end of the vocabulary, after the normal words. Below is an example | ||
of how to create a vocabulary of 10000 most common words and add 1000 input | ||
position markers. | ||
|
||
```bash | ||
vocab_size=10000 | ||
position_markers=1000 | ||
export LC_ALL=C | ||
cat train.src train.tgt | | ||
tr -s '[:space:]' '\n' | | ||
sort | | ||
uniq -c | | ||
sort -k1,1bnr -k2 | | ||
head -n "$((vocab_size - 4))" | | ||
awk '{ print $2 " " $1 }' >dict.pg.txt | ||
python3 -c "[print('<unk-{}> 0'.format(n)) for n in range($position_markers)]" >>dict.pg.txt | ||
``` | ||
|
||
##### 2. Preprocess the text data | ||
|
||
The idea is that any `<unk>` tokens in the text are replaced with `<unk-0>` if | ||
it appears in the first input position, `<unk-1>` if it appears in the second | ||
input position, and so on. This can be achieved using the `preprocess.py` script | ||
that is provided in this directory. | ||
|
||
##### 3. Train a model | ||
|
||
The number of these special tokens is given to the model with the | ||
`--source-position-markers` argument—the model simply maps all of these to the | ||
same word embedding as `<unk>`. | ||
|
||
The attention distribution that is used for pointing is selected using the | ||
`--alignment-heads` and `--alignment-layer` command-line arguments in the same | ||
way as with the `transformer_align` model. | ||
|
||
##### 4. Generate text and postprocess it | ||
|
||
When using the model to generate text, you want to preprocess the input text in | ||
the same way that training data was processed, replacing out-of-vocabulary words | ||
with `<unk-N>` tokens. If any of these tokens are copied to the output, the | ||
actual words can be retrieved from the unprocessed input text. Any `<unk-N>` | ||
token should be replaced with the word at position N in the original input | ||
sequence. This can be achieved using the `postprocess.py` script. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
## Training a pointer-generator model on the Extreme Summarization dataset | ||
|
||
##### 1. Download the Extreme Summarization data and preprocess it | ||
|
||
Follow the instructions [here](https://github.com/EdinburghNLP/XSum) to obtain | ||
the original Extreme Summarization dataset. You should have six files, | ||
{train,validation,test}.{document,summary}. | ||
|
||
##### 2. Create a vocabulary and extend it with source position markers | ||
|
||
```bash | ||
vocab_size=10000 | ||
position_markers=1000 | ||
export LC_ALL=C | ||
cat train.document train.summary | | ||
tr -s '[:space:]' '\n' | | ||
sort | | ||
uniq -c | | ||
sort -k1,1bnr -k2 | | ||
head -n "$((vocab_size - 4))" | | ||
awk '{ print $2 " " $1 }' >dict.pg.txt | ||
python3 -c "[print('<unk-{}> 0'.format(n)) for n in range($position_markers)]" >>dict.pg.txt | ||
``` | ||
|
||
This creates the file dict.pg.txt that contains the 10k most frequent words, | ||
followed by 1k source position markers: | ||
|
||
``` | ||
the 4954867 | ||
. 4157552 | ||
, 3439668 | ||
to 2212159 | ||
a 1916857 | ||
of 1916820 | ||
and 1823350 | ||
... | ||
<unk-0> 0 | ||
<unk-1> 0 | ||
<unk-2> 0 | ||
<unk-3> 0 | ||
<unk-4> 0 | ||
... | ||
``` | ||
|
||
##### 2. Preprocess the text data | ||
|
||
```bash | ||
./preprocess.py --source train.document --target train.summary --vocab <(cut -d' ' -f1 dict.pg.txt) --source-out train.pg.src --target-out train.pg.tgt | ||
./preprocess.py --source validation.document --target validation.summary --vocab <(cut -d' ' -f1 dict.pg.txt) --source-out valid.pg.src --target-out valid.pg.tgt | ||
./preprocess.py --source test.document --vocab <(cut -d' ' -f1 dict.pg.txt) --source-out test.pg.src | ||
``` | ||
|
||
The data should now contain `<unk-N>` tokens in place of out-of-vocabulary words. | ||
|
||
##### 3. Binarize the dataset: | ||
|
||
```bash | ||
fairseq-preprocess \ | ||
--source-lang src \ | ||
--target-lang tgt \ | ||
--trainpref train.pg \ | ||
--validpref valid.pg \ | ||
--destdir bin \ | ||
--workers 60 \ | ||
--srcdict dict.pg.txt \ | ||
--joined-dictionary | ||
``` | ||
|
||
##### 3. Train a model | ||
|
||
```bash | ||
total_updates=20000 | ||
warmup_updates=500 | ||
lr=0.001 | ||
max_tokens=4096 | ||
update_freq=4 | ||
pointer_layer=-2 | ||
|
||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 fairseq-train bin \ | ||
--user-dir examples/pointer_generator/src \ | ||
--max-tokens "$max_tokens" \ | ||
--task translation \ | ||
--source-lang src --target-lang tgt \ | ||
--truncate-source \ | ||
--layernorm-embedding \ | ||
--share-all-embeddings \ | ||
--encoder-normalize-before \ | ||
--decoder-normalize-before \ | ||
--required-batch-size-multiple 1 \ | ||
--arch transformer_pointer_generator \ | ||
--alignment-layer "$pointer_layer" \ | ||
--alignment-heads 1 \ | ||
--source-position-markers 1000 \ | ||
--criterion label_smoothed_cross_entropy \ | ||
--label-smoothing 0.1 \ | ||
--dropout 0.1 --attention-dropout 0.1 \ | ||
--weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-08 \ | ||
--clip-norm 0.1 \ | ||
--lr-scheduler inverse_sqrt --lr "$lr" --max-update "$total_updates" --warmup-updates "$warmup_updates" \ | ||
--update-freq "$update_freq" \ | ||
--skip-invalid-size-inputs-valid-test | ||
``` | ||
|
||
Above we specify that our dictionary contains 1000 source position markers, and | ||
that we want to use one attention head from the penultimate decoder layer for | ||
pointing. It should run in 5.5 hours on one node with eight 32GB V100 GPUs. The | ||
logged messages confirm that dictionary indices above 10000 will be mapped to | ||
the `<unk>` embedding: | ||
|
||
``` | ||
2020-09-24 20:43:53 | INFO | fairseq.tasks.translation | [src] dictionary: 11000 types | ||
2020-09-24 20:43:53 | INFO | fairseq.tasks.translation | [tgt] dictionary: 11000 types | ||
2020-09-24 20:43:53 | INFO | fairseq.data.data_utils | loaded 11332 examples from: bin/valid.src-tgt.src | ||
2020-09-24 20:43:53 | INFO | fairseq.data.data_utils | loaded 11332 examples from: bin/valid.src-tgt.tgt | ||
2020-09-24 20:43:53 | INFO | fairseq.tasks.translation | bin valid src-tgt 11332 examples | ||
2020-09-24 20:43:53 | INFO | fairseq.models.transformer_pg | dictionary indices from 10000 to 10999 will be mapped to 3 | ||
``` | ||
|
||
##### 4. Summarize the test sequences | ||
|
||
```bash | ||
batch_size=32 | ||
beam_size=6 | ||
max_length=60 | ||
length_penalty=1.0 | ||
|
||
fairseq-interactive bin \ | ||
--user-dir examples/pointer_generator/src \ | ||
--batch-size "$batch_size" \ | ||
--task translation \ | ||
--source-lang src --target-lang tgt \ | ||
--path checkpoints/checkpoint_last.pt \ | ||
--input test.pg.src \ | ||
--buffer-size 200 \ | ||
--max-len-a 0 \ | ||
--max-len-b "$max_length" \ | ||
--lenpen "$length_penalty" \ | ||
--beam "$beam_size" \ | ||
--skip-invalid-size-inputs-valid-test | | ||
tee generate.out | ||
grep ^H generate.out | cut -f 3- >generate.hyp | ||
``` | ||
|
||
Now you should have the generated sequences in `generate.hyp`. They contain | ||
`<unk-N>` tokens that the model has copied from the source sequence. In order to | ||
retrieve the original words, we need the unprocessed source sequences from | ||
`test.document`. | ||
|
||
##### 5. Process the generated output | ||
|
||
Since we skipped too long inputs when producing `generate.hyp`, we also have to | ||
skip too long sequences now that we read `test.document`. | ||
|
||
```bash | ||
./postprocess.py \ | ||
--source <(awk 'NF<1024' test.document) \ | ||
--target generate.hyp \ | ||
--target-out generate.hyp.processed | ||
``` | ||
|
||
Now you'll find the final sequences from `generate.hyp.processed`, with | ||
`<unk-N>` replaced with the original word from the source sequence. | ||
|
||
##### An example of a summarized sequence | ||
|
||
The original source document in `test.document`: | ||
|
||
> de roon moved to teesside in june 2016 for an initial # 8.8 m fee and played 33 premier league games last term . the netherlands international , 26 , scored five goals in 36 league and cup games during his spell at boro . meanwhile , manager garry monk confirmed the championship club 's interest in signing chelsea midfielder lewis baker . `` he 's a target and one of many that we 've had throughout the summer months , '' said monk . find all the latest football transfers on our dedicated page . | ||
The preprocessed source document in `test.src.pg`: | ||
|
||
> de \<unk-1> moved to \<unk-4> in june 2016 for an initial # \<unk-12> m fee and played 33 premier league games last term . the netherlands international , 26 , scored five goals in 36 league and cup games during his spell at boro . meanwhile , manager garry monk confirmed the championship club 's interest in signing chelsea midfielder lewis baker . `` he 's a target and one of many that we 've had throughout the summer months , '' said monk . find all the latest football transfers on our dedicated page . | ||
The generated summary in `generate.hyp`: | ||
|
||
> middlesbrough striker \<unk> de \<unk-1> has joined spanish side \<unk> on a season-long loan . | ||
The generated summary after postprocessing in `generate.hyp.processed`: | ||
|
||
> middlesbrough striker \<unk> de roon has joined spanish side \<unk> on a season-long loan . |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import sys | ||
import re | ||
import argparse | ||
|
||
|
||
class OOVIndexError(IndexError): | ||
def __init__(self, pos, source_seq, target_seq): | ||
super(OOVIndexError, self).__init__( | ||
"A <unk-N> tag in the target sequence refers to a position that is " | ||
"outside the source sequence. Most likely there was a mismatch in " | ||
"provided source and target sequences. Otherwise this would mean that " | ||
"the pointing mechanism somehow attended to a position that is past " | ||
"the actual sequence end." | ||
) | ||
self.source_pos = pos | ||
self.source_seq = source_seq | ||
self.target_seq = target_seq | ||
|
||
|
||
def replace_oovs(source_in, target_in, target_out): | ||
"""Replaces <unk-N> tokens in the target text with the corresponding word in | ||
the source text. | ||
""" | ||
|
||
oov_re = re.compile("^<unk-([0-9]+)>$") | ||
|
||
for source_seq, target_seq in zip(source_in, target_in): | ||
target_seq_out = [] | ||
|
||
pos_to_word = source_seq.strip().split() | ||
for token in target_seq.strip().split(): | ||
m = oov_re.match(token) | ||
if m: | ||
pos = int(m.group(1)) | ||
if pos >= len(pos_to_word): | ||
raise OOVIndexError(pos, source_seq, target_seq) | ||
token_out = pos_to_word[pos] | ||
else: | ||
token_out = token | ||
target_seq_out.append(token_out) | ||
target_out.write(" ".join(target_seq_out) + "\n") | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser( | ||
description="Replaces <unk-N> tokens in target sequences with words from " | ||
"the corresponding position in the source sequence." | ||
) | ||
parser.add_argument( | ||
"--source", type=str, help="text file with source sequences", required=True | ||
) | ||
parser.add_argument( | ||
"--target", type=str, help="text file with target sequences", required=True | ||
) | ||
parser.add_argument( | ||
"--target-out", | ||
type=str, | ||
help="where to write target sequences without <unk-N> " "entries", | ||
required=True, | ||
) | ||
args = parser.parse_args() | ||
|
||
target_in = ( | ||
open(args.target, "r", encoding="utf-8") if args.target is not None else None | ||
) | ||
target_out = ( | ||
open(args.target_out, "w", encoding="utf-8") | ||
if args.target_out is not None | ||
else None | ||
) | ||
with open(args.source, "r", encoding="utf-8") as source_in, open( | ||
args.target, "r", encoding="utf-8" | ||
) as target_in, open(args.target_out, "w", encoding="utf-8") as target_out: | ||
replace_oovs(source_in, target_in, target_out) | ||
|
||
|
||
if __name__ == "__main__": | ||
try: | ||
main() | ||
except OOVIndexError as e: | ||
print(e, file=sys.stderr) | ||
print("Source sequence:", e.source_seq.strip(), file=sys.stderr) | ||
print("Target sequence:", e.target_seq.strip(), file=sys.stderr) | ||
print( | ||
"Source sequence length:", | ||
len(e.source_seq.strip().split()), | ||
file=sys.stderr, | ||
) | ||
print("The offending tag points to:", e.source_pos) | ||
sys.exit(2) |
Oops, something went wrong.