Skip to content

Commit

Permalink
add key2gen metric. pre-training dataset add dailydialog
Browse files Browse the repository at this point in the history
  • Loading branch information
zqwerty committed May 11, 2022
1 parent a1a2f24 commit e265e86
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 6 deletions.
45 changes: 45 additions & 0 deletions convlab2/base_models/gpt/keyword_extraction/eval_key2gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import json
import datasets
from tabulate import tabulate

def main(predict_result):
data = {
"keywords": {
"positive_keywords": [], "negative_keywords": None,
"predictions": [], "references": []
},
"possible keywords": {
"positive_keywords": [], "negative_keywords": [],
"predictions": [], "references": []
}
}
with open(predict_result) as f:
for line in f:
item = json.loads(line)
if item["keywords+context"].startswith("keywords"):
data["keywords"]["predictions"].append(item['predictions'].strip())
data["keywords"]["references"].append(item['response'].strip())
positive_keywords = [k for k in item['keywords+context'].split('\n\n')[0][len("keywords: "):].split(' | ') if len(k) > 0]
data["keywords"]["positive_keywords"].append(positive_keywords)
elif item["keywords+context"].startswith("possible keywords"):
data["possible keywords"]["predictions"].append(item['predictions'].strip())
data["possible keywords"]["references"].append(item['response'].strip())
possible_keywords = [k for k in item['keywords+context'].split('\n\n')[0][len("possible keywords: "):].split(' | ') if len(k) > 0]
for keyword in positive_keywords:
possible_keywords.remove(keyword)
data["possible keywords"]["positive_keywords"].append(positive_keywords)
data["possible keywords"]["negative_keywords"].append(possible_keywords)
metric = datasets.load_metric('./key2gen_metric.py')
table = [{'prompt': "keywords", **metric.compute(**data["keywords"])}]
if len(data["possible keywords"]["predictions"]) > 0:
table.append({'prompt': "possible keywords", **metric.compute(**data["possible keywords"])})
print(tabulate(table, headers='keys', tablefmt='github'))


if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser(description="evaluate keywords to response generation performance")
parser.add_argument('--predict_result', '-p', type=str, required=True, help='path to the output file generated_predictions.json')
args = parser.parse_args()
print(args)
main(args.predict_result)
103 changes: 103 additions & 0 deletions convlab2/base_models/gpt/keyword_extraction/key2gen_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""key2gen Metric"""

import datasets
import sacrebleu

# TODO: Add BibTeX citation
_CITATION = """\
@inproceedings{post-2018-call,
title = "A Call for Clarity in Reporting {BLEU} Scores",
author = "Post, Matt",
booktitle = "Proceedings of the Third Conference on Machine Translation: Research Papers",
month = oct,
year = "2018",
address = "Belgium, Brussels",
publisher = "Association for Computational Linguistics",
url = "https://www.aclweb.org/anthology/W18-6319",
pages = "186--191",
}
"""

_DESCRIPTION = """\
Metric to evaluate text-to-text models on the keywords grounded generation task.
"""

_KWARGS_DESCRIPTION = """
Calculates corpus-bleu4, positive keywords recall, negative keywords recall
Args:
positive_keywords: list of keywords (list of string) in the ground truth references
negative_keywords: list of keywords (list of string) in the random sampled references
predictions: list of predictions to score. Each predictions
should be a string.
references: list of reference for each prediction. Each
reference should be a string.
Returns:
bleu: corpus-bleu score
positive_keywords_recall: how many keywords in the ground truth response are generated, micro-averaged
negative_keywords_recall: how many keywords in the random sampled response are generated, micro-averaged
Examples:
>>> key2gen_metric = datasets.load_metric("key2gen_metric.py")
>>> predictions = ["hello there general kenobi", "foo bar foobar"]
>>> references = ["hello there kenobi", "foo bar foobar"]
>>> results = nlg_metric.compute(predictions=predictions, references=references)
>>> print(results)
{'bleu': 35.35533905932737}
"""


@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class Key2GenMetrics(datasets.Metric):
"""Metric to evaluate text-to-text models on the keywords grounded generation task."""

def _info(self):
return datasets.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
# This defines the format of each prediction and reference
features=datasets.Features({
'predictions': datasets.Value('string'),
'references': datasets.Value('string'),
})
)

def _compute(self, predictions, references, positive_keywords, negative_keywords=None):
"""Returns the scores: bleu, positive_keywords_recall, negative_keywords_recall"""
if not negative_keywords:
negative_keywords = [[]] * len(positive_keywords)
bleu = sacrebleu.corpus_bleu(predictions, [references], lowercase=True).score
cnt = {'pos': 0, 'neg': 0, 'pos_recall': 0, 'neg_recall': 0}
for poskeys, negkeys, prediction in zip(positive_keywords, negative_keywords, predictions):
cnt['pos'] += len(poskeys)
cnt['neg'] += len(negkeys)

prediction = prediction.lower()
for key in poskeys:
key = key.lower()
if key in prediction:
cnt['pos_recall'] += 1

for key in negkeys:
key = key.lower()
if key in prediction:
cnt['neg_recall'] += 1

return {
"bleu": bleu,
"positive_keywords_recall": cnt['pos_recall']/cnt['pos'] if cnt['pos'] > 0 else 0,
"negative_keywords_recall": cnt['neg_recall']/cnt['neg'] if cnt['neg'] > 0 else 0,
}
17 changes: 12 additions & 5 deletions convlab2/base_models/gpt/keyword_extraction/test_t5_key2gen.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
set -e
n_gpus=1
task_name="key2gen"
dataset_name="multiwoz21"
n_gpus=2
task_name="key2gen_shuffle_noisy"
dataset_name="dailydialog+metalwoz+sgd+tm1+tm2+tm3"
speaker="all"
model_type="gpt"
data_dir="data/${task_name}/${model_type}/${dataset_name}"
Expand All @@ -16,7 +16,7 @@ target_column="response"
truncation_side="left"
max_source_length=512
max_target_length=128
model_name_or_path="output/key2gen/gpt/metalwoz+sgd+tm1+tm2+tm3"
model_name_or_path="output/${task_name}/${model_type}/${dataset_name}"
per_device_train_batch_size=128
per_device_eval_batch_size=128
gradient_accumulation_steps=4
Expand All @@ -40,4 +40,11 @@ python -m torch.distributed.launch \
--logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--per_device_eval_batch_size ${per_device_eval_batch_size}
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
--debug underflow_overflow \
--adafactor \
--gradient_checkpointing
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ gradient_accumulation_steps=4
lr=1e-3
num_train_epochs=1

python -m torch.distributed.launch --master_port 23456\
python -m torch.distributed.launch \
--nproc_per_node ${n_gpus} ../../t5/run_seq2seq.py \
--task_name ${task_name} \
--train_file ${train_file} \
Expand Down

0 comments on commit e265e86

Please sign in to comment.