Code for MaskGAN: Better Text Generation via Filling in the ______ published at ICLR 2018.
- TensorFlow >= v1.5
Warning: The open-source version of this code is still in the process of being tested. Pretraining may not work correctly.
For training on PTB:
-
Pretrain a LM on PTB and store the checkpoint in
/tmp/pretrain-lm/
. Instructions WIP. -
Run MaskGAN in MLE pretraining mode. If step 1 was not run, set
language_model_ckpt_dir
to empty.
python train_mask_gan.py \
--data_dir='/tmp/ptb' \
--batch_size=20 \
--sequence_length=20 \
--base_directory='/tmp/maskGAN' \
--hparams="gen_rnn_size=650,dis_rnn_size=650,gen_num_layers=2,dis_num_layers=2,gen_learning_rate=0.00074876,dis_learning_rate=5e-4,baseline_decay=0.99,dis_train_iterations=1,gen_learning_rate_decay=0.95" \
--mode='TRAIN' \
--max_steps=100000 \
--language_model_ckpt_dir=/tmp/pretrain-lm/ \
--generator_model='seq2seq_vd' \
--discriminator_model='rnn_zaremba' \
--is_present_rate=0.5 \
--summaries_every=10 \
--print_every=250 \
--max_num_to_print=3 \
--gen_training_strategy=cross_entropy \
--seq2seq_share_embedding
- Run MaskGAN in GAN mode. If step 2 was not run, set
maskgan_ckpt
to empty.
python train_mask_gan.py \
--data_dir='/tmp/ptb' \
--batch_size=128 \
--sequence_length=20 \
--base_directory='/tmp/maskGAN' \
--mask_strategy=contiguous \
--maskgan_ckpt='/tmp/maskGAN' \
--hparams="gen_rnn_size=650,dis_rnn_size=650,gen_num_layers=2,dis_num_layers=2,gen_learning_rate=0.000038877,gen_learning_rate_decay=1.0,gen_full_learning_rate_steps=2000000,gen_vd_keep_prob=0.33971,rl_discount_rate=0.89072,dis_learning_rate=5e-4,baseline_decay=0.99,dis_train_iterations=2,dis_pretrain_learning_rate=0.005,critic_learning_rate=5.1761e-7,dis_vd_keep_prob=0.71940" \
--mode='TRAIN' \
--max_steps=100000 \
--generator_model='seq2seq_vd' \
--discriminator_model='seq2seq_vd' \
--is_present_rate=0.5 \
--summaries_every=250 \
--print_every=250 \
--max_num_to_print=3 \
--gen_training_strategy='reinforce' \
--seq2seq_share_embedding=true \
--baseline_method=critic \
--attention_option=luong
- Generate samples:
python generate_samples.py \
--data_dir /tmp/ptb/ \
--data_set=ptb \
--batch_size=256 \
--sequence_length=20 \
--base_directory /tmp/imdbsample/ \
--hparams="gen_rnn_size=650,dis_rnn_size=650,gen_num_layers=2,gen_vd_keep_prob=0.33971" \
--generator_model=seq2seq_vd \
--discriminator_model=seq2seq_vd \
--is_present_rate=0.0 \
--maskgan_ckpt=/tmp/maskGAN \
--seq2seq_share_embedding=True \
--dis_share_embedding=True \
--attention_option=luong \
--mask_strategy=contiguous \
--baseline_method=critic \
--number_epochs=4
- Liam Fedus, @liamb315 [email protected]
- Andrew M. Dai, @a-dai [email protected]