From 95c3b6977aa412171ae4ac90ee7c85875ede2219 Mon Sep 17 00:00:00 2001 From: Carlos Chinchilla Date: Sat, 12 May 2018 08:04:09 -0700 Subject: [PATCH] save progress --- pix2pix-trainer/no_flip.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/pix2pix-trainer/no_flip.py b/pix2pix-trainer/no_flip.py index c528a5a..8ac0059 100644 --- a/pix2pix-trainer/no_flip.py +++ b/pix2pix-trainer/no_flip.py @@ -341,8 +341,19 @@ def train(): if sv.should_stop(): break - print('saving model') - saver.save(sess, os.path.join(output_path, 'model'), global_step=sv.global_step) + # save model + if train_epoch % 10 == 0 and train_step == 1: + print('saving model') + output_path = '{}/{}/{}'.format(output_path, job_name, train_epoch) + if not os.path.exists(output_path): + os.makedirs(output_path) + saver.save(sess, os.path.join(output_path, 'model'), global_step=sv.global_step) + + # save model + output_path = '{}/{}/{}'.format(output_path, job_name, 'final') + if not os.path.exists(output_path): + os.makedirs(output_path) + saver.save(sess, os.path.join(output_path, 'model'), global_step=sv.global_step) if __name__ == '__main__': @@ -381,9 +392,4 @@ def train(): l1_weight = args.l1_weight gan_weight = args.gan_weight - # create job output directory - output_path = '{}/{}'.format(output_path, job_name) - if not os.path.exists(output_path): - os.makedirs(output_path) - train()