-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
kyubyong park
authored and
kyubyong park
committed
Feb 18, 2019
1 parent
ed2deb8
commit 85e2dd9
Showing
51 changed files
with
233,460 additions
and
208,952 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,91 +1,82 @@ | ||
# A TensorFlow Implementation of the Transformer: Attention Is All You Need | ||
# **[UPDATED]** A TensorFlow Implementation of [Attention Is All You Need](https://arxiv.org/abs/1706.03762) | ||
|
||
When I opened this repository in 2017, there was no official code yet. | ||
I tried to implement the paper as I understood, but to no surprise | ||
it had several bugs. I realized them mostly thanks to people who issued here, so | ||
I'm very grateful to all of them. Though there is the [official implementation](https://github.com/tensorflow/tensor2tensor) as well as | ||
several other unofficial github repos, I decided to update my own one. | ||
This update focuses on: | ||
* readable / understandable code writing | ||
* modularization (but not too much) | ||
* revising known bugs. (masking, positional encoding, ...) | ||
* updating to TF1.12. (tf.data, ...) | ||
* adding some missing components (bpe, shared weight matrix, ...) | ||
* including useful comments in the code. | ||
|
||
I still stick to IWSLT 2016 de-en. I guess if you'd like to test on a big data such | ||
as WMT, you would rely on the official implementation. | ||
After all, it's pleasant to check quickly if your model works. | ||
The initial code for TF1.2 is moved to the [tf1.2_lecacy](tf1.2_legacy) folder for the record. | ||
|
||
## Requirements | ||
* NumPy >= 1.11.1 | ||
* TensorFlow >= 1.2 (Probably 1.1 should work, too, though I didn't test it) | ||
* regex | ||
* nltk | ||
|
||
## Why This Project? | ||
I tried to implement the idea in [Attention Is All You Need](https://arxiv.org/abs/1706.03762). They authors claimed that their model, the Transformer, outperformed the state-of-the-art one in machine translation with only attention, no CNNs, no RNNs. How cool it is! At the end of the paper, they promise they will make their code available soon, but apparently it is not so yet. I have two goals with this project. One is I wanted to have a full understanding of the paper. Often it's hard for me to have a good grasp before writing some code for it. Another is to share my code with people who are interested in this model before the official code is unveiled. | ||
|
||
## Differences with the original paper | ||
I don't intend to replicate the paper exactly. Rather, I aim to implement the main ideas in the paper and verify them in a SIMPLE and QUICK way. In this respect, some parts in my code are different than those in the paper. Among them are | ||
* I used the IWSLT 2016 de-en dataset, not the wmt dataset because the former is much smaller, and requires no special preprocessing. | ||
* I constructed vocabulary with words, not subwords for simplicity. Of course, you can try bpe or word-piece if you want. | ||
* I parameterized positional encoding. The paper used some sinusoidal formula, but Noam, one of the authors, says they both work. See the [discussion in reddit](https://www.reddit.com/r/MachineLearning/comments/6gwqiw/r_170603762_attention_is_all_you_need_sota_nmt/) | ||
* The paper adjusted the learning rate to global steps. I fixed the learning to a small number, 0.0001 simply because training was reasonably fast enough with the small dataset (Only a couple of hours on a single GTX 1060!!). | ||
|
||
## File description | ||
* `hyperparams.py` includes all hyper parameters that are needed. | ||
* `prepro.py` creates vocabulary files for the source and the target. | ||
* `data_load.py` contains functions regarding loading and batching data. | ||
* `modules.py` has all building blocks for encoder/decoder networks. | ||
* `train.py` has the model. | ||
* `eval.py` is for evaluation. | ||
* python==3.x (Let's move on to python 3 if you still use python 2) | ||
* tensorflow==1.12.0 | ||
* numpy>=1.15.4 | ||
* sentencepiece==0.1.8 | ||
* tqdm>=4.28.1 | ||
|
||
## Training | ||
* STEP 1. Download [IWSLT 2016 German–English parallel corpus](https://wit3.fbk.eu/download.php?release=2016-01&type=texts&slang=de&tlang=en) and extract it to `corpora/` folder. | ||
```sh | ||
wget -qO- --show-progress https://wit3.fbk.eu/archive/2016-01//texts/de/en/de-en.tgz | tar xz; mv de-en corpora | ||
* STEP 1. Run `bash download.sh` to download [IWSLT 2016 German–English parallel corpus](https://wit3.fbk.eu/download.php?release=2016-01&type=texts&slang=de&tlang=en). | ||
It should be extracted to `iwslt2016/de-en` folder automatically. | ||
* STEP 2. Run the command below to create preprocessed train/eval/test data. | ||
``` | ||
* STEP 2. Adjust hyper parameters in `hyperparams.py` if necessary. | ||
* STEP 3. Run `prepro.py` to generate vocabulary files to the `preprocessed` folder. | ||
* STEP 4. Run `train.py` or download the [pretrained files](https://www.dropbox.com/s/fo5wqgnbmvalwwq/logdir.zip?dl=0). | ||
|
||
## Training Loss and Accuracy | ||
* Training Loss | ||
<img src="fig/mean_loss.png"> | ||
|
||
* Training Accuracy | ||
<img src="fig/accuracy.png"> | ||
|
||
## Evaluation | ||
* Run `eval.py`. | ||
|
||
## Results | ||
I got a BLEU score of 17.14. (Recollect I trained with a small dataset, limited vocabulary) Some of the evaluation results are as follows. Details are available in the `results` folder. | ||
|
||
source: Sie war eine jährige Frau namens Alex<br> | ||
expected: She was a yearold woman named Alex<br> | ||
got: She was a woman named yearold name | ||
python prepro.py | ||
``` | ||
If you want to change the vocabulary size (default:32000), do this. | ||
``` | ||
python prepro.py --vocab_size 8000 | ||
``` | ||
It should create two folders `iwslt2016/prepro` and `iwslt2016/segmented`. | ||
|
||
source: Und als ich das hörte war ich erleichtert<br> | ||
expected: Now when I heard this I was so relieved<br> | ||
got: And when I heard that I was an <UNK> | ||
* STEP 3. Run the following command. | ||
``` | ||
python train.py | ||
``` | ||
Don't forget to check TensorBoard. (scalar / image / text) | ||
Check `hparams.py` to see which parameters are possible. For example, | ||
``` | ||
python train.py --logdir myLog --batch_size 256 --dropout_rate 0.5 | ||
``` | ||
|
||
source: Meine Kommilitonin bekam nämlich einen Brandstifter als ersten Patienten<br> | ||
expected: My classmate got an arsonist for her first client<br> | ||
got: Because my first <UNK> came from an in patients | ||
* STEP 3. Or download the pretrained models. | ||
``` | ||
wget -qO- --show-progress https://dl.dropbox.com/s/4o7zwef7kzma4q4/log.tar.gz | tar xz | ||
``` | ||
|
||
source: Das kriege ich hin dachte ich mir<br> | ||
expected: This I thought I could handle<br> | ||
got: I'll go ahead and I thought | ||
## Training Loss Curve | ||
<img src="fig/loss.png"> | ||
|
||
source: Aber ich habe es nicht hingekriegt<br> | ||
expected: But I didn't handle it<br> | ||
got: But I didn't <UNK> it | ||
## Learning rate | ||
<img src="fig/lr.png"> | ||
|
||
source: Ich hielt dagegen<br> | ||
expected: I pushed back<br> | ||
got: I thought about it | ||
## Bleu score on devset | ||
<img src="fig/bleu.png"> | ||
|
||
source: Das ist es was Psychologen einen AhaMoment nennen<br> | ||
expected: That's what psychologists call an Aha moment<br> | ||
got: That's what a <UNK> like a <UNK> | ||
|
||
source: Meldet euch wenn ihr in euren ern seid<br> | ||
expected: Raise your hand if you're in your s<br> | ||
got: Get yourself in your s | ||
## Inference (=test) | ||
* Run | ||
``` | ||
python test.py --ckpt log/1/iwslt2016_E19L2.62-29146 (OR yourCkptFile OR yourCkptFileDirectory) | ||
``` | ||
|
||
source: Ich möchte ein paar von euch sehen<br> | ||
expected: I really want to see some twentysomethings here<br> | ||
got: I want to see some of you | ||
## Results | ||
* Typically, machine translation is evaluated with Bleu score. | ||
* All evaluation results are available in [eval/1](eval/1) and [test/1](test/1). | ||
|
||
source: Oh yeah Ihr seid alle unglaublich<br> | ||
expected: Oh yay Y'all's awesome<br> | ||
got: Oh yeah you all are incredibly | ||
|tst2013 (dev) | tst2014 (test) | | ||
|--|--| | ||
|26.93|23.16| | ||
|
||
source: Dies ist nicht meine Meinung Das sind Fakten<br> | ||
expected: This is not my opinion These are the facts<br> | ||
got: This is not my opinion These are facts | ||
## Notes | ||
* Beam decoding will be added soon. | ||
* I'm going to update the code when TF2.0 comes if possible. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,93 +1,150 @@ | ||
# -*- coding: utf-8 -*- | ||
#/usr/bin/python2 | ||
#/usr/bin/python3 | ||
''' | ||
June 2017 by kyubyong park. | ||
Feb. 2019 by kyubyong park. | ||
[email protected]. | ||
https://www.github.com/kyubyong/transformer | ||
Note. | ||
if safe, entities on the source side have the prefix 1, and the target side 2, for convenience. | ||
For example, fpath1, fpath2 means source file path and target file path, respectively. | ||
''' | ||
from __future__ import print_function | ||
from hyperparams import Hyperparams as hp | ||
import tensorflow as tf | ||
import numpy as np | ||
import codecs | ||
import regex | ||
|
||
def load_de_vocab(): | ||
vocab = [line.split()[0] for line in codecs.open('preprocessed/de.vocab.tsv', 'r', 'utf-8').read().splitlines() if int(line.split()[1])>=hp.min_cnt] | ||
word2idx = {word: idx for idx, word in enumerate(vocab)} | ||
idx2word = {idx: word for idx, word in enumerate(vocab)} | ||
return word2idx, idx2word | ||
|
||
def load_en_vocab(): | ||
vocab = [line.split()[0] for line in codecs.open('preprocessed/en.vocab.tsv', 'r', 'utf-8').read().splitlines() if int(line.split()[1])>=hp.min_cnt] | ||
word2idx = {word: idx for idx, word in enumerate(vocab)} | ||
idx2word = {idx: word for idx, word in enumerate(vocab)} | ||
return word2idx, idx2word | ||
|
||
def create_data(source_sents, target_sents): | ||
de2idx, idx2de = load_de_vocab() | ||
en2idx, idx2en = load_en_vocab() | ||
|
||
# Index | ||
x_list, y_list, Sources, Targets = [], [], [], [] | ||
for source_sent, target_sent in zip(source_sents, target_sents): | ||
x = [de2idx.get(word, 1) for word in (source_sent + u" </S>").split()] # 1: OOV, </S>: End of Text | ||
y = [en2idx.get(word, 1) for word in (target_sent + u" </S>").split()] | ||
if max(len(x), len(y)) <=hp.maxlen: | ||
x_list.append(np.array(x)) | ||
y_list.append(np.array(y)) | ||
Sources.append(source_sent) | ||
Targets.append(target_sent) | ||
|
||
# Pad | ||
X = np.zeros([len(x_list), hp.maxlen], np.int32) | ||
Y = np.zeros([len(y_list), hp.maxlen], np.int32) | ||
for i, (x, y) in enumerate(zip(x_list, y_list)): | ||
X[i] = np.lib.pad(x, [0, hp.maxlen-len(x)], 'constant', constant_values=(0, 0)) | ||
Y[i] = np.lib.pad(y, [0, hp.maxlen-len(y)], 'constant', constant_values=(0, 0)) | ||
|
||
return X, Y, Sources, Targets | ||
|
||
def load_train_data(): | ||
de_sents = [regex.sub("[^\s\p{Latin}']", "", line) for line in codecs.open(hp.source_train, 'r', 'utf-8').read().split("\n") if line and line[0] != "<"] | ||
en_sents = [regex.sub("[^\s\p{Latin}']", "", line) for line in codecs.open(hp.target_train, 'r', 'utf-8').read().split("\n") if line and line[0] != "<"] | ||
|
||
X, Y, Sources, Targets = create_data(de_sents, en_sents) | ||
return X, Y | ||
|
||
def load_test_data(): | ||
def _refine(line): | ||
line = regex.sub("<[^>]+>", "", line) | ||
line = regex.sub("[^\s\p{Latin}']", "", line) | ||
return line.strip() | ||
|
||
de_sents = [_refine(line) for line in codecs.open(hp.source_test, 'r', 'utf-8').read().split("\n") if line and line[:4] == "<seg"] | ||
en_sents = [_refine(line) for line in codecs.open(hp.target_test, 'r', 'utf-8').read().split("\n") if line and line[:4] == "<seg"] | ||
|
||
X, Y, Sources, Targets = create_data(de_sents, en_sents) | ||
return X, Sources, Targets # (1064, 150) | ||
|
||
def get_batch_data(): | ||
# Load data | ||
X, Y = load_train_data() | ||
|
||
# calc total batch count | ||
num_batch = len(X) // hp.batch_size | ||
|
||
# Convert to tensor | ||
X = tf.convert_to_tensor(X, tf.int32) | ||
Y = tf.convert_to_tensor(Y, tf.int32) | ||
|
||
# Create Queues | ||
input_queues = tf.train.slice_input_producer([X, Y]) | ||
|
||
# create batch queues | ||
x, y = tf.train.shuffle_batch(input_queues, | ||
num_threads=8, | ||
batch_size=hp.batch_size, | ||
capacity=hp.batch_size*64, | ||
min_after_dequeue=hp.batch_size*32, | ||
allow_smaller_final_batch=False) | ||
|
||
return x, y, num_batch # (N, T), (N, T), () | ||
from utils import calc_num_batches | ||
|
||
def load_vocab(vocab_fpath): | ||
'''Loads vocabulary file and returns idx<->token maps | ||
vocab_fpath: string. vocabulary file path. | ||
Note that these are reserved | ||
0: <pad>, 1: <unk>, 2: <s>, 3: </s> | ||
Returns | ||
two dictionaries. | ||
''' | ||
vocab = [line.split()[0] for line in open(vocab_fpath, 'r').read().splitlines()] | ||
token2idx = {token: idx for idx, token in enumerate(vocab)} | ||
idx2token = {idx: token for idx, token in enumerate(vocab)} | ||
return token2idx, idx2token | ||
|
||
def load_data(fpath1, fpath2, maxlen1, maxlen2): | ||
'''Loads source and target data and filters out too lengthy samples. | ||
fpath1: source file path. string. | ||
fpath2: target file path. string. | ||
maxlen1: source sent maximum length. scalar. | ||
maxlen2: target sent maximum length. scalar. | ||
Returns | ||
sents1: list of source sents | ||
sents2: list of target sents | ||
''' | ||
sents1, sents2 = [], [] | ||
with open(fpath1, 'r') as f1, open(fpath2, 'r') as f2: | ||
for sent1, sent2 in zip(f1, f2): | ||
if len(sent1.split()) + 1 > maxlen1: continue # 1: </s> | ||
if len(sent2.split()) + 1 > maxlen2: continue # 1: </s> | ||
sents1.append(sent1.strip()) | ||
sents2.append(sent2.strip()) | ||
return sents1, sents2 | ||
|
||
|
||
def encode(inp, type, dict): | ||
'''Converts string to number. Used for `generator_fn`. | ||
inp: 1d byte array. | ||
type: "x" (source side) or "y" (target side) | ||
dict: token2idx dictionary | ||
Returns | ||
list of numbers | ||
''' | ||
inp_str = inp.decode("utf-8") | ||
if type=="x": tokens = inp_str.split() + ["</s>"] | ||
else: tokens = ["<s>"] + inp_str.split() + ["</s>"] | ||
|
||
x = [dict.get(t, dict["<unk>"]) for t in tokens] | ||
return x | ||
|
||
def generator_fn(sents1, sents2, vocab_fpath): | ||
'''Generates training / evaluation data | ||
sents1: list of source sents | ||
sents2: list of target sents | ||
vocab_fpath: string. vocabulary file path. | ||
yields | ||
xs: tuple of | ||
x: list of source token ids in a sent | ||
x_seqlen: int. sequence length of x | ||
sent1: str. raw source (=input) sentence | ||
labels: tuple of | ||
decoder_input: decoder_input: list of encoded decoder inputs | ||
y: list of target token ids in a sent | ||
y_seqlen: int. sequence length of y | ||
sent2: str. target sentence | ||
''' | ||
token2idx, _ = load_vocab(vocab_fpath) | ||
for sent1, sent2 in zip(sents1, sents2): | ||
x = encode(sent1, "x", token2idx) | ||
y = encode(sent2, "y", token2idx) | ||
decoder_input, y = y[:-1], y[1:] | ||
|
||
x_seqlen, y_seqlen = len(x), len(y) | ||
yield (x, x_seqlen, sent1), (decoder_input, y, y_seqlen, sent2) | ||
|
||
def input_fn(sents1, sents2, vocab_fpath, batch_size, shuffle=False): | ||
'''Batchify data | ||
sents1: list of source sents | ||
sents2: list of target sents | ||
vocab_fpath: string. vocabulary file path. | ||
batch_size: scalar | ||
shuffle: boolean | ||
Returns | ||
xs: tuple of | ||
x: int32 tensor. (N, T1) | ||
x_seqlens: int32 tensor. (N,) | ||
sents1: str tensor. (N,) | ||
ys: tuple of | ||
decoder_input: int32 tensor. (N, T2) | ||
y: int32 tensor. (N, T2) | ||
y_seqlen: int32 tensor. (N, ) | ||
sents2: str tensor. (N,) | ||
''' | ||
shapes = (([None], (), ()), | ||
([None], [None], (), ())) | ||
types = ((tf.int32, tf.int32, tf.string), | ||
(tf.int32, tf.int32, tf.int32, tf.string)) | ||
paddings = ((0, 0, ''), | ||
(0, 0, 0, '')) | ||
|
||
dataset = tf.data.Dataset.from_generator( | ||
generator_fn, | ||
output_shapes=shapes, | ||
output_types=types, | ||
args=(sents1, sents2, vocab_fpath)) # <- arguments for generator_fn. converted to np string arrays | ||
|
||
if shuffle: # for training | ||
dataset = dataset.shuffle(128*batch_size) | ||
|
||
dataset = dataset.repeat() # iterate forever | ||
dataset = dataset.padded_batch(batch_size, shapes, paddings).prefetch(1) | ||
|
||
return dataset | ||
|
||
def get_batch(fpath1, fpath2, maxlen1, maxlen2, vocab_fpath, batch_size, shuffle=False): | ||
'''Gets training / evaluation mini-batches | ||
fpath1: source file path. string. | ||
fpath2: target file path. string. | ||
maxlen1: source sent maximum length. scalar. | ||
maxlen2: target sent maximum length. scalar. | ||
vocab_fpath: string. vocabulary file path. | ||
batch_size: scalar | ||
shuffle: boolean | ||
Returns | ||
batches | ||
num_batches: number of mini-batches | ||
num_samples | ||
''' | ||
sents1, sents2 = load_data(fpath1, fpath2, maxlen1, maxlen2) | ||
batches = input_fn(sents1, sents2, vocab_fpath, batch_size, shuffle=shuffle) | ||
num_batches = calc_num_batches(len(sents1), batch_size) | ||
return batches, num_batches, len(sents1) |
Oops, something went wrong.