Skip to content

Commit

Permalink
Add checkpoint sharding unit tests (microsoft#2561)
Browse files Browse the repository at this point in the history
* added checkpopint sharding tests
  • Loading branch information
mrwyattii authored Dec 8, 2022
1 parent 591744e commit ccb8eb8
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 5 deletions.
6 changes: 3 additions & 3 deletions deepspeed/module_inject/replace_module.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -1022,7 +1022,7 @@ def replace_fn(child, _policy, layer_id=0):
if transformer_name not in k
}),
f'{config.save_mp_checkpoint_path}/{non_tp_ckpt_name}')
new_config = json.dumps({
ckpt_config = json.dumps({
'type':
ckpt_name,
'base_dir':
Expand All @@ -1044,9 +1044,9 @@ def replace_fn(child, _policy, layer_id=0):
'dtype':
'int8' if quantize else ('float16' if fp16 else 'float32')
})
with open(f"{config.save_mp_checkpoint_path}/ds-inference_config.json",
with open(f"{config.save_mp_checkpoint_path}/ds_inference_config.json",
"w") as cfg:
cfg.write(new_config)
cfg.write(ckpt_config)

rep_sd = replaced_module.state_dict()
for n, p in replaced_module.named_parameters():
Expand Down
3 changes: 2 additions & 1 deletion deepspeed/module_inject/replace_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def __init__(
self.is_megatron_v2 = megatron_v2
self.mlp_act_func_type = mlp_act_func_type
self.pre_attn_norm = pre_attn_norm
self.load_prefix = False
self.use_load_prefix = use_load_prefix
self.split_qkv = split_qkv

def attention(self):
"""
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/state_dict_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_sd_loader_json(json_file, checkpoint_engine):
version = data['version']
ckpt_type = data.get('parallelization', 'pp')
mp_size = data.get('mp_size', 0)
if 'bloom' in sd_type.lower():
if sd_type.lower() in ['bloom', 'ds_model']:
return data
return SDLoaderFactory.get_sd_loader(ckpt_list,
checkpoint_engine,
Expand Down
94 changes: 94 additions & 0 deletions tests/unit/inference/test_checkpoint_sharding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import os
import pytest
import torch
import deepspeed
from deepspeed.model_implementations import DeepSpeedTransformerInference
from transformers import AutoConfig, AutoModelForCausalLM
from unit.common import DistributedTest, DistributedFixture


def check_dtype(model, expected_dtype):
def find_dtype(module):
for child in module.children():
if isinstance(child, DeepSpeedTransformerInference):
return child.attention.attn_qkvw.dtype
else:
found_dtype = find_dtype(child)
if found_dtype:
return found_dtype

found_dtype = find_dtype(model)
assert found_dtype, "Did not find DeepSpeedTransformerInference in model"
assert (
found_dtype == expected_dtype
), f"Expected transformer dtype {expected_dtype}, but found {found_dtype}"


@pytest.fixture(params=[
"bigscience/bloom-560m",
"EleutherAI/gpt-j-6B",
"EleutherAI/gpt-neo-125M",
"facebook/opt-125m"
])
def model_name(request):
return request.param


@pytest.fixture(params=[torch.float16, torch.int8], ids=["fp16", "int8"])
def dtype(request):
return request.param


class save_shard(DistributedFixture):
world_size = 2

def run(self, model_name, class_tmpdir):
# Only write a checkpoint if one does not exist
if not os.path.isdir(os.path.join(class_tmpdir, model_name)):
world_size = int(os.getenv("WORLD_SIZE", "1"))
inf_config = {
"replace_with_kernel_inject": True,
"dtype": torch.float16,
"replace_method": "auto",
"enable_cuda_graph": False,
"tensor_parallel": {
"tp_size": world_size
},
"save_mp_checkpoint_path": os.path.join(str(class_tmpdir),
model_name),
}

# Load model and save sharded checkpoint
model = AutoModelForCausalLM.from_pretrained(model_name,
torch_dtype=torch.float16)
model = deepspeed.init_inference(model, config=inf_config)


@pytest.mark.seq_inference
class TestCheckpointShard(DistributedTest):
world_size = 2

def test(self, model_name, dtype, class_tmpdir, save_shard):
world_size = int(os.getenv("WORLD_SIZE", "1"))
inf_config = {
"replace_with_kernel_inject": True,
"dtype": dtype,
"replace_method": "auto",
"enable_cuda_graph": False,
"tensor_parallel": {
"tp_size": world_size
},
"checkpoint": os.path.join(class_tmpdir,
model_name,
"ds_inference_config.json"),
}

# Load model on meta tensors
model_config = AutoConfig.from_pretrained(model_name)
# Note that we use half precision to load initially, even for int8
with deepspeed.OnDevice(dtype=torch.float16, device="meta"):
model = AutoModelForCausalLM.from_config(model_config,
torch_dtype=torch.bfloat16)
model = model.eval()
model = deepspeed.init_inference(model, config=inf_config)
check_dtype(model, dtype)

0 comments on commit ccb8eb8

Please sign in to comment.