Skip to content

Commit

Permalink
guess mode
Browse files Browse the repository at this point in the history
  • Loading branch information
lllyasviel committed Feb 20, 2023
1 parent bdf496b commit 005008b
Show file tree
Hide file tree
Showing 10 changed files with 50 additions and 40 deletions.
9 changes: 5 additions & 4 deletions gradio_canny2image.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
ddim_sampler = DDIMSampler(model)


def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, strength, scale, seed, eta, low_threshold, high_threshold):
def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, low_threshold, high_threshold):
with torch.no_grad():
img = resize_image(HWC3(input_image), image_resolution)
H, W, C = img.shape
Expand All @@ -43,13 +43,13 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti
model.low_vram_shift(is_diffusing=False)

cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
un_cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
un_cond = {"c_concat": [torch.zeros_like(control) if guess_mode else control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
shape = (4, H // 8, W // 8)

if config.save_memory:
model.low_vram_shift(is_diffusing=True)

model.control_scales = [strength] * 13
model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13)
samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
shape, cond, verbose=False, eta=eta,
unconditional_guidance_scale=scale,
Expand Down Expand Up @@ -78,6 +78,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256)
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
guess_mode = gr.Checkbox(label='Guess Mode', value=False)
low_threshold = gr.Slider(label="Canny low threshold", minimum=1, maximum=255, value=100, step=1)
high_threshold = gr.Slider(label="Canny high threshold", minimum=1, maximum=255, value=200, step=1)
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
Expand All @@ -89,7 +90,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
with gr.Column():
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, strength, scale, seed, eta, low_threshold, high_threshold]
ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, low_threshold, high_threshold]
run_button.click(fn=process, inputs=ips, outputs=[result_gallery])


Expand Down
9 changes: 5 additions & 4 deletions gradio_depth2image.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
ddim_sampler = DDIMSampler(model)


def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, strength, scale, seed, eta):
def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
with torch.no_grad():
input_image = HWC3(input_image)
detected_map, _ = apply_midas(resize_image(input_image, detect_resolution))
Expand All @@ -45,13 +45,13 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti
model.low_vram_shift(is_diffusing=False)

cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
un_cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
un_cond = {"c_concat": [torch.zeros_like(control) if guess_mode else control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
shape = (4, H // 8, W // 8)

if config.save_memory:
model.low_vram_shift(is_diffusing=True)

model.control_scales = [strength] * 13
model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13)
samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
shape, cond, verbose=False, eta=eta,
unconditional_guidance_scale=scale,
Expand Down Expand Up @@ -80,6 +80,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256)
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
guess_mode = gr.Checkbox(label='Guess Mode', value=False)
detect_resolution = gr.Slider(label="Depth Resolution", minimum=128, maximum=1024, value=384, step=1)
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
Expand All @@ -90,7 +91,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
with gr.Column():
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, strength, scale, seed, eta]
ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
run_button.click(fn=process, inputs=ips, outputs=[result_gallery])


Expand Down
9 changes: 5 additions & 4 deletions gradio_fake_scribble2image.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
ddim_sampler = DDIMSampler(model)


def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, strength, scale, seed, eta):
def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
with torch.no_grad():
input_image = HWC3(input_image)
detected_map = apply_hed(resize_image(input_image, detect_resolution))
Expand All @@ -49,13 +49,13 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti
model.low_vram_shift(is_diffusing=False)

cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
un_cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
un_cond = {"c_concat": [torch.zeros_like(control) if guess_mode else control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
shape = (4, H // 8, W // 8)

if config.save_memory:
model.low_vram_shift(is_diffusing=True)

model.control_scales = [strength] * 13
model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13)
samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
shape, cond, verbose=False, eta=eta,
unconditional_guidance_scale=scale,
Expand Down Expand Up @@ -84,6 +84,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256)
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
guess_mode = gr.Checkbox(label='Guess Mode', value=False)
detect_resolution = gr.Slider(label="HED Resolution", minimum=128, maximum=1024, value=512, step=1)
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
Expand All @@ -94,7 +95,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
with gr.Column():
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, strength, scale, seed, eta]
ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
run_button.click(fn=process, inputs=ips, outputs=[result_gallery])


Expand Down
9 changes: 5 additions & 4 deletions gradio_hed2image.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
ddim_sampler = DDIMSampler(model)


def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, strength, scale, seed, eta):
def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
with torch.no_grad():
input_image = HWC3(input_image)
detected_map = apply_hed(resize_image(input_image, detect_resolution))
Expand All @@ -45,13 +45,13 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti
model.low_vram_shift(is_diffusing=False)

cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
un_cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
un_cond = {"c_concat": [torch.zeros_like(control) if guess_mode else control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
shape = (4, H // 8, W // 8)

if config.save_memory:
model.low_vram_shift(is_diffusing=True)

model.control_scales = [strength] * 13
model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13)
samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
shape, cond, verbose=False, eta=eta,
unconditional_guidance_scale=scale,
Expand Down Expand Up @@ -80,6 +80,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256)
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
guess_mode = gr.Checkbox(label='Guess Mode', value=False)
detect_resolution = gr.Slider(label="HED Resolution", minimum=128, maximum=1024, value=512, step=1)
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
Expand All @@ -90,7 +91,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
with gr.Column():
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, strength, scale, seed, eta]
ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
run_button.click(fn=process, inputs=ips, outputs=[result_gallery])


Expand Down
9 changes: 5 additions & 4 deletions gradio_hough2image.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
ddim_sampler = DDIMSampler(model)


def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, strength, scale, seed, eta, value_threshold, distance_threshold):
def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, value_threshold, distance_threshold):
with torch.no_grad():
input_image = HWC3(input_image)
detected_map = apply_mlsd(resize_image(input_image, detect_resolution), value_threshold, distance_threshold)
Expand All @@ -45,13 +45,13 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti
model.low_vram_shift(is_diffusing=False)

cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
un_cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
un_cond = {"c_concat": [torch.zeros_like(control) if guess_mode else control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
shape = (4, H // 8, W // 8)

if config.save_memory:
model.low_vram_shift(is_diffusing=True)

model.control_scales = [strength] * 13
model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13)
samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
shape, cond, verbose=False, eta=eta,
unconditional_guidance_scale=scale,
Expand Down Expand Up @@ -80,6 +80,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256)
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
guess_mode = gr.Checkbox(label='Guess Mode', value=False)
detect_resolution = gr.Slider(label="Hough Resolution", minimum=128, maximum=1024, value=512, step=1)
value_threshold = gr.Slider(label="Hough value threshold (MLSD)", minimum=0.01, maximum=2.0, value=0.1, step=0.01)
distance_threshold = gr.Slider(label="Hough distance threshold (MLSD)", minimum=0.01, maximum=20.0, value=0.1, step=0.01)
Expand All @@ -92,7 +93,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
with gr.Column():
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, strength, scale, seed, eta, value_threshold, distance_threshold]
ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, value_threshold, distance_threshold]
run_button.click(fn=process, inputs=ips, outputs=[result_gallery])


Expand Down
9 changes: 5 additions & 4 deletions gradio_normal2image.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
ddim_sampler = DDIMSampler(model)


def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, strength, scale, seed, eta, bg_threshold):
def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, bg_threshold):
with torch.no_grad():
input_image = HWC3(input_image)
_, detected_map = apply_midas(resize_image(input_image, detect_resolution), bg_th=bg_threshold)
Expand All @@ -45,13 +45,13 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti
model.low_vram_shift(is_diffusing=False)

cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
un_cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
un_cond = {"c_concat": [torch.zeros_like(control) if guess_mode else control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
shape = (4, H // 8, W // 8)

if config.save_memory:
model.low_vram_shift(is_diffusing=True)

model.control_scales = [strength] * 13
model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13)
samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
shape, cond, verbose=False, eta=eta,
unconditional_guidance_scale=scale,
Expand Down Expand Up @@ -80,6 +80,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256)
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
guess_mode = gr.Checkbox(label='Guess Mode', value=False)
detect_resolution = gr.Slider(label="Normal Resolution", minimum=128, maximum=1024, value=384, step=1)
bg_threshold = gr.Slider(label="Normal background threshold", minimum=0.0, maximum=1.0, value=0.4, step=0.01)
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
Expand All @@ -91,7 +92,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
with gr.Column():
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, strength, scale, seed, eta, bg_threshold]
ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, bg_threshold]
run_button.click(fn=process, inputs=ips, outputs=[result_gallery])


Expand Down
Loading

0 comments on commit 005008b

Please sign in to comment.