Skip to content

Commit

Permalink
Bug fix for loading models with TP=1 when mii grpc server is not needed
Browse files Browse the repository at this point in the history
  • Loading branch information
samyam committed Apr 11, 2022
1 parent 772087e commit e18abbc
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 2 deletions.
2 changes: 1 addition & 1 deletion examples/azure-local/gpt2-azure-local-example.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
resource_group=os.environ["AZ_RESOURCE_GROUP"])

mii_configs = mii.constants.MII_CONFIGS_DEFAULT
mii_configs[mii.constants.TENSOR_PARALLEL_KEY] = 2
mii_configs[mii.constants.TENSOR_PARALLEL_KEY] = 1

mii.deploy(task_name="text-generation",
model_name="gpt2",
Expand Down
1 change: 1 addition & 0 deletions mii/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from .utils import setup_task, get_task, get_task_name, check_if_task_and_model_is_supported
from .grpc_related.proto import modelresponse_pb2_grpc
from .grpc_related.proto import modelresponse_pb2
from .models.load_models import load_models
2 changes: 1 addition & 1 deletion mii/server_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def _is_server_process_alive(self):
def _initialize_service(self, model_name, model_path, ds_optimize):
process = None
if not self.use_grpc_server:
self.model = mii.load_model(model_name, model_path)
self.model = mii.load_models(mii.get_task_name(self.task), model_name, model_path, ds_optimize)
else:
if self._is_socket_open(self.port_number):
raise RuntimeError(
Expand Down
2 changes: 2 additions & 0 deletions mii/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ def setup_task():
return get_model_path(), not is_aml(), is_aml()




log_levels = {
"debug": logging.DEBUG,
"info": logging.INFO,
Expand Down

0 comments on commit e18abbc

Please sign in to comment.