Skip to content

Commit

Permalink
img2img, inpaint; init_image, mask_image, strength
Browse files Browse the repository at this point in the history
  • Loading branch information
gadicc committed Sep 16, 2022
1 parent c8b4a6b commit 52aff34
Showing 1 changed file with 38 additions and 19 deletions.
57 changes: 38 additions & 19 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import base64
from io import BytesIO
import os
import PIL
from APP_VARS import MODEL_ID, PIPELINE, SCHEDULER

PIPELINES = [
Expand All @@ -21,7 +22,7 @@
]


def getPipeline(PIPELINE):
def getPipeline(PIPELINE: str):
print("PIPELINE = '" + PIPELINE + "'")
if PIPELINE == "StableDiffusionPipeline":
return StableDiffusionPipeline
Expand All @@ -31,7 +32,7 @@ def getPipeline(PIPELINE):
return StableDiffusionInpaintPipeline


def initScheduler(SCHEDULER):
def initScheduler(SCHEDULER: str):
print("SCHEDULER = '" + SCHEDULER + "'")
if SCHEDULER == "LMS":
return LMSDiscreteScheduler(
Expand Down Expand Up @@ -75,6 +76,10 @@ def init():
).to("cuda")


def decodeBase64Image(imageStr: str) -> PIL.Image:
return PIL.Image.open(BytesIO(base64.decodebytes(bytes(imageStr, "utf-8"))))


# Inference is ran for every server call
# Reference your preloaded global model variable here.
def inference(model_inputs: dict) -> dict:
Expand All @@ -87,6 +92,9 @@ def inference(model_inputs: dict) -> dict:
model_id = model_inputs.get("MODEL_ID")
pipeline_name = model_inputs.get("PIPELINE")
scheduler_name = model_inputs.get("SCHEDULER")
del model_inputs["MODEL_ID"]
del model_inputs["PIPELINE"]
del model_inputs["SCHEDULER"]

if (
last_model_id != model_id
Expand All @@ -109,33 +117,44 @@ def inference(model_inputs: dict) -> dict:
last_scheduler_name = scheduler_name

# Parse out your arguments
prompt = model_inputs.get("prompt", None)
if prompt == None:
return {"message": "No prompt provided"}

height = model_inputs.get("height", 512)
width = model_inputs.get("width", 512)
num_inference_steps = model_inputs.get("num_inference_steps", 50)
guidance_scale = model_inputs.get("guidance_scale", 7.5)
seed = model_inputs.get("seed", None)
# prompt = model_inputs.get("prompt", None)
# if prompt == None:
# return {"message": "No prompt provided"}
#
# height = model_inputs.get("height", 512)
# width = model_inputs.get("width", 512)
# num_inference_steps = model_inputs.get("num_inference_steps", 50)
# guidance_scale = model_inputs.get("guidance_scale", 7.5)
# seed = model_inputs.get("seed", None)

if (
pipeline_name == "StableDiffusionImg2ImgPipeline"
or pipeline_name == "StableDiffusionInpaintPipeline"
):
model_inputs.update(
{"init_image": decodeBase64Image(model_inputs.get("init_image"))}
)
strength = model_inputs.get("strength", 0.75)

if pipeline_name == "StableDiffusionInpaintPipeline":
model_inputs.update(
{"mask_image": decodeBase64Image(model_inputs.get("mask_image"))}
)

seed = model_inputs.get("seed", None)
if seed == None:
# generator = None;
generator = torch.Generator(device="cuda")
generator.seed()
else:
generator = torch.Generator(device="cuda").manual_seed(seed)
del model_inputs.seed

model_inputs.update({"generator": generator})

# Run the model
with autocast("cuda"):
image = model(
prompt=prompt,
width=width,
height=height,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
).images[0]
image = model(**model_inputs).images[0]

buffered = BytesIO()
image.save(buffered, format="JPEG")
Expand Down

0 comments on commit 52aff34

Please sign in to comment.