Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support load merged checkpoint #70105

Merged
merged 4 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 124 additions & 7 deletions python/paddle/distributed/checkpoint/load_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@

from .metadata import LocalTensorIndex, LocalTensorMetadata
from .utils import (
check_unique_id,
compute_local_shape_and_global_offset,
flatten_state_dict,
get_max_id,
)

if TYPE_CHECKING:
Expand All @@ -50,23 +52,30 @@ class ReadItem:
PATH_TO_CHECKPOINT_FILES: dict[str, tuple[list, list]] = {}


def get_checkpoint_files(path, use_cache=True):
def get_checkpoint_files(path, use_cache=True, unique_id=None):
# if unique_id is None, all file ends with .metadata and .distcp is returned
if unique_id is None:
unique_id = ''
global PATH_TO_CHECKPOINT_FILES
if use_cache and path in PATH_TO_CHECKPOINT_FILES:
return PATH_TO_CHECKPOINT_FILES[path]
accessible_files = os.listdir(path)
metadata_files = [
file for file in accessible_files if file.endswith(".metadata")
file
for file in accessible_files
if file.endswith(f"{unique_id}.metadata")
]
assert (
len(metadata_files) > 0
), f"No metadata file found in the checkpoint directory:{path}."
), f"No metadata file ends with '{unique_id}.metadata' found in the checkpoint directory: {path}."
local_data_files = [
file for file in accessible_files if file.endswith(".distcp")
file
for file in accessible_files
if file.endswith(f"{unique_id}.distcp")
]
assert (
len(local_data_files) > 0
), f"No data file found in the checkpoint directory:{path}."
), f"No data file ends with '{unique_id}.distcp' found in the checkpoint directory:{path}."
if use_cache:
PATH_TO_CHECKPOINT_FILES[path] = (metadata_files, local_data_files)
return (metadata_files, local_data_files)
Expand Down Expand Up @@ -469,7 +478,8 @@ def load_state_dict(
path: str,
process_group: Group | None = None,
coordinator_rank: int = 0,
offload=False,
unique_id: int | None = None,
offload: bool = False,
) -> None:
"""
Load the state_dict inplace from a checkpoint path.
Expand All @@ -479,6 +489,7 @@ def load_state_dict(
path(str): The directory to load checkpoint files.
process_group(paddle.distributed.collective.Group): ProcessGroup to be used for cross-rank synchronization. Use the default process group which contains all cards.
coordinator_rank(int): The rank used to coordinate the checkpoint. Rank0 is used by default.
unique_id(int): The unique id of ckeckpoint, used to distinguish between different checkpoint versions. Default is None, in which case the id the max id of given path, and the newest version checkpoint is loaded.
offload(bool): Whether to offload the checkpoint data from GPU to CPU.
Example:
.. code-block:: python
Expand Down Expand Up @@ -524,8 +535,18 @@ def load_state_dict(
if use_dist:
# sync to avoid some ranks not write path yet
paddle.distributed.barrier(process_group)
if unique_id is None:
unique_id = get_max_id(path)
else:
assert unique_id >= 0, f'{unique_id} should be >= 0'
logger.info(f"The unique_id:{unique_id} is uesed.")

metadata_files, local_data_files = get_checkpoint_files(path)
if use_dist:
check_unique_id(unique_id, process_group)

metadata_files, local_data_files = get_checkpoint_files(
path, unique_id=unique_id
)

metadata_list = []
for file in metadata_files:
Expand Down Expand Up @@ -732,3 +753,99 @@ def _load_state_dict(

if use_dist:
paddle.distributed.barrier(process_group)


def compute_global_shape(local_tensor_indexs):
rank = len(local_tensor_indexs[0].local_shape)
global_shape = []
for dim in range(rank):
max_size = max(
m.global_offset[dim] + m.local_shape[dim]
for m in local_tensor_indexs
)
global_shape.append(max_size)
return global_shape


def load_merged_state_dict(
path: str, prefix=None, unique_id=None, offload=False
):
"""
Load the distributed checkpoint and merge it to unsharded state_dict.

Args:
path(str): The directory to load checkpoint files.
prefix(str): The flat_mapping prefix of state_dict key. e.g., 'model', Default None.
unique_id(int): The unique id of ckeckpoint, used to distinguish between different checkpoint versions. Default is None, in which case the id the max id of given path, and the newest version checkpoint is loaded.
offload(bool): Whether to offload the checkpoint data from GPU to CPU, set to True if GPU memory is not enough.

Returns:
dict: Merged state_dict.

Example:
.. code-block:: python

>>> # doctest: +SKIP('run in distributed mode.')
>>> import paddle
>>> import paddle.distributed as dist
>>> ckpt_path = "./checkpoint"
>>> w1 = paddle.arange(32).reshape([4, 8])
>>> mesh = dist.ProcessMesh([0, 1])
>>> sharded_w1 = dist.shard_tensor(w1, mesh, [dist.Shard(0)])
>>> state_dict = {"w1": sharded_w1}
>>> dist.save_state_dict(state_dict, ckpt_path) # save sharded checkpoint

>>> # doctest: +SKIP('run in single-card mode.')
>>> import paddle
>>> import paddle.distributed as dist
>>> ckpt_path = "./checkpoint"
>>> unsharded_state_dict = dist.checkpoint.utils.merge_state_dict(ckpt_path) # load unsharded checkpoint
>>> print(f"unsharded_state_dict:{unsharded_state_dict}")
unsharded_state_dict:{'w1':
[[0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ],
[8 , 9 , 10, 11, 12, 13, 14, 15],
[16, 17, 18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29, 30, 31]])}
>>> # doctest: -SKIP
"""
if unique_id is None:
unique_id = get_max_id(path)
else:
assert unique_id >= 0, f'{unique_id} should be >= 0'

metadata_files, local_data_files = get_checkpoint_files(
path, unique_id=unique_id
)

metadata_list = []
for file in metadata_files:
metadata_list.append(paddle.load(os.path.join(path, file)))

# create target state_dict by local_tensor_meta
state_dict_to_save = {}
for metadata in metadata_list:
for (
tensor_key,
local_tensor_meta,
) in metadata.state_dict_metadata.items():
if prefix is None or tensor_key.startswith(prefix):
global_shape = compute_global_shape(local_tensor_meta)
t = paddle.zeros(global_shape, dtype=local_tensor_meta[0].dtype)
if offload:
t = t.cpu()
state_dict_to_save[tensor_key] = t.cpu()
else:
continue

load_state_dict(state_dict_to_save, path, offload=offload)

# Update dictionary keys in place
for key in list(
state_dict_to_save.keys()
): # Use list(data.keys()) to avoid runtime error
if prefix and key.startswith(prefix):
new_key = key[len(prefix) + 1 :] # Remove the "str" prefix
state_dict_to_save[new_key] = state_dict_to_save.pop(
key
) # Add new key and remove the old one
return state_dict_to_save
42 changes: 20 additions & 22 deletions python/paddle/distributed/checkpoint/save_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@

from .metadata import LocalTensorIndex, LocalTensorMetadata, Metadata
from .utils import (
check_unique_id,
compute_local_shape_and_global_offset,
flatten_state_dict,
get_max_id,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -76,18 +78,6 @@ def copy_dict_to_cpu(nested_dict):
return new_dict


def check_file_name(file_name, process_group):
all_unique_id = []
unique_id = int(file_name.split(".")[0].split("_")[1])
paddle.distributed.all_gather_object(
all_unique_id, unique_id, process_group
)
for id in all_unique_id[1:]:
assert (
id == all_unique_id[0]
), f"id:{id} != all_unique_id[0]:{file_name}"


def merge_state_dict_metadata(global_state_dict_metadata):
assert isinstance(
global_state_dict_metadata, list
Expand Down Expand Up @@ -147,6 +137,7 @@ def save_state_dict(
path: str,
process_group: Group | None = None,
coordinator_rank: int = 0,
unique_id: int | None = None,
async_save: bool = False,
) -> None:
"""
Expand All @@ -156,9 +147,11 @@ def save_state_dict(
state_dict(Dict[str, paddle.Tensor]): The state_dict to save.
path(str): The directory to save state_dict.
process_group(paddle.distributed.collective.Group): ProcessGroup to be used for cross-rank synchronization. Use the default process group which contains all cards.
coordinator_rank(int): The rank used to save non distributed values. Rank0 is used by default.
coordinator_rank(int): The rank used to save non distributed values. Rank 0 is used by default.
unique_id(int): The unique id of ckeckpoint, used to distinguish between different checkpoint versions. Default is None, in which case the id 0 when save for the first time and increased by 1 each time when calling save_state_dict in the same path.
async_save(bool): Async save the state_dict, default is False.

Note: If there is already checkpoint in
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Note: If there is already checkpoint in
Note:
If there is already checkpoint in

image

Examples:
.. code-block:: python

Expand Down Expand Up @@ -193,16 +186,21 @@ def save_state_dict(
# Init the default global process group
paddle.distributed.init_parallel_env()

unique_id = 0
file_name = ""
while True:
file_name = f"{paddle.distributed.get_rank()}_{unique_id}.distcp"
if not os.path.exists(os.path.join(path, file_name)):
break
unique_id += 1
logger.debug(f"file_name:{file_name}")
if unique_id is None:
max_unique_id = get_max_id(path)
logger.debug(f"Max unique id: {max_unique_id}")
if max_unique_id is None:
unique_id = 0
else:
unique_id = max_unique_id
else:
assert unique_id >= 0, f'{unique_id} should be >= 0'
if use_dist:
check_file_name(file_name, process_group)
check_unique_id(unique_id, process_group)

file_name = f"{paddle.distributed.get_rank()}_{unique_id}.distcp"
logger.debug(f"The checkpoint is saved to file_name:{file_name}")

metadata = Metadata()
local_state_dict = {}
local_state_dict_metadata = {}
Expand Down
22 changes: 22 additions & 0 deletions python/paddle/distributed/checkpoint/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from __future__ import annotations

import copy
import os
import re
from typing import TYPE_CHECKING

import numpy as np
Expand Down Expand Up @@ -117,3 +119,23 @@ def unflatten_state_dict(flat_state_dict, mapping):
tmp[key_tuple[-1]] = value

return state_dict


def get_max_id(path):
numbers = []
pattern = re.compile(r"^(\d+)_(\d+)\.distcp$")
files = os.listdir(path)
for file in files:
match = pattern.match(file)
if match:
numbers.append(int(match.group(2)))
return max(numbers) if numbers else None


def check_unique_id(unique_id, process_group):
all_unique_id = []
paddle.distributed.all_gather_object(
all_unique_id, unique_id, process_group
)
for id in all_unique_id[1:]:
assert id == all_unique_id[0], f"id:{id} != all_unique_id[0]"