Skip to content

Commit

Permalink
Merge pull request #46 from ddlBoJack/dev-zzasdf
Browse files Browse the repository at this point in the history
simplify the logic of saving checkpoint
  • Loading branch information
ddlBoJack authored Mar 24, 2024
2 parents e570afb + 474222f commit 2a72707
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 29 deletions.
36 changes: 13 additions & 23 deletions src/slam_llm/model_checkpointing/checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from datetime import datetime
import torch
import time
from collections import OrderedDict

from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
Expand Down Expand Up @@ -164,34 +165,23 @@ def save_model_checkpoint(

logger.info(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n")

def save_model_checkpoint_peft(model, optimizer, rank, cfg, epoch=0, step=0):
def save_model_checkpoint_peft(model, optimizer, rank, cfg, checkpoint_name="checkpoint", save_trainable_only=True):
logger.info(f"--> saving model ...")
save_dir = os.path.join(cfg.output_dir, cfg.model_name, str(epoch+1), str(step+1))
save_dir = os.path.join(cfg.output_dir, checkpoint_name)
os.makedirs(save_dir, exist_ok=True)
if not cfg.freeze_llm:
if hasattr(model, "module"): #(FIX:MZY): a hack to deal with the model wrapped in DDP
model.module.llm.save_pretrained(save_dir)
else:
model.llm.save_pretrained(save_dir)
logger.info(f"llm saved at {save_dir}")

save_full_path = os.path.join(save_dir, "model.pt")
if hasattr(model, "module"): #(FIX:MZY): a hack to deal with the model wrapped in DDP
cpu_state = model.module.state_dict()
if cfg.enable_ddp:
model = model.module
cpu_state = model.state_dict()
if save_trainable_only:
state_dict = OrderedDict()
for name, para in model.named_parameters():
if para.requires_grad:
state_dict[name] = cpu_state[name]
else:
cpu_state = model.state_dict()
encoder_dict = {}
if not cfg.freeze_encoder:
for key in cpu_state.keys():
if key.startswith("encoder."):
encoder_dict[key] = cpu_state[key]
for key in cpu_state.keys():
if key.startswith("encoder_projector."):
encoder_dict[key] = cpu_state[key]
torch.save(encoder_dict, save_full_path)
state_dict = cpu_state
torch.save(state_dict, save_full_path)
logger.info(f"encoder saved at {save_full_path}")

logger.info(f"model checkpoint saved for epoch {epoch+1} step {step+1}\n")

def save_model_checkpoint_peft_full_shard(model, optimizer, rank, cfg, epoch=0):
with FSDP.state_dict_type(
Expand Down
13 changes: 7 additions & 6 deletions src/slam_llm/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
eval_epoch_acc = rest[0] if rest else -1
checkpoint_start_time = time.perf_counter()
if train_config.save_model and (eval_epoch_loss < best_val_loss):
checkpoint_name = f"{train_config.model_name}_epoch_{str(epoch+1)}_step_{step+1}"
if train_config.enable_fsdp or train_config.enable_ddp:
dist.barrier()
if train_config.use_peft:
Expand All @@ -182,19 +183,19 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
elif fsdp_config.sharding_strategy == ShardingStrategy.NO_SHARD:
if rank==0:
save_model_checkpoint_peft(
model, optimizer, rank, train_config, epoch=epoch, step=step
model, optimizer, rank, train_config, checkpoint_name=checkpoint_name
)
dist.barrier()
elif train_config.enable_ddp:
if rank==0:
save_model_checkpoint_peft(
model, optimizer, rank, train_config, epoch=epoch, step=step
model, optimizer, rank, train_config, checkpoint_name=checkpoint_name
)
dist.barrier()
else:
# model.save_pretrained(train_config.output_dir)
save_model_checkpoint_peft(
model, optimizer, rank, train_config, epoch=epoch, step=step
model, optimizer, rank, train_config, checkpoint_name=checkpoint_name
)
if train_config.enable_fsdp or train_config.enable_ddp:
if rank==0:
Expand All @@ -212,18 +213,18 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
elif fsdp_config.sharding_strategy == ShardingStrategy.NO_SHARD:
if rank==0:
save_model_checkpoint_peft(
model, optimizer, rank, train_config, epoch=epoch, step=step
model, optimizer, rank, train_config, checkpoint_name=checkpoint_name
)
dist.barrier()
elif train_config.enable_ddp:
if rank==0:
save_model_checkpoint_peft(
model, optimizer, rank, train_config, epoch=epoch, step=step
model, optimizer, rank, train_config, checkpoint_name=checkpoint_name
)
dist.barrier()
else:
save_model_checkpoint_peft(
model, optimizer, rank, train_config, epoch=epoch, step=step
model, optimizer, rank, train_config, checkpoint_name=checkpoint_name
)

else:
Expand Down

0 comments on commit 2a72707

Please sign in to comment.