Skip to content

Commit

Permalink
[Fix] Vqa fix complete,fix vqa score calculation in vlmeval.evaluate.…
Browse files Browse the repository at this point in the history
…vqa_eval (open-compass#114)

* Update vqa_eval.py

modify vqa score calculation in hit_calculate function

* fix vqa score calculation in vlmeval.evaluate.vqa_eval
  • Loading branch information
FangXinyu-0913 authored Mar 15, 2024
1 parent e7af632 commit 44dd617
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions vlmeval/evaluate/vqa_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,16 +158,16 @@ def _process_digit_article(inText):
return outText


def hit_calculate(result, dataset_name, full_score_weight=0.3, anls_threshold = 0.5):
def hit_calculate(result, dataset_name, vqa_score_threshold = 3, anls_threshold = 0.5):
if listinstr(['TextVQA'], dataset_name):
return [np.mean(x['match']) >= full_score_weight for x in result]
return [np.mean(x['match']) for x in result]
elif listinstr(['DocVQA'], dataset_name):
# return [1 - np.min(x['match']) >= anls_threshold for x in result]
return [0.0 if 1 - np.min(x['match']) < anls_threshold else 1 - np.min(x['match']) for x in result ]
elif listinstr(['ChartQA','OCRVQA'], dataset_name):
return [np.max(x['match']) for x in result]
else: #default using vqa_score to calculate score
return [np.mean(x['match']) >= full_score_weight for x in result]
return [np.mean(x['match']) for x in result]

# https://github.com/google-research/pix2struct/blob/main/pix2struct/metrics.py#L81
def relaxed_correctness(target: str,
Expand Down Expand Up @@ -253,7 +253,17 @@ def process_line(line, method = 'vqa_score'):
if method == 'vqa_score':
ret['gt'] = [process_answer(x) for x in answers]
ret['pred'] = process_answer(line['prediction'])
ret['match'] = [x == ret['pred'] for x in ret['gt']]
ret['match'] = []
for current_idx, gtAnsDatum in enumerate(ret['gt']):
otherGTAns = [
item for ret_gt_idx, item in enumerate(ret['gt'])
if ret_gt_idx != current_idx
]
matchingAns = [
item for item in otherGTAns if item == ret['pred']
]
acc = min(1, float(len(matchingAns)) / 3)
ret['match'].append(acc)
elif method == 'anls':
ret['gt'] = answers
ret['pred'] = line['prediction']
Expand Down

0 comments on commit 44dd617

Please sign in to comment.