Tensorflow implementation of Generating Sentences from a Continuous Space.
- Python packages:
- Python 3.4 or higher
- Tensorflow r0.12
- Numpy
- Clone this repository:
git clone https://github.com/Chung-I/Variational-Recurrent-Autoencoder-Tensorflow.git
- Set up conda environment:
conda create -n vrae python=3.6
conda activate vrae
- Install python package requirements:
pip install -r requirements.txt
Training:
python vrae.py --model_dir models --do train --new True
Reconstruct:
python vrae.py --model_dir models --do reconstruct --new False --input input.txt --output output.txt
Sample (this script read only the first line of input.txt
, generate num_pts
samples, and write them into output.txt
):
python vrae.py --model_dir models --do sample --new False --input input.txt --output output.txt
Interpolate (this script requires that input.txt
consists of only two sentences; it generate num_pts
interpolations between them, and write those interpolated sentences into output.txt
)::
python vrae.py --model_dir models --do interpolate --new False --input input.txt --output output.txt
model_dir
: The location of the config file config.json
and the checkpoint file.
do
: Accept 4 values: train
, encode_decode
, sample
, or interpolate
.
new
: create models with fresh parameters if set to True
; else read model parameters from checkpoints in model_dir
.
Hyperparameters are not passed from command prompt like that in tensorflow/models/rnn/translate/translate.py. Instead, vrae.py reads hyperparameters from config.json in model_dir
.
Below are hyperparameters in config.json:
-
model
:size
: embedding size, and encoder/decoder state size.latent_dim
: latent space size.in_vocab_size
: source vocabulary size.out_vocab_size
: target vocabulary size.data_dir
: path to the corpus.num_layers
: number of layers for encoder and decoder.use_lstm
: use lstm for encoder and decoder or not. UseBasicLSTMCell
if set toTrue
; elseGRUCell
is used.buckets
: A list of pairs of [input size, output size] for each bucket.bidirectional
:bidirectional_rnn
is used if set toTrue
.probablistic
: variance is set to zero if set toFalse
.orthogonal_initializer
:orthogonal_initializer
is used if set toTrue
; elseuniform_unit_scaling_initializer
is used.iaf
: inverse autoregressive flow is used if set toTrue
.activation
: activation for encoder-to-latent layer and latent-to-decoder layer.elu
: exponential linear unit.prelu
: parametric linear unit. (default)None
: linear.
-
train
:batch_size
beam_size
: beam size for decoding. Warning: beam search is still under implementation.NotImplementedError
would be raised ifbeam_size
is set to be greater than 1.learning_rate
: learning rate parameter passed intoAdamOptimizer
.steps_per_checkpoint
: save checkpoint everysteps_per_checkpoint
steps.anneal
: do KL cost annealing if set toTrue
.kl_rate_rise_factor
: KL term weight is increasd by this much everysteps_per_checkpoint
steps.max_train_data_size
: Limit on the size of training data (0: no limit).feed_previous
: IfTrue
, only the first of decoder_inputs will be used (the "GO" symbol), and all other decoder inputs will be generated by:next = embedding_lookup(embedding, argmax(previous_output))
. In effect, this implements a greedy decoder. It can also be used during training to emulate http://arxiv.org/abs/1506.03099. IfFalse
,decoder_inputs
are used as given (the standard decoder case).kl_min
: the minimum information constraint. Should be a non-negative float (where 0 is no constraint).max_gradient_norm
: gradients will be clipped to maximally this norm.word_dropout_keep_prob
: probability of randomly replacing some fraction of the conditioned-on word tokens with the generic unknown word tokenUNK
. when equal to 0, the decoder sees no input.
-
reconstruct:
feed_previous
word_dropout_keep_prob
-
sample:
feed_previous
word_dropout_keep_prob
num_pts
: samplenum_pts
points.
-
interpolate:
feed_previous
word_dropout_keep_prob
num_pts
: samplenum_pts
points.
Penn TreeBank corpus is included in the repo. We also provide a Chinese poem corpus, its preprocessed version (set {"model":{"data_dir": "<corpus_dir>"}}
in <model_dir>/config.json
to it), and its pretrained model (set model_dir
to it), all of which can be found here.