Skip to content

Commit

Permalink
Add support for inpainting task in DS-MII (deepspeedai#410)
Browse files Browse the repository at this point in the history
Co-authored-by: grajguru <[email protected]>
  • Loading branch information
gauravrajguru and grajguru authored Feb 14, 2024
1 parent 76b9639 commit 8e0c7f1
Show file tree
Hide file tree
Showing 9 changed files with 137 additions and 6 deletions.
7 changes: 7 additions & 0 deletions mii/legacy/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,13 @@ def query(self, request_dict, **query_kwargs):
elif self.task == TaskType.TEXT2IMG:
args = (request_dict["prompt"], request_dict.get("negative_prompt", None))
kwargs = query_kwargs
elif self.task == TaskType.INPAINTING:
negative_prompt = request_dict.get("negative_prompt", None)
args = (request_dict["prompt"],
request_dict["image"],
request_dict["mask_image"],
negative_prompt)
kwargs = query_kwargs
else:
args = (request_dict["query"], )
kwargs = query_kwargs
Expand Down
6 changes: 6 additions & 0 deletions mii/legacy/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class TaskType(str, Enum):
CONVERSATIONAL = "conversational"
TEXT2IMG = "text-to-image"
ZERO_SHOT_IMAGE_CLASSIFICATION = "zero-shot-image-classification"
INPAINTING = "text-to-image-inpainting"


class ModelProvider(str, Enum):
Expand Down Expand Up @@ -60,6 +61,11 @@ class ModelProvider(str, Enum):
TaskType.TEXT2IMG: ["prompt"],
TaskType.ZERO_SHOT_IMAGE_CLASSIFICATION: ["image",
"candidate_labels"],
TaskType.INPAINTING: [
"prompt",
"image",
"mask_image",
]
}

MII_CACHE_PATH = "MII_CACHE_PATH"
Expand Down
3 changes: 3 additions & 0 deletions mii/legacy/grpc_related/modelresponse_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ def ConversationalReply(self, request, context):
def ZeroShotImgClassificationReply(self, request, context):
return self._run_inference("ZeroShotImgClassificationReply", request)

def InpaintingReply(self, request, context):
return self._run_inference("InpaintingReply", request)


class AtomicCounter:
def __init__(self, initial_value=0):
Expand Down
9 changes: 9 additions & 0 deletions mii/legacy/grpc_related/proto/legacymodelresponse.proto
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ service ModelResponse {
rpc ConversationalReply(ConversationRequest) returns (ConversationReply) {}
rpc Txt2ImgReply(Text2ImageRequest) returns (ImageReply) {}
rpc ZeroShotImgClassificationReply (ZeroShotImgClassificationRequest) returns (SingleStringReply) {}
rpc InpaintingReply(InpaintingRequest) returns (ImageReply) {}
}

message Value {
Expand Down Expand Up @@ -114,3 +115,11 @@ message ZeroShotImgClassificationRequest {
repeated string candidate_labels = 2;
map<string,Value> query_kwargs = 3;
}

message InpaintingRequest {
repeated string prompt = 1;
repeated bytes image = 2;
repeated bytes mask_image = 3;
repeated string negative_prompt = 4;
map<string,Value> query_kwargs = 5;
}
13 changes: 9 additions & 4 deletions mii/legacy/grpc_related/proto/legacymodelresponse_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# DeepSpeed Team
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: legacymodelresponse.proto
# Protobuf Python Version: 4.25.0
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
Expand All @@ -17,7 +16,7 @@
from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2

DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x19legacymodelresponse.proto\x12\x13legacymodelresponse\x1a\x1bgoogle/protobuf/empty.proto\"_\n\x05Value\x12\x10\n\x06svalue\x18\x01 \x01(\tH\x00\x12\x10\n\x06ivalue\x18\x02 \x01(\x03H\x00\x12\x10\n\x06\x66value\x18\x03 \x01(\x02H\x00\x12\x10\n\x06\x62value\x18\x04 \x01(\x08H\x00\x42\x0e\n\x0coneof_values\"\x1f\n\tSessionID\x12\x12\n\nsession_id\x18\x01 \x01(\t\"\xc7\x01\n\x13SingleStringRequest\x12\x0f\n\x07request\x18\x01 \x01(\t\x12O\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x39.legacymodelresponse.SingleStringRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"\xc5\x01\n\x12MultiStringRequest\x12\x0f\n\x07request\x18\x01 \x03(\t\x12N\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x38.legacymodelresponse.MultiStringRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"S\n\x11SingleStringReply\x12\x10\n\x08response\x18\x01 \x01(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"R\n\x10MultiStringReply\x12\x10\n\x08response\x18\x01 \x03(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"\xc5\x01\n\tQARequest\x12\x10\n\x08question\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontext\x18\x02 \x01(\t\x12\x45\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32/.legacymodelresponse.QARequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"\x94\x02\n\x13\x43onversationRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x17\n\x0f\x63onversation_id\x18\x02 \x01(\t\x12\x18\n\x10past_user_inputs\x18\x03 \x03(\t\x12\x1b\n\x13generated_responses\x18\x04 \x03(\t\x12O\n\x0cquery_kwargs\x18\x05 \x03(\x0b\x32\x39.legacymodelresponse.ConversationRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"\x91\x01\n\x11\x43onversationReply\x12\x17\n\x0f\x63onversation_id\x18\x01 \x01(\t\x12\x18\n\x10past_user_inputs\x18\x02 \x03(\t\x12\x1b\n\x13generated_responses\x18\x03 \x03(\t\x12\x12\n\ntime_taken\x18\x04 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x05 \x01(\x02\"}\n\nImageReply\x12\x0e\n\x06images\x18\x01 \x03(\x0c\x12\x1d\n\x15nsfw_content_detected\x18\x02 \x03(\x08\x12\x0c\n\x04mode\x18\x03 \x01(\t\x12\x0e\n\x06size_w\x18\x04 \x01(\x03\x12\x0e\n\x06size_h\x18\x05 \x01(\x03\x12\x12\n\ntime_taken\x18\x06 \x01(\x02\"\xdb\x01\n\x11Text2ImageRequest\x12\x0e\n\x06prompt\x18\x01 \x03(\t\x12\x17\n\x0fnegative_prompt\x18\x02 \x03(\t\x12M\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32\x37.legacymodelresponse.Text2ImageRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"\xf9\x01\n ZeroShotImgClassificationRequest\x12\r\n\x05image\x18\x01 \x01(\t\x12\x18\n\x10\x63\x61ndidate_labels\x18\x02 \x03(\t\x12\\\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32\x46.legacymodelresponse.ZeroShotImgClassificationRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\x32\xb7\x08\n\rModelResponse\x12=\n\tTerminate\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12I\n\rCreateSession\x12\x1e.legacymodelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12J\n\x0e\x44\x65stroySession\x12\x1e.legacymodelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12\x62\n\x0eGeneratorReply\x12\'.legacymodelresponse.MultiStringRequest\x1a%.legacymodelresponse.MultiStringReply\"\x00\x12i\n\x13\x43lassificationReply\x12(.legacymodelresponse.SingleStringRequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x12\x62\n\x16QuestionAndAnswerReply\x12\x1e.legacymodelresponse.QARequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x12\x63\n\rFillMaskReply\x12(.legacymodelresponse.SingleStringRequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x12n\n\x18TokenClassificationReply\x12(.legacymodelresponse.SingleStringRequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x12i\n\x13\x43onversationalReply\x12(.legacymodelresponse.ConversationRequest\x1a&.legacymodelresponse.ConversationReply\"\x00\x12Y\n\x0cTxt2ImgReply\x12&.legacymodelresponse.Text2ImageRequest\x1a\x1f.legacymodelresponse.ImageReply\"\x00\x12\x81\x01\n\x1eZeroShotImgClassificationReply\x12\x35.legacymodelresponse.ZeroShotImgClassificationRequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x62\x06proto3'
b'\n\x19legacymodelresponse.proto\x12\x13legacymodelresponse\x1a\x1bgoogle/protobuf/empty.proto\"_\n\x05Value\x12\x10\n\x06svalue\x18\x01 \x01(\tH\x00\x12\x10\n\x06ivalue\x18\x02 \x01(\x03H\x00\x12\x10\n\x06\x66value\x18\x03 \x01(\x02H\x00\x12\x10\n\x06\x62value\x18\x04 \x01(\x08H\x00\x42\x0e\n\x0coneof_values\"\x1f\n\tSessionID\x12\x12\n\nsession_id\x18\x01 \x01(\t\"\xc7\x01\n\x13SingleStringRequest\x12\x0f\n\x07request\x18\x01 \x01(\t\x12O\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x39.legacymodelresponse.SingleStringRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"\xc5\x01\n\x12MultiStringRequest\x12\x0f\n\x07request\x18\x01 \x03(\t\x12N\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x38.legacymodelresponse.MultiStringRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"S\n\x11SingleStringReply\x12\x10\n\x08response\x18\x01 \x01(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"R\n\x10MultiStringReply\x12\x10\n\x08response\x18\x01 \x03(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"\xc5\x01\n\tQARequest\x12\x10\n\x08question\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontext\x18\x02 \x01(\t\x12\x45\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32/.legacymodelresponse.QARequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"\x94\x02\n\x13\x43onversationRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x17\n\x0f\x63onversation_id\x18\x02 \x01(\t\x12\x18\n\x10past_user_inputs\x18\x03 \x03(\t\x12\x1b\n\x13generated_responses\x18\x04 \x03(\t\x12O\n\x0cquery_kwargs\x18\x05 \x03(\x0b\x32\x39.legacymodelresponse.ConversationRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"\x91\x01\n\x11\x43onversationReply\x12\x17\n\x0f\x63onversation_id\x18\x01 \x01(\t\x12\x18\n\x10past_user_inputs\x18\x02 \x03(\t\x12\x1b\n\x13generated_responses\x18\x03 \x03(\t\x12\x12\n\ntime_taken\x18\x04 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x05 \x01(\x02\"}\n\nImageReply\x12\x0e\n\x06images\x18\x01 \x03(\x0c\x12\x1d\n\x15nsfw_content_detected\x18\x02 \x03(\x08\x12\x0c\n\x04mode\x18\x03 \x01(\t\x12\x0e\n\x06size_w\x18\x04 \x01(\x03\x12\x0e\n\x06size_h\x18\x05 \x01(\x03\x12\x12\n\ntime_taken\x18\x06 \x01(\x02\"\xdb\x01\n\x11Text2ImageRequest\x12\x0e\n\x06prompt\x18\x01 \x03(\t\x12\x17\n\x0fnegative_prompt\x18\x02 \x03(\t\x12M\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32\x37.legacymodelresponse.Text2ImageRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"\xf9\x01\n ZeroShotImgClassificationRequest\x12\r\n\x05image\x18\x01 \x01(\t\x12\x18\n\x10\x63\x61ndidate_labels\x18\x02 \x03(\t\x12\\\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32\x46.legacymodelresponse.ZeroShotImgClassificationRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"\xfe\x01\n\x11InpaintingRequest\x12\x0e\n\x06prompt\x18\x01 \x03(\t\x12\r\n\x05image\x18\x02 \x03(\x0c\x12\x12\n\nmask_image\x18\x03 \x03(\x0c\x12\x17\n\x0fnegative_prompt\x18\x04 \x03(\t\x12M\n\x0cquery_kwargs\x18\x05 \x03(\x0b\x32\x37.legacymodelresponse.InpaintingRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\x32\x95\t\n\rModelResponse\x12=\n\tTerminate\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12I\n\rCreateSession\x12\x1e.legacymodelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12J\n\x0e\x44\x65stroySession\x12\x1e.legacymodelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12\x62\n\x0eGeneratorReply\x12\'.legacymodelresponse.MultiStringRequest\x1a%.legacymodelresponse.MultiStringReply\"\x00\x12i\n\x13\x43lassificationReply\x12(.legacymodelresponse.SingleStringRequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x12\x62\n\x16QuestionAndAnswerReply\x12\x1e.legacymodelresponse.QARequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x12\x63\n\rFillMaskReply\x12(.legacymodelresponse.SingleStringRequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x12n\n\x18TokenClassificationReply\x12(.legacymodelresponse.SingleStringRequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x12i\n\x13\x43onversationalReply\x12(.legacymodelresponse.ConversationRequest\x1a&.legacymodelresponse.ConversationReply\"\x00\x12Y\n\x0cTxt2ImgReply\x12&.legacymodelresponse.Text2ImageRequest\x1a\x1f.legacymodelresponse.ImageReply\"\x00\x12\x81\x01\n\x1eZeroShotImgClassificationReply\x12\x35.legacymodelresponse.ZeroShotImgClassificationRequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x12\\\n\x0fInpaintingReply\x12&.legacymodelresponse.InpaintingRequest\x1a\x1f.legacymodelresponse.ImageReply\"\x00\x62\x06proto3'
)

_globals = globals()
Expand All @@ -38,6 +37,8 @@
_globals['_ZEROSHOTIMGCLASSIFICATIONREQUEST_QUERYKWARGSENTRY']._options = None
_globals[
'_ZEROSHOTIMGCLASSIFICATIONREQUEST_QUERYKWARGSENTRY']._serialized_options = b'8\001'
_globals['_INPAINTINGREQUEST_QUERYKWARGSENTRY']._options = None
_globals['_INPAINTINGREQUEST_QUERYKWARGSENTRY']._serialized_options = b'8\001'
_globals['_VALUE']._serialized_start = 79
_globals['_VALUE']._serialized_end = 174
_globals['_SESSIONID']._serialized_start = 176
Expand Down Expand Up @@ -75,6 +76,10 @@
_globals[
'_ZEROSHOTIMGCLASSIFICATIONREQUEST_QUERYKWARGSENTRY']._serialized_start = 331
_globals['_ZEROSHOTIMGCLASSIFICATIONREQUEST_QUERYKWARGSENTRY']._serialized_end = 409
_globals['_MODELRESPONSE']._serialized_start = 2009
_globals['_MODELRESPONSE']._serialized_end = 3088
_globals['_INPAINTINGREQUEST']._serialized_start = 2009
_globals['_INPAINTINGREQUEST']._serialized_end = 2263
_globals['_INPAINTINGREQUEST_QUERYKWARGSENTRY']._serialized_start = 331
_globals['_INPAINTINGREQUEST_QUERYKWARGSENTRY']._serialized_end = 409
_globals['_MODELRESPONSE']._serialized_start = 2266
_globals['_MODELRESPONSE']._serialized_end = 3439
# @@protoc_insertion_point(module_scope)
44 changes: 44 additions & 0 deletions mii/legacy/grpc_related/proto/legacymodelresponse_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ def __init__(self, channel):
SerializeToString,
response_deserializer=legacymodelresponse__pb2.SingleStringReply.FromString,
)
self.InpaintingReply = channel.unary_unary(
'/legacymodelresponse.ModelResponse/InpaintingReply',
request_serializer=legacymodelresponse__pb2.InpaintingRequest.
SerializeToString,
response_deserializer=legacymodelresponse__pb2.ImageReply.FromString,
)


class ModelResponseServicer(object):
Expand Down Expand Up @@ -151,6 +157,12 @@ def ZeroShotImgClassificationReply(self, request, context):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def InpaintingReply(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')


def add_ModelResponseServicer_to_server(servicer, server):
rpc_method_handlers = {
Expand Down Expand Up @@ -231,6 +243,12 @@ def add_ModelResponseServicer_to_server(servicer, server):
response_serializer=legacymodelresponse__pb2.SingleStringReply.
SerializeToString,
),
'InpaintingReply':
grpc.unary_unary_rpc_method_handler(
servicer.InpaintingReply,
request_deserializer=legacymodelresponse__pb2.InpaintingRequest.FromString,
response_serializer=legacymodelresponse__pb2.ImageReply.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'legacymodelresponse.ModelResponse',
Expand Down Expand Up @@ -526,3 +544,29 @@ def ZeroShotImgClassificationReply(request,
wait_for_ready,
timeout,
metadata)

@staticmethod
def InpaintingReply(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/legacymodelresponse.ModelResponse/InpaintingReply',
legacymodelresponse__pb2.InpaintingRequest.SerializeToString,
legacymodelresponse__pb2.ImageReply.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata)
48 changes: 47 additions & 1 deletion mii/legacy/method_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from mii.legacy.constants import TaskType
from mii.legacy.grpc_related.proto import legacymodelresponse_pb2 as modelresponse_pb2
from mii.legacy.utils import kwarg_dict_to_proto, unpack_proto_query_kwargs
from mii.legacy.models.utils import ImageResponse
from mii.legacy.models.utils import ImageResponse, convert_bytes_to_pil_image


def single_string_request_to_proto(self, request_dict, **query_kwargs):
Expand Down Expand Up @@ -312,6 +312,51 @@ def run_inference(self, inference_pipeline, args, kwargs):
return inference_pipeline(image, candidate_labels=candidate_labels, **kwargs)


class InpaintingMethods(Text2ImgMethods):
@property
def method(self):
return "InpaintingReply"

def run_inference(self, inference_pipeline, args, kwargs):
prompt, image, mask_image, negative_prompt = args
return inference_pipeline(prompt=prompt,
image=image,
mask_image=mask_image,
negative_prompt=negative_prompt,
**kwargs)

def pack_request_to_proto(self, request_dict, **query_kwargs):
prompt = request_dict["prompt"]
prompt = prompt if isinstance(prompt, list) else [prompt]
negative_prompt = request_dict.get("negative_prompt", [""] * len(prompt))
negative_prompt = negative_prompt if isinstance(negative_prompt,
list) else [negative_prompt]
image = request_dict["image"] if isinstance(request_dict["image"],
list) else [request_dict["image"]]
mask_image = request_dict["mask_image"] if isinstance(
request_dict["mask_image"],
list) else [request_dict["mask_image"]]

return modelresponse_pb2.InpaintingRequest(
prompt=prompt,
image=image,
mask_image=mask_image,
negative_prompt=negative_prompt,
query_kwargs=kwarg_dict_to_proto(query_kwargs),
)

def unpack_request_from_proto(self, request):
kwargs = unpack_proto_query_kwargs(request.query_kwargs)

image = [convert_bytes_to_pil_image(img) for img in request.image]
mask_image = [
convert_bytes_to_pil_image(mask_image) for mask_image in request.mask_image
]

args = (list(request.prompt), image, mask_image, list(request.negative_prompt))
return args, kwargs


GRPC_METHOD_TABLE = {
TaskType.TEXT_GENERATION: TextGenerationMethods(),
TaskType.TEXT_CLASSIFICATION: TextClassificationMethods(),
Expand All @@ -321,4 +366,5 @@ def run_inference(self, inference_pipeline, args, kwargs):
TaskType.CONVERSATIONAL: ConversationalMethods(),
TaskType.TEXT2IMG: Text2ImgMethods(),
TaskType.ZERO_SHOT_IMAGE_CLASSIFICATION: ZeroShotImgClassificationMethods(),
TaskType.INPAINTING: InpaintingMethods(),
}
Loading

0 comments on commit 8e0c7f1

Please sign in to comment.