forked from NVlabs/EAGLE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
convert_vqav2_for_submission.py
66 lines (50 loc) · 1.99 KB
/
convert_vqav2_for_submission.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import argparse
import json
# for debug
import sys
sys.path.append(os.getcwd())
from llava.eval.m4c_evaluator import EvalAIAnswerProcessor
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--src', type=str, required=True)
parser.add_argument('--save_path', type=str, required=True)
parser.add_argument('--dir', type=str, default="./playground/data/eval/vqav2")
parser.add_argument('--ckpt', type=str, required=True)
parser.add_argument('--split', type=str, required=True)
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
# src = os.path.join(args.dir, 'answers', args.split, args.ckpt, 'merge.jsonl')
src = args.src
test_split = os.path.join(args.dir, 'llava_vqav2_mscoco_test2015.jsonl')
# dst = os.path.join(args.dir, 'answers_upload', args.split, f'vqav2_test_{args.ckpt}.json')
dst = args.save_path
os.makedirs(os.path.dirname(dst), exist_ok=True)
results = []
error_line = 0
for line_idx, line in enumerate(open(src)):
try:
results.append(json.loads(line))
except:
error_line += 1
results = {x['question_id']: x['text'] for x in results}
test_split = [json.loads(line) for line in open(test_split)]
split_ids = set([x['question_id'] for x in test_split])
print(f'total results: {len(results)}, total split: {len(test_split)}, error_line: {error_line}')
all_answers = []
answer_processor = EvalAIAnswerProcessor()
for x in test_split:
if x['question_id'] not in results:
all_answers.append({
'question_id': x['question_id'],
'answer': ''
})
else:
all_answers.append({
'question_id': x['question_id'],
'answer': answer_processor(results[x['question_id']])
})
with open(dst, 'w') as f:
json.dump(all_answers, open(dst, 'w'))
print(f"successfully saving results to {dst}")