Skip to content

Commit

Permalink
feat(f1-score): add pred_file flag handling
Browse files Browse the repository at this point in the history
  • Loading branch information
agrudkow committed May 29, 2021
1 parent b39d4d4 commit 536338a
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions run_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,6 @@

# ists
flags.DEFINE_bool("calc_ists_metrics", default=False, help="Calculate metrics for prediotions in iSTS task")
flags.DEFINE_string("dataset", default=None,
help="Specifies tsv file with predictions. If None, ")
flags.DEFINE_string("pred_file", default=None,
help="Specifies tsv file with predictions. If None, ")

Expand Down Expand Up @@ -914,9 +912,14 @@ def tokenize_fn(text):
json.dump(predict_results, fp, indent=4)

if FLAGS.calc_ists_metrics:
predictions = model_utils.get_predictions()
predictions = []

if not FLAGS.pred_file:
predictions = model_utils.get_predictions()
else:
predictions = [model_utils.extract_global_step(FLAGS.pred_file[:-4]), FLAGS.pred_file]

for global_step, pred_file_path in predictions:
for global_step, pred_file_path in sorted(predictions, key=lambda x: x[0]):
f1_scores = calc_ists_metrics(pred_file_path, FLAGS.data_dir + "/test.tsv")
tf.logging.info( 'Dataset: ' + FLAGS.data_dir.split("/")[-1] + " Step: " + str(global_step) + ' [F1 Type]: {} \n [F1 Score]: {} \n [F1 T+S]: {}' % f1_scores)

Expand Down

0 comments on commit 536338a

Please sign in to comment.