-
Notifications
You must be signed in to change notification settings - Fork 15
/
preprocess.py
74 lines (66 loc) · 2.67 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import json
import os
from transformers import AutoTokenizer
import argparse
JSONL_FILE_DIR = "jsonl_files"
TOKENIED_FILE_DIR = "tokenized_files"
T5_PROMPT = {"wiki_bio": "convert the table to text: ",
"totto_meta": "",
"common_gen": "generate a sentence with: ",
"multi_news": "summarize: ",
"xsum": "summarize: ",
"wmt16_ro-en": "translate Romanian to English: ",
"java": "<java> ",
"python": "<python> "
}
def tokenize_raw(ds_name, model="t5-small", ptm_alias="t5", prompt=""):
"""
ds_name: file name of raw data
model: the pretrained model used
ptm_alias: alias for the model
prompt: optional, prompt for t5-based model
"""
tokenizer = AutoTokenizer.from_pretrained(model)
base_dir = f"{JSONL_FILE_DIR}/{ds_name}"
tokenized_dir = f"{TOKENIED_FILE_DIR}/{ds_name}"
if not os.path.exists(tokenized_dir):
os.makedirs(tokenized_dir)
files = ["val.jsonl", "train.jsonl"]
files_tokenized = [f"val.{ptm_alias}.jsonl", f"train.{ptm_alias}.jsonl"]
insts_list = []
for file in files:
insts = []
with open(os.path.join(base_dir, file)) as f:
for line in f:
insts.append(json.loads(line))
insts_list.append(insts)
for i, insts in enumerate(insts_list):
for inst in insts:
if "t5" in model:
source = prompt + inst["source"]
else:
source = inst["source"]
target = inst["target"]
src_id = tokenizer.encode(source)
tgt_id = tokenizer.encode(target)
inst["src_id"] = src_id
inst["tgt_id"] = tgt_id
print("write into ... ", os.path.join(tokenized_dir, files_tokenized[i]))
with open(os.path.join(tokenized_dir, files_tokenized[i]), "w") as f:
for inst in insts:
print(json.dumps(inst, ensure_ascii=False), file=f)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', required=True,
help=" the pretrained model used, tokenizer.from_pretrained(model_name)")
parser.add_argument('--dataset', required=True, help="selected dataset")
parser.add_argument('--ptm', default=None, help=" mark the tokenized file")
args = parser.parse_args()
if not args.ptm:
args.ptm = args.model_name.split("/")[-1].split("-")[0]
print("You are using the pretrain model: ", args.ptm)
# if you need a prompt
prompt = ""
if "t5" in args.model_name:
prompt = T5_PROMPT[args.dataset]
tokenize_raw(args.dataset, args.model_name, args.ptm, prompt)