Skip to content

Commit

Permalink
Fix decode first token bug
Browse files Browse the repository at this point in the history
  • Loading branch information
BeyonderXX committed Mar 20, 2023
1 parent 9a8357b commit 7b31fbd
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 8 deletions.
1 change: 0 additions & 1 deletion src/run_uie.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
from model.bloom import BloomForCausalLM_WithLoss
from model.codegen import CodeGenForCausalLM_WithLoss
from uie_collator import DataCollatorForUIE
from uie_collator import SUPPORTED_DECODER_MODELS, check_model

from uie_trainer import UIETrainer, DenserEvalCallback, skip_instructions
from compute_metrics import compute_metrics, compute_grouped_metrics
Expand Down
3 changes: 1 addition & 2 deletions src/uie_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ def seq2seq_call(self, batch, return_tensors):
return model_inputs

def decoder_call(self, batch, return_tensors):
# decoder模型处理
self.tokenizer.padding_side = 'left'
sources = []
label_lens = []
Expand Down Expand Up @@ -194,7 +193,7 @@ def decoder_call(self, batch, return_tensors):
max_len = min(max_len, limit_input_len)
loss_mask = torch.ones((label_mask.shape))
for k, label_len in enumerate(label_lens):
loss_mask[k, : max_len - label_len - 2] = 0
loss_mask[k, : max_len - label_len - 1] = 0
model_inputs['loss_mask'] = loss_mask.masked_fill(~label_mask, 0)

self._save_samples(model_inputs, sources, labels)
Expand Down
9 changes: 4 additions & 5 deletions src/uie_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
logger = datasets.logging.get_logger(__name__)
TASK_CONFIG_FILES = {"train": "train_tasks.json", "dev": "dev_tasks.json", "test": "test_tasks.json"}
INSTRUCTION_STRATEGIES = ['single', 'multiple']
SINGLE_QUOTES_SUBSTITUTE = "#$%#"
ANSWER_PREFIX = "Answer: "
ANSWER_PREFIX = "Answer:"


def check_path(path):
Expand Down Expand Up @@ -252,7 +251,7 @@ def load_NER_dataset(self, dataset_path, labels_path, dataset_name, sampling_str
for idx, instance in enumerate(instances):
example = sample_template.copy()
instruction = self._get_instruction('NER')
instruction += "Option:" + labels_str + " \n " + "Text: " + "{0}" + "\n" + "Answer: "
instruction += "Option:" + labels_str + " \n" + "Text: " + "{0}" + "\n" + "Answer:"
kv_pairs = []

for entity in instance['entities']:
Expand Down Expand Up @@ -285,7 +284,7 @@ def load_RE_dataset(self, dataset_path, labels_path, dataset_name, sampling_stra
for idx, instance in enumerate(instances):
example = sample_template.copy()
instruction = self._get_instruction('RE')
instruction += "Option:" + labels_str + " \n " + "Text: " + "{0}" + "\n" + "Answer: "
instruction += "Option:" + labels_str + " \n" + "Text: " + "{0}" + "\n" + "Answer:"
relation_pairs = []

for relation in instance['relations']:
Expand Down Expand Up @@ -320,7 +319,7 @@ def load_EE_dataset(self, dataset_path, labels_path, dataset_name, sampling_stra
for idx, instance in enumerate(instances):
example = sample_template.copy()
instruction = self._get_instruction('RE')
instruction += "Option:" + labels_str + " \n " + "Text: " + "{0}" + "\n" + "Answer: "
instruction += "Option:" + labels_str + " \n" + "Text: " + "{0}" + "\n" + "Answer:"
event_pairs = []

for k, event in enumerate(instance['events']):
Expand Down

0 comments on commit 7b31fbd

Please sign in to comment.