Skip to content

Commit

Permalink
[Feature] add normalization for ppl inference
Browse files Browse the repository at this point in the history
  • Loading branch information
wyx authored and wyx committed Mar 26, 2023
1 parent 21d4894 commit cd48038
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 11 deletions.
51 changes: 46 additions & 5 deletions openicl/icl_inferencer/icl_ppl_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,13 @@ def __init__(self,

def inference(self, retriever: BaseRetriever, ice_template: Optional[PromptTemplate] = None,
prompt_template: Optional[PromptTemplate] = None, output_json_filepath: Optional[str] = None,
output_json_filename: Optional[str] = None) -> List:
output_json_filename: Optional[str] = None, normalizing_str: Optional[str] = None) -> List:
# 1. Preparation for output logs
output_handler = PPLInferencerOutputHandler(self.accelerator)

sub_predictions = []
ppl = []
ice = []
index = 0

if output_json_filepath is None:
output_json_filepath = self.output_json_filepath
Expand All @@ -87,11 +86,14 @@ def inference(self, retriever: BaseRetriever, ice_template: Optional[PromptTempl
index = 0
prompt_list = []
sub_ppl_list = []
normalizing_prompt_list = []
context_length_list = []

# 5.1 Generate prompts of current label and truncate
for idx in range(len(ice_idx_list)):
prompt = retriever.generate_label_prompt(idx, ice[idx], label, ice_template=ice_template,
prompt_template=prompt_template)
prompt_template=prompt_template,
remain_sep=normalizing_str is not None)
if self.max_model_token_num is not None and self.api_name != 'gpt3':
prompt_token_num = self.get_input_token_num(prompt)
while len(ice_idx_list[idx]) > 0 and prompt_token_num > self.max_model_token_num:
Expand All @@ -100,14 +102,44 @@ def inference(self, retriever: BaseRetriever, ice_template: Optional[PromptTempl
prompt = retriever.generate_label_prompt(idx, ice[idx], label, ice_template=ice_template,
prompt_template=prompt_template)
prompt_token_num = self.get_input_token_num(prompt)

if normalizing_str is not None:
prompt_sep = prompt
if prompt_template is not None:
sep_token = prompt_template.sep_token
else:
sep_token = ice_template.sep_token
sep_pos = prompt_sep.find(sep_token)

context = prompt_sep[0:sep_pos]
answer = prompt_sep[sep_pos:].replace(sep_token, '')
prompt = context + answer
normalizing_prompt = normalizing_str + answer

context_length_list.append(self.get_input_token_num(context))
normalizing_prompt_list.append(normalizing_prompt)
prompt_list.append(prompt)

if normalizing_str is not None:
normalizing_str_len = self.get_input_token_num(normalizing_str)

# 5.2 Get PPL
logger.info(f"Calculating PPL for prompts labeled '{label}'")
for idx in trange(0, len(prompt_list), self.batch_size, disable=not self.is_main_process):
sub_prompt_list = prompt_list[idx:idx + self.batch_size]
if normalizing_str is not None:
sub_context_length_list = context_length_list[idx:idx + self.batch_size]
sub_normalizing_prompt_list = normalizing_prompt_list[idx:idx + self.batch_size]

with torch.no_grad():
sub_res = self.__get_ppl(sub_prompt_list).tolist()
if normalizing_str is not None:
res1 = self.__get_ppl(input_texts=sub_prompt_list, mask_length=sub_context_length_list)
res2 = self.__get_ppl(input_texts=sub_normalizing_prompt_list,
mask_length=[normalizing_str_len for i in range(len(sub_prompt_list))]
)
sub_res = res1 - res2
else:
sub_res = self.__get_ppl(sub_prompt_list).tolist()
for res, prompt in zip(sub_res, sub_prompt_list):
sub_ppl_list.append(res)
output_handler.save_prompt_and_ppl(label, prompt[len(ice[idx]):], prompt, res, index)
Expand All @@ -129,7 +161,7 @@ def inference(self, retriever: BaseRetriever, ice_template: Optional[PromptTempl

return [sample['prediction'] for sample in output_handler.results_dict.values()]

def __get_ppl(self, input_texts: List[str]):
def __get_ppl(self, input_texts: List[str], mask_length=None):
if self.call_api:
return api_get_ppl(self.api_name, input_texts)
self.tokenizer.padding_side = "right"
Expand All @@ -144,6 +176,15 @@ def __get_ppl(self, input_texts: List[str]):
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(
shift_labels.size())

if mask_length is not None:
mask = torch.zeros_like(shift_labels) # [batch,seqlen]
for i in range(len(mask)):
for j in range(mask_length[i] - 1, len(mask[i])):
mask[i][j] = 1
loss = loss * mask

lens = (inputs["input_ids"] != self.tokenizer.pad_token_id).sum(-1).cpu().numpy()
if mask_length is not None:
lens -= np.array(mask_length)
ce_loss = loss.sum(-1).cpu().detach().numpy() / lens
return ce_loss
18 changes: 16 additions & 2 deletions openicl/icl_prompt_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ def __init__(self,
column_token_map: Dict,
selected_column_name: Optional[str] = None,
selected_column_map: Optional[Dict] = None,
ice_token: Optional[str] = None
ice_token: Optional[str] = None,
sep_token: Optional[str] = None,
) -> None:
self.template = _check_type_list(template, [Dict, str])
self.column_token_map = _check_dict(column_token_map)
self.selected_column_name = _check_type_list(selected_column_name, [None, str])
self.selected_column_map = _check_type_list(selected_column_map, [None, Dict])
self.ice_token = _check_type_list(ice_token, [None, str])
self.sep_token = _check_type_list(sep_token, [None, str])
if (self.selected_column_name is not None and self.selected_column_map is None) or \
self.selected_column_name is None and self.selected_column_map is not None:
raise ValueError("self.selected_column_name and self.selected_column_map should be set together")
Expand Down Expand Up @@ -63,6 +65,9 @@ def generate_ice_item(self, entry: Dict, label: Hashable) -> str:
"""
# Select the corresponding template
tp = self.template[label] if isinstance(self.template, Dict) else self.template
# Remove sep token
if self.sep_token is not None:
tp.replace(self.sep_token, '')
# Remove ice_token
if self.ice_token is not None:
tp = tp.replace(self.ice_token, '')
Expand All @@ -74,13 +79,15 @@ def generate_ice_item(self, entry: Dict, label: Hashable) -> str:
tp = tp.replace(token, str(entry[key]))
return tp

def generate_label_prompt_item(self, entry: Dict, ice: str, label: Hashable) -> str:
def generate_label_prompt_item(self, entry: Dict, ice: str, label: Hashable, remain_sep: Optional[bool] = False) -> str:
"""Generate prompt based on :obj:`entry` data, :obj:`ice` in-context example, and the corresponding :obj:`label`.
Args:
entry (:obj:`Dict`): A piece of data containing the input field content.
ice (:obj:`str`): The generated in-context example.
label (:obj:`Hashable`): The value of the output field.
remain_sep (:obj:`bool`): If remain sep_token
Raises:
ValueError: If the :obj:`ice_token` attribute of the :obj:`PromptTemplate` instance is :obj:`None`.
Expand All @@ -92,6 +99,9 @@ def generate_label_prompt_item(self, entry: Dict, ice: str, label: Hashable) ->
raise ValueError("PromptTemplate.ice_token should be not None when generates prompt")
# Select the corresponding template
tp = self.template[label] if isinstance(self.template, Dict) else self.template
# Remove sep token
if not remain_sep and self.sep_token is not None:
tp.replace(self.sep_token, '')
# Insert in-context examples
tp = tp.replace(self.ice_token, ice)
# Replace context token
Expand All @@ -102,6 +112,7 @@ def generate_label_prompt_item(self, entry: Dict, ice: str, label: Hashable) ->
tp = tp.replace(token, str(entry[key]))
return tp


def generate_item(self, entry: Dict, output_field: Optional[Hashable] = None,
output_field_replace_token: Optional[str] = '',
ice_field_replace_token: Optional[str] = '') -> str:
Expand Down Expand Up @@ -129,6 +140,9 @@ def generate_item(self, entry: Dict, output_field: Optional[Hashable] = None,
tp = self.template[list(self.template.keys())[0]]
if self.ice_token is not None:
tp = tp.replace(self.ice_token, ice_field_replace_token)
# Remove sep token
if self.sep_token is not None:
tp.replace(self.sep_token, '')
for key, token in self.column_token_map.items():
if output_field is not None and key == output_field:
tp = tp.replace(token, output_field_replace_token)
Expand Down
6 changes: 3 additions & 3 deletions openicl/icl_retriever/icl_base_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,11 @@ def generate_prompt(self, idx: int, ice: str, ice_template: Optional[PromptTempl
return prompt_list, labels

def generate_label_prompt(self, idx: int, ice: str, label, ice_template: Optional[PromptTemplate] = None,
prompt_template: Optional[PromptTemplate] = None) -> str:
prompt_template: Optional[PromptTemplate] = None, remain_sep: Optional[bool] = False) -> str:
if prompt_template is not None:
return prompt_template.generate_label_prompt_item(self.test_ds[idx], ice, label) + self.prompt_eos_token
return prompt_template.generate_label_prompt_item(self.test_ds[idx], ice, label, remain_sep) + self.prompt_eos_token
elif ice_template is not None and ice_template.ice_token is not None:
return ice_template.generate_label_prompt_item(self.test_ds[idx], ice, label) + self.prompt_eos_token
return ice_template.generate_label_prompt_item(self.test_ds[idx], ice, label, remain_sep) + self.prompt_eos_token
else:
prefix_prompt = ' '.join(
list(map(str, [self.test_ds[idx][ctx] for ctx in self.dataset_reader.input_columns])))
Expand Down
2 changes: 1 addition & 1 deletion openicl/icl_retriever/icl_mdl_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def topk_search(self):
def retrieve(self):
return self.topk_search()

def cal_ce(self, input_texts: List[List], mask_length=None):
def cal_ce(self, input_texts: List[str], mask_length=None):
if self.metric_model is None:
logger.info(f'Load model {self.metric_model} for calculating MDL...')
self.metric_model = AutoModelForCausalLM.from_pretrained(self.ce_model_name)
Expand Down

0 comments on commit cd48038

Please sign in to comment.