Skip to content

Commit

Permalink
Merge pull request LianjiaTech#335 from Flow3rDown/main
Browse files Browse the repository at this point in the history
修正final model保存问题
  • Loading branch information
mabaochang authored May 6, 2023
2 parents 3212ee8 + b015f9c commit 58a71c2
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions train/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,19 +386,19 @@ def evaluation(model, eval_dataloader):

model.tput_timer.update_epoch_count()

if args.output_dir is not None:
print_rank_0('saving the final model ...', args.global_rank)#It will overwrite the last epoch model
model = convert_lora_to_linear_layer(model)

if args.global_rank == 0:
save_hf_format(model, tokenizer, args)

if args.zero_stage == 3:
# For zero stage 3, each gpu only has a part of the model, so we need a special save function
save_zero_three_model(model,
args.global_rank,
args.output_dir,
zero_stage=args.zero_stage)
if args.output_dir is not None:
print_rank_0('saving the final model ...', args.global_rank)#It will overwrite the last epoch model
model = convert_lora_to_linear_layer(model)

if args.global_rank == 0:
save_hf_format(model, tokenizer, args)

if args.zero_stage == 3:
# For zero stage 3, each gpu only has a part of the model, so we need a special save function
save_zero_three_model(model,
args.global_rank,
args.output_dir,
zero_stage=args.zero_stage)


if __name__ == "__main__":
Expand Down

0 comments on commit 58a71c2

Please sign in to comment.