Skip to content

Commit

Permalink
Transformer with integrated pointer-generator network (#2529)
Browse files Browse the repository at this point in the history
Summary:
This pull request implements a variant of the Transformer model that uses an attention distribution for pointing to input words. The attention distribution over the input words is interpolated with the normal output distribution over the vocabulary words, as in [See et al. (2017)](https://arxiv.org/abs/1704.04368). This allows the model to generate words that appear in the input, even if they don't appear in the vocabulary, helping especially with small vocabularies.

The mechanism for copying out-of-vocabulary words from the input has been implemented differently to See et al. In their [implementation](https://github.com/abisee/pointer-generator) they convey the word identities through the model in order to be able to produce out-of-vocabulary words. We wanted to minimize changes to the Fairseq code base and took a different approach, which I'll describe below. The entire implementation is contained in one file (plus there's one new test).

Copying out-of-vocabulary words is possible by pre-processing the input and post-processing the output. The user may add special words to the end of the vocabulary that can be used in place of `<unk>` tokens to identify different input positions (e.g. `<unk-0>`, `<unk-1>`, `<unk-2>`, ...). The number of these special words is given to the model with the `--source-position-markers` argument—the model simply maps all of these to the same word embedding as `<unk>`. With a simple post-processing the user may retrieve word at position N in the original text and use it in place of `<unk-N>`.

I didn't find a good place to document this usage of this model, so let me know if you think I should improve documentation somewhere.

This feature has not yet been discussed via a GitHub issue, but I'll open a new issue for discussion.

Pull Request resolved: facebookresearch/fairseq#2529

Reviewed By: ngoyal2707

Differential Revision: D23398430

Pulled By: myleott

fbshipit-source-id: f2f26c8ce8802ae6cf95515637660348ff3fc457
  • Loading branch information
senarvi authored and facebook-github-bot committed Sep 25, 2020
1 parent 53f1357 commit 3b7d85c
Show file tree
Hide file tree
Showing 8 changed files with 972 additions and 0 deletions.
82 changes: 82 additions & 0 deletions examples/pointer_generator/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Transformer with Pointer-Generator Network

This page describes the `transformer_pointer_generator` model that incorporates
a pointing mechanism in the Transformer model that facilitates copying of input
words to the output. This architecture is described in [Enarvi et al. (2020)](https://www.aclweb.org/anthology/2020.nlpmc-1.4/).

## Background

The pointer-generator network was introduced in [See et al. (2017)](https://arxiv.org/abs/1704.04368)
for RNN encoder-decoder attention models. A similar mechanism can be
incorporated in a Transformer model by reusing one of the many attention
distributions for pointing. The attention distribution over the input words is
interpolated with the normal output distribution over the vocabulary words. This
allows the model to generate words that appear in the input, even if they don't
appear in the vocabulary, helping especially with small vocabularies.

## Implementation

The mechanism for copying out-of-vocabulary words from the input has been
implemented differently to See et al. In their [implementation](https://github.com/abisee/pointer-generator)
they convey the word identities through the model in order to be able to produce
words that appear in the input sequence but not in the vocabulary. A different
approach was taken in the Fairseq implementation to keep it self-contained in
the model file, avoiding any changes to the rest of the code base. Copying
out-of-vocabulary words is possible by pre-processing the input and
post-processing the output. This is described in detail in the next section.

## Usage

The training and evaluation procedure is outlined below. You can also find a
more detailed example for the XSum dataset on [this page](README.xsum.md).

##### 1. Create a vocabulary and extend it with source position markers

The pointing mechanism is especially helpful with small vocabularies, if we are
able to recover the identities of any out-of-vocabulary words that are copied
from the input. For this purpose, the model allows extending the vocabulary with
special tokens that can be used in place of `<unk>` tokens to identify different
input positions. For example, the user may add `<unk-0>`, `<unk-1>`, `<unk-2>`,
etc. to the end of the vocabulary, after the normal words. Below is an example
of how to create a vocabulary of 10000 most common words and add 1000 input
position markers.

```bash
vocab_size=10000
position_markers=1000
export LC_ALL=C
cat train.src train.tgt |
tr -s '[:space:]' '\n' |
sort |
uniq -c |
sort -k1,1bnr -k2 |
head -n "$((vocab_size - 4))" |
awk '{ print $2 " " $1 }' >dict.pg.txt
python3 -c "[print('<unk-{}> 0'.format(n)) for n in range($position_markers)]" >>dict.pg.txt
```

##### 2. Preprocess the text data

The idea is that any `<unk>` tokens in the text are replaced with `<unk-0>` if
it appears in the first input position, `<unk-1>` if it appears in the second
input position, and so on. This can be achieved using the `preprocess.py` script
that is provided in this directory.

##### 3. Train a model

The number of these special tokens is given to the model with the
`--source-position-markers` argument—the model simply maps all of these to the
same word embedding as `<unk>`.

The attention distribution that is used for pointing is selected using the
`--alignment-heads` and `--alignment-layer` command-line arguments in the same
way as with the `transformer_align` model.

##### 4. Generate text and postprocess it

When using the model to generate text, you want to preprocess the input text in
the same way that training data was processed, replacing out-of-vocabulary words
with `<unk-N>` tokens. If any of these tokens are copied to the output, the
actual words can be retrieved from the unprocessed input text. Any `<unk-N>`
token should be replaced with the word at position N in the original input
sequence. This can be achieved using the `postprocess.py` script.
180 changes: 180 additions & 0 deletions examples/pointer_generator/README.xsum.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
## Training a pointer-generator model on the Extreme Summarization dataset

##### 1. Download the Extreme Summarization data and preprocess it

Follow the instructions [here](https://github.com/EdinburghNLP/XSum) to obtain
the original Extreme Summarization dataset. You should have six files,
{train,validation,test}.{document,summary}.

##### 2. Create a vocabulary and extend it with source position markers

```bash
vocab_size=10000
position_markers=1000
export LC_ALL=C
cat train.document train.summary |
tr -s '[:space:]' '\n' |
sort |
uniq -c |
sort -k1,1bnr -k2 |
head -n "$((vocab_size - 4))" |
awk '{ print $2 " " $1 }' >dict.pg.txt
python3 -c "[print('<unk-{}> 0'.format(n)) for n in range($position_markers)]" >>dict.pg.txt
```

This creates the file dict.pg.txt that contains the 10k most frequent words,
followed by 1k source position markers:

```
the 4954867
. 4157552
, 3439668
to 2212159
a 1916857
of 1916820
and 1823350
...
<unk-0> 0
<unk-1> 0
<unk-2> 0
<unk-3> 0
<unk-4> 0
...
```

##### 2. Preprocess the text data

```bash
./preprocess.py --source train.document --target train.summary --vocab <(cut -d' ' -f1 dict.pg.txt) --source-out train.pg.src --target-out train.pg.tgt
./preprocess.py --source validation.document --target validation.summary --vocab <(cut -d' ' -f1 dict.pg.txt) --source-out valid.pg.src --target-out valid.pg.tgt
./preprocess.py --source test.document --vocab <(cut -d' ' -f1 dict.pg.txt) --source-out test.pg.src
```

The data should now contain `<unk-N>` tokens in place of out-of-vocabulary words.

##### 3. Binarize the dataset:

```bash
fairseq-preprocess \
--source-lang src \
--target-lang tgt \
--trainpref train.pg \
--validpref valid.pg \
--destdir bin \
--workers 60 \
--srcdict dict.pg.txt \
--joined-dictionary
```

##### 3. Train a model

```bash
total_updates=20000
warmup_updates=500
lr=0.001
max_tokens=4096
update_freq=4
pointer_layer=-2

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 fairseq-train bin \
--user-dir examples/pointer_generator/src \
--max-tokens "$max_tokens" \
--task translation \
--source-lang src --target-lang tgt \
--truncate-source \
--layernorm-embedding \
--share-all-embeddings \
--encoder-normalize-before \
--decoder-normalize-before \
--required-batch-size-multiple 1 \
--arch transformer_pointer_generator \
--alignment-layer "$pointer_layer" \
--alignment-heads 1 \
--source-position-markers 1000 \
--criterion label_smoothed_cross_entropy \
--label-smoothing 0.1 \
--dropout 0.1 --attention-dropout 0.1 \
--weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-08 \
--clip-norm 0.1 \
--lr-scheduler inverse_sqrt --lr "$lr" --max-update "$total_updates" --warmup-updates "$warmup_updates" \
--update-freq "$update_freq" \
--skip-invalid-size-inputs-valid-test
```

Above we specify that our dictionary contains 1000 source position markers, and
that we want to use one attention head from the penultimate decoder layer for
pointing. It should run in 5.5 hours on one node with eight 32GB V100 GPUs. The
logged messages confirm that dictionary indices above 10000 will be mapped to
the `<unk>` embedding:

```
2020-09-24 20:43:53 | INFO | fairseq.tasks.translation | [src] dictionary: 11000 types
2020-09-24 20:43:53 | INFO | fairseq.tasks.translation | [tgt] dictionary: 11000 types
2020-09-24 20:43:53 | INFO | fairseq.data.data_utils | loaded 11332 examples from: bin/valid.src-tgt.src
2020-09-24 20:43:53 | INFO | fairseq.data.data_utils | loaded 11332 examples from: bin/valid.src-tgt.tgt
2020-09-24 20:43:53 | INFO | fairseq.tasks.translation | bin valid src-tgt 11332 examples
2020-09-24 20:43:53 | INFO | fairseq.models.transformer_pg | dictionary indices from 10000 to 10999 will be mapped to 3
```

##### 4. Summarize the test sequences

```bash
batch_size=32
beam_size=6
max_length=60
length_penalty=1.0

fairseq-interactive bin \
--user-dir examples/pointer_generator/src \
--batch-size "$batch_size" \
--task translation \
--source-lang src --target-lang tgt \
--path checkpoints/checkpoint_last.pt \
--input test.pg.src \
--buffer-size 200 \
--max-len-a 0 \
--max-len-b "$max_length" \
--lenpen "$length_penalty" \
--beam "$beam_size" \
--skip-invalid-size-inputs-valid-test |
tee generate.out
grep ^H generate.out | cut -f 3- >generate.hyp
```

Now you should have the generated sequences in `generate.hyp`. They contain
`<unk-N>` tokens that the model has copied from the source sequence. In order to
retrieve the original words, we need the unprocessed source sequences from
`test.document`.

##### 5. Process the generated output

Since we skipped too long inputs when producing `generate.hyp`, we also have to
skip too long sequences now that we read `test.document`.

```bash
./postprocess.py \
--source <(awk 'NF<1024' test.document) \
--target generate.hyp \
--target-out generate.hyp.processed
```

Now you'll find the final sequences from `generate.hyp.processed`, with
`<unk-N>` replaced with the original word from the source sequence.

##### An example of a summarized sequence

The original source document in `test.document`:

> de roon moved to teesside in june 2016 for an initial # 8.8 m fee and played 33 premier league games last term . the netherlands international , 26 , scored five goals in 36 league and cup games during his spell at boro . meanwhile , manager garry monk confirmed the championship club 's interest in signing chelsea midfielder lewis baker . `` he 's a target and one of many that we 've had throughout the summer months , '' said monk . find all the latest football transfers on our dedicated page .
The preprocessed source document in `test.src.pg`:

> de \<unk-1> moved to \<unk-4> in june 2016 for an initial # \<unk-12> m fee and played 33 premier league games last term . the netherlands international , 26 , scored five goals in 36 league and cup games during his spell at boro . meanwhile , manager garry monk confirmed the championship club 's interest in signing chelsea midfielder lewis baker . `` he 's a target and one of many that we 've had throughout the summer months , '' said monk . find all the latest football transfers on our dedicated page .
The generated summary in `generate.hyp`:

> middlesbrough striker \<unk> de \<unk-1> has joined spanish side \<unk> on a season-long loan .
The generated summary after postprocessing in `generate.hyp.processed`:

> middlesbrough striker \<unk> de roon has joined spanish side \<unk> on a season-long loan .
96 changes: 96 additions & 0 deletions examples/pointer_generator/postprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import sys
import re
import argparse


class OOVIndexError(IndexError):
def __init__(self, pos, source_seq, target_seq):
super(OOVIndexError, self).__init__(
"A <unk-N> tag in the target sequence refers to a position that is "
"outside the source sequence. Most likely there was a mismatch in "
"provided source and target sequences. Otherwise this would mean that "
"the pointing mechanism somehow attended to a position that is past "
"the actual sequence end."
)
self.source_pos = pos
self.source_seq = source_seq
self.target_seq = target_seq


def replace_oovs(source_in, target_in, target_out):
"""Replaces <unk-N> tokens in the target text with the corresponding word in
the source text.
"""

oov_re = re.compile("^<unk-([0-9]+)>$")

for source_seq, target_seq in zip(source_in, target_in):
target_seq_out = []

pos_to_word = source_seq.strip().split()
for token in target_seq.strip().split():
m = oov_re.match(token)
if m:
pos = int(m.group(1))
if pos >= len(pos_to_word):
raise OOVIndexError(pos, source_seq, target_seq)
token_out = pos_to_word[pos]
else:
token_out = token
target_seq_out.append(token_out)
target_out.write(" ".join(target_seq_out) + "\n")


def main():
parser = argparse.ArgumentParser(
description="Replaces <unk-N> tokens in target sequences with words from "
"the corresponding position in the source sequence."
)
parser.add_argument(
"--source", type=str, help="text file with source sequences", required=True
)
parser.add_argument(
"--target", type=str, help="text file with target sequences", required=True
)
parser.add_argument(
"--target-out",
type=str,
help="where to write target sequences without <unk-N> " "entries",
required=True,
)
args = parser.parse_args()

target_in = (
open(args.target, "r", encoding="utf-8") if args.target is not None else None
)
target_out = (
open(args.target_out, "w", encoding="utf-8")
if args.target_out is not None
else None
)
with open(args.source, "r", encoding="utf-8") as source_in, open(
args.target, "r", encoding="utf-8"
) as target_in, open(args.target_out, "w", encoding="utf-8") as target_out:
replace_oovs(source_in, target_in, target_out)


if __name__ == "__main__":
try:
main()
except OOVIndexError as e:
print(e, file=sys.stderr)
print("Source sequence:", e.source_seq.strip(), file=sys.stderr)
print("Target sequence:", e.target_seq.strip(), file=sys.stderr)
print(
"Source sequence length:",
len(e.source_seq.strip().split()),
file=sys.stderr,
)
print("The offending tag points to:", e.source_pos)
sys.exit(2)
Loading

0 comments on commit 3b7d85c

Please sign in to comment.