Skip to content

Commit

Permalink
新加入了对chatglm2、chatglm适配lora,以及参数的调整
Browse files Browse the repository at this point in the history
  • Loading branch information
duyupeng committed Nov 12, 2023
1 parent 0f513c1 commit 4534a30
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 108 deletions.
81 changes: 11 additions & 70 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,19 @@
from dataclasses import dataclass, field
import datasets
import os
from utils import chatglm_path,chatglm2_path
from utils import chatglm_path, chatglm2_path

tokenizer = ''



# tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(chatglm2_path, trust_remote_code=True)

@dataclass
class FinetuneArguments:
dataset_path: str = field(default="data/alpaca")
model_path: str = field(default="output")
lora_rank: int = field(default=8)
chatglm_path: str = field(default='model_path/chatglm')


class CastOutputToFloat(nn.Sequential):
Expand All @@ -36,8 +37,9 @@ def data_collator(features: list) -> dict:
for ids_l, feature in sorted(zip(len_ids, features), key=lambda x: -x[0]):
ids = feature["input_ids"]
seq_len = feature["seq_len"]

labels = (
[-100] * (seq_len - 1) + ids[(seq_len - 1) :] + [-100] * (longest - ids_l)
[-100] * (seq_len - 1) + ids[(seq_len - 1):] + [-100] * (longest - ids_l)
)
ids = ids + [tokenizer.pad_token_id] * (longest - ids_l)
_ids = torch.LongTensor(ids)
Expand All @@ -50,68 +52,6 @@ def data_collator(features: list) -> dict:
"labels": labels,
}

# def data_collator(examples):
# max_source_length = 64
# max_target_length = 128
# max_seq_length = max_source_length + max_target_length + 1
# model_inputs = {
# "input_ids": [],
# "labels": [],
# }
# for i in range(len(examples)):
#
# query, answer = examples[i]['prompt'], examples[i]['target']
#
# history = None
# prompt = tokenizer.build_prompt(query, history)
# # prompt = prefix + prompt
# a_ids = tokenizer.encode(text=prompt, add_special_tokens=True, truncation=True,
# max_length=max_source_length)
# b_ids = tokenizer.encode(text=answer, add_special_tokens=False, truncation=True,
# max_length=max_target_length)
# context_length = len(a_ids)
# input_ids = a_ids + b_ids + [tokenizer.eos_token_id]
# labels = [tokenizer.pad_token_id] * context_length + b_ids + [tokenizer.eos_token_id]
#
# pad_len = max_seq_length - len(input_ids)
# input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
# labels = labels + [tokenizer.pad_token_id] * pad_len
# if True:
# labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels]
# # 转换为tensor
# input_ids = torch.LongTensor(input_ids)
# labels = torch.LongTensor(labels)
#
# model_inputs["input_ids"].append(input_ids)
# model_inputs["labels"].append(labels)
# model_inputs["input_ids"] = torch.stack(model_inputs["input_ids"])
# model_inputs["labels"] = torch.stack(model_inputs["labels"])
# return model_inputs

# def data_collator(features: list) -> dict:
# len_ids = [len(feature["input_ids"]) for feature in features]
# longest = max(len_ids)
# input_ids = []
# labels_list = []
# for ids_l, feature in sorted(zip(len_ids, features), key=lambda x: -x[0]):
# ids = feature["input_ids"]
# seq_len = feature["seq_len"]
# labels = (
# [-100] * (seq_len - 1) + ids[(seq_len - 1) :] + [-100] * (longest - ids_l)
# )
# ids = ids + [tokenizer.pad_token_id] * (longest - ids_l)
# _ids = torch.LongTensor(ids)
# labels_list.append(torch.LongTensor(labels))
# input_ids.append(_ids)
# input_ids = torch.stack(input_ids)
# labels = torch.stack(labels_list)
# return {
# "input_ids": input_ids,
# "labels": labels,
# }
#
#
#

class ModifiedTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
Expand All @@ -133,17 +73,18 @@ def save_model(self, output_dir=None, _internal_call=False):

def main():
writer = SummaryWriter()

finetune_args, training_args = HfArgumentParser(
(FinetuneArguments, TrainingArguments)
).parse_args_into_dataclasses()

# init model
# model = AutoModel.from_pretrained(
# "THUDM/chatglm-6b", load_in_8bit=True, trust_remote_code=True, device_map="auto"
# )

model = AutoModel.from_pretrained(
chatglm2_path, load_in_8bit=True, trust_remote_code=True, device_map="auto"
finetune_args.chatglm_path, load_in_8bit=True, trust_remote_code=True, device_map="auto"
)
global tokenizer
tokenizer = AutoTokenizer.from_pretrained(finetune_args.chatglm_path, trust_remote_code=True)
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
model.is_parallelizable = True
Expand Down
68 changes: 30 additions & 38 deletions tokenize_dataset_rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,59 +6,49 @@
import transformers


def preprocess(tokenizer, config, example, max_seq_length):
prompt = example["context"]
target = example["target"]
prompt_ids = tokenizer.encode(prompt, max_length=max_seq_length, truncation=True)
target_ids = tokenizer.encode(
target,
max_length=max_seq_length,
truncation=True,
add_special_tokens=False)
input_ids = prompt_ids + target_ids + [config.eos_token_id]
return {"input_ids": input_ids, "seq_len": len(prompt_ids)}
def preprocess(tokenizer, config, example, max_seq_length, version):
if version == 'v1':
prompt = example["context"]
target = example["target"]
prompt_ids = tokenizer.encode(prompt, max_length=max_seq_length, truncation=True)
target_ids = tokenizer.encode(
target,
max_length=max_seq_length,
truncation=True,
add_special_tokens=False)
input_ids = prompt_ids + target_ids + [config.eos_token_id]
return {"input_ids": input_ids, "seq_len": len(prompt_ids)}

if version == 'v2':
query = example["context"]
target = example["target"]
history = None
prompt = tokenizer.build_prompt(query, history)

a_ids = tokenizer.encode(text=prompt, add_special_tokens=True, truncation=True,
max_length=max_seq_length)
b_ids = tokenizer.encode(text=target, add_special_tokens=False, truncation=True,
max_length=max_seq_length)

input_ids = a_ids + b_ids + [tokenizer.eos_token_id]

return {"input_ids": input_ids, "seq_len": len(a_ids)}

def preprocess_chatglm2(tokenizer, config, example, max_seq_length):
query = example["context"]
target = example["target"]

history = None
prompt = tokenizer.build_prompt(query, history)

a_ids = tokenizer.encode(text=prompt, add_special_tokens=True, truncation=True,
max_length=max_seq_length)
b_ids = tokenizer.encode(text=target, add_special_tokens=False, truncation=True,
max_length=max_seq_length)


input_ids = a_ids + b_ids + [tokenizer.eos_token_id]


return {"input_ids": input_ids, "seq_len": len(a_ids)}






def read_jsonl(path, max_seq_length, skip_overlength=False):
# model_name = "THUDM/chatglm-6b"
model_chatglm_path = '/home/adminz/ChatGLM-Tuning/model_path/chatglm'
model_chatglm2_path = '/home/adminz/ChatGLM-Tuning/model_path/chat2glm'
def read_jsonl(path, max_seq_length, chatglm_path, version='v1', skip_overlength=False):

tokenizer = transformers.AutoTokenizer.from_pretrained(
model_chatglm2_path, trust_remote_code=True)
chatglm_path, trust_remote_code=True)
config = transformers.AutoConfig.from_pretrained(
model_chatglm2_path, trust_remote_code=True, device_map='auto')
chatglm_path, trust_remote_code=True, device_map='auto')
with open(path, "r") as f:
for line in tqdm(f.readlines()):
example = json.loads(line)
# feature = preprocess(tokenizer, config, example, max_seq_length)
feature = preprocess_chatglm2(tokenizer, config, example, max_seq_length)
feature = preprocess(tokenizer, config, example, max_seq_length, version)
if skip_overlength and len(feature["input_ids"]) > max_seq_length:
continue
# feature["input_ids"] = feature["input_ids"][:max_seq_length]
Expand All @@ -71,10 +61,12 @@ def main():
parser.add_argument("--save_path", type=str, default="data/alpaca")
parser.add_argument("--max_seq_length", type=int, default=384)
parser.add_argument("--skip_overlength", type=bool, default=False)
parser.add_argument("--chatglm_path", type=str, default='model_path/chatglm')
parser.add_argument("--version", type=str, default='v1')
args = parser.parse_args()

dataset = datasets.Dataset.from_generator(
lambda: read_jsonl(args.jsonl_path, args.max_seq_length, args.skip_overlength)
lambda: read_jsonl(args.jsonl_path, args.max_seq_length, args.chatglm_path, args.version, args.skip_overlength)
)
dataset.save_to_disk(args.save_path)

Expand Down

0 comments on commit 4534a30

Please sign in to comment.