Skip to content

Commit

Permalink
fix predict example label map error
Browse files Browse the repository at this point in the history
  • Loading branch information
Ruchen Zhang committed Nov 7, 2018
1 parent a21d484 commit 08509ea
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions run_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,10 @@ def _create_examples(self, lines, set_type):
guid = "%s-%s" % (set_type, tokenization.convert_to_unicode(line[0]))
text_a = tokenization.convert_to_unicode(line[8])
text_b = tokenization.convert_to_unicode(line[9])
label = tokenization.convert_to_unicode(line[-1])
if set_type == "test":
label = "contradiction"
else:
label = tokenization.convert_to_unicode(line[-1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
Expand Down Expand Up @@ -301,7 +304,10 @@ def _create_examples(self, lines, set_type):
guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[3])
text_b = tokenization.convert_to_unicode(line[4])
label = tokenization.convert_to_unicode(line[0])
if set_type == "test":
label = "0"
else:
label = tokenization.convert_to_unicode(line[0])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
Expand Down Expand Up @@ -335,7 +341,10 @@ def _create_examples(self, lines, set_type):
for (i, line) in enumerate(lines):
guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[3])
label = tokenization.convert_to_unicode(line[1])
if set_type == "test":
label = "0"
else:
label = tokenization.convert_to_unicode(line[1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
Expand Down Expand Up @@ -855,6 +864,7 @@ def main(_):
for key in sorted(result.keys()):
tf.logging.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))

if FLAGS.do_predict:
predict_examples = processor.get_test_examples(FLAGS.data_dir)
predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
Expand Down

0 comments on commit 08509ea

Please sign in to comment.