Skip to content

Commit

Permalink
Multi-batch support for text-generation (microsoft#40)
Browse files Browse the repository at this point in the history
Co-authored-by: Michael Wyatt <[email protected]>
  • Loading branch information
jeffra and mrwyattii authored Jul 21, 2022
1 parent 56056fc commit b67556d
Show file tree
Hide file tree
Showing 12 changed files with 167 additions and 36 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: nv-torch18-v100-local
name: nv-torch-latest-v100-local

on:
push:
Expand All @@ -17,7 +17,7 @@ concurrency:

jobs:
unit-tests:
runs-on: [self-hosted, nvidia, torch18, v100]
runs-on: [self-hosted, nvidia, cu113, v100]

steps:
- uses: actions/checkout@v2
Expand All @@ -29,12 +29,14 @@ jobs:
python --version
which nvcc
nvcc --version
pip install torch==1.8.2+cu111 torchvision==0.9.2+cu111 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu113
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
- name: Install MII
run: |
git clone https://github.com/microsoft/DeepSpeed.git && pip install ./DeepSpeed
pip install git+https://github.com/microsoft/DeepSpeed.git
pip install git+https://github.com/huggingface/transformers.git
pip install -U accelerate
pip install .[dev,local]
ds_report
- name: Unit tests
Expand Down
1 change: 0 additions & 1 deletion examples/local/checkpoints.json

This file was deleted.

18 changes: 15 additions & 3 deletions mii/grpc_related/modelresponse_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,22 @@ def _unpack_proto_query_kwargs(self, query_kwargs):
def GeneratorReply(self, request, context):
query_kwargs = self._unpack_proto_query_kwargs(request.query_kwargs)
start = time.time()
response = self.inference_pipeline(request.request, **query_kwargs)

# unpack grpc list into py-list
request = [r for r in request.request]

batched_responses = self.inference_pipeline(request, **query_kwargs)
end = time.time()
return modelresponse_pb2.SingleStringReply(response=f"{response}",
time_taken=end - start)

# response is a list
text_responses = []
for response in batched_responses:
text = response[0]['generated_text']
text_responses.append(text)

val = modelresponse_pb2.MultiStringReply(response=text_responses,
time_taken=end - start)
return val

def ClassificationReply(self, request, context):
query_kwargs = self._unpack_proto_query_kwargs(request.query_kwargs)
Expand Down
3 changes: 3 additions & 0 deletions mii/grpc_related/proto/build_script.sh
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
#!/bin/bash
python3 -m grpc_tools.protoc -I./ --python_out=. --grpc_python_out=. ./modelresponse.proto

# update import to be global wrt mii
sed -i 's/modelresponse_pb2/mii.grpc_related.proto.modelresponse_pb2/g' modelresponse_pb2_grpc.py
14 changes: 11 additions & 3 deletions mii/grpc_related/proto/modelresponse.proto
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ package modelresponse;
// The greeting service definition.
service ModelResponse {
// Sends a greeting
rpc GeneratorReply (SingleStringRequest) returns (SingleStringReply) {}
rpc GeneratorReply (MultiStringRequest) returns (MultiStringReply) {}
rpc ClassificationReply (SingleStringRequest) returns (SingleStringReply) {}
rpc QuestionAndAnswerReply(QARequest) returns (SingleStringReply) {}
rpc FillMaskReply(SingleStringRequest) returns (SingleStringReply) {}
Expand All @@ -41,18 +41,26 @@ message Value {
}
}

// The request message containing the user's name.
message SingleStringRequest {
string request = 1;
map<string,Value> query_kwargs = 2;
}

// The response message containing the greetings
message MultiStringRequest {
repeated string request = 1;
map<string,Value> query_kwargs = 2;
}

message SingleStringReply {
string response = 1;
float time_taken = 2;
}

message MultiStringReply {
repeated string response = 1;
float time_taken = 2;
}

message QARequest {
string question = 1;
string context = 2;
Expand Down
69 changes: 58 additions & 11 deletions mii/grpc_related/proto/modelresponse_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 6 additions & 6 deletions mii/grpc_related/proto/modelresponse_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def __init__(self, channel):
"""
self.GeneratorReply = channel.unary_unary(
'/modelresponse.ModelResponse/GeneratorReply',
request_serializer=modelresponse__pb2.SingleStringRequest.SerializeToString,
response_deserializer=modelresponse__pb2.SingleStringReply.FromString,
request_serializer=modelresponse__pb2.MultiStringRequest.SerializeToString,
response_deserializer=modelresponse__pb2.MultiStringReply.FromString,
)
self.ClassificationReply = channel.unary_unary(
'/modelresponse.ModelResponse/ClassificationReply',
Expand Down Expand Up @@ -92,8 +92,8 @@ def add_ModelResponseServicer_to_server(servicer, server):
'GeneratorReply':
grpc.unary_unary_rpc_method_handler(
servicer.GeneratorReply,
request_deserializer=modelresponse__pb2.SingleStringRequest.FromString,
response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString,
request_deserializer=modelresponse__pb2.MultiStringRequest.FromString,
response_serializer=modelresponse__pb2.MultiStringReply.SerializeToString,
),
'ClassificationReply':
grpc.unary_unary_rpc_method_handler(
Expand Down Expand Up @@ -150,8 +150,8 @@ def GeneratorReply(request,
request,
target,
'/modelresponse.ModelResponse/GeneratorReply',
modelresponse__pb2.SingleStringRequest.SerializeToString,
modelresponse__pb2.SingleStringReply.FromString,
modelresponse__pb2.MultiStringRequest.SerializeToString,
modelresponse__pb2.MultiStringReply.FromString,
options,
channel_credentials,
insecure,
Expand Down
13 changes: 11 additions & 2 deletions mii/models/load_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@ def __call__(self, inputs, min_length=20, do_sample=False, **kwargs):
from deepspeed.inference.engine import InferenceEngine
if isinstance(self.model, InferenceEngine):
self.model = self.model.module
tokens = self.tokenizer.batch_encode_plus([inputs],

# expand proto list into py-list
inputs = [i for i in inputs]
tokens = self.tokenizer.batch_encode_plus(inputs,
return_tensors="pt",
padding=True)
for t in tokens:
Expand All @@ -68,7 +71,13 @@ def __call__(self, inputs, min_length=20, do_sample=False, **kwargs):
max_length=min_length,
do_sample=do_sample)
outputs = self.tokenizer.batch_decode(greedy_output, skip_special_tokens=True)
return outputs[0]

# construct output to align w. HF pipeline
output_dicts = []
for output in outputs:
output_dicts.append([{'generated_text': output}])

return output_dicts


def get_checkpoint_files(pretrained_model_name_or_path):
Expand Down
9 changes: 6 additions & 3 deletions mii/server_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,12 @@ async def _query_in_tensor_parallel(self, request_string, query_kwargs):
async def _request_async_response(self, stub_id, request_dict, query_kwargs):
proto_kwargs = kwarg_dict_to_proto(query_kwargs)
if self.task == mii.Tasks.TEXT_GENERATION:
response = await self.stubs[stub_id].GeneratorReply(
mii.modelresponse_pb2.SingleStringRequest(request=request_dict['query'],
query_kwargs=proto_kwargs))
# convert to batch of queries if they are not already
if not isinstance(request_dict['query'], list):
request_dict['query'] = [request_dict['query']]
req = mii.modelresponse_pb2.MultiStringRequest(request=request_dict['query'],
query_kwargs=proto_kwargs)
response = await self.stubs[stub_id].GeneratorReply(req)

elif self.task == mii.Tasks.TEXT_CLASSIFICATION:
response = await self.stubs[stub_id].ClassificationReply(
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
asyncio
deepspeed
deepspeed>=0.6.7
grpcio
grpcio-tools
pydantic
Expand Down
38 changes: 38 additions & 0 deletions scripts/model_download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#!/usr/bin/env python3
import os
import argparse

from huggingface_hub import HfApi
from transformers import AutoConfig, AutoTokenizer, AutoModel

def dir_path(path_str):
if os.path.isdir(path_str):
return path_str
elif input(f"{path_str} does not exist, create directory? [y/n]").lower() == "y":
os.makedirs(path_str)
return path_str
else:
raise NotADirectoryError(path_str)

class HFModelNotFoundError(Exception):
def __init__(self, model_str):
super().__init__(f"HuggingFace model not found: '{model_str}'")

def hf_model(model_str):
api = HfApi()
models = [m.modelId for m in api.list_models()]
if model_str in models:
return model_str
else:
raise HFModelNotFoundError(model_str)

parser = argparse.ArgumentParser()
parser.add_argument("--model_path", '-o', type=dir_path, required=True, help="Output directory for downloaded model files")
parser.add_argument("--model_name", '-m', type=hf_model, required=True, help="HuggingFace model name")
args = parser.parse_args()

for auto_func in [AutoConfig, AutoTokenizer, AutoModel]:
auto_func.from_pretrained(args.model_name, cache_dir=args.model_path)

print(f"Cached files for '{args.model_name}' downloaded to '{args.model_path}'")

Loading

0 comments on commit b67556d

Please sign in to comment.