Skip to content

Commit

Permalink
fix(run_classifier): move calc_ists_metrics to the top of the main to…
Browse files Browse the repository at this point in the history
… skip unnecessary script executions
  • Loading branch information
agrudkow committed May 29, 2021
1 parent 523f40b commit 04a727a
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions run_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,21 @@ def main(_):
if FLAGS.save_steps is not None:
FLAGS.iterations = min(FLAGS.iterations, FLAGS.save_steps)

if FLAGS.calc_ists_metrics:
predictions = []

if not FLAGS.pred_file:
predictions = model_utils.get_predictions(FLAGS.predict_dir)
else:
predictions = [model_utils.extract_global_step(FLAGS.pred_file[:-4]), FLAGS.pred_file]

for global_step, pred_file_path in sorted(predictions, key=lambda x: x[0]):
f1_scores = calc_ists_metrics(pred_file_path, FLAGS.data_dir + "/test.tsv")
tf.logging.info( 'Dataset: ' + FLAGS.data_dir.split("/")[-1] + " Step: " + str(global_step) + ' [F1 Type]: {} \n [F1 Score]: {} \n [F1 T+S]: {}' % f1_scores)

# End execution after caclulations
return None

if FLAGS.do_predict:
predict_dir = FLAGS.predict_dir
if not tf.gfile.Exists(predict_dir):
Expand Down Expand Up @@ -910,18 +925,6 @@ def tokenize_fn(text):

with tf.gfile.Open(predict_json_path, "w") as fp:
json.dump(predict_results, fp, indent=4)

if FLAGS.calc_ists_metrics:
predictions = []

if not FLAGS.pred_file:
predictions = model_utils.get_predictions(FLAGS.predict_dir)
else:
predictions = [model_utils.extract_global_step(FLAGS.pred_file[:-4]), FLAGS.pred_file]

for global_step, pred_file_path in sorted(predictions, key=lambda x: x[0]):
f1_scores = calc_ists_metrics(pred_file_path, FLAGS.data_dir + "/test.tsv")
tf.logging.info( 'Dataset: ' + FLAGS.data_dir.split("/")[-1] + " Step: " + str(global_step) + ' [F1 Type]: {} \n [F1 Score]: {} \n [F1 T+S]: {}' % f1_scores)


if __name__ == "__main__":
Expand Down

0 comments on commit 04a727a

Please sign in to comment.