Skip to content

Commit

Permalink
Fix max sequence padding
Browse files Browse the repository at this point in the history
Use transformers trainer
  • Loading branch information
duzx16 committed Jul 4, 2023
1 parent b99e3d7 commit a33cb4f
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 3,780 deletions.
2 changes: 1 addition & 1 deletion ptuning/evaluate.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ torchrun --standalone --nnodes=1 --nproc-per-node=$NUM_GPUS main.py \
--overwrite_cache \
--prompt_column content \
--response_column summary \
--model_name_or_path chatglm2-6b \
--model_name_or_path THUDM/chatglm2-6b \
--ptuning_checkpoint ./output/$CHECKPOINT/checkpoint-$STEP \
--output_dir ./output/$CHECKPOINT \
--overwrite_output_dir \
Expand Down
4 changes: 2 additions & 2 deletions ptuning/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def preprocess_function_eval(examples):
return model_inputs

def preprocess_function_train(examples):
max_seq_length = data_args.max_source_length + data_args.max_target_length
max_seq_length = data_args.max_source_length + data_args.max_target_length + 1

model_inputs = {
"input_ids": [],
Expand Down Expand Up @@ -335,7 +335,7 @@ def compute_metrics(eval_preds):
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
save_prefixencoder=model_args.pre_seq_len is not None
save_changed=model_args.pre_seq_len is not None
)

# Training
Expand Down
Loading

0 comments on commit a33cb4f

Please sign in to comment.