Skip to content

Commit

Permalink
[eval] GPT-3.5 QA baseline (lm-sys#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
suquark authored Mar 25, 2023
1 parent c91c005 commit 9fca947
Showing 1 changed file with 41 additions and 26 deletions.
67 changes: 41 additions & 26 deletions chatserver/eval/qa_baseline_gpt-3.5-turbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,34 @@
import json
import os
import time
import concurrent.futures

import openai
import tqdm


def get_eval(question: str, max_tokens: int):
response = openai.ChatCompletion.create(
model='gpt-3.5-turbo',
messages=[{
'role': 'system',
'content': 'You are a helpful assistant.'
}, {
'role': 'user',
'content': question,
}],
max_tokens=max_tokens,
)
return response['choices'][0]['message']['content']
MODEL = 'gpt-3.5-turbo'

def get_answer(question_id: int, question: str, max_tokens: int):
for retries in range(3):
try:
response = openai.ChatCompletion.create(
model=MODEL,
messages=[{
'role': 'system',
'content': 'You are a helpful assistant.'
}, {
'role': 'user',
'content': question,
}],
max_tokens=max_tokens,
)
answer = response['choices'][0]['message']['content']
return {'id': question_id, 'answer': answer, 'model': MODEL}
except Exception as e:
print('[ERROR]', e)
time.sleep(1)
return {'id': question_id, 'answer': '#ERROR#', 'model': MODEL}


if __name__ == '__main__':
Expand All @@ -31,22 +41,27 @@ def get_eval(question: str, max_tokens: int):
parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
args = parser.parse_args()

questions_dict = {}
with open(os.path.expanduser(args.question)) as f:
question = json.load(f)
questions_dict = {q['id']: q['question'] for q in question['questions']}
for line in f:
if not line:
continue
q = json.loads(line)
questions_dict[q['id']] = q['question']

answers = []

for qid, question in tqdm.tqdm(questions_dict.items()):
for retries in range(3):
try:
eval_result = get_eval(question, args.max_tokens)
answers.append({'id': qid, 'answer': eval_result})
break
except Exception as e:
print('Error: ', e)
if retries == 2:
answers.append({'id': qid, 'answer': '#ERROR#'})
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
futures = []
for qid, question in questions_dict.items():
future = executor.submit(get_answer, qid, question, args.max_tokens)
futures.append(future)

for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
answers.append(future.result())

answers.sort(key=lambda x: x['id'])

with open(os.path.expanduser(args.output), 'w') as f:
json.dump({'model': 'gpt-3.5-turbo', 'answers': answers}, f)
table = [json.dumps(ans) for ans in answers]
f.write('\n'.join(table))

0 comments on commit 9fca947

Please sign in to comment.