Skip to content

Commit

Permalink
fix bug in saving checkpoint when using ddp
Browse files Browse the repository at this point in the history
  • Loading branch information
zzasdf committed Mar 23, 2024
1 parent 618ba72 commit 158e3c8
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
7 changes: 3 additions & 4 deletions src/slam_llm/model_checkpointing/checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,9 @@ def save_model_checkpoint_peft(model, optimizer, rank, cfg, checkpoint_name="che
save_dir = os.path.join(cfg.output_dir, checkpoint_name)
os.makedirs(save_dir, exist_ok=True)
save_full_path = os.path.join(save_dir, "model.pt")
if hasattr(model, "module"): #(FIX:MZY): a hack to deal with the model wrapped in DDP
cpu_state = model.module.state_dict()
else:
cpu_state = model.state_dict()
if cfg.enable_ddp:
model = model.module
cpu_state = model.state_dict()
if save_trainable_only:
state_dict = OrderedDict()
for name, para in model.named_parameters():
Expand Down
3 changes: 2 additions & 1 deletion src/slam_llm/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
eval_epoch_acc = rest[0] if rest else -1
checkpoint_start_time = time.perf_counter()
if train_config.save_model and (eval_epoch_loss < best_val_loss):
checkpoint_name = f"{train_config.model_name}_epoch_{str(epoch+1)}_step_{step+1}_acc_{eval_epoch_acc}"
checkpoint_name = f"{train_config.model_name}_epoch_{str(epoch+1)}_step_{step+1}"
if train_config.enable_fsdp or train_config.enable_ddp:
dist.barrier()
if train_config.use_peft:
Expand Down Expand Up @@ -416,6 +416,7 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
)
pbar.update(1)
pbar.set_description(f"step: {step+1}/{total_length}, eval_loss: {eval_loss/(step+1):.4f}, eval_acc: {eval_acc/(step+1):.4f}")
break

# If there's more than one CUDA device, reduce evaluation loss across all devices
if torch.cuda.device_count() > 1 and train_config.enable_fsdp or train_config.enable_ddp:
Expand Down

0 comments on commit 158e3c8

Please sign in to comment.