Skip to content

Commit

Permalink
修改了模型载入的路径,finetune.py修改了data_collator;tokenize_dataset_rows修改了函数prep…
Browse files Browse the repository at this point in the history
…rocess,主要是为了和data_collator相匹配
  • Loading branch information
duyupeng committed Oct 10, 2023
1 parent 9973930 commit 1fa9be2
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 23 deletions.
90 changes: 67 additions & 23 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
from dataclasses import dataclass, field
import datasets
import os
from utils import chatglm_path,chatglm2_path


tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)

# tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(chatglm2_path, trust_remote_code=True)

@dataclass
class FinetuneArguments:
Expand All @@ -26,27 +28,66 @@ def forward(self, x):
return super().forward(x).to(torch.float32)


def data_collator(features: list) -> dict:
len_ids = [len(feature["input_ids"]) for feature in features]
longest = max(len_ids)
input_ids = []
labels_list = []
for ids_l, feature in sorted(zip(len_ids, features), key=lambda x: -x[0]):
ids = feature["input_ids"]
seq_len = feature["seq_len"]
labels = (
[-100] * (seq_len - 1) + ids[(seq_len - 1) :] + [-100] * (longest - ids_l)
)
ids = ids + [tokenizer.pad_token_id] * (longest - ids_l)
_ids = torch.LongTensor(ids)
labels_list.append(torch.LongTensor(labels))
input_ids.append(_ids)
input_ids = torch.stack(input_ids)
labels = torch.stack(labels_list)
return {
"input_ids": input_ids,
"labels": labels,
# def data_collator(features: list) -> dict:
# len_ids = [len(feature["input_ids"]) for feature in features]
# longest = max(len_ids)
# input_ids = []
# labels_list = []
# for ids_l, feature in sorted(zip(len_ids, features), key=lambda x: -x[0]):
# ids = feature["input_ids"]
# seq_len = feature["seq_len"]
# labels = (
# [-100] * (seq_len - 1) + ids[(seq_len - 1) :] + [-100] * (longest - ids_l)
# )
# ids = ids + [tokenizer.pad_token_id] * (longest - ids_l)
# _ids = torch.LongTensor(ids)
# labels_list.append(torch.LongTensor(labels))
# input_ids.append(_ids)
# input_ids = torch.stack(input_ids)
# labels = torch.stack(labels_list)
# return {
# "input_ids": input_ids,
# "labels": labels,
# }

def data_collator(examples):
max_source_length = 64
max_target_length = 128
max_seq_length = max_source_length + max_target_length + 1
model_inputs = {
"input_ids": [],
"labels": [],
}
for i in range(len(examples)):

query, answer = examples[i]['prompt'], examples[i]['target']

history = None
prompt = tokenizer.build_prompt(query, history)
# prompt = prefix + prompt
a_ids = tokenizer.encode(text=prompt, add_special_tokens=True, truncation=True,
max_length=max_source_length)
b_ids = tokenizer.encode(text=answer, add_special_tokens=False, truncation=True,
max_length=max_target_length)
context_length = len(a_ids)
input_ids = a_ids + b_ids + [tokenizer.eos_token_id]
labels = [tokenizer.pad_token_id] * context_length + b_ids + [tokenizer.eos_token_id]

pad_len = max_seq_length - len(input_ids)
input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
labels = labels + [tokenizer.pad_token_id] * pad_len
if True:
labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels]
# 转换为tensor
input_ids = torch.LongTensor(input_ids)
labels = torch.LongTensor(labels)

model_inputs["input_ids"].append(input_ids)
model_inputs["labels"].append(labels)
model_inputs["input_ids"] = torch.stack(model_inputs["input_ids"])
model_inputs["labels"] = torch.stack(model_inputs["labels"])
return model_inputs



class ModifiedTrainer(Trainer):
Expand Down Expand Up @@ -74,8 +115,11 @@ def main():
).parse_args_into_dataclasses()

# init model
# model = AutoModel.from_pretrained(
# "THUDM/chatglm-6b", load_in_8bit=True, trust_remote_code=True, device_map="auto"
# )
model = AutoModel.from_pretrained(
"THUDM/chatglm-6b", load_in_8bit=True, trust_remote_code=True, device_map="auto"
chatglm_path, load_in_8bit=True, trust_remote_code=True, device_map="auto"
)
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
Expand All @@ -97,7 +141,7 @@ def main():
model = get_peft_model(model, peft_config)

# load dataset
dataset = datasets.load_from_disk(finetune_args.dataset_path)
dataset = datasets.load_from_disk(finetune_args.dataset_path)[:3]
print(f"\n{len(dataset)=}\n")

# start train
Expand Down
56 changes: 56 additions & 0 deletions tokenize_dataset_rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,19 @@
import transformers


# def preprocess(tokenizer, config, example, max_seq_length):
# prompt = example["context"]
# target = example["target"]
# prompt_ids = tokenizer.encode(prompt, max_length=max_seq_length, truncation=True)
# target_ids = tokenizer.encode(
# target,
# max_length=max_seq_length,
# truncation=True,
# add_special_tokens=False)
# input_ids = prompt_ids + target_ids + [config.eos_token_id]
# return {"input_ids": input_ids, "seq_len": len(prompt_ids)}


def preprocess(tokenizer, config, example, max_seq_length):
prompt = example["context"]
target = example["target"]
Expand All @@ -19,6 +32,49 @@ def preprocess(tokenizer, config, example, max_seq_length):
return {"input_ids": input_ids, "seq_len": len(prompt_ids)}


def preprocess_chatglm2(tokenizer, config, example, max_seq_length):
prompt = example["context"]
target = example["target"]

return {"prompt": prompt, "target": target}

# def preprocess_function_train(examples):
# max_seq_length = data_args.max_source_length + data_args.max_target_length + 1
#
# model_inputs = {
# "input_ids": [],
# "labels": [],
# }
# for i in range(len(examples[prompt_column])):
# if examples[prompt_column][i] and examples[response_column][i]:
# query, answer = examples[prompt_column][i], examples[response_column][i]
#
# history = examples[history_column][i] if history_column is not None else None
# prompt = tokenizer.build_prompt(query, history)
#
# prompt = prefix + prompt
# a_ids = tokenizer.encode(text=prompt, add_special_tokens=True, truncation=True,
# max_length=data_args.max_source_length)
# b_ids = tokenizer.encode(text=answer, add_special_tokens=False, truncation=True,
# max_length=data_args.max_target_length)
#
# context_length = len(a_ids)
# input_ids = a_ids + b_ids + [tokenizer.eos_token_id]
# labels = [tokenizer.pad_token_id] * context_length + b_ids + [tokenizer.eos_token_id]
#
# pad_len = max_seq_length - len(input_ids)
# input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
# labels = labels + [tokenizer.pad_token_id] * pad_len
# if data_args.ignore_pad_token_for_loss:
# labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels]
#
# model_inputs["input_ids"].append(input_ids)
# model_inputs["labels"].append(labels)
#
# return model_inputs



def read_jsonl(path, max_seq_length, skip_overlength=False):
model_name = "THUDM/chatglm-6b"
tokenizer = transformers.AutoTokenizer.from_pretrained(
Expand Down

0 comments on commit 1fa9be2

Please sign in to comment.