Skip to content

Commit

Permalink
support disabling logs/plots/saves
Browse files Browse the repository at this point in the history
Needed since for heterogeneous DA I broke plotting
  • Loading branch information
floft committed Apr 27, 2020
1 parent c0712ef commit aba20df
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@
flags.DEFINE_string("uid", None, "A unique ID saved in the log/model folder names to avoid conflicts")
flags.DEFINE_integer("steps", 30000, "Number of training steps to run")
flags.DEFINE_float("gpumem", 3350, "GPU memory to let TensorFlow use, in MiB (0 for all)")
flags.DEFINE_integer("model_steps", 4000, "Save the model every so many steps")
flags.DEFINE_integer("log_train_steps", 500, "Log training information every so many steps")
flags.DEFINE_integer("log_val_steps", 4000, "Log validation information every so many steps (also saves model)")
flags.DEFINE_integer("log_plots_steps", 4000, "Log plots every so many steps")
flags.DEFINE_integer("model_steps", 0, "Save the model every so many steps (0 for only when log_val_steps)")
flags.DEFINE_integer("log_train_steps", 500, "Log training information every so many steps (0 for never)")
flags.DEFINE_integer("log_val_steps", 4000, "Log validation information every so many steps (also saves model, 0 for only at end)")
flags.DEFINE_integer("log_plots_steps", 0, "Log plots every so many steps (0 for never)")
flags.DEFINE_boolean("test", False, "Use real test set for evaluation rather than validation set")
flags.DEFINE_boolean("subdir", True, "Save models/logs in subdirectory of prefix")
flags.DEFINE_boolean("debug", False, "Start new log/model/images rather than continuing from previous run")
Expand Down Expand Up @@ -148,26 +148,28 @@ def main(argv):
sys.stdout.flush() # otherwise waits till the end to flush on Kamiak

# Metrics on training/validation data
if i%FLAGS.log_train_steps == 0:
if FLAGS.log_train_steps != 0 and i%FLAGS.log_train_steps == 0:
metrics.train(data_sources, data_target, global_step, t)

# Evaluate every log_val_steps but also at the last step
validation_accuracy_source = None
validation_accuracy_target = None
if i%FLAGS.log_val_steps == 0 or i == FLAGS.steps:
if (FLAGS.log_val_steps != 0 and i%FLAGS.log_val_steps == 0) \
or i == FLAGS.steps:
validation_accuracy_source, validation_accuracy_target \
= metrics.test(global_step)

# Checkpoints -- Save either if at the right model step or if we found
# a new validation accuracy. If this is better than the previous best
# model, we need to make a new checkpoint so we can restore from this
# step with the best accuracy.
if i%FLAGS.model_steps == 0 or validation_accuracy_source is not None:
if (FLAGS.model_steps != 0 and i%FLAGS.model_steps == 0) \
or validation_accuracy_source is not None:
checkpoint_manager.save(int(global_step-1),
validation_accuracy_source, validation_accuracy_target)

# Plots
if i%FLAGS.log_plots_steps == 0:
if FLAGS.log_plots_steps != 0 and i%FLAGS.log_plots_steps == 0:
metrics.plots(global_step)

# We're done -- used for hyperparameter tuning
Expand Down

0 comments on commit aba20df

Please sign in to comment.