Skip to content

Commit

Permalink
Enable checkpoint with FSDP on single device (Lightning-AI#124)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Apr 12, 2023
1 parent 5d98989 commit 7b2f995
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion lit_llama/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def save_model_checkpoint(fabric, model, file_path):
"""

if isinstance(fabric.strategy, FSDPStrategy):
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
save_policy = FullStateDictConfig(offload_to_cpu=(fabric.world_size > 1), rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
state_dict = model._forward_module.state_dict()
else:
Expand Down

0 comments on commit 7b2f995

Please sign in to comment.