PyTorch implementation of Generating Sentences from a Continuous Space by Bowman et al. 2015. Note: This implementation does not support LSTM's at the moment, but RNN's and GRU's.
Training was stopped after 4 epochs. The true ELBO was optimized for approximatily 1 epoch (as can bee see in the graph above). Results are averaged over entire split.
Split | NLL | KL |
---|---|---|
Train | 99.821 | 7.944 |
Validation | 103.220 | 7.346 |
Test | 103.967 | 7.269 |
Sentenes have been obtained after sampling from z ~ N(0, I).
Sentenes have been obtained after sampling twice from z ~ N(0, I) and the interpolating the two samples.
To run the training, please download the Penn Tree Bank data first (download from Tomas Mikolov's webpage). The code expects to find at least ptb.train.txt
and ptb.valid.txt
in the specified data directory.
Then training can be executed with the following command:
python3 train.py
The following arguments are available:
--data_dir
The path to the directory where PTB data is stored, and auxiliary data files will be stored.
--create_data
If provided, new auxiliary data files will be created form the source data.
--max_sequence_length
Specifies the cut off of long sentences.
--min_occ
If a word occurs less than "min_occ" times in the corpus, it will be replaced by the token.
--test
If provided, performance will also be measured on the test set.
-ep
, --epochs
-bs
, --batch_size
-lr
, --learning_rate
-eb
, --embedding_size
-rnn
, --rnn_type
Either 'rnn' or 'gru'.
-hs
, --hidden_size
-nl
, --num_layers
-bi
, --bidirectional
-ls
, --latent_size
-wd
, --word_dropout
Word dropout applied to the input of the Decoder.
-af
, --anneal_function
Either 'logistic' or 'linear'.
-k
, --k
Steepness of the logistic annealing function.
-x0
, --x0
For 'logistic', this is the mid-point (i.e. when the weight is 0.5); for 'linear' this is the denominator.
-v
, --print_every
-tb
, --tensorboard_logging
If provided, training progress is monitored with tensorboard.
-log
, --logdir
Directory of log files for tensorboard.
-bin
,--save_model_path
Directory where to store model checkpoints.