-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathbbh_evaluation.py
89 lines (79 loc) · 3.31 KB
/
bbh_evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from vllm import LLM, SamplingParams
from src.bbh_evaluation import BBHEvaluation
import argparse
def parse_args():
"""
Function to parse arguments
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--base_model_path",
type=str,
)
parser.add_argument(
"--lora_path",
type=str,
)
parser.add_argument("--temperature", type=float, default=0.0)
parser.add_argument("--max_tokens", type=int, default=4096)
parser.add_argument("--stopwords", nargs="*", default=[])
parser.add_argument(
"--min_cots",
type=int,
default=1,
help="Min. number of CoTs you want to generate (min = 1)",
)
parser.add_argument(
"--max_cots",
type=int,
default=1,
help="Max. number of CoTs you want to generate (min = 1)",
)
parser.add_argument("--postprocess_responses", action="store_true")
parser.add_argument("--chat_format", type=str, help="Options: llama_chat_simple, llama_chat_v2, llama_cot_chat, None")
args = parser.parse_args()
return args
def run_self_consistency(ARGS):
sampling_params = SamplingParams(temperature=ARGS.temperature,
max_tokens=ARGS.max_tokens,
stop=ARGS.stopwords)
enable_lora = ARGS.lora_path is not None
llm = LLM(model=ARGS.base_model_path, enable_lora=enable_lora, max_lora_rank=64)
benchmark = BenchmarkEvaluator(ARGS.split, k=ARGS.self_consistency_prompt_k, chat_format=ARGS.chat_format)
if enable_lora:
results = benchmark.self_consistency(llm,
sampling_params,
ARGS.lora_path,
postprocess_responses=ARGS.postprocess_responses,
self_consistency_k=ARGS.num_samples_self_consistency
)
else:
results = benchmark.self_consistency(llm,
sampling_params,
lora_path=None,
output_base_path=ARGS.base_model_path,
postprocess_responses=ARGS.postprocess_responses,
self_consistency_k=ARGS.num_samples_self_consistency)
print(results)
if __name__ == "__main__":
print("Starting")
ARGS = parse_args()
sampling_params = SamplingParams(temperature=ARGS.temperature, max_tokens=ARGS.max_tokens, stop=ARGS.stopwords)
enable_lora = ARGS.lora_path is not None
llm = LLM(model=ARGS.base_model_path, enable_lora=enable_lora, max_lora_rank=64)
for k in range(ARGS.min_cots, ARGS.max_cots+1):
benchmark = BBHEvaluation(k=k)
if enable_lora:
results = benchmark(llm,
sampling_params,
ARGS.lora_path,
postprocess_responses=ARGS.postprocess_responses
)
else:
results = benchmark(llm,
sampling_params,
lora_path=None,
output_base_path=ARGS.base_model_path,
postprocess_responses=ARGS.postprocess_responses)
print(f"Fininshed evaluation for k {k}")
print(results)