Skip to content

Commit

Permalink
add checkpoint step argument / modify default ckpt step from 1000 to …
Browse files Browse the repository at this point in the history
…10000
  • Loading branch information
khanrc committed Sep 11, 2017
1 parent 4352895 commit d0e529c
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def build_parser():
parser.add_argument('--model', help=models_str, required=True) # DRAGAN, CramerGAN
parser.add_argument('--name', help='default: name=model')
parser.add_argument('--dataset', help='CelebA / LSUN', required=True)
parser.add_argument('--ckpt_step', default=10000, help='# of steps for saving checkpoint (default: 10000)', type=int)
parser.add_argument('--renew', action='store_true', help='train model from scratch - \
clean saved checkpoints and summaries', default=False)

Expand All @@ -35,14 +36,14 @@ def sample_z(shape):
return np.random.normal(size=shape)


def train(model, input_op, num_epochs, batch_size, n_examples, renew=False):
def train(model, dataset, input_op, num_epochs, batch_size, n_examples, ckpt_step, renew=False):
# n_examples = 202599 # same as util.num_examples_from_tfrecords(glob.glob('./data/celebA_tfrecords/*.tfrecord'))
# 1 epoch = 1583 steps
print("\n# of examples: {}".format(n_examples))
print("steps per epoch: {}\n".format(n_examples//batch_size))

summary_path = os.path.join('./summary/', FLAGS.dataset, model.name)
ckpt_path = os.path.join('./checkpoints', FLAGS.dataset, model.name)
summary_path = os.path.join('./summary/', dataset, model.name)
ckpt_path = os.path.join('./checkpoints', dataset, model.name)
if renew:
if os.path.exists(summary_path):
tf.gfile.DeleteRecursively(summary_path)
Expand All @@ -66,7 +67,7 @@ def train(model, input_op, num_epochs, batch_size, n_examples, renew=False):
# make config_summary before define of summary_writer - bypass bug of tensorboard

# It seems that batch_size should have been contained in the model config ...
config_list = [('batch_size', batch_size), ('dataset', FLAGS.dataset)]
config_list = [('batch_size', batch_size), ('dataset', dataset)]
model_config_list = [[k, str(w)] for k, w in sorted(model.args.items()) + config_list]
model_config_summary_op = tf.summary.text('config', tf.convert_to_tensor(model_config_list), collections=[])
model_config_summary = sess.run(model_config_summary_op)
Expand Down Expand Up @@ -103,7 +104,7 @@ def train(model, input_op, num_epochs, batch_size, n_examples, renew=False):
if global_step % 10 == 0:
pbar.update(10)

if global_step % 1000 == 0:
if global_step % ckpt_step == 0:
saver.save(sess, ckpt_path+'/'+model.name, global_step=global_step)

except tf.errors.OutOfRangeError:
Expand Down Expand Up @@ -132,5 +133,5 @@ def train(model, input_op, num_epochs, batch_size, n_examples, renew=False):
X = input_pipeline(dataset_pattern, batch_size=FLAGS.batch_size,
num_threads=FLAGS.num_threads, num_epochs=FLAGS.num_epochs)
model = config.get_model(FLAGS.model, FLAGS.name, training=True)
train(model=model, input_op=X, num_epochs=FLAGS.num_epochs, batch_size=FLAGS.batch_size,
n_examples=n_examples, renew=FLAGS.renew)
train(model=model, dataset=FLAGS.dataset, input_op=X, num_epochs=FLAGS.num_epochs, batch_size=FLAGS.batch_size,
n_examples=n_examples, ckpt_step=FLAGS.ckpt_step, renew=FLAGS.renew)

0 comments on commit d0e529c

Please sign in to comment.