Skip to content

Commit

Permalink
updated to TF1.12
Browse files Browse the repository at this point in the history
  • Loading branch information
kyubyong park authored and kyubyong park committed Feb 18, 2019
1 parent ed2deb8 commit 85e2dd9
Show file tree
Hide file tree
Showing 51 changed files with 233,460 additions and 208,952 deletions.
143 changes: 67 additions & 76 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,91 +1,82 @@
# A TensorFlow Implementation of the Transformer: Attention Is All You Need
# **[UPDATED]** A TensorFlow Implementation of [Attention Is All You Need](https://arxiv.org/abs/1706.03762)

When I opened this repository in 2017, there was no official code yet.
I tried to implement the paper as I understood, but to no surprise
it had several bugs. I realized them mostly thanks to people who issued here, so
I'm very grateful to all of them. Though there is the [official implementation](https://github.com/tensorflow/tensor2tensor) as well as
several other unofficial github repos, I decided to update my own one.
This update focuses on:
* readable / understandable code writing
* modularization (but not too much)
* revising known bugs. (masking, positional encoding, ...)
* updating to TF1.12. (tf.data, ...)
* adding some missing components (bpe, shared weight matrix, ...)
* including useful comments in the code.

I still stick to IWSLT 2016 de-en. I guess if you'd like to test on a big data such
as WMT, you would rely on the official implementation.
After all, it's pleasant to check quickly if your model works.
The initial code for TF1.2 is moved to the [tf1.2_lecacy](tf1.2_legacy) folder for the record.

## Requirements
* NumPy >= 1.11.1
* TensorFlow >= 1.2 (Probably 1.1 should work, too, though I didn't test it)
* regex
* nltk

## Why This Project?
I tried to implement the idea in [Attention Is All You Need](https://arxiv.org/abs/1706.03762). They authors claimed that their model, the Transformer, outperformed the state-of-the-art one in machine translation with only attention, no CNNs, no RNNs. How cool it is! At the end of the paper, they promise they will make their code available soon, but apparently it is not so yet. I have two goals with this project. One is I wanted to have a full understanding of the paper. Often it's hard for me to have a good grasp before writing some code for it. Another is to share my code with people who are interested in this model before the official code is unveiled.

## Differences with the original paper
I don't intend to replicate the paper exactly. Rather, I aim to implement the main ideas in the paper and verify them in a SIMPLE and QUICK way. In this respect, some parts in my code are different than those in the paper. Among them are
* I used the IWSLT 2016 de-en dataset, not the wmt dataset because the former is much smaller, and requires no special preprocessing.
* I constructed vocabulary with words, not subwords for simplicity. Of course, you can try bpe or word-piece if you want.
* I parameterized positional encoding. The paper used some sinusoidal formula, but Noam, one of the authors, says they both work. See the [discussion in reddit](https://www.reddit.com/r/MachineLearning/comments/6gwqiw/r_170603762_attention_is_all_you_need_sota_nmt/)
* The paper adjusted the learning rate to global steps. I fixed the learning to a small number, 0.0001 simply because training was reasonably fast enough with the small dataset (Only a couple of hours on a single GTX 1060!!).

## File description
* `hyperparams.py` includes all hyper parameters that are needed.
* `prepro.py` creates vocabulary files for the source and the target.
* `data_load.py` contains functions regarding loading and batching data.
* `modules.py` has all building blocks for encoder/decoder networks.
* `train.py` has the model.
* `eval.py` is for evaluation.
* python==3.x (Let's move on to python 3 if you still use python 2)
* tensorflow==1.12.0
* numpy>=1.15.4
* sentencepiece==0.1.8
* tqdm>=4.28.1

## Training
* STEP 1. Download [IWSLT 2016 German–English parallel corpus](https://wit3.fbk.eu/download.php?release=2016-01&type=texts&slang=de&tlang=en) and extract it to `corpora/` folder.
```sh
wget -qO- --show-progress https://wit3.fbk.eu/archive/2016-01//texts/de/en/de-en.tgz | tar xz; mv de-en corpora
* STEP 1. Run `bash download.sh` to download [IWSLT 2016 German–English parallel corpus](https://wit3.fbk.eu/download.php?release=2016-01&type=texts&slang=de&tlang=en).
It should be extracted to `iwslt2016/de-en` folder automatically.
* STEP 2. Run the command below to create preprocessed train/eval/test data.
```
* STEP 2. Adjust hyper parameters in `hyperparams.py` if necessary.
* STEP 3. Run `prepro.py` to generate vocabulary files to the `preprocessed` folder.
* STEP 4. Run `train.py` or download the [pretrained files](https://www.dropbox.com/s/fo5wqgnbmvalwwq/logdir.zip?dl=0).

## Training Loss and Accuracy
* Training Loss
<img src="fig/mean_loss.png">

* Training Accuracy
<img src="fig/accuracy.png">

## Evaluation
* Run `eval.py`.

## Results
I got a BLEU score of 17.14. (Recollect I trained with a small dataset, limited vocabulary) Some of the evaluation results are as follows. Details are available in the `results` folder.

source: Sie war eine jährige Frau namens Alex<br>
expected: She was a yearold woman named Alex<br>
got: She was a woman named yearold name
python prepro.py
```
If you want to change the vocabulary size (default:32000), do this.
```
python prepro.py --vocab_size 8000
```
It should create two folders `iwslt2016/prepro` and `iwslt2016/segmented`.

source: Und als ich das hörte war ich erleichtert<br>
expected: Now when I heard this I was so relieved<br>
got: And when I heard that I was an <UNK>
* STEP 3. Run the following command.
```
python train.py
```
Don't forget to check TensorBoard. (scalar / image / text)
Check `hparams.py` to see which parameters are possible. For example,
```
python train.py --logdir myLog --batch_size 256 --dropout_rate 0.5
```

source: Meine Kommilitonin bekam nämlich einen Brandstifter als ersten Patienten<br>
expected: My classmate got an arsonist for her first client<br>
got: Because my first <UNK> came from an in patients
* STEP 3. Or download the pretrained models.
```
wget -qO- --show-progress https://dl.dropbox.com/s/4o7zwef7kzma4q4/log.tar.gz | tar xz
```

source: Das kriege ich hin dachte ich mir<br>
expected: This I thought I could handle<br>
got: I'll go ahead and I thought
## Training Loss Curve
<img src="fig/loss.png">

source: Aber ich habe es nicht hingekriegt<br>
expected: But I didn't handle it<br>
got: But I didn't <UNK> it
## Learning rate
<img src="fig/lr.png">

source: Ich hielt dagegen<br>
expected: I pushed back<br>
got: I thought about it
## Bleu score on devset
<img src="fig/bleu.png">

source: Das ist es was Psychologen einen AhaMoment nennen<br>
expected: That's what psychologists call an Aha moment<br>
got: That's what a <UNK> like a <UNK>

source: Meldet euch wenn ihr in euren ern seid<br>
expected: Raise your hand if you're in your s<br>
got: Get yourself in your s
## Inference (=test)
* Run
```
python test.py --ckpt log/1/iwslt2016_E19L2.62-29146 (OR yourCkptFile OR yourCkptFileDirectory)
```

source: Ich möchte ein paar von euch sehen<br>
expected: I really want to see some twentysomethings here<br>
got: I want to see some of you
## Results
* Typically, machine translation is evaluated with Bleu score.
* All evaluation results are available in [eval/1](eval/1) and [test/1](test/1).

source: Oh yeah Ihr seid alle unglaublich<br>
expected: Oh yay Y'all's awesome<br>
got: Oh yeah you all are incredibly
|tst2013 (dev) | tst2014 (test) |
|--|--|
|26.93|23.16|

source: Dies ist nicht meine Meinung Das sind Fakten<br>
expected: This is not my opinion These are the facts<br>
got: This is not my opinion These are facts
## Notes
* Beam decoding will be added soon.
* I'm going to update the code when TF2.0 comes if possible.
229 changes: 143 additions & 86 deletions data_load.py
Original file line number Diff line number Diff line change
@@ -1,93 +1,150 @@
# -*- coding: utf-8 -*-
#/usr/bin/python2
#/usr/bin/python3
'''
June 2017 by kyubyong park.
Feb. 2019 by kyubyong park.
[email protected].
https://www.github.com/kyubyong/transformer
Note.
if safe, entities on the source side have the prefix 1, and the target side 2, for convenience.
For example, fpath1, fpath2 means source file path and target file path, respectively.
'''
from __future__ import print_function
from hyperparams import Hyperparams as hp
import tensorflow as tf
import numpy as np
import codecs
import regex

def load_de_vocab():
vocab = [line.split()[0] for line in codecs.open('preprocessed/de.vocab.tsv', 'r', 'utf-8').read().splitlines() if int(line.split()[1])>=hp.min_cnt]
word2idx = {word: idx for idx, word in enumerate(vocab)}
idx2word = {idx: word for idx, word in enumerate(vocab)}
return word2idx, idx2word

def load_en_vocab():
vocab = [line.split()[0] for line in codecs.open('preprocessed/en.vocab.tsv', 'r', 'utf-8').read().splitlines() if int(line.split()[1])>=hp.min_cnt]
word2idx = {word: idx for idx, word in enumerate(vocab)}
idx2word = {idx: word for idx, word in enumerate(vocab)}
return word2idx, idx2word

def create_data(source_sents, target_sents):
de2idx, idx2de = load_de_vocab()
en2idx, idx2en = load_en_vocab()

# Index
x_list, y_list, Sources, Targets = [], [], [], []
for source_sent, target_sent in zip(source_sents, target_sents):
x = [de2idx.get(word, 1) for word in (source_sent + u" </S>").split()] # 1: OOV, </S>: End of Text
y = [en2idx.get(word, 1) for word in (target_sent + u" </S>").split()]
if max(len(x), len(y)) <=hp.maxlen:
x_list.append(np.array(x))
y_list.append(np.array(y))
Sources.append(source_sent)
Targets.append(target_sent)

# Pad
X = np.zeros([len(x_list), hp.maxlen], np.int32)
Y = np.zeros([len(y_list), hp.maxlen], np.int32)
for i, (x, y) in enumerate(zip(x_list, y_list)):
X[i] = np.lib.pad(x, [0, hp.maxlen-len(x)], 'constant', constant_values=(0, 0))
Y[i] = np.lib.pad(y, [0, hp.maxlen-len(y)], 'constant', constant_values=(0, 0))

return X, Y, Sources, Targets

def load_train_data():
de_sents = [regex.sub("[^\s\p{Latin}']", "", line) for line in codecs.open(hp.source_train, 'r', 'utf-8').read().split("\n") if line and line[0] != "<"]
en_sents = [regex.sub("[^\s\p{Latin}']", "", line) for line in codecs.open(hp.target_train, 'r', 'utf-8').read().split("\n") if line and line[0] != "<"]

X, Y, Sources, Targets = create_data(de_sents, en_sents)
return X, Y

def load_test_data():
def _refine(line):
line = regex.sub("<[^>]+>", "", line)
line = regex.sub("[^\s\p{Latin}']", "", line)
return line.strip()

de_sents = [_refine(line) for line in codecs.open(hp.source_test, 'r', 'utf-8').read().split("\n") if line and line[:4] == "<seg"]
en_sents = [_refine(line) for line in codecs.open(hp.target_test, 'r', 'utf-8').read().split("\n") if line and line[:4] == "<seg"]

X, Y, Sources, Targets = create_data(de_sents, en_sents)
return X, Sources, Targets # (1064, 150)

def get_batch_data():
# Load data
X, Y = load_train_data()

# calc total batch count
num_batch = len(X) // hp.batch_size

# Convert to tensor
X = tf.convert_to_tensor(X, tf.int32)
Y = tf.convert_to_tensor(Y, tf.int32)

# Create Queues
input_queues = tf.train.slice_input_producer([X, Y])

# create batch queues
x, y = tf.train.shuffle_batch(input_queues,
num_threads=8,
batch_size=hp.batch_size,
capacity=hp.batch_size*64,
min_after_dequeue=hp.batch_size*32,
allow_smaller_final_batch=False)

return x, y, num_batch # (N, T), (N, T), ()
from utils import calc_num_batches

def load_vocab(vocab_fpath):
'''Loads vocabulary file and returns idx<->token maps
vocab_fpath: string. vocabulary file path.
Note that these are reserved
0: <pad>, 1: <unk>, 2: <s>, 3: </s>
Returns
two dictionaries.
'''
vocab = [line.split()[0] for line in open(vocab_fpath, 'r').read().splitlines()]
token2idx = {token: idx for idx, token in enumerate(vocab)}
idx2token = {idx: token for idx, token in enumerate(vocab)}
return token2idx, idx2token

def load_data(fpath1, fpath2, maxlen1, maxlen2):
'''Loads source and target data and filters out too lengthy samples.
fpath1: source file path. string.
fpath2: target file path. string.
maxlen1: source sent maximum length. scalar.
maxlen2: target sent maximum length. scalar.
Returns
sents1: list of source sents
sents2: list of target sents
'''
sents1, sents2 = [], []
with open(fpath1, 'r') as f1, open(fpath2, 'r') as f2:
for sent1, sent2 in zip(f1, f2):
if len(sent1.split()) + 1 > maxlen1: continue # 1: </s>
if len(sent2.split()) + 1 > maxlen2: continue # 1: </s>
sents1.append(sent1.strip())
sents2.append(sent2.strip())
return sents1, sents2


def encode(inp, type, dict):
'''Converts string to number. Used for `generator_fn`.
inp: 1d byte array.
type: "x" (source side) or "y" (target side)
dict: token2idx dictionary
Returns
list of numbers
'''
inp_str = inp.decode("utf-8")
if type=="x": tokens = inp_str.split() + ["</s>"]
else: tokens = ["<s>"] + inp_str.split() + ["</s>"]

x = [dict.get(t, dict["<unk>"]) for t in tokens]
return x

def generator_fn(sents1, sents2, vocab_fpath):
'''Generates training / evaluation data
sents1: list of source sents
sents2: list of target sents
vocab_fpath: string. vocabulary file path.
yields
xs: tuple of
x: list of source token ids in a sent
x_seqlen: int. sequence length of x
sent1: str. raw source (=input) sentence
labels: tuple of
decoder_input: decoder_input: list of encoded decoder inputs
y: list of target token ids in a sent
y_seqlen: int. sequence length of y
sent2: str. target sentence
'''
token2idx, _ = load_vocab(vocab_fpath)
for sent1, sent2 in zip(sents1, sents2):
x = encode(sent1, "x", token2idx)
y = encode(sent2, "y", token2idx)
decoder_input, y = y[:-1], y[1:]

x_seqlen, y_seqlen = len(x), len(y)
yield (x, x_seqlen, sent1), (decoder_input, y, y_seqlen, sent2)

def input_fn(sents1, sents2, vocab_fpath, batch_size, shuffle=False):
'''Batchify data
sents1: list of source sents
sents2: list of target sents
vocab_fpath: string. vocabulary file path.
batch_size: scalar
shuffle: boolean
Returns
xs: tuple of
x: int32 tensor. (N, T1)
x_seqlens: int32 tensor. (N,)
sents1: str tensor. (N,)
ys: tuple of
decoder_input: int32 tensor. (N, T2)
y: int32 tensor. (N, T2)
y_seqlen: int32 tensor. (N, )
sents2: str tensor. (N,)
'''
shapes = (([None], (), ()),
([None], [None], (), ()))
types = ((tf.int32, tf.int32, tf.string),
(tf.int32, tf.int32, tf.int32, tf.string))
paddings = ((0, 0, ''),
(0, 0, 0, ''))

dataset = tf.data.Dataset.from_generator(
generator_fn,
output_shapes=shapes,
output_types=types,
args=(sents1, sents2, vocab_fpath)) # <- arguments for generator_fn. converted to np string arrays

if shuffle: # for training
dataset = dataset.shuffle(128*batch_size)

dataset = dataset.repeat() # iterate forever
dataset = dataset.padded_batch(batch_size, shapes, paddings).prefetch(1)

return dataset

def get_batch(fpath1, fpath2, maxlen1, maxlen2, vocab_fpath, batch_size, shuffle=False):
'''Gets training / evaluation mini-batches
fpath1: source file path. string.
fpath2: target file path. string.
maxlen1: source sent maximum length. scalar.
maxlen2: target sent maximum length. scalar.
vocab_fpath: string. vocabulary file path.
batch_size: scalar
shuffle: boolean
Returns
batches
num_batches: number of mini-batches
num_samples
'''
sents1, sents2 = load_data(fpath1, fpath2, maxlen1, maxlen2)
batches = input_fn(sents1, sents2, vocab_fpath, batch_size, shuffle=shuffle)
num_batches = calc_num_batches(len(sents1), batch_size)
return batches, num_batches, len(sents1)
Loading

0 comments on commit 85e2dd9

Please sign in to comment.