Skip to content

Commit

Permalink
update gpt4 eval scripts with batch evals
Browse files Browse the repository at this point in the history
  • Loading branch information
howard-yen committed Jan 13, 2025
1 parent 810b630 commit f4744f0
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 72 deletions.
30 changes: 18 additions & 12 deletions scripts/eval_gpt4_longqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,15 @@ def check_metrics(model, results_file, output_file):

sum_score = 0
count_score = 0
for idx, d in enumerate(tqdm(results["data"])):
p = judge_prompt.format(question=d['question'], correct_answers=d['answer'], parsed_output=parse_output(d['output']))

o = model.generate(prompt=p)
all_inputs = []
for d in results["data"]:
p = judge_prompt.format(question=d['question'], correct_answers=d['answer'], parsed_output=parse_output(d['output']))
all_inputs.append(p)

outputs = model.generate_batch(prompt=all_inputs, batch_file=output_file+".batch")
for idx, o in enumerate(outputs):
d = results["data"][idx]
s = None

if o is not None:
Expand All @@ -84,22 +89,22 @@ def check_metrics(model, results_file, output_file):
sum_score += scores["fluency"] * scores["correctness"]
count_score += 1

d["gpt4-scores"] = s
d["gpt-4-scores"] = s

if idx < 10:
print("=====================================")
print(f"Prompt: {p}")
print(f"Prompt: {all_inputs[idx]}")
print(f"Output: {o['output']}")
print(f"Final score: {s}")

results["averaged_metrics"]["gpt-4-score"] = sum_score / count_score
with open(output_file, "w") as f:
json.dump(results, f, indent=4)

return results

if __name__ == "__main__":
model = OpenAIModel("azure/gpt-4o-2024-05-13", temperature=0.1)
model = OpenAIModel("gpt-4o-2024-05-13", temperature=0.1)
parser = argparse.ArgumentParser()
parser.add_argument("--num_shards", type=int, default=1)
parser.add_argument("--shard_idx", type=int, default=0)
Expand All @@ -108,14 +113,15 @@ def check_metrics(model, results_file, output_file):
shard_idx = args.shard_idx

# instruct models
model_to_check = ['gpt-4-0125-preview', 'gpt-4o-2024-05-13', 'gpt-4o-2024-08-06', 'gpt-4o-mini-2024-07-18', 'claude-3-5-sonnet-20240620', 'gemini-1.5-flash-001', 'gemini-1.5-pro-001', 'Meta-Llama-3-8B-Instruct', 'Meta-Llama-3-8B-Instruct-Theta8M', 'Meta-Llama-3-70B-Instruct-Theta8M', 'Meta-Llama-3.1-8B-Instruct', 'Meta-Llama-3.1-70B-Instruct', 'Mistral-7B-Instruct-v0.1', 'Mistral-7B-Instruct-v0.2', 'Mistral-7B-Instruct-v0.3', 'Mistral-Nemo-Instruct-2407', 'Phi-3-mini-128k-instruct', 'Phi-3-small-128k-instruct', 'Phi-3-medium-128k-instruct', 'Phi-3.5-mini-instruct', 'Qwen2-7B-Instruct', 'Qwen2-57B-A14B-Instruct', 'c4ai-command-r-v01', 'AI21-Jamba-1.5-Mini', 'prolong-64k-instruct', 'prolong-512k-instruct-20b-theta128m', "MegaBeam-Mistral-7B-512k"]
model_to_check =['gpt-4-0125-preview','gpt-4o-mini-2024-07-18','gpt-4o-2024-05-13','gpt-4o-2024-08-06','claude-3-5-sonnet-20240620','gemini-1.5-flash-001','gemini-1.5-pro-001','Llama-2-7B-32K-Instruct','Meta-Llama-3-8B-Instruct','Meta-Llama-3-8B-Instruct-Theta16M','Meta-Llama-3-70B-Instruct-Theta16M','Llama-3.1-8B-Instruct','Llama-3.1-70B-Instruct','Llama-3.3-70B-Instruct','Llama-3.2-1B-Instruct','Llama-3.2-3B-Instruct','Mistral-7B-Instruct-v0.1','Mistral-7B-Instruct-v0.2','Mistral-7B-Instruct-v0.3','Ministral-8B-Instruct-2410','Mistral-Nemo-Instruct-2407','MegaBeam-Mistral-7B-512k','Phi-3-mini-128k-instruct','Phi-3-small-128k-instruct','Phi-3-medium-128k-instruct','Phi-3.5-mini-instruct','Qwen2-7B-Instruct','Qwen2-57B-A14B-Instruct','Qwen2.5-1.5B-Instruct','Qwen2.5-3B-Instruct','Qwen2.5-7B-Instruct','Qwen2.5-72B-Instruct','Llama-3-8B-ProLong-512k-Instruct','gemma-2-9b-it','gemma-2-9b-it-Theta320K','gemma-2-27b-it','gemma-2-27b-it-Theta320K','c4ai-command-r-v01','AI21-Jamba-1.5-Mini']

# all models
model_to_check = ['gpt-4-0125-preview', 'gpt-4o-mini-2024-07-18', 'gpt-4o-2024-05-13', 'gpt-4o-2024-08-06', 'claude-3-5-sonnet-20240620', 'gemini-1.5-flash-001', 'gemini-1.5-pro-001', 'LLaMA-2-7B-32K', 'Llama-2-7B-32K-Instruct', 'llama-2-7b-80k-basefixed', 'Yarn-Llama-2-7b-64k', 'Yarn-Llama-2-7b-128k', 'Meta-Llama-3-8B', 'Meta-Llama-3-8B-Instruct', 'Meta-Llama-3-8B-Theta8M', 'Meta-Llama-3-8B-Instruct-Theta8M', 'Meta-Llama-3-70B-Theta8M', 'Meta-Llama-3-70B-Instruct-Theta8M', 'Meta-Llama-3.1-8B', 'Meta-Llama-3.1-8B-Instruct', 'Meta-Llama-3.1-70B', 'Meta-Llama-3.1-70B-Instruct', 'Llama-3.2-1B', 'Llama-3.2-1B-Instruct', 'Llama-3.2-3B', 'Llama-3.2-3B-Instruct', 'Mistral-7B-v0.1', 'Mistral-7B-Instruct-v0.1', 'Mistral-7B-Instruct-v0.2', 'Mistral-7B-v0.3', 'Mistral-7B-Instruct-v0.3', 'Mistral-Nemo-Base-2407', 'Mistral-Nemo-Instruct-2407', 'MegaBeam-Mistral-7B-512k', 'Yi-6B-200K', 'Yi-9B-200K', 'Yi-34B-200K', 'Yi-1.5-9B-32K', 'Phi-3-mini-128k-instruct', 'Phi-3-small-128k-instruct', 'Phi-3-medium-128k-instruct', 'Phi-3.5-mini-instruct', 'Qwen2-7B', 'Qwen2-7B-Instruct', 'Qwen2-57B-A14B', 'Qwen2-57B-A14B-Instruct', 'c4ai-command-r-v01', 'Jamba-v0.1', 'AI21-Jamba-1.5-Mini', 'prolong-64k-instruct', 'prolong-512k-instruct-20b-theta128m']

# customize this line according to the file pahts that you want to check
all_paths = [glob.glob(f"output/{m}/narrativeqa_*.json") for m in model_to_check]
model_to_check = ['gpt-4-0125-preview','gpt-4o-mini-2024-07-18','gpt-4o-2024-05-13','gpt-4o-2024-08-06','claude-3-5-sonnet-20240620','gemini-1.5-flash-001','gemini-1.5-pro-001','Llama-2-7B-32K','Llama-2-7B-32K-Instruct','llama-2-7b-80k','Yarn-Llama-2-7b-64k','Yarn-Llama-2-7b-128k','Meta-Llama-3-8B','Meta-Llama-3-8B-Instruct','Meta-Llama-3-8B-Theta16M','Meta-Llama-3-8B-Instruct-Theta16M','Meta-Llama-3-70B-Theta16M','Meta-Llama-3-70B-Instruct-Theta16M','Llama-3.1-8B','Llama-3.1-8B-Instruct','Llama-3.1-70B','Llama-3.1-70B-Instruct','Llama-3.3-70B-Instruct','Llama-3.2-1B','Llama-3.2-1B-Instruct','Llama-3.2-3B','Llama-3.2-3B-Instruct','Mistral-7B-v0.1','Mistral-7B-Instruct-v0.1','Mistral-7B-Instruct-v0.2','Mistral-7B-v0.3','Mistral-7B-Instruct-v0.3','Ministral-8B-Instruct-2410','Mistral-Nemo-Base-2407','Mistral-Nemo-Instruct-2407','MegaBeam-Mistral-7B-512k','Yi-6B-200K','Yi-9B-200K','Yi-34B-200K','Yi-1.5-9B-32K','Phi-3-mini-128k-instruct','Phi-3-small-128k-instruct','Phi-3-medium-128k-instruct','Phi-3.5-mini-instruct','Qwen2-7B','Qwen2-7B-Instruct','Qwen2-57B-A14B','Qwen2-57B-A14B-Instruct','Qwen2.5-1.5B','Qwen2.5-1.5B-Instruct','Qwen2.5-3B','Qwen2.5-3B-Instruct','Qwen2.5-7B','Qwen2.5-7B-Instruct','Qwen2.5-72B-Instruct','Llama-3-8B-ProLong-512k-Instruct','gemma-2-9b','gemma-2-9b-it','gemma-2-9b-it-Theta320K','gemma-2-27b','gemma-2-27b-it','gemma-2-27b-it-Theta320K','c4ai-command-r-v01','Jamba-v0.1','AI21-Jamba-1.5-Mini']

# customize this line according to the file paths that you want to check
# in this example, we look for files with v1 tag
all_paths = [glob.glob(f"output/{m}/narrativeqa_*_v1_*.json") for m in model_to_check]
all_paths = [item for sublist in all_paths for item in sublist]
all_paths = [p for p in all_paths if not os.path.exists(p.replace(".json", "-gpt4eval_o.json"))]
all_paths = all_paths[shard_idx::num_shards]
print(f"Found {len(all_paths)} path")
Expand Down
2 changes: 1 addition & 1 deletion scripts/eval_gpt4_longqa.sh
Original file line number Diff line number Diff line change
@@ -1 +1 @@
for i in {0..15}; do python scripts/eval_gpt4_longqa.py --num_shards 16 --shard_idx $i & done
shards=30; for i in $(seq 0 $shards); do python scripts/eval_gpt4_longqa.py --num_shards $shards --shard_idx $i & done
112 changes: 54 additions & 58 deletions scripts/eval_gpt4_summ.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@
Overall, the book is a sweeping narrative that spans multiple generations and continents. It is a story about identity, culture, family, and history, and it raises important questions about the human experience.<end of summary>
Reasoning: The summary incorrectly identifies the protagonist as "Cal Stephanides" instead of "Cal Margaret", so key point 1 is not supported. It does not mention key point 2. The summary mentions that Raul and Harris are silbings and that they eventually marry and settle down in Detroit so key point 3 is supported. It also mentions the Turkish attack and how they escape from Smyrna ot America so key point 5 is supported. It does not talk about the ship where they are wed so key point 6 is not supported. The summary then stops discussing the plot and so it does not mention key point 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, or 26. Thus, the only supported key points are 3 and 5, so recall is 2.
Reasoning: The summary incorrectly identifies the protagonist as "Cal Stephanides" instead of "Cal Margaret", so key point 1 is not supported. It does not mention key point 2. The summary mentions that Raul and Harris are silbings and that they eventually marry and settle down in Detroit so key point 3 is supported. It also mentions the Turkish attack and how they escape from Smyrna to America so key point 5 is supported. It does not talk about the ship where they are wed so key point 6 is not supported. The summary then stops discussing the plot and so it does not mention key point 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, or 26. Thus, the only supported key points are 3 and 5, so recall is 2.
Output: {{"supported_key_points": [3, 5], "recall": 2}}
Expand Down Expand Up @@ -334,6 +334,10 @@ def parse_json(text):
json.loads(matches[-1])
except:
matches = re.findall(r"(?:```json)(.+)(?:```)", text, re.DOTALL)
try:
json.loads(matches[-1])
except:
return None
return json.loads(matches[-1])
return None

Expand All @@ -353,6 +357,7 @@ def check_metrics(model, results_file, output_file):
d = json.loads(line)
keypoints[d["id"]] = d["summary/short_keypoints"]

all_inputs = []

for idx, d in enumerate(tqdm(results["data"])):
d["keypoints"] = keypoints[d["id"]]
Expand All @@ -365,63 +370,54 @@ def check_metrics(model, results_file, output_file):
fp = fluency_prompt.format(text=d["output"].strip())
rp = recall_prompt.format(keypoints="\n".join([f"{i+1}. {kp}" for i, kp in enumerate(d["keypoints"])]), summary=d["output"].strip())
pp = precision_prompt.format(expert_summary=d["summary/long"], summary=d["output"].strip())

all_inputs.extend([fp, rp, pp])

def get_score(prompt, tries=2):
o = None
for _ in range(tries):
o = model.generate(prompt=prompt)
if o is not None and o["output"] is not None:
ret = parse_json(o["output"])
if ret is not None:
return ret, o
return None, o

f, fo = get_score(fp)
if f is None:
continue
r, ro = get_score(rp)
if r is None:
continue
p, po = get_score(pp)
if p is None:
outputs = model.generate_batch(prompt=all_inputs, batch_file=output_file+".batch")
for idx, d in enumerate(tqdm(results["data"])):
os = outputs[idx*3:idx*3+3]
if any([x is None or x.get("output") is None for x in os]):
continue
fo, ro, po = os

if f is not None and r is not None and p is not None:
rec = r["recall"] / len(d["keypoints"]) if len(d["keypoints"]) > 0 else 0
prec = p["precision"] / p["sentence_count"] if p["sentence_count"] > 0 else 0
f1 = f["fluency"] * 2 * (rec * prec) / (rec + prec) if rec + prec > 0 else 0
d["gpt4-scores"] = {
"fluency": f["fluency"],
"recall_total": len(d["keypoints"]),
"recall_found": r["recall"],
"precision_total": p["sentence_count"],
"precision_found": p["precision"],
"recall": rec,
"precision": prec,
"f1": f1,
"flunecy_output": fo["output"],
"recall_output": ro["output"],
"precision_output": po["output"],
}

if idx < 10:
print("=====================================")
print(f"Fluency: {fo['output']}")
print(f"Recall: {ro['output']}")
print(f"Precision: {po['output']}")
print(f"Scores: {d['gpt4-scores']}")
else:
print("Warning! Couldn't get a score")
rets = [parse_json(o["output"]) for o in os]
if any([r is None for r in rets]):
print(f"GPT-4 output: \n---fluency call---\n{fo['output']}\n---recall call---\n{ro['output']}\n---precision call---\n{po['output']}\n------")
# import pdb; pdb.set_trace()
if len([d for d in results["data"] if "gpt4-scores" in d]) == 0:
continue
f, r, p = rets

rec = r["recall"] / len(d["keypoints"]) if len(d["keypoints"]) > 0 else 0
prec = p["precision"] / p["sentence_count"] if p["sentence_count"] > 0 else 0
f1 = f["fluency"] * 2 * (rec * prec) / (rec + prec) if rec + prec > 0 else 0
d["gpt-4-scores"] = {
"fluency": f["fluency"],
"recall_total": len(d["keypoints"]),
"recall_found": r["recall"],
"precision_total": p["sentence_count"],
"precision_found": p["precision"],
"recall": rec,
"precision": prec,
"f1": f1,
"flunecy_output": fo["output"],
"recall_output": ro["output"],
"precision_output": po["output"],
}

if idx < 10:
print("=====================================")
print(f"Fluency: {fo['output']}")
print(f"Recall: {ro['output']}")
print(f"Precision: {po['output']}")
print(f"Scores: {d['gpt-4-scores']}")

if len([d for d in results["data"] if "gpt-4-scores" in d]) == 0:
raise Exception("No scores found")

averaged = {
"gpt4-recall": np.mean([d["gpt4-scores"]["recall"] for d in results["data"] if "gpt4-scores" in d]),
"gpt4-precision": np.mean([d["gpt4-scores"]["precision"] for d in results["data"] if "gpt4-scores" in d]),
"gpt4-fluency": np.mean([d["gpt4-scores"]["fluency"] for d in results["data"] if "gpt4-scores" in d]),
"gpt4-f1": np.mean([d["gpt4-scores"]["f1"] for d in results["data"] if "gpt4-scores" in d]),
"gpt-4-recall": np.mean([d["gpt-4-scores"]["recall"] for d in results["data"] if "gpt-4-scores" in d]),
"gpt-4-precision": np.mean([d["gpt-4-scores"]["precision"] for d in results["data"] if "gpt-4-scores" in d]),
"gpt-4-fluency": np.mean([d["gpt-4-scores"]["fluency"] for d in results["data"] if "gpt-4-scores" in d]),
"gpt-4-f1": np.mean([d["gpt-4-scores"]["f1"] for d in results["data"] if "gpt-4-scores" in d]),
}
results["averaged_metrics"].update(averaged)

Expand All @@ -432,7 +428,7 @@ def get_score(prompt, tries=2):
return results

if __name__ == "__main__":
model = OpenAIModel("azure/gpt-4o-2024-05-13", temperature=0.1, generation_max_length=4096)
model = OpenAIModel("gpt-4o-2024-05-13", temperature=0.1, generation_max_length=4096)

parser = argparse.ArgumentParser()
parser.add_argument("--num_shards", type=int, default=1)
Expand All @@ -441,13 +437,14 @@ def get_score(prompt, tries=2):
num_shards = args.num_shards
shard_idx = args.shard_idx

# this is all of our chat models
model_to_check = ['gpt-4-0125-preview', 'gpt-4o-2024-05-13', 'gpt-4o-2024-08-06', 'gpt-4o-mini-2024-07-18', 'claude-3-5-sonnet-20240620', 'gemini-1.5-flash-001', 'gemini-1.5-pro-001', 'Meta-Llama-3-8B-Instruct', 'Meta-Llama-3-8B-Instruct-Theta8M', 'Meta-Llama-3-70B-Instruct-Theta8M', 'Meta-Llama-3.1-8B-Instruct', 'Meta-Llama-3.1-70B-Instruct', 'Mistral-7B-Instruct-v0.1', 'Mistral-7B-Instruct-v0.2', 'Mistral-7B-Instruct-v0.3', 'Mistral-Nemo-Instruct-2407', 'Phi-3-mini-128k-instruct', 'Phi-3-small-128k-instruct', 'Phi-3-medium-128k-instruct', 'Phi-3.5-mini-instruct', 'Qwen2-7B-Instruct', 'Qwen2-57B-A14B-Instruct', 'c4ai-command-r-v01', 'AI21-Jamba-1.5-Mini', 'prolong-64k-instruct', 'prolong-512k-instruct-20b-theta128m', "MegaBeam-Mistral-7B-512k"]
# chat models
model_to_check =['gpt-4-0125-preview','gpt-4o-mini-2024-07-18','gpt-4o-2024-05-13','gpt-4o-2024-08-06','claude-3-5-sonnet-20240620','gemini-1.5-flash-001','gemini-1.5-pro-001','Llama-2-7B-32K-Instruct','Meta-Llama-3-8B-Instruct','Meta-Llama-3-8B-Instruct-Theta16M','Meta-Llama-3-70B-Instruct-Theta16M','Llama-3.1-8B-Instruct','Llama-3.1-70B-Instruct','Llama-3.3-70B-Instruct','Llama-3.2-1B-Instruct','Llama-3.2-3B-Instruct','Mistral-7B-Instruct-v0.1','Mistral-7B-Instruct-v0.2','Mistral-7B-Instruct-v0.3','Ministral-8B-Instruct-2410','Mistral-Nemo-Instruct-2407','MegaBeam-Mistral-7B-512k','Phi-3-mini-128k-instruct','Phi-3-small-128k-instruct','Phi-3-medium-128k-instruct','Phi-3.5-mini-instruct','Qwen2-7B-Instruct','Qwen2-57B-A14B-Instruct','Qwen2.5-1.5B-Instruct','Qwen2.5-3B-Instruct','Qwen2.5-7B-Instruct','Qwen2.5-72B-Instruct','Llama-3-8B-ProLong-512k-Instruct','gemma-2-9b-it','gemma-2-9b-it-Theta320K','gemma-2-27b-it','gemma-2-27b-it-Theta320K','c4ai-command-r-v01','AI21-Jamba-1.5-Mini']

model_to_check = ['gpt-4-0125-preview', 'gpt-4o-2024-05-13', 'gpt-4o-2024-08-06', 'gpt-4o-mini-2024-07-18', 'claude-3-5-sonnet-20240620', 'gemini-1.5-flash-001', 'gemini-1.5-pro-001', 'Meta-Llama-3-8B-Theta8M', 'Meta-Llama-3-8B-Instruct-Theta8M', 'Meta-Llama-3-70B-Theta8M', 'Meta-Llama-3-70B-Instruct-Theta8M', 'Meta-Llama-3.1-8B', 'Meta-Llama-3.1-8B-Instruct', 'Meta-Llama-3.1-70B', 'Meta-Llama-3.1-70B-Instruct', "Llama-3.2-1B", "Llama-3.2-1B-Instruct", "Llama-3.2-3B", "Llama-3.2-3B-Instruct", 'llama-2-7b-80k-basefixed', 'Yarn-Llama-2-7b-128k', 'Mistral-7B-Instruct-v0.1', 'Mistral-7B-Instruct-v0.2', 'Mistral-7B-v0.3', 'Mistral-7B-Instruct-v0.3', 'Mistral-Nemo-Instruct-2407', 'MegaBeam-Mistral-7B-512k', 'Phi-3-mini-128k-instruct', 'Phi-3-small-128k-instruct', 'Phi-3-medium-128k-instruct', 'Phi-3.5-mini-instruct', 'Yi-6B-200K', 'Yi-9B-200K', 'Yi-34B-200K', 'Qwen2-7B-Instruct', 'Qwen2-57B-A14B-Instruct', 'AI21-Jamba-1.5-Mini', 'prolong-512k-instruct-20b-theta128m',]
# all models
model_to_check = ['gpt-4-0125-preview','gpt-4o-mini-2024-07-18','gpt-4o-2024-05-13','gpt-4o-2024-08-06','claude-3-5-sonnet-20240620','gemini-1.5-flash-001','gemini-1.5-pro-001','Llama-2-7B-32K','Llama-2-7B-32K-Instruct','llama-2-7b-80k','Yarn-Llama-2-7b-64k','Yarn-Llama-2-7b-128k','Meta-Llama-3-8B','Meta-Llama-3-8B-Instruct','Meta-Llama-3-8B-Theta16M','Meta-Llama-3-8B-Instruct-Theta16M','Meta-Llama-3-70B-Theta16M','Meta-Llama-3-70B-Instruct-Theta16M','Llama-3.1-8B','Llama-3.1-8B-Instruct','Llama-3.1-70B','Llama-3.1-70B-Instruct','Llama-3.3-70B-Instruct','Llama-3.2-1B','Llama-3.2-1B-Instruct','Llama-3.2-3B','Llama-3.2-3B-Instruct','Mistral-7B-v0.1','Mistral-7B-Instruct-v0.1','Mistral-7B-Instruct-v0.2','Mistral-7B-v0.3','Mistral-7B-Instruct-v0.3','Ministral-8B-Instruct-2410','Mistral-Nemo-Base-2407','Mistral-Nemo-Instruct-2407','MegaBeam-Mistral-7B-512k','Yi-6B-200K','Yi-9B-200K','Yi-34B-200K','Yi-1.5-9B-32K','Phi-3-mini-128k-instruct','Phi-3-small-128k-instruct','Phi-3-medium-128k-instruct','Phi-3.5-mini-instruct','Qwen2-7B','Qwen2-7B-Instruct','Qwen2-57B-A14B','Qwen2-57B-A14B-Instruct','Qwen2.5-1.5B','Qwen2.5-1.5B-Instruct','Qwen2.5-3B','Qwen2.5-3B-Instruct','Qwen2.5-7B','Qwen2.5-7B-Instruct','Qwen2.5-72B-Instruct','Llama-3-8B-ProLong-512k-Instruct','gemma-2-9b','gemma-2-9b-it','gemma-2-9b-it-Theta320K','gemma-2-27b','gemma-2-27b-it','gemma-2-27b-it-Theta320K','c4ai-command-r-v01','Jamba-v0.1','AI21-Jamba-1.5-Mini']

#just replace the glob pattern
all_paths = [glob.glob(f"output/{m}/multi_lexsum_*_v12_*max400min*.json") for m in model_to_check] + [glob.glob(f"output/{m}/infbench_sum_*_v12_*max1200min*.json") for m in model_to_check]
# just replace the glob pattern
all_paths = [glob.glob(f"output/{m}/multi_lexsum_*_v1_*.json") for m in model_to_check] + [glob.glob(f"output/{m}/infbench_sum_*_v1_*.json") for m in model_to_check]

all_paths = [item for sublist in all_paths for item in sublist if item.endswith(".json")]
all_paths = [p for p in all_paths if not os.path.exists(p.replace(".json", "-gpt4eval_o.json"))]
Expand All @@ -459,4 +456,3 @@ def get_score(prompt, tries=2):
newp = p.replace(".json", "-gpt4eval_o.json")
print("evaluating")
check_metrics(model, p, newp)

Loading

0 comments on commit f4744f0

Please sign in to comment.