-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
1,613 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
|
||
# Download Hotpot Data | ||
wget http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_distractor_v1.json | ||
wget http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_fullwiki_v1.json | ||
wget http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_train_v1.1.json | ||
|
||
# Download GloVe | ||
GLOVE_DIR=./ | ||
mkdir -p $GLOVE_DIR | ||
wget http://nlp.stanford.edu/data/glove.840B.300d.zip -O $GLOVE_DIR/glove.840B.300d.zip | ||
unzip $GLOVE_DIR/glove.840B.300d.zip -d $GLOVE_DIR | ||
|
||
# Download Spacy language models | ||
python3 -m spacy download en |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
import sys | ||
import ujson as json | ||
import re | ||
import string | ||
from collections import Counter | ||
import pickle | ||
|
||
def normalize_answer(s): | ||
|
||
def remove_articles(text): | ||
return re.sub(r'\b(a|an|the)\b', ' ', text) | ||
|
||
def white_space_fix(text): | ||
return ' '.join(text.split()) | ||
|
||
def remove_punc(text): | ||
exclude = set(string.punctuation) | ||
return ''.join(ch for ch in text if ch not in exclude) | ||
|
||
def lower(text): | ||
return text.lower() | ||
|
||
return white_space_fix(remove_articles(remove_punc(lower(s)))) | ||
|
||
|
||
def f1_score(prediction, ground_truth): | ||
normalized_prediction = normalize_answer(prediction) | ||
normalized_ground_truth = normalize_answer(ground_truth) | ||
|
||
ZERO_METRIC = (0, 0, 0) | ||
|
||
if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: | ||
return ZERO_METRIC | ||
if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: | ||
return ZERO_METRIC | ||
|
||
prediction_tokens = normalized_prediction.split() | ||
ground_truth_tokens = normalized_ground_truth.split() | ||
common = Counter(prediction_tokens) & Counter(ground_truth_tokens) | ||
num_same = sum(common.values()) | ||
if num_same == 0: | ||
return ZERO_METRIC | ||
precision = 1.0 * num_same / len(prediction_tokens) | ||
recall = 1.0 * num_same / len(ground_truth_tokens) | ||
f1 = (2 * precision * recall) / (precision + recall) | ||
return f1, precision, recall | ||
|
||
|
||
def exact_match_score(prediction, ground_truth): | ||
return (normalize_answer(prediction) == normalize_answer(ground_truth)) | ||
|
||
def update_answer(metrics, prediction, gold): | ||
em = exact_match_score(prediction, gold) | ||
f1, prec, recall = f1_score(prediction, gold) | ||
metrics['em'] += float(em) | ||
metrics['f1'] += f1 | ||
metrics['prec'] += prec | ||
metrics['recall'] += recall | ||
return em, prec, recall | ||
|
||
def update_sp(metrics, prediction, gold): | ||
cur_sp_pred = set(map(tuple, prediction)) | ||
gold_sp_pred = set(map(tuple, gold)) | ||
tp, fp, fn = 0, 0, 0 | ||
for e in cur_sp_pred: | ||
if e in gold_sp_pred: | ||
tp += 1 | ||
else: | ||
fp += 1 | ||
for e in gold_sp_pred: | ||
if e not in cur_sp_pred: | ||
fn += 1 | ||
prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0 | ||
recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0 | ||
f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0 | ||
em = 1.0 if fp + fn == 0 else 0.0 | ||
metrics['sp_em'] += em | ||
metrics['sp_f1'] += f1 | ||
metrics['sp_prec'] += prec | ||
metrics['sp_recall'] += recall | ||
return em, prec, recall | ||
|
||
def eval(prediction_file, gold_file): | ||
with open(prediction_file) as f: | ||
prediction = json.load(f) | ||
with open(gold_file) as f: | ||
gold = json.load(f) | ||
|
||
metrics = {'em': 0, 'f1': 0, 'prec': 0, 'recall': 0, | ||
'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0, | ||
'joint_em': 0, 'joint_f1': 0, 'joint_prec': 0, 'joint_recall': 0} | ||
for dp in gold: | ||
cur_id = dp['_id'] | ||
can_eval_joint = True | ||
if cur_id not in prediction['answer']: | ||
print('missing answer {}'.format(cur_id)) | ||
can_eval_joint = False | ||
else: | ||
em, prec, recall = update_answer( | ||
metrics, prediction['answer'][cur_id], dp['answer']) | ||
if cur_id not in prediction['sp']: | ||
print('missing sp fact {}'.format(cur_id)) | ||
can_eval_joint = False | ||
else: | ||
sp_em, sp_prec, sp_recall = update_sp( | ||
metrics, prediction['sp'][cur_id], dp['supporting_facts']) | ||
|
||
if can_eval_joint: | ||
joint_prec = prec * sp_prec | ||
joint_recall = recall * sp_recall | ||
if joint_prec + joint_recall > 0: | ||
joint_f1 = 2 * joint_prec * joint_recall / (joint_prec + joint_recall) | ||
else: | ||
joint_f1 = 0. | ||
joint_em = em * sp_em | ||
|
||
metrics['joint_em'] += joint_em | ||
metrics['joint_f1'] += joint_f1 | ||
metrics['joint_prec'] += joint_prec | ||
metrics['joint_recall'] += joint_recall | ||
|
||
N = len(gold) | ||
for k in metrics.keys(): | ||
metrics[k] /= N | ||
|
||
print(metrics) | ||
|
||
if __name__ == '__main__': | ||
eval(sys.argv[1], sys.argv[2]) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import os | ||
from prepro import prepro | ||
from run import train, test | ||
import argparse | ||
|
||
parser = argparse.ArgumentParser() | ||
|
||
glove_word_file = "glove.840B.300d.txt" | ||
|
||
word_emb_file = "word_emb.json" | ||
char_emb_file = "char_emb.json" | ||
train_eval = "train_eval.json" | ||
dev_eval = "dev_eval.json" | ||
test_eval = "test_eval.json" | ||
word2idx_file = "word2idx.json" | ||
char2idx_file = "char2idx.json" | ||
idx2word_file = 'idx2word.json' | ||
idx2char_file = 'idx2char.json' | ||
train_record_file = 'train_record.pkl' | ||
dev_record_file = 'dev_record.pkl' | ||
test_record_file = 'test_record.pkl' | ||
|
||
|
||
parser.add_argument('--mode', type=str, default='train') | ||
parser.add_argument('--data_file', type=str) | ||
parser.add_argument('--glove_word_file', type=str, default=glove_word_file) | ||
parser.add_argument('--save', type=str, default='HOTPOT') | ||
|
||
parser.add_argument('--word_emb_file', type=str, default=word_emb_file) | ||
parser.add_argument('--char_emb_file', type=str, default=char_emb_file) | ||
parser.add_argument('--train_eval_file', type=str, default=train_eval) | ||
parser.add_argument('--dev_eval_file', type=str, default=dev_eval) | ||
parser.add_argument('--test_eval_file', type=str, default=test_eval) | ||
parser.add_argument('--word2idx_file', type=str, default=word2idx_file) | ||
parser.add_argument('--char2idx_file', type=str, default=char2idx_file) | ||
parser.add_argument('--idx2word_file', type=str, default=idx2word_file) | ||
parser.add_argument('--idx2char_file', type=str, default=idx2char_file) | ||
|
||
parser.add_argument('--train_record_file', type=str, default=train_record_file) | ||
parser.add_argument('--dev_record_file', type=str, default=dev_record_file) | ||
parser.add_argument('--test_record_file', type=str, default=test_record_file) | ||
|
||
parser.add_argument('--glove_char_size', type=int, default=94) | ||
parser.add_argument('--glove_word_size', type=int, default=int(2.2e6)) | ||
parser.add_argument('--glove_dim', type=int, default=300) | ||
parser.add_argument('--char_dim', type=int, default=8) | ||
|
||
parser.add_argument('--para_limit', type=int, default=1000) | ||
parser.add_argument('--ques_limit', type=int, default=80) | ||
parser.add_argument('--sent_limit', type=int, default=100) | ||
parser.add_argument('--char_limit', type=int, default=16) | ||
|
||
parser.add_argument('--batch_size', type=int, default=64) | ||
parser.add_argument('--checkpoint', type=int, default=1000) | ||
parser.add_argument('--period', type=int, default=100) | ||
parser.add_argument('--init_lr', type=float, default=0.5) | ||
parser.add_argument('--keep_prob', type=float, default=0.8) | ||
parser.add_argument('--hidden', type=int, default=80) | ||
parser.add_argument('--char_hidden', type=int, default=100) | ||
parser.add_argument('--patience', type=int, default=1) | ||
parser.add_argument('--seed', type=int, default=13) | ||
|
||
parser.add_argument('--sp_lambda', type=float, default=0.0) | ||
|
||
parser.add_argument('--data_split', type=str, default='train') | ||
parser.add_argument('--fullwiki', action='store_true') | ||
parser.add_argument('--prediction_file', type=str) | ||
parser.add_argument('--sp_threshold', type=float, default=0.3) | ||
|
||
config = parser.parse_args() | ||
|
||
def _concat(filename): | ||
if config.fullwiki: | ||
return 'fullwiki.{}'.format(filename) | ||
return filename | ||
# config.train_record_file = _concat(config.train_record_file) | ||
config.dev_record_file = _concat(config.dev_record_file) | ||
config.test_record_file = _concat(config.test_record_file) | ||
# config.train_eval_file = _concat(config.train_eval_file) | ||
config.dev_eval_file = _concat(config.dev_eval_file) | ||
config.test_eval_file = _concat(config.test_eval_file) | ||
|
||
if config.mode == 'train': | ||
train(config) | ||
elif config.mode == 'prepro': | ||
prepro(config) | ||
elif config.mode == 'test': | ||
test(config) | ||
elif config.mode == 'count': | ||
cnt_len(config) |
Oops, something went wrong.