Skip to content

Commit

Permalink
Add lora training script (lm-sys#138)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZYHowell authored Apr 6, 2023
1 parent 48c31a7 commit b0f6107
Showing 1 changed file with 108 additions and 0 deletions.
108 changes: 108 additions & 0 deletions fastchat/train/train_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Usage: deepspeed train_lora.py --deepspeed <$PATH_TO_DEEPSPEED_CONFIG>

# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
import pathlib
import typing

from peft import (
LoraConfig,
get_peft_model,
)
import transformers
from transformers import Trainer

from fastchat.train.train import (DataArguments, ModelArguments,
TrainingArguments,
make_supervised_data_module,
smart_tokenizer_and_embedding_resize)

IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "</s>"
DEFAULT_UNK_TOKEN = "</s>"

# TODO: the lora_target_modules cannot support list
@dataclass
class LoraArguments:
lora_r: int = 8,
lora_alpha: int = 16,
lora_dropout: float = 0.05,
lora_target_modules: typing.List[str] = ["q_proj", "v_proj"],
lora_weight_path: str = ""

def train():
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments, LoraArguments))
(model_args, data_args, training_args,
lora_args) = parser.parse_args_into_dataclasses()

model = transformers.LlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
)
lora_config = LoraConfig(
r=lora_args.lora_r,
lora_alpha=lora_args.lora_alpha,
target_modules=["q_proj", "v_proj"],
lora_dropout=lora_args.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=False,
)
if tokenizer.pad_token is None:
smart_tokenizer_and_embedding_resize(
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
tokenizer=tokenizer,
model=model,
)
if "llama" in model_args.model_name_or_path:
tokenizer.add_special_tokens({
"eos_token": DEFAULT_EOS_TOKEN,
"bos_token": DEFAULT_BOS_TOKEN,
"unk_token": DEFAULT_UNK_TOKEN,
})

data_module = make_supervised_data_module(tokenizer=tokenizer,
data_args=data_args)
trainer = Trainer(model=model,
tokenizer=tokenizer,
args=training_args,
**data_module)

model.config.use_cache = False

if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
trainer.save_state()
model.save_pretrained(training_args.output_dir)


if __name__ == "__main__":
train()

0 comments on commit b0f6107

Please sign in to comment.