Skip to content

Commit

Permalink
Major Change in structure to adopt R-Net by HKUST-KnowComp
Browse files Browse the repository at this point in the history
  • Loading branch information
localminimum committed Apr 24, 2018
1 parent df4120a commit 80c5b86
Show file tree
Hide file tree
Showing 16 changed files with 963 additions and 1,531 deletions.
59 changes: 30 additions & 29 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# FAST AND ACCURATE READING COMPREHENSION WITHOUT RECURRENT NETWORKS
A Tensorflow implementation of Google's [Fast Reading Comprehension](https://openreview.net/pdf?id=B14TlG-RW) from [ICLR2018](https://openreview.net/forum?id=B14TlG-RW).
Without RNNs the model computes relatively quickly compared to [R-net](https://github.com/minsangkim142/R-net)(about 5 times faster in naive implementation).
After 12 epochs of training our model reaches dev EM/F1 = 57 / 72.
Training and preprocessing pipeline has been adopted from [R-Net by HKUST-KnowComp](https://github.com/HKUST-KnowComp/R-Net). Demo mode needs to be reimplemented. If you are here for the demo please use "dev" branch. The model reaches EM/F1 = 66/75 in 30k steps.

Due to memory issue, a single head dot-product attention is used as opposed to 8 heads multi-head attention as mentioned in the original paper. Also hidden size is reduced to 96 from 128 due to memory problems in GTX1080. (8GB GPU memory is insufficient. If you have a 12GB memory GPU please share your results with us.)

![Alt text](/../master/screenshots/figure.png?raw=true "Network Outline")

Expand All @@ -16,29 +17,41 @@ Pretrained [GloVe embeddings](https://nlp.stanford.edu/projects/glove/) obtained
* TensorFlow (1.2 or higher)
* spacy

## Downloads and Setup
Preprocessing step is identical to [R-net](https://github.com/minsangkim142/R-net).
Once you clone this repo, run the following lines from bash **just once** to process the dataset (SQuAD).
```shell
$ pip install -r requirements.txt
$ bash setup.sh
$ python process.py --process True --reduce_glove True
## Usage

To download and preprocess the data, run

```bash
# download SQuAD and Glove
sh download.sh
# preprocess the data
python config.py --mode prepro
```

## Training / Testing / Debugging / Demo
You can change the hyperparameters from params.py file to fit the model in your GPU. To train the model, run the following line.
To test or debug your model after training, change mode = "train" from params.py file and run the model.
```shell
$ python model.py
Just like [R-Net by HKUST-KnowComp](https://github.com/HKUST-KnowComp/R-Net), hyper parameters are stored in config.py. To debug/train/test the model, run

```bash
python evaluate-v1.1.py ~/data/squad/dev-v1.1.json log/answer/answer.json
```

**A working realtime demo is available at demo.py. To use web interface for live demo change use mode = "demo" and set batch_size to 1. (The code is taken from [R-net](https://github.com/minsangkim142/R-net))**
The default directory for tensorboard log file is `log/event`

## Detailed Implementaion

* The model adopts character level convolution - max pooling - highway network for input representations similar to [this paper by Yoon Kim](https://arxiv.org/pdf/1508.06615.pdf).
* Encoder consists of positional encoding - depthwise separable convolution - self attention - feed forward structure with layer norm in between.
* Stochastic depth dropout is used to drop the residual connection with respect to increasing depth of the network as this paper heavily relies on residual connection.
* Context-to-Query attention is used but Query-to-Context attention is not implemented as it is reported not to improve much on the performance.
* Learning rate increases from 0.0 to 0.001 in first 1000 steps in inverse exponential scale and fixed to 0.001 from 1000 steps.
* For regularization, dropout of 0.1 is used every 2 sub-layers and 2 blocks.
* During prediction, this model uses shadow variables maintained by exponential moving average of all global variables.
* [Taken from R-Net](https://github.com/HKUST-KnowComp/R-Net): To address efficiency issue, this implementation uses bucketing method (contributed by xiongyifan) and CudnnGRU. Due to a known bug [#13254](https://github.com/tensorflow/tensorflow/issues/13254) in Tensorflow, the weights of CudnnGRU may not be properly restored. Check the test score if you want to use it for prediction. The bucketing method can speedup the training, but will lower the F1 score by 0.3%.


## TODO's
- [x] Add trilinear function to Context-to-Query attention
- [x] Convergence testing
- [x] Apply dropouts + stochastic depth dropout
- [x] Realtime Demo
- [ ] Realtime Demo
- [ ] Query-to-context attention
- [ ] Data augmentation by paraphrasing

Expand All @@ -47,15 +60,3 @@ Run tensorboard for visualisation.
```shell
$ tensorboard --logdir=./
```

![Alt text](/../master/screenshots/tensorboard.png?raw=true "Training Curve")

## Note
**2/02/18**
The model quickly reaches EM/F1 = 55/69 on devset, but never gets beyond that even with strong regularization. Also the training speed (1.8 batch per second in GTX1080) is slower than the paper suggests (3.2 batch per second in P100).

**28/01/18**
The model reaches devset performance of EM/F1=44/58 1 hour into training without dropout. Next goal is to train with dropout every 2 layers.

**04/11/17**
Currently the model is not optimized and there is a memory leak so I strongly suggest only training if your memory is 16GB >. Also I haven't done convergence testing yet. The training time is 5 ~ 6x faster on naive implementation compared to [R-net](https://github.com/minsangkim142/R-net).
131 changes: 131 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import os
import tensorflow as tf

'''
This file is taken and modified from R-Net by HKUST-KnowComp
https://github.com/HKUST-KnowComp/R-Net
'''

from prepro import prepro
from main import train, test

flags = tf.flags

home = os.path.expanduser("~")
train_file = os.path.join(home, "data", "squad", "train-v1.1.json")
dev_file = os.path.join(home, "data", "squad", "dev-v1.1.json")
test_file = os.path.join(home, "data", "squad", "dev-v1.1.json")
glove_word_file = os.path.join(home, "data", "glove", "glove.840B.300d.txt")

target_dir = "data"
log_dir = "log/event"
save_dir = "log/model"
answer_dir = "log/answer"
train_record_file = os.path.join(target_dir, "train.tfrecords")
dev_record_file = os.path.join(target_dir, "dev.tfrecords")
test_record_file = os.path.join(target_dir, "test.tfrecords")
word_emb_file = os.path.join(target_dir, "word_emb.json")
char_emb_file = os.path.join(target_dir, "char_emb.json")
train_eval = os.path.join(target_dir, "train_eval.json")
dev_eval = os.path.join(target_dir, "dev_eval.json")
test_eval = os.path.join(target_dir, "test_eval.json")
dev_meta = os.path.join(target_dir, "dev_meta.json")
test_meta = os.path.join(target_dir, "test_meta.json")
answer_file = os.path.join(answer_dir, "answer.json")

if not os.path.exists(target_dir):
os.makedirs(target_dir)
if not os.path.exists(log_dir):
os.makedirs(log_dir)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
if not os.path.exists(answer_dir):
os.makedirs(answer_dir)

flags.DEFINE_string("mode", "train", "Running mode train/debug/test")

flags.DEFINE_string("target_dir", target_dir, "Target directory for out data")
flags.DEFINE_string("log_dir", log_dir, "Directory for tf event")
flags.DEFINE_string("save_dir", save_dir, "Directory for saving model")
flags.DEFINE_string("train_file", train_file, "Train source file")
flags.DEFINE_string("dev_file", dev_file, "Dev source file")
flags.DEFINE_string("test_file", test_file, "Test source file")
flags.DEFINE_string("glove_word_file", glove_word_file, "Glove word embedding source file")

flags.DEFINE_string("train_record_file", train_record_file, "Out file for train data")
flags.DEFINE_string("dev_record_file", dev_record_file, "Out file for dev data")
flags.DEFINE_string("test_record_file", test_record_file, "Out file for test data")
flags.DEFINE_string("word_emb_file", word_emb_file, "Out file for word embedding")
flags.DEFINE_string("char_emb_file", char_emb_file, "Out file for char embedding")
flags.DEFINE_string("train_eval_file", train_eval, "Out file for train eval")
flags.DEFINE_string("dev_eval_file", dev_eval, "Out file for dev eval")
flags.DEFINE_string("test_eval_file", test_eval, "Out file for test eval")
flags.DEFINE_string("dev_meta", dev_meta, "Out file for dev meta")
flags.DEFINE_string("test_meta", test_meta, "Out file for test meta")
flags.DEFINE_string("answer_file", answer_file, "Out file for answer")


flags.DEFINE_integer("glove_char_size", 94, "Corpus size for Glove")
flags.DEFINE_integer("glove_word_size", int(2.2e6), "Corpus size for Glove")
flags.DEFINE_integer("glove_dim", 300, "Embedding dimension for Glove")
flags.DEFINE_integer("char_dim", 200, "Embedding dimension for char")

flags.DEFINE_integer("para_limit", 400, "Limit length for paragraph")
flags.DEFINE_integer("ques_limit", 50, "Limit length for question")
flags.DEFINE_integer("test_para_limit", 1000, "Limit length for paragraph in test file")
flags.DEFINE_integer("test_ques_limit", 100, "Limit length for question in test file")
flags.DEFINE_integer("char_limit", 16, "Limit length for character")
flags.DEFINE_integer("word_count_limit", -1, "Min count for word")
flags.DEFINE_integer("char_count_limit", -1, "Min count for char")

flags.DEFINE_integer("capacity", 15000, "Batch size of dataset shuffle")
flags.DEFINE_integer("num_threads", 4, "Number of threads in input pipeline")
flags.DEFINE_boolean("is_bucket", False, "build bucket batch iterator or not")
flags.DEFINE_list("bucket_range", [40, 401, 40], "the range of bucket")

flags.DEFINE_integer("batch_size", 32, "Batch size")
flags.DEFINE_integer("num_steps", 60000, "Number of steps")
flags.DEFINE_integer("checkpoint", 1000, "checkpoint to save and evaluate the model")
flags.DEFINE_integer("period", 100, "period to save batch loss")
flags.DEFINE_integer("val_num_batches", 150, "Number of batches to evaluate the model")
flags.DEFINE_float("dropout", 0.1, "Dropout prob across the layers")
flags.DEFINE_float("grad_clip", 5.0, "Global Norm gradient clipping rate")
flags.DEFINE_float("learning_rate", 0.001, "Learning rate")
flags.DEFINE_float("decay", 0.9999, "Exponential moving average decay")
flags.DEFINE_float("l2_norm", 3e-7, "L2 norm scale")
flags.DEFINE_integer("hidden", 128, "Hidden size")
flags.DEFINE_integer("num_heads", 1, "Number of heads in self attention")

# Extensions (Uncomment corresponding code in download.sh to download the required data)
glove_char_file = os.path.join(home, "data", "glove", "glove.840B.300d-char.txt")
flags.DEFINE_string("glove_char_file", glove_char_file, "Glove character embedding source file")
flags.DEFINE_boolean("pretrained_char", False, "Whether to use pretrained character embedding")

fasttext_file = os.path.join(home, "data", "fasttext", "wiki-news-300d-1M.vec")
flags.DEFINE_string("fasttext_file", fasttext_file, "Fasttext word embedding source file")
flags.DEFINE_boolean("fasttext", False, "Whether to use fasttext")


def main(_):
config = flags.FLAGS
if config.mode == "train":
train(config)
elif config.mode == "prepro":
prepro(config)
elif config.mode == "debug":
config.num_steps = 2
config.val_num_batches = 1
config.checkpoint = 1
config.period = 1
train(config)
elif config.mode == "test":
if config.use_cudnn:
print("Warning: Due to a known bug in Tensorlfow, the parameters of CudnnGRU may not be properly restored.")
test(config)
else:
print("Unknown mode")
exit(0)


if __name__ == "__main__":
tf.app.run()
Loading

0 comments on commit 80c5b86

Please sign in to comment.