forked from lm-sys/FastChat
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update evaluation scripts and instructions (lm-sys#223)
- Loading branch information
1 parent
7b8fa6d
commit 6d98710
Showing
24 changed files
with
1,143 additions
and
141 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import argparse | ||
from transformers import AutoTokenizer, AutoModelForCausalLM | ||
import torch | ||
import os | ||
import json | ||
from tqdm import tqdm | ||
import shortuuid | ||
import ray | ||
|
||
from fastchat.conversation import default_conversation | ||
from fastchat.utils import disable_torch_init | ||
|
||
|
||
def run_eval(model_path, model_id, question_file, answer_file, num_gpus): | ||
# split question file into num_gpus files | ||
ques_jsons = [] | ||
with open(os.path.expanduser(question_file), "r") as ques_file: | ||
for line in ques_file: | ||
ques_jsons.append(line) | ||
|
||
chunk_size = len(ques_jsons) // num_gpus | ||
ans_handles = [] | ||
for i in range(0, len(ques_jsons), chunk_size): | ||
ans_handles.append(get_model_answers.remote(model_path, model_id, ques_jsons[i:i + chunk_size])) | ||
|
||
ans_jsons = [] | ||
for ans_handle in ans_handles: | ||
ans_jsons.extend(ray.get(ans_handle)) | ||
|
||
with open(os.path.expanduser(answer_file), "w") as ans_file: | ||
for line in ans_jsons: | ||
ans_file.write(json.dumps(line) + "\n") | ||
|
||
|
||
@ray.remote(num_gpus=1) | ||
@torch.inference_mode() | ||
def get_model_answers(model_path, model_id, question_jsons): | ||
disable_torch_init() | ||
model_path = os.path.expanduser(model_path) | ||
tokenizer = AutoTokenizer.from_pretrained(model_path) | ||
model = AutoModelForCausalLM.from_pretrained(model_path, | ||
torch_dtype=torch.float16).cuda() | ||
|
||
ans_jsons = [] | ||
for i, line in enumerate(tqdm(question_jsons)): | ||
ques_json = json.loads(line) | ||
idx = ques_json["question_id"] | ||
qs = ques_json["text"] | ||
conv = default_conversation.copy() | ||
conv.append_message(conv.roles[0], qs) | ||
prompt = conv.get_prompt() | ||
inputs = tokenizer([prompt]) | ||
output_ids = model.generate( | ||
torch.as_tensor(inputs.input_ids).cuda(), | ||
do_sample=True, | ||
temperature=0.7, | ||
max_new_tokens=1024) | ||
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] | ||
try: | ||
index = outputs.index(conv.sep, len(prompt)) | ||
except ValueError: | ||
outputs += conv.sep | ||
index = outputs.index(conv.sep, len(prompt)) | ||
|
||
outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip() | ||
ans_id = shortuuid.uuid() | ||
ans_jsons.append({"question_id": idx, | ||
"text": outputs, | ||
"answer_id": ans_id, | ||
"model_id": model_id, | ||
"metadata": {}}) | ||
return ans_jsons | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model-path", type=str, required=True) | ||
parser.add_argument("--model-id", type=str, required=True) | ||
parser.add_argument("--question-file", type=str, required=True) | ||
parser.add_argument("--answer-file", type=str, default="answer.jsonl") | ||
parser.add_argument("--num-gpus", type=int, default=1) | ||
args = parser.parse_args() | ||
|
||
ray.init() | ||
run_eval(args.model_path, args.model_id, args.question_file, args.answer_file, args.num_gpus) |
Oops, something went wrong.