Skip to content

Commit

Permalink
Merge pull request EvolvingLMMs-Lab#247 from EvolvingLMMs-Lab/dev/fix…
Browse files Browse the repository at this point in the history
…_tags

Add new LMMS evaluation task for wild vision benchmark
  • Loading branch information
Luodian authored Sep 13, 2024
2 parents 8ec13bd + 8541b16 commit e77fb31
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 17 deletions.
1 change: 0 additions & 1 deletion lmms_eval/tasks/ok_vqa/ok_vqa_val2014.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
group: ok_vqa
task: ok_vqa_val2014
test_split: val2014
include: _default_template_vqa_yaml
1 change: 0 additions & 1 deletion lmms_eval/tasks/ok_vqa/ok_vqa_val2014_lite.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
group: ok_vqa
task: ok_vqa_val2014_lite
test_split: lite
dataset_path: lmms-lab/LMMs-Eval-Lite
Expand Down
1 change: 0 additions & 1 deletion lmms_eval/tasks/vizwiz_vqa/vizwiz_vqa_test.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
group: vizwiz_vqa
task: vizwiz_vqa_test
test_split: test
include: _default_template_vqa_yaml
Expand Down
1 change: 0 additions & 1 deletion lmms_eval/tasks/vizwiz_vqa/vizwiz_vqa_val.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
group: vizwiz_vqa
task: vizwiz_vqa_val
test_split: val
include: _default_template_vqa_yaml
Expand Down
177 changes: 164 additions & 13 deletions lmms_eval/tasks/wild_vision_bench/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import re
import time
from collections import defaultdict
from copy import deepcopy
from io import BytesIO
from pathlib import Path
Expand All @@ -11,6 +12,7 @@
import requests
import yaml
from loguru import logger as eval_logger
from scipy import stats

NUM_SECONDS_TO_SLEEP = 5

Expand Down Expand Up @@ -97,7 +99,7 @@ def get_chat_response(base64_image, prompt, max_retries=5, wait_time=10):
response = requests.post(API_URL, headers=headers, json=payload, timeout=60)
response.raise_for_status()
response_data = response.json()
print(response_data)
# print(response_data)
return response_data["choices"][0]["message"]["content"], GPT_EVAL_MODEL_NAME
except requests.exceptions.RequestException as e:
print(f"Request failed on attempt {attempt+1}: {e}")
Expand Down Expand Up @@ -156,30 +158,179 @@ def wild_vision_process_results(doc, results):
score = resps

if "A>B" in score:
final_score = -1
winner = "model_a"
judgement = "Worse" # Baseline better
elif "A>>B" in score:
final_score = -2
winner = "model_a"
judgement = "Worse++"
elif "A=B" in score:
final_score = 0
winner = "tie"
judgement = "Tie"
elif "B>A" in score:
final_score = 1
winner = "model_b"
judgement = "Better"
elif "B>>A" in score:
final_score = 2
winner = "model_b"
judgement = "Better++"
else:
final_score = 0
winner = "tie"
judgement = "Unclear"

return {"gpt_eval_score": {"question": doc["instruction"], "score": final_score, "gpt_resps": resps, "ans_1": doc[BASELINE_MODEL_NAME], "ans_2": pred, "filtered_resps": score, "judgement": judgement}}
return {
"elo_scores": {
"question": doc["instruction"],
"model_a": BASELINE_MODEL_NAME,
"model_b": "evaluation_model",
"winner": winner,
"gpt_resps": resps,
"model_resps": pred,
"judgement": judgement,
},
"win_rates": {
"question": doc["instruction"],
"model_a": BASELINE_MODEL_NAME,
"model_b": "evaluation_model",
"winner": winner,
},
"judgements_better": {
"judgement": judgement,
},
"judgements_better_plus": {
"judgement": judgement,
},
"judgements_worse": {
"judgement": judgement,
},
"judgements_worse_plus": {
"judgement": judgement,
},
"judgements_tie": {
"judgement": judgement,
},
"judgements_unclear": {
"judgement": judgement,
},
}


import math

import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression


def prepare_elo_data(results):
battles = []
for result in results:
battles.append({"model_a": result["model_a"], "model_b": result["model_b"], "winner": result["winner"]})
return pd.DataFrame(battles)


def compute_mle_elo(df, baseline, SCALE=400, BASE=10, INIT_RATING=1000):
models = pd.concat([df["model_a"], df["model_b"]]).unique()
models = pd.Series(np.arange(len(models)), index=models)

# duplicate battles
df = pd.concat([df, df], ignore_index=True)
p = len(models.index)
n = df.shape[0]

X = np.zeros([n, p])
X[np.arange(n), models[df["model_a"]]] = +math.log(BASE)
X[np.arange(n), models[df["model_b"]]] = -math.log(BASE)

# one A win => two A win
Y = np.zeros(n)
Y[df["winner"] == "model_a"] = 1.0

# one tie => one A win + one B win
# find tie + tie (both bad) index
tie_idx = (df["winner"] == "tie") | (df["winner"] == "tie (bothbad)")
tie_idx[len(tie_idx) // 2 :] = False
Y[tie_idx] = 1.0

lr = LogisticRegression(fit_intercept=False, penalty=None, tol=1e-8)
try:
lr.fit(X, Y)
elo_scores = SCALE * lr.coef_[0] + INIT_RATING
except ValueError as e:
eval_logger.warning(f"Error in LogisticRegression: {e}")
eval_logger.warning("Falling back to default ELO scores")
elo_scores = np.full(p, INIT_RATING)

# set anchor as gpt-4-0314 = 1000
if baseline in models.index:
elo_scores += 1000 - elo_scores[models[baseline]]

# Create a DataFrame with "model" and "score" columns
elo_df = pd.DataFrame({"model": models.index, "score": elo_scores})

return elo_df.sort_values("score", ascending=False)


def predict_win_rate(elo_ratings, SCALE=400, BASE=10, INIT_RATING=1000):
names = sorted(list(elo_ratings.keys()))
wins = defaultdict(lambda: defaultdict(lambda: 0))
for a in names:
for b in names:
ea = 1 / (1 + BASE ** ((elo_ratings[b] - elo_ratings[a]) / SCALE))
wins[a][b] = ea
wins[b][a] = 1 - ea

data = {a: [wins[a][b] if a != b else np.nan for b in names] for a in names}

df = pd.DataFrame(data, index=names)
df.index.name = "model_a"
df.columns.name = "model_b"
return df.T


def get_win_rate_column(df, column, baseline):
to_dict = df.set_index("model")[column].to_dict()
win_rate_table = predict_win_rate(to_dict)
return win_rate_table[baseline].fillna(0.5).apply(lambda x: round(x * 100, 2))


def wild_vision_aggregation_elo_scores(results):
battles = prepare_elo_data(results)
elo_ratings = compute_mle_elo(battles, BASELINE_MODEL_NAME)
elo_score = get_win_rate_column(elo_ratings, "score", BASELINE_MODEL_NAME)
return elo_score["evaluation_model"]


def wild_vision_aggregation_win_rates(results):
battles = prepare_elo_data(results)
win_rates = battles.groupby("model_b").apply(lambda x: (x["winner"] == "model_b").mean()).to_dict()
win_rates[BASELINE_MODEL_NAME] = battles.groupby("model_a").apply(lambda x: (x["winner"] == "model_a").mean()).get(BASELINE_MODEL_NAME, 0)
return win_rates["evaluation_model"] * 100


def wild_vision_aggregation_judgements_better(results):
judgements = pd.DataFrame(results)["judgement"].value_counts(normalize=True).to_dict()
return judgements["Better"] * 100 if "Better" in judgements else 0


def wild_vision_aggregation_judgements_better_plus(results):
judgements = pd.DataFrame(results)["judgement"].value_counts(normalize=True).to_dict()
return judgements["Better++"] * 100 if "Better++" in judgements else 0


def wild_vision_aggregation_judgements_worse(results):
judgements = pd.DataFrame(results)["judgement"].value_counts(normalize=True).to_dict()
return judgements["Worse"] * 100 if "Worse" in judgements else 0


def wild_vision_aggregation_judgements_worse_plus(results):
judgements = pd.DataFrame(results)["judgement"].value_counts(normalize=True).to_dict()
return judgements["Worse++"] * 100 if "Worse++" in judgements else 0


def wild_vision_aggregation_judgements_tie(results):
judgements = pd.DataFrame(results)["judgement"].value_counts(normalize=True).to_dict()
return judgements["Tie"] * 100 if "Tie" in judgements else 0

def wild_vision_aggregation(results):
score = 0
for res in results:
score += res["score"]

return score / len(results)
def wild_vision_aggregation_judgements_unclear(results):
judgements = pd.DataFrame(results)["judgement"].value_counts(normalize=True).to_dict()
return judgements["Unclear"] * 100 if "Unclear" in judgements else 0
9 changes: 9 additions & 0 deletions lmms_eval/tasks/wild_vision_bench/wild_vision_bench0630.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
task: wildvision_0630
dataset_name: release_bench_0630_with_modelresponse
test_split: test500
output_type: generate_until
include: _default_template_yaml
lmms_eval_specific_kwargs:
default:
pre_prompt: ""
post_prompt: ""
4 changes: 4 additions & 0 deletions lmms_eval/tasks/wild_vision_bench/wildvision_bench.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
group: wildvision
task:
- wildvision_0617
- wildvision_0630

0 comments on commit e77fb31

Please sign in to comment.