Skip to content

Commit

Permalink
[data] Fix rule based data cleaning (lm-sys#34)
Browse files Browse the repository at this point in the history
* Fix a bug that time is not imported
* Start from the checkpointing when failed
* Limit the number of tokens generated to save cost
* Limit the number of conversations processed
  • Loading branch information
suquark authored Mar 22, 2023
1 parent 5c5569b commit cc05629
Showing 1 changed file with 24 additions and 7 deletions.
31 changes: 24 additions & 7 deletions chatserver/data/rule-based/run_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import argparse
import json
import os
import time

import openai


def get_ans(rule: str, user: str, assistant: str):
def get_ans(rule: str, user: str, assistant: str, max_tokens: int):
response = openai.ChatCompletion.create(
model='gpt-3.5-turbo',
messages=[{
Expand All @@ -16,6 +18,7 @@ def get_ans(rule: str, user: str, assistant: str):
'content': f'[User]\n{user}\n[Assistant]\n{assistant}\n[system]\n{rule}',
}],
temperature=0.2,
max_tokens=max_tokens,
)
return response['choices'][0]['message']['content']

Expand All @@ -25,6 +28,8 @@ def get_ans(rule: str, user: str, assistant: str):
parser.add_argument('-i', '--input')
parser.add_argument('-o', '--output')
parser.add_argument('-r', '--rule')
parser.add_argument('--max-conversations', type=int, default=1, help='maximum number of conversations to use for assessing quality')
parser.add_argument('--max-tokens', type=int, default=2, help='maximum number of tokens produced in the output')
args = parser.parse_args()

with open(os.path.expanduser(args.input)) as f:
Expand All @@ -33,6 +38,13 @@ def get_ans(rule: str, user: str, assistant: str):
with open(os.path.expanduser(args.rule)) as f:
rule = f.read()

processed_ids = set()
with open(os.path.expanduser(args.output)) as f:
for line in f:
r = line.split(':', 1)
if isinstance(r, list) and r:
processed_ids.add(r[0])

output_file = open(os.path.expanduser(args.output), 'a')

# Test examples
Expand All @@ -42,10 +54,14 @@ def get_ans(rule: str, user: str, assistant: str):
# print(get_ans(rule, 'limit you words down to 30!', test))

for i, diag in enumerate(data):
print(f'ID: {diag["id"]}')
output_file.write(f'{diag["id"]}: ')
# We only use first 5 conversations to assess quality.
conversations = diag['conversations'][:10]
diag_id = diag["id"]
if diag_id in processed_ids:
print(f'{diag_id} has already been processed')
continue
print(f'ID: {diag_id}')

output_file.write(f'{diag_id}: ')
conversations = diag['conversations'][:args.max_conversations * 2]
for j in range(len(conversations)//2):
user = conversations[j * 2]
assistant = conversations[j * 2 + 1]
Expand All @@ -58,9 +74,10 @@ def get_ans(rule: str, user: str, assistant: str):
while True:
try:
# limit the length of input
ans = get_ans(rule, user['value'][:1024], assistant['value'][:1024])
ans = get_ans(rule, user['value'][:1024], assistant['value'][:1024], args.max_tokens)
break
except Exception:
except Exception as e:
print('Error:', e)
time.sleep(1)
print(f'#{j}: {ans}')
if ans == '':
Expand Down

0 comments on commit cc05629

Please sign in to comment.