forked from lm-sys/FastChat
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add lora training script (lm-sys#138)
- Loading branch information
Showing
1 changed file
with
108 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# Usage: deepspeed train_lora.py --deepspeed <$PATH_TO_DEEPSPEED_CONFIG> | ||
|
||
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: | ||
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from dataclasses import dataclass | ||
import pathlib | ||
import typing | ||
|
||
from peft import ( | ||
LoraConfig, | ||
get_peft_model, | ||
) | ||
import transformers | ||
from transformers import Trainer | ||
|
||
from fastchat.train.train import (DataArguments, ModelArguments, | ||
TrainingArguments, | ||
make_supervised_data_module, | ||
smart_tokenizer_and_embedding_resize) | ||
|
||
IGNORE_INDEX = -100 | ||
DEFAULT_PAD_TOKEN = "[PAD]" | ||
DEFAULT_EOS_TOKEN = "</s>" | ||
DEFAULT_BOS_TOKEN = "</s>" | ||
DEFAULT_UNK_TOKEN = "</s>" | ||
|
||
# TODO: the lora_target_modules cannot support list | ||
@dataclass | ||
class LoraArguments: | ||
lora_r: int = 8, | ||
lora_alpha: int = 16, | ||
lora_dropout: float = 0.05, | ||
lora_target_modules: typing.List[str] = ["q_proj", "v_proj"], | ||
lora_weight_path: str = "" | ||
|
||
def train(): | ||
parser = transformers.HfArgumentParser( | ||
(ModelArguments, DataArguments, TrainingArguments, LoraArguments)) | ||
(model_args, data_args, training_args, | ||
lora_args) = parser.parse_args_into_dataclasses() | ||
|
||
model = transformers.LlamaForCausalLM.from_pretrained( | ||
model_args.model_name_or_path, | ||
cache_dir=training_args.cache_dir, | ||
) | ||
lora_config = LoraConfig( | ||
r=lora_args.lora_r, | ||
lora_alpha=lora_args.lora_alpha, | ||
target_modules=["q_proj", "v_proj"], | ||
lora_dropout=lora_args.lora_dropout, | ||
bias="none", | ||
task_type="CAUSAL_LM", | ||
) | ||
model = get_peft_model(model, lora_config) | ||
model.print_trainable_parameters() | ||
|
||
tokenizer = transformers.AutoTokenizer.from_pretrained( | ||
model_args.model_name_or_path, | ||
cache_dir=training_args.cache_dir, | ||
model_max_length=training_args.model_max_length, | ||
padding_side="right", | ||
use_fast=False, | ||
) | ||
if tokenizer.pad_token is None: | ||
smart_tokenizer_and_embedding_resize( | ||
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), | ||
tokenizer=tokenizer, | ||
model=model, | ||
) | ||
if "llama" in model_args.model_name_or_path: | ||
tokenizer.add_special_tokens({ | ||
"eos_token": DEFAULT_EOS_TOKEN, | ||
"bos_token": DEFAULT_BOS_TOKEN, | ||
"unk_token": DEFAULT_UNK_TOKEN, | ||
}) | ||
|
||
data_module = make_supervised_data_module(tokenizer=tokenizer, | ||
data_args=data_args) | ||
trainer = Trainer(model=model, | ||
tokenizer=tokenizer, | ||
args=training_args, | ||
**data_module) | ||
|
||
model.config.use_cache = False | ||
|
||
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): | ||
trainer.train(resume_from_checkpoint=True) | ||
else: | ||
trainer.train() | ||
trainer.save_state() | ||
model.save_pretrained(training_args.output_dir) | ||
|
||
|
||
if __name__ == "__main__": | ||
train() |