Skip to content

Commit

Permalink
[Improve] Add internlm7b qlora oasst1 hf config (InternLM#68)
Browse files Browse the repository at this point in the history
add hf internlm qlora oasst1
  • Loading branch information
LZHgrla authored Aug 30, 2023
1 parent cfc4b66 commit 8ff60e9
Showing 1 changed file with 77 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import torch
from datasets import load_dataset
from peft import LoraConfig
from transformers import (AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig, Trainer, TrainingArguments)

from xtuner.dataset import process_hf_dataset
from xtuner.dataset.map_fns import oasst1_map_fn, template_map_fn_factory
from xtuner.utils import PROMPT_TEMPLATE

framework = 'huggingface'
pretrained_model_name_or_path = 'internlm/internlm-7b'
dataset_name_or_path = 'timdettmers/openassistant-guanaco'
max_length = 2048
pack_to_max_length = True
prompt_template = PROMPT_TEMPLATE.openassistant

trainer = Trainer

training_args = dict(
type=TrainingArguments,
do_train=True,
learning_rate=2e-4,
weight_decay=0,
lr_scheduler_type='cosine',
warmup_steps=100,
per_device_train_batch_size=1,
gradient_accumulation_steps=16,
num_train_epochs=3,
fp16=True,
logging_steps=1,
optim='paged_adamw_32bit',
save_strategy='steps',
save_steps=1000,
save_total_limit=2,
ddp_find_unused_parameters=False)

tokenizer = dict(
type=AutoTokenizer.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
trust_remote_code=True,
padding_side='right')

model = dict(
type=AutoModelForCausalLM.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
trust_remote_code=True,
torch_dtype=torch.float16,
quantization_config=dict(
type=BitsAndBytesConfig,
load_in_4bit=True,
load_in_8bit=False,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4'))

lora = dict(
type=LoraConfig,
r=64,
lora_alpha=16,
lora_dropout=0.1,
bias='none',
task_type='CAUSAL_LM')

train_dataset = dict(
type=process_hf_dataset,
dataset=dict(type=load_dataset, path=dataset_name_or_path),
tokenizer=tokenizer,
max_length=max_length,
dataset_map_fn=oasst1_map_fn,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
remove_unused_columns=True,
shuffle_before_pack=True,
pack_to_max_length=pack_to_max_length)

0 comments on commit 8ff60e9

Please sign in to comment.