Skip to content

Commit

Permalink
Make DS-Inference config readable from JSON (microsoft#2537)
Browse files Browse the repository at this point in the history
  • Loading branch information
mrwyattii authored Nov 22, 2022
1 parent 57e0a55 commit 8b4318b
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 12 deletions.
33 changes: 21 additions & 12 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import sys
import types
import json
from typing import Optional, Union
import torch
from torch.optim import Optimizer
Expand Down Expand Up @@ -229,7 +230,7 @@ def default_inference_config():
return DeepSpeedInferenceConfig().dict()


def init_inference(model, config=None, **kwargs):
def init_inference(model, config={}, **kwargs):
"""Initialize the DeepSpeed InferenceEngine.
Description: all four cases are valid and supported in DS init_inference() API.
Expand Down Expand Up @@ -272,7 +273,7 @@ def init_inference(model, config=None, **kwargs):
Arguments:
model: Required: original nn.module object without any wrappers
config: Optional: instead of arguments, you can pass in a DS inference config dict
config: Optional: instead of arguments, you can pass in a DS inference config dict or path to JSON file
Returns:
A deepspeed.InferenceEngine wrapped model.
Expand All @@ -283,17 +284,25 @@ def init_inference(model, config=None, **kwargs):
__git_branch__),
ranks=[0])

# User did not pass a config, use defaults
if config is None:
config_dict = kwargs
else:
# Load config_dict from config first
if isinstance(config, str):
with open(config, "r") as f:
config_dict = json.load(f)
elif isinstance(config, dict):
config_dict = config

# if config and kwargs both are passed, merge them, and overwrite using kwargs
if config and kwargs:
config_dict = {}
config_dict.update(config)
config_dict.update(kwargs)
else:
raise ValueError(
f"'config' argument expected string or dictionary, got {type(config)}")

# Update with values from kwargs, ensuring no conflicting overlap between config and kwargs
overlap_keys = set(config_dict.keys()).intersection(kwargs.keys())
# If there is overlap, error out if values are different
for key in overlap_keys:
if config_dict[key] != kwargs[key]:
raise ValueError(
f"Conflicting argument '{key}' in 'config':{config_dict[key]} and kwargs:{kwargs[key]}"
)
config_dict.update(kwargs)

ds_inference_config = DeepSpeedInferenceConfig(**config_dict)

Expand Down
39 changes: 39 additions & 0 deletions tests/unit/inference/test_inference_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest
import torch
import deepspeed
from unit.common import DistributedTest
from unit.simple_model import create_config_from_dict


@pytest.mark.inference
class TestInferenceConfig(DistributedTest):
world_size = 1

def test_overlap_kwargs(self):
config = {"replace_with_kernel_inject": True}
kwargs = {"replace_with_kernel_inject": True}

engine = deepspeed.init_inference(torch.nn.Module(), config=config, **kwargs)
assert engine._config.replace_with_kernel_inject

def test_overlap_kwargs_conflict(self):
config = {"replace_with_kernel_inject": True}
kwargs = {"replace_with_kernel_inject": False}

with pytest.raises(ValueError):
engine = deepspeed.init_inference(torch.nn.Module(), config=config, **kwargs)

def test_kwargs_and_config(self):
config = {"replace_with_kernel_inject": True}
kwargs = {"dtype": torch.float32}

engine = deepspeed.init_inference(torch.nn.Module(), config=config, **kwargs)
assert engine._config.replace_with_kernel_inject
assert engine._config.dtype == kwargs["dtype"]

def test_json_config(self, tmpdir):
config = {"replace_with_kernel_inject": True}
config_json = create_config_from_dict(tmpdir, config)

engine = deepspeed.init_inference(torch.nn.Module(), config=config_json)
assert engine._config.replace_with_kernel_inject

0 comments on commit 8b4318b

Please sign in to comment.