Skip to content

Commit

Permalink
Fix prediction too long length bug
Browse files Browse the repository at this point in the history
  • Loading branch information
BeyonderXX committed Mar 21, 2023
1 parent 4bf89d1 commit bc30ce5
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
3 changes: 0 additions & 3 deletions src/run_uie.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,10 +283,7 @@ def main():
task_config_dir=data_args.task_config_dir,
instruction_file=data_args.instruction_file,
instruction_strategy=data_args.instruction_strategy,
# keep_in_memory=True,
cache_dir=data_cache_dir, # for debug, change dataset size, otherwise open it
# verification_mode=datasets.VerificationMode.NONE,
# ignore_verifications=True,
max_num_instances_per_task=data_args.max_num_instances_per_task,
max_num_instances_per_eval_task=data_args.max_num_instances_per_eval_task,
num_examples=data_args.num_examples
Expand Down
9 changes: 6 additions & 3 deletions src/uie_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@ def skip_instructions(model, predictions_ids, tokenizer, ignore_idx=-100):
final_predictions = []
if check_model(model.config._name_or_path, SUPPORTED_DECODER_MODELS):
for pred in predictions:
assert ANSWER_PREFIX in pred
splits = pred.split(ANSWER_PREFIX)
final_predictions.append(splits[-1].strip())

if ANSWER_PREFIX in pred:
splits = pred.split(ANSWER_PREFIX)
final_predictions.append(splits[-1].strip())
else:
final_predictions.append('')
else:
final_predictions = predictions

Expand Down

0 comments on commit bc30ce5

Please sign in to comment.